1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <functional>
18 #include <map>
19 #include <string>
20 #include <vector>
21 
22 #include <google/protobuf/descriptor.h>
23 #include <google/protobuf/compiler/plugin.h>
24 #include <google/protobuf/compiler/code_generator.h>
25 #include <google/protobuf/io/printer.h>
26 #include <google/protobuf/io/zero_copy_stream.h>
27 #include <google/protobuf/stubs/strutil.h>
28 
29 #include "nugget/protobuf/options.pb.h"
30 
31 using ::google::protobuf::FileDescriptor;
32 using ::google::protobuf::JoinStrings;
33 using ::google::protobuf::MethodDescriptor;
34 using ::google::protobuf::ServiceDescriptor;
35 using ::google::protobuf::Split;
36 using ::google::protobuf::SplitStringUsing;
37 using ::google::protobuf::StripSuffixString;
38 using ::google::protobuf::compiler::CodeGenerator;
39 using ::google::protobuf::compiler::OutputDirectory;
40 using ::google::protobuf::io::Printer;
41 using ::google::protobuf::io::ZeroCopyOutputStream;
42 
43 using ::nugget::protobuf::app_id;
44 using ::nugget::protobuf::request_buffer_size;
45 using ::nugget::protobuf::response_buffer_size;
46 
47 namespace {
48 
validateServiceOptions(const ServiceDescriptor & service)49 std::string validateServiceOptions(const ServiceDescriptor& service) {
50     if (!service.options().HasExtension(app_id)) {
51         return "nugget.protobuf.app_id is not defined for service " + service.name();
52     }
53     if (!service.options().HasExtension(request_buffer_size)) {
54         return "nugget.protobuf.request_buffer_size is not defined for service " + service.name();
55     }
56     if (!service.options().HasExtension(response_buffer_size)) {
57         return "nugget.protobuf.response_buffer_size is not defined for service " + service.name();
58     }
59     return "";
60 }
61 
62 template <typename Descriptor>
Packages(const Descriptor & descriptor)63 std::vector<std::string> Packages(const Descriptor& descriptor) {
64     std::vector<std::string> namespaces;
65     SplitStringUsing(descriptor.full_name(), ".", &namespaces);
66     namespaces.pop_back(); // just take the package
67     return namespaces;
68 }
69 
70 template <typename Descriptor>
FullyQualifiedIdentifier(const Descriptor & descriptor)71 std::string FullyQualifiedIdentifier(const Descriptor& descriptor) {
72     const auto namespaces = Packages(descriptor);
73     if (namespaces.empty()) {
74         return "::" + descriptor.name();
75     } else {
76         std::string namespace_path;
77         JoinStrings(namespaces, "::", &namespace_path);
78         return "::" + namespace_path + "::" + descriptor.name();
79     }
80 }
81 
82 template <typename Descriptor>
FullyQualifiedHeader(const Descriptor & descriptor)83 std::string FullyQualifiedHeader(const Descriptor& descriptor) {
84     const auto packages = Packages(descriptor);
85     const auto file = Split(descriptor.file()->name(), "/").back();
86     const auto header = StripSuffixString(file, ".proto") + ".pb.h";
87     if (packages.empty()) {
88         return header;
89     } else {
90         std::string package_path;
91         JoinStrings(packages, "/", &package_path);
92         return package_path + "/" + header;
93     }
94 }
95 
96 template <typename Descriptor>
OpenNamespaces(Printer & printer,const Descriptor & descriptor)97 void OpenNamespaces(Printer& printer, const Descriptor& descriptor) {
98     const auto namespaces = Packages(descriptor);
99     for (const auto& ns : namespaces) {
100         std::map<std::string, std::string> namespaceVars;
101         namespaceVars["namespace"] = ns;
102         printer.Print(namespaceVars, R"(
103 namespace $namespace$ {)");
104     }
105 }
106 
107 template <typename Descriptor>
CloseNamespaces(Printer & printer,const Descriptor & descriptor)108 void CloseNamespaces(Printer& printer, const Descriptor& descriptor) {
109     const auto namespaces = Packages(descriptor);
110     for (auto it = namespaces.crbegin(); it != namespaces.crend(); ++it) {
111         std::map<std::string, std::string> namespaceVars;
112         namespaceVars["namespace"] = *it;
113         printer.Print(namespaceVars, R"(
114 } // namespace $namespace$)");
115     }
116 }
117 
ForEachMethod(const ServiceDescriptor & service,std::function<void (std::map<std::string,std::string>)> handler)118 void ForEachMethod(const ServiceDescriptor& service,
119                    std::function<void(std::map<std::string, std::string>)> handler) {
120     for (int i = 0; i < service.method_count(); ++i) {
121         const MethodDescriptor& method = *service.method(i);
122         std::map<std::string, std::string> vars;
123         vars["method_id"] = std::to_string(i);
124         vars["method_name"] = method.name();
125         vars["method_input_type"] = FullyQualifiedIdentifier(*method.input_type());
126         vars["method_output_type"] = FullyQualifiedIdentifier(*method.output_type());
127         handler(vars);
128     }
129 }
130 
GenerateMockClient(Printer & printer,const ServiceDescriptor & service)131 void GenerateMockClient(Printer& printer, const ServiceDescriptor& service) {
132     std::map<std::string, std::string> vars;
133     vars["include_guard"] = "PROTOC_GENERATED_MOCK_" + service.name() + "_CLIENT_H";
134     vars["service_header"] = service.name() + ".client.h";
135     vars["mock_class"] = "Mock" + service.name();
136     vars["class"] = service.name();
137 
138     printer.Print(vars, R"(
139 #ifndef $include_guard$
140 #define $include_guard$
141 
142 #include <gmock/gmock.h>
143 
144 #include <$service_header$>)");
145 
146     OpenNamespaces(printer, service);
147 
148     printer.Print(vars, R"(
149 struct $mock_class$ : public I$class$ {)");
150 
151     ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) {
152         printer.Print(methodVars, R"(
153     MOCK_METHOD2($method_name$, uint32_t(const $method_input_type$&, $method_output_type$*));)");
154     });
155 
156     printer.Print(vars, R"(
157 };)");
158 
159     CloseNamespaces(printer, service);
160 
161     printer.Print(vars, R"(
162 #endif)");
163 }
164 
GenerateClientHeader(Printer & printer,const ServiceDescriptor & service)165 void GenerateClientHeader(Printer& printer, const ServiceDescriptor& service) {
166     std::map<std::string, std::string> vars;
167     vars["include_guard"] = "PROTOC_GENERATED_" + service.name() + "_CLIENT_H";
168     vars["protobuf_header"] = FullyQualifiedHeader(service);
169     vars["class"] = service.name();
170     vars["iface_class"] = "I" + service.name();
171     vars["app_id"] = "APP_ID_" + service.options().GetExtension(app_id);
172 
173     printer.Print(vars, R"(
174 #ifndef $include_guard$
175 #define $include_guard$
176 
177 #include <application.h>
178 #include <nos/AppClient.h>
179 #include <nos/NuggetClientInterface.h>
180 
181 #include "$protobuf_header$")");
182 
183     OpenNamespaces(printer, service);
184 
185     // Pure virtual interface to make testing easier
186     printer.Print(vars, R"(
187 class $iface_class$ {
188 public:
189     virtual ~$iface_class$() = default;)");
190 
191     ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) {
192         printer.Print(methodVars, R"(
193     virtual uint32_t $method_name$(const $method_input_type$&, $method_output_type$*) = 0;)");
194     });
195 
196     printer.Print(vars, R"(
197 };)");
198 
199     // Implementation of the interface for Nugget
200     printer.Print(vars, R"(
201 class $class$ : public $iface_class$ {
202     ::nos::AppClient _app;
203 public:
204     $class$(::nos::NuggetClientInterface& client) : _app{client, $app_id$} {}
205     ~$class$() override = default;)");
206 
207     ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) {
208         printer.Print(methodVars, R"(
209     uint32_t $method_name$(const $method_input_type$&, $method_output_type$*) override;)");
210     });
211 
212     printer.Print(vars, R"(
213 };)");
214 
215     CloseNamespaces(printer, service);
216 
217     printer.Print(vars, R"(
218 #endif)");
219 }
220 
GenerateClientSource(Printer & printer,const ServiceDescriptor & service)221 void GenerateClientSource(Printer& printer, const ServiceDescriptor& service) {
222     std::map<std::string, std::string> vars;
223     vars["generated_header"] = service.name() + ".client.h";
224     vars["class"] = service.name();
225 
226     const uint32_t max_request_size = service.options().GetExtension(request_buffer_size);
227     const uint32_t max_response_size = service.options().GetExtension(response_buffer_size);
228     vars["max_request_size"] = std::to_string(max_request_size);
229     vars["max_response_size"] = std::to_string(max_response_size);
230 
231     printer.Print(vars, R"(
232 #include <$generated_header$>
233 
234 #include <application.h>)");
235 
236     OpenNamespaces(printer, service);
237 
238     // Methods
239     ForEachMethod(service, [&](std::map<std::string, std::string>  methodVars) {
240         methodVars.insert(vars.begin(), vars.end());
241         printer.Print(methodVars, R"(
242 uint32_t $class$::$method_name$(const $method_input_type$& request, $method_output_type$* response) {
243     const size_t request_size = request.ByteSize();
244     if (request_size > $max_request_size$) {
245         return APP_ERROR_TOO_MUCH;
246     }
247     std::vector<uint8_t> buffer(request_size);
248     if (!request.SerializeToArray(buffer.data(), buffer.size())) {
249         return APP_ERROR_RPC;
250     }
251     std::vector<uint8_t> responseBuffer;
252     if (response != nullptr) {
253       responseBuffer.resize($max_response_size$);
254     }
255     const uint32_t appStatus = _app.Call($method_id$, buffer,
256                                          (response != nullptr) ? &responseBuffer : nullptr);
257     if (appStatus == APP_SUCCESS && response != nullptr) {
258         if (!response->ParseFromArray(responseBuffer.data(), responseBuffer.size())) {
259             return APP_ERROR_RPC;
260         }
261     }
262     return appStatus;
263 })");
264     });
265 
266     CloseNamespaces(printer, service);
267 }
268 
269 // Generator for C++ Nugget service client
270 class CppNuggetServiceClientGenerator : public CodeGenerator {
271 public:
272     CppNuggetServiceClientGenerator() = default;
273     ~CppNuggetServiceClientGenerator() override = default;
274 
Generate(const FileDescriptor * file,const std::string & parameter,OutputDirectory * output_directory,std::string * error) const275     bool Generate(const FileDescriptor* file,
276                   const std::string& parameter,
277                   OutputDirectory* output_directory,
278                   std::string* error) const override {
279         for (int i = 0; i < file->service_count(); ++i) {
280             const auto& service = *file->service(i);
281 
282             *error = validateServiceOptions(service);
283             if (!error->empty()) {
284                 return false;
285             }
286 
287             if (parameter == "mock") {
288                 std::unique_ptr<ZeroCopyOutputStream> output{
289                         output_directory->Open("Mock" + service.name() + ".client.h")};
290                 Printer printer(output.get(), '$');
291                 GenerateMockClient(printer, service);
292             } else if (parameter == "header") {
293                 std::unique_ptr<ZeroCopyOutputStream> output{
294                         output_directory->Open(service.name() + ".client.h")};
295                 Printer printer(output.get(), '$');
296                 GenerateClientHeader(printer, service);
297             } else if (parameter == "source") {
298                 std::unique_ptr<ZeroCopyOutputStream> output{
299                         output_directory->Open(service.name() + ".client.cpp")};
300                 Printer printer(output.get(), '$');
301                 GenerateClientSource(printer, service);
302             } else {
303                 *error = "Illegal parameter: must be mock|header|source";
304                 return false;
305             }
306         }
307 
308         return true;
309     }
310 
311 private:
312     GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CppNuggetServiceClientGenerator);
313 };
314 
315 } // namespace
316 
main(int argc,char * argv[])317 int main(int argc, char* argv[]) {
318     GOOGLE_PROTOBUF_VERIFY_VERSION;
319     CppNuggetServiceClientGenerator generator;
320     return google::protobuf::compiler::PluginMain(argc, argv, &generator);
321 }
322