1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <stdio.h>
18 #include <stdlib.h>
19 
20 #include <fstream>
21 #include <iostream>
22 #include <map>
23 #include <set>
24 #include <stack>
25 #include <vector>
26 
27 #include <google/protobuf/compiler/code_generator.h>
28 #include <google/protobuf/compiler/importer.h>
29 #include <google/protobuf/compiler/plugin.h>
30 #include <google/protobuf/dynamic_message.h>
31 #include <google/protobuf/io/printer.h>
32 #include <google/protobuf/io/zero_copy_stream_impl.h>
33 #include <google/protobuf/util/field_comparator.h>
34 #include <google/protobuf/util/message_differencer.h>
35 
36 #include "perfetto/ext/base/string_utils.h"
37 
38 namespace protozero {
39 namespace {
40 
41 using namespace google::protobuf;
42 using namespace google::protobuf::compiler;
43 using namespace google::protobuf::io;
44 using perfetto::base::SplitString;
45 using perfetto::base::StripChars;
46 using perfetto::base::StripSuffix;
47 using perfetto::base::ToUpper;
48 
49 static constexpr auto TYPE_MESSAGE = FieldDescriptor::TYPE_MESSAGE;
50 static constexpr auto TYPE_SINT32 = FieldDescriptor::TYPE_SINT32;
51 static constexpr auto TYPE_SINT64 = FieldDescriptor::TYPE_SINT64;
52 
53 static const char kHeader[] =
54     "// DO NOT EDIT. Autogenerated by Perfetto cppgen_plugin\n";
55 
56 class CppObjGenerator : public ::google::protobuf::compiler::CodeGenerator {
57  public:
58   CppObjGenerator();
59   ~CppObjGenerator() override;
60 
61   // CodeGenerator implementation
62   bool Generate(const google::protobuf::FileDescriptor* file,
63                 const std::string& options,
64                 GeneratorContext* context,
65                 std::string* error) const override;
66 
67  private:
68   std::string GetCppType(const FieldDescriptor* field, bool constref) const;
69   std::string GetProtozeroSetter(const FieldDescriptor* field) const;
70   std::string GetPackedBuffer(const FieldDescriptor* field) const;
71   std::string GetPackedWireType(const FieldDescriptor* field) const;
72 
73   void GenEnum(const EnumDescriptor*, Printer*) const;
74   void GenEnumAliases(const EnumDescriptor*, Printer*) const;
75   void GenClassDecl(const Descriptor*, Printer*) const;
76   void GenClassDef(const Descriptor*, Printer*) const;
77 
GetNamespaces(const FileDescriptor * file) const78   std::vector<std::string> GetNamespaces(const FileDescriptor* file) const {
79     std::string pkg = file->package() + wrapper_namespace_;
80     return SplitString(pkg, ".");
81   }
82 
83   template <typename T = Descriptor>
GetFullName(const T * msg,bool with_namespace=false) const84   std::string GetFullName(const T* msg, bool with_namespace = false) const {
85     std::string full_type;
86     full_type.append(msg->name());
87     for (const Descriptor* par = msg->containing_type(); par;
88          par = par->containing_type()) {
89       full_type.insert(0, par->name() + "_");
90     }
91     if (with_namespace) {
92       std::string prefix;
93       for (const std::string& ns : GetNamespaces(msg->file())) {
94         prefix += ns + "::";
95       }
96       full_type = prefix + full_type;
97     }
98     return full_type;
99   }
100 
101   mutable std::string wrapper_namespace_;
102 };
103 
104 CppObjGenerator::CppObjGenerator() = default;
105 CppObjGenerator::~CppObjGenerator() = default;
106 
Generate(const google::protobuf::FileDescriptor * file,const std::string & options,GeneratorContext * context,std::string * error) const107 bool CppObjGenerator::Generate(const google::protobuf::FileDescriptor* file,
108                                const std::string& options,
109                                GeneratorContext* context,
110                                std::string* error) const {
111   for (const std::string& option : SplitString(options, ",")) {
112     std::vector<std::string> option_pair = SplitString(option, "=");
113     if (option_pair[0] == "wrapper_namespace") {
114       wrapper_namespace_ =
115           option_pair.size() == 2 ? "." + option_pair[1] : std::string();
116     } else {
117       *error = "Unknown plugin option: " + option_pair[0];
118       return false;
119     }
120   }
121 
122   auto get_file_name = [](const FileDescriptor* proto) {
123     return StripSuffix(proto->name(), ".proto") + ".gen";
124   };
125 
126   const std::unique_ptr<ZeroCopyOutputStream> h_fstream(
127       context->Open(get_file_name(file) + ".h"));
128   const std::unique_ptr<ZeroCopyOutputStream> cc_fstream(
129       context->Open(get_file_name(file) + ".cc"));
130 
131   // Variables are delimited by $.
132   Printer h_printer(h_fstream.get(), '$');
133   Printer cc_printer(cc_fstream.get(), '$');
134 
135   std::string include_guard = file->package() + "_" + file->name() + "_CPP_H_";
136   include_guard = ToUpper(include_guard);
137   include_guard = StripChars(include_guard, ".-/\\", '_');
138 
139   h_printer.Print(kHeader);
140   h_printer.Print("#ifndef $g$\n#define $g$\n\n", "g", include_guard);
141   h_printer.Print("#include <stdint.h>\n");
142   h_printer.Print("#include <bitset>\n");
143   h_printer.Print("#include <vector>\n");
144   h_printer.Print("#include <string>\n");
145   h_printer.Print("#include <type_traits>\n\n");
146   h_printer.Print("#include \"perfetto/protozero/cpp_message_obj.h\"\n");
147   h_printer.Print("#include \"perfetto/protozero/copyable_ptr.h\"\n");
148   h_printer.Print("#include \"perfetto/base/export.h\"\n\n");
149 
150   cc_printer.Print("#include \"perfetto/protozero/message.h\"\n");
151   cc_printer.Print(
152       "#include \"perfetto/protozero/packed_repeated_fields.h\"\n");
153   cc_printer.Print("#include \"perfetto/protozero/proto_decoder.h\"\n");
154   cc_printer.Print("#include \"perfetto/protozero/scattered_heap_buffer.h\"\n");
155   cc_printer.Print(kHeader);
156   cc_printer.Print("#if defined(__GNUC__) || defined(__clang__)\n");
157   cc_printer.Print("#pragma GCC diagnostic push\n");
158   cc_printer.Print("#pragma GCC diagnostic ignored \"-Wfloat-equal\"\n");
159   cc_printer.Print("#endif\n");
160 
161   // Generate includes for translated types of dependencies.
162 
163   // Figure out the subset of imports that are used only for lazy fields. We
164   // won't emit a C++ #include for them. This code is overly aggressive at
165   // removing imports: it rules them out as soon as it sees one lazy field
166   // whose type is defined in that import. A 100% correct solution would require
167   // to check that *all* dependent types for a given import are lazy before
168   // excluding that. In practice we don't need that because we don't use imports
169   // for both lazy and non-lazy fields.
170   std::set<std::string> lazy_imports;
171   for (int m = 0; m < file->message_type_count(); m++) {
172     const Descriptor* msg = file->message_type(m);
173     for (int i = 0; i < msg->field_count(); i++) {
174       const FieldDescriptor* field = msg->field(i);
175       if (field->options().lazy()) {
176         lazy_imports.insert(field->message_type()->file()->name());
177       }
178     }
179   }
180 
181   // Recursively traverse all imports and turn them into #include(s).
182   std::vector<const FileDescriptor*> imports_to_visit;
183   std::set<const FileDescriptor*> imports_visited;
184   imports_to_visit.push_back(file);
185 
186   while (!imports_to_visit.empty()) {
187     const FileDescriptor* cur = imports_to_visit.back();
188     imports_to_visit.pop_back();
189     imports_visited.insert(cur);
190     std::string base_name = StripSuffix(cur->name(), ".proto");
191     cc_printer.Print("#include \"$f$.gen.h\"\n", "f", base_name);
192     for (int i = 0; i < cur->dependency_count(); i++) {
193       const FileDescriptor* dep = cur->dependency(i);
194       if (imports_visited.count(dep) || lazy_imports.count(dep->name()))
195         continue;
196       imports_to_visit.push_back(dep);
197     }
198   }
199 
200   // Compute all nested types to generate forward declarations later.
201 
202   std::set<const Descriptor*> all_types_seen;  // All deps
203   std::set<const EnumDescriptor*> all_enums_seen;
204 
205   // We track the types additionally in vectors to guarantee a stable order in
206   // the generated output.
207   std::vector<const Descriptor*> local_types;  // Cur .proto file only.
208   std::vector<const Descriptor*> all_types;    // All deps
209   std::vector<const EnumDescriptor*> local_enums;
210   std::vector<const EnumDescriptor*> all_enums;
211 
212   auto add_enum = [&local_enums, &all_enums, &all_enums_seen,
213                    &file](const EnumDescriptor* enum_desc) {
214     if (all_enums_seen.count(enum_desc))
215       return;
216     all_enums_seen.insert(enum_desc);
217     all_enums.push_back(enum_desc);
218     if (enum_desc->file() == file)
219       local_enums.push_back(enum_desc);
220   };
221 
222   for (int i = 0; i < file->enum_type_count(); i++)
223     add_enum(file->enum_type(i));
224 
225   std::stack<const Descriptor*> recursion_stack;
226   for (int i = 0; i < file->message_type_count(); i++)
227     recursion_stack.push(file->message_type(i));
228 
229   while (!recursion_stack.empty()) {
230     const Descriptor* msg = recursion_stack.top();
231     recursion_stack.pop();
232     if (all_types_seen.count(msg))
233       continue;
234     all_types_seen.insert(msg);
235     all_types.push_back(msg);
236     if (msg->file() == file)
237       local_types.push_back(msg);
238 
239     for (int i = 0; i < msg->nested_type_count(); i++)
240       recursion_stack.push(msg->nested_type(i));
241 
242     for (int i = 0; i < msg->enum_type_count(); i++)
243       add_enum(msg->enum_type(i));
244 
245     for (int i = 0; i < msg->field_count(); i++) {
246       const FieldDescriptor* field = msg->field(i);
247       if (field->has_default_value()) {
248         *error = "field " + field->name() +
249                  ": Explicitly declared default values are not supported";
250         return false;
251       }
252       if (field->options().lazy() &&
253           (field->is_repeated() || field->type() != TYPE_MESSAGE)) {
254         *error = "[lazy=true] is supported only on non-repeated fields\n";
255         return false;
256       }
257 
258       if (field->type() == TYPE_MESSAGE && !field->options().lazy())
259         recursion_stack.push(field->message_type());
260 
261       if (field->type() == FieldDescriptor::TYPE_ENUM)
262         add_enum(field->enum_type());
263     }
264   }  //  while (!recursion_stack.empty())
265 
266   // Generate forward declarations in the header for proto types.
267   // Note: do NOT add #includes to other generated headers (either .gen.h or
268   // .pbzero.h). Doing so is extremely hard to handle at the build-system level
269   // and requires propagating public_deps everywhere.
270   cc_printer.Print("\n");
271 
272   // -- Begin of fwd declarations.
273 
274   // Build up the map of forward declarations.
275   std::multimap<std::string /*namespace*/, std::string /*decl*/> fwd_decls;
276   enum FwdType { kClass, kEnum };
277   auto add_fwd_decl = [&fwd_decls](FwdType cpp_type,
278                                    const std::string& full_name) {
279     auto dot = full_name.rfind("::");
280     PERFETTO_CHECK(dot != std::string::npos);
281     auto package = full_name.substr(0, dot);
282     auto name = full_name.substr(dot + 2);
283     if (cpp_type == kClass) {
284       fwd_decls.emplace(package, "class " + name + ";");
285     } else {
286       PERFETTO_CHECK(cpp_type == kEnum);
287       fwd_decls.emplace(package, "enum " + name + " : int;");
288     }
289   };
290 
291   add_fwd_decl(kClass, "protozero::Message");
292   for (const Descriptor* msg : all_types) {
293     add_fwd_decl(kClass, GetFullName(msg, true));
294   }
295   for (const EnumDescriptor* enm : all_enums) {
296     add_fwd_decl(kEnum, GetFullName(enm, true));
297   }
298 
299   // Emit forward declarations grouping by package.
300   std::string last_package;
301   auto close_last_package = [&last_package, &h_printer] {
302     if (!last_package.empty()) {
303       for (const std::string& ns : SplitString(last_package, "::"))
304         h_printer.Print("}  // namespace $ns$\n", "ns", ns);
305       h_printer.Print("\n");
306     }
307   };
308   for (const auto& kv : fwd_decls) {
309     const std::string& package = kv.first;
310     if (package != last_package) {
311       close_last_package();
312       last_package = package;
313       for (const std::string& ns : SplitString(package, "::"))
314         h_printer.Print("namespace $ns$ {\n", "ns", ns);
315     }
316     h_printer.Print("$decl$\n", "decl", kv.second);
317   }
318   close_last_package();
319 
320   // -- End of fwd declarations.
321 
322   for (const std::string& ns : GetNamespaces(file)) {
323     h_printer.Print("namespace $n$ {\n", "n", ns);
324     cc_printer.Print("namespace $n$ {\n", "n", ns);
325   }
326 
327   // Generate declarations and definitions.
328   for (const EnumDescriptor* enm : local_enums)
329     GenEnum(enm, &h_printer);
330 
331   for (const Descriptor* msg : local_types) {
332     GenClassDecl(msg, &h_printer);
333     GenClassDef(msg, &cc_printer);
334   }
335 
336   for (const std::string& ns : GetNamespaces(file)) {
337     h_printer.Print("}  // namespace $n$\n", "n", ns);
338     cc_printer.Print("}  // namespace $n$\n", "n", ns);
339   }
340   cc_printer.Print("#if defined(__GNUC__) || defined(__clang__)\n");
341   cc_printer.Print("#pragma GCC diagnostic pop\n");
342   cc_printer.Print("#endif\n");
343 
344   h_printer.Print("\n#endif  // $g$\n", "g", include_guard);
345 
346   return true;
347 }
348 
GetCppType(const FieldDescriptor * field,bool constref) const349 std::string CppObjGenerator::GetCppType(const FieldDescriptor* field,
350                                         bool constref) const {
351   switch (field->type()) {
352     case FieldDescriptor::TYPE_DOUBLE:
353       return "double";
354     case FieldDescriptor::TYPE_FLOAT:
355       return "float";
356     case FieldDescriptor::TYPE_FIXED32:
357     case FieldDescriptor::TYPE_UINT32:
358       return "uint32_t";
359     case FieldDescriptor::TYPE_SFIXED32:
360     case FieldDescriptor::TYPE_INT32:
361     case FieldDescriptor::TYPE_SINT32:
362       return "int32_t";
363     case FieldDescriptor::TYPE_FIXED64:
364     case FieldDescriptor::TYPE_UINT64:
365       return "uint64_t";
366     case FieldDescriptor::TYPE_SFIXED64:
367     case FieldDescriptor::TYPE_SINT64:
368     case FieldDescriptor::TYPE_INT64:
369       return "int64_t";
370     case FieldDescriptor::TYPE_BOOL:
371       return "bool";
372     case FieldDescriptor::TYPE_STRING:
373     case FieldDescriptor::TYPE_BYTES:
374       return constref ? "const std::string&" : "std::string";
375     case FieldDescriptor::TYPE_MESSAGE:
376       assert(!field->options().lazy());
377       return constref ? "const " + GetFullName(field->message_type()) + "&"
378                       : GetFullName(field->message_type());
379     case FieldDescriptor::TYPE_ENUM:
380       return GetFullName(field->enum_type());
381     case FieldDescriptor::TYPE_GROUP:
382       abort();
383   }
384   abort();  // for gcc
385 }
386 
GetProtozeroSetter(const FieldDescriptor * field) const387 std::string CppObjGenerator::GetProtozeroSetter(
388     const FieldDescriptor* field) const {
389   switch (field->type()) {
390     case FieldDescriptor::TYPE_BOOL:
391       return "AppendTinyVarInt";
392     case FieldDescriptor::TYPE_INT32:
393     case FieldDescriptor::TYPE_INT64:
394     case FieldDescriptor::TYPE_UINT32:
395     case FieldDescriptor::TYPE_UINT64:
396     case FieldDescriptor::TYPE_ENUM:
397       return "AppendVarInt";
398     case FieldDescriptor::TYPE_SINT32:
399     case FieldDescriptor::TYPE_SINT64:
400       return "AppendSignedVarInt";
401     case FieldDescriptor::TYPE_FIXED32:
402     case FieldDescriptor::TYPE_FIXED64:
403     case FieldDescriptor::TYPE_SFIXED32:
404     case FieldDescriptor::TYPE_SFIXED64:
405     case FieldDescriptor::TYPE_FLOAT:
406     case FieldDescriptor::TYPE_DOUBLE:
407       return "AppendFixed";
408     case FieldDescriptor::TYPE_STRING:
409     case FieldDescriptor::TYPE_BYTES:
410       return "AppendString";
411     case FieldDescriptor::TYPE_GROUP:
412     case FieldDescriptor::TYPE_MESSAGE:
413       abort();
414   }
415   abort();
416 }
417 
GetPackedBuffer(const FieldDescriptor * field) const418 std::string CppObjGenerator::GetPackedBuffer(
419     const FieldDescriptor* field) const {
420   switch (field->type()) {
421     case FieldDescriptor::TYPE_FIXED32:
422       return "::protozero::PackedFixedSizeInt<uint32_t>";
423     case FieldDescriptor::TYPE_SFIXED32:
424       return "::protozero::PackedFixedSizeInt<int32_t>";
425     case FieldDescriptor::TYPE_FIXED64:
426       return "::protozero::PackedFixedSizeInt<uint64_t>";
427     case FieldDescriptor::TYPE_SFIXED64:
428       return "::protozero::PackedFixedSizeInt<int64_t>";
429     case FieldDescriptor::TYPE_DOUBLE:
430       return "::protozero::PackedFixedSizeInt<double>";
431     case FieldDescriptor::TYPE_FLOAT:
432       return "::protozero::PackedFixedSizeInt<float>";
433     case FieldDescriptor::TYPE_INT32:
434     case FieldDescriptor::TYPE_SINT32:
435     case FieldDescriptor::TYPE_UINT32:
436     case FieldDescriptor::TYPE_INT64:
437     case FieldDescriptor::TYPE_UINT64:
438     case FieldDescriptor::TYPE_SINT64:
439     case FieldDescriptor::TYPE_BOOL:
440       return "::protozero::PackedVarInt";
441     case FieldDescriptor::TYPE_STRING:
442     case FieldDescriptor::TYPE_BYTES:
443     case FieldDescriptor::TYPE_MESSAGE:
444     case FieldDescriptor::TYPE_ENUM:
445     case FieldDescriptor::TYPE_GROUP:
446       break;  // Will abort()
447   }
448   abort();
449 }
450 
GetPackedWireType(const FieldDescriptor * field) const451 std::string CppObjGenerator::GetPackedWireType(
452     const FieldDescriptor* field) const {
453   switch (field->type()) {
454     case FieldDescriptor::TYPE_FIXED32:
455     case FieldDescriptor::TYPE_SFIXED32:
456     case FieldDescriptor::TYPE_FLOAT:
457       return "::protozero::proto_utils::ProtoWireType::kFixed32";
458     case FieldDescriptor::TYPE_FIXED64:
459     case FieldDescriptor::TYPE_SFIXED64:
460     case FieldDescriptor::TYPE_DOUBLE:
461       return "::protozero::proto_utils::ProtoWireType::kFixed64";
462     case FieldDescriptor::TYPE_INT32:
463     case FieldDescriptor::TYPE_SINT32:
464     case FieldDescriptor::TYPE_UINT32:
465     case FieldDescriptor::TYPE_INT64:
466     case FieldDescriptor::TYPE_UINT64:
467     case FieldDescriptor::TYPE_SINT64:
468     case FieldDescriptor::TYPE_BOOL:
469       return "::protozero::proto_utils::ProtoWireType::kVarInt";
470     case FieldDescriptor::TYPE_STRING:
471     case FieldDescriptor::TYPE_BYTES:
472     case FieldDescriptor::TYPE_MESSAGE:
473     case FieldDescriptor::TYPE_ENUM:
474     case FieldDescriptor::TYPE_GROUP:
475       break;  // Will abort()
476   }
477   abort();
478 }
479 
GenEnum(const EnumDescriptor * enum_desc,Printer * p) const480 void CppObjGenerator::GenEnum(const EnumDescriptor* enum_desc,
481                               Printer* p) const {
482   std::string full_name = GetFullName(enum_desc);
483 
484   // When generating enums, there are two cases:
485   // 1. Enums nested in a message (most frequent case), e.g.:
486   //    message MyMsg { enum MyEnum { FOO=1; BAR=2; } }
487   // 2. Enum defined at the package level, outside of any message.
488   //
489   // In the case 1, the C++ code generated by the official protobuf library is:
490   // enum MyEnum {  MyMsg_MyEnum_FOO=1, MyMsg_MyEnum_BAR=2 }
491   // class MyMsg { static const auto FOO = MyMsg_MyEnum_FOO; ... same for BAR }
492   //
493   // In the case 2, the C++ code is simply:
494   // enum MyEnum { FOO=1, BAR=2 }
495   // Hence this |prefix| logic.
496   std::string prefix = enum_desc->containing_type() ? full_name + "_" : "";
497   p->Print("enum $f$ : int {\n", "f", full_name);
498   for (int e = 0; e < enum_desc->value_count(); e++) {
499     const EnumValueDescriptor* value = enum_desc->value(e);
500     p->Print("  $p$$n$ = $v$,\n", "p", prefix, "n", value->name(), "v",
501              std::to_string(value->number()));
502   }
503   p->Print("};\n");
504 }
505 
GenEnumAliases(const EnumDescriptor * enum_desc,Printer * p) const506 void CppObjGenerator::GenEnumAliases(const EnumDescriptor* enum_desc,
507                                      Printer* p) const {
508   int min_value = std::numeric_limits<int>::max();
509   int max_value = std::numeric_limits<int>::min();
510   std::string min_name;
511   std::string max_name;
512   std::string full_name = GetFullName(enum_desc);
513   for (int e = 0; e < enum_desc->value_count(); e++) {
514     const EnumValueDescriptor* value = enum_desc->value(e);
515     p->Print("static constexpr auto $n$ = $f$_$n$;\n", "f", full_name, "n",
516              value->name());
517     if (value->number() < min_value) {
518       min_value = value->number();
519       min_name = full_name + "_" + value->name();
520     }
521     if (value->number() > max_value) {
522       max_value = value->number();
523       max_name = full_name + "_" + value->name();
524     }
525   }
526   p->Print("static constexpr auto $n$_MIN = $m$;\n", "n", enum_desc->name(),
527            "m", min_name);
528   p->Print("static constexpr auto $n$_MAX = $m$;\n", "n", enum_desc->name(),
529            "m", max_name);
530 }
531 
GenClassDecl(const Descriptor * msg,Printer * p) const532 void CppObjGenerator::GenClassDecl(const Descriptor* msg, Printer* p) const {
533   std::string full_name = GetFullName(msg);
534   p->Print(
535       "\nclass PERFETTO_EXPORT $n$ : public ::protozero::CppMessageObj {\n",
536       "n", full_name);
537   p->Print(" public:\n");
538   p->Indent();
539 
540   // Do a first pass to generate aliases for nested types.
541   // e.g., using Foo = Parent_Foo;
542   for (int i = 0; i < msg->nested_type_count(); i++) {
543     const Descriptor* nested_msg = msg->nested_type(i);
544     p->Print("using $n$ = $f$;\n", "n", nested_msg->name(), "f",
545              GetFullName(nested_msg));
546   }
547   for (int i = 0; i < msg->enum_type_count(); i++) {
548     const EnumDescriptor* nested_enum = msg->enum_type(i);
549     p->Print("using $n$ = $f$;\n", "n", nested_enum->name(), "f",
550              GetFullName(nested_enum));
551     GenEnumAliases(nested_enum, p);
552   }
553 
554   // Generate constants with field numbers.
555   p->Print("enum FieldNumbers {\n");
556   for (int i = 0; i < msg->field_count(); i++) {
557     const FieldDescriptor* field = msg->field(i);
558     std::string name = field->camelcase_name();
559     name[0] = perfetto::base::Uppercase(name[0]);
560     p->Print("  k$n$FieldNumber = $num$,\n", "n", name, "num",
561              std::to_string(field->number()));
562   }
563   p->Print("};\n\n");
564 
565   p->Print("$n$();\n", "n", full_name);
566   p->Print("~$n$() override;\n", "n", full_name);
567   p->Print("$n$($n$&&) noexcept;\n", "n", full_name);
568   p->Print("$n$& operator=($n$&&);\n", "n", full_name);
569   p->Print("$n$(const $n$&);\n", "n", full_name);
570   p->Print("$n$& operator=(const $n$&);\n", "n", full_name);
571   p->Print("bool operator==(const $n$&) const;\n", "n", full_name);
572   p->Print(
573       "bool operator!=(const $n$& other) const { return !(*this == other); }\n",
574       "n", full_name);
575   p->Print("\n");
576 
577   std::string proto_type = GetFullName(msg, true);
578   p->Print("bool ParseFromArray(const void*, size_t) override;\n");
579   p->Print("std::string SerializeAsString() const override;\n");
580   p->Print("std::vector<uint8_t> SerializeAsArray() const override;\n");
581   p->Print("void Serialize(::protozero::Message*) const;\n");
582 
583   // Generate accessors.
584   for (int i = 0; i < msg->field_count(); i++) {
585     const FieldDescriptor* field = msg->field(i);
586     auto set_bit = "_has_field_.set(" + std::to_string(field->number()) + ")";
587     p->Print("\n");
588     if (field->options().lazy()) {
589       p->Print("const std::string& $n$_raw() const { return $n$_; }\n", "n",
590                field->lowercase_name());
591       p->Print(
592           "void set_$n$_raw(const std::string& raw) { $n$_ = raw; $s$; }\n",
593           "n", field->lowercase_name(), "s", set_bit);
594     } else if (!field->is_repeated()) {
595       p->Print("bool has_$n$() const { return _has_field_[$bit$]; }\n", "n",
596                field->lowercase_name(), "bit", std::to_string(field->number()));
597       if (field->type() == TYPE_MESSAGE) {
598         p->Print("$t$ $n$() const { return *$n$_; }\n", "t",
599                  GetCppType(field, true), "n", field->lowercase_name());
600         p->Print("$t$* mutable_$n$() { $s$; return $n$_.get(); }\n", "t",
601                  GetCppType(field, false), "n", field->lowercase_name(), "s",
602                  set_bit);
603       } else {
604         p->Print("$t$ $n$() const { return $n$_; }\n", "t",
605                  GetCppType(field, true), "n", field->lowercase_name());
606         p->Print("void set_$n$($t$ value) { $n$_ = value; $s$; }\n", "t",
607                  GetCppType(field, true), "n", field->lowercase_name(), "s",
608                  set_bit);
609         if (field->type() == FieldDescriptor::TYPE_BYTES) {
610           p->Print(
611               "void set_$n$(const void* p, size_t s) { "
612               "$n$_.assign(reinterpret_cast<const char*>(p), s); $s$; }\n",
613               "n", field->lowercase_name(), "s", set_bit);
614         }
615       }
616     } else {  // is_repeated()
617       p->Print("const std::vector<$t$>& $n$() const { return $n$_; }\n", "t",
618                GetCppType(field, false), "n", field->lowercase_name());
619       p->Print("std::vector<$t$>* mutable_$n$() { return &$n$_; }\n", "t",
620                GetCppType(field, false), "n", field->lowercase_name());
621 
622       // Generate accessors for repeated message types in the .cc file so that
623       // the header doesn't depend on the full definition of all nested types.
624       if (field->type() == TYPE_MESSAGE) {
625         p->Print("int $n$_size() const;\n", "t", GetCppType(field, false), "n",
626                  field->lowercase_name());
627         p->Print("void clear_$n$();\n", "n", field->lowercase_name());
628         p->Print("$t$* add_$n$();\n", "t", GetCppType(field, false), "n",
629                  field->lowercase_name());
630       } else {  // Primitive type.
631         p->Print(
632             "int $n$_size() const { return static_cast<int>($n$_.size()); }\n",
633             "t", GetCppType(field, false), "n", field->lowercase_name());
634         p->Print("void clear_$n$() { $n$_.clear(); }\n", "n",
635                  field->lowercase_name());
636         p->Print("void add_$n$($t$ value) { $n$_.emplace_back(value); }\n", "t",
637                  GetCppType(field, false), "n", field->lowercase_name());
638         // TODO(primiano): this should be done only for TYPE_MESSAGE.
639         // Unfortuntely we didn't realize before and now we have a bunch of code
640         // that does: *msg->add_int_value() = 42 instead of
641         // msg->add_int_value(42).
642         p->Print(
643             "$t$* add_$n$() { $n$_.emplace_back(); return &$n$_.back(); }\n",
644             "t", GetCppType(field, false), "n", field->lowercase_name());
645       }
646     }
647   }
648   p->Outdent();
649   p->Print("\n private:\n");
650   p->Indent();
651 
652   // Generate fields.
653   int max_field_id = 1;
654   for (int i = 0; i < msg->field_count(); i++) {
655     const FieldDescriptor* field = msg->field(i);
656     max_field_id = std::max(max_field_id, field->number());
657     if (field->options().lazy()) {
658       p->Print("std::string $n$_;  // [lazy=true]\n", "n",
659                field->lowercase_name());
660     } else if (!field->is_repeated()) {
661       std::string type = GetCppType(field, false);
662       if (field->type() == TYPE_MESSAGE) {
663         type = "::protozero::CopyablePtr<" + type + ">";
664         p->Print("$t$ $n$_;\n", "t", type, "n", field->lowercase_name());
665       } else {
666         p->Print("$t$ $n$_{};\n", "t", type, "n", field->lowercase_name());
667       }
668     } else {  // is_repeated()
669       p->Print("std::vector<$t$> $n$_;\n", "t", GetCppType(field, false), "n",
670                field->lowercase_name());
671     }
672   }
673   p->Print("\n");
674   p->Print("// Allows to preserve unknown protobuf fields for compatibility\n");
675   p->Print("// with future versions of .proto files.\n");
676   p->Print("std::string unknown_fields_;\n");
677 
678   p->Print("\nstd::bitset<$id$> _has_field_{};\n", "id",
679            std::to_string(max_field_id + 1));
680 
681   p->Outdent();
682   p->Print("};\n\n");
683 }
684 
GenClassDef(const Descriptor * msg,Printer * p) const685 void CppObjGenerator::GenClassDef(const Descriptor* msg, Printer* p) const {
686   p->Print("\n");
687   std::string full_name = GetFullName(msg);
688 
689   p->Print("$n$::$n$() = default;\n", "n", full_name);
690   p->Print("$n$::~$n$() = default;\n", "n", full_name);
691   p->Print("$n$::$n$(const $n$&) = default;\n", "n", full_name);
692   p->Print("$n$& $n$::operator=(const $n$&) = default;\n", "n", full_name);
693   p->Print("$n$::$n$($n$&&) noexcept = default;\n", "n", full_name);
694   p->Print("$n$& $n$::operator=($n$&&) = default;\n", "n", full_name);
695 
696   p->Print("\n");
697 
698   // Comparison operator
699   p->Print("bool $n$::operator==(const $n$& other) const {\n", "n", full_name);
700   p->Indent();
701 
702   p->Print("return unknown_fields_ == other.unknown_fields_");
703   for (int i = 0; i < msg->field_count(); i++)
704     p->Print("\n && $n$_ == other.$n$_", "n", msg->field(i)->lowercase_name());
705   p->Print(";");
706   p->Outdent();
707   p->Print("\n}\n\n");
708 
709   // Accessors for repeated message fields.
710   for (int i = 0; i < msg->field_count(); i++) {
711     const FieldDescriptor* field = msg->field(i);
712     if (field->options().lazy() || !field->is_repeated() ||
713         field->type() != TYPE_MESSAGE) {
714       continue;
715     }
716     p->Print(
717         "int $c$::$n$_size() const { return static_cast<int>($n$_.size()); }\n",
718         "c", full_name, "t", GetCppType(field, false), "n",
719         field->lowercase_name());
720     p->Print("void $c$::clear_$n$() { $n$_.clear(); }\n", "c", full_name, "n",
721              field->lowercase_name());
722     p->Print(
723         "$t$* $c$::add_$n$() { $n$_.emplace_back(); return &$n$_.back(); }\n",
724         "c", full_name, "t", GetCppType(field, false), "n",
725         field->lowercase_name());
726   }
727 
728   std::string proto_type = GetFullName(msg, true);
729 
730   // Generate the ParseFromArray() method definition.
731   p->Print("bool $f$::ParseFromArray(const void* raw, size_t size) {\n", "f",
732            full_name);
733   p->Indent();
734   for (int i = 0; i < msg->field_count(); i++) {
735     const FieldDescriptor* field = msg->field(i);
736     if (field->is_repeated())
737       p->Print("$n$_.clear();\n", "n", field->lowercase_name());
738   }
739   p->Print("unknown_fields_.clear();\n");
740   p->Print("bool packed_error = false;\n");
741   p->Print("\n");
742   p->Print("::protozero::ProtoDecoder dec(raw, size);\n");
743   p->Print("for (auto field = dec.ReadField(); field.valid(); ");
744   p->Print("field = dec.ReadField()) {\n");
745   p->Indent();
746   p->Print("if (field.id() < _has_field_.size()) {\n");
747   p->Print("  _has_field_.set(field.id());\n");
748   p->Print("}\n");
749   p->Print("switch (field.id()) {\n");
750   p->Indent();
751   for (int i = 0; i < msg->field_count(); i++) {
752     const FieldDescriptor* field = msg->field(i);
753     p->Print("case $id$ /* $n$ */:\n", "id", std::to_string(field->number()),
754              "n", field->lowercase_name());
755     p->Indent();
756     if (field->options().lazy()) {
757       p->Print("$n$_ = field.as_std_string();\n", "n", field->lowercase_name());
758     } else {
759       std::string statement;
760       if (field->type() == TYPE_MESSAGE) {
761         statement = "$rval$.ParseFromArray(field.data(), field.size());\n";
762       } else {
763         if (field->type() == TYPE_SINT32 || field->type() == TYPE_SINT64) {
764           // sint32/64 fields are special and need to be zig-zag-decoded.
765           statement = "field.get_signed(&$rval$);\n";
766         } else {
767           statement = "field.get(&$rval$);\n";
768         }
769       }
770       if (field->is_packed()) {
771         PERFETTO_CHECK(field->is_repeated());
772         if (field->type() == TYPE_SINT32 || field->type() == TYPE_SINT64) {
773           PERFETTO_FATAL("packed signed (zigzag) fields are not supported");
774         }
775         p->Print(
776             "for (::protozero::PackedRepeatedFieldIterator<$w$, $c$> "
777             "rep(field.data(), field.size(), &packed_error); rep; ++rep) {\n",
778             "w", GetPackedWireType(field), "c", GetCppType(field, false));
779         p->Print("  $n$_.emplace_back(*rep);\n", "n", field->lowercase_name());
780         p->Print("}\n");
781       } else if (field->is_repeated()) {
782         p->Print("$n$_.emplace_back();\n", "n", field->lowercase_name());
783         p->Print(statement.c_str(), "rval",
784                  field->lowercase_name() + "_.back()");
785       } else if (field->type() == TYPE_MESSAGE) {
786         p->Print(statement.c_str(), "rval",
787                  "(*" + field->lowercase_name() + "_)");
788       } else {
789         p->Print(statement.c_str(), "rval", field->lowercase_name() + "_");
790       }
791     }
792     p->Print("break;\n");
793     p->Outdent();
794   }  // for (field)
795   p->Print("default:\n");
796   p->Print("  field.SerializeAndAppendTo(&unknown_fields_);\n");
797   p->Print("  break;\n");
798   p->Outdent();
799   p->Print("}\n");  // switch(field.id)
800   p->Outdent();
801   p->Print("}\n");                                           // for(field)
802   p->Print("return !packed_error && !dec.bytes_left();\n");  // for(field)
803   p->Outdent();
804   p->Print("}\n\n");
805 
806   // Generate the SerializeAsString() method definition.
807   p->Print("std::string $f$::SerializeAsString() const {\n", "f", full_name);
808   p->Indent();
809   p->Print("::protozero::HeapBuffered<::protozero::Message> msg;\n");
810   p->Print("Serialize(msg.get());\n");
811   p->Print("return msg.SerializeAsString();\n");
812   p->Outdent();
813   p->Print("}\n\n");
814 
815   // Generate the SerializeAsArray() method definition.
816   p->Print("std::vector<uint8_t> $f$::SerializeAsArray() const {\n", "f",
817            full_name);
818   p->Indent();
819   p->Print("::protozero::HeapBuffered<::protozero::Message> msg;\n");
820   p->Print("Serialize(msg.get());\n");
821   p->Print("return msg.SerializeAsArray();\n");
822   p->Outdent();
823   p->Print("}\n\n");
824 
825   // Generate the Serialize() method that writes the fields into the passed
826   // protozero |msg| write-only interface |msg|.
827   p->Print("void $f$::Serialize(::protozero::Message* msg) const {\n", "f",
828            full_name);
829   p->Indent();
830   for (int i = 0; i < msg->field_count(); i++) {
831     const FieldDescriptor* field = msg->field(i);
832     std::map<std::string, std::string> args;
833     args["id"] = std::to_string(field->number());
834     args["n"] = field->lowercase_name();
835     p->Print(args, "// Field $id$: $n$\n");
836     if (field->is_packed()) {
837       PERFETTO_CHECK(field->is_repeated());
838       p->Print("{\n");
839       p->Indent();
840       p->Print("$p$ pack;\n", "p", GetPackedBuffer(field));
841       p->Print(args, "for (auto& it : $n$_)\n");
842       p->Print(args, "  pack.Append(it);\n");
843       p->Print(args, "msg->AppendBytes($id$, pack.data(), pack.size());\n");
844       p->Outdent();
845       p->Print("}\n");
846     } else {
847       if (field->is_repeated()) {
848         p->Print(args, "for (auto& it : $n$_) {\n");
849         args["lvalue"] = "it";
850         args["rvalue"] = "it";
851       } else {
852         p->Print(args, "if (_has_field_[$id$]) {\n");
853         args["lvalue"] = "(*" + field->lowercase_name() + "_)";
854         args["rvalue"] = field->lowercase_name() + "_";
855       }
856       p->Indent();
857       if (field->options().lazy()) {
858         p->Print(args, "msg->AppendString($id$, $rvalue$);\n");
859       } else if (field->type() == TYPE_MESSAGE) {
860         p->Print(args,
861                  "$lvalue$.Serialize("
862                  "msg->BeginNestedMessage<::protozero::Message>($id$));\n");
863       } else {
864         args["setter"] = GetProtozeroSetter(field);
865         p->Print(args, "msg->$setter$($id$, $rvalue$);\n");
866       }
867       p->Outdent();
868       p->Print("}\n");
869     }
870 
871     p->Print("\n");
872   }  // for (field)
873   p->Print(
874       "msg->AppendRawProtoBytes(unknown_fields_.data(), "
875       "unknown_fields_.size());\n");
876   p->Outdent();
877   p->Print("}\n\n");
878 }
879 
880 }  // namespace
881 }  // namespace protozero
882 
main(int argc,char ** argv)883 int main(int argc, char** argv) {
884   ::protozero::CppObjGenerator generator;
885   return google::protobuf::compiler::PluginMain(argc, argv, &generator);
886 }
887