1 /*
2 * Copyright 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "packet_def.h"
18
19 #include <iomanip>
20 #include <list>
21 #include <set>
22
23 #include "fields/all_fields.h"
24 #include "packet_dependency.h"
25 #include "util.h"
26
PacketDef(std::string name,FieldList fields)27 PacketDef::PacketDef(std::string name, FieldList fields) : ParentDef(name, fields, nullptr) {}
PacketDef(std::string name,FieldList fields,PacketDef * parent)28 PacketDef::PacketDef(std::string name, FieldList fields, PacketDef* parent) : ParentDef(name, fields, parent) {}
29
GetNewField(const std::string &,ParseLocation) const30 PacketField* PacketDef::GetNewField(const std::string&, ParseLocation) const {
31 return nullptr; // Packets can't be fields
32 }
33
GenParserDefinition(std::ostream & s,bool generate_fuzzing,bool generate_tests) const34 void PacketDef::GenParserDefinition(std::ostream& s, bool generate_fuzzing, bool generate_tests) const {
35 s << "class " << name_ << "View";
36 if (parent_ != nullptr) {
37 s << " : public " << parent_->name_ << "View {";
38 } else {
39 s << " : public PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> {";
40 }
41 s << " public:";
42
43 // Specialize function
44 if (parent_ != nullptr) {
45 s << "static " << name_ << "View Create(" << parent_->name_ << "View parent)";
46 s << "{ return " << name_ << "View(std::move(parent)); }";
47 // CreateOptional
48 s << "static std::optional<" << name_ << "View> CreateOptional(";
49 s << parent_->name_ << "View parent)";
50 s << "{ auto to_validate = " << name_ << "View::Create(std::move(parent));";
51 s << "if (to_validate.IsValid()) { return to_validate; }";
52 s << "else {return {};}}";
53 } else {
54 s << "static " << name_ << "View Create(PacketView<";
55 s << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet)";
56 s << "{ return " << name_ << "View(std::move(packet)); }";
57 // CreateOptional
58 s << "static std::optional<" << name_ << "View> CreateOptional(PacketView<";
59 s << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet)";
60 s << "{ auto to_validate = " << name_ << "View::Create(std::move(packet));";
61 s << "if (to_validate.IsValid()) { return to_validate; }";
62 s << "else {return {};}}";
63 }
64
65 if (generate_fuzzing || generate_tests) {
66 GenTestingParserFromBytes(s);
67 }
68
69 std::set<std::string> fixed_types = {
70 FixedScalarField::kFieldType,
71 FixedEnumField::kFieldType,
72 };
73
74 // Print all of the public fields which are all the fields minus the fixed fields.
75 const auto& public_fields = fields_.GetFieldsWithoutTypes(fixed_types);
76 bool has_fixed_fields = public_fields.size() != fields_.size();
77 for (const auto& field : public_fields) {
78 GenParserFieldGetter(s, field);
79 s << "\n";
80 }
81 GenValidator(s);
82 s << "\n";
83
84 s << " public:";
85 GenParserToString(s);
86 s << "\n";
87
88 s << " protected:\n";
89 // Constructor from a View
90 if (parent_ != nullptr) {
91 s << "explicit " << name_ << "View(" << parent_->name_ << "View parent)";
92 s << " : " << parent_->name_ << "View(std::move(parent)) { was_validated_ = false; }";
93 } else {
94 s << "explicit " << name_ << "View(PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> packet) ";
95 s << " : PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(packet) { was_validated_ = false;}";
96 }
97
98 // Print the private fields which are the fixed fields.
99 if (has_fixed_fields) {
100 const auto& private_fields = fields_.GetFieldsWithTypes(fixed_types);
101 s << " private:\n";
102 for (const auto& field : private_fields) {
103 GenParserFieldGetter(s, field);
104 s << "\n";
105 }
106 }
107 s << "};\n";
108 }
109
GenTestingParserFromBytes(std::ostream & s) const110 void PacketDef::GenTestingParserFromBytes(std::ostream& s) const {
111 s << "\n#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n";
112
113 s << "static " << name_ << "View FromBytes(std::vector<uint8_t> bytes) {";
114 s << "auto vec = std::make_shared<std::vector<uint8_t>>(bytes);";
115 s << "return " << name_ << "View::Create(";
116 auto ancestor_ptr = parent_;
117 size_t parent_parens = 0;
118 while (ancestor_ptr != nullptr) {
119 s << ancestor_ptr->name_ << "View::Create(";
120 parent_parens++;
121 ancestor_ptr = ancestor_ptr->parent_;
122 }
123 s << "PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>(vec)";
124 for (size_t i = 0; i < parent_parens; i++) {
125 s << ")";
126 }
127 s << ");";
128 s << "}";
129
130 s << "\n#endif\n";
131 }
132
GenParserDefinitionPybind11(std::ostream & s) const133 void PacketDef::GenParserDefinitionPybind11(std::ostream& s) const {
134 s << "py::class_<" << name_ << "View";
135 if (parent_ != nullptr) {
136 s << ", " << parent_->name_ << "View";
137 } else {
138 s << ", PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian>";
139 }
140 s << ">(m, \"" << name_ << "View\")";
141 if (parent_ != nullptr) {
142 s << ".def(py::init([](" << parent_->name_ << "View parent) {";
143 } else {
144 s << ".def(py::init([](PacketView<" << (is_little_endian_ ? "" : "!") << "kLittleEndian> parent) {";
145 }
146 s << "auto view =" << name_ << "View::Create(std::move(parent));";
147 s << "if (!view.IsValid()) { throw std::invalid_argument(\"Bad packet view\"); }";
148 s << "return view; }))";
149
150 s << ".def(py::init(&" << name_ << "View::Create))";
151 std::set<std::string> protected_field_types = {
152 FixedScalarField::kFieldType,
153 FixedEnumField::kFieldType,
154 SizeField::kFieldType,
155 CountField::kFieldType,
156 };
157 const auto& public_fields = fields_.GetFieldsWithoutTypes(protected_field_types);
158 for (const auto& field : public_fields) {
159 auto getter_func_name = field->GetGetterFunctionName();
160 if (getter_func_name.empty()) {
161 continue;
162 }
163 s << ".def(\"" << getter_func_name << "\", &" << name_ << "View::" << getter_func_name << ")";
164 }
165 s << ".def(\"IsValid\", &" << name_ << "View::IsValid)";
166 s << ";\n";
167 }
168
GenParserFieldGetter(std::ostream & s,const PacketField * field) const169 void PacketDef::GenParserFieldGetter(std::ostream& s, const PacketField* field) const {
170 // Start field offset
171 auto start_field_offset = GetOffsetForField(field->GetName(), false);
172 auto end_field_offset = GetOffsetForField(field->GetName(), true);
173
174 if (start_field_offset.empty() && end_field_offset.empty()) {
175 ERROR(field) << "Field location for " << field->GetName() << " is ambiguous, "
176 << "no method exists to determine field location from begin() or end().\n";
177 }
178
179 field->GenGetter(s, start_field_offset, end_field_offset);
180 }
181
GetDefinitionType() const182 TypeDef::Type PacketDef::GetDefinitionType() const {
183 return TypeDef::Type::PACKET;
184 }
185
GenValidator(std::ostream & s) const186 void PacketDef::GenValidator(std::ostream& s) const {
187 // Get the static offset for all of our fields.
188 int bits_size = 0;
189 for (const auto& field : fields_) {
190 if (field->GetFieldType() != PaddingField::kFieldType) {
191 bits_size += field->GetSize().bits();
192 }
193 }
194
195 // Generate the public validator IsValid().
196 // The method only needs to be generated for the top most class.
197 if (parent_ == nullptr) {
198 s << "bool IsValid() {" << std::endl;
199 s << " if (was_validated_) {" << std::endl;
200 s << " return true;" << std::endl;
201 s << " } else {" << std::endl;
202 s << " was_validated_ = true;" << std::endl;
203 s << " return (was_validated_ = Validate());" << std::endl;
204 s << " }" << std::endl;
205 s << "}" << std::endl;
206 }
207
208 // Generate the private validator Validate().
209 // The method is overridden by all child classes.
210 s << "protected:" << std::endl;
211 if (parent_ == nullptr) {
212 s << "virtual bool Validate() const {" << std::endl;
213 } else {
214 s << "bool Validate() const override {" << std::endl;
215 s << " if (!" << parent_->name_ << "View::Validate()) {" << std::endl;
216 s << " return false;" << std::endl;
217 s << " }" << std::endl;
218 }
219
220 // Offset by the parents known size. We know that any dynamic fields can
221 // already be called since the parent must have already been validated by
222 // this point.
223 auto parent_size = Size(0);
224 if (parent_ != nullptr) {
225 parent_size = parent_->GetSize(true);
226 }
227
228 s << "auto it = begin() + (" << parent_size << ") / 8;";
229
230 // Check if you can extract the static fields.
231 // At this point you know you can use the size getters without crashing
232 // as long as they follow the instruction that size fields cant come before
233 // their corrisponding variable length field.
234 s << "it += " << ((bits_size + 7) / 8) << " /* Total size of the fixed fields */;";
235 s << "if (it > end()) return false;";
236
237 // For any variable length fields, use their size check.
238 for (const auto& field : fields_) {
239 if (field->GetFieldType() == ChecksumStartField::kFieldType) {
240 auto offset = GetOffsetForField(field->GetName(), false);
241 if (!offset.empty()) {
242 s << "size_t sum_index = (" << offset << ") / 8;";
243 } else {
244 offset = GetOffsetForField(field->GetName(), true);
245 if (offset.empty()) {
246 ERROR(field) << "Checksum Start Field offset can not be determined.";
247 }
248 s << "size_t sum_index = size() - (" << offset << ") / 8;";
249 }
250
251 const auto& field_name = ((ChecksumStartField*)field)->GetStartedFieldName();
252 const auto& started_field = fields_.GetField(field_name);
253 if (started_field == nullptr) {
254 ERROR(field) << __func__ << ": Can't find checksum field named " << field_name << "(" << field->GetName()
255 << ")";
256 }
257 auto end_offset = GetOffsetForField(started_field->GetName(), false);
258 if (!end_offset.empty()) {
259 s << "size_t end_sum_index = (" << end_offset << ") / 8;";
260 } else {
261 end_offset = GetOffsetForField(started_field->GetName(), true);
262 if (end_offset.empty()) {
263 ERROR(started_field) << "Checksum Field end_offset can not be determined.";
264 }
265 s << "size_t end_sum_index = size() - (" << started_field->GetSize() << " - " << end_offset << ") / 8;";
266 }
267 s << "if (end_sum_index >= size()) { return false; }";
268 if (is_little_endian_) {
269 s << "auto checksum_view = GetLittleEndianSubview(sum_index, end_sum_index);";
270 } else {
271 s << "auto checksum_view = GetBigEndianSubview(sum_index, end_sum_index);";
272 }
273 s << started_field->GetDataType() << " checksum;";
274 s << "checksum.Initialize();";
275 s << "for (uint8_t byte : checksum_view) { ";
276 s << "checksum.AddByte(byte);}";
277 s << "if (checksum.GetChecksum() != (begin() + end_sum_index).extract<"
278 << util::GetTypeForSize(started_field->GetSize().bits()) << ">()) { return false; }";
279
280 continue;
281 }
282
283 auto field_size = field->GetSize();
284 // Fixed size fields have already been handled.
285 if (!field_size.has_dynamic()) {
286 continue;
287 }
288
289 // Custom fields with dynamic size must have the offset for the field passed in as well
290 // as the end iterator so that they may ensure that they don't try to read past the end.
291 // Custom fields with fixed sizes will be handled in the static offset checking.
292 if (field->GetFieldType() == CustomField::kFieldType) {
293 // Check if we can determine offset from begin(), otherwise error because by this point,
294 // the size of the custom field is unknown and can't be subtracted from end() to get the
295 // offset.
296 auto offset = GetOffsetForField(field->GetName(), false);
297 if (offset.empty()) {
298 ERROR(field) << "Custom Field offset can not be determined from begin().";
299 }
300
301 if (offset.bits() % 8 != 0) {
302 ERROR(field) << "Custom fields must be byte aligned.";
303 }
304
305 // Custom fields are special as their size field takes an argument.
306 const auto& custom_size_var = field->GetName() + "_size";
307 s << "const auto& " << custom_size_var << " = " << field_size.dynamic_string();
308 s << "(begin() + (" << offset << ") / 8);";
309
310 s << "if (!" << custom_size_var << ".has_value()) { return false; }";
311 s << "it += *" << custom_size_var << ";";
312 s << "if (it > end()) return false;";
313 continue;
314 } else {
315 s << "it += (" << field_size.dynamic_string() << ") / 8;";
316 s << "if (it > end()) return false;";
317 }
318 }
319
320 // Validate constraints after validating the size
321 if (parent_constraints_.size() > 0 && parent_ == nullptr) {
322 ERROR() << "Can't have a constraint on a NULL parent";
323 }
324
325 for (const auto& constraint : parent_constraints_) {
326 s << "if (Get" << util::UnderscoreToCamelCase(constraint.first) << "() != ";
327 const auto& field = parent_->GetParamList().GetField(constraint.first);
328 if (field->GetFieldType() == ScalarField::kFieldType) {
329 s << std::get<int64_t>(constraint.second);
330 } else {
331 s << std::get<std::string>(constraint.second);
332 }
333 s << ") return false;";
334 }
335
336 // Validate the packets fields last
337 for (const auto& field : fields_) {
338 field->GenValidator(s);
339 s << "\n";
340 }
341
342 s << "return true;";
343 s << "}\n";
344 if (parent_ == nullptr) {
345 s << "bool was_validated_{false};\n";
346 }
347 }
348
GenParserToString(std::ostream & s) const349 void PacketDef::GenParserToString(std::ostream& s) const {
350 s << "virtual std::string ToString() const " << (parent_ != nullptr ? " override" : "") << " {";
351 s << "std::stringstream ss;";
352 s << "ss << std::showbase << std::hex << \"" << name_ << " { \";";
353
354 if (fields_.size() > 0) {
355 s << "ss << \"\" ";
356 bool firstfield = true;
357 for (const auto& field : fields_) {
358 if (field->GetFieldType() == ReservedField::kFieldType || field->GetFieldType() == FixedScalarField::kFieldType ||
359 field->GetFieldType() == ChecksumStartField::kFieldType)
360 continue;
361
362 s << (firstfield ? " << \"" : " << \", ") << field->GetName() << " = \" << ";
363
364 field->GenStringRepresentation(s, field->GetGetterFunctionName() + "()");
365
366 if (firstfield) {
367 firstfield = false;
368 }
369 }
370 s << ";";
371 }
372
373 s << "ss << \" }\";";
374 s << "return ss.str();";
375 s << "}\n";
376 }
377
GenBuilderDefinition(std::ostream & s,bool generate_fuzzing,bool generate_tests) const378 void PacketDef::GenBuilderDefinition(std::ostream& s, bool generate_fuzzing, bool generate_tests) const {
379 s << "class " << name_ << "Builder";
380 if (parent_ != nullptr) {
381 s << " : public " << parent_->name_ << "Builder";
382 } else {
383 if (is_little_endian_) {
384 s << " : public PacketBuilder<kLittleEndian>";
385 } else {
386 s << " : public PacketBuilder<!kLittleEndian>";
387 }
388 }
389 s << " {";
390 s << " public:";
391 s << " virtual ~" << name_ << "Builder() = default;";
392
393 if (!fields_.HasBody()) {
394 GenBuilderCreate(s);
395 s << "\n";
396
397 if (generate_fuzzing || generate_tests) {
398 GenTestingFromView(s);
399 s << "\n";
400 }
401 }
402
403 GenSerialize(s);
404 s << "\n";
405
406 GenSize(s);
407 s << "\n";
408
409 s << " protected:\n";
410 GenBuilderConstructor(s);
411 s << "\n";
412
413 GenBuilderParameterChecker(s);
414 s << "\n";
415
416 GenMembers(s);
417 s << "};\n";
418
419 if (generate_tests) {
420 GenTestDefine(s);
421 s << "\n";
422 }
423
424 if (generate_fuzzing || generate_tests) {
425 GenReflectTestDefine(s);
426 s << "\n";
427 }
428
429 if (generate_fuzzing) {
430 GenFuzzTestDefine(s);
431 s << "\n";
432 }
433 }
434
GenTestingFromView(std::ostream & s) const435 void PacketDef::GenTestingFromView(std::ostream& s) const {
436 s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING) || defined(FUZZ_TARGET)\n";
437
438 s << "static std::unique_ptr<" << name_ << "Builder> FromView(" << name_ << "View view) {";
439 s << "if (!view.IsValid()) return nullptr;";
440 s << "return " << name_ << "Builder::Create(";
441 FieldList params = GetParamList().GetFieldsWithoutTypes({
442 BodyField::kFieldType,
443 });
444 for (std::size_t i = 0; i < params.size(); i++) {
445 params[i]->GenBuilderParameterFromView(s);
446 if (i != params.size() - 1) {
447 s << ", ";
448 }
449 }
450 s << ");";
451 s << "}";
452
453 s << "\n#endif\n";
454 }
455
GenBuilderDefinitionPybind11(std::ostream & s) const456 void PacketDef::GenBuilderDefinitionPybind11(std::ostream& s) const {
457 s << "py::class_<" << name_ << "Builder";
458 if (parent_ != nullptr) {
459 s << ", " << parent_->name_ << "Builder";
460 } else {
461 if (is_little_endian_) {
462 s << ", PacketBuilder<kLittleEndian>";
463 } else {
464 s << ", PacketBuilder<!kLittleEndian>";
465 }
466 }
467 s << ", std::shared_ptr<" << name_ << "Builder>";
468 s << ">(m, \"" << name_ << "Builder\")";
469 if (!fields_.HasBody()) {
470 GenBuilderCreatePybind11(s);
471 }
472 s << ".def(\"Serialize\", [](" << name_ << "Builder& builder){";
473 s << "std::vector<uint8_t> bytes;";
474 s << "BitInserter bi(bytes);";
475 s << "builder.Serialize(bi);";
476 s << "return bytes;})";
477 s << ";\n";
478 }
479
GenTestDefine(std::ostream & s) const480 void PacketDef::GenTestDefine(std::ostream& s) const {
481 s << "#ifdef PACKET_TESTING\n";
482 s << "#define DEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest(...)";
483 s << "class " << name_ << "ReflectionTest : public testing::TestWithParam<std::vector<uint8_t>> { ";
484 s << "public: ";
485 s << "void CompareBytes(std::vector<uint8_t> captured_packet) {";
486 s << name_ << "View view = " << name_ << "View::FromBytes(captured_packet);";
487 s << "if (!view.IsValid()) { log::info(\"Invalid Packet Bytes (size = {})\", view.size());";
488 s << "for (size_t i = 0; i < view.size(); i++) { log::info(\"{:5}:{:02x}\", i, *(view.begin() + "
489 "i)); }}";
490 s << "ASSERT_TRUE(view.IsValid());";
491 s << "auto packet = " << name_ << "Builder::FromView(view);";
492 s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = std::make_shared<std::vector<uint8_t>>();";
493 s << "packet_bytes->reserve(packet->size());";
494 s << "BitInserter it(*packet_bytes);";
495 s << "packet->Serialize(it);";
496 s << "ASSERT_EQ(*packet_bytes, captured_packet);";
497 s << "}";
498 s << "};";
499 s << "TEST_P(" << name_ << "ReflectionTest, generatedReflectionTest) {";
500 s << "CompareBytes(GetParam());";
501 s << "}";
502 s << "INSTANTIATE_TEST_SUITE_P(" << name_ << "_reflection, ";
503 s << name_ << "ReflectionTest, testing::Values(__VA_ARGS__))";
504 int i = 0;
505 for (const auto& bytes : test_cases_) {
506 s << "\nuint8_t " << name_ << "_test_bytes_" << i << "[] = \"" << bytes << "\";";
507 s << "std::vector<uint8_t> " << name_ << "_test_vec_" << i << "(";
508 s << name_ << "_test_bytes_" << i << ",";
509 s << name_ << "_test_bytes_" << i << " + sizeof(";
510 s << name_ << "_test_bytes_" << i << ") - 1);";
511 i++;
512 }
513 if (!test_cases_.empty()) {
514 i = 0;
515 s << "\nDEFINE_AND_INSTANTIATE_" << name_ << "ReflectionTest(";
516 for (auto bytes : test_cases_) {
517 if (i > 0) {
518 s << ",";
519 }
520 s << name_ << "_test_vec_" << i++;
521 }
522 s << ");";
523 }
524 s << "\n#endif";
525 }
526
GenReflectTestDefine(std::ostream & s) const527 void PacketDef::GenReflectTestDefine(std::ostream& s) const {
528 s << "#if defined(PACKET_FUZZ_TESTING) || defined(PACKET_TESTING)\n";
529 s << "#define DEFINE_" << name_ << "ReflectionFuzzTest() ";
530 s << "void Run" << name_ << "ReflectionFuzzTest(const uint8_t* data, size_t size) {";
531 s << "auto vec = std::vector<uint8_t>(data, data + size);";
532 s << name_ << "View view = " << name_ << "View::FromBytes(vec);";
533 s << "if (!view.IsValid()) { return; }";
534 s << "auto packet = " << name_ << "Builder::FromView(view);";
535 s << "std::shared_ptr<std::vector<uint8_t>> packet_bytes = std::make_shared<std::vector<uint8_t>>();";
536 s << "packet_bytes->reserve(packet->size());";
537 s << "BitInserter it(*packet_bytes);";
538 s << "packet->Serialize(it);";
539 s << "}";
540 s << "\n#endif\n";
541 }
542
GenFuzzTestDefine(std::ostream & s) const543 void PacketDef::GenFuzzTestDefine(std::ostream& s) const {
544 s << "#ifdef PACKET_FUZZ_TESTING\n";
545 s << "#define DEFINE_AND_REGISTER_" << name_ << "ReflectionFuzzTest(REGISTRY) ";
546 s << "DEFINE_" << name_ << "ReflectionFuzzTest();";
547 s << " class " << name_ << "ReflectionFuzzTestRegistrant {";
548 s << "public: ";
549 s << "explicit " << name_
550 << "ReflectionFuzzTestRegistrant(std::vector<void(*)(const uint8_t*, size_t)>& fuzz_test_registry) {";
551 s << "fuzz_test_registry.push_back(Run" << name_ << "ReflectionFuzzTest);";
552 s << "}}; ";
553 s << name_ << "ReflectionFuzzTestRegistrant " << name_ << "_reflection_fuzz_test_registrant(REGISTRY);";
554 s << "\n#endif";
555 }
556
GetParametersToValidate() const557 FieldList PacketDef::GetParametersToValidate() const {
558 FieldList params_to_validate;
559 for (const auto& field : GetParamList()) {
560 if (field->HasParameterValidator()) {
561 params_to_validate.AppendField(field);
562 }
563 }
564 return params_to_validate;
565 }
566
GenBuilderCreate(std::ostream & s) const567 void PacketDef::GenBuilderCreate(std::ostream& s) const {
568 s << "static std::unique_ptr<" << name_ << "Builder> Create(";
569
570 auto params = GetParamList();
571 for (std::size_t i = 0; i < params.size(); i++) {
572 params[i]->GenBuilderParameter(s);
573 if (i != params.size() - 1) {
574 s << ", ";
575 }
576 }
577 s << ") {";
578
579 // Call the constructor
580 s << "auto builder = std::unique_ptr<" << name_ << "Builder>(new " << name_ << "Builder(";
581
582 params = params.GetFieldsWithoutTypes({
583 PayloadField::kFieldType,
584 BodyField::kFieldType,
585 });
586 // Add the parameters.
587 for (std::size_t i = 0; i < params.size(); i++) {
588 if (params[i]->BuilderParameterMustBeMoved()) {
589 s << "std::move(" << params[i]->GetName() << ")";
590 } else {
591 s << params[i]->GetName();
592 }
593 if (i != params.size() - 1) {
594 s << ", ";
595 }
596 }
597
598 s << "));";
599 if (fields_.HasPayload()) {
600 s << "builder->payload_ = std::move(payload);";
601 }
602 s << "return builder;";
603 s << "}\n";
604 }
605
GenBuilderCreatePybind11(std::ostream & s) const606 void PacketDef::GenBuilderCreatePybind11(std::ostream& s) const {
607 s << ".def(py::init([](";
608 auto params = GetParamList();
609 std::vector<std::string> constructor_args;
610 for (const auto& param : params) {
611 std::stringstream ss;
612 auto param_type = param->GetBuilderParameterType();
613 if (param_type.empty()) {
614 continue;
615 }
616 // Use shared_ptr instead of unique_ptr for the Python interface
617 if (param->BuilderParameterMustBeMoved()) {
618 param_type = util::StringFindAndReplaceAll(param_type, "unique_ptr", "shared_ptr");
619 }
620 ss << param_type << " " << param->GetName();
621 constructor_args.push_back(ss.str());
622 }
623 s << util::StringJoin(",", constructor_args) << "){";
624
625 // Deal with move only args
626 for (const auto& param : params) {
627 std::stringstream ss;
628 auto param_type = param->GetBuilderParameterType();
629 if (param_type.empty()) {
630 continue;
631 }
632 if (!param->BuilderParameterMustBeMoved()) {
633 continue;
634 }
635 auto move_only_param_name = param->GetName() + "_move_only";
636 s << param_type << " " << move_only_param_name << ";";
637 if (param->IsContainerField()) {
638 // Assume single layer container and copy it
639 auto struct_type = param->GetElementField()->GetDataType();
640 struct_type = util::StringFindAndReplaceAll(struct_type, "std::unique_ptr<", "");
641 struct_type = util::StringFindAndReplaceAll(struct_type, ">", "");
642 s << "for (size_t i = 0; i < " << param->GetName() << ".size(); i++) {";
643 // Serialize each struct
644 s << "auto " << param->GetName() + "_bytes = std::make_shared<std::vector<uint8_t>>();";
645 s << param->GetName() + "_bytes->reserve(" << param->GetName() << "[i]->size());";
646 s << "BitInserter " << param->GetName() + "_bi(*" << param->GetName() << "_bytes);";
647 s << param->GetName() << "[i]->Serialize(" << param->GetName() << "_bi);";
648 // Parse it again
649 s << "auto " << param->GetName() << "_view = PacketView<kLittleEndian>(" << param->GetName() << "_bytes);";
650 s << param->GetElementField()->GetDataType() << " " << param->GetName() << "_reparsed = ";
651 s << "Parse" << struct_type << "(" << param->GetName() + "_view.begin());";
652 // Push it into a new container
653 if (param->GetFieldType() == VectorField::kFieldType) {
654 s << move_only_param_name << ".push_back(std::move(" << param->GetName() + "_reparsed));";
655 } else if (param->GetFieldType() == ArrayField::kFieldType) {
656 s << move_only_param_name << "[i] = std::move(" << param->GetName() << "_reparsed);";
657 } else {
658 ERROR() << param << " is not supported by Pybind11";
659 }
660 s << "}";
661 } else {
662 // Serialize the parameter and pass the bytes in a RawBuilder
663 s << "std::vector<uint8_t> " << param->GetName() + "_bytes;";
664 s << param->GetName() + "_bytes.reserve(" << param->GetName() << "->size());";
665 s << "BitInserter " << param->GetName() + "_bi(" << param->GetName() << "_bytes);";
666 s << param->GetName() << "->Serialize(" << param->GetName() + "_bi);";
667 s << move_only_param_name << " = ";
668 s << "std::make_unique<RawBuilder>(" << param->GetName() << "_bytes);";
669 }
670 }
671 s << "return " << name_ << "Builder::Create(";
672 std::vector<std::string> builder_vars;
673 for (const auto& param : params) {
674 std::stringstream ss;
675 auto param_type = param->GetBuilderParameterType();
676 if (param_type.empty()) {
677 continue;
678 }
679 auto param_name = param->GetName();
680 if (param->BuilderParameterMustBeMoved()) {
681 ss << "std::move(" << param_name << "_move_only)";
682 } else {
683 ss << param_name;
684 }
685 builder_vars.push_back(ss.str());
686 }
687 s << util::StringJoin(",", builder_vars) << ");}";
688 s << "))";
689 }
690
GenBuilderParameterChecker(std::ostream & s) const691 void PacketDef::GenBuilderParameterChecker(std::ostream& s) const {
692 FieldList params_to_validate = GetParametersToValidate();
693
694 // Skip writing this function if there is nothing to validate.
695 if (params_to_validate.size() == 0) {
696 return;
697 }
698
699 // Generate function arguments.
700 s << "void CheckParameterValues(";
701 for (std::size_t i = 0; i < params_to_validate.size(); i++) {
702 params_to_validate[i]->GenBuilderParameter(s);
703 if (i != params_to_validate.size() - 1) {
704 s << ", ";
705 }
706 }
707 s << ") {";
708
709 // Check the parameters.
710 for (const auto& field : params_to_validate) {
711 field->GenParameterValidator(s);
712 }
713 s << "}\n";
714 }
715
GenBuilderConstructor(std::ostream & s) const716 void PacketDef::GenBuilderConstructor(std::ostream& s) const {
717 s << "explicit " << name_ << "Builder(";
718
719 // Generate the constructor parameters.
720 auto params = GetParamList().GetFieldsWithoutTypes({
721 PayloadField::kFieldType,
722 BodyField::kFieldType,
723 });
724 for (std::size_t i = 0; i < params.size(); i++) {
725 params[i]->GenBuilderParameter(s);
726 if (i != params.size() - 1) {
727 s << ", ";
728 }
729 }
730 if (params.size() > 0 || parent_constraints_.size() > 0) {
731 s << ") :";
732 } else {
733 s << ")";
734 }
735
736 // Get the list of parent params to call the parent constructor with.
737 FieldList parent_params;
738 if (parent_ != nullptr) {
739 // Pass parameters to the parent constructor
740 s << parent_->name_ << "Builder(";
741 parent_params = parent_->GetParamList().GetFieldsWithoutTypes({
742 PayloadField::kFieldType,
743 BodyField::kFieldType,
744 });
745
746 // Go through all the fields and replace constrained fields with fixed values
747 // when calling the parent constructor.
748 for (std::size_t i = 0; i < parent_params.size(); i++) {
749 const auto& field = parent_params[i];
750 const auto& constraint = parent_constraints_.find(field->GetName());
751 if (constraint != parent_constraints_.end()) {
752 if (field->GetFieldType() == ScalarField::kFieldType) {
753 s << std::get<int64_t>(constraint->second);
754 } else if (field->GetFieldType() == EnumField::kFieldType) {
755 s << std::get<std::string>(constraint->second);
756 } else {
757 ERROR(field) << "Constraints on non enum/scalar fields should be impossible.";
758 }
759
760 s << "/* " << field->GetName() << "_ */";
761 } else {
762 s << field->GetName();
763 }
764
765 if (i != parent_params.size() - 1) {
766 s << ", ";
767 }
768 }
769 s << ") ";
770 }
771
772 // Build a list of parameters that excludes all parent parameters.
773 FieldList saved_params;
774 for (const auto& field : params) {
775 if (parent_params.GetField(field->GetName()) == nullptr) {
776 saved_params.AppendField(field);
777 }
778 }
779 if (parent_ != nullptr && saved_params.size() > 0) {
780 s << ",";
781 }
782 for (std::size_t i = 0; i < saved_params.size(); i++) {
783 const auto& saved_param_name = saved_params[i]->GetName();
784 if (saved_params[i]->BuilderParameterMustBeMoved()) {
785 s << saved_param_name << "_(std::move(" << saved_param_name << "))";
786 } else {
787 s << saved_param_name << "_(" << saved_param_name << ")";
788 }
789 if (i != saved_params.size() - 1) {
790 s << ",";
791 }
792 }
793 s << " {";
794
795 FieldList params_to_validate = GetParametersToValidate();
796
797 if (params_to_validate.size() > 0) {
798 s << "CheckParameterValues(";
799 for (std::size_t i = 0; i < params_to_validate.size(); i++) {
800 s << params_to_validate[i]->GetName() << "_";
801 if (i != params_to_validate.size() - 1) {
802 s << ", ";
803 }
804 }
805 s << ");";
806 }
807
808 s << "}\n";
809 }
810