• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <list>
17 #include <map>
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/core/framework/op_gen_lib.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/io/path.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/platform/env.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/java/src/gen/cc/java_defs.h"
31 #include "tensorflow/java/src/gen/cc/op_generator.h"
32 #include "tensorflow/java/src/gen/cc/op_specs.h"
33 #include "tensorflow/java/src/gen/cc/source_writer.h"
34 
35 namespace tensorflow {
36 namespace java {
37 namespace {
38 
39 constexpr const char kLicense[] =
40     "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n"
41     "\n"
42     "Licensed under the Apache License, Version 2.0 (the \"License\");\n"
43     "you may not use this file except in compliance with the License.\n"
44     "You may obtain a copy of the License at\n"
45     "\n"
46     "    http://www.apache.org/licenses/LICENSE-2.0\n"
47     "\n"
48     "Unless required by applicable law or agreed to in writing, software\n"
49     "distributed under the License is distributed on an \"AS IS\" BASIS,\n"
50     "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
51     "See the License for the specific language governing permissions and\n"
52     "limitations under the License.\n"
53     "=======================================================================*/"
54     "\n";
55 
56 // There is three different modes to render an op class, depending on the
57 // number and type of outputs it has:
58 //
59 // DEFAULT: This mode does not provide any specialization for the op class, it
60 //          is applied when the operation does not comply with any other mode
61 //
62 // OPERAND: The op class implements the Operand<T> interface, allowing an
63 //          instance to be passed directly in input to another operation
64 //
65 // LIST_OPERAND: The op class implements the Iterable<Operand<T>> interface,
66 //          allowing an instance to be passed directly as a list input to
67 //          another operation
68 //
69 enum RenderMode { DEFAULT, OPERAND, LIST_OPERAND };
70 
AddArgument(const Variable & var,const string & description,Method * method_out,Javadoc * javadoc_out)71 void AddArgument(const Variable& var, const string& description,
72                  Method* method_out, Javadoc* javadoc_out) {
73   method_out->add_argument(var);
74   javadoc_out->add_param_tag(var.name(), description);
75 }
76 
CollectOpDependencies(const OpSpec & op,RenderMode mode,std::list<Type> * out)77 void CollectOpDependencies(const OpSpec& op, RenderMode mode,
78                            std::list<Type>* out) {
79   out->push_back(Type::Class("Operation", "org.tensorflow"));
80   out->push_back(Type::Class("OperationBuilder", "org.tensorflow"));
81   out->push_back(Type::Class("Scope", "org.tensorflow.op"));
82   if (mode == OPERAND) {
83     out->push_back(Type::Class("Output", "org.tensorflow"));
84   } else if (mode == LIST_OPERAND) {
85     out->push_back(Type::Interface("Iterator", "java.util"));
86   }
87   // Don't pay attention to duplicate types in the dependency list, they will
88   // be filtered out by the SourceWriter.
89   for (const ArgumentSpec& input : op.inputs()) {
90     out->push_back(input.var().type());
91     if (input.iterable()) {
92       out->push_back(Type::Class("Operands", "org.tensorflow.op"));
93     }
94   }
95   for (const ArgumentSpec& output : op.outputs()) {
96     out->push_back(output.var().type());
97     if (output.iterable()) {
98       out->push_back(Type::Class("Arrays", "java.util"));
99     }
100   }
101   for (const AttributeSpec& attribute : op.attributes()) {
102     out->push_back(attribute.var().type());
103     out->push_back(attribute.jni_type());
104     if (attribute.has_default_value() &&
105         attribute.type().kind() == Type::GENERIC) {
106       out->push_back(Type::ForDataType(attribute.default_value()->type()));
107     }
108   }
109   for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
110     out->push_back(optional_attribute.var().type());
111   }
112 }
113 
WriteSetAttrDirective(const AttributeSpec & attr,bool optional,SourceWriter * writer)114 void WriteSetAttrDirective(const AttributeSpec& attr, bool optional,
115                            SourceWriter* writer) {
116   string var_name = optional ? "opts." + attr.var().name() : attr.var().name();
117   if (attr.iterable()) {
118     string array_name = attr.var().name() + "Array";
119     writer->AppendType(attr.jni_type())
120         .Append("[] " + array_name + " = new ")
121         .AppendType(attr.jni_type())
122         .Append("[" + var_name + ".size()];")
123         .EndLine()
124         .BeginBlock("for (int i = 0; i < " + array_name + ".length; ++i)")
125         .Append(array_name + "[i] = ");
126     if (attr.type().kind() == Type::GENERIC) {
127       writer->Append("DataType.fromClass(" + var_name + ".get(i));");
128     } else {
129       writer->Append(var_name + ".get(i);");
130     }
131     writer->EndLine()
132         .EndBlock()
133         .Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ")
134         .Append(array_name + ");")
135         .EndLine();
136   } else {
137     writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ");
138     if (attr.var().type().name() == "Class") {
139       writer->Append("DataType.fromClass(" + var_name + "));");
140     } else {
141       writer->Append(var_name + ");");
142     }
143     writer->EndLine();
144   }
145 }
146 
RenderSecondaryFactoryMethod(const OpSpec & op,const Type & op_class,std::map<string,Type> default_types,SourceWriter * writer)147 void RenderSecondaryFactoryMethod(const OpSpec& op, const Type& op_class,
148                                   std::map<string, Type> default_types,
149                                   SourceWriter* writer) {
150   // Build the return type for the secondary factory, replacing generic
151   // parameters with their default value if any
152   Type return_type = Type::Class(op_class.name(), op_class.package());
153   for (const Type& parameter : op_class.parameters()) {
154     if (parameter.kind() == Type::GENERIC &&
155         default_types.find(parameter.name()) != default_types.end()) {
156       return_type.add_parameter(default_types.at(parameter.name()));
157     } else {
158       return_type.add_parameter(parameter);
159     }
160   }
161   Method factory = Method::Create("create", return_type);
162   Javadoc factory_doc = Javadoc::Create(
163       "Factory method to create a class to wrap a new " + op_class.name() +
164       " operation to the graph, using "
165       "default output types.");
166   Variable scope =
167       Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op"));
168   AddArgument(scope, "current graph scope", &factory, &factory_doc);
169   std::stringstream factory_statement;
170   factory_statement << "return create(scope";
171   for (const ArgumentSpec& input : op.inputs()) {
172     AddArgument(input.var(), input.description(), &factory, &factory_doc);
173     factory_statement << ", " << input.var().name();
174   }
175   for (const AttributeSpec& attr : op.attributes()) {
176     // Only add attributes that are not types or have no default value to the
177     // signature of the secondary factory
178     factory_statement << ", ";
179     if (attr.type().kind() == Type::GENERIC &&
180         default_types.find(attr.type().name()) != default_types.end()) {
181       factory_statement << default_types.at(attr.type().name()).name()
182                         << ".class";
183     } else {
184       AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
185       factory_statement << attr.var().name();
186     }
187   }
188   if (!op.optional_attributes().empty()) {
189     Variable options_var = Variable::Varargs("options", Type::Class("Options"));
190     AddArgument(options_var, "carries optional attributes values", &factory,
191                 &factory_doc);
192     factory_statement << ", " << options_var.name();
193   }
194   factory_doc.add_tag("return", "a new instance of " + op_class.name());
195 
196   writer->BeginMethod(factory, PUBLIC | STATIC, &factory_doc);
197   writer->Append(factory_statement.str().c_str()).Append(");").EndLine();
198   writer->EndMethod();
199 }
200 
RenderFactoryMethods(const OpSpec & op,const Type & op_class,SourceWriter * writer)201 void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
202                           SourceWriter* writer) {
203   Method factory = Method::Create("create", op_class);
204   Javadoc factory_doc =
205       Javadoc::Create("Factory method to create a class to wrap a new " +
206                       op_class.name() + " operation to the graph.");
207   Variable scope =
208       Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op"));
209   AddArgument(scope, "current graph scope", &factory, &factory_doc);
210   for (const ArgumentSpec& input : op.inputs()) {
211     AddArgument(input.var(), input.description(), &factory, &factory_doc);
212   }
213   std::map<string, Type> default_types;
214   for (const AttributeSpec& attr : op.attributes()) {
215     AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
216     // If this attribute is a type with a default value, save its value
217     // for passing it implicitly in a secondary factory method
218     if (attr.has_default_value() && attr.type().kind() == Type::GENERIC) {
219       Type default_type = Type::ForDataType(attr.default_value()->type());
220       if (!default_type.wildcard()) {
221         default_types.insert(std::make_pair(attr.type().name(), default_type));
222       }
223     }
224   }
225   if (!op.optional_attributes().empty()) {
226     AddArgument(Variable::Varargs("options", Type::Class("Options")),
227                 "carries optional attributes values", &factory, &factory_doc);
228   }
229   factory_doc.add_tag("return", "a new instance of " + op_class.name());
230 
231   writer->BeginMethod(factory, PUBLIC | STATIC, &factory_doc);
232   writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\"" +
233                  op.graph_op_name() + "\", scope.makeOpName(\"" +
234                  op_class.name() + "\"));");
235   writer->EndLine();
236   for (const ArgumentSpec& input : op.inputs()) {
237     if (input.iterable()) {
238       writer->Append("opBuilder.addInputList(Operands.asOutputs(" +
239                      input.var().name() + "));");
240       writer->EndLine();
241     } else {
242       writer->Append("opBuilder.addInput(" + input.var().name() +
243                      ".asOutput());");
244       writer->EndLine();
245     }
246   }
247   for (const AttributeSpec& attribute : op.attributes()) {
248     WriteSetAttrDirective(attribute, false, writer);
249   }
250   if (!op.optional_attributes().empty()) {
251     writer->BeginBlock("if (options != null)")
252         .BeginBlock("for (Options opts : options)");
253     for (const AttributeSpec& attribute : op.optional_attributes()) {
254       writer->BeginBlock("if (opts." + attribute.var().name() + " != null)");
255       WriteSetAttrDirective(attribute, true, writer);
256       writer->EndBlock();
257     }
258     writer->EndBlock().EndBlock();
259   }
260   writer->Append("return new ")
261       .AppendType(op_class)
262       .Append("(opBuilder.build());")
263       .EndLine();
264   writer->EndMethod();
265 
266   // If this operation has type attributes with a default value, create a
267   // second factory method that infers those values implicitly
268   if (!default_types.empty()) {
269     RenderSecondaryFactoryMethod(op, op_class, default_types, writer);
270   }
271 }
272 
RenderConstructor(const OpSpec & op,const Type & op_class,SourceWriter * writer)273 void RenderConstructor(const OpSpec& op, const Type& op_class,
274                        SourceWriter* writer) {
275   Variable operation =
276       Variable::Create("operation", Type::Class("Operation", "org.tensorflow"));
277   Method constructor = Method::ConstructorFor(op_class).add_argument(operation);
278   for (const ArgumentSpec& output : op.outputs()) {
279     if (output.iterable() && !output.type().wildcard()) {
280       constructor.add_annotation(
281           Annotation::Create("SuppressWarnings").attributes("\"unchecked\""));
282       break;
283     }
284   }
285   writer->BeginMethod(constructor, PRIVATE)
286       .Append("super(operation);")
287       .EndLine();
288   if (!op.outputs().empty()) {
289     writer->Append("int outputIdx = 0;").EndLine();
290     for (const ArgumentSpec& output : op.outputs()) {
291       if (output.iterable()) {
292         string var_length = output.var().name() + "Length";
293         writer->Append("int " + var_length)
294             .Append(" = operation.outputListLength(\"" + output.op_def_name() +
295                     "\");")
296             .EndLine()
297             .Append(output.var().name() + " = Arrays.asList(");
298         if (!output.type().wildcard()) {
299           writer->Append("(")
300               .AppendType(output.var().type().parameters().front())
301               .Append("[])");
302         }
303         writer->Append("operation.outputList(outputIdx, " + var_length + "));")
304             .EndLine()
305             .Append("outputIdx += " + var_length + ";")
306             .EndLine();
307       } else {
308         writer
309             ->Append(output.var().name() + " = operation.output(outputIdx++);")
310             .EndLine();
311       }
312     }
313   }
314   writer->EndMethod();
315 }
316 
RenderGettersAndSetters(const OpSpec & op,SourceWriter * writer)317 void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) {
318   for (const AttributeSpec& attr : op.optional_attributes()) {
319     Method setter = Method::Create(attr.var().name(), Type::Class("Options"));
320     Javadoc setter_doc = Javadoc::Create();
321     AddArgument(attr.var(), attr.description(), &setter, &setter_doc);
322     writer->BeginMethod(setter, PUBLIC | STATIC, &setter_doc)
323         .Append("return new Options()." + attr.var().name() + "(" +
324                 attr.var().name() + ");")
325         .EndLine()
326         .EndMethod();
327   }
328   for (const ArgumentSpec& output : op.outputs()) {
329     Method getter = Method::Create(output.var().name(), output.var().type());
330     Javadoc getter_doc = Javadoc::Create(output.description());
331     writer->BeginMethod(getter, PUBLIC, &getter_doc)
332         .Append("return " + output.var().name() + ";")
333         .EndLine()
334         .EndMethod();
335   }
336 }
337 
RenderInterfaceImpl(const OpSpec & op,RenderMode mode,SourceWriter * writer)338 void RenderInterfaceImpl(const OpSpec& op, RenderMode mode,
339                          SourceWriter* writer) {
340   ArgumentSpec output = op.outputs().front();
341 
342   if (mode == OPERAND) {
343     bool cast2obj = output.type().wildcard();
344     Type return_type =
345         Type::Class("Output", "org.tensorflow")
346             .add_parameter(cast2obj ? Type::Class("Object") : output.type());
347     Method as_output = Method::Create("asOutput", return_type)
348                            .add_annotation(Annotation::Create("Override"));
349     if (cast2obj) {
350       as_output.add_annotation(
351           Annotation::Create("SuppressWarnings").attributes("\"unchecked\""));
352     }
353     writer->BeginMethod(as_output, PUBLIC);
354     if (cast2obj) {
355       writer->Append("return (").AppendType(return_type).Append(") ");
356     } else {
357       writer->Append("return ");
358     }
359     writer->Append(output.var().name() + ";").EndLine().EndMethod();
360 
361   } else if (mode == LIST_OPERAND) {
362     Type operand = Type::Interface("Operand", "org.tensorflow");
363     if (output.type().wildcard()) {
364       operand.add_parameter(Type::Class("Object"));
365     } else {
366       operand.add_parameter(output.type());
367     }
368     Type return_type =
369         Type::Interface("Iterator", "java.util").add_parameter(operand);
370     Method iterator =
371         Method::Create("iterator", return_type)
372             .add_annotation(Annotation::Create("Override"))
373             .add_annotation(Annotation::Create("SuppressWarnings")
374                                 .attributes("{\"rawtypes\", \"unchecked\"}"));
375     // cast the output list using a raw List
376     writer->BeginMethod(iterator, PUBLIC)
377         .Append("return (" + return_type.name() + ") ")
378         .Append(output.var().name() + ".iterator();")
379         .EndLine()
380         .EndMethod();
381   }
382 }
383 
RenderOptionsClass(const OpSpec & op,const Type & op_class,SourceWriter * writer)384 void RenderOptionsClass(const OpSpec& op, const Type& op_class,
385                         SourceWriter* writer) {
386   Type options_class = Type::Class("Options");
387   Javadoc options_doc = Javadoc::Create("Optional attributes for {@link " +
388                                         op_class.canonical_name() + "}");
389   writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc);
390   for (const AttributeSpec& attr : op.optional_attributes()) {
391     Method setter = Method::Create(attr.var().name(), options_class);
392     Javadoc setter_doc = Javadoc::Create();
393     AddArgument(attr.var(), attr.description(), &setter, &setter_doc);
394     writer->BeginMethod(setter, PUBLIC, &setter_doc)
395         .Append("this." + attr.var().name() + " = " + attr.var().name() + ";")
396         .EndLine()
397         .Append("return this;")
398         .EndLine()
399         .EndMethod();
400   }
401   writer->EndLine();
402   for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
403     writer->WriteField(optional_attribute.var(), PRIVATE);
404   }
405   Method constructor = Method::ConstructorFor(options_class);
406   writer->BeginMethod(constructor, PRIVATE).EndMethod();
407   writer->EndType();
408 }
409 
ClassOf(const EndpointSpec & endpoint,const string & base_package)410 inline Type ClassOf(const EndpointSpec& endpoint, const string& base_package) {
411   return Type::Class(
412       endpoint.name(),
413       base_package + "." + str_util::Lowercase(endpoint.package()));
414 }
415 
GenerateOp(const OpSpec & op,const EndpointSpec & endpoint,const string & base_package,const string & output_dir,Env * env)416 void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
417                 const string& base_package, const string& output_dir,
418                 Env* env) {
419   Type op_class(
420       ClassOf(endpoint, base_package)
421           .add_supertype(Type::Class("PrimitiveOp", "org.tensorflow.op")));
422   Javadoc op_javadoc(endpoint.javadoc());
423 
424   // op interfaces
425   RenderMode mode = DEFAULT;
426   if (op.outputs().size() == 1) {
427     const ArgumentSpec& output = op.outputs().front();
428     Type operand_type(output.type().wildcard() ? Type::Class("Object")
429                                                : output.type());
430     Type operand_inf(Type::Interface("Operand", "org.tensorflow")
431                          .add_parameter(operand_type));
432     if (output.iterable()) {
433       mode = LIST_OPERAND;
434       op_class.add_supertype(Type::IterableOf(operand_inf));
435     } else {
436       mode = OPERAND;
437       op_class.add_supertype(operand_inf);
438     }
439   }
440   // op generic parameters
441   std::set<string> generics;
442   for (const ArgumentSpec& output : op.outputs()) {
443     if (output.type().kind() == Type::GENERIC && !output.type().wildcard() &&
444         generics.find(output.type().name()) == generics.end()) {
445       op_class.add_parameter(output.type());
446       op_javadoc.add_param_tag(
447           "<" + output.type().name() + ">",
448           "data type for {@code " + output.var().name() + "()} output");
449       generics.insert(output.type().name());
450     }
451   }
452   // op annotations
453   if (endpoint.deprecated()) {
454     op_class.add_annotation(Annotation::Create("Deprecated"));
455     string explanation;
456     if (!op.endpoints().front().deprecated()) {
457       explanation =
458           "use {@link " +
459           ClassOf(op.endpoints().front(), base_package).canonical_name() +
460           "} instead";
461     } else {
462       explanation = op.deprecation_explanation();
463     }
464     op_javadoc.add_tag("deprecated", explanation);
465   }
466   if (!op.hidden()) {
467     // expose the op in the Ops Graph API only if it is visible
468     Annotation oper_annot =
469         Annotation::Create("Operator", "org.tensorflow.op.annotation");
470     if (endpoint.package() != kDefaultEndpointPackage) {
471       oper_annot.attributes("group = \"" + endpoint.package() + "\"");
472     }
473     op_class.add_annotation(oper_annot);
474   }
475   // create op class file
476   const string op_dir_name = io::JoinPath(
477       output_dir, str_util::StringReplace(op_class.package(), ".", "/", true));
478   if (!env->FileExists(op_dir_name).ok()) {
479     TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(op_dir_name))
480         << op_dir_name;
481   }
482   const string op_file_name = op_class.name() + ".java";
483   std::unique_ptr<tensorflow::WritableFile> op_file;
484   TF_CHECK_OK(
485       env->NewWritableFile(io::JoinPath(op_dir_name, op_file_name), &op_file))
486       << op_file_name;
487 
488   // render endpoint source code
489   SourceFileWriter writer(op_file.get());
490   std::list<Type> dependencies;
491   CollectOpDependencies(op, mode, &dependencies);
492   writer.Write(kLicense)
493       .EndLine()
494       .Write("// This class has been generated, DO NOT EDIT!")
495       .EndLine()
496       .EndLine()
497       .BeginType(op_class, PUBLIC | FINAL, &dependencies, &op_javadoc);
498   if (!op.optional_attributes().empty()) {
499     RenderOptionsClass(op, op_class, &writer);
500   }
501   RenderFactoryMethods(op, op_class, &writer);
502   RenderGettersAndSetters(op, &writer);
503   if (mode != DEFAULT) {
504     RenderInterfaceImpl(op, mode, &writer);
505   }
506   writer.EndLine();
507   for (const ArgumentSpec& output : op.outputs()) {
508     writer.WriteField(output.var(), PRIVATE);
509   }
510   RenderConstructor(op, op_class, &writer);
511   writer.EndType();
512 }
513 
CanGenerateOp(const OpDef & op_def,const ApiDef & api_def)514 bool CanGenerateOp(const OpDef& op_def, const ApiDef& api_def) {
515   if (api_def.visibility() == ApiDef::SKIP) {
516     return false;
517   }
518   for (const auto& attr : op_def.attr()) {
519     if (attr.type() == "func" || attr.type() == "list(func)") {
520       return false;  // TODO(karllessard) add support for function attributes
521     }
522   }
523   return true;
524 }
525 
526 }  // namespace
527 
Run(const OpList & op_list,const string & base_package,const string & output_dir)528 Status OpGenerator::Run(const OpList& op_list, const string& base_package,
529                         const string& output_dir) {
530   ApiDefMap api_map(op_list);
531   if (!api_dirs_.empty()) {
532     // Only load api files that correspond to the requested "op_list"
533     for (const auto& op : op_list.op()) {
534       for (const auto& api_def_dir : api_dirs_) {
535         const std::string api_def_file_pattern =
536             io::JoinPath(api_def_dir, "api_def_" + op.name() + ".pbtxt");
537         if (env_->FileExists(api_def_file_pattern).ok()) {
538           TF_CHECK_OK(api_map.LoadFile(env_, api_def_file_pattern))
539               << api_def_file_pattern;
540         }
541       }
542     }
543   }
544   api_map.UpdateDocs();
545   for (const auto& op_def : op_list.op()) {
546     const ApiDef* api_def = api_map.GetApiDef(op_def.name());
547     if (CanGenerateOp(op_def, *api_def)) {
548       OpSpec op(OpSpec::Create(op_def, *api_def));
549       for (const EndpointSpec& endpoint : op.endpoints()) {
550         GenerateOp(op, endpoint, base_package, output_dir, env_);
551       }
552     }
553   }
554   return Status::OK();
555 }
556 
557 }  // namespace java
558 }  // namespace tensorflow
559