1 /*
2  * Copyright 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "parent_def.h"
18 
19 #include "fields/all_fields.h"
20 #include "util.h"
21 
ParentDef(std::string name,FieldList fields)22 ParentDef::ParentDef(std::string name, FieldList fields) : ParentDef(name, fields, nullptr) {}
ParentDef(std::string name,FieldList fields,ParentDef * parent)23 ParentDef::ParentDef(std::string name, FieldList fields, ParentDef* parent)
24     : TypeDef(name), fields_(fields), parent_(parent) {}
25 
AddParentConstraint(std::string field_name,std::variant<int64_t,std::string> value)26 void ParentDef::AddParentConstraint(std::string field_name, std::variant<int64_t, std::string> value) {
27   // NOTE: This could end up being very slow if there are a lot of constraints.
28   const auto& parent_params = parent_->GetParamList();
29   const auto& constrained_field = parent_params.GetField(field_name);
30   if (constrained_field == nullptr) {
31     ERROR() << "Attempting to constrain field " << field_name << " in parent " << parent_->name_
32             << ", but no such field exists.";
33   }
34 
35   if (constrained_field->GetFieldType() == ScalarField::kFieldType) {
36     if (!std::holds_alternative<int64_t>(value)) {
37       ERROR(constrained_field) << "Attempting to constrain a scalar field to an enum value in " << parent_->name_;
38     }
39   } else if (constrained_field->GetFieldType() == EnumField::kFieldType) {
40     if (!std::holds_alternative<std::string>(value)) {
41       ERROR(constrained_field) << "Attempting to constrain an enum field to a scalar value in " << parent_->name_;
42     }
43     const auto& enum_def = static_cast<EnumField*>(constrained_field)->GetEnumDef();
44     if (!enum_def.HasEntry(std::get<std::string>(value))) {
45       ERROR(constrained_field) << "No matching enumeration \"" << std::get<std::string>(value)
46                                << "\" for constraint on enum in parent " << parent_->name_ << ".";
47     }
48 
49     // For enums, we have to qualify the value using the enum type name.
50     value = enum_def.GetTypeName() + "::" + std::get<std::string>(value);
51   } else {
52     ERROR(constrained_field) << "Field in parent " << parent_->name_ << " is not viable for constraining.";
53   }
54 
55   parent_constraints_.insert(std::pair(field_name, value));
56 }
57 
AddTestCase(std::string packet_bytes)58 void ParentDef::AddTestCase(std::string packet_bytes) {
59   test_cases_.insert(std::move(packet_bytes));
60 }
61 
62 // Assign all size fields to their corresponding variable length fields.
63 // Will crash if
64 //  - there aren't any fields that don't match up to a field.
65 //  - the size field points to a fixed size field.
66 //  - if the size field comes after the variable length field.
AssignSizeFields()67 void ParentDef::AssignSizeFields() {
68   for (const auto& field : fields_) {
69     DEBUG() << "field name: " << field->GetName();
70 
71     if (field->GetFieldType() != SizeField::kFieldType && field->GetFieldType() != CountField::kFieldType) {
72       continue;
73     }
74 
75     const SizeField* size_field = static_cast<SizeField*>(field);
76     // Check to see if a corresponding field can be found.
77     const auto& var_len_field = fields_.GetField(size_field->GetSizedFieldName());
78     if (var_len_field == nullptr) {
79       ERROR(field) << "Could not find corresponding field for size/count field.";
80     }
81 
82     // Do the ordering check to ensure the size field comes before the
83     // variable length field.
84     for (auto it = fields_.begin(); *it != size_field; it++) {
85       DEBUG() << "field name: " << (*it)->GetName();
86       if (*it == var_len_field) {
87         ERROR(var_len_field, size_field) << "Size/count field must come before the variable length field it describes.";
88       }
89     }
90 
91     if (var_len_field->GetFieldType() == PayloadField::kFieldType) {
92       const auto& payload_field = static_cast<PayloadField*>(var_len_field);
93       payload_field->SetSizeField(size_field);
94       continue;
95     }
96 
97     if (var_len_field->GetFieldType() == BodyField::kFieldType) {
98       const auto& body_field = static_cast<BodyField*>(var_len_field);
99       body_field->SetSizeField(size_field);
100       continue;
101     }
102 
103     if (var_len_field->GetFieldType() == VectorField::kFieldType) {
104       const auto& vector_field = static_cast<VectorField*>(var_len_field);
105       vector_field->SetSizeField(size_field);
106       continue;
107     }
108 
109     // If we've reached this point then the field wasn't a variable length field.
110     // Check to see if the field is a variable length field
111     ERROR(field, size_field) << "Can not use size/count in reference to a fixed size field.\n";
112   }
113 }
114 
SetEndianness(bool is_little_endian)115 void ParentDef::SetEndianness(bool is_little_endian) {
116   is_little_endian_ = is_little_endian;
117 }
118 
119 // Get the size. You scan specify without_payload in order to exclude payload fields as children will be overriding it.
GetSize(bool without_payload) const120 Size ParentDef::GetSize(bool without_payload) const {
121   auto size = Size(0);
122 
123   for (const auto& field : fields_) {
124     if (without_payload &&
125         (field->GetFieldType() == PayloadField::kFieldType || field->GetFieldType() == BodyField::kFieldType)) {
126       continue;
127     }
128 
129     // The offset to the field must be passed in as an argument for dynamically sized custom fields.
130     if (field->GetFieldType() == CustomField::kFieldType && field->GetSize().has_dynamic()) {
131       std::stringstream custom_field_size;
132 
133       // Custom fields are special as their size field takes an argument.
134       custom_field_size << field->GetSize().dynamic_string() << "(begin()";
135 
136       // Check if we can determine offset from begin(), otherwise error because by this point,
137       // the size of the custom field is unknown and can't be subtracted from end() to get the
138       // offset.
139       auto offset = GetOffsetForField(field->GetName(), false);
140       if (offset.empty()) {
141         ERROR(field) << "Custom Field offset can not be determined from begin().";
142       }
143 
144       if (offset.bits() % 8 != 0) {
145         ERROR(field) << "Custom fields must be byte aligned.";
146       }
147       if (offset.has_bits()) custom_field_size << " + " << offset.bits() / 8;
148       if (offset.has_dynamic()) custom_field_size << " + " << offset.dynamic_string();
149       custom_field_size << ")";
150 
151       size += custom_field_size.str();
152       continue;
153     }
154 
155     size += field->GetSize();
156   }
157 
158   if (parent_ != nullptr) {
159     size += parent_->GetSize(true);
160   }
161 
162   return size;
163 }
164 
165 // Get the offset until the field is reached, if there is no field
166 // returns an empty Size. from_end requests the offset to the field
167 // starting from the end() iterator. If there is a field with an unknown
168 // size along the traversal, then an empty size is returned.
GetOffsetForField(std::string field_name,bool from_end) const169 Size ParentDef::GetOffsetForField(std::string field_name, bool from_end) const {
170   // Check first if the field exists.
171   if (fields_.GetField(field_name) == nullptr) {
172     ERROR() << "Can't find a field offset for nonexistent field named: " << field_name << " in " << name_;
173   }
174 
175   PacketField* padded_field = nullptr;
176   {
177     PacketField* last_field = nullptr;
178     for (const auto field : fields_) {
179       if (field->GetFieldType() == PaddingField::kFieldType) {
180         padded_field = last_field;
181       }
182       last_field = field;
183     }
184   }
185 
186   // We have to use a generic lambda to conditionally change iteration direction
187   // due to iterator and reverse_iterator being different types.
188   auto size_lambda = [&](auto from, auto to) -> Size {
189     auto size = Size(0);
190     for (auto it = from; it != to; it++) {
191       // We've reached the field, end the loop.
192       if ((*it)->GetName() == field_name) break;
193       const auto& field = *it;
194       // If there is a field with an unknown size before the field, return an empty Size.
195       if (field->GetSize().empty() && padded_field != field) {
196         return Size();
197       }
198       if (field != padded_field) {
199         if (!from_end || field->GetFieldType() != PaddingField::kFieldType) {
200           size += field->GetSize();
201         }
202       }
203     }
204     return size;
205   };
206 
207   // Change iteration direction based on from_end.
208   auto size = Size();
209   if (from_end)
210     size = size_lambda(fields_.rbegin(), fields_.rend());
211   else
212     size = size_lambda(fields_.begin(), fields_.end());
213   if (size.empty()) return size;
214 
215   // We need the offset until a payload or body field.
216   if (parent_ != nullptr) {
217     if (parent_->fields_.HasPayload()) {
218       auto parent_payload_offset = parent_->GetOffsetForField("payload", from_end);
219       if (parent_payload_offset.empty()) {
220         ERROR() << "Empty offset for payload in " << parent_->name_ << " finding the offset for field: " << field_name;
221       }
222       size += parent_payload_offset;
223     } else {
224       auto parent_body_offset = parent_->GetOffsetForField("body", from_end);
225       if (parent_body_offset.empty()) {
226         ERROR() << "Empty offset for body in " << parent_->name_ << " finding the offset for field: " << field_name;
227       }
228       size += parent_body_offset;
229     }
230   }
231 
232   return size;
233 }
234 
GetParamList() const235 FieldList ParentDef::GetParamList() const {
236   FieldList params;
237 
238   std::set<std::string> param_types = {
239       ScalarField::kFieldType,
240       EnumField::kFieldType,
241       ArrayField::kFieldType,
242       VectorField::kFieldType,
243       CustomField::kFieldType,
244       StructField::kFieldType,
245       VariableLengthStructField::kFieldType,
246       PayloadField::kFieldType,
247   };
248 
249   if (parent_ != nullptr) {
250     auto parent_params = parent_->GetParamList().GetFieldsWithTypes(param_types);
251 
252     // Do not include constrained fields in the params
253     for (const auto& field : parent_params) {
254       if (parent_constraints_.find(field->GetName()) == parent_constraints_.end()) {
255         params.AppendField(field);
256       }
257     }
258   }
259   // Add our parameters.
260   return params.Merge(fields_.GetFieldsWithTypes(param_types));
261 }
262 
GenMembers(std::ostream & s) const263 void ParentDef::GenMembers(std::ostream& s) const {
264   // Add the parameter list.
265   for (const auto& field : fields_) {
266     if (field->GenBuilderMember(s)) {
267       s << "_{};";
268     }
269   }
270 }
271 
GenSize(std::ostream & s) const272 void ParentDef::GenSize(std::ostream& s) const {
273   auto header_fields = fields_.GetFieldsBeforePayloadOrBody();
274   auto footer_fields = fields_.GetFieldsAfterPayloadOrBody();
275 
276   Size padded_size;
277   const PacketField* padded_field = nullptr;
278   const PacketField* last_field = nullptr;
279   for (const auto& field : fields_) {
280     if (field->GetFieldType() == PaddingField::kFieldType) {
281       if (!padded_size.empty()) {
282         ERROR() << "Only one padding field is allowed.  Second field: " << field->GetName();
283       }
284       padded_field = last_field;
285       padded_size = field->GetSize();
286     }
287     last_field = field;
288   }
289 
290   s << "protected:";
291   s << "size_t BitsOfHeader() const {";
292   s << "return 0";
293 
294   if (parent_ != nullptr) {
295     if (parent_->GetDefinitionType() == Type::PACKET) {
296       s << " + " << parent_->name_ << "Builder::BitsOfHeader() ";
297     } else {
298       s << " + " << parent_->name_ << "::BitsOfHeader() ";
299     }
300   }
301 
302   for (const auto& field : header_fields) {
303     if (field == padded_field) {
304       s << " + " << padded_size;
305     } else {
306       s << " + " << field->GetBuilderSize();
307     }
308   }
309   s << ";";
310 
311   s << "}\n\n";
312 
313   s << "size_t BitsOfFooter() const {";
314   s << "return 0";
315   for (const auto& field : footer_fields) {
316     if (field == padded_field) {
317       s << " + " << padded_size;
318     } else {
319       s << " + " << field->GetBuilderSize();
320     }
321   }
322 
323   if (parent_ != nullptr) {
324     if (parent_->GetDefinitionType() == Type::PACKET) {
325       s << " + " << parent_->name_ << "Builder::BitsOfFooter() ";
326     } else {
327       s << " + " << parent_->name_ << "::BitsOfFooter() ";
328     }
329   }
330   s << ";";
331   s << "}\n\n";
332 
333   if (fields_.HasPayload()) {
334     s << "size_t GetPayloadSize() const {";
335     s << "if (payload_ != nullptr) {return payload_->size();}";
336     s << "else { return size() - (BitsOfHeader() + BitsOfFooter()) / 8;}";
337     s << ";}\n\n";
338   }
339 
340   s << "public:";
341   s << "virtual size_t size() const override {";
342   s << "return (BitsOfHeader() / 8)";
343   if (fields_.HasPayload()) {
344     s << "+ payload_->size()";
345   }
346   if (fields_.HasBody()) {
347     for (const auto& field : header_fields) {
348       if (field->GetFieldType() == SizeField::kFieldType) {
349         const auto& field_name = ((SizeField*)field)->GetSizedFieldName();
350         if (field_name == "body") {
351           s << "+ body_size_extracted_";
352         }
353       }
354     }
355   }
356   s << " + (BitsOfFooter() / 8);";
357   s << "}\n";
358 }
359 
GenSerialize(std::ostream & s) const360 void ParentDef::GenSerialize(std::ostream& s) const {
361   auto header_fields = fields_.GetFieldsBeforePayloadOrBody();
362   auto footer_fields = fields_.GetFieldsAfterPayloadOrBody();
363 
364   s << "protected:";
365   s << "void SerializeHeader(BitInserter&";
366   if (parent_ != nullptr || header_fields.size() != 0) {
367     s << " i ";
368   }
369   s << ") const {";
370 
371   if (parent_ != nullptr) {
372     if (parent_->GetDefinitionType() == Type::PACKET) {
373       s << parent_->name_ << "Builder::SerializeHeader(i);";
374     } else {
375       s << parent_->name_ << "::SerializeHeader(i);";
376     }
377   }
378 
379   const PacketField* padded_field = nullptr;
380   {
381     PacketField* last_field = nullptr;
382     for (const auto field : header_fields) {
383       if (field->GetFieldType() == PaddingField::kFieldType) {
384         padded_field = last_field;
385       }
386       last_field = field;
387     }
388   }
389 
390   for (const auto& field : header_fields) {
391     if (field->GetFieldType() == SizeField::kFieldType) {
392       const auto& field_name = ((SizeField*)field)->GetSizedFieldName();
393       const auto& sized_field = fields_.GetField(field_name);
394       if (sized_field == nullptr) {
395         ERROR(field) << __func__ << ": Can't find sized field named " << field_name;
396       }
397       if (sized_field->GetFieldType() == PayloadField::kFieldType) {
398         s << "size_t payload_bytes = GetPayloadSize();";
399         std::string modifier = ((PayloadField*)sized_field)->size_modifier_;
400         if (modifier != "") {
401           s << "static_assert((" << modifier << ")%8 == 0, \"Modifiers must be byte-aligned\");";
402           s << "payload_bytes = payload_bytes + (" << modifier << ") / 8;";
403         }
404         s << "ASSERT(payload_bytes < (static_cast<size_t>(1) << " << field->GetSize().bits() << "));";
405         s << "insert(static_cast<" << field->GetDataType() << ">(payload_bytes), i," << field->GetSize().bits() << ");";
406       } else if (sized_field->GetFieldType() == BodyField::kFieldType) {
407         s << field->GetName() << "_extracted_ = 0;";
408         s << "size_t local_size = " << name_ << "::size();";
409 
410         s << "ASSERT((size() - local_size) < (static_cast<size_t>(1) << " << field->GetSize().bits() << "));";
411         s << "insert(static_cast<" << field->GetDataType() << ">(size() - local_size), i," << field->GetSize().bits()
412           << ");";
413       } else {
414         if (sized_field->GetFieldType() != VectorField::kFieldType) {
415           ERROR(field) << __func__ << ": Unhandled sized field type for " << field_name;
416         }
417         const auto& vector_name = field_name + "_";
418         const VectorField* vector = (VectorField*)sized_field;
419         s << "size_t " << vector_name + "bytes = 0;";
420         if (vector->element_size_.empty() || vector->element_size_.has_dynamic()) {
421           s << "for (auto elem : " << vector_name << ") {";
422           s << vector_name + "bytes += elem.size(); }";
423         } else {
424           s << vector_name + "bytes = ";
425           s << vector_name << ".size() * ((" << vector->element_size_ << ") / 8);";
426         }
427         std::string modifier = vector->GetSizeModifier();
428         if (modifier != "") {
429           s << "static_assert((" << modifier << ")%8 == 0, \"Modifiers must be byte-aligned\");";
430           s << vector_name << "bytes = ";
431           s << vector_name << "bytes + (" << modifier << ") / 8;";
432         }
433         s << "ASSERT(" << vector_name + "bytes < (1 << " << field->GetSize().bits() << "));";
434         s << "insert(" << vector_name << "bytes, i, ";
435         s << field->GetSize().bits() << ");";
436       }
437     } else if (field->GetFieldType() == ChecksumStartField::kFieldType) {
438       const auto& field_name = ((ChecksumStartField*)field)->GetStartedFieldName();
439       const auto& started_field = fields_.GetField(field_name);
440       if (started_field == nullptr) {
441         ERROR(field) << __func__ << ": Can't find checksum field named " << field_name << "(" << field->GetName()
442                      << ")";
443       }
444       s << "auto shared_checksum_ptr = std::make_shared<" << started_field->GetDataType() << ">();";
445       s << "shared_checksum_ptr->Initialize();";
446       s << "i.RegisterObserver(packet::ByteObserver(";
447       s << "[shared_checksum_ptr](uint8_t byte){ shared_checksum_ptr->AddByte(byte);},";
448       s << "[shared_checksum_ptr](){ return static_cast<uint64_t>(shared_checksum_ptr->GetChecksum());}));";
449     } else if (field->GetFieldType() == PaddingField::kFieldType) {
450       s << "ASSERT(unpadded_size <= " << field->GetSize().bytes() << ");";
451       s << "size_t padding_bytes = ";
452       s << field->GetSize().bytes() << " - unpadded_size;";
453       s << "for (size_t padding = 0; padding < padding_bytes; padding++) {i.insert_byte(0);}";
454     } else if (field->GetFieldType() == CountField::kFieldType) {
455       const auto& vector_name = ((SizeField*)field)->GetSizedFieldName() + "_";
456       s << "insert(" << vector_name << ".size(), i, " << field->GetSize().bits() << ");";
457     } else {
458       if (field == padded_field) {
459         s << "size_t unpadded_size = (" << field->GetBuilderSize() << ") / 8;";
460       }
461       field->GenInserter(s);
462     }
463   }
464   s << "}\n\n";
465 
466   s << "void SerializeFooter(BitInserter&";
467   if (parent_ != nullptr || footer_fields.size() != 0) {
468     s << " i ";
469   }
470   s << ") const {";
471 
472   for (const auto& field : footer_fields) {
473     field->GenInserter(s);
474   }
475   if (parent_ != nullptr) {
476     if (parent_->GetDefinitionType() == Type::PACKET) {
477       s << parent_->name_ << "Builder::SerializeFooter(i);";
478     } else {
479       s << parent_->name_ << "::SerializeFooter(i);";
480     }
481   }
482   s << "}\n\n";
483 
484   s << "public:";
485   s << "virtual void Serialize(BitInserter& i) const override {";
486   s << "SerializeHeader(i);";
487   if (fields_.HasPayload()) {
488     s << "payload_->Serialize(i);";
489   }
490   s << "SerializeFooter(i);";
491 
492   s << "}\n";
493 }
494 
GenInstanceOf(std::ostream & s) const495 void ParentDef::GenInstanceOf(std::ostream& s) const {
496   if (parent_ != nullptr && parent_constraints_.size() > 0) {
497     s << "static bool IsInstance(const " << parent_->name_ << "& parent) {";
498     // Get the list of parent params.
499     FieldList parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
500         PayloadField::kFieldType,
501         BodyField::kFieldType,
502     });
503 
504     // Check if constrained parent fields are set to their correct values.
505     for (const auto& field : parent_params) {
506       const auto& constraint = parent_constraints_.find(field->GetName());
507       if (constraint != parent_constraints_.end()) {
508         s << "if (parent." << field->GetName() << "_ != ";
509         if (field->GetFieldType() == ScalarField::kFieldType) {
510           s << std::get<int64_t>(constraint->second) << ")";
511           s << "{ return false;}";
512         } else if (field->GetFieldType() == EnumField::kFieldType) {
513           s << std::get<std::string>(constraint->second) << ")";
514           s << "{ return false;}";
515         } else {
516           ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
517         }
518       }
519     }
520     s << "return true;}";
521   }
522 }
523 
GetRootDef() const524 const ParentDef* ParentDef::GetRootDef() const {
525   if (parent_ == nullptr) {
526     return this;
527   }
528 
529   return parent_->GetRootDef();
530 }
531 
GetAncestors() const532 std::vector<const ParentDef*> ParentDef::GetAncestors() const {
533   std::vector<const ParentDef*> res;
534   auto parent = parent_;
535   while (parent != nullptr) {
536     res.push_back(parent);
537     parent = parent->parent_;
538   }
539   std::reverse(res.begin(), res.end());
540   return res;
541 }
542 
GetAllConstraints() const543 std::map<std::string, std::variant<int64_t, std::string>> ParentDef::GetAllConstraints() const {
544   std::map<std::string, std::variant<int64_t, std::string>> res;
545   res.insert(parent_constraints_.begin(), parent_constraints_.end());
546   for (auto parent : GetAncestors()) {
547     res.insert(parent->parent_constraints_.begin(), parent->parent_constraints_.end());
548   }
549   return res;
550 }
551 
HasAncestorNamed(std::string name) const552 bool ParentDef::HasAncestorNamed(std::string name) const {
553   auto parent = parent_;
554   while (parent != nullptr) {
555     if (parent->name_ == name) {
556       return true;
557     }
558     parent = parent->parent_;
559   }
560   return false;
561 }
562 
FindConstraintField() const563 std::string ParentDef::FindConstraintField() const {
564   std::string res;
565   for (const auto& child : children_) {
566     if (!child->parent_constraints_.empty()) {
567       return child->parent_constraints_.begin()->first;
568     }
569     res = child->FindConstraintField();
570   }
571   return res;
572 }
573 
574 std::map<const ParentDef*, const std::variant<int64_t, std::string>>
FindDescendantsWithConstraint(std::string constraint_name) const575     ParentDef::FindDescendantsWithConstraint(
576     std::string constraint_name) const {
577   std::map<const ParentDef*, const std::variant<int64_t, std::string>> res;
578 
579   for (auto const& child : children_) {
580     auto constraint = child->parent_constraints_.find(constraint_name);
581     if (constraint != child->parent_constraints_.end()) {
582       res.insert(std::pair(child, constraint->second));
583     }
584     auto m = child->FindDescendantsWithConstraint(constraint_name);
585     res.insert(m.begin(), m.end());
586   }
587   return res;
588 }
589 
FindPathToDescendant(std::string descendant) const590 std::vector<const ParentDef*> ParentDef::FindPathToDescendant(std::string descendant) const {
591   std::vector<const ParentDef*> res;
592 
593   for (auto const& child : children_) {
594     auto v = child->FindPathToDescendant(descendant);
595     if (v.size() > 0) {
596       res.insert(res.begin(), v.begin(), v.end());
597       res.push_back(child);
598     }
599     if (child->name_ == descendant) {
600       res.push_back(child);
601       return res;
602     }
603   }
604   return res;
605 }
606 
HasChildEnums() const607 bool ParentDef::HasChildEnums() const {
608   return !children_.empty() || fields_.HasPayload();
609 }
610 
GenRustConformanceCheck(std::ostream & s) const611 void ParentDef::GenRustConformanceCheck(std::ostream& s) const {
612   auto fields = fields_.GetFieldsWithTypes({
613       FixedScalarField::kFieldType,
614   });
615 
616   for (auto const& field : fields) {
617     auto start_offset = GetOffsetForField(field->GetName(), false);
618     auto end_offset = GetOffsetForField(field->GetName(), true);
619 
620     auto f = (FixedScalarField*)field;
621     f->GenRustGetter(s, start_offset, end_offset);
622     s << "if " << f->GetName() << " != ";
623     f->GenValue(s);
624     s << " { return false; } ";
625   }
626 }
627 
GenRustWriteToFields(std::ostream & s) const628 void ParentDef::GenRustWriteToFields(std::ostream& s) const {
629   auto fields = fields_.GetFieldsWithoutTypes({
630       BodyField::kFieldType,
631       PaddingField::kFieldType,
632       ReservedField::kFieldType,
633   });
634 
635   for (auto const& field : fields) {
636     auto start_field_offset = GetOffsetForField(field->GetName(), false);
637     auto end_field_offset = GetOffsetForField(field->GetName(), true);
638 
639     if (start_field_offset.empty() && end_field_offset.empty()) {
640       ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, "
641                    << "no method exists to determine field location from begin() or end().\n";
642     }
643 
644     if (field->GetFieldType() == SizeField::kFieldType) {
645       const auto& field_name = ((SizeField*)field)->GetSizedFieldName();
646       const auto& sized_field = fields_.GetField(field_name);
647       if (sized_field == nullptr) {
648         ERROR(field) << __func__ << ": Can't find sized field named " << field_name;
649       }
650       if (sized_field->GetFieldType() == PayloadField::kFieldType) {
651         std::string modifier = ((PayloadField*)sized_field)->size_modifier_;
652         if (modifier != "") {
653           ERROR(field) << __func__ << ": size modifiers not implemented yet for " << field_name;
654         }
655 
656         s << "let " << field->GetName() << " = " << field->GetRustDataType()
657           << "::try_from(self.child.get_total_size()).expect(\"payload size did not fit\");";
658       } else if (sized_field->GetFieldType() == BodyField::kFieldType) {
659         s << "let " << field->GetName() << " = " << field->GetRustDataType()
660           << "::try_from(self.get_total_size() - self.get_size()).expect(\"payload size did not fit\");";
661       } else if (sized_field->GetFieldType() == VectorField::kFieldType) {
662         const auto& vector_name = field_name + "_bytes";
663         const VectorField* vector = (VectorField*)sized_field;
664         if (vector->element_size_.empty() || vector->element_size_.has_dynamic()) {
665           s << "let " << vector_name + " = self." << field_name
666             << ".iter().fold(0, |acc, x| acc + x.get_total_size());";
667         } else {
668           s << "let " << vector_name + " = self." << field_name << ".len() * ((" << vector->element_size_ << ") / 8);";
669         }
670         std::string modifier = vector->GetSizeModifier();
671         if (modifier != "") {
672           s << "let " << vector_name << " = " << vector_name << " + (" << modifier.substr(1) << ") / 8;";
673         }
674 
675         s << "let " << field->GetName() << " = " << field->GetRustDataType() << "::try_from(" << vector_name
676           << ").expect(\"payload size did not fit\");";
677       } else {
678         ERROR(field) << __func__ << ": Unhandled sized field type for " << field_name;
679       }
680     }
681 
682     field->GenRustWriter(s, start_field_offset, end_field_offset);
683   }
684 }
685 
GenSizeRetVal(std::ostream & s) const686 void ParentDef::GenSizeRetVal(std::ostream& s) const {
687   int size = 0;
688   auto fields = fields_.GetFieldsWithoutTypes({
689       BodyField::kFieldType,
690   });
691   const PacketField* padded_field = nullptr;
692   auto padding_fields = fields_.GetFieldsWithTypes({
693       PaddingField::kFieldType,
694   });
695   if (padding_fields.size()) {
696     PacketField* last_field = nullptr;
697     for (const auto field : fields) {
698       if (field->GetFieldType() == PaddingField::kFieldType) {
699         padded_field = last_field;
700       }
701       last_field = field;
702     }
703   }
704 
705   s << "let ret = 0;";
706   for (const auto field : fields) {
707     bool is_vector = field->GetFieldType() == VectorField::kFieldType;
708     if (field != padded_field) {  // Skip the size of padded fields
709       if (is_vector) {
710         if (size > 0) {
711           if (size % 8 != 0) {
712             ERROR() << "size is not a multiple of 8!\n";
713           }
714           s << "let ret = ret + " << size / 8 << ";";
715           size = 0;
716         }
717 
718         const VectorField* vector = (VectorField*)field;
719         if (vector->element_size_.empty() || vector->element_size_.has_dynamic()) {
720           s << "let ret = ret + self." << vector->GetName() << ".iter().fold(0, |acc, x| acc + x.get_total_size());";
721         } else {
722           s << "let ret = ret + (self." << vector->GetName() << ".len() * ((" << vector->element_size_ << ") / 8));";
723         }
724       } else {
725         size += field->GetSize().bits();
726       }
727     } else {
728       s << "/* Skipping " << field->GetName() << " since it is padded */";
729     }
730   }
731   if (size > 0) {
732     if (size % 8 != 0) {
733       ERROR() << "size is not a multiple of 8!\n";
734     }
735     s << "let ret = ret + " << size / 8 << ";";
736   }
737 
738   s << "ret";
739 }
740