1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <list>
17 #include <map>
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "tensorflow/core/framework/op_gen_lib.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/io/path.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/platform/env.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/java/src/gen/cc/java_defs.h"
31 #include "tensorflow/java/src/gen/cc/op_generator.h"
32 #include "tensorflow/java/src/gen/cc/op_specs.h"
33 #include "tensorflow/java/src/gen/cc/source_writer.h"
34
35 namespace tensorflow {
36 namespace java {
37 namespace {
38
39 constexpr const char kLicense[] =
40 "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n"
41 "\n"
42 "Licensed under the Apache License, Version 2.0 (the \"License\");\n"
43 "you may not use this file except in compliance with the License.\n"
44 "You may obtain a copy of the License at\n"
45 "\n"
46 " http://www.apache.org/licenses/LICENSE-2.0\n"
47 "\n"
48 "Unless required by applicable law or agreed to in writing, software\n"
49 "distributed under the License is distributed on an \"AS IS\" BASIS,\n"
50 "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
51 "See the License for the specific language governing permissions and\n"
52 "limitations under the License.\n"
53 "=======================================================================*/"
54 "\n";
55
56 // There is three different modes to render an op class, depending on the
57 // number and type of outputs it has:
58 //
59 // DEFAULT: This mode does not provide any specialization for the op class, it
60 // is applied when the operation does not comply with any other mode
61 //
62 // OPERAND: The op class implements the Operand<T> interface, allowing an
63 // instance to be passed directly in input to another operation
64 //
65 // LIST_OPERAND: The op class implements the Iterable<Operand<T>> interface,
66 // allowing an instance to be passed directly as a list input to
67 // another operation
68 //
69 enum RenderMode { DEFAULT, OPERAND, LIST_OPERAND };
70
AddArgument(const Variable & var,const string & description,Method * method_out,Javadoc * javadoc_out)71 void AddArgument(const Variable& var, const string& description,
72 Method* method_out, Javadoc* javadoc_out) {
73 method_out->add_argument(var);
74 javadoc_out->add_param_tag(var.name(), description);
75 }
76
CollectOpDependencies(const OpSpec & op,RenderMode mode,std::list<Type> * out)77 void CollectOpDependencies(const OpSpec& op, RenderMode mode,
78 std::list<Type>* out) {
79 out->push_back(Type::Class("Operation", "org.tensorflow"));
80 out->push_back(Type::Class("OperationBuilder", "org.tensorflow"));
81 out->push_back(Type::Class("Scope", "org.tensorflow.op"));
82 if (mode == OPERAND) {
83 out->push_back(Type::Class("Output", "org.tensorflow"));
84 } else if (mode == LIST_OPERAND) {
85 out->push_back(Type::Interface("Iterator", "java.util"));
86 }
87 // Don't pay attention to duplicate types in the dependency list, they will
88 // be filtered out by the SourceWriter.
89 for (const ArgumentSpec& input : op.inputs()) {
90 out->push_back(input.var().type());
91 if (input.iterable()) {
92 out->push_back(Type::Class("Operands", "org.tensorflow.op"));
93 }
94 }
95 for (const ArgumentSpec& output : op.outputs()) {
96 out->push_back(output.var().type());
97 if (output.iterable()) {
98 out->push_back(Type::Class("Arrays", "java.util"));
99 }
100 }
101 for (const AttributeSpec& attribute : op.attributes()) {
102 out->push_back(attribute.var().type());
103 out->push_back(attribute.jni_type());
104 if (attribute.has_default_value() &&
105 attribute.type().kind() == Type::GENERIC) {
106 out->push_back(Type::ForDataType(attribute.default_value()->type()));
107 }
108 }
109 for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
110 out->push_back(optional_attribute.var().type());
111 }
112 }
113
WriteSetAttrDirective(const AttributeSpec & attr,bool optional,SourceWriter * writer)114 void WriteSetAttrDirective(const AttributeSpec& attr, bool optional,
115 SourceWriter* writer) {
116 string var_name = optional ? "opts." + attr.var().name() : attr.var().name();
117 if (attr.iterable()) {
118 string array_name = attr.var().name() + "Array";
119 writer->AppendType(attr.jni_type())
120 .Append("[] " + array_name + " = new ")
121 .AppendType(attr.jni_type())
122 .Append("[" + var_name + ".size()];")
123 .EndLine()
124 .BeginBlock("for (int i = 0; i < " + array_name + ".length; ++i)")
125 .Append(array_name + "[i] = ");
126 if (attr.type().kind() == Type::GENERIC) {
127 writer->Append("DataType.fromClass(" + var_name + ".get(i));");
128 } else {
129 writer->Append(var_name + ".get(i);");
130 }
131 writer->EndLine()
132 .EndBlock()
133 .Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ")
134 .Append(array_name + ");")
135 .EndLine();
136 } else {
137 writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ");
138 if (attr.var().type().name() == "Class") {
139 writer->Append("DataType.fromClass(" + var_name + "));");
140 } else {
141 writer->Append(var_name + ");");
142 }
143 writer->EndLine();
144 }
145 }
146
RenderSecondaryFactoryMethod(const OpSpec & op,const Type & op_class,std::map<string,Type> default_types,SourceWriter * writer)147 void RenderSecondaryFactoryMethod(const OpSpec& op, const Type& op_class,
148 std::map<string, Type> default_types,
149 SourceWriter* writer) {
150 // Build the return type for the secondary factory, replacing generic
151 // parameters with their default value if any
152 Type return_type = Type::Class(op_class.name(), op_class.package());
153 for (const Type& parameter : op_class.parameters()) {
154 if (parameter.kind() == Type::GENERIC &&
155 default_types.find(parameter.name()) != default_types.end()) {
156 return_type.add_parameter(default_types.at(parameter.name()));
157 } else {
158 return_type.add_parameter(parameter);
159 }
160 }
161 Method factory = Method::Create("create", return_type);
162 Javadoc factory_doc = Javadoc::Create(
163 "Factory method to create a class to wrap a new " + op_class.name() +
164 " operation to the graph, using "
165 "default output types.");
166 Variable scope =
167 Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op"));
168 AddArgument(scope, "current graph scope", &factory, &factory_doc);
169 std::stringstream factory_statement;
170 factory_statement << "return create(scope";
171 for (const ArgumentSpec& input : op.inputs()) {
172 AddArgument(input.var(), input.description(), &factory, &factory_doc);
173 factory_statement << ", " << input.var().name();
174 }
175 for (const AttributeSpec& attr : op.attributes()) {
176 // Only add attributes that are not types or have no default value to the
177 // signature of the secondary factory
178 factory_statement << ", ";
179 if (attr.type().kind() == Type::GENERIC &&
180 default_types.find(attr.type().name()) != default_types.end()) {
181 factory_statement << default_types.at(attr.type().name()).name()
182 << ".class";
183 } else {
184 AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
185 factory_statement << attr.var().name();
186 }
187 }
188 if (!op.optional_attributes().empty()) {
189 Variable options_var = Variable::Varargs("options", Type::Class("Options"));
190 AddArgument(options_var, "carries optional attributes values", &factory,
191 &factory_doc);
192 factory_statement << ", " << options_var.name();
193 }
194 factory_doc.add_tag("return", "a new instance of " + op_class.name());
195
196 writer->BeginMethod(factory, PUBLIC | STATIC, &factory_doc);
197 writer->Append(factory_statement.str().c_str()).Append(");").EndLine();
198 writer->EndMethod();
199 }
200
RenderFactoryMethods(const OpSpec & op,const Type & op_class,SourceWriter * writer)201 void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
202 SourceWriter* writer) {
203 Method factory = Method::Create("create", op_class);
204 Javadoc factory_doc =
205 Javadoc::Create("Factory method to create a class to wrap a new " +
206 op_class.name() + " operation to the graph.");
207 Variable scope =
208 Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op"));
209 AddArgument(scope, "current graph scope", &factory, &factory_doc);
210 for (const ArgumentSpec& input : op.inputs()) {
211 AddArgument(input.var(), input.description(), &factory, &factory_doc);
212 }
213 std::map<string, Type> default_types;
214 for (const AttributeSpec& attr : op.attributes()) {
215 AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
216 // If this attribute is a type with a default value, save its value
217 // for passing it implicitly in a secondary factory method
218 if (attr.has_default_value() && attr.type().kind() == Type::GENERIC) {
219 Type default_type = Type::ForDataType(attr.default_value()->type());
220 if (!default_type.wildcard()) {
221 default_types.insert(std::make_pair(attr.type().name(), default_type));
222 }
223 }
224 }
225 if (!op.optional_attributes().empty()) {
226 AddArgument(Variable::Varargs("options", Type::Class("Options")),
227 "carries optional attributes values", &factory, &factory_doc);
228 }
229 factory_doc.add_tag("return", "a new instance of " + op_class.name());
230
231 writer->BeginMethod(factory, PUBLIC | STATIC, &factory_doc);
232 writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\"" +
233 op.graph_op_name() + "\", scope.makeOpName(\"" +
234 op_class.name() + "\"));");
235 writer->EndLine();
236 for (const ArgumentSpec& input : op.inputs()) {
237 if (input.iterable()) {
238 writer->Append("opBuilder.addInputList(Operands.asOutputs(" +
239 input.var().name() + "));");
240 writer->EndLine();
241 } else {
242 writer->Append("opBuilder.addInput(" + input.var().name() +
243 ".asOutput());");
244 writer->EndLine();
245 }
246 }
247 for (const AttributeSpec& attribute : op.attributes()) {
248 WriteSetAttrDirective(attribute, false, writer);
249 }
250 if (!op.optional_attributes().empty()) {
251 writer->BeginBlock("if (options != null)")
252 .BeginBlock("for (Options opts : options)");
253 for (const AttributeSpec& attribute : op.optional_attributes()) {
254 writer->BeginBlock("if (opts." + attribute.var().name() + " != null)");
255 WriteSetAttrDirective(attribute, true, writer);
256 writer->EndBlock();
257 }
258 writer->EndBlock().EndBlock();
259 }
260 writer->Append("return new ")
261 .AppendType(op_class)
262 .Append("(opBuilder.build());")
263 .EndLine();
264 writer->EndMethod();
265
266 // If this operation has type attributes with a default value, create a
267 // second factory method that infers those values implicitly
268 if (!default_types.empty()) {
269 RenderSecondaryFactoryMethod(op, op_class, default_types, writer);
270 }
271 }
272
RenderConstructor(const OpSpec & op,const Type & op_class,SourceWriter * writer)273 void RenderConstructor(const OpSpec& op, const Type& op_class,
274 SourceWriter* writer) {
275 Variable operation =
276 Variable::Create("operation", Type::Class("Operation", "org.tensorflow"));
277 Method constructor = Method::ConstructorFor(op_class).add_argument(operation);
278 for (const ArgumentSpec& output : op.outputs()) {
279 if (output.iterable() && !output.type().wildcard()) {
280 constructor.add_annotation(
281 Annotation::Create("SuppressWarnings").attributes("\"unchecked\""));
282 break;
283 }
284 }
285 writer->BeginMethod(constructor, PRIVATE)
286 .Append("super(operation);")
287 .EndLine();
288 if (!op.outputs().empty()) {
289 writer->Append("int outputIdx = 0;").EndLine();
290 for (const ArgumentSpec& output : op.outputs()) {
291 if (output.iterable()) {
292 string var_length = output.var().name() + "Length";
293 writer->Append("int " + var_length)
294 .Append(" = operation.outputListLength(\"" + output.op_def_name() +
295 "\");")
296 .EndLine()
297 .Append(output.var().name() + " = Arrays.asList(");
298 if (!output.type().wildcard()) {
299 writer->Append("(")
300 .AppendType(output.var().type().parameters().front())
301 .Append("[])");
302 }
303 writer->Append("operation.outputList(outputIdx, " + var_length + "));")
304 .EndLine()
305 .Append("outputIdx += " + var_length + ";")
306 .EndLine();
307 } else {
308 writer
309 ->Append(output.var().name() + " = operation.output(outputIdx++);")
310 .EndLine();
311 }
312 }
313 }
314 writer->EndMethod();
315 }
316
RenderGettersAndSetters(const OpSpec & op,SourceWriter * writer)317 void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) {
318 for (const AttributeSpec& attr : op.optional_attributes()) {
319 Method setter = Method::Create(attr.var().name(), Type::Class("Options"));
320 Javadoc setter_doc = Javadoc::Create();
321 AddArgument(attr.var(), attr.description(), &setter, &setter_doc);
322 writer->BeginMethod(setter, PUBLIC | STATIC, &setter_doc)
323 .Append("return new Options()." + attr.var().name() + "(" +
324 attr.var().name() + ");")
325 .EndLine()
326 .EndMethod();
327 }
328 for (const ArgumentSpec& output : op.outputs()) {
329 Method getter = Method::Create(output.var().name(), output.var().type());
330 Javadoc getter_doc = Javadoc::Create(output.description());
331 writer->BeginMethod(getter, PUBLIC, &getter_doc)
332 .Append("return " + output.var().name() + ";")
333 .EndLine()
334 .EndMethod();
335 }
336 }
337
RenderInterfaceImpl(const OpSpec & op,RenderMode mode,SourceWriter * writer)338 void RenderInterfaceImpl(const OpSpec& op, RenderMode mode,
339 SourceWriter* writer) {
340 ArgumentSpec output = op.outputs().front();
341
342 if (mode == OPERAND) {
343 bool cast2obj = output.type().wildcard();
344 Type return_type =
345 Type::Class("Output", "org.tensorflow")
346 .add_parameter(cast2obj ? Type::Class("Object") : output.type());
347 Method as_output = Method::Create("asOutput", return_type)
348 .add_annotation(Annotation::Create("Override"));
349 if (cast2obj) {
350 as_output.add_annotation(
351 Annotation::Create("SuppressWarnings").attributes("\"unchecked\""));
352 }
353 writer->BeginMethod(as_output, PUBLIC);
354 if (cast2obj) {
355 writer->Append("return (").AppendType(return_type).Append(") ");
356 } else {
357 writer->Append("return ");
358 }
359 writer->Append(output.var().name() + ";").EndLine().EndMethod();
360
361 } else if (mode == LIST_OPERAND) {
362 Type operand = Type::Interface("Operand", "org.tensorflow");
363 if (output.type().wildcard()) {
364 operand.add_parameter(Type::Class("Object"));
365 } else {
366 operand.add_parameter(output.type());
367 }
368 Type return_type =
369 Type::Interface("Iterator", "java.util").add_parameter(operand);
370 Method iterator =
371 Method::Create("iterator", return_type)
372 .add_annotation(Annotation::Create("Override"))
373 .add_annotation(Annotation::Create("SuppressWarnings")
374 .attributes("{\"rawtypes\", \"unchecked\"}"));
375 // cast the output list using a raw List
376 writer->BeginMethod(iterator, PUBLIC)
377 .Append("return (" + return_type.name() + ") ")
378 .Append(output.var().name() + ".iterator();")
379 .EndLine()
380 .EndMethod();
381 }
382 }
383
RenderOptionsClass(const OpSpec & op,const Type & op_class,SourceWriter * writer)384 void RenderOptionsClass(const OpSpec& op, const Type& op_class,
385 SourceWriter* writer) {
386 Type options_class = Type::Class("Options");
387 Javadoc options_doc = Javadoc::Create("Optional attributes for {@link " +
388 op_class.canonical_name() + "}");
389 writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc);
390 for (const AttributeSpec& attr : op.optional_attributes()) {
391 Method setter = Method::Create(attr.var().name(), options_class);
392 Javadoc setter_doc = Javadoc::Create();
393 AddArgument(attr.var(), attr.description(), &setter, &setter_doc);
394 writer->BeginMethod(setter, PUBLIC, &setter_doc)
395 .Append("this." + attr.var().name() + " = " + attr.var().name() + ";")
396 .EndLine()
397 .Append("return this;")
398 .EndLine()
399 .EndMethod();
400 }
401 writer->EndLine();
402 for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
403 writer->WriteField(optional_attribute.var(), PRIVATE);
404 }
405 Method constructor = Method::ConstructorFor(options_class);
406 writer->BeginMethod(constructor, PRIVATE).EndMethod();
407 writer->EndType();
408 }
409
ClassOf(const EndpointSpec & endpoint,const string & base_package)410 inline Type ClassOf(const EndpointSpec& endpoint, const string& base_package) {
411 return Type::Class(
412 endpoint.name(),
413 base_package + "." + str_util::Lowercase(endpoint.package()));
414 }
415
GenerateOp(const OpSpec & op,const EndpointSpec & endpoint,const string & base_package,const string & output_dir,Env * env)416 void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
417 const string& base_package, const string& output_dir,
418 Env* env) {
419 Type op_class(
420 ClassOf(endpoint, base_package)
421 .add_supertype(Type::Class("PrimitiveOp", "org.tensorflow.op")));
422 Javadoc op_javadoc(endpoint.javadoc());
423
424 // op interfaces
425 RenderMode mode = DEFAULT;
426 if (op.outputs().size() == 1) {
427 const ArgumentSpec& output = op.outputs().front();
428 Type operand_type(output.type().wildcard() ? Type::Class("Object")
429 : output.type());
430 Type operand_inf(Type::Interface("Operand", "org.tensorflow")
431 .add_parameter(operand_type));
432 if (output.iterable()) {
433 mode = LIST_OPERAND;
434 op_class.add_supertype(Type::IterableOf(operand_inf));
435 } else {
436 mode = OPERAND;
437 op_class.add_supertype(operand_inf);
438 }
439 }
440 // op generic parameters
441 std::set<string> generics;
442 for (const ArgumentSpec& output : op.outputs()) {
443 if (output.type().kind() == Type::GENERIC && !output.type().wildcard() &&
444 generics.find(output.type().name()) == generics.end()) {
445 op_class.add_parameter(output.type());
446 op_javadoc.add_param_tag(
447 "<" + output.type().name() + ">",
448 "data type for {@code " + output.var().name() + "()} output");
449 generics.insert(output.type().name());
450 }
451 }
452 // op annotations
453 if (endpoint.deprecated()) {
454 op_class.add_annotation(Annotation::Create("Deprecated"));
455 string explanation;
456 if (!op.endpoints().front().deprecated()) {
457 explanation =
458 "use {@link " +
459 ClassOf(op.endpoints().front(), base_package).canonical_name() +
460 "} instead";
461 } else {
462 explanation = op.deprecation_explanation();
463 }
464 op_javadoc.add_tag("deprecated", explanation);
465 }
466 if (!op.hidden()) {
467 // expose the op in the Ops Graph API only if it is visible
468 Annotation oper_annot =
469 Annotation::Create("Operator", "org.tensorflow.op.annotation");
470 if (endpoint.package() != kDefaultEndpointPackage) {
471 oper_annot.attributes("group = \"" + endpoint.package() + "\"");
472 }
473 op_class.add_annotation(oper_annot);
474 }
475 // create op class file
476 const string op_dir_name = io::JoinPath(
477 output_dir, str_util::StringReplace(op_class.package(), ".", "/", true));
478 if (!env->FileExists(op_dir_name).ok()) {
479 TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(op_dir_name))
480 << op_dir_name;
481 }
482 const string op_file_name = op_class.name() + ".java";
483 std::unique_ptr<tensorflow::WritableFile> op_file;
484 TF_CHECK_OK(
485 env->NewWritableFile(io::JoinPath(op_dir_name, op_file_name), &op_file))
486 << op_file_name;
487
488 // render endpoint source code
489 SourceFileWriter writer(op_file.get());
490 std::list<Type> dependencies;
491 CollectOpDependencies(op, mode, &dependencies);
492 writer.Write(kLicense)
493 .EndLine()
494 .Write("// This class has been generated, DO NOT EDIT!")
495 .EndLine()
496 .EndLine()
497 .BeginType(op_class, PUBLIC | FINAL, &dependencies, &op_javadoc);
498 if (!op.optional_attributes().empty()) {
499 RenderOptionsClass(op, op_class, &writer);
500 }
501 RenderFactoryMethods(op, op_class, &writer);
502 RenderGettersAndSetters(op, &writer);
503 if (mode != DEFAULT) {
504 RenderInterfaceImpl(op, mode, &writer);
505 }
506 writer.EndLine();
507 for (const ArgumentSpec& output : op.outputs()) {
508 writer.WriteField(output.var(), PRIVATE);
509 }
510 RenderConstructor(op, op_class, &writer);
511 writer.EndType();
512 }
513
CanGenerateOp(const OpDef & op_def,const ApiDef & api_def)514 bool CanGenerateOp(const OpDef& op_def, const ApiDef& api_def) {
515 if (api_def.visibility() == ApiDef::SKIP) {
516 return false;
517 }
518 for (const auto& attr : op_def.attr()) {
519 if (attr.type() == "func" || attr.type() == "list(func)") {
520 return false; // TODO(karllessard) add support for function attributes
521 }
522 }
523 return true;
524 }
525
526 } // namespace
527
Run(const OpList & op_list,const string & base_package,const string & output_dir)528 Status OpGenerator::Run(const OpList& op_list, const string& base_package,
529 const string& output_dir) {
530 ApiDefMap api_map(op_list);
531 if (!api_dirs_.empty()) {
532 // Only load api files that correspond to the requested "op_list"
533 for (const auto& op : op_list.op()) {
534 for (const auto& api_def_dir : api_dirs_) {
535 const std::string api_def_file_pattern =
536 io::JoinPath(api_def_dir, "api_def_" + op.name() + ".pbtxt");
537 if (env_->FileExists(api_def_file_pattern).ok()) {
538 TF_CHECK_OK(api_map.LoadFile(env_, api_def_file_pattern))
539 << api_def_file_pattern;
540 }
541 }
542 }
543 }
544 api_map.UpdateDocs();
545 for (const auto& op_def : op_list.op()) {
546 const ApiDef* api_def = api_map.GetApiDef(op_def.name());
547 if (CanGenerateOp(op_def, *api_def)) {
548 OpSpec op(OpSpec::Create(op_def, *api_def));
549 for (const EndpointSpec& endpoint : op.endpoints()) {
550 GenerateOp(op, endpoint, base_package, output_dir, env_);
551 }
552 }
553 }
554 return Status::OK();
555 }
556
557 } // namespace java
558 } // namespace tensorflow
559