1 /*
2  * Copyright 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "packet_def.h"
18 
19 #include <iomanip>
20 #include <list>
21 #include <set>
22 
23 #include "fields/all_fields.h"
24 #include "util.h"
25 
PacketDef(std::string name,FieldList fields)26 PacketDef::PacketDef(std::string name, FieldList fields) : ParentDef(name, fields, nullptr) {}
PacketDef(std::string name,FieldList fields,PacketDef * parent)27 PacketDef::PacketDef(std::string name, FieldList fields, PacketDef* parent) : ParentDef(name, fields, parent) {}
28 
GetNewField(const std::string &,ParseLocation) const29 PacketField* PacketDef::GetNewField(const std::string&, ParseLocation) const {
30   return nullptr;  // Packets can't be fields
31 }
32 
GenParserDefinition(std::ostream & s) const33 void PacketDef::GenParserDefinition(std::ostream& s) const {
34   s << "class " << name_ << "View";
35   if (parent_ != nullptr) {
36     s << " : public " << parent_->name_ << "View {";
37   } else {
38     s << " : public PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> {";
39   }
40   s << " public:";
41 
42   // Specialize function
43   if (parent_ != nullptr) {
44     s << "static " << name_ << "View Create(" << parent_->name_ << "View parent)";
45     s << "{ return " << name_ << "View(std::move(parent)); }";
46   } else {
47     s << "static " << name_ << "View Create(PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet) ";
48     s << "{ return " << name_ << "View(std::move(packet)); }";
49   }
50 
51   GenTestingParserFromBytes(s);
52 
53   std::set<std::string> fixed_types = {
54       FixedScalarField::kFieldType,
55       FixedEnumField::kFieldType,
56   };
57 
58   // Print all of the public fields which are all the fields minus the fixed fields.
59   const auto& public_fields = fields_.GetFieldsWithoutTypes(fixed_types);
60   bool has_fixed_fields = public_fields.size() != fields_.size();
61   for (const auto& field : public_fields) {
62     GenParserFieldGetter(s, field);
63     s << "\n";
64   }
65   GenValidator(s);
66   s << "\n";
67 
68   s << " public:";
69   GenParserToString(s);
70   s << "\n";
71 
72   s << " protected:\n";
73   // Constructor from a View
74   if (parent_ != nullptr) {
75     s << "explicit " << name_ << "View(" << parent_->name_ << "View parent)";
76     s << " : " << parent_->name_ << "View(std::move(parent)) { was_validated_ = false; }";
77   } else {
78     s << "explicit " << name_ << "View(PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet) ";
79     s << " : PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(packet) { was_validated_ = false;}";
80   }
81 
82   // Print the private fields which are the fixed fields.
83   if (has_fixed_fields) {
84     const auto& private_fields = fields_.GetFieldsWithTypes(fixed_types);
85     s << " private:\n";
86     for (const auto& field : private_fields) {
87       GenParserFieldGetter(s, field);
88       s << "\n";
89     }
90   }
91   s << "};\n";
92 }
93 
GenTestingParserFromBytes(std::ostream & s) const94 void PacketDef::GenTestingParserFromBytes(std::ostream& s) const {
95   s << "\n#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n";
96 
97   s << "static " << name_ << "View FromBytes(std::vector<uint8_t> bytes) {";
98   s << "auto vec = std::make_shared<std::vector<uint8_t>>(bytes);";
99   s << "return " << name_ << "View::Create(";
100   auto ancestor_ptr = parent_;
101   size_t parent_parens = 0;
102   while (ancestor_ptr != nullptr) {
103     s << ancestor_ptr->name_ << "View::Create(";
104     parent_parens++;
105     ancestor_ptr = ancestor_ptr->parent_;
106   }
107   s << "PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(vec)";
108   for (size_t i = 0; i < parent_parens; i++) {
109     s << ")";
110   }
111   s << ");";
112   s << "}";
113 
114   s << "\n#endif\n";
115 }
116 
GenParserDefinitionPybind11(std::ostream & s) const117 void PacketDef::GenParserDefinitionPybind11(std::ostream& s) const {
118   s << "py::class_<" << name_ << "View";
119   if (parent_ != nullptr) {
120     s << ", " << parent_->name_ << "View";
121   } else {
122     s << ", PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>";
123   }
124   s << ">(m, \"" << name_ << "View\")";
125   if (parent_ != nullptr) {
126     s << ".def(py::init([](" << parent_->name_ << "View parent) {";
127   } else {
128     s << ".def(py::init([](PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> parent) {";
129   }
130   s << "auto view =" << name_ << "View::Create(std::move(parent));";
131   s << "if (!view.IsValid()) { throw std::invalid_argument(\"Bad packet view\"); }";
132   s << "return view; }))";
133 
134   s << ".def(py::init(&" << name_ << "View::Create))";
135   std::set<std::string> protected_field_types = {
136       FixedScalarField::kFieldType,
137       FixedEnumField::kFieldType,
138       SizeField::kFieldType,
139       CountField::kFieldType,
140   };
141   const auto& public_fields = fields_.GetFieldsWithoutTypes(protected_field_types);
142   for (const auto& field : public_fields) {
143     auto getter_func_name = field->GetGetterFunctionName();
144     if (getter_func_name.empty()) {
145       continue;
146     }
147     s << ".def(\"" << getter_func_name << "\", &" << name_ << "View::" << getter_func_name << ")";
148   }
149   s << ".def(\"IsValid\", &" << name_ << "View::IsValid)";
150   s << ";\n";
151 }
152 
GenParserFieldGetter(std::ostream & s,const PacketField * field) const153 void PacketDef::GenParserFieldGetter(std::ostream& s, const PacketField* field) const {
154   // Start field offset
155   auto start_field_offset = GetOffsetForField(field->GetName(), false);
156   auto end_field_offset = GetOffsetForField(field->GetName(), true);
157 
158   if (start_field_offset.empty() && end_field_offset.empty()) {
159     ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, "
160                  << "no method exists to determine field location from begin() or end().\n";
161   }
162 
163   field->GenGetter(s, start_field_offset, end_field_offset);
164 }
165 
GetDefinitionType() const166 TypeDef::Type PacketDef::GetDefinitionType() const {
167   return TypeDef::Type::PACKET;
168 }
169 
GenValidator(std::ostream & s) const170 void PacketDef::GenValidator(std::ostream& s) const {
171   // Get the static offset for all of our fields.
172   int bits_size = 0;
173   for (const auto& field : fields_) {
174     if (field->GetFieldType() != PaddingField::kFieldType) {
175       bits_size += field->GetSize().bits();
176     }
177   }
178 
179   // Write the function declaration.
180   s << "virtual bool IsValid() " << (parent_ != nullptr ? " override" : "") << " {";
181   s << "if (was_validated_) { return true; } ";
182   s << "else { was_validated_ = true; was_validated_ = IsValid_(); return was_validated_; }";
183   s << "}";
184 
185   s << "protected:";
186   s << "virtual bool IsValid_() const {";
187 
188   if (parent_ != nullptr) {
189     s << "if (!" << parent_->name_ << "View::IsValid_()) { return false; } ";
190   }
191 
192   // Offset by the parents known size. We know that any dynamic fields can
193   // already be called since the parent must have already been validated by
194   // this point.
195   auto parent_size = Size(0);
196   if (parent_ != nullptr) {
197     parent_size = parent_->GetSize(true);
198   }
199 
200   s << "auto it = begin() + (" << parent_size << ") / 8;";
201 
202   // Check if you can extract the static fields.
203   // At this point you know you can use the size getters without crashing
204   // as long as they follow the instruction that size fields cant come before
205   // their corrisponding variable length field.
206   s << "it += " << ((bits_size + 7) / 8) << " /* Total size of the fixed fields */;";
207   s << "if (it > end()) return false;";
208 
209   // For any variable length fields, use their size check.
210   for (const auto& field : fields_) {
211     if (field->GetFieldType() == ChecksumStartField::kFieldType) {
212       auto offset = GetOffsetForField(field->GetName(), false);
213       if (!offset.empty()) {
214         s << "size_t sum_index = (" << offset << ") / 8;";
215       } else {
216         offset = GetOffsetForField(field->GetName(), true);
217         if (offset.empty()) {
218           ERROR(field) << "Checksum Start Field offset can not be determined.";
219         }
220         s << "size_t sum_index = size() - (" << offset << ") / 8;";
221       }
222 
223       const auto& field_name = ((ChecksumStartField*)field)->GetStartedFieldName();
224       const auto& started_field = fields_.GetField(field_name);
225       if (started_field == nullptr) {
226         ERROR(field) << __func__ << ": Can't find checksum field named " << field_name << "(" << field->GetName()
227                      << ")";
228       }
229       auto end_offset = GetOffsetForField(started_field->GetName(), false);
230       if (!end_offset.empty()) {
231         s << "size_t end_sum_index = (" << end_offset << ") / 8;";
232       } else {
233         end_offset = GetOffsetForField(started_field->GetName(), true);
234         if (end_offset.empty()) {
235           ERROR(started_field) << "Checksum Field end_offset can not be determined.";
236         }
237         s << "size_t end_sum_index = size() - (" << started_field->GetSize() << " - " << end_offset << ") / 8;";
238       }
239       if (is_little_endian_) {
240         s << "auto checksum_view = GetLittleEndianSubview(sum_index, end_sum_index);";
241       } else {
242         s << "auto checksum_view = GetBigEndianSubview(sum_index, end_sum_index);";
243       }
244       s << started_field->GetDataType() << " checksum;";
245       s << "checksum.Initialize();";
246       s << "for (uint8_t byte : checksum_view) { ";
247       s << "checksum.AddByte(byte);}";
248       s << "if (checksum.GetChecksum() != (begin() + end_sum_index).extract<"
249         << util::GetTypeForSize(started_field->GetSize().bits()) << ">()) { return false; }";
250 
251       continue;
252     }
253 
254     auto field_size = field->GetSize();
255     // Fixed size fields have already been handled.
256     if (!field_size.has_dynamic()) {
257       continue;
258     }
259 
260     // Custom fields with dynamic size must have the offset for the field passed in as well
261     // as the end iterator so that they may ensure that they don't try to read past the end.
262     // Custom fields with fixed sizes will be handled in the static offset checking.
263     if (field->GetFieldType() == CustomField::kFieldType) {
264       // Check if we can determine offset from begin(), otherwise error because by this point,
265       // the size of the custom field is unknown and can't be subtracted from end() to get the
266       // offset.
267       auto offset = GetOffsetForField(field->GetName(), false);
268       if (offset.empty()) {
269         ERROR(field) << "Custom Field offset can not be determined from begin().";
270       }
271 
272       if (offset.bits() % 8 != 0) {
273         ERROR(field) << "Custom fields must be byte aligned.";
274       }
275 
276       // Custom fields are special as their size field takes an argument.
277       const auto& custom_size_var = field->GetName() + "_size";
278       s << "const auto& " << custom_size_var << " = " << field_size.dynamic_string();
279       s << "(begin() + (" << offset << ") / 8);";
280 
281       s << "if (!" << custom_size_var << ".has_value()) { return false; }";
282       s << "it += *" << custom_size_var << ";";
283       s << "if (it > end()) return false;";
284       continue;
285     } else {
286       s << "it += (" << field_size.dynamic_string() << ") / 8;";
287       s << "if (it > end()) return false;";
288     }
289   }
290 
291   // Validate constraints after validating the size
292   if (parent_constraints_.size() > 0 && parent_ == nullptr) {
293     ERROR() << "Can't have a constraint on a NULL parent";
294   }
295 
296   for (const auto& constraint : parent_constraints_) {
297     s << "if (Get" << util::UnderscoreToCamelCase(constraint.first) << "() != ";
298     const auto& field = parent_->GetParamList().GetField(constraint.first);
299     if (field->GetFieldType() == ScalarField::kFieldType) {
300       s << std::get<int64_t>(constraint.second);
301     } else {
302       s << std::get<std::string>(constraint.second);
303     }
304     s << ") return false;";
305   }
306 
307   // Validate the packets fields last
308   for (const auto& field : fields_) {
309     field->GenValidator(s);
310     s << "\n";
311   }
312 
313   s << "return true;";
314   s << "}\n";
315   if (parent_ == nullptr) {
316     s << "bool was_validated_{false};\n";
317   }
318 }
319 
GenParserToString(std::ostream & s) const320 void PacketDef::GenParserToString(std::ostream& s) const {
321   s << "virtual std::string ToString() " << (parent_ != nullptr ? " override" : "") << " {";
322   s << "std::stringstream ss;";
323   s << "ss << std::showbase << std::hex << \"" << name_ << " { \";";
324 
325   if (fields_.size() > 0) {
326     s << "ss << \"\" ";
327     bool firstfield = true;
328     for (const auto& field : fields_) {
329       if (field->GetFieldType() == ReservedField::kFieldType || field->GetFieldType() == FixedScalarField::kFieldType ||
330           field->GetFieldType() == ChecksumStartField::kFieldType)
331         continue;
332 
333       s << (firstfield ? " << \"" : " << \", ") << field->GetName() << " = \" << ";
334 
335       field->GenStringRepresentation(s, field->GetGetterFunctionName() + "()");
336 
337       if (firstfield) {
338         firstfield = false;
339       }
340     }
341     s << ";";
342   }
343 
344   s << "ss << \" }\";";
345   s << "return ss.str();";
346   s << "}\n";
347 }
348 
GenBuilderDefinition(std::ostream & s) const349 void PacketDef::GenBuilderDefinition(std::ostream& s) const {
350   s << "class " << name_ << "Builder";
351   if (parent_ != nullptr) {
352     s << " : public " << parent_->name_ << "Builder";
353   } else {
354     if (is_little_endian_) {
355       s << " : public PacketBuilder<kLittleEndian>";
356     } else {
357       s << " : public PacketBuilder<!kLittleEndian>";
358     }
359   }
360   s << " {";
361   s << " public:";
362   s << "  virtual ~" << name_ << "Builder() = default;";
363 
364   if (!fields_.HasBody()) {
365     GenBuilderCreate(s);
366     s << "\n";
367 
368     GenTestingFromView(s);
369     s << "\n";
370   }
371 
372   GenSerialize(s);
373   s << "\n";
374 
375   GenSize(s);
376   s << "\n";
377 
378   s << " protected:\n";
379   GenBuilderConstructor(s);
380   s << "\n";
381 
382   GenBuilderParameterChecker(s);
383   s << "\n";
384 
385   GenMembers(s);
386   s << "};\n";
387 
388   GenTestDefine(s);
389   s << "\n";
390 
391   GenFuzzTestDefine(s);
392   s << "\n";
393 }
394 
GenTestingFromView(std::ostream & s) const395 void PacketDef::GenTestingFromView(std::ostream& s) const {
396   s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n";
397 
398   s << "static std::unique_ptr<" << name_ << "Builder> FromView(" << name_ << "View view) {";
399   s << "return " << name_ << "Builder::Create(";
400   FieldList params = GetParamList().GetFieldsWithoutTypes({
401       BodyField::kFieldType,
402   });
403   for (std::size_t i = 0; i < params.size(); i++) {
404     params[i]->GenBuilderParameterFromView(s);
405     if (i != params.size() - 1) {
406       s << ", ";
407     }
408   }
409   s << ");";
410   s << "}";
411 
412   s << "\n#endif\n";
413 }
414 
GenBuilderDefinitionPybind11(std::ostream & s) const415 void PacketDef::GenBuilderDefinitionPybind11(std::ostream& s) const {
416   s << "py::class_<" << name_ << "Builder";
417   if (parent_ != nullptr) {
418     s << ", " << parent_->name_ << "Builder";
419   } else {
420     if (is_little_endian_) {
421       s << ", PacketBuilder<kLittleEndian>";
422     } else {
423       s << ", PacketBuilder<!kLittleEndian>";
424     }
425   }
426   s << ", std::shared_ptr<" << name_ << "Builder>";
427   s << ">(m, \"" << name_ << "Builder\")";
428   if (!fields_.HasBody()) {
429     GenBuilderCreatePybind11(s);
430   }
431   s << ".def(\"Serialize\", [](" << name_ << "Builder& builder){";
432   s << "std::vector<uint8_t> bytes;";
433   s << "BitInserter bi(bytes);";
434   s << "builder.Serialize(bi);";
435   s << "return bytes;})";
436   s << ";\n";
437 }
438 
GenTestDefine(std::ostream & s) const439 void PacketDef::GenTestDefine(std::ostream& s) const {
440   s << "#ifdef PACKET_TESTING\n";
441   s << "#define DEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest(...)";
442   s << "class " << name_ << "ReflectionTest : public testing::TestWithParam<std::vector<uint8_t>> { ";
443   s << "public: ";
444   s << "void CompareBytes(std::vector<uint8_t> captured_packet) {";
445   s << name_ << "View view = " << name_ << "View::FromBytes(captured_packet);";
446   s << "if (!view.IsValid()) { LOG_INFO(\"Invalid Packet Bytes (size = %zu)\", view.size());";
447   s << "for (size_t i = 0; i < view.size(); i++) { LOG_INFO(\"%5zd:%02X\", i, *(view.begin() + i)); }}";
448   s << "ASSERT_TRUE(view.IsValid());";
449   s << "auto packet = " << name_ << "Builder::FromView(view);";
450   s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = std::make_shared<std::vector<uint8_t>>();";
451   s << "packet_bytes->reserve(packet->size());";
452   s << "BitInserter it(*packet_bytes);";
453   s << "packet->Serialize(it);";
454   s << "ASSERT_EQ(*packet_bytes, captured_packet);";
455   s << "}";
456   s << "};";
457   s << "TEST_P(" << name_ << "ReflectionTest, generatedReflectionTest) {";
458   s << "CompareBytes(GetParam());";
459   s << "}";
460   s << "INSTANTIATE_TEST_SUITE_P(" << name_ << "_reflection, ";
461   s << name_ << "ReflectionTest, testing::Values(__VA_ARGS__))";
462   int i = 0;
463   for (const auto& bytes : test_cases_) {
464     s << "\nuint8_t " << name_ << "_test_bytes_" << i << "[] = \"" << bytes << "\";";
465     s << "std::vector<uint8_t> " << name_ << "_test_vec_" << i << "(";
466     s << name_ << "_test_bytes_" << i << ",";
467     s << name_ << "_test_bytes_" << i << " + sizeof(";
468     s << name_ << "_test_bytes_" << i << ") - 1);";
469     i++;
470   }
471   if (!test_cases_.empty()) {
472     i = 0;
473     s << "\nDEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest(";
474     for (auto bytes : test_cases_) {
475       if (i > 0) {
476         s << ",";
477       }
478       s << name_ << "_test_vec_" << i++;
479     }
480     s << ");";
481   }
482   s << "\n#endif";
483 }
484 
GenFuzzTestDefine(std::ostream & s) const485 void PacketDef::GenFuzzTestDefine(std::ostream& s) const {
486   s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING)\n";
487   s << "#define DEFINE_" << name_ << "ReflectionFuzzTest() ";
488   s << "void Run" << name_ << "ReflectionFuzzTest(const uint8_t* data, size_t size) {";
489   s << "auto vec = std::vector<uint8_t>(data, data + size);";
490   s << name_ << "View view = " << name_ << "View::FromBytes(vec);";
491   s << "if (!view.IsValid()) { return; }";
492   s << "auto packet = " << name_ << "Builder::FromView(view);";
493   s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = std::make_shared<std::vector<uint8_t>>();";
494   s << "packet_bytes->reserve(packet->size());";
495   s << "BitInserter it(*packet_bytes);";
496   s << "packet->Serialize(it);";
497   s << "}";
498   s << "\n#endif\n";
499   s << "#ifdef PACKET_FUZZ_TESTING\n";
500   s << "#define DEFINE_AND_REGISTER_" << name_ << "ReflectionFuzzTest(REGISTRY) ";
501   s << "DEFINE_" << name_ << "ReflectionFuzzTest();";
502   s << " class " << name_ << "ReflectionFuzzTestRegistrant {";
503   s << "public: ";
504   s << "explicit " << name_
505     << "ReflectionFuzzTestRegistrant(std::vector<void(*)(const uint8_t*, size_t)>& fuzz_test_registry) {";
506   s << "fuzz_test_registry.push_back(Run" << name_ << "ReflectionFuzzTest);";
507   s << "}}; ";
508   s << name_ << "ReflectionFuzzTestRegistrant " << name_ << "_reflection_fuzz_test_registrant(REGISTRY);";
509   s << "\n#endif";
510 }
511 
GetParametersToValidate() const512 FieldList PacketDef::GetParametersToValidate() const {
513   FieldList params_to_validate;
514   for (const auto& field : GetParamList()) {
515     if (field->HasParameterValidator()) {
516       params_to_validate.AppendField(field);
517     }
518   }
519   return params_to_validate;
520 }
521 
GenBuilderCreate(std::ostream & s) const522 void PacketDef::GenBuilderCreate(std::ostream& s) const {
523   s << "static std::unique_ptr<" << name_ << "Builder> Create(";
524 
525   auto params = GetParamList();
526   for (std::size_t i = 0; i < params.size(); i++) {
527     params[i]->GenBuilderParameter(s);
528     if (i != params.size() - 1) {
529       s << ", ";
530     }
531   }
532   s << ") {";
533 
534   // Call the constructor
535   s << "auto builder = std::unique_ptr<" << name_ << "Builder>(new " << name_ << "Builder(";
536 
537   params = params.GetFieldsWithoutTypes({
538       PayloadField::kFieldType,
539       BodyField::kFieldType,
540   });
541   // Add the parameters.
542   for (std::size_t i = 0; i < params.size(); i++) {
543     if (params[i]->BuilderParameterMustBeMoved()) {
544       s << "std::move(" << params[i]->GetName() << ")";
545     } else {
546       s << params[i]->GetName();
547     }
548     if (i != params.size() - 1) {
549       s << ", ";
550     }
551   }
552 
553   s << "));";
554   if (fields_.HasPayload()) {
555     s << "builder->payload_ = std::move(payload);";
556   }
557   s << "return builder;";
558   s << "}\n";
559 }
560 
GenBuilderCreatePybind11(std::ostream & s) const561 void PacketDef::GenBuilderCreatePybind11(std::ostream& s) const {
562   s << ".def(py::init([](";
563   auto params = GetParamList();
564   std::vector<std::string> constructor_args;
565   int i = 1;
566   for (const auto& param : params) {
567     i++;
568     std::stringstream ss;
569     auto param_type = param->GetBuilderParameterType();
570     if (param_type.empty()) {
571       continue;
572     }
573     // Use shared_ptr instead of unique_ptr for the Python interface
574     if (param->BuilderParameterMustBeMoved()) {
575       param_type = util::StringFindAndReplaceAll(param_type, "unique_ptr", "shared_ptr");
576     }
577     ss << param_type << " " << param->GetName();
578     constructor_args.push_back(ss.str());
579   }
580   s << util::StringJoin(",", constructor_args) << "){";
581 
582   // Deal with move only args
583   for (const auto& param : params) {
584     std::stringstream ss;
585     auto param_type = param->GetBuilderParameterType();
586     if (param_type.empty()) {
587       continue;
588     }
589     if (!param->BuilderParameterMustBeMoved()) {
590       continue;
591     }
592     auto move_only_param_name = param->GetName() + "_move_only";
593     s << param_type << " " << move_only_param_name << ";";
594     if (param->IsContainerField()) {
595       // Assume single layer container and copy it
596       auto struct_type = param->GetElementField()->GetDataType();
597       struct_type = util::StringFindAndReplaceAll(struct_type, "std::unique_ptr<", "");
598       struct_type = util::StringFindAndReplaceAll(struct_type, ">", "");
599       s << "for (size_t i = 0; i < " << param->GetName() << ".size(); i++) {";
600       // Serialize each struct
601       s << "auto " << param->GetName() + "_bytes = std::make_shared<std::vector<uint8_t>>();";
602       s << param->GetName() + "_bytes->reserve(" << param->GetName() << "[i]->size());";
603       s << "BitInserter " << param->GetName() + "_bi(*" << param->GetName() << "_bytes);";
604       s << param->GetName() << "[i]->Serialize(" << param->GetName() << "_bi);";
605       // Parse it again
606       s << "auto " << param->GetName() << "_view = PacketView<kLittleEndian>(" << param->GetName() << "_bytes);";
607       s << param->GetElementField()->GetDataType() << " " << param->GetName() << "_reparsed = ";
608       s << "Parse" << struct_type << "(" << param->GetName() + "_view.begin());";
609       // Push it into a new container
610       if (param->GetFieldType() == VectorField::kFieldType) {
611         s << move_only_param_name << ".push_back(std::move(" << param->GetName() + "_reparsed));";
612       } else if (param->GetFieldType() == ArrayField::kFieldType) {
613         s << move_only_param_name << "[i] = std::move(" << param->GetName() << "_reparsed);";
614       } else {
615         ERROR() << param << " is not supported by Pybind11";
616       }
617       s << "}";
618     } else {
619       // Serialize the parameter and pass the bytes in a RawBuilder
620       s << "std::vector<uint8_t> " << param->GetName() + "_bytes;";
621       s << param->GetName() + "_bytes.reserve(" << param->GetName() << "->size());";
622       s << "BitInserter " << param->GetName() + "_bi(" << param->GetName() << "_bytes);";
623       s << param->GetName() << "->Serialize(" << param->GetName() + "_bi);";
624       s << move_only_param_name << " = ";
625       s << "std::make_unique<RawBuilder>(" << param->GetName() << "_bytes);";
626     }
627   }
628   s << "return " << name_ << "Builder::Create(";
629   std::vector<std::string> builder_vars;
630   for (const auto& param : params) {
631     std::stringstream ss;
632     auto param_type = param->GetBuilderParameterType();
633     if (param_type.empty()) {
634       continue;
635     }
636     auto param_name = param->GetName();
637     if (param->BuilderParameterMustBeMoved()) {
638       ss << "std::move(" << param_name << "_move_only)";
639     } else {
640       ss << param_name;
641     }
642     builder_vars.push_back(ss.str());
643   }
644   s << util::StringJoin(",", builder_vars) << ");}";
645   s << "))";
646 }
647 
GenBuilderParameterChecker(std::ostream & s) const648 void PacketDef::GenBuilderParameterChecker(std::ostream& s) const {
649   FieldList params_to_validate = GetParametersToValidate();
650 
651   // Skip writing this function if there is nothing to validate.
652   if (params_to_validate.size() == 0) {
653     return;
654   }
655 
656   // Generate function arguments.
657   s << "void CheckParameterValues(";
658   for (std::size_t i = 0; i < params_to_validate.size(); i++) {
659     params_to_validate[i]->GenBuilderParameter(s);
660     if (i != params_to_validate.size() - 1) {
661       s << ", ";
662     }
663   }
664   s << ") {";
665 
666   // Check the parameters.
667   for (const auto& field : params_to_validate) {
668     field->GenParameterValidator(s);
669   }
670   s << "}\n";
671 }
672 
GenBuilderConstructor(std::ostream & s) const673 void PacketDef::GenBuilderConstructor(std::ostream& s) const {
674   s << "explicit " << name_ << "Builder(";
675 
676   // Generate the constructor parameters.
677   auto params = GetParamList().GetFieldsWithoutTypes({
678       PayloadField::kFieldType,
679       BodyField::kFieldType,
680   });
681   for (std::size_t i = 0; i < params.size(); i++) {
682     params[i]->GenBuilderParameter(s);
683     if (i != params.size() - 1) {
684       s << ", ";
685     }
686   }
687   if (params.size() > 0 || parent_constraints_.size() > 0) {
688     s << ") :";
689   } else {
690     s << ")";
691   }
692 
693   // Get the list of parent params to call the parent constructor with.
694   FieldList parent_params;
695   if (parent_ != nullptr) {
696     // Pass parameters to the parent constructor
697     s << parent_->name_ << "Builder(";
698     parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
699         PayloadField::kFieldType,
700         BodyField::kFieldType,
701     });
702 
703     // Go through all the fields and replace constrained fields with fixed values
704     // when calling the parent constructor.
705     for (std::size_t i = 0; i < parent_params.size(); i++) {
706       const auto& field = parent_params[i];
707       const auto& constraint = parent_constraints_.find(field->GetName());
708       if (constraint != parent_constraints_.end()) {
709         if (field->GetFieldType() == ScalarField::kFieldType) {
710           s << std::get<int64_t>(constraint->second);
711         } else if (field->GetFieldType() == EnumField::kFieldType) {
712           s << std::get<std::string>(constraint->second);
713         } else {
714           ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
715         }
716 
717         s << "/* " << field->GetName() << "_ */";
718       } else {
719         s << field->GetName();
720       }
721 
722       if (i != parent_params.size() - 1) {
723         s << ", ";
724       }
725     }
726     s << ") ";
727   }
728 
729   // Build a list of parameters that excludes all parent parameters.
730   FieldList saved_params;
731   for (const auto& field : params) {
732     if (parent_params.GetField(field->GetName()) == nullptr) {
733       saved_params.AppendField(field);
734     }
735   }
736   if (parent_ != nullptr && saved_params.size() > 0) {
737     s << ",";
738   }
739   for (std::size_t i = 0; i < saved_params.size(); i++) {
740     const auto& saved_param_name = saved_params[i]->GetName();
741     if (saved_params[i]->BuilderParameterMustBeMoved()) {
742       s << saved_param_name << "_(std::move(" << saved_param_name << "))";
743     } else {
744       s << saved_param_name << "_(" << saved_param_name << ")";
745     }
746     if (i != saved_params.size() - 1) {
747       s << ",";
748     }
749   }
750   s << " {";
751 
752   FieldList params_to_validate = GetParametersToValidate();
753 
754   if (params_to_validate.size() > 0) {
755     s << "CheckParameterValues(";
756     for (std::size_t i = 0; i < params_to_validate.size(); i++) {
757       s << params_to_validate[i]->GetName() << "_";
758       if (i != params_to_validate.size() - 1) {
759         s << ", ";
760       }
761     }
762     s << ");";
763   }
764 
765   s << "}\n";
766 }
767 
GenRustChildEnums(std::ostream & s) const768 void PacketDef::GenRustChildEnums(std::ostream& s) const {
769   if (HasChildEnums()) {
770     bool payload = fields_.HasPayload();
771     s << "#[derive(Debug)] ";
772     s << "enum " << name_ << "DataChild {";
773     for (const auto& child : children_) {
774       s << child->name_ << "(Arc<" << child->name_ << "Data>),";
775     }
776     if (payload) {
777       s << "Payload(Bytes),";
778     }
779     s << "None,";
780     s << "}\n";
781 
782     s << "impl " << name_ << "DataChild {";
783     s << "fn get_total_size(&self) -> usize {";
784     s << "match self {";
785     for (const auto& child : children_) {
786       s << name_ << "DataChild::" << child->name_ << "(value) => value.get_total_size(),";
787     }
788     if (payload) {
789       s << name_ << "DataChild::Payload(p) => p.len(),";
790     }
791     s << name_ << "DataChild::None => 0,";
792     s << "}\n";
793     s << "}\n";
794     s << "}\n";
795 
796     s << "#[derive(Debug)] ";
797     s << "pub enum " << name_ << "Child {";
798     for (const auto& child : children_) {
799       s << child->name_ << "(" << child->name_ << "Packet),";
800     }
801     if (payload) {
802       s << "Payload(Bytes),";
803     }
804     s << "None,";
805     s << "}\n";
806   }
807 }
808 
GenRustStructDeclarations(std::ostream & s) const809 void PacketDef::GenRustStructDeclarations(std::ostream& s) const {
810   s << "#[derive(Debug)] ";
811   s << "struct " << name_ << "Data {";
812 
813   // Generate struct fields
814   GenRustStructFieldNameAndType(s);
815   if (HasChildEnums()) {
816     s << "child: " << name_ << "DataChild,";
817   }
818   s << "}\n";
819 
820   // Generate accessor struct
821   s << "#[derive(Debug, Clone)] ";
822   s << "pub struct " << name_ << "Packet {";
823   auto lineage = GetAncestors();
824   lineage.push_back(this);
825   for (auto it = lineage.begin(); it != lineage.end(); it++) {
826     auto def = *it;
827     s << util::CamelCaseToUnderScore(def->name_) << ": Arc<" << def->name_ << "Data>,";
828   }
829   s << "}\n";
830 
831   // Generate builder struct
832   s << "#[derive(Debug)] ";
833   s << "pub struct " << name_ << "Builder {";
834   auto params = GetParamList().GetFieldsWithoutTypes({
835       PayloadField::kFieldType,
836       BodyField::kFieldType,
837   });
838   for (auto param : params) {
839     s << "pub ";
840     param->GenRustNameAndType(s);
841     s << ", ";
842   }
843   if (fields_.HasPayload()) {
844     s << "pub payload: Option<Bytes>,";
845   }
846   s << "}\n";
847 }
848 
GenRustStructFieldNameAndType(std::ostream & s) const849 bool PacketDef::GenRustStructFieldNameAndType(std::ostream& s) const {
850   auto fields = fields_.GetFieldsWithoutTypes({
851       BodyField::kFieldType,
852       CountField::kFieldType,
853       PaddingField::kFieldType,
854       ReservedField::kFieldType,
855       SizeField::kFieldType,
856       PayloadField::kFieldType,
857       FixedScalarField::kFieldType,
858   });
859   if (fields.size() == 0) {
860     return false;
861   }
862   for (const auto& field : fields) {
863     field->GenRustNameAndType(s);
864     s << ", ";
865   }
866   return true;
867 }
868 
GenRustStructFieldNames(std::ostream & s) const869 void PacketDef::GenRustStructFieldNames(std::ostream& s) const {
870   auto fields = fields_.GetFieldsWithoutTypes({
871       BodyField::kFieldType,
872       CountField::kFieldType,
873       PaddingField::kFieldType,
874       ReservedField::kFieldType,
875       SizeField::kFieldType,
876       PayloadField::kFieldType,
877       FixedScalarField::kFieldType,
878   });
879   for (const auto field : fields) {
880     s << field->GetName();
881     s << ", ";
882   }
883 }
884 
GenRustStructImpls(std::ostream & s) const885 void PacketDef::GenRustStructImpls(std::ostream& s) const {
886   s << "impl " << name_ << "Data {";
887 
888   // conforms function
889   s << "fn conforms(bytes: &[u8]) -> bool {";
890   GenRustConformanceCheck(s);
891 
892   auto fields = fields_.GetFieldsWithTypes({
893       StructField::kFieldType,
894   });
895 
896   for (auto const& field : fields) {
897     auto start_offset = GetOffsetForField(field->GetName(), false);
898     auto end_offset = GetOffsetForField(field->GetName(), true);
899 
900     s << "if !" << field->GetRustDataType() << "::conforms(&bytes[" << start_offset.bytes();
901     s << ".." << start_offset.bytes() + field->GetSize().bytes() << "]) { return false; }";
902   }
903 
904   s << " true";
905   s << "}";
906 
907   // parse function
908   if (parent_constraints_.empty() && children_.size() > 1 && parent_ != nullptr) {
909     auto constraint = FindConstraintField();
910     auto constraint_field = GetParamList().GetField(constraint);
911     auto constraint_type = constraint_field->GetRustDataType();
912     s << "fn parse(bytes: &[u8], " << constraint << ": " << constraint_type << ") -> Result<Self> {";
913   } else {
914     s << "fn parse(bytes: &[u8]) -> Result<Self> {";
915   }
916   fields = fields_.GetFieldsWithoutTypes({
917       BodyField::kFieldType,
918   });
919 
920   for (auto const& field : fields) {
921     auto start_field_offset = GetOffsetForField(field->GetName(), false);
922     auto end_field_offset = GetOffsetForField(field->GetName(), true);
923 
924     if (start_field_offset.empty() && end_field_offset.empty()) {
925       ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, "
926                    << "no method exists to determine field location from begin() or end().\n";
927     }
928 
929     field->GenBoundsCheck(s, start_field_offset, end_field_offset, name_);
930     field->GenRustGetter(s, start_field_offset, end_field_offset);
931   }
932 
933   auto payload_field = fields_.GetFieldsWithTypes({
934     PayloadField::kFieldType,
935   });
936 
937   Size payload_offset;
938 
939   if (payload_field.HasPayload()) {
940     payload_offset = GetOffsetForField(payload_field[0]->GetName(), false);
941   }
942 
943   auto constraint_name = FindConstraintField();
944   auto constrained_descendants = FindDescendantsWithConstraint(constraint_name);
945 
946   if (children_.size() > 1) {
947     s << "let child = match " << constraint_name << " {";
948 
949     for (const auto& desc : constrained_descendants) {
950       auto desc_path = FindPathToDescendant(desc.first->name_);
951       std::reverse(desc_path.begin(), desc_path.end());
952       auto constraint_field = GetParamList().GetField(constraint_name);
953       auto constraint_type = constraint_field->GetFieldType();
954 
955       if (constraint_type == EnumField::kFieldType) {
956         auto type = std::get<std::string>(desc.second);
957         auto variant_name = type.substr(type.find("::") + 2, type.length());
958         auto enum_type = type.substr(0, type.find("::"));
959         auto enum_variant = enum_type + "::"
960             + util::UnderscoreToCamelCase(util::ToLowerCase(variant_name));
961         s << enum_variant;
962         s << " if " << desc_path[0]->name_ << "Data::conforms(&bytes[..])";
963         s << " => {";
964         s << name_ << "DataChild::";
965         s << desc_path[0]->name_ << "(Arc::new(";
966         if (desc_path[0]->parent_constraints_.empty()) {
967           s << desc_path[0]->name_ << "Data::parse(&bytes[..]";
968           s << ", " << enum_variant << ")?))";
969         } else {
970           s << desc_path[0]->name_ << "Data::parse(&bytes[..])?))";
971         }
972       } else if (constraint_type == ScalarField::kFieldType) {
973         s << std::get<int64_t>(desc.second) << " => {";
974         s << "unimplemented!();";
975       }
976       s << "}\n";
977     }
978 
979     if (!constrained_descendants.empty()) {
980       s << "v => return Err(Error::ConstraintOutOfBounds{field: \"" << constraint_name
981         << "\".to_string(), value: v as u64}),";
982     }
983 
984     s << "};\n";
985   } else if (children_.size() == 1) {
986     auto child = children_.at(0);
987     s << "let child = match " << child->name_ << "Data::parse(&bytes[..]) {";
988     s << " Ok(c) if " << child->name_ << "Data::conforms(&bytes[..]) => {";
989     s << name_ << "DataChild::" << child->name_ << "(Arc::new(c))";
990     s << " },";
991     s << " Err(Error::InvalidLengthError { .. }) => " << name_ << "DataChild::None,";
992     s << " _ => return Err(Error::InvalidPacketError),";
993     s << "};";
994   } else if (fields_.HasPayload()) {
995     s << "let child = if payload.len() > 0 {";
996     s << name_ << "DataChild::Payload(Bytes::from(payload))";
997     s << "} else {";
998     s << name_ << "DataChild::None";
999     s << "};";
1000   }
1001 
1002   s << "Ok(Self {";
1003   fields = fields_.GetFieldsWithoutTypes({
1004       BodyField::kFieldType,
1005       CountField::kFieldType,
1006       PaddingField::kFieldType,
1007       ReservedField::kFieldType,
1008       SizeField::kFieldType,
1009       PayloadField::kFieldType,
1010       FixedScalarField::kFieldType,
1011   });
1012 
1013   if (fields.size() > 0) {
1014     for (const auto& field : fields) {
1015       auto field_type = field->GetFieldType();
1016       s << field->GetName();
1017       s << ", ";
1018     }
1019   }
1020 
1021   if (HasChildEnums()) {
1022     s << "child,";
1023   }
1024   s << "})\n";
1025   s << "}\n";
1026 
1027   // write_to function
1028   s << "fn write_to(&self, buffer: &mut BytesMut) {";
1029   GenRustWriteToFields(s);
1030 
1031   if (HasChildEnums()) {
1032     s << "match &self.child {";
1033     for (const auto& child : children_) {
1034       s << name_ << "DataChild::" << child->name_ << "(value) => value.write_to(buffer),";
1035     }
1036     if (fields_.HasPayload()) {
1037       auto offset = GetOffsetForField("payload");
1038       s << name_ << "DataChild::Payload(p) => buffer[" << offset.bytes() << "..].copy_from_slice(&p[..]),";
1039     }
1040     s << name_ << "DataChild::None => {}";
1041     s << "}";
1042   }
1043 
1044   s << "}\n";
1045 
1046   s << "fn get_total_size(&self) -> usize {";
1047   if (HasChildEnums()) {
1048     s << "self.get_size() + self.child.get_total_size()";
1049   } else {
1050     s << "self.get_size()";
1051   }
1052   s << "}\n";
1053 
1054   s << "fn get_size(&self) -> usize {";
1055   GenSizeRetVal(s);
1056   s << "}\n";
1057   s << "}\n";
1058 }
1059 
GenRustAccessStructImpls(std::ostream & s) const1060 void PacketDef::GenRustAccessStructImpls(std::ostream& s) const {
1061   if (complement_ != nullptr) {
1062     auto complement_root = complement_->GetRootDef();
1063     auto complement_root_accessor = util::CamelCaseToUnderScore(complement_root->name_);
1064     s << "impl CommandExpectations for " << name_ << "Packet {";
1065     s << " type ResponseType = " << complement_->name_ << "Packet;";
1066     s << " fn _to_response_type(pkt: EventPacket) -> Self::ResponseType { ";
1067     s << complement_->name_ << "Packet::new(pkt." << complement_root_accessor << ".clone())";
1068     s << " }";
1069     s << "}";
1070   }
1071 
1072   s << "impl Packet for " << name_ << "Packet {";
1073   auto root = GetRootDef();
1074   auto root_accessor = util::CamelCaseToUnderScore(root->name_);
1075 
1076   s << "fn to_bytes(self) -> Bytes {";
1077   s << " let mut buffer = BytesMut::new();";
1078   s << " buffer.resize(self." << root_accessor << ".get_total_size(), 0);";
1079   s << " self." << root_accessor << ".write_to(&mut buffer);";
1080   s << " buffer.freeze()";
1081   s << "}\n";
1082 
1083   s << "fn to_vec(self) -> Vec<u8> { self.to_bytes().to_vec() }\n";
1084   s << "}";
1085 
1086   s << "impl " << name_ << "Packet {";
1087   if (parent_ == nullptr) {
1088     s << "pub fn parse(bytes: &[u8]) -> Result<Self> { ";
1089     s << "Ok(Self::new(Arc::new(" << name_ << "Data::parse(bytes)?)))";
1090     s << "}";
1091   }
1092 
1093   if (HasChildEnums()) {
1094     s << " pub fn specialize(&self) -> " << name_ << "Child {";
1095     s << " match &self." << util::CamelCaseToUnderScore(name_) << ".child {";
1096     for (const auto& child : children_) {
1097       s << name_ << "DataChild::" << child->name_ << "(_) => " << name_ << "Child::" << child->name_ << "("
1098         << child->name_ << "Packet::new(self." << root_accessor << ".clone())),";
1099     }
1100     if (fields_.HasPayload()) {
1101       s << name_ << "DataChild::Payload(p) => " << name_ << "Child::Payload(p.clone()),";
1102     }
1103     s << name_ << "DataChild::None => " << name_ << "Child::None,";
1104     s << "}}";
1105   }
1106   auto lineage = GetAncestors();
1107   lineage.push_back(this);
1108   const ParentDef* prev = nullptr;
1109 
1110   s << " fn new(root: Arc<" << root->name_ << "Data>) -> Self {";
1111   for (auto it = lineage.begin(); it != lineage.end(); it++) {
1112     auto def = *it;
1113     auto accessor_name = util::CamelCaseToUnderScore(def->name_);
1114     if (prev == nullptr) {
1115       s << "let " << accessor_name << " = root;";
1116     } else {
1117       s << "let " << accessor_name << " = match &" << util::CamelCaseToUnderScore(prev->name_) << ".child {";
1118       s << prev->name_ << "DataChild::" << def->name_ << "(value) => (*value).clone(),";
1119       s << "_ => panic!(\"inconsistent state - child was not " << def->name_ << "\"),";
1120       s << "};";
1121     }
1122     prev = def;
1123   }
1124   s << "Self {";
1125   for (auto it = lineage.begin(); it != lineage.end(); it++) {
1126     auto def = *it;
1127     s << util::CamelCaseToUnderScore(def->name_) << ",";
1128   }
1129   s << "}}";
1130 
1131   for (auto it = lineage.begin(); it != lineage.end(); it++) {
1132     auto def = *it;
1133     auto fields = def->fields_.GetFieldsWithoutTypes({
1134         BodyField::kFieldType,
1135         CountField::kFieldType,
1136         PaddingField::kFieldType,
1137         ReservedField::kFieldType,
1138         SizeField::kFieldType,
1139         PayloadField::kFieldType,
1140         FixedScalarField::kFieldType,
1141     });
1142 
1143     for (auto const& field : fields) {
1144       if (field->GetterIsByRef()) {
1145         s << "pub fn get_" << field->GetName() << "(&self) -> &" << field->GetRustDataType() << "{";
1146         s << " &self." << util::CamelCaseToUnderScore(def->name_) << ".as_ref()." << field->GetName();
1147         s << "}\n";
1148       } else {
1149         s << "pub fn get_" << field->GetName() << "(&self) -> " << field->GetRustDataType() << "{";
1150         s << " self." << util::CamelCaseToUnderScore(def->name_) << ".as_ref()." << field->GetName();
1151         s << "}\n";
1152       }
1153     }
1154   }
1155 
1156   s << "}\n";
1157 
1158   lineage = GetAncestors();
1159   for (auto it = lineage.begin(); it != lineage.end(); it++) {
1160     auto def = *it;
1161     s << "impl Into<" << def->name_ << "Packet> for " << name_ << "Packet {";
1162     s << " fn into(self) -> " << def->name_ << "Packet {";
1163     s << def->name_ << "Packet::new(self." << util::CamelCaseToUnderScore(root->name_) << ")";
1164     s << " }";
1165     s << "}\n";
1166   }
1167 }
1168 
GenRustBuilderStructImpls(std::ostream & s) const1169 void PacketDef::GenRustBuilderStructImpls(std::ostream& s) const {
1170   if (complement_ != nullptr) {
1171     auto complement_root = complement_->GetRootDef();
1172     auto complement_root_accessor = util::CamelCaseToUnderScore(complement_root->name_);
1173     s << "impl CommandExpectations for " << name_ << "Builder {";
1174     s << " type ResponseType = " << complement_->name_ << "Packet;";
1175     s << " fn _to_response_type(pkt: EventPacket) -> Self::ResponseType { ";
1176     s << complement_->name_ << "Packet::new(pkt." << complement_root_accessor << ".clone())";
1177     s << " }";
1178     s << "}";
1179   }
1180 
1181   s << "impl " << name_ << "Builder {";
1182   s << "pub fn build(self) -> " << name_ << "Packet {";
1183   auto lineage = GetAncestors();
1184   lineage.push_back(this);
1185   std::reverse(lineage.begin(), lineage.end());
1186 
1187   auto all_constraints = GetAllConstraints();
1188 
1189   const ParentDef* prev = nullptr;
1190   for (auto ancestor : lineage) {
1191     auto fields = ancestor->fields_.GetFieldsWithoutTypes({
1192         BodyField::kFieldType,
1193         CountField::kFieldType,
1194         PaddingField::kFieldType,
1195         ReservedField::kFieldType,
1196         SizeField::kFieldType,
1197         PayloadField::kFieldType,
1198         FixedScalarField::kFieldType,
1199     });
1200 
1201     auto accessor_name = util::CamelCaseToUnderScore(ancestor->name_);
1202     s << "let " << accessor_name << "= Arc::new(" << ancestor->name_ << "Data {";
1203     for (auto field : fields) {
1204       auto constraint = all_constraints.find(field->GetName());
1205       s << field->GetName() << ": ";
1206       if (constraint != all_constraints.end()) {
1207         if (field->GetFieldType() == ScalarField::kFieldType) {
1208           s << std::get<int64_t>(constraint->second);
1209         } else if (field->GetFieldType() == EnumField::kFieldType) {
1210           auto value = std::get<std::string>(constraint->second);
1211           auto constant = value.substr(value.find("::") + 2, std::string::npos);
1212           s << field->GetDataType() << "::" << util::ConstantCaseToCamelCase(constant);
1213           ;
1214         } else {
1215           ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
1216         }
1217       } else {
1218         s << "self." << field->GetName();
1219       }
1220       s << ", ";
1221     }
1222     if (ancestor->HasChildEnums()) {
1223       if (prev == nullptr) {
1224         if (ancestor->fields_.HasPayload()) {
1225           s << "child: match self.payload { ";
1226           s << "None => " << name_ << "DataChild::None,";
1227           s << "Some(bytes) => " << name_ << "DataChild::Payload(bytes),";
1228           s << "},";
1229         } else {
1230           s << "child: " << name_ << "DataChild::None,";
1231         }
1232       } else {
1233         s << "child: " << ancestor->name_ << "DataChild::" << prev->name_ << "("
1234           << util::CamelCaseToUnderScore(prev->name_) << "),";
1235       }
1236     }
1237     s << "});";
1238     prev = ancestor;
1239   }
1240 
1241   s << name_ << "Packet::new(" << util::CamelCaseToUnderScore(prev->name_) << ")";
1242   s << "}\n";
1243 
1244   s << "}\n";
1245   for (const auto ancestor : GetAncestors()) {
1246     s << "impl Into<" << ancestor->name_ << "Packet> for " << name_ << "Builder {";
1247     s << " fn into(self) -> " << ancestor->name_ << "Packet { self.build().into() }";
1248     s << "}\n";
1249   }
1250 }
1251 
GenRustBuilderTest(std::ostream & s) const1252 void PacketDef::GenRustBuilderTest(std::ostream& s) const {
1253   auto lineage = GetAncestors();
1254   lineage.push_back(this);
1255   if (!lineage.empty() && !test_cases_.empty()) {
1256     s << "macro_rules! " << util::CamelCaseToUnderScore(name_) << "_builder_tests { ";
1257     s << "($($name:ident: $byte_string:expr,)*) => {";
1258     s << "$(";
1259     s << "\n#[test]\n";
1260     s << "pub fn $name() { ";
1261     s << "let raw_bytes = $byte_string;";
1262     for (size_t i = 0; i < lineage.size(); i++) {
1263       s << "/* (" << i << ") */\n";
1264       if (i == 0) {
1265         s << "match " << lineage[i]->name_ << "Packet::parse(raw_bytes) {";
1266         s << "Ok(" << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet) => {";
1267         s << "match " << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet.specialize() {";
1268       } else if (i != lineage.size() - 1) {
1269         s << lineage[i - 1]->name_ << "Child::" << lineage[i]->name_ << "(";
1270         s << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet) => {";
1271         s << "match " << util::CamelCaseToUnderScore(lineage[i]->name_) << "_packet.specialize() {";
1272       } else {
1273         s << lineage[i - 1]->name_ << "Child::" << lineage[i]->name_ << "(packet) => {";
1274         s << "let rebuilder = " << lineage[i]->name_ << "Builder {";
1275         FieldList params = GetParamList();
1276         if (params.HasBody()) {
1277           ERROR() << "Packets with body fields can't be auto-tested.  Test a child.";
1278         }
1279         for (const auto param : params) {
1280           s << param->GetName() << " : packet.";
1281           if (param->GetFieldType() == VectorField::kFieldType) {
1282             s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().to_vec(),";
1283           } else if (param->GetFieldType() == ArrayField::kFieldType) {
1284             const auto array_param = static_cast<const ArrayField*>(param);
1285             const auto element_field = array_param->GetElementField();
1286             if (element_field->GetFieldType() == StructField::kFieldType) {
1287               s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().to_vec(),";
1288             } else {
1289               s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().clone(),";
1290             }
1291           } else if (param->GetFieldType() == StructField::kFieldType) {
1292             s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "().clone(),";
1293           } else {
1294             s << util::CamelCaseToUnderScore(param->GetGetterFunctionName()) << "(),";
1295           }
1296         }
1297         s << "};";
1298         s << "let rebuilder_base : " << lineage[0]->name_ << "Packet = rebuilder.into();";
1299         s << "let rebuilder_bytes : &[u8] = &rebuilder_base.to_bytes();";
1300         s << "assert_eq!(rebuilder_bytes, raw_bytes);";
1301         s << "}";
1302       }
1303     }
1304     for (size_t i = 1; i < lineage.size(); i++) {
1305       s << "_ => {";
1306       s << "println!(\"Couldn't parse " << util::CamelCaseToUnderScore(lineage[lineage.size() - i]->name_);
1307       s << "{:02x?}\", " << util::CamelCaseToUnderScore(lineage[lineage.size() - i - 1]->name_) << "_packet); ";
1308       s << "}}}";
1309     }
1310 
1311     s << ",";
1312     s << "Err(e) => panic!(\"could not parse " << lineage[0]->name_ << ": {:?} {:02x?}\", e, raw_bytes),";
1313     s << "}";
1314     s << "}";
1315     s << ")*";
1316     s << "}";
1317     s << "}";
1318 
1319     s << util::CamelCaseToUnderScore(name_) << "_builder_tests! { ";
1320     int number = 0;
1321     for (const auto& test_case : test_cases_) {
1322       s << util::CamelCaseToUnderScore(name_) << "_builder_test_";
1323       s << std::setfill('0') << std::setw(2) << number++ << ": ";
1324       s << "b\"" << test_case << "\",";
1325     }
1326     s << "}";
1327     s << "\n";
1328   }
1329 }
1330 
GenRustDef(std::ostream & s) const1331 void PacketDef::GenRustDef(std::ostream& s) const {
1332   GenRustChildEnums(s);
1333   GenRustStructDeclarations(s);
1334   GenRustStructImpls(s);
1335   GenRustAccessStructImpls(s);
1336   GenRustBuilderStructImpls(s);
1337   GenRustBuilderTest(s);
1338 }
1339