1 /*
2  * Copyright 2020 Google Inc. All rights reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 /*
18  * NOTE: The following implementation is a translation for the Swift-grpc
19  * generator since flatbuffers doesnt allow plugins for now. if an issue arises
20  * please open an issue in the flatbuffers repository. This file should always
21  * be maintained according to the Swift-grpc repository
22  */
23 #include <map>
24 #include <sstream>
25 
26 #include "flatbuffers/util.h"
27 #include "src/compiler/schema_interface.h"
28 #include "src/compiler/swift_generator.h"
29 
30 namespace grpc_swift_generator {
31 
GenerateMessage(const grpc::string & name)32 grpc::string GenerateMessage(const grpc::string &name) {
33   return "Message<" + name + ">";
34 }
35 
36 // MARK: - Client
37 
GenerateClientFuncName(const grpc_generator::Method * method)38 grpc::string GenerateClientFuncName(const grpc_generator::Method *method) {
39   if (method->NoStreaming()) {
40     return "$GenAccess$ func $MethodName$(_ request: $Input$"
41            ", callOptions: CallOptions?$isNil$) -> UnaryCall<$Input$,$Output$>";
42   }
43 
44   if (method->ClientStreaming()) {
45     return "$GenAccess$ func $MethodName$"
46            "(callOptions: CallOptions?$isNil$) -> "
47            "ClientStreamingCall<$Input$,$Output$>";
48   }
49 
50   if (method->ServerStreaming()) {
51     return "$GenAccess$ func $MethodName$(_ request: $Input$"
52            ", callOptions: CallOptions?$isNil$, handler: @escaping ($Output$"
53            ") -> Void) -> ServerStreamingCall<$Input$, $Output$>";
54   }
55   return "$GenAccess$ func $MethodName$"
56          "(callOptions: CallOptions?$isNil$, handler: @escaping ($Output$"
57          ") -> Void) -> BidirectionalStreamingCall<$Input$, $Output$>";
58 }
59 
GenerateClientFuncBody(const grpc_generator::Method * method)60 grpc::string GenerateClientFuncBody(const grpc_generator::Method *method) {
61   if (method->NoStreaming()) {
62     return "return self.makeUnaryCall(path: "
63            "\"/$PATH$$ServiceName$/$MethodName$\", request: request, "
64            "callOptions: callOptions ?? self.defaultCallOptions)";
65   }
66 
67   if (method->ClientStreaming()) {
68     return "return self.makeClientStreamingCall(path: "
69            "\"/$PATH$$ServiceName$/$MethodName$\", callOptions: callOptions ?? "
70            "self.defaultCallOptions)";
71   }
72 
73   if (method->ServerStreaming()) {
74     return "return self.makeServerStreamingCall(path: "
75            "\"/$PATH$$ServiceName$/$MethodName$\", request: request, "
76            "callOptions: callOptions ?? self.defaultCallOptions, handler: "
77            "handler)";
78   }
79   return "return self.makeBidirectionalStreamingCall(path: "
80          "\"/$PATH$$ServiceName$/$MethodName$\", callOptions: callOptions ?? "
81          "self.defaultCallOptions, handler: handler)";
82 }
83 
GenerateClientProtocol(const grpc_generator::Service * service,grpc_generator::Printer * printer,std::map<grpc::string,grpc::string> * dictonary)84 void GenerateClientProtocol(const grpc_generator::Service *service,
85                             grpc_generator::Printer *printer,
86                             std::map<grpc::string, grpc::string> *dictonary) {
87   auto vars = *dictonary;
88   printer->Print(vars, "$ACCESS$ protocol $ServiceName$Service {\n");
89   vars["GenAccess"] = "";
90   for (auto it = 0; it < service->method_count(); it++) {
91     auto method = service->method(it);
92     vars["Input"] = GenerateMessage(method->get_input_type_name());
93     vars["Output"] = GenerateMessage(method->get_output_type_name());
94     vars["MethodName"] = method->name();
95     vars["isNil"] = "";
96     printer->Print("\t");
97     auto func = GenerateClientFuncName(method.get());
98     printer->Print(vars, func.c_str());
99     printer->Print("\n");
100   }
101   printer->Print("}\n\n");
102 }
103 
GenerateClientClass(const grpc_generator::Service * service,grpc_generator::Printer * printer,std::map<grpc::string,grpc::string> * dictonary)104 void GenerateClientClass(const grpc_generator::Service *service,
105                          grpc_generator::Printer *printer,
106                          std::map<grpc::string, grpc::string> *dictonary) {
107   auto vars = *dictonary;
108   printer->Print(vars,
109                  "$ACCESS$ final class $ServiceName$ServiceClient: GRPCClient, "
110                  "$ServiceName$Service {\n");
111   printer->Print(vars, "\t$ACCESS$ let connection: ClientConnection\n");
112   printer->Print(vars, "\t$ACCESS$ var defaultCallOptions: CallOptions\n");
113   printer->Print("\n");
114   printer->Print(vars,
115                  "\t$ACCESS$ init(connection: ClientConnection, "
116                  "defaultCallOptions: CallOptions = CallOptions()) {\n");
117   printer->Print("\t\tself.connection = connection\n");
118   printer->Print("\t\tself.defaultCallOptions = defaultCallOptions\n");
119   printer->Print("\t}");
120   printer->Print("\n");
121   vars["GenAccess"] = "public";
122   for (auto it = 0; it < service->method_count(); it++) {
123     auto method = service->method(it);
124     vars["Input"] = GenerateMessage(method->get_input_type_name());
125     vars["Output"] = GenerateMessage(method->get_output_type_name());
126     vars["MethodName"] = method->name();
127     vars["isNil"] = " = nil";
128     printer->Print("\n\t");
129     auto func = GenerateClientFuncName(method.get());
130     printer->Print(vars, func.c_str());
131     printer->Print(" {\n");
132     auto body = GenerateClientFuncBody(method.get());
133     printer->Print("\t\t");
134     printer->Print(vars, body.c_str());
135     printer->Print("\n\t}\n");
136   }
137   printer->Print("}\n");
138 }
139 
140 // MARK: - Server
141 
GenerateServerFuncName(const grpc_generator::Method * method)142 grpc::string GenerateServerFuncName(const grpc_generator::Method *method) {
143   if (method->NoStreaming()) {
144     return "func $MethodName$(_ request: $Input$"
145            ", context: StatusOnlyCallContext) -> EventLoopFuture<$Output$>";
146   }
147 
148   if (method->ClientStreaming()) {
149     return "func $MethodName$(context: UnaryResponseCallContext<$Output$>) -> "
150            "EventLoopFuture<(StreamEvent<$Input$"
151            ">) -> Void>";
152   }
153 
154   if (method->ServerStreaming()) {
155     return "func $MethodName$(request: $Input$"
156            ", context: StreamingResponseCallContext<$Output$>) -> "
157            "EventLoopFuture<GRPCStatus>";
158   }
159   return "func $MethodName$(context: StreamingResponseCallContext<$Output$>) "
160          "-> EventLoopFuture<(StreamEvent<$Input$>) -> Void>";
161 }
162 
GenerateServerExtensionBody(const grpc_generator::Method * method)163 grpc::string GenerateServerExtensionBody(const grpc_generator::Method *method) {
164   grpc::string start = "\t\tcase \"$MethodName$\":\n\t\t";
165   if (method->NoStreaming()) {
166     return start +
167            "return UnaryCallHandler(callHandlerContext: callHandlerContext) { "
168            "context in"
169            "\n\t\t\t"
170            "return { request in"
171            "\n\t\t\t\t"
172            "self.$MethodName$(request, context: context)"
173            "\n\t\t\t}"
174            "\n\t\t}";
175   }
176   if (method->ClientStreaming()) {
177     return start +
178            "return ClientStreamingCallHandler(callHandlerContext: "
179            "callHandlerContext) { context in"
180            "\n\t\t\t"
181            "return { request in"
182            "\n\t\t\t\t"
183            "self.$MethodName$(request: request, context: context)"
184            "\n\t\t\t}"
185            "\n\t\t}";
186   }
187   if (method->ServerStreaming()) {
188     return start +
189            "return ServerStreamingCallHandler(callHandlerContext: "
190            "callHandlerContext) { context in"
191            "\n\t\t\t"
192            "return { request in"
193            "\n\t\t\t\t"
194            "self.$MethodName$(request: request, context: context)"
195            "\n\t\t\t}"
196            "\n\t\t}";
197   }
198   if (method->BidiStreaming()) {
199     return start +
200            "return BidirectionalStreamingCallHandler(callHandlerContext: "
201            "callHandlerContext) { context in"
202            "\n\t\t\t"
203            "return { request in"
204            "\n\t\t\t\t"
205            "self.$MethodName$(request: request, context: context)"
206            "\n\t\t\t}"
207            "\n\t\t}";
208   }
209   return "";
210 }
211 
GenerateServerProtocol(const grpc_generator::Service * service,grpc_generator::Printer * printer,std::map<grpc::string,grpc::string> * dictonary)212 void GenerateServerProtocol(const grpc_generator::Service *service,
213                             grpc_generator::Printer *printer,
214                             std::map<grpc::string, grpc::string> *dictonary) {
215   auto vars = *dictonary;
216   printer->Print(
217       vars, "$ACCESS$ protocol $ServiceName$Provider: CallHandlerProvider {\n");
218   for (auto it = 0; it < service->method_count(); it++) {
219     auto method = service->method(it);
220     vars["Input"] = GenerateMessage(method->get_input_type_name());
221     vars["Output"] = GenerateMessage(method->get_output_type_name());
222     vars["MethodName"] = method->name();
223     printer->Print("\t");
224     auto func = GenerateServerFuncName(method.get());
225     printer->Print(vars, func.c_str());
226     printer->Print("\n");
227   }
228   printer->Print("}\n\n");
229 
230   printer->Print(vars, "$ACCESS$ extension $ServiceName$Provider {\n");
231   printer->Print(vars,
232                  "\tvar serviceName: String { return "
233                  "\"$PATH$$ServiceName$\" }\n");
234   printer->Print(
235       "\tfunc handleMethod(_ methodName: String, callHandlerContext: "
236       "CallHandlerContext) -> GRPCCallHandler? {\n");
237   printer->Print("\t\tswitch methodName {\n");
238   for (auto it = 0; it < service->method_count(); it++) {
239     auto method = service->method(it);
240     vars["Input"] = GenerateMessage(method->get_input_type_name());
241     vars["Output"] = GenerateMessage(method->get_output_type_name());
242     vars["MethodName"] = method->name();
243     auto body = GenerateServerExtensionBody(method.get());
244     printer->Print(vars, body.c_str());
245     printer->Print("\n");
246   }
247   printer->Print("\t\tdefault: return nil;\n");
248   printer->Print("\t\t}\n");
249   printer->Print("\t}\n\n");
250   printer->Print("}\n\n");
251 }
252 
Generate(grpc_generator::File * file,const grpc_generator::Service * service)253 grpc::string Generate(grpc_generator::File *file,
254                       const grpc_generator::Service *service) {
255   grpc::string output;
256   std::map<grpc::string, grpc::string> vars;
257   vars["PATH"] = file->package();
258   if (!file->package().empty()) { vars["PATH"].append("."); }
259   vars["ServiceName"] = service->name();
260   vars["ACCESS"] = "public";
261   auto printer = file->CreatePrinter(&output);
262   printer->Print(vars,
263                  "/// Usage: instantiate $ServiceName$ServiceClient, then call "
264                  "methods of this protocol to make API calls.\n");
265   GenerateClientProtocol(service, &*printer, &vars);
266   GenerateClientClass(service, &*printer, &vars);
267   printer->Print("\n");
268   GenerateServerProtocol(service, &*printer, &vars);
269   return output;
270 }
271 
GenerateHeader()272 grpc::string GenerateHeader() {
273   grpc::string code;
274   code +=
275       "/// The following code is generated by the Flatbuffers library which "
276       "might not be in sync with grpc-swift\n";
277   code +=
278       "/// in case of an issue please open github issue, though it would be "
279       "maintained\n";
280   code += "import Foundation\n";
281   code += "import GRPC\n";
282   code += "import NIO\n";
283   code += "import NIOHTTP1\n";
284   code += "import FlatBuffers\n";
285   code += "\n";
286   code +=
287       "public protocol GRPCFlatBufPayload: GRPCPayload, FlatBufferGRPCMessage "
288       "{}\n";
289 
290   code += "public extension GRPCFlatBufPayload {\n";
291   code += "    init(serializedByteBuffer: inout NIO.ByteBuffer) throws {\n";
292   code +=
293       "        self.init(byteBuffer: FlatBuffers.ByteBuffer(contiguousBytes: "
294       "serializedByteBuffer.readableBytesView, count: "
295       "serializedByteBuffer.readableBytes))\n";
296   code += "    }\n";
297 
298   code += "    func serialize(into buffer: inout NIO.ByteBuffer) throws {\n";
299   code +=
300       "        let buf = UnsafeRawBufferPointer(start: self.rawPointer, count: "
301       "Int(self.size))\n";
302   code += "        buffer.writeBytes(buf)\n";
303   code += "    }\n";
304   code += "}\n";
305   code += "extension Message: GRPCFlatBufPayload {}\n";
306   return code;
307 }
308 }  // namespace grpc_swift_generator
309