1 /*
2  * Copyright 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "struct_def.h"
18 
19 #include "fields/all_fields.h"
20 #include "util.h"
21 
StructDef(std::string name,FieldList fields)22 StructDef::StructDef(std::string name, FieldList fields) : StructDef(name, fields, nullptr) {}
StructDef(std::string name,FieldList fields,StructDef * parent)23 StructDef::StructDef(std::string name, FieldList fields, StructDef* parent)
24     : ParentDef(name, fields, parent), total_size_(GetSize(true)) {}
25 
GetNewField(const std::string & name,ParseLocation loc) const26 PacketField* StructDef::GetNewField(const std::string& name, ParseLocation loc) const {
27   if (fields_.HasBody()) {
28     return new VariableLengthStructField(name, name_, loc);
29   } else {
30     return new StructField(name, name_, total_size_, loc);
31   }
32 }
33 
GetDefinitionType() const34 TypeDef::Type StructDef::GetDefinitionType() const {
35   return TypeDef::Type::STRUCT;
36 }
37 
GenSpecialize(std::ostream & s) const38 void StructDef::GenSpecialize(std::ostream& s) const {
39   if (parent_ == nullptr) {
40     return;
41   }
42   s << "static " << name_ << "* Specialize(" << parent_->name_ << "* parent) {";
43   s << "ASSERT(" << name_ << "::IsInstance(*parent));";
44   s << "return static_cast<" << name_ << "*>(parent);";
45   s << "}";
46 }
47 
GenToString(std::ostream & s) const48 void StructDef::GenToString(std::ostream& s) const {
49   s << "std::string ToString() {";
50   s << "std::stringstream ss;";
51   s << "ss << std::hex << std::showbase << \"" << name_ << " { \";";
52 
53   if (fields_.size() > 0) {
54     s << "ss";
55     bool firstfield = true;
56     for (const auto& field : fields_) {
57       if (field->GetFieldType() == ReservedField::kFieldType ||
58           field->GetFieldType() == ChecksumStartField::kFieldType ||
59           field->GetFieldType() == FixedScalarField::kFieldType || field->GetFieldType() == CountField::kFieldType ||
60           field->GetFieldType() == SizeField::kFieldType)
61         continue;
62 
63       s << (firstfield ? " << \"" : " << \", ") << field->GetName() << " = \" << ";
64 
65       field->GenStringRepresentation(s, field->GetName() + "_");
66 
67       if (firstfield) {
68         firstfield = false;
69       }
70     }
71     s << ";";
72   }
73 
74   s << "ss << \" }\";";
75   s << "return ss.str();";
76   s << "}\n";
77 }
78 
GenParse(std::ostream & s) const79 void StructDef::GenParse(std::ostream& s) const {
80   std::string iterator = (is_little_endian_ ? "Iterator<kLittleEndian>" : "Iterator<!kLittleEndian>");
81 
82   if (fields_.HasBody()) {
83     s << "static std::optional<" << iterator << ">";
84   } else {
85     s << "static " << iterator;
86   }
87 
88   s << " Parse(" << name_ << "* to_fill, " << iterator << " struct_begin_it ";
89 
90   if (parent_ != nullptr) {
91     s << ", bool fill_parent = true) {";
92   } else {
93     s << ") {";
94   }
95   s << "auto to_bound = struct_begin_it;";
96 
97   if (parent_ != nullptr) {
98     s << "if (fill_parent) {";
99     std::string parent_param = (parent_->parent_ == nullptr ? "" : ", true");
100     if (parent_->fields_.HasBody()) {
101       s << "auto parent_optional_it = " << parent_->name_ << "::Parse(to_fill, to_bound" << parent_param << ");";
102       if (fields_.HasBody()) {
103         s << "if (!parent_optional_it) { return {}; }";
104       } else {
105         s << "ASSERT(parent_optional_it);";
106       }
107     } else {
108       s << parent_->name_ << "::Parse(to_fill, to_bound" << parent_param << ");";
109     }
110     s << "}";
111   }
112 
113   if (!fields_.HasBody()) {
114     s << "size_t end_index = struct_begin_it.NumBytesRemaining();";
115     s << "if (end_index < " << GetSize().bytes() << ")";
116     s << "{ return struct_begin_it.Subrange(0,0);}";
117   }
118 
119   Size total_bits{0};
120   for (const auto& field : fields_) {
121     if (field->GetFieldType() != ReservedField::kFieldType && field->GetFieldType() != BodyField::kFieldType &&
122         field->GetFieldType() != FixedScalarField::kFieldType &&
123         field->GetFieldType() != ChecksumStartField::kFieldType && field->GetFieldType() != ChecksumField::kFieldType &&
124         field->GetFieldType() != CountField::kFieldType) {
125       total_bits += field->GetSize().bits();
126     }
127   }
128   s << "{";
129   s << "if (to_bound.NumBytesRemaining() < " << total_bits.bytes() << ")";
130   if (!fields_.HasBody()) {
131     s << "{ return to_bound.Subrange(to_bound.NumBytesRemaining(),0);}";
132   } else {
133     s << "{ return {};}";
134   }
135   s << "}";
136   for (const auto& field : fields_) {
137     if (field->GetFieldType() != ReservedField::kFieldType && field->GetFieldType() != BodyField::kFieldType &&
138         field->GetFieldType() != FixedScalarField::kFieldType && field->GetFieldType() != SizeField::kFieldType &&
139         field->GetFieldType() != ChecksumStartField::kFieldType && field->GetFieldType() != ChecksumField::kFieldType &&
140         field->GetFieldType() != CountField::kFieldType) {
141       s << "{";
142       int num_leading_bits =
143           field->GenBounds(s, GetStructOffsetForField(field->GetName()), Size(), field->GetStructSize());
144       s << "auto " << field->GetName() << "_ptr = &to_fill->" << field->GetName() << "_;";
145       field->GenExtractor(s, num_leading_bits, true);
146       s << "}";
147     }
148     if (field->GetFieldType() == CountField::kFieldType || field->GetFieldType() == SizeField::kFieldType) {
149       s << "{";
150       int num_leading_bits =
151           field->GenBounds(s, GetStructOffsetForField(field->GetName()), Size(), field->GetStructSize());
152       s << "auto " << field->GetName() << "_ptr = &to_fill->" << field->GetName() << "_extracted_;";
153       field->GenExtractor(s, num_leading_bits, true);
154       s << "}";
155     }
156   }
157   s << "return struct_begin_it + to_fill->size();";
158   s << "}";
159 }
160 
GenParseFunctionPrototype(std::ostream & s) const161 void StructDef::GenParseFunctionPrototype(std::ostream& s) const {
162   s << "std::unique_ptr<" << name_ << "> Parse" << name_ << "(";
163   if (is_little_endian_) {
164     s << "Iterator<kLittleEndian>";
165   } else {
166     s << "Iterator<!kLittleEndian>";
167   }
168   s << "it);";
169 }
170 
GenDefinition(std::ostream & s) const171 void StructDef::GenDefinition(std::ostream& s) const {
172   s << "class " << name_;
173   if (parent_ != nullptr) {
174     s << " : public " << parent_->name_;
175   } else {
176     if (is_little_endian_) {
177       s << " : public PacketStruct<kLittleEndian>";
178     } else {
179       s << " : public PacketStruct<!kLittleEndian>";
180     }
181   }
182   s << " {";
183   s << " public:";
184 
185   GenConstructor(s);
186 
187   s << " public:\n";
188   s << "  virtual ~" << name_ << "() = default;\n";
189 
190   GenSerialize(s);
191   s << "\n";
192 
193   GenParse(s);
194   s << "\n";
195 
196   GenSize(s);
197   s << "\n";
198 
199   GenInstanceOf(s);
200   s << "\n";
201 
202   GenSpecialize(s);
203   s << "\n";
204 
205   GenToString(s);
206   s << "\n";
207 
208   GenMembers(s);
209   for (const auto& field : fields_) {
210     if (field->GetFieldType() == CountField::kFieldType || field->GetFieldType() == SizeField::kFieldType) {
211       s << "\n private:\n";
212       s << " mutable " << field->GetDataType() << " " << field->GetName() << "_extracted_{0};";
213     }
214   }
215   s << "};\n";
216 
217   if (fields_.HasBody()) {
218     GenParseFunctionPrototype(s);
219   }
220   s << "\n";
221 }
222 
GenDefinitionPybind11(std::ostream & s) const223 void StructDef::GenDefinitionPybind11(std::ostream& s) const {
224   s << "py::class_<" << name_;
225   if (parent_ != nullptr) {
226     s << ", " << parent_->name_;
227   } else {
228     if (is_little_endian_) {
229       s << ", PacketStruct<kLittleEndian>";
230     } else {
231       s << ", PacketStruct<!kLittleEndian>";
232     }
233   }
234   s << ", std::shared_ptr<" << name_ << ">";
235   s << ">(m, \"" << name_ << "\")";
236   s << ".def(py::init<>())";
237   s << ".def(\"Serialize\", [](" << GetTypeName() << "& obj){";
238   s << "std::vector<uint8_t> bytes;";
239   s << "BitInserter bi(bytes);";
240   s << "obj.Serialize(bi);";
241   s << "return bytes;})";
242   s << ".def(\"Parse\", &" << name_ << "::Parse)";
243   s << ".def(\"size\", &" << name_ << "::size)";
244   for (const auto& field : fields_) {
245     if (field->GetBuilderParameterType().empty()) {
246       continue;
247     }
248     s << ".def_readwrite(\"" << field->GetName() << "\", &" << name_ << "::" << field->GetName() << "_)";
249   }
250   s << ";\n";
251 }
252 
GenConstructor(std::ostream & s) const253 void StructDef::GenConstructor(std::ostream& s) const {
254   if (parent_ != nullptr) {
255     s << name_ << "(const " << parent_->name_ << "& parent) : " << parent_->name_ << "(parent) {}";
256     s << name_ << "() : " << parent_->name_ << "() {";
257   } else {
258     s << name_ << "() {";
259   }
260 
261   // Get the list of parent params.
262   FieldList parent_params;
263   if (parent_ != nullptr) {
264     parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
265         PayloadField::kFieldType,
266         BodyField::kFieldType,
267     });
268 
269     // Set constrained parent fields to their correct values.
270     for (const auto& field : parent_params) {
271       const auto& constraint = parent_constraints_.find(field->GetName());
272       if (constraint != parent_constraints_.end()) {
273         s << parent_->name_ << "::" << field->GetName() << "_ = ";
274         if (field->GetFieldType() == ScalarField::kFieldType) {
275           s << std::get<int64_t>(constraint->second) << ";";
276         } else if (field->GetFieldType() == EnumField::kFieldType) {
277           s << std::get<std::string>(constraint->second) << ";";
278         } else {
279           ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
280         }
281       }
282     }
283   }
284 
285   s << "}\n";
286 }
287 
GetStructOffsetForField(std::string field_name) const288 Size StructDef::GetStructOffsetForField(std::string field_name) const {
289   auto size = Size(0);
290   for (auto it = fields_.begin(); it != fields_.end(); it++) {
291     // We've reached the field, end the loop.
292     if ((*it)->GetName() == field_name) break;
293     const auto& field = *it;
294     // When we need to parse this field, all previous fields should already be parsed.
295     if (field->GetStructSize().empty()) {
296       ERROR() << "Empty size for field " << (*it)->GetName() << " finding the offset for field: " << field_name;
297     }
298     size += field->GetStructSize();
299   }
300 
301   // We need the offset until a body field.
302   if (parent_ != nullptr) {
303     auto parent_body_offset = static_cast<StructDef*>(parent_)->GetStructOffsetForField("body");
304     if (parent_body_offset.empty()) {
305       ERROR() << "Empty offset for body in " << parent_->name_ << " finding the offset for field: " << field_name;
306     }
307     size += parent_body_offset;
308   }
309 
310   return size;
311 }
312 
GenRustFieldNameAndType(std::ostream & s,bool include_fixed) const313 void StructDef::GenRustFieldNameAndType(std::ostream& s, bool include_fixed) const {
314   auto fields = fields_.GetFieldsWithoutTypes({
315       BodyField::kFieldType,
316       CountField::kFieldType,
317       PaddingField::kFieldType,
318       ReservedField::kFieldType,
319       SizeField::kFieldType,
320   });
321   for (const auto& field : fields) {
322     if (!include_fixed && field->GetFieldType() == FixedScalarField::kFieldType) {
323       continue;
324     }
325     field->GenRustNameAndType(s);
326     s << ", ";
327   }
328 }
329 
GenRustFieldNames(std::ostream & s) const330 void StructDef::GenRustFieldNames(std::ostream& s) const {
331   auto fields = fields_.GetFieldsWithoutTypes({
332       BodyField::kFieldType,
333       CountField::kFieldType,
334       PaddingField::kFieldType,
335       ReservedField::kFieldType,
336       SizeField::kFieldType,
337   });
338   for (const auto& field : fields) {
339     s << field->GetName();
340     s << ", ";
341   }
342 }
343 
GenRustDeclarations(std::ostream & s) const344 void StructDef::GenRustDeclarations(std::ostream& s) const {
345   s << "#[derive(Debug, Clone)] ";
346   s << "pub struct " << name_ << "{";
347 
348   // Generate struct fields
349   auto fields = fields_.GetFieldsWithoutTypes({
350       BodyField::kFieldType,
351       CountField::kFieldType,
352       PaddingField::kFieldType,
353       ReservedField::kFieldType,
354       SizeField::kFieldType,
355   });
356   for (const auto& field : fields) {
357     s << "pub ";
358     field->GenRustNameAndType(s);
359     s << ", ";
360   }
361   s << "}\n";
362 }
363 
GenRustImpls(std::ostream & s) const364 void StructDef::GenRustImpls(std::ostream& s) const {
365   s << "impl " << name_ << "{";
366 
367   s << "fn conforms(bytes: &[u8]) -> bool {";
368   GenRustConformanceCheck(s);
369   s << " true";
370   s << "}";
371 
372   s << "pub fn parse(bytes: &[u8]) -> Result<Self> {";
373   auto fields = fields_.GetFieldsWithoutTypes({
374       BodyField::kFieldType,
375   });
376 
377   for (const auto& field : fields) {
378     auto start_field_offset = GetOffsetForField(field->GetName(), false);
379     auto end_field_offset = GetOffsetForField(field->GetName(), true);
380 
381     if (start_field_offset.empty() && end_field_offset.empty()) {
382       ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, "
383                    << "no method exists to determine field location from begin() or end().\n";
384     }
385 
386     field->GenBoundsCheck(s, start_field_offset, end_field_offset, name_);
387     field->GenRustGetter(s, start_field_offset, end_field_offset);
388   }
389 
390   fields = fields_.GetFieldsWithoutTypes({
391       BodyField::kFieldType,
392       CountField::kFieldType,
393       PaddingField::kFieldType,
394       ReservedField::kFieldType,
395       SizeField::kFieldType,
396   });
397 
398   s << "Ok(Self {";
399   for (const auto& field : fields) {
400     if (field->GetFieldType() == FixedScalarField::kFieldType) {
401       s << field->GetName() << ": ";
402       static_cast<FixedScalarField*>(field)->GenValue(s);
403     } else {
404       s << field->GetName();
405     }
406     s << ", ";
407   }
408   s << "})}\n";
409 
410   // write_to function
411   s << "fn write_to(&self, buffer: &mut [u8]) {";
412   GenRustWriteToFields(s);
413   s << "}\n";
414 
415   s << "fn get_total_size(&self) -> usize {";
416   GenSizeRetVal(s);
417   s << "}";
418   s << "}\n";
419 }
420 
GenRustDef(std::ostream & s) const421 void StructDef::GenRustDef(std::ostream& s) const {
422   GenRustDeclarations(s);
423   GenRustImpls(s);
424 }
425