1 /*
2  * Copyright 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fields/vector_field.h"
18 
19 #include "fields/count_field.h"
20 #include "fields/custom_field.h"
21 #include "util.h"
22 
23 const std::string VectorField::kFieldType = "VectorField";
24 
VectorField(std::string name,int element_size,std::string size_modifier,ParseLocation loc)25 VectorField::VectorField(std::string name, int element_size, std::string size_modifier, ParseLocation loc)
26     : PacketField(name, loc), element_field_(new ScalarField("val", element_size, loc)), element_size_(element_size),
27       size_modifier_(size_modifier) {
28   if (element_size > 64 || element_size < 0)
29     ERROR(this) << __func__ << ": Not implemented for element size = " << element_size;
30   if (element_size % 8 != 0) {
31     ERROR(this) << "Can only have arrays with elements that are byte aligned (" << element_size << ")";
32   }
33 }
34 
VectorField(std::string name,TypeDef * type_def,std::string size_modifier,ParseLocation loc)35 VectorField::VectorField(std::string name, TypeDef* type_def, std::string size_modifier, ParseLocation loc)
36     : PacketField(name, loc), element_field_(type_def->GetNewField("val", loc)),
37       element_size_(element_field_->GetSize()), size_modifier_(size_modifier) {
38   if (!element_size_.empty() && element_size_.bits() % 8 != 0) {
39     ERROR(this) << "Can only have arrays with elements that are byte aligned (" << element_size_ << ")";
40   }
41 }
42 
GetFieldType() const43 const std::string& VectorField::GetFieldType() const {
44   return VectorField::kFieldType;
45 }
46 
GetSize() const47 Size VectorField::GetSize() const {
48   // If there is no size field, then it is of unknown size.
49   if (size_field_ == nullptr) {
50     return Size();
51   }
52 
53   // size_field_ is of type SIZE
54   if (size_field_->GetFieldType() == SizeField::kFieldType) {
55     std::string ret = "(static_cast<size_t>(Get" + util::UnderscoreToCamelCase(size_field_->GetName()) + "()) * 8)";
56     if (!size_modifier_.empty()) ret += size_modifier_;
57     return ret;
58   }
59 
60   // size_field_ is of type COUNT and elements have a fixed size
61   if (!element_size_.empty() && !element_size_.has_dynamic()) {
62     return "(static_cast<size_t>(Get" + util::UnderscoreToCamelCase(size_field_->GetName()) + "()) * " +
63            std::to_string(element_size_.bits()) + ")";
64   }
65 
66   return Size();
67 }
68 
GetBuilderSize() const69 Size VectorField::GetBuilderSize() const {
70   if (!element_size_.empty() && !element_size_.has_dynamic()) {
71     std::string ret = "(static_cast<size_t>(" + GetName() + "_.size()) * " + std::to_string(element_size_.bits()) + ")";
72     return ret;
73   } else if (element_field_->BuilderParameterMustBeMoved()) {
74     std::string ret = "[this](){ size_t length = 0; for (const auto& elem : " + GetName() +
75                       "_) { length += elem->size() * 8; } return length; }()";
76     return ret;
77   } else {
78     std::string ret = "[this](){ size_t length = 0; for (const auto& elem : " + GetName() +
79                       "_) { length += elem.size() * 8; } return length; }()";
80     return ret;
81   }
82 }
83 
GetStructSize() const84 Size VectorField::GetStructSize() const {
85   // If there is no size field, then it is of unknown size.
86   if (size_field_ == nullptr) {
87     return Size();
88   }
89 
90   // size_field_ is of type SIZE
91   if (size_field_->GetFieldType() == SizeField::kFieldType) {
92     std::string ret = "(static_cast<size_t>(to_fill->" + size_field_->GetName() + "_extracted_) * 8)";
93     if (!size_modifier_.empty()) ret += "-" + size_modifier_;
94     return ret;
95   }
96 
97   // size_field_ is of type COUNT and elements have a fixed size
98   if (!element_size_.empty() && !element_size_.has_dynamic()) {
99     return "(static_cast<size_t>(to_fill->" + size_field_->GetName() + "_extracted_) * " +
100            std::to_string(element_size_.bits()) + ")";
101   }
102 
103   return Size();
104 }
105 
GetDataType() const106 std::string VectorField::GetDataType() const {
107   return "std::vector<" + element_field_->GetDataType() + ">";
108 }
109 
GenExtractor(std::ostream & s,int num_leading_bits,bool for_struct) const110 void VectorField::GenExtractor(std::ostream& s, int num_leading_bits, bool for_struct) const {
111   s << "auto " << element_field_->GetName() << "_it = " << GetName() << "_it;";
112   if (size_field_ != nullptr && size_field_->GetFieldType() == CountField::kFieldType) {
113     s << "size_t " << element_field_->GetName() << "_count = ";
114     if (for_struct) {
115       s << "to_fill->" << size_field_->GetName() << "_extracted_;";
116     } else {
117       s << "Get" << util::UnderscoreToCamelCase(size_field_->GetName()) << "();";
118     }
119   }
120   s << "while (";
121   if (size_field_ != nullptr && size_field_->GetFieldType() == CountField::kFieldType) {
122     s << "(" << element_field_->GetName() << "_count-- > 0) && ";
123   }
124   if (!element_size_.empty()) {
125     s << element_field_->GetName() << "_it.NumBytesRemaining() >= " << element_size_.bytes() << ") {";
126   } else {
127     s << element_field_->GetName() << "_it.NumBytesRemaining() > 0) {";
128   }
129   if (element_field_->BuilderParameterMustBeMoved()) {
130     s << element_field_->GetDataType() << " " << element_field_->GetName() << "_ptr;";
131   } else {
132     s << element_field_->GetDataType() << " " << element_field_->GetName() << "_value;";
133     s << element_field_->GetDataType() << "* " << element_field_->GetName() << "_ptr = &" << element_field_->GetName()
134       << "_value;";
135   }
136   element_field_->GenExtractor(s, num_leading_bits, for_struct);
137   s << "if (" << element_field_->GetName() << "_ptr != nullptr) { ";
138   if (element_field_->BuilderParameterMustBeMoved()) {
139     s << GetName() << "_ptr->push_back(std::move(" << element_field_->GetName() << "_ptr));";
140   } else {
141     s << GetName() << "_ptr->push_back(" << element_field_->GetName() << "_value);";
142   }
143   s << "}";
144   s << "}";
145 }
146 
GetGetterFunctionName() const147 std::string VectorField::GetGetterFunctionName() const {
148   std::stringstream ss;
149   ss << "Get" << util::UnderscoreToCamelCase(GetName());
150   return ss.str();
151 }
152 
GenGetter(std::ostream & s,Size start_offset,Size end_offset) const153 void VectorField::GenGetter(std::ostream& s, Size start_offset, Size end_offset) const {
154   s << GetDataType() << " " << GetGetterFunctionName() << "() {";
155   s << "ASSERT(was_validated_);";
156   s << "size_t end_index = size();";
157   s << "auto to_bound = begin();";
158 
159   int num_leading_bits = GenBounds(s, start_offset, end_offset, GetSize());
160   s << GetDataType() << " " << GetName() << "_value{};";
161   s << GetDataType() << "* " << GetName() << "_ptr = &" << GetName() << "_value;";
162   GenExtractor(s, num_leading_bits, false);
163 
164   s << "return " << GetName() << "_value;";
165   s << "}\n";
166 }
167 
GetBuilderParameterType() const168 std::string VectorField::GetBuilderParameterType() const {
169   std::stringstream ss;
170   if (element_field_->BuilderParameterMustBeMoved()) {
171     ss << "std::vector<" << element_field_->GetDataType() << ">";
172   } else {
173     ss << "const std::vector<" << element_field_->GetDataType() << ">&";
174   }
175   return ss.str();
176 }
177 
BuilderParameterMustBeMoved() const178 bool VectorField::BuilderParameterMustBeMoved() const {
179   return element_field_->BuilderParameterMustBeMoved();
180 }
181 
GenBuilderMember(std::ostream & s) const182 bool VectorField::GenBuilderMember(std::ostream& s) const {
183   s << "std::vector<" << element_field_->GetDataType() << "> " << GetName();
184   return true;
185 }
186 
HasParameterValidator() const187 bool VectorField::HasParameterValidator() const {
188   // Does not have parameter validator yet.
189   // TODO: See comment in GenParameterValidator
190   return false;
191 }
192 
GenParameterValidator(std::ostream &) const193 void VectorField::GenParameterValidator(std::ostream&) const {
194   // No Parameter validator if its dynamically size.
195   // TODO: Maybe add a validator to ensure that the size isn't larger than what the size field can hold.
196   return;
197 }
198 
GenInserter(std::ostream & s) const199 void VectorField::GenInserter(std::ostream& s) const {
200   s << "for (const auto& val_ : " << GetName() << "_) {";
201   element_field_->GenInserter(s);
202   s << "}\n";
203 }
204 
GenValidator(std::ostream &) const205 void VectorField::GenValidator(std::ostream&) const {
206   // NOTE: We could check if the element size divides cleanly into the array size, but we decided to forgo that
207   // in favor of just returning as many elements as possible in a best effort style.
208   //
209   // Other than that there is nothing that arrays need to be validated on other than length so nothing needs to
210   // be done here.
211 }
212 
SetSizeField(const SizeField * size_field)213 void VectorField::SetSizeField(const SizeField* size_field) {
214   if (size_field->GetFieldType() == CountField::kFieldType && !size_modifier_.empty()) {
215     ERROR(this, size_field) << "Can not use count field to describe array with a size modifier."
216                             << " Use size instead";
217   }
218 
219   size_field_ = size_field;
220 }
221 
GetSizeModifier() const222 const std::string& VectorField::GetSizeModifier() const {
223   return size_modifier_;
224 }
225 
IsContainerField() const226 bool VectorField::IsContainerField() const {
227   return true;
228 }
229 
GetElementField() const230 const PacketField* VectorField::GetElementField() const {
231   return element_field_;
232 }
233 
GenStringRepresentation(std::ostream & s,std::string accessor) const234 void VectorField::GenStringRepresentation(std::ostream& s, std::string accessor) const {
235   s << "\"VECTOR[\";";
236 
237   std::string arr_idx = "arridx_" + accessor;
238   std::string vec_size = accessor + ".size()";
239   s << "for (size_t index = 0; index < " << vec_size << "; index++) {";
240   std::string element_accessor = "(" + accessor + "[index])";
241   s << "ss << ((index == 0) ? \"\" : \", \") << ";
242 
243   if (element_field_->GetFieldType() == CustomField::kFieldType) {
244     s << element_accessor << ".ToString()";
245   } else {
246     element_field_->GenStringRepresentation(s, element_accessor);
247   }
248 
249   s << ";}";
250   s << "ss << \"]\"";
251 }
252 
GetRustDataType() const253 std::string VectorField::GetRustDataType() const {
254   return "Vec::<" + element_field_->GetRustDataType() + ">";
255 }
256 
GenBoundsCheck(std::ostream & s,Size start_offset,Size,std::string context) const257 void VectorField::GenBoundsCheck(std::ostream& s, Size start_offset, Size, std::string context) const {
258   auto element_field_type = GetElementField()->GetFieldType();
259   auto element_field = GetElementField();
260   auto element_size = element_field->GetSize().bytes();
261 
262   if (element_field_type == ScalarField::kFieldType) {
263     if (size_field_ == nullptr) {
264       s << "let rem_ = (bytes.len() - " << start_offset.bytes() << ") % " << element_size << ";";
265       s << "if rem_ != 0 {";
266       s << " return Err(Error::InvalidLengthError{";
267       s << "    obj: \"" << context << "\".to_string(),";
268       s << "    field: \"" << GetName() << "\".to_string(),";
269       s << "    wanted: bytes.len() + rem_,";
270       s << "    got: bytes.len()});";
271       s << "}";
272     } else if (size_field_->GetFieldType() == CountField::kFieldType) {
273       s << "let want_ = " << start_offset.bytes() << " + ((" << size_field_->GetName() << " as usize) * "
274         << element_size << ");";
275       s << "if bytes.len() < want_ {";
276       s << " return Err(Error::InvalidLengthError{";
277       s << "    obj: \"" << context << "\".to_string(),";
278       s << "    field: \"" << GetName() << "\".to_string(),";
279       s << "    wanted: want_,";
280       s << "    got: bytes.len()});";
281       s << "}";
282     } else {
283       s << "let want_ = " << start_offset.bytes() << " + (" << size_field_->GetName() << " as usize)";
284       if (GetSizeModifier() != "") {
285         s << " - ((" << GetSizeModifier().substr(1) << ") / 8)";
286       }
287       s << ";";
288       s << "if bytes.len() < want_ {";
289       s << " return Err(Error::InvalidLengthError{";
290       s << "    obj: \"" << context << "\".to_string(),";
291       s << "    field: \"" << GetName() << "\".to_string(),";
292       s << "    wanted: want_,";
293       s << "    got: bytes.len()});";
294       s << "}";
295       if (GetSizeModifier() != "") {
296         s << "if ((" << size_field_->GetName() << " as usize) < ((" << GetSizeModifier().substr(1) << ") / 8)) {";
297         s << " return Err(Error::ImpossibleStructError);";
298         s << "}";
299       }
300     }
301   }
302 }
303 
GenRustGetter(std::ostream & s,Size start_offset,Size) const304 void VectorField::GenRustGetter(std::ostream& s, Size start_offset, Size) const {
305   auto element_field_type = GetElementField()->GetFieldType();
306   auto element_field = GetElementField();
307   auto element_size = element_field->GetSize().bytes();
308 
309   if (element_field_type == ScalarField::kFieldType) {
310     s << "let " << GetName() << ": " << GetRustDataType() << " = ";
311     if (size_field_ == nullptr) {
312       s << "bytes[" << start_offset.bytes() << "..]";
313     } else if (size_field_->GetFieldType() == CountField::kFieldType) {
314       s << "bytes[" << start_offset.bytes() << ".." << start_offset.bytes() << " + ((";
315       s << size_field_->GetName() << " as usize) * " << element_size << ")]";
316     } else {
317       s << "bytes[" << start_offset.bytes() << "..(";
318       s << start_offset.bytes() << " + " << size_field_->GetName();
319       s << " as usize)";
320       if (GetSizeModifier() != "") {
321         s << " - ((" << GetSizeModifier().substr(1) << ") / 8)";
322       }
323       s << "]";
324     }
325 
326     s << ".to_vec().chunks_exact(" << element_size << ").into_iter().map(|i| ";
327     s << element_field->GetRustDataType() << "::from_le_bytes([";
328 
329     for (int j=0; j < element_size; j++) {
330       s << "i[" << j << "]";
331       if (j != element_size - 1) {
332         s << ", ";
333       }
334     }
335     s << "])).collect();";
336   } else {
337     s << "let mut " << GetName() << ": " << GetRustDataType() << " = Vec::new();";
338     if (size_field_ == nullptr) {
339       s << "let mut parsable_ = &bytes[" << start_offset.bytes() << "..];";
340       s << "while parsable_.len() > 0 {";
341     } else if (size_field_->GetFieldType() == CountField::kFieldType) {
342       s << "let mut parsable_ = &bytes[" << start_offset.bytes() << "..];";
343       s << "let count_ = " << size_field_->GetName() << " as usize;";
344       s << "for _ in 0..count_ {";
345     } else {
346       s << "let mut parsable_ = &bytes[" << start_offset.bytes() << ".." << start_offset.bytes() << " + ("
347         << size_field_->GetName() << " as usize)";
348       if (GetSizeModifier() != "") {
349         s << " - ((" << GetSizeModifier().substr(1) << ") / 8)";
350       }
351       s << "];";
352       s << "while parsable_.len() > 0 {";
353     }
354     s << " match " << element_field->GetRustDataType() << "::parse(&parsable_) {";
355     s << "  Ok(parsed) => {";
356     s << "   parsable_ = &parsable_[parsed.get_total_size()..];";
357     s << GetName() << ".push(parsed);";
358     s << "  },";
359     s << "  Err(Error::ImpossibleStructError) => break,";
360     s << "  Err(e) => return Err(e),";
361     s << " }";
362     s << "}";
363   }
364 }
365 
GenRustWriter(std::ostream & s,Size start_offset,Size) const366 void VectorField::GenRustWriter(std::ostream& s, Size start_offset, Size) const {
367   if (GetElementField()->GetFieldType() == ScalarField::kFieldType) {
368     s << "for (i, e) in self." << GetName() << ".iter().enumerate() {";
369     s << "buffer[" << start_offset.bytes() << "+i..";
370     s << start_offset.bytes() << "+i+" << GetElementField()->GetSize().bytes() << "]";
371     s << ".copy_from_slice(&e.to_le_bytes())";
372     s << "}";
373   } else {
374     s << "let mut vec_buffer_ = &mut buffer[" << start_offset.bytes() << "..];";
375     s << "for e_ in &self." << GetName() << " {";
376     s << " e_.write_to(&mut vec_buffer_[0..e_.get_total_size()]);";
377     s << " vec_buffer_ = &mut vec_buffer_[e_.get_total_size()..];";
378     s << "}";
379   }
380 }
381