1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <stdlib.h>
18
19 #include <limits>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24
25 #include <google/protobuf/compiler/code_generator.h>
26 #include <google/protobuf/compiler/plugin.h>
27 #include <google/protobuf/descriptor.h>
28 #include <google/protobuf/descriptor.pb.h>
29 #include <google/protobuf/io/printer.h>
30 #include <google/protobuf/io/zero_copy_stream.h>
31
32 #include "perfetto/ext/base/string_utils.h"
33
34 namespace protozero {
35 namespace {
36
37 using google::protobuf::Descriptor;
38 using google::protobuf::EnumDescriptor;
39 using google::protobuf::EnumValueDescriptor;
40 using google::protobuf::FieldDescriptor;
41 using google::protobuf::FileDescriptor;
42 using google::protobuf::compiler::GeneratorContext;
43 using google::protobuf::io::Printer;
44 using google::protobuf::io::ZeroCopyOutputStream;
45 using perfetto::base::SplitString;
46 using perfetto::base::StripChars;
47 using perfetto::base::StripPrefix;
48 using perfetto::base::StripSuffix;
49 using perfetto::base::ToUpper;
50 using perfetto::base::Uppercase;
51
52 // Keep this value in sync with ProtoDecoder::kMaxDecoderFieldId. If they go out
53 // of sync pbzero.h files will stop compiling, hitting the at() static_assert.
54 // Not worth an extra dependency.
55 constexpr int kMaxDecoderFieldId = 999;
56
Assert(bool condition)57 void Assert(bool condition) {
58 if (!condition)
59 abort();
60 }
61
62 struct FileDescriptorComp {
operator ()protozero::__anon58692eed0111::FileDescriptorComp63 bool operator()(const FileDescriptor* lhs, const FileDescriptor* rhs) const {
64 int comp = lhs->name().compare(rhs->name());
65 Assert(comp != 0 || lhs == rhs);
66 return comp < 0;
67 }
68 };
69
70 struct DescriptorComp {
operator ()protozero::__anon58692eed0111::DescriptorComp71 bool operator()(const Descriptor* lhs, const Descriptor* rhs) const {
72 int comp = lhs->full_name().compare(rhs->full_name());
73 Assert(comp != 0 || lhs == rhs);
74 return comp < 0;
75 }
76 };
77
78 struct EnumDescriptorComp {
operator ()protozero::__anon58692eed0111::EnumDescriptorComp79 bool operator()(const EnumDescriptor* lhs, const EnumDescriptor* rhs) const {
80 int comp = lhs->full_name().compare(rhs->full_name());
81 Assert(comp != 0 || lhs == rhs);
82 return comp < 0;
83 }
84 };
85
ProtoStubName(const FileDescriptor * proto)86 inline std::string ProtoStubName(const FileDescriptor* proto) {
87 return StripSuffix(proto->name(), ".proto") + ".pbzero";
88 }
89
90 class GeneratorJob {
91 public:
GeneratorJob(const FileDescriptor * file,Printer * stub_h_printer)92 GeneratorJob(const FileDescriptor* file, Printer* stub_h_printer)
93 : source_(file), stub_h_(stub_h_printer) {}
94
GenerateStubs()95 bool GenerateStubs() {
96 Preprocess();
97 GeneratePrologue();
98 for (const EnumDescriptor* enumeration : enums_)
99 GenerateEnumDescriptor(enumeration);
100 for (const Descriptor* message : messages_)
101 GenerateMessageDescriptor(message);
102 for (const auto& key_value : extensions_)
103 GenerateExtension(key_value.first, key_value.second);
104 GenerateEpilogue();
105 return error_.empty();
106 }
107
SetOption(const std::string & name,const std::string & value)108 void SetOption(const std::string& name, const std::string& value) {
109 if (name == "wrapper_namespace") {
110 wrapper_namespace_ = value;
111 } else {
112 Abort(std::string() + "Unknown plugin option '" + name + "'.");
113 }
114 }
115
116 // If generator fails to produce stubs for a particular proto definitions
117 // it finishes with undefined output and writes the first error occured.
GetFirstError() const118 const std::string& GetFirstError() const { return error_; }
119
120 private:
121 // Only the first error will be recorded.
Abort(const std::string & reason)122 void Abort(const std::string& reason) {
123 if (error_.empty())
124 error_ = reason;
125 }
126
127 // Get full name (including outer descriptors) of proto descriptor.
128 template <class T>
GetDescriptorName(const T * descriptor)129 inline std::string GetDescriptorName(const T* descriptor) {
130 if (!package_.empty()) {
131 return StripPrefix(descriptor->full_name(), package_ + ".");
132 } else {
133 return descriptor->full_name();
134 }
135 }
136
137 // Get C++ class name corresponding to proto descriptor.
138 // Nested names are splitted by underscores. Underscores in type names aren't
139 // prohibited but not recommended in order to avoid name collisions.
140 template <class T>
GetCppClassName(const T * descriptor,bool full=false)141 inline std::string GetCppClassName(const T* descriptor, bool full = false) {
142 std::string name = StripChars(GetDescriptorName(descriptor), ".", '_');
143 if (full)
144 name = full_namespace_prefix_ + name;
145 return name;
146 }
147
GetFieldNumberConstant(const FieldDescriptor * field)148 inline std::string GetFieldNumberConstant(const FieldDescriptor* field) {
149 std::string name = field->camelcase_name();
150 if (!name.empty()) {
151 name.at(0) = Uppercase(name.at(0));
152 name = "k" + name + "FieldNumber";
153 } else {
154 // Protoc allows fields like 'bool _ = 1'.
155 Abort("Empty field name in camel case notation.");
156 }
157 return name;
158 }
159
160 // Note: intentionally avoiding depending on protozero sources, as well as
161 // protobuf-internal WireFormat/WireFormatLite classes.
FieldTypeToProtozeroWireType(FieldDescriptor::Type proto_type)162 const char* FieldTypeToProtozeroWireType(FieldDescriptor::Type proto_type) {
163 switch (proto_type) {
164 case FieldDescriptor::TYPE_INT64:
165 case FieldDescriptor::TYPE_UINT64:
166 case FieldDescriptor::TYPE_INT32:
167 case FieldDescriptor::TYPE_BOOL:
168 case FieldDescriptor::TYPE_UINT32:
169 case FieldDescriptor::TYPE_ENUM:
170 case FieldDescriptor::TYPE_SINT32:
171 case FieldDescriptor::TYPE_SINT64:
172 return "::protozero::proto_utils::ProtoWireType::kVarInt";
173
174 case FieldDescriptor::TYPE_FIXED32:
175 case FieldDescriptor::TYPE_SFIXED32:
176 case FieldDescriptor::TYPE_FLOAT:
177 return "::protozero::proto_utils::ProtoWireType::kFixed32";
178
179 case FieldDescriptor::TYPE_FIXED64:
180 case FieldDescriptor::TYPE_SFIXED64:
181 case FieldDescriptor::TYPE_DOUBLE:
182 return "::protozero::proto_utils::ProtoWireType::kFixed64";
183
184 case FieldDescriptor::TYPE_STRING:
185 case FieldDescriptor::TYPE_MESSAGE:
186 case FieldDescriptor::TYPE_BYTES:
187 return "::protozero::proto_utils::ProtoWireType::kLengthDelimited";
188
189 case FieldDescriptor::TYPE_GROUP:
190 Abort("Groups not supported.");
191 }
192 Abort("Unrecognized FieldDescriptor::Type.");
193 return "";
194 }
195
FieldTypeToPackedBufferType(FieldDescriptor::Type proto_type)196 const char* FieldTypeToPackedBufferType(FieldDescriptor::Type proto_type) {
197 switch (proto_type) {
198 case FieldDescriptor::TYPE_INT64:
199 case FieldDescriptor::TYPE_UINT64:
200 case FieldDescriptor::TYPE_INT32:
201 case FieldDescriptor::TYPE_BOOL:
202 case FieldDescriptor::TYPE_UINT32:
203 case FieldDescriptor::TYPE_ENUM:
204 case FieldDescriptor::TYPE_SINT32:
205 case FieldDescriptor::TYPE_SINT64:
206 return "::protozero::PackedVarInt";
207
208 case FieldDescriptor::TYPE_FIXED32:
209 return "::protozero::PackedFixedSizeInt<uint32_t>";
210 case FieldDescriptor::TYPE_SFIXED32:
211 return "::protozero::PackedFixedSizeInt<int32_t>";
212 case FieldDescriptor::TYPE_FLOAT:
213 return "::protozero::PackedFixedSizeInt<float>";
214
215 case FieldDescriptor::TYPE_FIXED64:
216 return "::protozero::PackedFixedSizeInt<uint64_t>";
217 case FieldDescriptor::TYPE_SFIXED64:
218 return "::protozero::PackedFixedSizeInt<int64_t>";
219 case FieldDescriptor::TYPE_DOUBLE:
220 return "::protozero::PackedFixedSizeInt<double>";
221
222 case FieldDescriptor::TYPE_STRING:
223 case FieldDescriptor::TYPE_MESSAGE:
224 case FieldDescriptor::TYPE_BYTES:
225 case FieldDescriptor::TYPE_GROUP:
226 Abort("Unexpected FieldDescritor::Type.");
227 }
228 Abort("Unrecognized FieldDescriptor::Type.");
229 return "";
230 }
231
FieldToProtoSchemaType(const FieldDescriptor * field)232 const char* FieldToProtoSchemaType(const FieldDescriptor* field) {
233 switch (field->type()) {
234 case FieldDescriptor::TYPE_BOOL:
235 return "kBool";
236 case FieldDescriptor::TYPE_INT32:
237 return "kInt32";
238 case FieldDescriptor::TYPE_INT64:
239 return "kInt64";
240 case FieldDescriptor::TYPE_UINT32:
241 return "kUint32";
242 case FieldDescriptor::TYPE_UINT64:
243 return "kUint64";
244 case FieldDescriptor::TYPE_SINT32:
245 return "kSint32";
246 case FieldDescriptor::TYPE_SINT64:
247 return "kSint64";
248 case FieldDescriptor::TYPE_FIXED32:
249 return "kFixed32";
250 case FieldDescriptor::TYPE_FIXED64:
251 return "kFixed64";
252 case FieldDescriptor::TYPE_SFIXED32:
253 return "kSfixed32";
254 case FieldDescriptor::TYPE_SFIXED64:
255 return "kSfixed64";
256 case FieldDescriptor::TYPE_FLOAT:
257 return "kFloat";
258 case FieldDescriptor::TYPE_DOUBLE:
259 return "kDouble";
260 case FieldDescriptor::TYPE_ENUM:
261 return "kEnum";
262 case FieldDescriptor::TYPE_STRING:
263 return "kString";
264 case FieldDescriptor::TYPE_MESSAGE:
265 return "kMessage";
266 case FieldDescriptor::TYPE_BYTES:
267 return "kBytes";
268
269 case FieldDescriptor::TYPE_GROUP:
270 Abort("Groups not supported.");
271 return "";
272 }
273 Abort("Unrecognized FieldDescriptor::Type.");
274 return "";
275 }
276
FieldToCppTypeName(const FieldDescriptor * field)277 std::string FieldToCppTypeName(const FieldDescriptor* field) {
278 switch (field->type()) {
279 case FieldDescriptor::TYPE_BOOL:
280 return "bool";
281 case FieldDescriptor::TYPE_INT32:
282 return "int32_t";
283 case FieldDescriptor::TYPE_INT64:
284 return "int64_t";
285 case FieldDescriptor::TYPE_UINT32:
286 return "uint32_t";
287 case FieldDescriptor::TYPE_UINT64:
288 return "uint64_t";
289 case FieldDescriptor::TYPE_SINT32:
290 return "int32_t";
291 case FieldDescriptor::TYPE_SINT64:
292 return "int64_t";
293 case FieldDescriptor::TYPE_FIXED32:
294 return "uint32_t";
295 case FieldDescriptor::TYPE_FIXED64:
296 return "uint64_t";
297 case FieldDescriptor::TYPE_SFIXED32:
298 return "int32_t";
299 case FieldDescriptor::TYPE_SFIXED64:
300 return "int64_t";
301 case FieldDescriptor::TYPE_FLOAT:
302 return "float";
303 case FieldDescriptor::TYPE_DOUBLE:
304 return "double";
305 case FieldDescriptor::TYPE_ENUM:
306 return GetCppClassName(field->enum_type(), true);
307 case FieldDescriptor::TYPE_STRING:
308 case FieldDescriptor::TYPE_BYTES:
309 return "std::string";
310 case FieldDescriptor::TYPE_MESSAGE:
311 return GetCppClassName(field->message_type());
312 case FieldDescriptor::TYPE_GROUP:
313 Abort("Groups not supported.");
314 return "";
315 }
316 Abort("Unrecognized FieldDescriptor::Type.");
317 return "";
318 }
319
FieldToRepetitionType(const FieldDescriptor * field)320 const char* FieldToRepetitionType(const FieldDescriptor* field) {
321 if (!field->is_repeated())
322 return "kNotRepeated";
323 if (field->is_packed())
324 return "kRepeatedPacked";
325 return "kRepeatedNotPacked";
326 }
327
CollectDescriptors()328 void CollectDescriptors() {
329 // Collect message descriptors in DFS order.
330 std::vector<const Descriptor*> stack;
331 stack.reserve(static_cast<size_t>(source_->message_type_count()));
332 for (int i = 0; i < source_->message_type_count(); ++i)
333 stack.push_back(source_->message_type(i));
334
335 while (!stack.empty()) {
336 const Descriptor* message = stack.back();
337 stack.pop_back();
338
339 if (message->extension_count() > 0) {
340 if (message->field_count() > 0 || message->nested_type_count() > 0 ||
341 message->enum_type_count() > 0) {
342 Abort("message with extend blocks shouldn't contain anything else");
343 }
344
345 // Iterate over all fields in "extend" blocks.
346 for (int i = 0; i < message->extension_count(); ++i) {
347 const FieldDescriptor* extension = message->extension(i);
348
349 // Protoc plugin API does not group fields in "extend" blocks.
350 // As the support for extensions in protozero is limited, the code
351 // assumes that extend blocks are located inside a wrapper message and
352 // name of this message is used to group them.
353 std::string extension_name = extension->extension_scope()->name();
354 extensions_[extension_name].push_back(extension);
355 }
356 } else {
357 messages_.push_back(message);
358 for (int i = 0; i < message->nested_type_count(); ++i) {
359 stack.push_back(message->nested_type(i));
360 // Emit a forward declaration of nested message types, as the outer
361 // class will refer to them when creating type aliases.
362 referenced_messages_.insert(message->nested_type(i));
363 }
364 }
365 }
366
367 // Collect enums.
368 for (int i = 0; i < source_->enum_type_count(); ++i)
369 enums_.push_back(source_->enum_type(i));
370
371 if (source_->extension_count() > 0)
372 Abort("top-level extension blocks are not supported");
373
374 for (const Descriptor* message : messages_) {
375 for (int i = 0; i < message->enum_type_count(); ++i) {
376 enums_.push_back(message->enum_type(i));
377 }
378 }
379 }
380
CollectDependencies()381 void CollectDependencies() {
382 // Public import basically means that callers only need to import this
383 // proto in order to use the stuff publicly imported by this proto.
384 for (int i = 0; i < source_->public_dependency_count(); ++i)
385 public_imports_.insert(source_->public_dependency(i));
386
387 if (source_->weak_dependency_count() > 0)
388 Abort("Weak imports are not supported.");
389
390 // Validations. Collect public imports (of collected imports) in DFS order.
391 // Visibilty for current proto:
392 // - all imports listed in current proto,
393 // - public imports of everything imported (recursive).
394 std::vector<const FileDescriptor*> stack;
395 for (int i = 0; i < source_->dependency_count(); ++i) {
396 const FileDescriptor* import = source_->dependency(i);
397 stack.push_back(import);
398 if (public_imports_.count(import) == 0) {
399 private_imports_.insert(import);
400 }
401 }
402
403 while (!stack.empty()) {
404 const FileDescriptor* import = stack.back();
405 stack.pop_back();
406 // Having imports under different packages leads to unnecessary
407 // complexity with namespaces.
408 if (import->package() != package_)
409 Abort("Imported proto must be in the same package.");
410
411 for (int i = 0; i < import->public_dependency_count(); ++i) {
412 stack.push_back(import->public_dependency(i));
413 }
414 }
415
416 // Collect descriptors of messages and enums used in current proto.
417 // It will be used to generate necessary forward declarations and
418 // check that everything lays in the same namespace.
419 for (const Descriptor* message : messages_) {
420 for (int i = 0; i < message->field_count(); ++i) {
421 const FieldDescriptor* field = message->field(i);
422
423 if (field->type() == FieldDescriptor::TYPE_MESSAGE) {
424 if (public_imports_.count(field->message_type()->file()) == 0) {
425 // Avoid multiple forward declarations since
426 // public imports have been already included.
427 referenced_messages_.insert(field->message_type());
428 }
429 } else if (field->type() == FieldDescriptor::TYPE_ENUM) {
430 if (public_imports_.count(field->enum_type()->file()) == 0) {
431 referenced_enums_.insert(field->enum_type());
432 }
433 }
434 }
435 }
436 }
437
Preprocess()438 void Preprocess() {
439 // Package name maps to a series of namespaces.
440 package_ = source_->package();
441 namespaces_ = SplitString(package_, ".");
442 if (!wrapper_namespace_.empty())
443 namespaces_.push_back(wrapper_namespace_);
444
445 full_namespace_prefix_ = "::";
446 for (const std::string& ns : namespaces_)
447 full_namespace_prefix_ += ns + "::";
448
449 CollectDescriptors();
450 CollectDependencies();
451 }
452
453 // Print top header, namespaces and forward declarations.
GeneratePrologue()454 void GeneratePrologue() {
455 std::string greeting =
456 "// Autogenerated by the ProtoZero compiler plugin. DO NOT EDIT.\n";
457 std::string guard = package_ + "_" + source_->name() + "_H_";
458 guard = ToUpper(guard);
459 guard = StripChars(guard, ".-/\\", '_');
460
461 stub_h_->Print(
462 "$greeting$\n"
463 "#ifndef $guard$\n"
464 "#define $guard$\n\n"
465 "#include <stddef.h>\n"
466 "#include <stdint.h>\n\n"
467 "#include \"perfetto/protozero/field_writer.h\"\n"
468 "#include \"perfetto/protozero/message.h\"\n"
469 "#include \"perfetto/protozero/packed_repeated_fields.h\"\n"
470 "#include \"perfetto/protozero/proto_decoder.h\"\n"
471 "#include \"perfetto/protozero/proto_utils.h\"\n",
472 "greeting", greeting, "guard", guard);
473
474 // Print includes for public imports.
475 for (const FileDescriptor* dependency : public_imports_) {
476 // Dependency name could contain slashes but importing from upper-level
477 // directories is not possible anyway since build system processes each
478 // proto file individually. Hence proto lookup path is always equal to the
479 // directory where particular proto file is located and protoc does not
480 // allow reference to upper directory (aka ..) in import path.
481 //
482 // Laconically said:
483 // - source_->name() may never have slashes,
484 // - dependency->name() may have slashes but always refers to inner path.
485 stub_h_->Print("#include \"$name$.h\"\n", "name",
486 ProtoStubName(dependency));
487 }
488 stub_h_->Print("\n");
489
490 // Print namespaces.
491 for (const std::string& ns : namespaces_) {
492 stub_h_->Print("namespace $ns$ {\n", "ns", ns);
493 }
494 stub_h_->Print("\n");
495
496 // Print forward declarations.
497 for (const Descriptor* message : referenced_messages_) {
498 stub_h_->Print("class $class$;\n", "class", GetCppClassName(message));
499 }
500 for (const EnumDescriptor* enumeration : referenced_enums_) {
501 stub_h_->Print("enum $class$ : int32_t;\n", "class",
502 GetCppClassName(enumeration));
503 }
504 stub_h_->Print("\n");
505 }
506
GenerateEnumDescriptor(const EnumDescriptor * enumeration)507 void GenerateEnumDescriptor(const EnumDescriptor* enumeration) {
508 stub_h_->Print("enum $class$ : int32_t {\n", "class",
509 GetCppClassName(enumeration));
510 stub_h_->Indent();
511
512 std::string value_name_prefix;
513 if (enumeration->containing_type() != nullptr)
514 value_name_prefix = GetCppClassName(enumeration) + "_";
515
516 std::string min_name, max_name;
517 int min_val = std::numeric_limits<int>::max();
518 int max_val = -1;
519 for (int i = 0; i < enumeration->value_count(); ++i) {
520 const EnumValueDescriptor* value = enumeration->value(i);
521 stub_h_->Print("$name$ = $number$,\n", "name",
522 value_name_prefix + value->name(), "number",
523 std::to_string(value->number()));
524 if (value->number() < min_val) {
525 min_val = value->number();
526 min_name = value_name_prefix + value->name();
527 }
528 if (value->number() > max_val) {
529 max_val = value->number();
530 max_name = value_name_prefix + value->name();
531 }
532 }
533 stub_h_->Outdent();
534 stub_h_->Print("};\n\n");
535 stub_h_->Print("const $class$ $class$_MIN = $min$;\n", "class",
536 GetCppClassName(enumeration), "min", min_name);
537 stub_h_->Print("const $class$ $class$_MAX = $max$;\n", "class",
538 GetCppClassName(enumeration), "max", max_name);
539 stub_h_->Print("\n");
540 }
541
542 // Packed repeated fields are encoded as a length-delimited field on the wire,
543 // where the payload is the concatenation of invidually encoded elements.
GeneratePackedRepeatedFieldDescriptor(const FieldDescriptor * field)544 void GeneratePackedRepeatedFieldDescriptor(const FieldDescriptor* field) {
545 std::map<std::string, std::string> setter;
546 setter["name"] = field->lowercase_name();
547 setter["field_metadata"] = GetFieldMetadataTypeName(field);
548 setter["action"] = "set";
549 setter["buffer_type"] = FieldTypeToPackedBufferType(field->type());
550 stub_h_->Print(
551 setter,
552 "void $action$_$name$(const $buffer_type$& packed_buffer) {\n"
553 " AppendBytes($field_metadata$::kFieldId, packed_buffer.data(),\n"
554 " packed_buffer.size());\n"
555 "}\n");
556 }
557
GenerateSimpleFieldDescriptor(const FieldDescriptor * field)558 void GenerateSimpleFieldDescriptor(const FieldDescriptor* field) {
559 std::map<std::string, std::string> setter;
560 setter["id"] = std::to_string(field->number());
561 setter["name"] = field->lowercase_name();
562 setter["field_metadata"] = GetFieldMetadataTypeName(field);
563 setter["action"] = field->is_repeated() ? "add" : "set";
564 setter["cpp_type"] = FieldToCppTypeName(field);
565 setter["proto_field_type"] = FieldToProtoSchemaType(field);
566
567 const char* code_stub =
568 "void $action$_$name$($cpp_type$ value) {\n"
569 " static constexpr uint32_t field_id = $field_metadata$::kFieldId;\n"
570 " // Call the appropriate protozero::Message::Append(field_id, ...)\n"
571 " // method based on the type of the field.\n"
572 " ::protozero::internal::FieldWriter<\n"
573 " ::protozero::proto_utils::ProtoSchemaType::$proto_field_type$>\n"
574 " ::Append(*this, field_id, value);\n"
575 "}\n";
576
577 if (field->type() == FieldDescriptor::TYPE_STRING) {
578 // Strings and bytes should have an additional accessor which specifies
579 // the length explicitly.
580 const char* additional_method =
581 "void $action$_$name$(const char* data, size_t size) {\n"
582 " AppendBytes($field_metadata$::kFieldId, data, size);\n"
583 "}\n";
584 stub_h_->Print(setter, additional_method);
585 } else if (field->type() == FieldDescriptor::TYPE_BYTES) {
586 const char* additional_method =
587 "void $action$_$name$(const uint8_t* data, size_t size) {\n"
588 " AppendBytes($field_metadata$::kFieldId, data, size);\n"
589 "}\n";
590 stub_h_->Print(setter, additional_method);
591 } else if (field->type() == FieldDescriptor::TYPE_GROUP ||
592 field->type() == FieldDescriptor::TYPE_MESSAGE) {
593 Abort("Unsupported field type.");
594 return;
595 }
596
597 stub_h_->Print(setter, code_stub);
598 }
599
GenerateNestedMessageFieldDescriptor(const FieldDescriptor * field)600 void GenerateNestedMessageFieldDescriptor(const FieldDescriptor* field) {
601 std::string action = field->is_repeated() ? "add" : "set";
602 std::string inner_class = GetCppClassName(field->message_type());
603 stub_h_->Print(
604 "template <typename T = $inner_class$> T* $action$_$name$() {\n"
605 " return BeginNestedMessage<T>($id$);\n"
606 "}\n\n",
607 "id", std::to_string(field->number()), "name", field->lowercase_name(),
608 "action", action, "inner_class", inner_class);
609 if (field->options().lazy()) {
610 stub_h_->Print(
611 "void $action$_$name$_raw(const std::string& raw) {\n"
612 " return AppendBytes($id$, raw.data(), raw.size());\n"
613 "}\n\n",
614 "id", std::to_string(field->number()), "name",
615 field->lowercase_name(), "action", action, "inner_class",
616 inner_class);
617 }
618 }
619
GenerateDecoder(const Descriptor * message)620 void GenerateDecoder(const Descriptor* message) {
621 int max_field_id = 0;
622 bool has_nonpacked_repeated_fields = false;
623 for (int i = 0; i < message->field_count(); ++i) {
624 const FieldDescriptor* field = message->field(i);
625 if (field->number() > kMaxDecoderFieldId)
626 continue;
627 max_field_id = std::max(max_field_id, field->number());
628 if (field->is_repeated() && !field->is_packed())
629 has_nonpacked_repeated_fields = true;
630 }
631
632 std::string class_name = GetCppClassName(message) + "_Decoder";
633 stub_h_->Print(
634 "class $name$ : public "
635 "::protozero::TypedProtoDecoder</*MAX_FIELD_ID=*/$max$, "
636 "/*HAS_NONPACKED_REPEATED_FIELDS=*/$rep$> {\n",
637 "name", class_name, "max", std::to_string(max_field_id), "rep",
638 has_nonpacked_repeated_fields ? "true" : "false");
639 stub_h_->Print(" public:\n");
640 stub_h_->Indent();
641 stub_h_->Print(
642 "$name$(const uint8_t* data, size_t len) "
643 ": TypedProtoDecoder(data, len) {}\n",
644 "name", class_name);
645 stub_h_->Print(
646 "explicit $name$(const std::string& raw) : "
647 "TypedProtoDecoder(reinterpret_cast<const uint8_t*>(raw.data()), "
648 "raw.size()) {}\n",
649 "name", class_name);
650 stub_h_->Print(
651 "explicit $name$(const ::protozero::ConstBytes& raw) : "
652 "TypedProtoDecoder(raw.data, raw.size) {}\n",
653 "name", class_name);
654
655 for (int i = 0; i < message->field_count(); ++i) {
656 const FieldDescriptor* field = message->field(i);
657 if (field->number() > max_field_id) {
658 stub_h_->Print("// field $name$ omitted because its id is too high\n",
659 "name", field->name());
660 continue;
661 }
662 std::string getter;
663 std::string cpp_type;
664 switch (field->type()) {
665 case FieldDescriptor::TYPE_BOOL:
666 getter = "as_bool";
667 cpp_type = "bool";
668 break;
669 case FieldDescriptor::TYPE_SFIXED32:
670 case FieldDescriptor::TYPE_SINT32:
671 case FieldDescriptor::TYPE_INT32:
672 getter = "as_int32";
673 cpp_type = "int32_t";
674 break;
675 case FieldDescriptor::TYPE_SFIXED64:
676 case FieldDescriptor::TYPE_SINT64:
677 case FieldDescriptor::TYPE_INT64:
678 getter = "as_int64";
679 cpp_type = "int64_t";
680 break;
681 case FieldDescriptor::TYPE_FIXED32:
682 case FieldDescriptor::TYPE_UINT32:
683 getter = "as_uint32";
684 cpp_type = "uint32_t";
685 break;
686 case FieldDescriptor::TYPE_FIXED64:
687 case FieldDescriptor::TYPE_UINT64:
688 getter = "as_uint64";
689 cpp_type = "uint64_t";
690 break;
691 case FieldDescriptor::TYPE_FLOAT:
692 getter = "as_float";
693 cpp_type = "float";
694 break;
695 case FieldDescriptor::TYPE_DOUBLE:
696 getter = "as_double";
697 cpp_type = "double";
698 break;
699 case FieldDescriptor::TYPE_ENUM:
700 getter = "as_int32";
701 cpp_type = "int32_t";
702 break;
703 case FieldDescriptor::TYPE_STRING:
704 getter = "as_string";
705 cpp_type = "::protozero::ConstChars";
706 break;
707 case FieldDescriptor::TYPE_MESSAGE:
708 case FieldDescriptor::TYPE_BYTES:
709 getter = "as_bytes";
710 cpp_type = "::protozero::ConstBytes";
711 break;
712 case FieldDescriptor::TYPE_GROUP:
713 continue;
714 }
715
716 stub_h_->Print("bool has_$name$() const { return at<$id$>().valid(); }\n",
717 "name", field->lowercase_name(), "id",
718 std::to_string(field->number()));
719
720 if (field->is_packed()) {
721 const char* protozero_wire_type =
722 FieldTypeToProtozeroWireType(field->type());
723 stub_h_->Print(
724 "::protozero::PackedRepeatedFieldIterator<$wire_type$, $cpp_type$> "
725 "$name$(bool* parse_error_ptr) const { return "
726 "GetPackedRepeated<$wire_type$, $cpp_type$>($id$, "
727 "parse_error_ptr); }\n",
728 "wire_type", protozero_wire_type, "cpp_type", cpp_type, "name",
729 field->lowercase_name(), "id", std::to_string(field->number()));
730 } else if (field->is_repeated()) {
731 stub_h_->Print(
732 "::protozero::RepeatedFieldIterator<$cpp_type$> $name$() const { "
733 "return "
734 "GetRepeated<$cpp_type$>($id$); }\n",
735 "name", field->lowercase_name(), "cpp_type", cpp_type, "id",
736 std::to_string(field->number()));
737 } else {
738 stub_h_->Print(
739 "$cpp_type$ $name$() const { return at<$id$>().$getter$(); }\n",
740 "name", field->lowercase_name(), "id",
741 std::to_string(field->number()), "cpp_type", cpp_type, "getter",
742 getter);
743 }
744 }
745 stub_h_->Outdent();
746 stub_h_->Print("};\n\n");
747 }
748
GenerateConstantsForMessageFields(const Descriptor * message)749 void GenerateConstantsForMessageFields(const Descriptor* message) {
750 const bool has_fields = (message->field_count() > 0);
751
752 // Field number constants.
753 if (has_fields) {
754 stub_h_->Print("enum : int32_t {\n");
755 stub_h_->Indent();
756
757 for (int i = 0; i < message->field_count(); ++i) {
758 const FieldDescriptor* field = message->field(i);
759 stub_h_->Print("$name$ = $id$,\n", "name",
760 GetFieldNumberConstant(field), "id",
761 std::to_string(field->number()));
762 }
763 stub_h_->Outdent();
764 stub_h_->Print("};\n");
765 }
766 }
767
GenerateMessageDescriptor(const Descriptor * message)768 void GenerateMessageDescriptor(const Descriptor* message) {
769 GenerateDecoder(message);
770
771 stub_h_->Print(
772 "class $name$ : public ::protozero::Message {\n"
773 " public:\n",
774 "name", GetCppClassName(message));
775 stub_h_->Indent();
776
777 stub_h_->Print("using Decoder = $name$_Decoder;\n", "name",
778 GetCppClassName(message));
779
780 GenerateConstantsForMessageFields(message);
781
782 // Using statements for nested messages.
783 for (int i = 0; i < message->nested_type_count(); ++i) {
784 const Descriptor* nested_message = message->nested_type(i);
785 stub_h_->Print("using $local_name$ = $global_name$;\n", "local_name",
786 nested_message->name(), "global_name",
787 GetCppClassName(nested_message, true));
788 }
789
790 // Using statements for nested enums.
791 for (int i = 0; i < message->enum_type_count(); ++i) {
792 const EnumDescriptor* nested_enum = message->enum_type(i);
793 stub_h_->Print("using $local_name$ = $global_name$;\n", "local_name",
794 nested_enum->name(), "global_name",
795 GetCppClassName(nested_enum, true));
796 }
797
798 // Values of nested enums.
799 for (int i = 0; i < message->enum_type_count(); ++i) {
800 const EnumDescriptor* nested_enum = message->enum_type(i);
801 std::string value_name_prefix = GetCppClassName(nested_enum) + "_";
802
803 for (int j = 0; j < nested_enum->value_count(); ++j) {
804 const EnumValueDescriptor* value = nested_enum->value(j);
805 stub_h_->Print("static const $class$ $name$ = $full_name$;\n", "class",
806 nested_enum->name(), "name", value->name(), "full_name",
807 value_name_prefix + value->name());
808 }
809 }
810
811 // Field descriptors.
812 for (int i = 0; i < message->field_count(); ++i) {
813 GenerateFieldDescriptor(GetCppClassName(message), message->field(i));
814 }
815
816 stub_h_->Outdent();
817 stub_h_->Print("};\n\n");
818 }
819
GetFieldMetadataTypeName(const FieldDescriptor * field)820 std::string GetFieldMetadataTypeName(const FieldDescriptor* field) {
821 std::string name = field->camelcase_name();
822 if (isalpha(name[0]))
823 name[0] = static_cast<char>(toupper(name[0]));
824 return "FieldMetadata_" + name;
825 }
826
GetFieldMetadataVariableName(const FieldDescriptor * field)827 std::string GetFieldMetadataVariableName(const FieldDescriptor* field) {
828 std::string name = field->camelcase_name();
829 if (isalpha(name[0]))
830 name[0] = static_cast<char>(toupper(name[0]));
831 return "k" + name;
832 }
833
GenerateFieldMetadata(const std::string & message_cpp_type,const FieldDescriptor * field)834 void GenerateFieldMetadata(const std::string& message_cpp_type,
835 const FieldDescriptor* field) {
836 const char* code_stub = R"(
837 using $field_metadata_type$ =
838 ::protozero::proto_utils::FieldMetadata<
839 $field_id$,
840 ::protozero::proto_utils::RepetitionType::$repetition_type$,
841 ::protozero::proto_utils::ProtoSchemaType::$proto_field_type$,
842 $cpp_type$,
843 $message_cpp_type$>;
844
845 // Ceci n'est pas une pipe.
846 // This is actually a variable of FieldMetadataHelper<FieldMetadata<...>>
847 // type (and users are expected to use it as such, hence kCamelCase name).
848 // It is declared as a function to keep protozero bindings header-only as
849 // inline constexpr variables are not available until C++17 (while inline
850 // functions are).
851 // TODO(altimin): Use inline variable instead after adopting C++17.
852 static constexpr $field_metadata_type$ $field_metadata_var$() { return {}; }
853 )";
854
855 stub_h_->Print(code_stub, "field_id", std::to_string(field->number()),
856 "repetition_type", FieldToRepetitionType(field),
857 "proto_field_type", FieldToProtoSchemaType(field),
858 "cpp_type", FieldToCppTypeName(field), "message_cpp_type",
859 message_cpp_type, "field_metadata_type",
860 GetFieldMetadataTypeName(field), "field_metadata_var",
861 GetFieldMetadataVariableName(field));
862 }
863
GenerateFieldDescriptor(const std::string & message_cpp_type,const FieldDescriptor * field)864 void GenerateFieldDescriptor(const std::string& message_cpp_type,
865 const FieldDescriptor* field) {
866 GenerateFieldMetadata(message_cpp_type, field);
867 if (field->is_packed()) {
868 GeneratePackedRepeatedFieldDescriptor(field);
869 } else if (field->type() != FieldDescriptor::TYPE_MESSAGE) {
870 GenerateSimpleFieldDescriptor(field);
871 } else {
872 GenerateNestedMessageFieldDescriptor(field);
873 }
874 }
875
876 // Generate extension class for a group of FieldDescriptor instances
877 // representing one "extend" block in proto definition. For example:
878 //
879 // message SpecificExtension {
880 // extend GeneralThing {
881 // optional Fizz fizz = 101;
882 // optional Buzz buzz = 102;
883 // }
884 // }
885 //
886 // This is going to be passed as a vector of two elements, "fizz" and
887 // "buzz". Wrapping message is used to provide a name for generated
888 // extension class.
889 //
890 // In the example above, generated code is going to look like:
891 //
892 // class SpecificExtension : public GeneralThing {
893 // Fizz* set_fizz();
894 // Buzz* set_buzz();
895 // }
GenerateExtension(const std::string & extension_name,const std::vector<const FieldDescriptor * > & descriptors)896 void GenerateExtension(
897 const std::string& extension_name,
898 const std::vector<const FieldDescriptor*>& descriptors) {
899 // Use an arbitrary descriptor in order to get generic information not
900 // specific to any of them.
901 const FieldDescriptor* descriptor = descriptors[0];
902 const Descriptor* base_message = descriptor->containing_type();
903
904 // TODO(ddrone): ensure that this code works when containing_type located in
905 // other file or namespace.
906 stub_h_->Print("class $name$ : public $extendee$ {\n", "name",
907 extension_name, "extendee",
908 GetCppClassName(base_message, /*full=*/true));
909 stub_h_->Print(" public:\n");
910 stub_h_->Indent();
911 for (const FieldDescriptor* field : descriptors) {
912 if (field->containing_type() != base_message) {
913 Abort("one wrapper should extend only one message");
914 return;
915 }
916 GenerateFieldDescriptor(extension_name, field);
917 }
918 stub_h_->Outdent();
919 stub_h_->Print("};\n");
920 }
921
GenerateEpilogue()922 void GenerateEpilogue() {
923 for (unsigned i = 0; i < namespaces_.size(); ++i) {
924 stub_h_->Print("} // Namespace.\n");
925 }
926 stub_h_->Print("#endif // Include guard.\n");
927 }
928
929 const FileDescriptor* const source_;
930 Printer* const stub_h_;
931 std::string error_;
932
933 std::string package_;
934 std::string wrapper_namespace_;
935 std::vector<std::string> namespaces_;
936 std::string full_namespace_prefix_;
937 std::vector<const Descriptor*> messages_;
938 std::vector<const EnumDescriptor*> enums_;
939 std::map<std::string, std::vector<const FieldDescriptor*>> extensions_;
940
941 // The custom *Comp comparators are to ensure determinism of the generator.
942 std::set<const FileDescriptor*, FileDescriptorComp> public_imports_;
943 std::set<const FileDescriptor*, FileDescriptorComp> private_imports_;
944 std::set<const Descriptor*, DescriptorComp> referenced_messages_;
945 std::set<const EnumDescriptor*, EnumDescriptorComp> referenced_enums_;
946 };
947
948 class ProtoZeroGenerator : public ::google::protobuf::compiler::CodeGenerator {
949 public:
950 explicit ProtoZeroGenerator();
951 ~ProtoZeroGenerator() override;
952
953 // CodeGenerator implementation
954 bool Generate(const google::protobuf::FileDescriptor* file,
955 const std::string& options,
956 GeneratorContext* context,
957 std::string* error) const override;
958 };
959
ProtoZeroGenerator()960 ProtoZeroGenerator::ProtoZeroGenerator() {}
961
~ProtoZeroGenerator()962 ProtoZeroGenerator::~ProtoZeroGenerator() {}
963
Generate(const FileDescriptor * file,const std::string & options,GeneratorContext * context,std::string * error) const964 bool ProtoZeroGenerator::Generate(const FileDescriptor* file,
965 const std::string& options,
966 GeneratorContext* context,
967 std::string* error) const {
968 const std::unique_ptr<ZeroCopyOutputStream> stub_h_file_stream(
969 context->Open(ProtoStubName(file) + ".h"));
970 const std::unique_ptr<ZeroCopyOutputStream> stub_cc_file_stream(
971 context->Open(ProtoStubName(file) + ".cc"));
972
973 // Variables are delimited by $.
974 Printer stub_h_printer(stub_h_file_stream.get(), '$');
975 GeneratorJob job(file, &stub_h_printer);
976
977 Printer stub_cc_printer(stub_cc_file_stream.get(), '$');
978 stub_cc_printer.Print("// Intentionally empty (crbug.com/998165)\n");
979
980 // Parse additional options.
981 for (const std::string& option : SplitString(options, ",")) {
982 std::vector<std::string> option_pair = SplitString(option, "=");
983 job.SetOption(option_pair[0], option_pair[1]);
984 }
985
986 if (!job.GenerateStubs()) {
987 *error = job.GetFirstError();
988 return false;
989 }
990 return true;
991 }
992
993 } // namespace
994 } // namespace protozero
995
main(int argc,char * argv[])996 int main(int argc, char* argv[]) {
997 ::protozero::ProtoZeroGenerator generator;
998 return google::protobuf::compiler::PluginMain(argc, argv, &generator);
999 }
1000