1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <stdlib.h>
18 
19 #include <limits>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24 
25 #include <google/protobuf/compiler/code_generator.h>
26 #include <google/protobuf/compiler/plugin.h>
27 #include <google/protobuf/descriptor.h>
28 #include <google/protobuf/descriptor.pb.h>
29 #include <google/protobuf/io/printer.h>
30 #include <google/protobuf/io/zero_copy_stream.h>
31 
32 #include "perfetto/ext/base/string_utils.h"
33 
34 namespace protozero {
35 namespace {
36 
37 using google::protobuf::Descriptor;
38 using google::protobuf::EnumDescriptor;
39 using google::protobuf::EnumValueDescriptor;
40 using google::protobuf::FieldDescriptor;
41 using google::protobuf::FileDescriptor;
42 using google::protobuf::compiler::GeneratorContext;
43 using google::protobuf::io::Printer;
44 using google::protobuf::io::ZeroCopyOutputStream;
45 using perfetto::base::SplitString;
46 using perfetto::base::StripChars;
47 using perfetto::base::StripPrefix;
48 using perfetto::base::StripSuffix;
49 using perfetto::base::ToUpper;
50 using perfetto::base::Uppercase;
51 
52 // Keep this value in sync with ProtoDecoder::kMaxDecoderFieldId. If they go out
53 // of sync pbzero.h files will stop compiling, hitting the at() static_assert.
54 // Not worth an extra dependency.
55 constexpr int kMaxDecoderFieldId = 999;
56 
Assert(bool condition)57 void Assert(bool condition) {
58   if (!condition)
59     abort();
60 }
61 
62 struct FileDescriptorComp {
operator ()protozero::__anon58692eed0111::FileDescriptorComp63   bool operator()(const FileDescriptor* lhs, const FileDescriptor* rhs) const {
64     int comp = lhs->name().compare(rhs->name());
65     Assert(comp != 0 || lhs == rhs);
66     return comp < 0;
67   }
68 };
69 
70 struct DescriptorComp {
operator ()protozero::__anon58692eed0111::DescriptorComp71   bool operator()(const Descriptor* lhs, const Descriptor* rhs) const {
72     int comp = lhs->full_name().compare(rhs->full_name());
73     Assert(comp != 0 || lhs == rhs);
74     return comp < 0;
75   }
76 };
77 
78 struct EnumDescriptorComp {
operator ()protozero::__anon58692eed0111::EnumDescriptorComp79   bool operator()(const EnumDescriptor* lhs, const EnumDescriptor* rhs) const {
80     int comp = lhs->full_name().compare(rhs->full_name());
81     Assert(comp != 0 || lhs == rhs);
82     return comp < 0;
83   }
84 };
85 
ProtoStubName(const FileDescriptor * proto)86 inline std::string ProtoStubName(const FileDescriptor* proto) {
87   return StripSuffix(proto->name(), ".proto") + ".pbzero";
88 }
89 
90 class GeneratorJob {
91  public:
GeneratorJob(const FileDescriptor * file,Printer * stub_h_printer)92   GeneratorJob(const FileDescriptor* file, Printer* stub_h_printer)
93       : source_(file), stub_h_(stub_h_printer) {}
94 
GenerateStubs()95   bool GenerateStubs() {
96     Preprocess();
97     GeneratePrologue();
98     for (const EnumDescriptor* enumeration : enums_)
99       GenerateEnumDescriptor(enumeration);
100     for (const Descriptor* message : messages_)
101       GenerateMessageDescriptor(message);
102     for (const auto& key_value : extensions_)
103       GenerateExtension(key_value.first, key_value.second);
104     GenerateEpilogue();
105     return error_.empty();
106   }
107 
SetOption(const std::string & name,const std::string & value)108   void SetOption(const std::string& name, const std::string& value) {
109     if (name == "wrapper_namespace") {
110       wrapper_namespace_ = value;
111     } else {
112       Abort(std::string() + "Unknown plugin option '" + name + "'.");
113     }
114   }
115 
116   // If generator fails to produce stubs for a particular proto definitions
117   // it finishes with undefined output and writes the first error occured.
GetFirstError() const118   const std::string& GetFirstError() const { return error_; }
119 
120  private:
121   // Only the first error will be recorded.
Abort(const std::string & reason)122   void Abort(const std::string& reason) {
123     if (error_.empty())
124       error_ = reason;
125   }
126 
127   // Get full name (including outer descriptors) of proto descriptor.
128   template <class T>
GetDescriptorName(const T * descriptor)129   inline std::string GetDescriptorName(const T* descriptor) {
130     if (!package_.empty()) {
131       return StripPrefix(descriptor->full_name(), package_ + ".");
132     } else {
133       return descriptor->full_name();
134     }
135   }
136 
137   // Get C++ class name corresponding to proto descriptor.
138   // Nested names are splitted by underscores. Underscores in type names aren't
139   // prohibited but not recommended in order to avoid name collisions.
140   template <class T>
GetCppClassName(const T * descriptor,bool full=false)141   inline std::string GetCppClassName(const T* descriptor, bool full = false) {
142     std::string name = StripChars(GetDescriptorName(descriptor), ".", '_');
143     if (full)
144       name = full_namespace_prefix_ + name;
145     return name;
146   }
147 
GetFieldNumberConstant(const FieldDescriptor * field)148   inline std::string GetFieldNumberConstant(const FieldDescriptor* field) {
149     std::string name = field->camelcase_name();
150     if (!name.empty()) {
151       name.at(0) = Uppercase(name.at(0));
152       name = "k" + name + "FieldNumber";
153     } else {
154       // Protoc allows fields like 'bool _ = 1'.
155       Abort("Empty field name in camel case notation.");
156     }
157     return name;
158   }
159 
160   // Note: intentionally avoiding depending on protozero sources, as well as
161   // protobuf-internal WireFormat/WireFormatLite classes.
FieldTypeToProtozeroWireType(FieldDescriptor::Type proto_type)162   const char* FieldTypeToProtozeroWireType(FieldDescriptor::Type proto_type) {
163     switch (proto_type) {
164       case FieldDescriptor::TYPE_INT64:
165       case FieldDescriptor::TYPE_UINT64:
166       case FieldDescriptor::TYPE_INT32:
167       case FieldDescriptor::TYPE_BOOL:
168       case FieldDescriptor::TYPE_UINT32:
169       case FieldDescriptor::TYPE_ENUM:
170       case FieldDescriptor::TYPE_SINT32:
171       case FieldDescriptor::TYPE_SINT64:
172         return "::protozero::proto_utils::ProtoWireType::kVarInt";
173 
174       case FieldDescriptor::TYPE_FIXED32:
175       case FieldDescriptor::TYPE_SFIXED32:
176       case FieldDescriptor::TYPE_FLOAT:
177         return "::protozero::proto_utils::ProtoWireType::kFixed32";
178 
179       case FieldDescriptor::TYPE_FIXED64:
180       case FieldDescriptor::TYPE_SFIXED64:
181       case FieldDescriptor::TYPE_DOUBLE:
182         return "::protozero::proto_utils::ProtoWireType::kFixed64";
183 
184       case FieldDescriptor::TYPE_STRING:
185       case FieldDescriptor::TYPE_MESSAGE:
186       case FieldDescriptor::TYPE_BYTES:
187         return "::protozero::proto_utils::ProtoWireType::kLengthDelimited";
188 
189       case FieldDescriptor::TYPE_GROUP:
190         Abort("Groups not supported.");
191     }
192     Abort("Unrecognized FieldDescriptor::Type.");
193     return "";
194   }
195 
FieldTypeToPackedBufferType(FieldDescriptor::Type proto_type)196   const char* FieldTypeToPackedBufferType(FieldDescriptor::Type proto_type) {
197     switch (proto_type) {
198       case FieldDescriptor::TYPE_INT64:
199       case FieldDescriptor::TYPE_UINT64:
200       case FieldDescriptor::TYPE_INT32:
201       case FieldDescriptor::TYPE_BOOL:
202       case FieldDescriptor::TYPE_UINT32:
203       case FieldDescriptor::TYPE_ENUM:
204       case FieldDescriptor::TYPE_SINT32:
205       case FieldDescriptor::TYPE_SINT64:
206         return "::protozero::PackedVarInt";
207 
208       case FieldDescriptor::TYPE_FIXED32:
209         return "::protozero::PackedFixedSizeInt<uint32_t>";
210       case FieldDescriptor::TYPE_SFIXED32:
211         return "::protozero::PackedFixedSizeInt<int32_t>";
212       case FieldDescriptor::TYPE_FLOAT:
213         return "::protozero::PackedFixedSizeInt<float>";
214 
215       case FieldDescriptor::TYPE_FIXED64:
216         return "::protozero::PackedFixedSizeInt<uint64_t>";
217       case FieldDescriptor::TYPE_SFIXED64:
218         return "::protozero::PackedFixedSizeInt<int64_t>";
219       case FieldDescriptor::TYPE_DOUBLE:
220         return "::protozero::PackedFixedSizeInt<double>";
221 
222       case FieldDescriptor::TYPE_STRING:
223       case FieldDescriptor::TYPE_MESSAGE:
224       case FieldDescriptor::TYPE_BYTES:
225       case FieldDescriptor::TYPE_GROUP:
226         Abort("Unexpected FieldDescritor::Type.");
227     }
228     Abort("Unrecognized FieldDescriptor::Type.");
229     return "";
230   }
231 
FieldToProtoSchemaType(const FieldDescriptor * field)232   const char* FieldToProtoSchemaType(const FieldDescriptor* field) {
233     switch (field->type()) {
234       case FieldDescriptor::TYPE_BOOL:
235         return "kBool";
236       case FieldDescriptor::TYPE_INT32:
237         return "kInt32";
238       case FieldDescriptor::TYPE_INT64:
239         return "kInt64";
240       case FieldDescriptor::TYPE_UINT32:
241         return "kUint32";
242       case FieldDescriptor::TYPE_UINT64:
243         return "kUint64";
244       case FieldDescriptor::TYPE_SINT32:
245         return "kSint32";
246       case FieldDescriptor::TYPE_SINT64:
247         return "kSint64";
248       case FieldDescriptor::TYPE_FIXED32:
249         return "kFixed32";
250       case FieldDescriptor::TYPE_FIXED64:
251         return "kFixed64";
252       case FieldDescriptor::TYPE_SFIXED32:
253         return "kSfixed32";
254       case FieldDescriptor::TYPE_SFIXED64:
255         return "kSfixed64";
256       case FieldDescriptor::TYPE_FLOAT:
257         return "kFloat";
258       case FieldDescriptor::TYPE_DOUBLE:
259         return "kDouble";
260       case FieldDescriptor::TYPE_ENUM:
261         return "kEnum";
262       case FieldDescriptor::TYPE_STRING:
263         return "kString";
264       case FieldDescriptor::TYPE_MESSAGE:
265         return "kMessage";
266       case FieldDescriptor::TYPE_BYTES:
267         return "kBytes";
268 
269       case FieldDescriptor::TYPE_GROUP:
270         Abort("Groups not supported.");
271         return "";
272     }
273     Abort("Unrecognized FieldDescriptor::Type.");
274     return "";
275   }
276 
FieldToCppTypeName(const FieldDescriptor * field)277   std::string FieldToCppTypeName(const FieldDescriptor* field) {
278     switch (field->type()) {
279       case FieldDescriptor::TYPE_BOOL:
280         return "bool";
281       case FieldDescriptor::TYPE_INT32:
282         return "int32_t";
283       case FieldDescriptor::TYPE_INT64:
284         return "int64_t";
285       case FieldDescriptor::TYPE_UINT32:
286         return "uint32_t";
287       case FieldDescriptor::TYPE_UINT64:
288         return "uint64_t";
289       case FieldDescriptor::TYPE_SINT32:
290         return "int32_t";
291       case FieldDescriptor::TYPE_SINT64:
292         return "int64_t";
293       case FieldDescriptor::TYPE_FIXED32:
294         return "uint32_t";
295       case FieldDescriptor::TYPE_FIXED64:
296         return "uint64_t";
297       case FieldDescriptor::TYPE_SFIXED32:
298         return "int32_t";
299       case FieldDescriptor::TYPE_SFIXED64:
300         return "int64_t";
301       case FieldDescriptor::TYPE_FLOAT:
302         return "float";
303       case FieldDescriptor::TYPE_DOUBLE:
304         return "double";
305       case FieldDescriptor::TYPE_ENUM:
306         return GetCppClassName(field->enum_type(), true);
307       case FieldDescriptor::TYPE_STRING:
308       case FieldDescriptor::TYPE_BYTES:
309         return "std::string";
310       case FieldDescriptor::TYPE_MESSAGE:
311         return GetCppClassName(field->message_type());
312       case FieldDescriptor::TYPE_GROUP:
313         Abort("Groups not supported.");
314         return "";
315     }
316     Abort("Unrecognized FieldDescriptor::Type.");
317     return "";
318   }
319 
FieldToRepetitionType(const FieldDescriptor * field)320   const char* FieldToRepetitionType(const FieldDescriptor* field) {
321     if (!field->is_repeated())
322       return "kNotRepeated";
323     if (field->is_packed())
324       return "kRepeatedPacked";
325     return "kRepeatedNotPacked";
326   }
327 
CollectDescriptors()328   void CollectDescriptors() {
329     // Collect message descriptors in DFS order.
330     std::vector<const Descriptor*> stack;
331     stack.reserve(static_cast<size_t>(source_->message_type_count()));
332     for (int i = 0; i < source_->message_type_count(); ++i)
333       stack.push_back(source_->message_type(i));
334 
335     while (!stack.empty()) {
336       const Descriptor* message = stack.back();
337       stack.pop_back();
338 
339       if (message->extension_count() > 0) {
340         if (message->field_count() > 0 || message->nested_type_count() > 0 ||
341             message->enum_type_count() > 0) {
342           Abort("message with extend blocks shouldn't contain anything else");
343         }
344 
345         // Iterate over all fields in "extend" blocks.
346         for (int i = 0; i < message->extension_count(); ++i) {
347           const FieldDescriptor* extension = message->extension(i);
348 
349           // Protoc plugin API does not group fields in "extend" blocks.
350           // As the support for extensions in protozero is limited, the code
351           // assumes that extend blocks are located inside a wrapper message and
352           // name of this message is used to group them.
353           std::string extension_name = extension->extension_scope()->name();
354           extensions_[extension_name].push_back(extension);
355         }
356       } else {
357         messages_.push_back(message);
358         for (int i = 0; i < message->nested_type_count(); ++i) {
359           stack.push_back(message->nested_type(i));
360           // Emit a forward declaration of nested message types, as the outer
361           // class will refer to them when creating type aliases.
362           referenced_messages_.insert(message->nested_type(i));
363         }
364       }
365     }
366 
367     // Collect enums.
368     for (int i = 0; i < source_->enum_type_count(); ++i)
369       enums_.push_back(source_->enum_type(i));
370 
371     if (source_->extension_count() > 0)
372       Abort("top-level extension blocks are not supported");
373 
374     for (const Descriptor* message : messages_) {
375       for (int i = 0; i < message->enum_type_count(); ++i) {
376         enums_.push_back(message->enum_type(i));
377       }
378     }
379   }
380 
CollectDependencies()381   void CollectDependencies() {
382     // Public import basically means that callers only need to import this
383     // proto in order to use the stuff publicly imported by this proto.
384     for (int i = 0; i < source_->public_dependency_count(); ++i)
385       public_imports_.insert(source_->public_dependency(i));
386 
387     if (source_->weak_dependency_count() > 0)
388       Abort("Weak imports are not supported.");
389 
390     // Validations. Collect public imports (of collected imports) in DFS order.
391     // Visibilty for current proto:
392     // - all imports listed in current proto,
393     // - public imports of everything imported (recursive).
394     std::vector<const FileDescriptor*> stack;
395     for (int i = 0; i < source_->dependency_count(); ++i) {
396       const FileDescriptor* import = source_->dependency(i);
397       stack.push_back(import);
398       if (public_imports_.count(import) == 0) {
399         private_imports_.insert(import);
400       }
401     }
402 
403     while (!stack.empty()) {
404       const FileDescriptor* import = stack.back();
405       stack.pop_back();
406       // Having imports under different packages leads to unnecessary
407       // complexity with namespaces.
408       if (import->package() != package_)
409         Abort("Imported proto must be in the same package.");
410 
411       for (int i = 0; i < import->public_dependency_count(); ++i) {
412         stack.push_back(import->public_dependency(i));
413       }
414     }
415 
416     // Collect descriptors of messages and enums used in current proto.
417     // It will be used to generate necessary forward declarations and
418     // check that everything lays in the same namespace.
419     for (const Descriptor* message : messages_) {
420       for (int i = 0; i < message->field_count(); ++i) {
421         const FieldDescriptor* field = message->field(i);
422 
423         if (field->type() == FieldDescriptor::TYPE_MESSAGE) {
424           if (public_imports_.count(field->message_type()->file()) == 0) {
425             // Avoid multiple forward declarations since
426             // public imports have been already included.
427             referenced_messages_.insert(field->message_type());
428           }
429         } else if (field->type() == FieldDescriptor::TYPE_ENUM) {
430           if (public_imports_.count(field->enum_type()->file()) == 0) {
431             referenced_enums_.insert(field->enum_type());
432           }
433         }
434       }
435     }
436   }
437 
Preprocess()438   void Preprocess() {
439     // Package name maps to a series of namespaces.
440     package_ = source_->package();
441     namespaces_ = SplitString(package_, ".");
442     if (!wrapper_namespace_.empty())
443       namespaces_.push_back(wrapper_namespace_);
444 
445     full_namespace_prefix_ = "::";
446     for (const std::string& ns : namespaces_)
447       full_namespace_prefix_ += ns + "::";
448 
449     CollectDescriptors();
450     CollectDependencies();
451   }
452 
453   // Print top header, namespaces and forward declarations.
GeneratePrologue()454   void GeneratePrologue() {
455     std::string greeting =
456         "// Autogenerated by the ProtoZero compiler plugin. DO NOT EDIT.\n";
457     std::string guard = package_ + "_" + source_->name() + "_H_";
458     guard = ToUpper(guard);
459     guard = StripChars(guard, ".-/\\", '_');
460 
461     stub_h_->Print(
462         "$greeting$\n"
463         "#ifndef $guard$\n"
464         "#define $guard$\n\n"
465         "#include <stddef.h>\n"
466         "#include <stdint.h>\n\n"
467         "#include \"perfetto/protozero/field_writer.h\"\n"
468         "#include \"perfetto/protozero/message.h\"\n"
469         "#include \"perfetto/protozero/packed_repeated_fields.h\"\n"
470         "#include \"perfetto/protozero/proto_decoder.h\"\n"
471         "#include \"perfetto/protozero/proto_utils.h\"\n",
472         "greeting", greeting, "guard", guard);
473 
474     // Print includes for public imports.
475     for (const FileDescriptor* dependency : public_imports_) {
476       // Dependency name could contain slashes but importing from upper-level
477       // directories is not possible anyway since build system processes each
478       // proto file individually. Hence proto lookup path is always equal to the
479       // directory where particular proto file is located and protoc does not
480       // allow reference to upper directory (aka ..) in import path.
481       //
482       // Laconically said:
483       // - source_->name() may never have slashes,
484       // - dependency->name() may have slashes but always refers to inner path.
485       stub_h_->Print("#include \"$name$.h\"\n", "name",
486                      ProtoStubName(dependency));
487     }
488     stub_h_->Print("\n");
489 
490     // Print namespaces.
491     for (const std::string& ns : namespaces_) {
492       stub_h_->Print("namespace $ns$ {\n", "ns", ns);
493     }
494     stub_h_->Print("\n");
495 
496     // Print forward declarations.
497     for (const Descriptor* message : referenced_messages_) {
498       stub_h_->Print("class $class$;\n", "class", GetCppClassName(message));
499     }
500     for (const EnumDescriptor* enumeration : referenced_enums_) {
501       stub_h_->Print("enum $class$ : int32_t;\n", "class",
502                      GetCppClassName(enumeration));
503     }
504     stub_h_->Print("\n");
505   }
506 
GenerateEnumDescriptor(const EnumDescriptor * enumeration)507   void GenerateEnumDescriptor(const EnumDescriptor* enumeration) {
508     stub_h_->Print("enum $class$ : int32_t {\n", "class",
509                    GetCppClassName(enumeration));
510     stub_h_->Indent();
511 
512     std::string value_name_prefix;
513     if (enumeration->containing_type() != nullptr)
514       value_name_prefix = GetCppClassName(enumeration) + "_";
515 
516     std::string min_name, max_name;
517     int min_val = std::numeric_limits<int>::max();
518     int max_val = -1;
519     for (int i = 0; i < enumeration->value_count(); ++i) {
520       const EnumValueDescriptor* value = enumeration->value(i);
521       stub_h_->Print("$name$ = $number$,\n", "name",
522                      value_name_prefix + value->name(), "number",
523                      std::to_string(value->number()));
524       if (value->number() < min_val) {
525         min_val = value->number();
526         min_name = value_name_prefix + value->name();
527       }
528       if (value->number() > max_val) {
529         max_val = value->number();
530         max_name = value_name_prefix + value->name();
531       }
532     }
533     stub_h_->Outdent();
534     stub_h_->Print("};\n\n");
535     stub_h_->Print("const $class$ $class$_MIN = $min$;\n", "class",
536                    GetCppClassName(enumeration), "min", min_name);
537     stub_h_->Print("const $class$ $class$_MAX = $max$;\n", "class",
538                    GetCppClassName(enumeration), "max", max_name);
539     stub_h_->Print("\n");
540   }
541 
542   // Packed repeated fields are encoded as a length-delimited field on the wire,
543   // where the payload is the concatenation of invidually encoded elements.
GeneratePackedRepeatedFieldDescriptor(const FieldDescriptor * field)544   void GeneratePackedRepeatedFieldDescriptor(const FieldDescriptor* field) {
545     std::map<std::string, std::string> setter;
546     setter["name"] = field->lowercase_name();
547     setter["field_metadata"] = GetFieldMetadataTypeName(field);
548     setter["action"] = "set";
549     setter["buffer_type"] = FieldTypeToPackedBufferType(field->type());
550     stub_h_->Print(
551         setter,
552         "void $action$_$name$(const $buffer_type$& packed_buffer) {\n"
553         "  AppendBytes($field_metadata$::kFieldId, packed_buffer.data(),\n"
554         "              packed_buffer.size());\n"
555         "}\n");
556   }
557 
GenerateSimpleFieldDescriptor(const FieldDescriptor * field)558   void GenerateSimpleFieldDescriptor(const FieldDescriptor* field) {
559     std::map<std::string, std::string> setter;
560     setter["id"] = std::to_string(field->number());
561     setter["name"] = field->lowercase_name();
562     setter["field_metadata"] = GetFieldMetadataTypeName(field);
563     setter["action"] = field->is_repeated() ? "add" : "set";
564     setter["cpp_type"] = FieldToCppTypeName(field);
565     setter["proto_field_type"] = FieldToProtoSchemaType(field);
566 
567     const char* code_stub =
568         "void $action$_$name$($cpp_type$ value) {\n"
569         "  static constexpr uint32_t field_id = $field_metadata$::kFieldId;\n"
570         "  // Call the appropriate protozero::Message::Append(field_id, ...)\n"
571         "  // method based on the type of the field.\n"
572         "  ::protozero::internal::FieldWriter<\n"
573         "    ::protozero::proto_utils::ProtoSchemaType::$proto_field_type$>\n"
574         "      ::Append(*this, field_id, value);\n"
575         "}\n";
576 
577     if (field->type() == FieldDescriptor::TYPE_STRING) {
578       // Strings and bytes should have an additional accessor which specifies
579       // the length explicitly.
580       const char* additional_method =
581           "void $action$_$name$(const char* data, size_t size) {\n"
582           "  AppendBytes($field_metadata$::kFieldId, data, size);\n"
583           "}\n";
584       stub_h_->Print(setter, additional_method);
585     } else if (field->type() == FieldDescriptor::TYPE_BYTES) {
586       const char* additional_method =
587           "void $action$_$name$(const uint8_t* data, size_t size) {\n"
588           "  AppendBytes($field_metadata$::kFieldId, data, size);\n"
589           "}\n";
590       stub_h_->Print(setter, additional_method);
591     } else if (field->type() == FieldDescriptor::TYPE_GROUP ||
592                field->type() == FieldDescriptor::TYPE_MESSAGE) {
593       Abort("Unsupported field type.");
594       return;
595     }
596 
597     stub_h_->Print(setter, code_stub);
598   }
599 
GenerateNestedMessageFieldDescriptor(const FieldDescriptor * field)600   void GenerateNestedMessageFieldDescriptor(const FieldDescriptor* field) {
601     std::string action = field->is_repeated() ? "add" : "set";
602     std::string inner_class = GetCppClassName(field->message_type());
603     stub_h_->Print(
604         "template <typename T = $inner_class$> T* $action$_$name$() {\n"
605         "  return BeginNestedMessage<T>($id$);\n"
606         "}\n\n",
607         "id", std::to_string(field->number()), "name", field->lowercase_name(),
608         "action", action, "inner_class", inner_class);
609     if (field->options().lazy()) {
610       stub_h_->Print(
611           "void $action$_$name$_raw(const std::string& raw) {\n"
612           "  return AppendBytes($id$, raw.data(), raw.size());\n"
613           "}\n\n",
614           "id", std::to_string(field->number()), "name",
615           field->lowercase_name(), "action", action, "inner_class",
616           inner_class);
617     }
618   }
619 
GenerateDecoder(const Descriptor * message)620   void GenerateDecoder(const Descriptor* message) {
621     int max_field_id = 0;
622     bool has_nonpacked_repeated_fields = false;
623     for (int i = 0; i < message->field_count(); ++i) {
624       const FieldDescriptor* field = message->field(i);
625       if (field->number() > kMaxDecoderFieldId)
626         continue;
627       max_field_id = std::max(max_field_id, field->number());
628       if (field->is_repeated() && !field->is_packed())
629         has_nonpacked_repeated_fields = true;
630     }
631 
632     std::string class_name = GetCppClassName(message) + "_Decoder";
633     stub_h_->Print(
634         "class $name$ : public "
635         "::protozero::TypedProtoDecoder</*MAX_FIELD_ID=*/$max$, "
636         "/*HAS_NONPACKED_REPEATED_FIELDS=*/$rep$> {\n",
637         "name", class_name, "max", std::to_string(max_field_id), "rep",
638         has_nonpacked_repeated_fields ? "true" : "false");
639     stub_h_->Print(" public:\n");
640     stub_h_->Indent();
641     stub_h_->Print(
642         "$name$(const uint8_t* data, size_t len) "
643         ": TypedProtoDecoder(data, len) {}\n",
644         "name", class_name);
645     stub_h_->Print(
646         "explicit $name$(const std::string& raw) : "
647         "TypedProtoDecoder(reinterpret_cast<const uint8_t*>(raw.data()), "
648         "raw.size()) {}\n",
649         "name", class_name);
650     stub_h_->Print(
651         "explicit $name$(const ::protozero::ConstBytes& raw) : "
652         "TypedProtoDecoder(raw.data, raw.size) {}\n",
653         "name", class_name);
654 
655     for (int i = 0; i < message->field_count(); ++i) {
656       const FieldDescriptor* field = message->field(i);
657       if (field->number() > max_field_id) {
658         stub_h_->Print("// field $name$ omitted because its id is too high\n",
659                        "name", field->name());
660         continue;
661       }
662       std::string getter;
663       std::string cpp_type;
664       switch (field->type()) {
665         case FieldDescriptor::TYPE_BOOL:
666           getter = "as_bool";
667           cpp_type = "bool";
668           break;
669         case FieldDescriptor::TYPE_SFIXED32:
670         case FieldDescriptor::TYPE_SINT32:
671         case FieldDescriptor::TYPE_INT32:
672           getter = "as_int32";
673           cpp_type = "int32_t";
674           break;
675         case FieldDescriptor::TYPE_SFIXED64:
676         case FieldDescriptor::TYPE_SINT64:
677         case FieldDescriptor::TYPE_INT64:
678           getter = "as_int64";
679           cpp_type = "int64_t";
680           break;
681         case FieldDescriptor::TYPE_FIXED32:
682         case FieldDescriptor::TYPE_UINT32:
683           getter = "as_uint32";
684           cpp_type = "uint32_t";
685           break;
686         case FieldDescriptor::TYPE_FIXED64:
687         case FieldDescriptor::TYPE_UINT64:
688           getter = "as_uint64";
689           cpp_type = "uint64_t";
690           break;
691         case FieldDescriptor::TYPE_FLOAT:
692           getter = "as_float";
693           cpp_type = "float";
694           break;
695         case FieldDescriptor::TYPE_DOUBLE:
696           getter = "as_double";
697           cpp_type = "double";
698           break;
699         case FieldDescriptor::TYPE_ENUM:
700           getter = "as_int32";
701           cpp_type = "int32_t";
702           break;
703         case FieldDescriptor::TYPE_STRING:
704           getter = "as_string";
705           cpp_type = "::protozero::ConstChars";
706           break;
707         case FieldDescriptor::TYPE_MESSAGE:
708         case FieldDescriptor::TYPE_BYTES:
709           getter = "as_bytes";
710           cpp_type = "::protozero::ConstBytes";
711           break;
712         case FieldDescriptor::TYPE_GROUP:
713           continue;
714       }
715 
716       stub_h_->Print("bool has_$name$() const { return at<$id$>().valid(); }\n",
717                      "name", field->lowercase_name(), "id",
718                      std::to_string(field->number()));
719 
720       if (field->is_packed()) {
721         const char* protozero_wire_type =
722             FieldTypeToProtozeroWireType(field->type());
723         stub_h_->Print(
724             "::protozero::PackedRepeatedFieldIterator<$wire_type$, $cpp_type$> "
725             "$name$(bool* parse_error_ptr) const { return "
726             "GetPackedRepeated<$wire_type$, $cpp_type$>($id$, "
727             "parse_error_ptr); }\n",
728             "wire_type", protozero_wire_type, "cpp_type", cpp_type, "name",
729             field->lowercase_name(), "id", std::to_string(field->number()));
730       } else if (field->is_repeated()) {
731         stub_h_->Print(
732             "::protozero::RepeatedFieldIterator<$cpp_type$> $name$() const { "
733             "return "
734             "GetRepeated<$cpp_type$>($id$); }\n",
735             "name", field->lowercase_name(), "cpp_type", cpp_type, "id",
736             std::to_string(field->number()));
737       } else {
738         stub_h_->Print(
739             "$cpp_type$ $name$() const { return at<$id$>().$getter$(); }\n",
740             "name", field->lowercase_name(), "id",
741             std::to_string(field->number()), "cpp_type", cpp_type, "getter",
742             getter);
743       }
744     }
745     stub_h_->Outdent();
746     stub_h_->Print("};\n\n");
747   }
748 
GenerateConstantsForMessageFields(const Descriptor * message)749   void GenerateConstantsForMessageFields(const Descriptor* message) {
750     const bool has_fields = (message->field_count() > 0);
751 
752     // Field number constants.
753     if (has_fields) {
754       stub_h_->Print("enum : int32_t {\n");
755       stub_h_->Indent();
756 
757       for (int i = 0; i < message->field_count(); ++i) {
758         const FieldDescriptor* field = message->field(i);
759         stub_h_->Print("$name$ = $id$,\n", "name",
760                        GetFieldNumberConstant(field), "id",
761                        std::to_string(field->number()));
762       }
763       stub_h_->Outdent();
764       stub_h_->Print("};\n");
765     }
766   }
767 
GenerateMessageDescriptor(const Descriptor * message)768   void GenerateMessageDescriptor(const Descriptor* message) {
769     GenerateDecoder(message);
770 
771     stub_h_->Print(
772         "class $name$ : public ::protozero::Message {\n"
773         " public:\n",
774         "name", GetCppClassName(message));
775     stub_h_->Indent();
776 
777     stub_h_->Print("using Decoder = $name$_Decoder;\n", "name",
778                    GetCppClassName(message));
779 
780     GenerateConstantsForMessageFields(message);
781 
782     // Using statements for nested messages.
783     for (int i = 0; i < message->nested_type_count(); ++i) {
784       const Descriptor* nested_message = message->nested_type(i);
785       stub_h_->Print("using $local_name$ = $global_name$;\n", "local_name",
786                      nested_message->name(), "global_name",
787                      GetCppClassName(nested_message, true));
788     }
789 
790     // Using statements for nested enums.
791     for (int i = 0; i < message->enum_type_count(); ++i) {
792       const EnumDescriptor* nested_enum = message->enum_type(i);
793       stub_h_->Print("using $local_name$ = $global_name$;\n", "local_name",
794                      nested_enum->name(), "global_name",
795                      GetCppClassName(nested_enum, true));
796     }
797 
798     // Values of nested enums.
799     for (int i = 0; i < message->enum_type_count(); ++i) {
800       const EnumDescriptor* nested_enum = message->enum_type(i);
801       std::string value_name_prefix = GetCppClassName(nested_enum) + "_";
802 
803       for (int j = 0; j < nested_enum->value_count(); ++j) {
804         const EnumValueDescriptor* value = nested_enum->value(j);
805         stub_h_->Print("static const $class$ $name$ = $full_name$;\n", "class",
806                        nested_enum->name(), "name", value->name(), "full_name",
807                        value_name_prefix + value->name());
808       }
809     }
810 
811     // Field descriptors.
812     for (int i = 0; i < message->field_count(); ++i) {
813       GenerateFieldDescriptor(GetCppClassName(message), message->field(i));
814     }
815 
816     stub_h_->Outdent();
817     stub_h_->Print("};\n\n");
818   }
819 
GetFieldMetadataTypeName(const FieldDescriptor * field)820   std::string GetFieldMetadataTypeName(const FieldDescriptor* field) {
821     std::string name = field->camelcase_name();
822     if (isalpha(name[0]))
823       name[0] = static_cast<char>(toupper(name[0]));
824     return "FieldMetadata_" + name;
825   }
826 
GetFieldMetadataVariableName(const FieldDescriptor * field)827   std::string GetFieldMetadataVariableName(const FieldDescriptor* field) {
828     std::string name = field->camelcase_name();
829     if (isalpha(name[0]))
830       name[0] = static_cast<char>(toupper(name[0]));
831     return "k" + name;
832   }
833 
GenerateFieldMetadata(const std::string & message_cpp_type,const FieldDescriptor * field)834   void GenerateFieldMetadata(const std::string& message_cpp_type,
835                              const FieldDescriptor* field) {
836     const char* code_stub = R"(
837 using $field_metadata_type$ =
838   ::protozero::proto_utils::FieldMetadata<
839     $field_id$,
840     ::protozero::proto_utils::RepetitionType::$repetition_type$,
841     ::protozero::proto_utils::ProtoSchemaType::$proto_field_type$,
842     $cpp_type$,
843     $message_cpp_type$>;
844 
845 // Ceci n'est pas une pipe.
846 // This is actually a variable of FieldMetadataHelper<FieldMetadata<...>>
847 // type (and users are expected to use it as such, hence kCamelCase name).
848 // It is declared as a function to keep protozero bindings header-only as
849 // inline constexpr variables are not available until C++17 (while inline
850 // functions are).
851 // TODO(altimin): Use inline variable instead after adopting C++17.
852 static constexpr $field_metadata_type$ $field_metadata_var$() { return {}; }
853 )";
854 
855     stub_h_->Print(code_stub, "field_id", std::to_string(field->number()),
856                    "repetition_type", FieldToRepetitionType(field),
857                    "proto_field_type", FieldToProtoSchemaType(field),
858                    "cpp_type", FieldToCppTypeName(field), "message_cpp_type",
859                    message_cpp_type, "field_metadata_type",
860                    GetFieldMetadataTypeName(field), "field_metadata_var",
861                    GetFieldMetadataVariableName(field));
862   }
863 
GenerateFieldDescriptor(const std::string & message_cpp_type,const FieldDescriptor * field)864   void GenerateFieldDescriptor(const std::string& message_cpp_type,
865                                const FieldDescriptor* field) {
866     GenerateFieldMetadata(message_cpp_type, field);
867     if (field->is_packed()) {
868       GeneratePackedRepeatedFieldDescriptor(field);
869     } else if (field->type() != FieldDescriptor::TYPE_MESSAGE) {
870       GenerateSimpleFieldDescriptor(field);
871     } else {
872       GenerateNestedMessageFieldDescriptor(field);
873     }
874   }
875 
876   // Generate extension class for a group of FieldDescriptor instances
877   // representing one "extend" block in proto definition. For example:
878   //
879   //   message SpecificExtension {
880   //     extend GeneralThing {
881   //       optional Fizz fizz = 101;
882   //       optional Buzz buzz = 102;
883   //     }
884   //   }
885   //
886   // This is going to be passed as a vector of two elements, "fizz" and
887   // "buzz". Wrapping message is used to provide a name for generated
888   // extension class.
889   //
890   // In the example above, generated code is going to look like:
891   //
892   //   class SpecificExtension : public GeneralThing {
893   //     Fizz* set_fizz();
894   //     Buzz* set_buzz();
895   //   }
GenerateExtension(const std::string & extension_name,const std::vector<const FieldDescriptor * > & descriptors)896   void GenerateExtension(
897       const std::string& extension_name,
898       const std::vector<const FieldDescriptor*>& descriptors) {
899     // Use an arbitrary descriptor in order to get generic information not
900     // specific to any of them.
901     const FieldDescriptor* descriptor = descriptors[0];
902     const Descriptor* base_message = descriptor->containing_type();
903 
904     // TODO(ddrone): ensure that this code works when containing_type located in
905     // other file or namespace.
906     stub_h_->Print("class $name$ : public $extendee$ {\n", "name",
907                    extension_name, "extendee",
908                    GetCppClassName(base_message, /*full=*/true));
909     stub_h_->Print(" public:\n");
910     stub_h_->Indent();
911     for (const FieldDescriptor* field : descriptors) {
912       if (field->containing_type() != base_message) {
913         Abort("one wrapper should extend only one message");
914         return;
915       }
916       GenerateFieldDescriptor(extension_name, field);
917     }
918     stub_h_->Outdent();
919     stub_h_->Print("};\n");
920   }
921 
GenerateEpilogue()922   void GenerateEpilogue() {
923     for (unsigned i = 0; i < namespaces_.size(); ++i) {
924       stub_h_->Print("} // Namespace.\n");
925     }
926     stub_h_->Print("#endif  // Include guard.\n");
927   }
928 
929   const FileDescriptor* const source_;
930   Printer* const stub_h_;
931   std::string error_;
932 
933   std::string package_;
934   std::string wrapper_namespace_;
935   std::vector<std::string> namespaces_;
936   std::string full_namespace_prefix_;
937   std::vector<const Descriptor*> messages_;
938   std::vector<const EnumDescriptor*> enums_;
939   std::map<std::string, std::vector<const FieldDescriptor*>> extensions_;
940 
941   // The custom *Comp comparators are to ensure determinism of the generator.
942   std::set<const FileDescriptor*, FileDescriptorComp> public_imports_;
943   std::set<const FileDescriptor*, FileDescriptorComp> private_imports_;
944   std::set<const Descriptor*, DescriptorComp> referenced_messages_;
945   std::set<const EnumDescriptor*, EnumDescriptorComp> referenced_enums_;
946 };
947 
948 class ProtoZeroGenerator : public ::google::protobuf::compiler::CodeGenerator {
949  public:
950   explicit ProtoZeroGenerator();
951   ~ProtoZeroGenerator() override;
952 
953   // CodeGenerator implementation
954   bool Generate(const google::protobuf::FileDescriptor* file,
955                 const std::string& options,
956                 GeneratorContext* context,
957                 std::string* error) const override;
958 };
959 
ProtoZeroGenerator()960 ProtoZeroGenerator::ProtoZeroGenerator() {}
961 
~ProtoZeroGenerator()962 ProtoZeroGenerator::~ProtoZeroGenerator() {}
963 
Generate(const FileDescriptor * file,const std::string & options,GeneratorContext * context,std::string * error) const964 bool ProtoZeroGenerator::Generate(const FileDescriptor* file,
965                                   const std::string& options,
966                                   GeneratorContext* context,
967                                   std::string* error) const {
968   const std::unique_ptr<ZeroCopyOutputStream> stub_h_file_stream(
969       context->Open(ProtoStubName(file) + ".h"));
970   const std::unique_ptr<ZeroCopyOutputStream> stub_cc_file_stream(
971       context->Open(ProtoStubName(file) + ".cc"));
972 
973   // Variables are delimited by $.
974   Printer stub_h_printer(stub_h_file_stream.get(), '$');
975   GeneratorJob job(file, &stub_h_printer);
976 
977   Printer stub_cc_printer(stub_cc_file_stream.get(), '$');
978   stub_cc_printer.Print("// Intentionally empty (crbug.com/998165)\n");
979 
980   // Parse additional options.
981   for (const std::string& option : SplitString(options, ",")) {
982     std::vector<std::string> option_pair = SplitString(option, "=");
983     job.SetOption(option_pair[0], option_pair[1]);
984   }
985 
986   if (!job.GenerateStubs()) {
987     *error = job.GetFirstError();
988     return false;
989   }
990   return true;
991 }
992 
993 }  // namespace
994 }  // namespace protozero
995 
main(int argc,char * argv[])996 int main(int argc, char* argv[]) {
997   ::protozero::ProtoZeroGenerator generator;
998   return google::protobuf::compiler::PluginMain(argc, argv, &generator);
999 }
1000