1 /*
2  * Copyright 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "packet_def.h"
18 
19 #include <iomanip>
20 #include <list>
21 #include <set>
22 
23 #include "fields/all_fields.h"
24 #include "packet_dependency.h"
25 #include "util.h"
26 
PacketDef(std::string name,FieldList fields)27 PacketDef::PacketDef(std::string name, FieldList fields) : ParentDef(name, fields, nullptr) {}
PacketDef(std::string name,FieldList fields,PacketDef * parent)28 PacketDef::PacketDef(std::string name, FieldList fields, PacketDef* parent) : ParentDef(name, fields, parent) {}
29 
GetNewField(const std::string &,ParseLocation) const30 PacketField* PacketDef::GetNewField(const std::string&, ParseLocation) const {
31   return nullptr;  // Packets can't be fields
32 }
33 
GenParserDefinition(std::ostream & s,bool generate_fuzzing,bool generate_tests) const34 void PacketDef::GenParserDefinition(std::ostream& s, bool generate_fuzzing, bool generate_tests) const {
35   s << "class " << name_ << "View";
36   if (parent_ != nullptr) {
37     s << " : public " << parent_->name_ << "View {";
38   } else {
39     s << " : public PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> {";
40   }
41   s << " public:";
42 
43   // Specialize function
44   if (parent_ != nullptr) {
45     s << "static " << name_ << "View Create(" << parent_->name_ << "View parent)";
46     s << "{ return " << name_ << "View(std::move(parent)); }";
47     // CreateOptional
48     s << "static std::optional<" << name_ << "View> CreateOptional(";
49     s << parent_->name_ << "View parent)";
50     s << "{ auto to_validate = " << name_ << "View::Create(std::move(parent));";
51     s << "if (to_validate.IsValid()) { return to_validate; }";
52     s << "else {return {};}}";
53   } else {
54     s << "static " << name_ << "View Create(PacketView<";
55     s << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet)";
56     s << "{ return " << name_ << "View(std::move(packet)); }";
57     // CreateOptional
58     s << "static std::optional<" << name_ << "View> CreateOptional(PacketView<";
59     s << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet)";
60     s << "{ auto to_validate = " << name_ << "View::Create(std::move(packet));";
61     s << "if (to_validate.IsValid()) { return to_validate; }";
62     s << "else {return {};}}";
63   }
64 
65   if (generate_fuzzing || generate_tests) {
66     GenTestingParserFromBytes(s);
67   }
68 
69   std::set<std::string> fixed_types = {
70       FixedScalarField::kFieldType,
71       FixedEnumField::kFieldType,
72   };
73 
74   // Print all of the public fields which are all the fields minus the fixed fields.
75   const auto& public_fields = fields_.GetFieldsWithoutTypes(fixed_types);
76   bool has_fixed_fields = public_fields.size() != fields_.size();
77   for (const auto& field : public_fields) {
78     GenParserFieldGetter(s, field);
79     s << "\n";
80   }
81   GenValidator(s);
82   s << "\n";
83 
84   s << " public:";
85   GenParserToString(s);
86   s << "\n";
87 
88   s << " protected:\n";
89   // Constructor from a View
90   if (parent_ != nullptr) {
91     s << "explicit " << name_ << "View(" << parent_->name_ << "View parent)";
92     s << " : " << parent_->name_ << "View(std::move(parent)) { was_validated_ = false; }";
93   } else {
94     s << "explicit " << name_ << "View(PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet) ";
95     s << " : PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(packet) { was_validated_ = false;}";
96   }
97 
98   // Print the private fields which are the fixed fields.
99   if (has_fixed_fields) {
100     const auto& private_fields = fields_.GetFieldsWithTypes(fixed_types);
101     s << " private:\n";
102     for (const auto& field : private_fields) {
103       GenParserFieldGetter(s, field);
104       s << "\n";
105     }
106   }
107   s << "};\n";
108 }
109 
GenTestingParserFromBytes(std::ostream & s) const110 void PacketDef::GenTestingParserFromBytes(std::ostream& s) const {
111   s << "\n#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n";
112 
113   s << "static " << name_ << "View FromBytes(std::vector<uint8_t> bytes) {";
114   s << "auto vec = std::make_shared<std::vector<uint8_t>>(bytes);";
115   s << "return " << name_ << "View::Create(";
116   auto ancestor_ptr = parent_;
117   size_t parent_parens = 0;
118   while (ancestor_ptr != nullptr) {
119     s << ancestor_ptr->name_ << "View::Create(";
120     parent_parens++;
121     ancestor_ptr = ancestor_ptr->parent_;
122   }
123   s << "PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(vec)";
124   for (size_t i = 0; i < parent_parens; i++) {
125     s << ")";
126   }
127   s << ");";
128   s << "}";
129 
130   s << "\n#endif\n";
131 }
132 
GenParserDefinitionPybind11(std::ostream & s) const133 void PacketDef::GenParserDefinitionPybind11(std::ostream& s) const {
134   s << "py::class_<" << name_ << "View";
135   if (parent_ != nullptr) {
136     s << ", " << parent_->name_ << "View";
137   } else {
138     s << ", PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>";
139   }
140   s << ">(m, \"" << name_ << "View\")";
141   if (parent_ != nullptr) {
142     s << ".def(py::init([](" << parent_->name_ << "View parent) {";
143   } else {
144     s << ".def(py::init([](PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> parent) {";
145   }
146   s << "auto view =" << name_ << "View::Create(std::move(parent));";
147   s << "if (!view.IsValid()) { throw std::invalid_argument(\"Bad packet view\"); }";
148   s << "return view; }))";
149 
150   s << ".def(py::init(&" << name_ << "View::Create))";
151   std::set<std::string> protected_field_types = {
152       FixedScalarField::kFieldType,
153       FixedEnumField::kFieldType,
154       SizeField::kFieldType,
155       CountField::kFieldType,
156   };
157   const auto& public_fields = fields_.GetFieldsWithoutTypes(protected_field_types);
158   for (const auto& field : public_fields) {
159     auto getter_func_name = field->GetGetterFunctionName();
160     if (getter_func_name.empty()) {
161       continue;
162     }
163     s << ".def(\"" << getter_func_name << "\", &" << name_ << "View::" << getter_func_name << ")";
164   }
165   s << ".def(\"IsValid\", &" << name_ << "View::IsValid)";
166   s << ";\n";
167 }
168 
GenParserFieldGetter(std::ostream & s,const PacketField * field) const169 void PacketDef::GenParserFieldGetter(std::ostream& s, const PacketField* field) const {
170   // Start field offset
171   auto start_field_offset = GetOffsetForField(field->GetName(), false);
172   auto end_field_offset = GetOffsetForField(field->GetName(), true);
173 
174   if (start_field_offset.empty() && end_field_offset.empty()) {
175     ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, "
176                  << "no method exists to determine field location from begin() or end().\n";
177   }
178 
179   field->GenGetter(s, start_field_offset, end_field_offset);
180 }
181 
GetDefinitionType() const182 TypeDef::Type PacketDef::GetDefinitionType() const {
183   return TypeDef::Type::PACKET;
184 }
185 
GenValidator(std::ostream & s) const186 void PacketDef::GenValidator(std::ostream& s) const {
187   // Get the static offset for all of our fields.
188   int bits_size = 0;
189   for (const auto& field : fields_) {
190     if (field->GetFieldType() != PaddingField::kFieldType) {
191       bits_size += field->GetSize().bits();
192     }
193   }
194 
195   // Generate the public validator IsValid().
196   // The method only needs to be generated for the top most class.
197   if (parent_ == nullptr) {
198     s << "bool IsValid() {" << std::endl;
199     s << "  if (was_validated_) {" << std::endl;
200     s << "    return true;" << std::endl;
201     s << "  } else {" << std::endl;
202     s << "    was_validated_ = true;" << std::endl;
203     s << "    return (was_validated_ = Validate());" << std::endl;
204     s << "  }" << std::endl;
205     s << "}" << std::endl;
206   }
207 
208   // Generate the private validator Validate().
209   // The method is overridden by all child classes.
210   s << "protected:" << std::endl;
211   if (parent_ == nullptr) {
212     s << "virtual bool Validate() const {" << std::endl;
213   } else {
214     s << "bool Validate() const override {" << std::endl;
215     s << "  if (!" << parent_->name_ << "View::Validate()) {" << std::endl;
216     s << "    return false;" << std::endl;
217     s << "  }" << std::endl;
218   }
219 
220   // Offset by the parents known size. We know that any dynamic fields can
221   // already be called since the parent must have already been validated by
222   // this point.
223   auto parent_size = Size(0);
224   if (parent_ != nullptr) {
225     parent_size = parent_->GetSize(true);
226   }
227 
228   s << "auto it = begin() + (" << parent_size << ") / 8;";
229 
230   // Check if you can extract the static fields.
231   // At this point you know you can use the size getters without crashing
232   // as long as they follow the instruction that size fields cant come before
233   // their corrisponding variable length field.
234   s << "it += " << ((bits_size + 7) / 8) << " /* Total size of the fixed fields */;";
235   s << "if (it > end()) return false;";
236 
237   // For any variable length fields, use their size check.
238   for (const auto& field : fields_) {
239     if (field->GetFieldType() == ChecksumStartField::kFieldType) {
240       auto offset = GetOffsetForField(field->GetName(), false);
241       if (!offset.empty()) {
242         s << "size_t sum_index = (" << offset << ") / 8;";
243       } else {
244         offset = GetOffsetForField(field->GetName(), true);
245         if (offset.empty()) {
246           ERROR(field) << "Checksum Start Field offset can not be determined.";
247         }
248         s << "size_t sum_index = size() - (" << offset << ") / 8;";
249       }
250 
251       const auto& field_name = ((ChecksumStartField*)field)->GetStartedFieldName();
252       const auto& started_field = fields_.GetField(field_name);
253       if (started_field == nullptr) {
254         ERROR(field) << __func__ << ": Can't find checksum field named " << field_name << "(" << field->GetName()
255                      << ")";
256       }
257       auto end_offset = GetOffsetForField(started_field->GetName(), false);
258       if (!end_offset.empty()) {
259         s << "size_t end_sum_index = (" << end_offset << ") / 8;";
260       } else {
261         end_offset = GetOffsetForField(started_field->GetName(), true);
262         if (end_offset.empty()) {
263           ERROR(started_field) << "Checksum Field end_offset can not be determined.";
264         }
265         s << "size_t end_sum_index = size() - (" << started_field->GetSize() << " - " << end_offset << ") / 8;";
266       }
267       s << "if (end_sum_index >= size()) { return false; }";
268       if (is_little_endian_) {
269         s << "auto checksum_view = GetLittleEndianSubview(sum_index, end_sum_index);";
270       } else {
271         s << "auto checksum_view = GetBigEndianSubview(sum_index, end_sum_index);";
272       }
273       s << started_field->GetDataType() << " checksum;";
274       s << "checksum.Initialize();";
275       s << "for (uint8_t byte : checksum_view) { ";
276       s << "checksum.AddByte(byte);}";
277       s << "if (checksum.GetChecksum() != (begin() + end_sum_index).extract<"
278         << util::GetTypeForSize(started_field->GetSize().bits()) << ">()) { return false; }";
279 
280       continue;
281     }
282 
283     auto field_size = field->GetSize();
284     // Fixed size fields have already been handled.
285     if (!field_size.has_dynamic()) {
286       continue;
287     }
288 
289     // Custom fields with dynamic size must have the offset for the field passed in as well
290     // as the end iterator so that they may ensure that they don't try to read past the end.
291     // Custom fields with fixed sizes will be handled in the static offset checking.
292     if (field->GetFieldType() == CustomField::kFieldType) {
293       // Check if we can determine offset from begin(), otherwise error because by this point,
294       // the size of the custom field is unknown and can't be subtracted from end() to get the
295       // offset.
296       auto offset = GetOffsetForField(field->GetName(), false);
297       if (offset.empty()) {
298         ERROR(field) << "Custom Field offset can not be determined from begin().";
299       }
300 
301       if (offset.bits() % 8 != 0) {
302         ERROR(field) << "Custom fields must be byte aligned.";
303       }
304 
305       // Custom fields are special as their size field takes an argument.
306       const auto& custom_size_var = field->GetName() + "_size";
307       s << "const auto& " << custom_size_var << " = " << field_size.dynamic_string();
308       s << "(begin() + (" << offset << ") / 8);";
309 
310       s << "if (!" << custom_size_var << ".has_value()) { return false; }";
311       s << "it += *" << custom_size_var << ";";
312       s << "if (it > end()) return false;";
313       continue;
314     } else {
315       s << "it += (" << field_size.dynamic_string() << ") / 8;";
316       s << "if (it > end()) return false;";
317     }
318   }
319 
320   // Validate constraints after validating the size
321   if (parent_constraints_.size() > 0 && parent_ == nullptr) {
322     ERROR() << "Can't have a constraint on a NULL parent";
323   }
324 
325   for (const auto& constraint : parent_constraints_) {
326     s << "if (Get" << util::UnderscoreToCamelCase(constraint.first) << "() != ";
327     const auto& field = parent_->GetParamList().GetField(constraint.first);
328     if (field->GetFieldType() == ScalarField::kFieldType) {
329       s << std::get<int64_t>(constraint.second);
330     } else {
331       s << std::get<std::string>(constraint.second);
332     }
333     s << ") return false;";
334   }
335 
336   // Validate the packets fields last
337   for (const auto& field : fields_) {
338     field->GenValidator(s);
339     s << "\n";
340   }
341 
342   s << "return true;";
343   s << "}\n";
344   if (parent_ == nullptr) {
345     s << "bool was_validated_{false};\n";
346   }
347 }
348 
GenParserToString(std::ostream & s) const349 void PacketDef::GenParserToString(std::ostream& s) const {
350   s << "virtual std::string ToString() const " << (parent_ != nullptr ? " override" : "") << " {";
351   s << "std::stringstream ss;";
352   s << "ss << std::showbase << std::hex << \"" << name_ << " { \";";
353 
354   if (fields_.size() > 0) {
355     s << "ss << \"\" ";
356     bool firstfield = true;
357     for (const auto& field : fields_) {
358       if (field->GetFieldType() == ReservedField::kFieldType || field->GetFieldType() == FixedScalarField::kFieldType ||
359           field->GetFieldType() == ChecksumStartField::kFieldType)
360         continue;
361 
362       s << (firstfield ? " << \"" : " << \", ") << field->GetName() << " = \" << ";
363 
364       field->GenStringRepresentation(s, field->GetGetterFunctionName() + "()");
365 
366       if (firstfield) {
367         firstfield = false;
368       }
369     }
370     s << ";";
371   }
372 
373   s << "ss << \" }\";";
374   s << "return ss.str();";
375   s << "}\n";
376 }
377 
GenBuilderDefinition(std::ostream & s,bool generate_fuzzing,bool generate_tests) const378 void PacketDef::GenBuilderDefinition(std::ostream& s, bool generate_fuzzing, bool generate_tests) const {
379   s << "class " << name_ << "Builder";
380   if (parent_ != nullptr) {
381     s << " : public " << parent_->name_ << "Builder";
382   } else {
383     if (is_little_endian_) {
384       s << " : public PacketBuilder<kLittleEndian>";
385     } else {
386       s << " : public PacketBuilder<!kLittleEndian>";
387     }
388   }
389   s << " {";
390   s << " public:";
391   s << "  virtual ~" << name_ << "Builder() = default;";
392 
393   if (!fields_.HasBody()) {
394     GenBuilderCreate(s);
395     s << "\n";
396 
397     if (generate_fuzzing || generate_tests) {
398       GenTestingFromView(s);
399       s << "\n";
400     }
401   }
402 
403   GenSerialize(s);
404   s << "\n";
405 
406   GenSize(s);
407   s << "\n";
408 
409   s << " protected:\n";
410   GenBuilderConstructor(s);
411   s << "\n";
412 
413   GenBuilderParameterChecker(s);
414   s << "\n";
415 
416   GenMembers(s);
417   s << "};\n";
418 
419   if (generate_tests) {
420     GenTestDefine(s);
421     s << "\n";
422   }
423 
424   if (generate_fuzzing || generate_tests) {
425     GenReflectTestDefine(s);
426     s << "\n";
427   }
428 
429   if (generate_fuzzing) {
430     GenFuzzTestDefine(s);
431     s << "\n";
432   }
433 }
434 
GenTestingFromView(std::ostream & s) const435 void PacketDef::GenTestingFromView(std::ostream& s) const {
436   s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n";
437 
438   s << "static std::unique_ptr<" << name_ << "Builder> FromView(" << name_ << "View view) {";
439   s << "if (!view.IsValid()) return nullptr;";
440   s << "return " << name_ << "Builder::Create(";
441   FieldList params = GetParamList().GetFieldsWithoutTypes({
442       BodyField::kFieldType,
443   });
444   for (std::size_t i = 0; i < params.size(); i++) {
445     params[i]->GenBuilderParameterFromView(s);
446     if (i != params.size() - 1) {
447       s << ", ";
448     }
449   }
450   s << ");";
451   s << "}";
452 
453   s << "\n#endif\n";
454 }
455 
GenBuilderDefinitionPybind11(std::ostream & s) const456 void PacketDef::GenBuilderDefinitionPybind11(std::ostream& s) const {
457   s << "py::class_<" << name_ << "Builder";
458   if (parent_ != nullptr) {
459     s << ", " << parent_->name_ << "Builder";
460   } else {
461     if (is_little_endian_) {
462       s << ", PacketBuilder<kLittleEndian>";
463     } else {
464       s << ", PacketBuilder<!kLittleEndian>";
465     }
466   }
467   s << ", std::shared_ptr<" << name_ << "Builder>";
468   s << ">(m, \"" << name_ << "Builder\")";
469   if (!fields_.HasBody()) {
470     GenBuilderCreatePybind11(s);
471   }
472   s << ".def(\"Serialize\", [](" << name_ << "Builder& builder){";
473   s << "std::vector<uint8_t> bytes;";
474   s << "BitInserter bi(bytes);";
475   s << "builder.Serialize(bi);";
476   s << "return bytes;})";
477   s << ";\n";
478 }
479 
GenTestDefine(std::ostream & s) const480 void PacketDef::GenTestDefine(std::ostream& s) const {
481   s << "#ifdef PACKET_TESTING\n";
482   s << "#define DEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest(...)";
483   s << "class " << name_ << "ReflectionTest : public testing::TestWithParam<std::vector<uint8_t>> { ";
484   s << "public: ";
485   s << "void CompareBytes(std::vector<uint8_t> captured_packet) {";
486   s << name_ << "View view = " << name_ << "View::FromBytes(captured_packet);";
487   s << "if (!view.IsValid()) { log::info(\"Invalid Packet Bytes (size = {})\", view.size());";
488   s << "for (size_t i = 0; i < view.size(); i++) { log::info(\"{:5}:{:02x}\", i, *(view.begin() + "
489        "i)); }}";
490   s << "ASSERT_TRUE(view.IsValid());";
491   s << "auto packet = " << name_ << "Builder::FromView(view);";
492   s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = std::make_shared<std::vector<uint8_t>>();";
493   s << "packet_bytes->reserve(packet->size());";
494   s << "BitInserter it(*packet_bytes);";
495   s << "packet->Serialize(it);";
496   s << "ASSERT_EQ(*packet_bytes, captured_packet);";
497   s << "}";
498   s << "};";
499   s << "TEST_P(" << name_ << "ReflectionTest, generatedReflectionTest) {";
500   s << "CompareBytes(GetParam());";
501   s << "}";
502   s << "INSTANTIATE_TEST_SUITE_P(" << name_ << "_reflection, ";
503   s << name_ << "ReflectionTest, testing::Values(__VA_ARGS__))";
504   int i = 0;
505   for (const auto& bytes : test_cases_) {
506     s << "\nuint8_t " << name_ << "_test_bytes_" << i << "[] = \"" << bytes << "\";";
507     s << "std::vector<uint8_t> " << name_ << "_test_vec_" << i << "(";
508     s << name_ << "_test_bytes_" << i << ",";
509     s << name_ << "_test_bytes_" << i << " + sizeof(";
510     s << name_ << "_test_bytes_" << i << ") - 1);";
511     i++;
512   }
513   if (!test_cases_.empty()) {
514     i = 0;
515     s << "\nDEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest(";
516     for (auto bytes : test_cases_) {
517       if (i > 0) {
518         s << ",";
519       }
520       s << name_ << "_test_vec_" << i++;
521     }
522     s << ");";
523   }
524   s << "\n#endif";
525 }
526 
GenReflectTestDefine(std::ostream & s) const527 void PacketDef::GenReflectTestDefine(std::ostream& s) const {
528   s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING)\n";
529   s << "#define DEFINE_" << name_ << "ReflectionFuzzTest() ";
530   s << "void Run" << name_ << "ReflectionFuzzTest(const uint8_t* data, size_t size) {";
531   s << "auto vec = std::vector<uint8_t>(data, data + size);";
532   s << name_ << "View view = " << name_ << "View::FromBytes(vec);";
533   s << "if (!view.IsValid()) { return; }";
534   s << "auto packet = " << name_ << "Builder::FromView(view);";
535   s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = std::make_shared<std::vector<uint8_t>>();";
536   s << "packet_bytes->reserve(packet->size());";
537   s << "BitInserter it(*packet_bytes);";
538   s << "packet->Serialize(it);";
539   s << "}";
540   s << "\n#endif\n";
541 }
542 
GenFuzzTestDefine(std::ostream & s) const543 void PacketDef::GenFuzzTestDefine(std::ostream& s) const {
544   s << "#ifdef PACKET_FUZZ_TESTING\n";
545   s << "#define DEFINE_AND_REGISTER_" << name_ << "ReflectionFuzzTest(REGISTRY) ";
546   s << "DEFINE_" << name_ << "ReflectionFuzzTest();";
547   s << " class " << name_ << "ReflectionFuzzTestRegistrant {";
548   s << "public: ";
549   s << "explicit " << name_
550     << "ReflectionFuzzTestRegistrant(std::vector<void(*)(const uint8_t*, size_t)>& fuzz_test_registry) {";
551   s << "fuzz_test_registry.push_back(Run" << name_ << "ReflectionFuzzTest);";
552   s << "}}; ";
553   s << name_ << "ReflectionFuzzTestRegistrant " << name_ << "_reflection_fuzz_test_registrant(REGISTRY);";
554   s << "\n#endif";
555 }
556 
GetParametersToValidate() const557 FieldList PacketDef::GetParametersToValidate() const {
558   FieldList params_to_validate;
559   for (const auto& field : GetParamList()) {
560     if (field->HasParameterValidator()) {
561       params_to_validate.AppendField(field);
562     }
563   }
564   return params_to_validate;
565 }
566 
GenBuilderCreate(std::ostream & s) const567 void PacketDef::GenBuilderCreate(std::ostream& s) const {
568   s << "static std::unique_ptr<" << name_ << "Builder> Create(";
569 
570   auto params = GetParamList();
571   for (std::size_t i = 0; i < params.size(); i++) {
572     params[i]->GenBuilderParameter(s);
573     if (i != params.size() - 1) {
574       s << ", ";
575     }
576   }
577   s << ") {";
578 
579   // Call the constructor
580   s << "auto builder = std::unique_ptr<" << name_ << "Builder>(new " << name_ << "Builder(";
581 
582   params = params.GetFieldsWithoutTypes({
583       PayloadField::kFieldType,
584       BodyField::kFieldType,
585   });
586   // Add the parameters.
587   for (std::size_t i = 0; i < params.size(); i++) {
588     if (params[i]->BuilderParameterMustBeMoved()) {
589       s << "std::move(" << params[i]->GetName() << ")";
590     } else {
591       s << params[i]->GetName();
592     }
593     if (i != params.size() - 1) {
594       s << ", ";
595     }
596   }
597 
598   s << "));";
599   if (fields_.HasPayload()) {
600     s << "builder->payload_ = std::move(payload);";
601   }
602   s << "return builder;";
603   s << "}\n";
604 }
605 
GenBuilderCreatePybind11(std::ostream & s) const606 void PacketDef::GenBuilderCreatePybind11(std::ostream& s) const {
607   s << ".def(py::init([](";
608   auto params = GetParamList();
609   std::vector<std::string> constructor_args;
610   for (const auto& param : params) {
611     std::stringstream ss;
612     auto param_type = param->GetBuilderParameterType();
613     if (param_type.empty()) {
614       continue;
615     }
616     // Use shared_ptr instead of unique_ptr for the Python interface
617     if (param->BuilderParameterMustBeMoved()) {
618       param_type = util::StringFindAndReplaceAll(param_type, "unique_ptr", "shared_ptr");
619     }
620     ss << param_type << " " << param->GetName();
621     constructor_args.push_back(ss.str());
622   }
623   s << util::StringJoin(",", constructor_args) << "){";
624 
625   // Deal with move only args
626   for (const auto& param : params) {
627     std::stringstream ss;
628     auto param_type = param->GetBuilderParameterType();
629     if (param_type.empty()) {
630       continue;
631     }
632     if (!param->BuilderParameterMustBeMoved()) {
633       continue;
634     }
635     auto move_only_param_name = param->GetName() + "_move_only";
636     s << param_type << " " << move_only_param_name << ";";
637     if (param->IsContainerField()) {
638       // Assume single layer container and copy it
639       auto struct_type = param->GetElementField()->GetDataType();
640       struct_type = util::StringFindAndReplaceAll(struct_type, "std::unique_ptr<", "");
641       struct_type = util::StringFindAndReplaceAll(struct_type, ">", "");
642       s << "for (size_t i = 0; i < " << param->GetName() << ".size(); i++) {";
643       // Serialize each struct
644       s << "auto " << param->GetName() + "_bytes = std::make_shared<std::vector<uint8_t>>();";
645       s << param->GetName() + "_bytes->reserve(" << param->GetName() << "[i]->size());";
646       s << "BitInserter " << param->GetName() + "_bi(*" << param->GetName() << "_bytes);";
647       s << param->GetName() << "[i]->Serialize(" << param->GetName() << "_bi);";
648       // Parse it again
649       s << "auto " << param->GetName() << "_view = PacketView<kLittleEndian>(" << param->GetName() << "_bytes);";
650       s << param->GetElementField()->GetDataType() << " " << param->GetName() << "_reparsed = ";
651       s << "Parse" << struct_type << "(" << param->GetName() + "_view.begin());";
652       // Push it into a new container
653       if (param->GetFieldType() == VectorField::kFieldType) {
654         s << move_only_param_name << ".push_back(std::move(" << param->GetName() + "_reparsed));";
655       } else if (param->GetFieldType() == ArrayField::kFieldType) {
656         s << move_only_param_name << "[i] = std::move(" << param->GetName() << "_reparsed);";
657       } else {
658         ERROR() << param << " is not supported by Pybind11";
659       }
660       s << "}";
661     } else {
662       // Serialize the parameter and pass the bytes in a RawBuilder
663       s << "std::vector<uint8_t> " << param->GetName() + "_bytes;";
664       s << param->GetName() + "_bytes.reserve(" << param->GetName() << "->size());";
665       s << "BitInserter " << param->GetName() + "_bi(" << param->GetName() << "_bytes);";
666       s << param->GetName() << "->Serialize(" << param->GetName() + "_bi);";
667       s << move_only_param_name << " = ";
668       s << "std::make_unique<RawBuilder>(" << param->GetName() << "_bytes);";
669     }
670   }
671   s << "return " << name_ << "Builder::Create(";
672   std::vector<std::string> builder_vars;
673   for (const auto& param : params) {
674     std::stringstream ss;
675     auto param_type = param->GetBuilderParameterType();
676     if (param_type.empty()) {
677       continue;
678     }
679     auto param_name = param->GetName();
680     if (param->BuilderParameterMustBeMoved()) {
681       ss << "std::move(" << param_name << "_move_only)";
682     } else {
683       ss << param_name;
684     }
685     builder_vars.push_back(ss.str());
686   }
687   s << util::StringJoin(",", builder_vars) << ");}";
688   s << "))";
689 }
690 
GenBuilderParameterChecker(std::ostream & s) const691 void PacketDef::GenBuilderParameterChecker(std::ostream& s) const {
692   FieldList params_to_validate = GetParametersToValidate();
693 
694   // Skip writing this function if there is nothing to validate.
695   if (params_to_validate.size() == 0) {
696     return;
697   }
698 
699   // Generate function arguments.
700   s << "void CheckParameterValues(";
701   for (std::size_t i = 0; i < params_to_validate.size(); i++) {
702     params_to_validate[i]->GenBuilderParameter(s);
703     if (i != params_to_validate.size() - 1) {
704       s << ", ";
705     }
706   }
707   s << ") {";
708 
709   // Check the parameters.
710   for (const auto& field : params_to_validate) {
711     field->GenParameterValidator(s);
712   }
713   s << "}\n";
714 }
715 
GenBuilderConstructor(std::ostream & s) const716 void PacketDef::GenBuilderConstructor(std::ostream& s) const {
717   s << "explicit " << name_ << "Builder(";
718 
719   // Generate the constructor parameters.
720   auto params = GetParamList().GetFieldsWithoutTypes({
721       PayloadField::kFieldType,
722       BodyField::kFieldType,
723   });
724   for (std::size_t i = 0; i < params.size(); i++) {
725     params[i]->GenBuilderParameter(s);
726     if (i != params.size() - 1) {
727       s << ", ";
728     }
729   }
730   if (params.size() > 0 || parent_constraints_.size() > 0) {
731     s << ") :";
732   } else {
733     s << ")";
734   }
735 
736   // Get the list of parent params to call the parent constructor with.
737   FieldList parent_params;
738   if (parent_ != nullptr) {
739     // Pass parameters to the parent constructor
740     s << parent_->name_ << "Builder(";
741     parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
742         PayloadField::kFieldType,
743         BodyField::kFieldType,
744     });
745 
746     // Go through all the fields and replace constrained fields with fixed values
747     // when calling the parent constructor.
748     for (std::size_t i = 0; i < parent_params.size(); i++) {
749       const auto& field = parent_params[i];
750       const auto& constraint = parent_constraints_.find(field->GetName());
751       if (constraint != parent_constraints_.end()) {
752         if (field->GetFieldType() == ScalarField::kFieldType) {
753           s << std::get<int64_t>(constraint->second);
754         } else if (field->GetFieldType() == EnumField::kFieldType) {
755           s << std::get<std::string>(constraint->second);
756         } else {
757           ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
758         }
759 
760         s << "/* " << field->GetName() << "_ */";
761       } else {
762         s << field->GetName();
763       }
764 
765       if (i != parent_params.size() - 1) {
766         s << ", ";
767       }
768     }
769     s << ") ";
770   }
771 
772   // Build a list of parameters that excludes all parent parameters.
773   FieldList saved_params;
774   for (const auto& field : params) {
775     if (parent_params.GetField(field->GetName()) == nullptr) {
776       saved_params.AppendField(field);
777     }
778   }
779   if (parent_ != nullptr && saved_params.size() > 0) {
780     s << ",";
781   }
782   for (std::size_t i = 0; i < saved_params.size(); i++) {
783     const auto& saved_param_name = saved_params[i]->GetName();
784     if (saved_params[i]->BuilderParameterMustBeMoved()) {
785       s << saved_param_name << "_(std::move(" << saved_param_name << "))";
786     } else {
787       s << saved_param_name << "_(" << saved_param_name << ")";
788     }
789     if (i != saved_params.size() - 1) {
790       s << ",";
791     }
792   }
793   s << " {";
794 
795   FieldList params_to_validate = GetParametersToValidate();
796 
797   if (params_to_validate.size() > 0) {
798     s << "CheckParameterValues(";
799     for (std::size_t i = 0; i < params_to_validate.size(); i++) {
800       s << params_to_validate[i]->GetName() << "_";
801       if (i != params_to_validate.size() - 1) {
802         s << ", ";
803       }
804     }
805     s << ");";
806   }
807 
808   s << "}\n";
809 }
810