1 /*
2  * Copyright 2016 Google Inc. All rights reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/compiler/java_generator.h"
18 
19 #include <algorithm>
20 #include <iostream>
21 #include <iterator>
22 #include <map>
23 #include <utility>
24 #include <vector>
25 
26 // just to get flatbuffer_version_string()
27 #include <flatbuffers/flatbuffers.h>
28 #include <flatbuffers/util.h>
29 #define to_string flatbuffers::NumToString
30 
31 // Stringify helpers used solely to cast GRPC_VERSION
32 #ifndef STR
33 #define STR(s) #s
34 #endif
35 
36 #ifndef XSTR
37 #define XSTR(s) STR(s)
38 #endif
39 
40 
41 typedef grpc_generator::Printer Printer;
42 typedef std::map<grpc::string, grpc::string> VARS;
43 typedef grpc_generator::Service ServiceDescriptor;
44 typedef grpc_generator::CommentHolder
45     DescriptorType;  // base class of all 'descriptors'
46 typedef grpc_generator::Method MethodDescriptor;
47 
48 namespace grpc_java_generator {
49 typedef std::string string;
50 // Generates imports for the service
GenerateImports(grpc_generator::File * file,grpc_generator::Printer * printer,VARS & vars)51 void GenerateImports(grpc_generator::File* file,
52                      grpc_generator::Printer* printer, VARS& vars) {
53   vars["filename"] = file->filename();
54   printer->Print(
55       vars,
56       "//Generated by flatc compiler (version $flatc_version$)\n");
57   printer->Print("//If you make any local changes, they will be lost\n");
58   printer->Print(vars, "//source: $filename$.fbs\n\n");
59   printer->Print(vars, "package $Package$;\n\n");
60   vars["Package"] = vars["Package"] + ".";
61   if (!file->additional_headers().empty()) {
62     printer->Print(file->additional_headers().c_str());
63     printer->Print("\n\n");
64   }
65 }
66 
67 // Adjust a method name prefix identifier to follow the JavaBean spec:
68 //   - decapitalize the first letter
69 //   - remove embedded underscores & capitalize the following letter
MixedLower(const string & word)70 static string MixedLower(const string& word) {
71   string w;
72   w += static_cast<string::value_type>(tolower(word[0]));
73   bool after_underscore = false;
74   for (size_t i = 1; i < word.length(); ++i) {
75     if (word[i] == '_') {
76       after_underscore = true;
77     } else {
78       w += after_underscore ? static_cast<string::value_type>(toupper(word[i]))
79                             : word[i];
80       after_underscore = false;
81     }
82   }
83   return w;
84 }
85 
86 // Converts to the identifier to the ALL_UPPER_CASE format.
87 //   - An underscore is inserted where a lower case letter is followed by an
88 //     upper case letter.
89 //   - All letters are converted to upper case
ToAllUpperCase(const string & word)90 static string ToAllUpperCase(const string& word) {
91   string w;
92   for (size_t i = 0; i < word.length(); ++i) {
93     w += static_cast<string::value_type>(toupper(word[i]));
94     if ((i < word.length() - 1) && islower(word[i]) && isupper(word[i + 1])) {
95       w += '_';
96     }
97   }
98   return w;
99 }
100 
LowerMethodName(const MethodDescriptor * method)101 static inline string LowerMethodName(const MethodDescriptor* method) {
102   return MixedLower(method->name());
103 }
104 
MethodPropertiesFieldName(const MethodDescriptor * method)105 static inline string MethodPropertiesFieldName(const MethodDescriptor* method) {
106   return "METHOD_" + ToAllUpperCase(method->name());
107 }
108 
MethodPropertiesGetterName(const MethodDescriptor * method)109 static inline string MethodPropertiesGetterName(
110     const MethodDescriptor* method) {
111   return MixedLower("get_" + method->name() + "_method");
112 }
113 
MethodIdFieldName(const MethodDescriptor * method)114 static inline string MethodIdFieldName(const MethodDescriptor* method) {
115   return "METHODID_" + ToAllUpperCase(method->name());
116 }
117 
JavaClassName(VARS & vars,const string & name)118 static inline string JavaClassName(VARS& vars, const string& name) {
119   // string name = google::protobuf::compiler::java::ClassName(desc);
120   return vars["Package"] + name;
121 }
122 
ServiceClassName(const string & service_name)123 static inline string ServiceClassName(const string& service_name) {
124   return service_name + "Grpc";
125 }
126 
127 // TODO(nmittler): Remove once protobuf includes javadoc methods in
128 // distribution.
129 template <typename ITR>
GrpcSplitStringToIteratorUsing(const string & full,const char * delim,ITR & result)130 static void GrpcSplitStringToIteratorUsing(const string& full,
131                                            const char* delim, ITR& result) {
132   // Optimize the common case where delim is a single character.
133   if (delim[0] != '\0' && delim[1] == '\0') {
134     char c = delim[0];
135     const char* p = full.data();
136     const char* end = p + full.size();
137     while (p != end) {
138       if (*p == c) {
139         ++p;
140       } else {
141         const char* start = p;
142         while (++p != end && *p != c)
143           ;
144         *result++ = string(start, p - start);
145       }
146     }
147     return;
148   }
149 
150   string::size_type begin_index, end_index;
151   begin_index = full.find_first_not_of(delim);
152   while (begin_index != string::npos) {
153     end_index = full.find_first_of(delim, begin_index);
154     if (end_index == string::npos) {
155       *result++ = full.substr(begin_index);
156       return;
157     }
158     *result++ = full.substr(begin_index, (end_index - begin_index));
159     begin_index = full.find_first_not_of(delim, end_index);
160   }
161 }
162 
GrpcSplitStringUsing(const string & full,const char * delim,std::vector<string> * result)163 static void GrpcSplitStringUsing(const string& full, const char* delim,
164                                  std::vector<string>* result) {
165   std::back_insert_iterator<std::vector<string>> it(*result);
166   GrpcSplitStringToIteratorUsing(full, delim, it);
167 }
168 
GrpcSplit(const string & full,const char * delim)169 static std::vector<string> GrpcSplit(const string& full, const char* delim) {
170   std::vector<string> result;
171   GrpcSplitStringUsing(full, delim, &result);
172   return result;
173 }
174 
175 // TODO(nmittler): Remove once protobuf includes javadoc methods in
176 // distribution.
GrpcEscapeJavadoc(const string & input)177 static string GrpcEscapeJavadoc(const string& input) {
178   string result;
179   result.reserve(input.size() * 2);
180 
181   char prev = '*';
182 
183   for (string::size_type i = 0; i < input.size(); i++) {
184     char c = input[i];
185     switch (c) {
186       case '*':
187         // Avoid "/*".
188         if (prev == '/') {
189           result.append("&#42;");
190         } else {
191           result.push_back(c);
192         }
193         break;
194       case '/':
195         // Avoid "*/".
196         if (prev == '*') {
197           result.append("&#47;");
198         } else {
199           result.push_back(c);
200         }
201         break;
202       case '@':
203         // '@' starts javadoc tags including the @deprecated tag, which will
204         // cause a compile-time error if inserted before a declaration that
205         // does not have a corresponding @Deprecated annotation.
206         result.append("&#64;");
207         break;
208       case '<':
209         // Avoid interpretation as HTML.
210         result.append("&lt;");
211         break;
212       case '>':
213         // Avoid interpretation as HTML.
214         result.append("&gt;");
215         break;
216       case '&':
217         // Avoid interpretation as HTML.
218         result.append("&amp;");
219         break;
220       case '\\':
221         // Java interprets Unicode escape sequences anywhere!
222         result.append("&#92;");
223         break;
224       default:
225         result.push_back(c);
226         break;
227     }
228 
229     prev = c;
230   }
231 
232   return result;
233 }
234 
GrpcGetDocLines(const string & comments)235 static std::vector<string> GrpcGetDocLines(const string& comments) {
236   if (!comments.empty()) {
237     // TODO(kenton):  Ideally we should parse the comment text as Markdown and
238     //   write it back as HTML, but this requires a Markdown parser.  For now
239     //   we just use <pre> to get fixed-width text formatting.
240 
241     // If the comment itself contains block comment start or end markers,
242     // HTML-escape them so that they don't accidentally close the doc comment.
243     string escapedComments = GrpcEscapeJavadoc(comments);
244 
245     std::vector<string> lines = GrpcSplit(escapedComments, "\n");
246     while (!lines.empty() && lines.back().empty()) {
247       lines.pop_back();
248     }
249     return lines;
250   }
251   return std::vector<string>();
252 }
253 
GrpcGetDocLinesForDescriptor(const DescriptorType * descriptor)254 static std::vector<string> GrpcGetDocLinesForDescriptor(
255     const DescriptorType* descriptor) {
256   return descriptor->GetAllComments();
257   // return GrpcGetDocLines(descriptor->GetLeadingComments("///"));
258 }
259 
GrpcWriteDocCommentBody(Printer * printer,VARS & vars,const std::vector<string> & lines,bool surroundWithPreTag)260 static void GrpcWriteDocCommentBody(Printer* printer, VARS& vars,
261                                     const std::vector<string>& lines,
262                                     bool surroundWithPreTag) {
263   if (!lines.empty()) {
264     if (surroundWithPreTag) {
265       printer->Print(" * <pre>\n");
266     }
267 
268     for (size_t i = 0; i < lines.size(); i++) {
269       // Most lines should start with a space.  Watch out for lines that start
270       // with a /, since putting that right after the leading asterisk will
271       // close the comment.
272       vars["line"] = lines[i];
273       if (!lines[i].empty() && lines[i][0] == '/') {
274         printer->Print(vars, " * $line$\n");
275       } else {
276         printer->Print(vars, " *$line$\n");
277       }
278     }
279 
280     if (surroundWithPreTag) {
281       printer->Print(" * </pre>\n");
282     }
283   }
284 }
285 
GrpcWriteDocComment(Printer * printer,VARS & vars,const string & comments)286 static void GrpcWriteDocComment(Printer* printer, VARS& vars,
287                                 const string& comments) {
288   printer->Print("/**\n");
289   std::vector<string> lines = GrpcGetDocLines(comments);
290   GrpcWriteDocCommentBody(printer, vars, lines, false);
291   printer->Print(" */\n");
292 }
293 
GrpcWriteServiceDocComment(Printer * printer,VARS & vars,const ServiceDescriptor * service)294 static void GrpcWriteServiceDocComment(Printer* printer, VARS& vars,
295                                        const ServiceDescriptor* service) {
296   printer->Print("/**\n");
297   std::vector<string> lines = GrpcGetDocLinesForDescriptor(service);
298   GrpcWriteDocCommentBody(printer, vars, lines, true);
299   printer->Print(" */\n");
300 }
301 
GrpcWriteMethodDocComment(Printer * printer,VARS & vars,const MethodDescriptor * method)302 void GrpcWriteMethodDocComment(Printer* printer, VARS& vars,
303                                const MethodDescriptor* method) {
304   printer->Print("/**\n");
305   std::vector<string> lines = GrpcGetDocLinesForDescriptor(method);
306   GrpcWriteDocCommentBody(printer, vars, lines, true);
307   printer->Print(" */\n");
308 }
309 
310 //outputs static singleton extractor for type stored in "extr_type" and "extr_type_name" vars
PrintTypeExtractor(Printer * p,VARS & vars)311 static void PrintTypeExtractor(Printer* p, VARS& vars) {
312   p->Print(
313     vars,
314     "private static volatile FlatbuffersUtils.FBExtactor<$extr_type$> "
315     "extractorOf$extr_type_name$;\n"
316     "private static FlatbuffersUtils.FBExtactor<$extr_type$> "
317     "getExtractorOf$extr_type_name$() {\n"
318     "    if (extractorOf$extr_type_name$ != null) return "
319     "extractorOf$extr_type_name$;\n"
320     "    synchronized ($service_class_name$.class) {\n"
321     "        if (extractorOf$extr_type_name$ != null) return "
322     "extractorOf$extr_type_name$;\n"
323     "        extractorOf$extr_type_name$ = new "
324     "FlatbuffersUtils.FBExtactor<$extr_type$>() {\n"
325     "            public $extr_type$ extract (ByteBuffer buffer) {\n"
326     "                return "
327     "$extr_type$.getRootAs$extr_type_name$(buffer);\n"
328     "            }\n"
329     "        };\n"
330     "        return extractorOf$extr_type_name$;\n"
331     "    }\n"
332     "}\n\n");
333 }
PrintMethodFields(Printer * p,VARS & vars,const ServiceDescriptor * service)334 static void PrintMethodFields(Printer* p, VARS& vars,
335                               const ServiceDescriptor* service) {
336   p->Print("// Static method descriptors that strictly reflect the proto.\n");
337   vars["service_name"] = service->name();
338 
339   //set of names of rpc input- and output- types that were already encountered.
340   //this is needed to avoid duplicating type extractor since it's possible that
341   //the same type is used as an input or output type of more than a single RPC method
342   std::set<std::string> encounteredTypes;
343 
344   for (int i = 0; i < service->method_count(); ++i) {
345     auto method = service->method(i);
346     vars["arg_in_id"] = to_string(2L * i); //trying to make msvc 10 happy
347     vars["arg_out_id"] = to_string(2L * i + 1);
348     vars["method_name"] = method->name();
349     vars["input_type_name"] = method->get_input_type_name();
350     vars["output_type_name"] = method->get_output_type_name();
351     vars["input_type"] = JavaClassName(vars, method->get_input_type_name());
352     vars["output_type"] = JavaClassName(vars, method->get_output_type_name());
353     vars["method_field_name"] = MethodPropertiesFieldName(method.get());
354     vars["method_new_field_name"] = MethodPropertiesGetterName(method.get());
355     vars["method_method_name"] = MethodPropertiesGetterName(method.get());
356     bool client_streaming = method->ClientStreaming() || method->BidiStreaming();
357     bool server_streaming = method->ServerStreaming() || method->BidiStreaming();
358     if (client_streaming) {
359       if (server_streaming) {
360         vars["method_type"] = "BIDI_STREAMING";
361       } else {
362         vars["method_type"] = "CLIENT_STREAMING";
363       }
364     } else {
365       if (server_streaming) {
366         vars["method_type"] = "SERVER_STREAMING";
367       } else {
368         vars["method_type"] = "UNARY";
369       }
370     }
371 
372     p->Print(
373         vars,
374         "@$ExperimentalApi$(\"https://github.com/grpc/grpc-java/issues/"
375         "1901\")\n"
376         "@$Deprecated$ // Use {@link #$method_method_name$()} instead. \n"
377         "public static final $MethodDescriptor$<$input_type$,\n"
378         "    $output_type$> $method_field_name$ = $method_method_name$();\n"
379         "\n"
380         "private static volatile $MethodDescriptor$<$input_type$,\n"
381         "    $output_type$> $method_new_field_name$;\n"
382         "\n");
383 
384     if (encounteredTypes.insert(vars["input_type_name"]).second) {
385       vars["extr_type"] = vars["input_type"];
386       vars["extr_type_name"] = vars["input_type_name"];
387       PrintTypeExtractor(p, vars);
388     }
389 
390     if (encounteredTypes.insert(vars["output_type_name"]).second) {
391       vars["extr_type"] = vars["output_type"];
392       vars["extr_type_name"] = vars["output_type_name"];
393       PrintTypeExtractor(p, vars);
394     }
395 
396     p->Print(
397       vars,
398       "@$ExperimentalApi$(\"https://github.com/grpc/grpc-java/issues/"
399       "1901\")\n"
400       "public static $MethodDescriptor$<$input_type$,\n"
401       "    $output_type$> $method_method_name$() {\n"
402       "  $MethodDescriptor$<$input_type$, $output_type$> "
403       "$method_new_field_name$;\n"
404       "  if (($method_new_field_name$ = "
405       "$service_class_name$.$method_new_field_name$) == null) {\n"
406       "    synchronized ($service_class_name$.class) {\n"
407       "      if (($method_new_field_name$ = "
408       "$service_class_name$.$method_new_field_name$) == null) {\n"
409       "        $service_class_name$.$method_new_field_name$ = "
410       "$method_new_field_name$ = \n"
411       "            $MethodDescriptor$.<$input_type$, "
412       "$output_type$>newBuilder()\n"
413       "            .setType($MethodType$.$method_type$)\n"
414       "            .setFullMethodName(generateFullMethodName(\n"
415       "                \"$Package$$service_name$\", \"$method_name$\"))\n"
416       "            .setSampledToLocalTracing(true)\n"
417       "            .setRequestMarshaller(FlatbuffersUtils.marshaller(\n"
418       "                $input_type$.class, "
419       "getExtractorOf$input_type_name$()))\n"
420       "            .setResponseMarshaller(FlatbuffersUtils.marshaller(\n"
421       "                $output_type$.class, "
422       "getExtractorOf$output_type_name$()))\n");
423 
424     //            vars["proto_method_descriptor_supplier"] = service->name() +
425     //            "MethodDescriptorSupplier";
426     p->Print(vars, "                .setSchemaDescriptor(null)\n");
427     //"                .setSchemaDescriptor(new
428     //$proto_method_descriptor_supplier$(\"$method_name$\"))\n");
429 
430     p->Print(vars, "                .build();\n");
431     p->Print(vars,
432              "        }\n"
433              "      }\n"
434              "   }\n"
435              "   return $method_new_field_name$;\n"
436              "}\n");
437 
438     p->Print("\n");
439   }
440 }
441 enum StubType {
442   ASYNC_INTERFACE = 0,
443   BLOCKING_CLIENT_INTERFACE = 1,
444   FUTURE_CLIENT_INTERFACE = 2,
445   BLOCKING_SERVER_INTERFACE = 3,
446   ASYNC_CLIENT_IMPL = 4,
447   BLOCKING_CLIENT_IMPL = 5,
448   FUTURE_CLIENT_IMPL = 6,
449   ABSTRACT_CLASS = 7,
450 };
451 
452 enum CallType { ASYNC_CALL = 0, BLOCKING_CALL = 1, FUTURE_CALL = 2 };
453 
454 static void PrintBindServiceMethodBody(Printer* p, VARS& vars,
455                                        const ServiceDescriptor* service);
456 
457 // Prints a client interface or implementation class, or a server interface.
PrintStub(Printer * p,VARS & vars,const ServiceDescriptor * service,StubType type)458 static void PrintStub(Printer* p, VARS& vars, const ServiceDescriptor* service,
459                       StubType type) {
460   const string service_name = service->name();
461   vars["service_name"] = service_name;
462   vars["abstract_name"] = service_name + "ImplBase";
463   string stub_name = service_name;
464   string client_name = service_name;
465   CallType call_type = ASYNC_CALL;
466   bool impl_base = false;
467   bool interface = false;
468   switch (type) {
469     case ABSTRACT_CLASS:
470       call_type = ASYNC_CALL;
471       impl_base = true;
472       break;
473     case ASYNC_CLIENT_IMPL:
474       call_type = ASYNC_CALL;
475       stub_name += "Stub";
476       break;
477     case BLOCKING_CLIENT_INTERFACE:
478       interface = true;
479       FLATBUFFERS_FALLTHROUGH(); // fall thru
480     case BLOCKING_CLIENT_IMPL:
481       call_type = BLOCKING_CALL;
482       stub_name += "BlockingStub";
483       client_name += "BlockingClient";
484       break;
485     case FUTURE_CLIENT_INTERFACE:
486       interface = true;
487       FLATBUFFERS_FALLTHROUGH(); // fall thru
488     case FUTURE_CLIENT_IMPL:
489       call_type = FUTURE_CALL;
490       stub_name += "FutureStub";
491       client_name += "FutureClient";
492       break;
493     case ASYNC_INTERFACE:
494       call_type = ASYNC_CALL;
495       interface = true;
496       break;
497     default:
498       GRPC_CODEGEN_FAIL << "Cannot determine class name for StubType: " << type;
499   }
500   vars["stub_name"] = stub_name;
501   vars["client_name"] = client_name;
502 
503   // Class head
504   if (!interface) {
505     GrpcWriteServiceDocComment(p, vars, service);
506   }
507   if (impl_base) {
508     p->Print(vars,
509              "public static abstract class $abstract_name$ implements "
510              "$BindableService$ {\n");
511   } else {
512     p->Print(vars,
513              "public static final class $stub_name$ extends "
514              "$AbstractStub$<$stub_name$> {\n");
515   }
516   p->Indent();
517 
518   // Constructor and build() method
519   if (!impl_base && !interface) {
520     p->Print(vars, "private $stub_name$($Channel$ channel) {\n");
521     p->Indent();
522     p->Print("super(channel);\n");
523     p->Outdent();
524     p->Print("}\n\n");
525     p->Print(vars,
526              "private $stub_name$($Channel$ channel,\n"
527              "    $CallOptions$ callOptions) {\n");
528     p->Indent();
529     p->Print("super(channel, callOptions);\n");
530     p->Outdent();
531     p->Print("}\n\n");
532     p->Print(vars,
533              "@$Override$\n"
534              "protected $stub_name$ build($Channel$ channel,\n"
535              "    $CallOptions$ callOptions) {\n");
536     p->Indent();
537     p->Print(vars, "return new $stub_name$(channel, callOptions);\n");
538     p->Outdent();
539     p->Print("}\n");
540   }
541 
542   // RPC methods
543   for (int i = 0; i < service->method_count(); ++i) {
544     auto method = service->method(i);
545     vars["input_type"] = JavaClassName(vars, method->get_input_type_name());
546     vars["output_type"] = JavaClassName(vars, method->get_output_type_name());
547     vars["lower_method_name"] = LowerMethodName(&*method);
548     vars["method_method_name"] = MethodPropertiesGetterName(&*method);
549     bool client_streaming = method->ClientStreaming() || method->BidiStreaming();
550     bool server_streaming = method->ServerStreaming() || method->BidiStreaming();
551 
552     if (call_type == BLOCKING_CALL && client_streaming) {
553       // Blocking client interface with client streaming is not available
554       continue;
555     }
556 
557     if (call_type == FUTURE_CALL && (client_streaming || server_streaming)) {
558       // Future interface doesn't support streaming.
559       continue;
560     }
561 
562     // Method signature
563     p->Print("\n");
564     // TODO(nmittler): Replace with WriteMethodDocComment once included by the
565     // protobuf distro.
566     if (!interface) {
567       GrpcWriteMethodDocComment(p, vars, &*method);
568     }
569     p->Print("public ");
570     switch (call_type) {
571       case BLOCKING_CALL:
572         GRPC_CODEGEN_CHECK(!client_streaming)
573             << "Blocking client interface with client streaming is unavailable";
574         if (server_streaming) {
575           // Server streaming
576           p->Print(vars,
577                    "$Iterator$<$output_type$> $lower_method_name$(\n"
578                    "    $input_type$ request)");
579         } else {
580           // Simple RPC
581           p->Print(vars,
582                    "$output_type$ $lower_method_name$($input_type$ request)");
583         }
584         break;
585       case ASYNC_CALL:
586         if (client_streaming) {
587           // Bidirectional streaming or client streaming
588           p->Print(vars,
589                    "$StreamObserver$<$input_type$> $lower_method_name$(\n"
590                    "    $StreamObserver$<$output_type$> responseObserver)");
591         } else {
592           // Server streaming or simple RPC
593           p->Print(vars,
594                    "void $lower_method_name$($input_type$ request,\n"
595                    "    $StreamObserver$<$output_type$> responseObserver)");
596         }
597         break;
598       case FUTURE_CALL:
599         GRPC_CODEGEN_CHECK(!client_streaming && !server_streaming)
600             << "Future interface doesn't support streaming. "
601             << "client_streaming=" << client_streaming << ", "
602             << "server_streaming=" << server_streaming;
603         p->Print(vars,
604                  "$ListenableFuture$<$output_type$> $lower_method_name$(\n"
605                  "    $input_type$ request)");
606         break;
607     }
608 
609     if (interface) {
610       p->Print(";\n");
611       continue;
612     }
613     // Method body.
614     p->Print(" {\n");
615     p->Indent();
616     if (impl_base) {
617       switch (call_type) {
618           // NB: Skipping validation of service methods. If something is wrong,
619           // we wouldn't get to this point as compiler would return errors when
620           // generating service interface.
621         case ASYNC_CALL:
622           if (client_streaming) {
623             p->Print(vars,
624                      "return "
625                      "asyncUnimplementedStreamingCall($method_method_name$(), "
626                      "responseObserver);\n");
627           } else {
628             p->Print(vars,
629                      "asyncUnimplementedUnaryCall($method_method_name$(), "
630                      "responseObserver);\n");
631           }
632           break;
633         default:
634           break;
635       }
636     } else if (!interface) {
637       switch (call_type) {
638         case BLOCKING_CALL:
639           GRPC_CODEGEN_CHECK(!client_streaming)
640               << "Blocking client streaming interface is not available";
641           if (server_streaming) {
642             vars["calls_method"] = "blockingServerStreamingCall";
643             vars["params"] = "request";
644           } else {
645             vars["calls_method"] = "blockingUnaryCall";
646             vars["params"] = "request";
647           }
648           p->Print(vars,
649                    "return $calls_method$(\n"
650                    "    getChannel(), $method_method_name$(), "
651                    "getCallOptions(), $params$);\n");
652           break;
653         case ASYNC_CALL:
654           if (server_streaming) {
655             if (client_streaming) {
656               vars["calls_method"] = "asyncBidiStreamingCall";
657               vars["params"] = "responseObserver";
658             } else {
659               vars["calls_method"] = "asyncServerStreamingCall";
660               vars["params"] = "request, responseObserver";
661             }
662           } else {
663             if (client_streaming) {
664               vars["calls_method"] = "asyncClientStreamingCall";
665               vars["params"] = "responseObserver";
666             } else {
667               vars["calls_method"] = "asyncUnaryCall";
668               vars["params"] = "request, responseObserver";
669             }
670           }
671           vars["last_line_prefix"] = client_streaming ? "return " : "";
672           p->Print(vars,
673                    "$last_line_prefix$$calls_method$(\n"
674                    "    getChannel().newCall($method_method_name$(), "
675                    "getCallOptions()), $params$);\n");
676           break;
677         case FUTURE_CALL:
678           GRPC_CODEGEN_CHECK(!client_streaming && !server_streaming)
679               << "Future interface doesn't support streaming. "
680               << "client_streaming=" << client_streaming << ", "
681               << "server_streaming=" << server_streaming;
682           vars["calls_method"] = "futureUnaryCall";
683           p->Print(vars,
684                    "return $calls_method$(\n"
685                    "    getChannel().newCall($method_method_name$(), "
686                    "getCallOptions()), request);\n");
687           break;
688       }
689     }
690     p->Outdent();
691     p->Print("}\n");
692   }
693 
694   if (impl_base) {
695     p->Print("\n");
696     p->Print(
697         vars,
698         "@$Override$ public final $ServerServiceDefinition$ bindService() {\n");
699     vars["instance"] = "this";
700     PrintBindServiceMethodBody(p, vars, service);
701     p->Print("}\n");
702   }
703 
704   p->Outdent();
705   p->Print("}\n\n");
706 }
707 
CompareMethodClientStreaming(const std::unique_ptr<const grpc_generator::Method> & method1,const std::unique_ptr<const grpc_generator::Method> & method2)708 static bool CompareMethodClientStreaming(
709     const std::unique_ptr<const grpc_generator::Method>& method1,
710     const std::unique_ptr<const grpc_generator::Method>& method2) {
711   return method1->ClientStreaming() < method2->ClientStreaming();
712 }
713 
714 // Place all method invocations into a single class to reduce memory footprint
715 // on Android.
PrintMethodHandlerClass(Printer * p,VARS & vars,const ServiceDescriptor * service)716 static void PrintMethodHandlerClass(Printer* p, VARS& vars,
717                                     const ServiceDescriptor* service) {
718   // Sort method ids based on ClientStreaming() so switch tables are compact.
719   std::vector<std::unique_ptr<const grpc_generator::Method>> sorted_methods(
720       service->method_count());
721   for (int i = 0; i < service->method_count(); ++i) {
722     sorted_methods[i] = service->method(i);
723   }
724   stable_sort(sorted_methods.begin(), sorted_methods.end(),
725               CompareMethodClientStreaming);
726   for (size_t i = 0; i < sorted_methods.size(); i++) {
727     auto& method = sorted_methods[i];
728     vars["method_id"] = to_string(i);
729     vars["method_id_name"] = MethodIdFieldName(&*method);
730     p->Print(vars,
731              "private static final int $method_id_name$ = $method_id$;\n");
732   }
733   p->Print("\n");
734   vars["service_name"] = service->name() + "ImplBase";
735   p->Print(vars,
736            "private static final class MethodHandlers<Req, Resp> implements\n"
737            "    io.grpc.stub.ServerCalls.UnaryMethod<Req, Resp>,\n"
738            "    io.grpc.stub.ServerCalls.ServerStreamingMethod<Req, Resp>,\n"
739            "    io.grpc.stub.ServerCalls.ClientStreamingMethod<Req, Resp>,\n"
740            "    io.grpc.stub.ServerCalls.BidiStreamingMethod<Req, Resp> {\n"
741            "  private final $service_name$ serviceImpl;\n"
742            "  private final int methodId;\n"
743            "\n"
744            "  MethodHandlers($service_name$ serviceImpl, int methodId) {\n"
745            "    this.serviceImpl = serviceImpl;\n"
746            "    this.methodId = methodId;\n"
747            "  }\n\n");
748   p->Indent();
749   p->Print(vars,
750            "@$Override$\n"
751            "@java.lang.SuppressWarnings(\"unchecked\")\n"
752            "public void invoke(Req request, $StreamObserver$<Resp> "
753            "responseObserver) {\n"
754            "  switch (methodId) {\n");
755   p->Indent();
756   p->Indent();
757 
758   for (int i = 0; i < service->method_count(); ++i) {
759     auto method = service->method(i);
760     if (method->ClientStreaming() || method->BidiStreaming()) {
761       continue;
762     }
763     vars["method_id_name"] = MethodIdFieldName(&*method);
764     vars["lower_method_name"] = LowerMethodName(&*method);
765     vars["input_type"] = JavaClassName(vars, method->get_input_type_name());
766     vars["output_type"] = JavaClassName(vars, method->get_output_type_name());
767     p->Print(vars,
768              "case $method_id_name$:\n"
769              "  serviceImpl.$lower_method_name$(($input_type$) request,\n"
770              "      ($StreamObserver$<$output_type$>) responseObserver);\n"
771              "  break;\n");
772   }
773   p->Print(
774       "default:\n"
775       "  throw new AssertionError();\n");
776 
777   p->Outdent();
778   p->Outdent();
779   p->Print(
780       "  }\n"
781       "}\n\n");
782 
783   p->Print(vars,
784            "@$Override$\n"
785            "@java.lang.SuppressWarnings(\"unchecked\")\n"
786            "public $StreamObserver$<Req> invoke(\n"
787            "    $StreamObserver$<Resp> responseObserver) {\n"
788            "  switch (methodId) {\n");
789   p->Indent();
790   p->Indent();
791 
792   for (int i = 0; i < service->method_count(); ++i) {
793     auto method = service->method(i);
794     if (!(method->ClientStreaming() || method->BidiStreaming())) {
795       continue;
796     }
797     vars["method_id_name"] = MethodIdFieldName(&*method);
798     vars["lower_method_name"] = LowerMethodName(&*method);
799     vars["input_type"] = JavaClassName(vars, method->get_input_type_name());
800     vars["output_type"] = JavaClassName(vars, method->get_output_type_name());
801     p->Print(
802         vars,
803         "case $method_id_name$:\n"
804         "  return ($StreamObserver$<Req>) serviceImpl.$lower_method_name$(\n"
805         "      ($StreamObserver$<$output_type$>) responseObserver);\n");
806   }
807   p->Print(
808       "default:\n"
809       "  throw new AssertionError();\n");
810 
811   p->Outdent();
812   p->Outdent();
813   p->Print(
814       "  }\n"
815       "}\n");
816 
817   p->Outdent();
818   p->Print("}\n\n");
819 }
820 
PrintGetServiceDescriptorMethod(Printer * p,VARS & vars,const ServiceDescriptor * service)821 static void PrintGetServiceDescriptorMethod(Printer* p, VARS& vars,
822                                             const ServiceDescriptor* service) {
823   vars["service_name"] = service->name();
824   //        vars["proto_base_descriptor_supplier"] = service->name() +
825   //        "BaseDescriptorSupplier"; vars["proto_file_descriptor_supplier"] =
826   //        service->name() + "FileDescriptorSupplier";
827   //        vars["proto_method_descriptor_supplier"] = service->name() +
828   //        "MethodDescriptorSupplier"; vars["proto_class_name"] =
829   //        google::protobuf::compiler::java::ClassName(service->file());
830   //        p->Print(
831   //                 vars,
832   //                 "private static abstract class
833   //                 $proto_base_descriptor_supplier$\n" "    implements
834   //                 $ProtoFileDescriptorSupplier$,
835   //                 $ProtoServiceDescriptorSupplier$ {\n" "
836   //                 $proto_base_descriptor_supplier$() {}\n"
837   //                 "\n"
838   //                 "  @$Override$\n"
839   //                 "  public com.google.protobuf.Descriptors.FileDescriptor
840   //                 getFileDescriptor() {\n" "    return
841   //                 $proto_class_name$.getDescriptor();\n" "  }\n"
842   //                 "\n"
843   //                 "  @$Override$\n"
844   //                 "  public com.google.protobuf.Descriptors.ServiceDescriptor
845   //                 getServiceDescriptor() {\n" "    return
846   //                 getFileDescriptor().findServiceByName(\"$service_name$\");\n"
847   //                 "  }\n"
848   //                 "}\n"
849   //                 "\n"
850   //                 "private static final class
851   //                 $proto_file_descriptor_supplier$\n" "    extends
852   //                 $proto_base_descriptor_supplier$ {\n" "
853   //                 $proto_file_descriptor_supplier$() {}\n"
854   //                 "}\n"
855   //                 "\n"
856   //                 "private static final class
857   //                 $proto_method_descriptor_supplier$\n" "    extends
858   //                 $proto_base_descriptor_supplier$\n" "    implements
859   //                 $ProtoMethodDescriptorSupplier$ {\n" "  private final
860   //                 String methodName;\n"
861   //                 "\n"
862   //                 "  $proto_method_descriptor_supplier$(String methodName)
863   //                 {\n" "    this.methodName = methodName;\n" "  }\n"
864   //                 "\n"
865   //                 "  @$Override$\n"
866   //                 "  public com.google.protobuf.Descriptors.MethodDescriptor
867   //                 getMethodDescriptor() {\n" "    return
868   //                 getServiceDescriptor().findMethodByName(methodName);\n" "
869   //                 }\n"
870   //                 "}\n\n");
871 
872   p->Print(
873       vars,
874       "private static volatile $ServiceDescriptor$ serviceDescriptor;\n\n");
875 
876   p->Print(vars,
877            "public static $ServiceDescriptor$ getServiceDescriptor() {\n");
878   p->Indent();
879   p->Print(vars, "$ServiceDescriptor$ result = serviceDescriptor;\n");
880   p->Print("if (result == null) {\n");
881   p->Indent();
882   p->Print(vars, "synchronized ($service_class_name$.class) {\n");
883   p->Indent();
884   p->Print("result = serviceDescriptor;\n");
885   p->Print("if (result == null) {\n");
886   p->Indent();
887 
888   p->Print(vars,
889            "serviceDescriptor = result = "
890            "$ServiceDescriptor$.newBuilder(SERVICE_NAME)");
891   p->Indent();
892   p->Indent();
893   p->Print(vars, "\n.setSchemaDescriptor(null)");
894   for (int i = 0; i < service->method_count(); ++i) {
895     auto method = service->method(i);
896     vars["method_method_name"] = MethodPropertiesGetterName(&*method);
897     p->Print(vars, "\n.addMethod($method_method_name$())");
898   }
899   p->Print("\n.build();\n");
900   p->Outdent();
901   p->Outdent();
902 
903   p->Outdent();
904   p->Print("}\n");
905   p->Outdent();
906   p->Print("}\n");
907   p->Outdent();
908   p->Print("}\n");
909   p->Print("return result;\n");
910   p->Outdent();
911   p->Print("}\n");
912 }
913 
PrintBindServiceMethodBody(Printer * p,VARS & vars,const ServiceDescriptor * service)914 static void PrintBindServiceMethodBody(Printer* p, VARS& vars,
915                                        const ServiceDescriptor* service) {
916   vars["service_name"] = service->name();
917   p->Indent();
918   p->Print(vars,
919            "return "
920            "$ServerServiceDefinition$.builder(getServiceDescriptor())\n");
921   p->Indent();
922   p->Indent();
923   for (int i = 0; i < service->method_count(); ++i) {
924     auto method = service->method(i);
925     vars["lower_method_name"] = LowerMethodName(&*method);
926     vars["method_method_name"] = MethodPropertiesGetterName(&*method);
927     vars["input_type"] = JavaClassName(vars, method->get_input_type_name());
928     vars["output_type"] = JavaClassName(vars, method->get_output_type_name());
929     vars["method_id_name"] = MethodIdFieldName(&*method);
930     bool client_streaming = method->ClientStreaming() || method->BidiStreaming();
931     bool server_streaming = method->ServerStreaming() || method->BidiStreaming();
932     if (client_streaming) {
933       if (server_streaming) {
934         vars["calls_method"] = "asyncBidiStreamingCall";
935       } else {
936         vars["calls_method"] = "asyncClientStreamingCall";
937       }
938     } else {
939       if (server_streaming) {
940         vars["calls_method"] = "asyncServerStreamingCall";
941       } else {
942         vars["calls_method"] = "asyncUnaryCall";
943       }
944     }
945     p->Print(vars, ".addMethod(\n");
946     p->Indent();
947     p->Print(vars,
948              "$method_method_name$(),\n"
949              "$calls_method$(\n");
950     p->Indent();
951     p->Print(vars,
952              "new MethodHandlers<\n"
953              "  $input_type$,\n"
954              "  $output_type$>(\n"
955              "    $instance$, $method_id_name$)))\n");
956     p->Outdent();
957     p->Outdent();
958   }
959   p->Print(".build();\n");
960   p->Outdent();
961   p->Outdent();
962   p->Outdent();
963 }
964 
PrintService(Printer * p,VARS & vars,const ServiceDescriptor * service,bool disable_version)965 static void PrintService(Printer* p, VARS& vars,
966                          const ServiceDescriptor* service,
967                          bool disable_version) {
968   vars["service_name"] = service->name();
969   vars["service_class_name"] = ServiceClassName(service->name());
970   vars["grpc_version"] = "";
971 #ifdef GRPC_VERSION
972   if (!disable_version) {
973     vars["grpc_version"] = " (version " XSTR(GRPC_VERSION) ")";
974   }
975 #else
976   (void)disable_version;
977 #endif
978   // TODO(nmittler): Replace with WriteServiceDocComment once included by
979   // protobuf distro.
980   GrpcWriteServiceDocComment(p, vars, service);
981   p->Print(vars,
982            "@$Generated$(\n"
983            "    value = \"by gRPC proto compiler$grpc_version$\",\n"
984            "    comments = \"Source: $file_name$.fbs\")\n"
985            "public final class $service_class_name$ {\n\n");
986   p->Indent();
987   p->Print(vars, "private $service_class_name$() {}\n\n");
988 
989   p->Print(vars,
990            "public static final String SERVICE_NAME = "
991            "\"$Package$$service_name$\";\n\n");
992 
993   PrintMethodFields(p, vars, service);
994 
995   // TODO(nmittler): Replace with WriteDocComment once included by protobuf
996   // distro.
997   GrpcWriteDocComment(
998       p, vars,
999       " Creates a new async stub that supports all call types for the service");
1000   p->Print(vars,
1001            "public static $service_name$Stub newStub($Channel$ channel) {\n");
1002   p->Indent();
1003   p->Print(vars, "return new $service_name$Stub(channel);\n");
1004   p->Outdent();
1005   p->Print("}\n\n");
1006 
1007   // TODO(nmittler): Replace with WriteDocComment once included by protobuf
1008   // distro.
1009   GrpcWriteDocComment(
1010       p, vars,
1011       " Creates a new blocking-style stub that supports unary and streaming "
1012       "output calls on the service");
1013   p->Print(vars,
1014            "public static $service_name$BlockingStub newBlockingStub(\n"
1015            "    $Channel$ channel) {\n");
1016   p->Indent();
1017   p->Print(vars, "return new $service_name$BlockingStub(channel);\n");
1018   p->Outdent();
1019   p->Print("}\n\n");
1020 
1021   // TODO(nmittler): Replace with WriteDocComment once included by protobuf
1022   // distro.
1023   GrpcWriteDocComment(
1024       p, vars,
1025       " Creates a new ListenableFuture-style stub that supports unary calls "
1026       "on the service");
1027   p->Print(vars,
1028            "public static $service_name$FutureStub newFutureStub(\n"
1029            "    $Channel$ channel) {\n");
1030   p->Indent();
1031   p->Print(vars, "return new $service_name$FutureStub(channel);\n");
1032   p->Outdent();
1033   p->Print("}\n\n");
1034 
1035   PrintStub(p, vars, service, ABSTRACT_CLASS);
1036   PrintStub(p, vars, service, ASYNC_CLIENT_IMPL);
1037   PrintStub(p, vars, service, BLOCKING_CLIENT_IMPL);
1038   PrintStub(p, vars, service, FUTURE_CLIENT_IMPL);
1039 
1040   PrintMethodHandlerClass(p, vars, service);
1041   PrintGetServiceDescriptorMethod(p, vars, service);
1042   p->Outdent();
1043   p->Print("}\n");
1044 }
1045 
PrintStaticImports(Printer * p)1046 void PrintStaticImports(Printer* p) {
1047   p->Print(
1048       "import java.nio.ByteBuffer;\n"
1049       "import static "
1050       "io.grpc.MethodDescriptor.generateFullMethodName;\n"
1051       "import static "
1052       "io.grpc.stub.ClientCalls.asyncBidiStreamingCall;\n"
1053       "import static "
1054       "io.grpc.stub.ClientCalls.asyncClientStreamingCall;\n"
1055       "import static "
1056       "io.grpc.stub.ClientCalls.asyncServerStreamingCall;\n"
1057       "import static "
1058       "io.grpc.stub.ClientCalls.asyncUnaryCall;\n"
1059       "import static "
1060       "io.grpc.stub.ClientCalls.blockingServerStreamingCall;\n"
1061       "import static "
1062       "io.grpc.stub.ClientCalls.blockingUnaryCall;\n"
1063       "import static "
1064       "io.grpc.stub.ClientCalls.futureUnaryCall;\n"
1065       "import static "
1066       "io.grpc.stub.ServerCalls.asyncBidiStreamingCall;\n"
1067       "import static "
1068       "io.grpc.stub.ServerCalls.asyncClientStreamingCall;\n"
1069       "import static "
1070       "io.grpc.stub.ServerCalls.asyncServerStreamingCall;\n"
1071       "import static "
1072       "io.grpc.stub.ServerCalls.asyncUnaryCall;\n"
1073       "import static "
1074       "io.grpc.stub.ServerCalls.asyncUnimplementedStreamingCall;\n"
1075       "import static "
1076       "io.grpc.stub.ServerCalls.asyncUnimplementedUnaryCall;\n\n");
1077 }
1078 
GenerateService(const grpc_generator::Service * service,grpc_generator::Printer * printer,VARS & vars,bool disable_version)1079 void GenerateService(const grpc_generator::Service* service,
1080                      grpc_generator::Printer* printer, VARS& vars,
1081                      bool disable_version) {
1082   // All non-generated classes must be referred by fully qualified names to
1083   // avoid collision with generated classes.
1084   vars["String"] = "java.lang.String";
1085   vars["Deprecated"] = "java.lang.Deprecated";
1086   vars["Override"] = "java.lang.Override";
1087   vars["Channel"] = "io.grpc.Channel";
1088   vars["CallOptions"] = "io.grpc.CallOptions";
1089   vars["MethodType"] = "io.grpc.MethodDescriptor.MethodType";
1090   vars["ServerMethodDefinition"] = "io.grpc.ServerMethodDefinition";
1091   vars["BindableService"] = "io.grpc.BindableService";
1092   vars["ServerServiceDefinition"] = "io.grpc.ServerServiceDefinition";
1093   vars["ServiceDescriptor"] = "io.grpc.ServiceDescriptor";
1094   vars["ProtoFileDescriptorSupplier"] =
1095       "io.grpc.protobuf.ProtoFileDescriptorSupplier";
1096   vars["ProtoServiceDescriptorSupplier"] =
1097       "io.grpc.protobuf.ProtoServiceDescriptorSupplier";
1098   vars["ProtoMethodDescriptorSupplier"] =
1099       "io.grpc.protobuf.ProtoMethodDescriptorSupplier";
1100   vars["AbstractStub"] = "io.grpc.stub.AbstractStub";
1101   vars["MethodDescriptor"] = "io.grpc.MethodDescriptor";
1102   vars["NanoUtils"] = "io.grpc.protobuf.nano.NanoUtils";
1103   vars["StreamObserver"] = "io.grpc.stub.StreamObserver";
1104   vars["Iterator"] = "java.util.Iterator";
1105   vars["Generated"] = "javax.annotation.Generated";
1106   vars["ListenableFuture"] =
1107       "com.google.common.util.concurrent.ListenableFuture";
1108   vars["ExperimentalApi"] = "io.grpc.ExperimentalApi";
1109 
1110   PrintStaticImports(printer);
1111 
1112   PrintService(printer, vars, service, disable_version);
1113 }
1114 
GenerateServiceSource(grpc_generator::File * file,const grpc_generator::Service * service,grpc_java_generator::Parameters * parameters)1115 grpc::string GenerateServiceSource(
1116     grpc_generator::File* file, const grpc_generator::Service* service,
1117     grpc_java_generator::Parameters* parameters) {
1118   grpc::string out;
1119   auto printer = file->CreatePrinter(&out);
1120   VARS vars;
1121   vars["flatc_version"] = grpc::string(
1122       FLATBUFFERS_STRING(FLATBUFFERS_VERSION_MAJOR) "." FLATBUFFERS_STRING(
1123           FLATBUFFERS_VERSION_MINOR) "." FLATBUFFERS_STRING(FLATBUFFERS_VERSION_REVISION));
1124 
1125   vars["file_name"] = file->filename();
1126 
1127   if (!parameters->package_name.empty()) {
1128     vars["Package"] = parameters->package_name;  // ServiceJavaPackage(service);
1129   }
1130   GenerateImports(file, &*printer, vars);
1131   GenerateService(service, &*printer, vars, false);
1132   return out;
1133 }
1134 
1135 }  // namespace grpc_java_generator
1136