1 //
2 // Copyright (C) 2011 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 
17 #include "shill/net/byte_string.h"
18 
19 #include <netinet/in.h>
20 #include <string.h>
21 
22 #include <algorithm>
23 
24 #include <base/strings/string_number_conversions.h>
25 
26 using std::min;
27 using std::string;
28 using std::vector;
29 
30 namespace shill {
31 
ByteString(const ByteString & b)32 ByteString::ByteString(const ByteString& b) {
33   data_ = b.data_;
34 }
35 
operator =(const ByteString & b)36 ByteString& ByteString::operator=(const ByteString& b) {
37   data_ = b.data_;
38   return *this;
39 }
40 
GetData()41 unsigned char* ByteString::GetData() {
42   return (GetLength() == 0) ? nullptr : &data_.front();
43 }
44 
GetConstData() const45 const unsigned char* ByteString::GetConstData() const {
46   return (GetLength() == 0) ? nullptr : &data_.front();
47 }
48 
GetLength() const49 size_t ByteString::GetLength() const {
50   return data_.size();
51 }
52 
GetSubstring(size_t offset,size_t length) const53 ByteString ByteString::GetSubstring(size_t offset, size_t length) const {
54   if (offset > GetLength()) {
55     offset = GetLength();
56   }
57   if (length > GetLength() - offset) {
58     length = GetLength() - offset;
59   }
60   return ByteString(GetConstData() + offset, length);
61 }
62 
63 // static
CreateFromCPUUInt32(uint32_t val)64 ByteString ByteString::CreateFromCPUUInt32(uint32_t val) {
65   return ByteString(reinterpret_cast<unsigned char*>(&val), sizeof(val));
66 }
67 
68 // static
CreateFromNetUInt32(uint32_t val)69 ByteString ByteString::CreateFromNetUInt32(uint32_t val) {
70   return CreateFromCPUUInt32(ntohl(val));
71 }
72 
73 // static
CreateFromHexString(const string & hex_string)74 ByteString ByteString::CreateFromHexString(const string& hex_string) {
75   vector<uint8_t> bytes;
76   if (!base::HexStringToBytes(hex_string, &bytes)) {
77     return ByteString();
78   }
79   return ByteString(&bytes.front(), bytes.size());
80 }
81 
ConvertToCPUUInt32(uint32_t * val) const82 bool ByteString::ConvertToCPUUInt32(uint32_t* val) const {
83   if (val == nullptr || GetLength() != sizeof(*val)) {
84     return false;
85   }
86   memcpy(val, GetConstData(), sizeof(*val));
87 
88   return true;
89 }
90 
ConvertToNetUInt32(uint32_t * val) const91 bool ByteString::ConvertToNetUInt32(uint32_t* val) const {
92   if (!ConvertToCPUUInt32(val)) {
93     return false;
94   }
95   *val = ntohl(*val);
96   return true;
97 }
98 
99 template <typename T>
ConvertByteOrderAsUIntArray(T (* converter)(T))100 bool ByteString::ConvertByteOrderAsUIntArray(T (*converter)(T)) {
101   size_t length = GetLength();
102   if ((length % sizeof(T)) != 0) {
103     return false;
104   }
105   for (auto i = data_.begin(); i != data_.end(); i += sizeof(T)) {
106     // Take care of word alignment.
107     T val;
108     memcpy(&val, &(*i), sizeof(T));
109     val = converter(val);
110     memcpy(&(*i), &val, sizeof(T));
111   }
112   return true;
113 }
114 
ConvertFromNetToCPUUInt32Array()115 bool ByteString::ConvertFromNetToCPUUInt32Array() {
116   return ConvertByteOrderAsUIntArray(ntohl);
117 }
118 
ConvertFromCPUToNetUInt32Array()119 bool ByteString::ConvertFromCPUToNetUInt32Array() {
120   return ConvertByteOrderAsUIntArray(htonl);
121 }
122 
IsZero() const123 bool ByteString::IsZero() const {
124   for (const auto& i : data_) {
125     if (i != 0) {
126       return false;
127     }
128   }
129   return true;
130 }
131 
BitwiseAnd(const ByteString & b)132 bool ByteString::BitwiseAnd(const ByteString& b) {
133   if (GetLength() != b.GetLength()) {
134     return false;
135   }
136   auto lhs = data_.begin();
137   for (const auto& rhs : b.data_) {
138     *lhs++ &= rhs;
139   }
140   return true;
141 }
142 
BitwiseOr(const ByteString & b)143 bool ByteString::BitwiseOr(const ByteString& b) {
144   if (GetLength() != b.GetLength()) {
145     return false;
146   }
147   auto lhs = data_.begin();
148   for (const auto& rhs : b.data_) {
149     *lhs++ |= rhs;
150   }
151   return true;
152 }
153 
BitwiseInvert()154 void ByteString::BitwiseInvert() {
155   for (auto& i : data_) {
156     i = ~i;
157   }
158 }
159 
Equals(const ByteString & b) const160 bool ByteString::Equals(const ByteString& b) const {
161   if (GetLength() != b.GetLength()) {
162     return false;
163   }
164   auto lhs = data_.begin();
165   for (const auto& rhs : b.data_) {
166     if (*lhs++ != rhs) {
167       return false;
168     }
169   }
170   return true;
171 }
172 
Append(const ByteString & b)173 void ByteString::Append(const ByteString& b) {
174   data_.insert(data_.end(), b.data_.begin(), b.data_.end());
175 }
176 
Clear()177 void ByteString::Clear() {
178   data_.clear();
179 }
180 
Resize(int size)181 void ByteString::Resize(int size) {
182   data_.resize(size, 0);
183 }
184 
HexEncode() const185 string ByteString::HexEncode() const {
186   return base::HexEncode(GetConstData(), GetLength());
187 }
188 
CopyData(size_t size,void * output) const189 bool ByteString::CopyData(size_t size, void* output) const {
190   if (output == nullptr || GetLength() < size) {
191     return false;
192   }
193   memcpy(output, GetConstData(), size);
194   return true;
195 }
196 
197 // static
IsLessThan(const ByteString & lhs,const ByteString & rhs)198 bool ByteString::IsLessThan(const ByteString& lhs, const ByteString& rhs) {
199   size_t byte_count = min(lhs.GetLength(), rhs.GetLength());
200   int result = memcmp(lhs.GetConstData(), rhs.GetConstData(), byte_count);
201   if (result == 0) {
202     return lhs.GetLength() < rhs.GetLength();
203   }
204   return result < 0;
205 }
206 
207 }  // namespace shill
208