1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/python/framework/python_op_gen_internal.h"
17 
18 #include <float.h>
19 #include <stdio.h>
20 #include <iomanip>
21 #include <sstream>
22 #include <unordered_map>
23 #include "tensorflow/core/framework/api_def.pb.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_def.pb.h"
27 #include "tensorflow/core/framework/op_def.pb_text.h"
28 #include "tensorflow/core/framework/op_def_util.h"
29 #include "tensorflow/core/framework/op_gen_lib.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/framework/tensor.pb_text.h"
32 #include "tensorflow/core/framework/tensor_shape.pb.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/lib/gtl/map_util.h"
36 #include "tensorflow/core/lib/gtl/stl_util.h"
37 #include "tensorflow/core/lib/strings/str_util.h"
38 #include "tensorflow/core/lib/strings/strcat.h"
39 #include "tensorflow/core/lib/strings/stringprintf.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/platform/types.h"
43 
44 namespace tensorflow {
45 namespace python_op_gen_internal {
46 
47 const int kRightMargin = 78;
48 // Names specified in tf_export decorators are exported to
49 // TensorFlow 2.0 by default.
50 const int kLatestAPIExportVersion = 2;
51 
IsPythonReserved(const string & s)52 bool IsPythonReserved(const string& s) {
53   static const std::set<string>* const kPythonReserved = new std::set<string>(
54       {// Keywords in Python, from:
55        //   import keyword
56        //   print keyword.kwlist
57        "and", "as", "assert", "break", "class", "continue", "def", "del",
58        "elif", "else", "except", "exec", "finally", "for", "from", "global",
59        "if", "import", "in", "is", "lambda", "not", "or", "pass", "print",
60        "raise", "return", "try", "while", "with", "yield",
61        // Built-in functions and types in Python, from:
62        //   [x for x in dir(__builtins__) if not x[0].islower()]
63        "ArithmeticError", "AssertionError", "AttributeError", "BaseException",
64        "BufferError", "BytesWarning", "DeprecationWarning", "EOFError",
65        "Ellipsis", "EnvironmentError", "Exception", "False",
66        "FloatingPointError", "FutureWarning", "GeneratorExit", "IOError",
67        "ImportError", "ImportWarning", "IndentationError", "IndexError",
68        "KeyError", "KeyboardInterrupt", "LookupError", "MemoryError",
69        "NameError", "None", "NotImplemented", "NotImplementedError", "OSError",
70        "OverflowError", "PendingDeprecationWarning", "ReferenceError",
71        "RuntimeError", "RuntimeWarning", "StandardError", "StopIteration",
72        "SyntaxError", "SyntaxWarning", "SystemError", "SystemExit", "TabError",
73        "True", "TypeError", "UnboundLocalError", "UnicodeDecodeError",
74        "UnicodeEncodeError", "UnicodeError", "UnicodeTranslateError",
75        "UnicodeWarning", "UserWarning", "ValueError", "Warning",
76        "ZeroDivisionError", "__debug__", "__doc__", "__import__", "__name__",
77        "__package__"});
78 
79   return kPythonReserved->count(s) > 0;
80 }
81 
IsOpWithUnderscorePrefix(const string & s)82 bool IsOpWithUnderscorePrefix(const string& s) {
83   static const std::set<string>* const kUnderscoreOps = new std::set<string>(
84       {// Lowercase built-in functions and types in Python, from:
85        // [x for x in dir(__builtins__) if x[0].islower()] except "round".
86        // These need to be excluded so they don't conflict with actual built-in
87        // functions since we use '*' imports.
88        "abs", "all", "any", "apply", "bin", "bool", "buffer", "bytearray",
89        "bytes", "callable", "chr", "classmethod", "cmp", "coerce", "compile",
90        "complex", "copyright", "credits", "delattr", "dict", "dir", "divmod",
91        "enumerate", "eval", "execfile", "exit", "file", "filter", "float",
92        "format", "frozenset", "getattr", "globals", "hasattr", "hash", "help",
93        "hex", "id", "input", "int", "intern", "isinstance", "issubclass",
94        "iter", "len", "license", "list", "locals", "long", "map", "max",
95        "memoryview", "min", "next", "object", "oct", "open", "ord", "pow",
96        "print", "property", "quit", "range", "raw_input", "reduce", "reload",
97        "repr", "reversed", "set", "setattr", "slice", "sorted", "staticmethod",
98        "str", "sum", "super", "tuple", "type", "unichr", "unicode", "vars",
99        "xrange", "zip",
100        // These have the same name as ops defined in Python and might be used
101        // incorrectly depending on order of '*' imports.
102        // TODO(annarev): reduce usage of '*' imports and remove these from the
103        // list.
104        "fused_batch_norm", "histogram_fixed_width", "stack",
105        "batch_norm_with_global_normalization", "clip_by_value"});
106   return kUnderscoreOps->count(s) > 0;
107 }
108 
AvoidPythonReserved(const string & s)109 string AvoidPythonReserved(const string& s) {
110   if (IsPythonReserved(s)) return strings::StrCat(s, "_");
111   return s;
112 }
113 
114 // Indent the first line by "initial" spaces and all following lines
115 // by "rest" spaces.
Indent(int initial,int rest,StringPiece in)116 string Indent(int initial, int rest, StringPiece in) {
117   // TODO(josh11b): Also word-wrapping?
118   string copy(in.data(), in.size());
119   str_util::StripTrailingWhitespace(&copy);
120   std::vector<string> v = str_util::Split(copy, '\n');
121 
122   string result;
123   bool first = true;
124   for (const string& line : v) {
125     if (first) {
126       result = strings::StrCat(Spaces(initial), line, "\n");
127       first = false;
128     } else {
129       if (line.empty()) {
130         strings::StrAppend(&result, "\n");
131       } else {
132         strings::StrAppend(&result, Spaces(rest), line, "\n");
133       }
134     }
135   }
136   return result;
137 }
138 
139 // Adds append to *dest, with a space if the first line will be <= width,
140 // or a newline otherwise.
AppendWithinWidth(string * dest,StringPiece append,int width)141 void AppendWithinWidth(string* dest, StringPiece append, int width) {
142   auto first_line = append.find('\n');
143   if (first_line == string::npos) first_line = append.size();
144   if (dest->size() + first_line + 1 /* space */ > static_cast<size_t>(width)) {
145     strings::StrAppend(dest, "\n", append);
146   } else {
147     strings::StrAppend(dest, " ", append);
148   }
149 }
150 
151 // Like DataTypeString() but uses the Python names for the
152 // float types.
PythonDataTypeString(DataType dtype)153 string PythonDataTypeString(DataType dtype) {
154   switch (dtype) {
155     case DT_FLOAT:
156       return "float32";
157     case DT_DOUBLE:
158       return "float64";
159     default:
160       return DataTypeString(dtype);
161   }
162 }
163 
TypeString(DataType dtype,bool ref)164 string TypeString(DataType dtype, bool ref) {
165   if (ref) {
166     return strings::StrCat("mutable `", PythonDataTypeString(dtype), "`");
167   } else {
168     return strings::StrCat("`", PythonDataTypeString(dtype), "`");
169   }
170 }
171 
TypeListString(const AttrValue & value)172 string TypeListString(const AttrValue& value) {
173   string ret;
174   for (int t : value.list().type()) {
175     if (!ret.empty()) strings::StrAppend(&ret, ", ");
176     DataType dtype = static_cast<DataType>(t);
177     if (IsRefType(dtype)) {
178       strings::StrAppend(&ret, PythonDataTypeString(RemoveRefType(dtype)),
179                          " mutable");
180     } else {
181       strings::StrAppend(&ret, "`", PythonDataTypeString(dtype), "`");
182     }
183   }
184   return ret;
185 }
186 
SingleTensorName(DataType dtype,bool is_ref)187 string SingleTensorName(DataType dtype, bool is_ref) {
188   const string type_str = TypeString(dtype, is_ref);
189   return strings::StrCat("A `Tensor` of type ", type_str, ".");
190 }
191 
192 const char kUnknownTensorType[] = {"A `Tensor`."};
193 
ArgTypeName(const OpDef & op_def,const OpDef::ArgDef & arg,const std::unordered_map<string,string> & inferred_attrs,bool is_output)194 string ArgTypeName(const OpDef& op_def, const OpDef::ArgDef& arg,
195                    const std::unordered_map<string, string>& inferred_attrs,
196                    bool is_output) {
197   if (!arg.number_attr().empty()) {
198     // N Tensors with the same type
199     const string* original_arg =
200         gtl::FindOrNull(inferred_attrs, arg.number_attr());
201     string prefix;
202     if (original_arg == nullptr) {
203       prefix = strings::StrCat("A list of `", arg.number_attr(), "`");
204     } else if (*original_arg == arg.name()) {
205       const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def);
206       if (attr->has_minimum() && attr->minimum() > 0) {
207         prefix = strings::StrCat("A list of at least ", attr->minimum());
208       } else {
209         prefix = "A list of";
210       }
211     } else {
212       prefix = strings::StrCat("A list with the same length as `",
213                                AvoidPythonReserved(*original_arg), "` of");
214     }
215 
216     if (arg.type() != DT_INVALID) {
217       return strings::StrCat(prefix, " `Tensor` objects with type ",
218                              TypeString(arg.type(), arg.is_ref()), ".");
219     } else {
220       original_arg = gtl::FindOrNull(inferred_attrs, arg.type_attr());
221       if (arg.is_ref()) {
222         strings::StrAppend(&prefix, " mutable");
223       }
224       if (original_arg == nullptr) {
225         return strings::StrCat(prefix, " `Tensor` objects with type `",
226                                arg.type_attr(), "`.");
227       } else if (*original_arg == arg.name()) {
228         const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def);
229         if (attr->has_allowed_values()) {
230           return strings::StrCat(prefix,
231                                  " `Tensor` objects with the same type in: ",
232                                  TypeListString(attr->allowed_values()), ".");
233         } else {
234           return strings::StrCat(prefix,
235                                  " `Tensor` objects with the same type.");
236         }
237       } else {
238         return strings::StrCat(prefix,
239                                " `Tensor` objects with the same type as `",
240                                AvoidPythonReserved(*original_arg), "`.");
241       }
242     }
243   } else if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) {
244     const bool is_list = !arg.type_list_attr().empty();
245     const string attr_name = is_list ? arg.type_list_attr() : arg.type_attr();
246     const OpDef::AttrDef* attr = FindAttr(attr_name, op_def);
247     const string mutable_str = arg.is_ref() ? "mutable " : "";
248     const string prefix =
249         is_list ? strings::StrCat("A list of ", mutable_str, "`Tensor` objects")
250                 : strings::StrCat("A ", mutable_str, "`Tensor`");
251     const string* original_arg = gtl::FindOrNull(inferred_attrs, attr_name);
252     if (original_arg == nullptr) {
253       return strings::StrCat(prefix, " of type `", attr_name, "`.");
254     } else if (*original_arg == arg.name()) {
255       if (attr->has_allowed_values()) {
256         if (is_list) {
257           return strings::StrCat(prefix, " with types from: ",
258                                  TypeListString(attr->allowed_values()), ".");
259         } else {
260           return strings::StrCat(
261               prefix, is_output ? ". Has one of the following types: "
262                                 : ". Must be one of the following types: ",
263               TypeListString(attr->allowed_values()), ".");
264         }
265       } else {
266         return strings::StrCat(prefix, ".");
267       }
268     } else {
269       return strings::StrCat(prefix,
270                              is_output ? ". Has the same type as `"
271                                        : ". Must have the same type as `",
272                              AvoidPythonReserved(*original_arg), "`.");
273     }
274   } else {
275     return SingleTensorName(arg.type(), arg.is_ref());
276   }
277 }
278 
GetReturns(const OpDef & op_def,const std::vector<string> & output_type_string)279 string GetReturns(const OpDef& op_def,
280                   const std::vector<string>& output_type_string) {
281   string result;
282   DCHECK_EQ(op_def.output_arg_size(), output_type_string.size());
283   const int num_outs = op_def.output_arg_size();
284   strings::StrAppend(&result, "\n  Returns:\n");
285   if (num_outs == 0) {
286     strings::StrAppend(&result, "    The created Operation.\n");
287   } else {
288     if (num_outs == 1) {
289       StringPiece description = op_def.output_arg(0).description();
290       if (ConsumeEquals(&description)) {  // Skip the generated type info.
291         strings::StrAppend(&result, Indent(4, 4, description));
292       } else {
293         // Special case of one output, don't use the name of the output unless
294         // there is no description.
295         string desc = output_type_string.empty() ? kUnknownTensorType
296                                                  : output_type_string[0];
297         if (desc == kUnknownTensorType) {
298           // Special case where we don't understand how the output tensor type
299           // depends on the input tensor types, just use the output arg
300           // description if we can.
301           if (!description.empty()) {
302             desc = op_def.output_arg(0).description();
303           } else if (!op_def.output_arg(0).name().empty()) {
304             desc = strings::StrCat(" The ", op_def.output_arg(0).name(),
305                                    " `Tensor`.");
306           }
307         } else if (!description.empty()) {
308           AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */);
309         }
310         strings::StrAppend(&result, Indent(4, 4, desc));
311       }
312     } else {
313       std::vector<string> out_names(num_outs);
314       for (int i = 0; i < num_outs; ++i) {
315         if (!op_def.output_arg(i).name().empty()) {
316           out_names[i] = op_def.output_arg(i).name();
317         } else {
318           out_names[i] = strings::StrCat("output", i);
319         }
320       }
321       strings::StrAppend(&result, "    A tuple of `Tensor` objects (",
322                          str_util::Join(out_names, ", "), ").\n\n");
323       for (int i = 0; i < num_outs; ++i) {
324         string desc = strings::StrCat(out_names[i], ": ");
325         StringPiece description = op_def.output_arg(i).description();
326         if (ConsumeEquals(&description)) {  // Skip the generated type info.
327           strings::StrAppend(&desc, description);
328         } else {
329           const string type = static_cast<size_t>(i) < output_type_string.size()
330                                   ? output_type_string[i]
331                                   : kUnknownTensorType;
332           if (!description.empty()) {
333             if (type == kUnknownTensorType) {
334               // Special case where we don't understand how the output tensor
335               // type depends on the input tensor types, so we just use the
336               // output arg description.
337               strings::StrAppend(&desc, description);
338             } else {
339               strings::StrAppend(&desc, type, " ", description);
340             }
341           } else {
342             strings::StrAppend(&desc, type);
343           }
344         }
345         strings::StrAppend(&result, Indent(4, 6, desc));
346       }
347     }
348   }
349   return result;
350 }
351 
StringToPython(const string & str)352 string StringToPython(const string& str) {
353   return strings::StrCat("\"", str_util::CEscape(str), "\"");
354 }
355 
DataTypeToPython(DataType dtype,const string & dtype_module)356 string DataTypeToPython(DataType dtype, const string& dtype_module) {
357   return strings::StrCat(dtype_module, PythonDataTypeString(dtype));
358 }
359 
ShapeToPython(const TensorShapeProto & shape)360 string ShapeToPython(const TensorShapeProto& shape) {
361   if (shape.unknown_rank()) {
362     return "None";
363   }
364   string python = "[";
365   for (const auto& dim : shape.dim()) {
366     if (python.size() > 1) strings::StrAppend(&python, ", ");
367     if (!dim.name().empty()) {
368       strings::StrAppend(&python, "(", StringToPython(dim.name()), ", ",
369                          dim.size(), ")");
370     } else {
371       strings::StrAppend(&python, dim.size());
372     }
373   }
374   strings::StrAppend(&python, "]");
375   return python;
376 }
377 
TensorToPython(const TensorProto & proto)378 string TensorToPython(const TensorProto& proto) {
379   return ProtoShortDebugString(proto);
380 }
381 
AttrListToPython(const AttrValue & value,const string & dtype_module="tf.")382 string AttrListToPython(const AttrValue& value,
383                         const string& dtype_module = "tf.") {
384   string ret;
385   if (value.list().s_size() > 0) {
386     for (int i = 0; i < value.list().s_size(); ++i) {
387       if (i > 0) strings::StrAppend(&ret, ", ");
388       strings::StrAppend(&ret, StringToPython(value.list().s(i)));
389     }
390   } else if (value.list().i_size() > 0) {
391     for (int i = 0; i < value.list().i_size(); ++i) {
392       if (i > 0) strings::StrAppend(&ret, ", ");
393       strings::StrAppend(&ret, value.list().i(i));
394     }
395   } else if (value.list().f_size() > 0) {
396     for (int i = 0; i < value.list().f_size(); ++i) {
397       if (i > 0) strings::StrAppend(&ret, ", ");
398       strings::StrAppend(&ret, value.list().f(i));
399     }
400   } else if (value.list().b_size() > 0) {
401     for (int i = 0; i < value.list().b_size(); ++i) {
402       if (i > 0) strings::StrAppend(&ret, ", ");
403       strings::StrAppend(&ret, value.list().b(i) ? "True" : "False");
404     }
405   } else if (value.list().type_size() > 0) {
406     for (int i = 0; i < value.list().type_size(); ++i) {
407       if (i > 0) strings::StrAppend(&ret, ", ");
408       strings::StrAppend(&ret,
409                          DataTypeToPython(value.list().type(i), dtype_module));
410     }
411   } else if (value.list().shape_size() > 0) {
412     for (int i = 0; i < value.list().shape_size(); ++i) {
413       if (i > 0) strings::StrAppend(&ret, ", ");
414       strings::StrAppend(&ret, ShapeToPython(value.list().shape(i)));
415     }
416   } else if (value.list().tensor_size() > 0) {
417     for (int i = 0; i < value.list().tensor_size(); ++i) {
418       if (i > 0) strings::StrAppend(&ret, ", ");
419       strings::StrAppend(&ret, TensorToPython(value.list().tensor(i)));
420     }
421   } else if (value.list().func_size() > 0) {
422     for (int i = 0; i < value.list().func_size(); ++i) {
423       if (i > 0) strings::StrAppend(&ret, ", ");
424       strings::StrAppend(&ret, StringToPython(value.list().func(i).name()));
425     }
426   }
427   return ret;
428 }
429 
430 // NOTE: The return value may contain spaces (for example, it could be
431 // a string "foo bar" with an embedded space) and is not safe to pass
432 // to WordWrap().
AttrValueToPython(const string & type,const AttrValue & value,const string & dtype_module)433 string AttrValueToPython(const string& type, const AttrValue& value,
434                          const string& dtype_module) {
435   if (type == "string") {
436     return StringToPython(value.s());
437   } else if (type == "int") {
438     return strings::StrCat(value.i());
439   } else if (type == "float") {
440     if (std::isnan(value.f()) || std::isinf(value.f())) {
441       return strings::StrCat("float('", value.f(), "')");
442     } else {
443       // Use locale-independent conversion.
444       static_assert(FLT_DIG < 10, "FLT_DIG is too big");
445       std::ostringstream s;
446       s.imbue(std::locale::classic());
447       s << std::setprecision(FLT_DIG) << value.f();
448       return s.str();
449     }
450   } else if (type == "bool") {
451     return value.b() ? "True" : "False";
452   } else if (type == "type") {
453     return DataTypeToPython(value.type(), dtype_module);
454   } else if (type == "shape") {
455     return ShapeToPython(value.shape());
456   } else if (type == "tensor") {
457     return TensorToPython(value.tensor());
458   } else if (type == "func") {
459     return StringToPython(value.func().name());
460   } else if (str_util::StartsWith(type, "list(")) {
461     return strings::StrCat("[", AttrListToPython(value, dtype_module), "]");
462   } else {
463     return "?";
464   }
465 }
466 
GenerateLowerCaseOpName(const string & str,string * result)467 void GenerateLowerCaseOpName(const string& str, string* result) {
468   const char joiner = '_';
469   const int last_index = str.size() - 1;
470   for (int i = 0; i <= last_index; ++i) {
471     const char c = str[i];
472     // Emit a joiner only if a previous-lower-to-now-upper or a
473     // now-upper-to-next-lower transition happens.
474     if (isupper(c) && (i > 0)) {
475       if (islower(str[i - 1]) || ((i < last_index) && islower(str[i + 1]))) {
476         result->push_back(joiner);
477       }
478     }
479     result->push_back(tolower(c));
480   }
481 }
482 
AddDelimiter(string * append_to,const string & delim)483 static void AddDelimiter(string* append_to, const string& delim) {
484   if (!append_to->empty()) strings::StrAppend(append_to, delim);
485 }
486 
FindAttr(StringPiece name,const ApiDef & api_def)487 const ApiDef::Attr* FindAttr(StringPiece name, const ApiDef& api_def) {
488   for (int i = 0; i < api_def.attr_size(); ++i) {
489     if (api_def.attr(i).name() == name) {
490       return &api_def.attr(i);
491     }
492   }
493   return nullptr;
494 }
495 
GenPythonOp(const OpDef & op_def,const ApiDef & api_def,const string & function_name)496 GenPythonOp::GenPythonOp(const OpDef& op_def, const ApiDef& api_def,
497                          const string& function_name)
498     : op_def_(op_def),
499       api_def_(api_def),
500       function_name_(function_name),
501       num_outs_(op_def.output_arg_size()) {}
502 
~GenPythonOp()503 GenPythonOp::~GenPythonOp() {}
504 
Code()505 string GenPythonOp::Code() {
506   // This has all the input args followed by those attrs that don't have
507   // defaults.
508   std::vector<ParamNames> params_no_default;
509   // The parameters with defaults (these have to be listed after those without).
510   // No input args are included, just attrs.
511   std::vector<ParamNames> params_with_default;
512 
513   for (int i = 0; i < api_def_.arg_order_size(); ++i) {
514     const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
515     const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
516     params_no_default.emplace_back(api_def_arg.name(), api_def_arg.rename_to());
517     if (!arg.type_attr().empty()) {
518       gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_attr(), arg.name());
519     } else if (!arg.type_list_attr().empty()) {
520       gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_list_attr(),
521                               arg.name());
522     }
523     if (!arg.number_attr().empty()) {
524       gtl::InsertIfNotPresent(&inferred_attrs_, arg.number_attr(), arg.name());
525     }
526   }
527   for (int i = 0; i < api_def_.attr_size(); ++i) {
528     const auto& attr(api_def_.attr(i));
529     // Do not add inferred attrs to the Python function signature.
530     if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
531       if (attr.has_default_value()) {
532         params_with_default.emplace_back(attr.name(), attr.rename_to());
533       } else {
534         params_no_default.emplace_back(attr.name(), attr.rename_to());
535       }
536     }
537   }
538 
539   // Save the list of attr parameters (attrs that won't be inferred),
540   // those with defaults go at the end.
541   // Get the attrs in the order we want by taking the attrs without defaults
542   // from the end of args_no_default, and adding args_no_default.
543   attrs_.reserve(params_no_default.size() - op_def_.input_arg_size() +
544                  params_with_default.size());
545   for (int i = op_def_.input_arg_size(); i < params_no_default.size(); ++i) {
546     attrs_.push_back(params_no_default[i].GetName());
547   }
548   for (int i = 0; i < params_with_default.size(); ++i) {
549     attrs_.push_back(params_with_default[i].GetName());
550   }
551 
552   param_names_.reserve(params_no_default.size() + params_with_default.size());
553   param_names_.insert(param_names_.begin(), params_no_default.begin(),
554                       params_no_default.end());
555   for (const auto& param : params_with_default) {
556     param_names_.push_back(param);
557   }
558 
559   string parameters;
560   for (const auto& param : params_no_default) {
561     AddDelimiter(&parameters, ", ");
562     strings::StrAppend(&parameters, param.GetRenameTo());
563   }
564   for (const auto& param_and_default : params_with_default) {
565     AddDelimiter(&parameters, ", ");
566     strings::StrAppend(&parameters, param_and_default.GetRenameTo(), "=None");
567   }
568   AddDelimiter(&parameters, ", ");
569   strings::StrAppend(&parameters, "name=None");
570 
571   AddExport();
572   AddDefLine(parameters);
573   AddDocStringDescription();
574   AddDocStringArgs();
575   AddDocStringInputs();
576   AddDocStringAttrs();
577   AddDocStringNameArg();
578   AddOutputGlobals();
579   AddDocStringOutputs();
580   strings::StrAppend(&result_, "  \"\"\"\n");
581   AddBody("  ");
582   strings::StrAppend(&result_, "\n\n");
583 
584   return prelude_ + result_;
585 }
586 
AddExport()587 void GenPythonOp::AddExport() {
588   if (api_def_.visibility() != ApiDef::VISIBLE) {
589     return;
590   }
591   // Whether op should be available in latest export version.
592   bool op_available_in_latest =
593       !api_def_.deprecation_version() ||
594       api_def_.deprecation_version() > kLatestAPIExportVersion;
595 
596   string names;
597   string names_v1;
598   string deprecated_endpoints;
599 
600   for (const auto& endpoint : api_def_.endpoint()) {
601     string endpoint_name;
602     python_op_gen_internal::GenerateLowerCaseOpName(endpoint.name(),
603                                                     &endpoint_name);
604     if (endpoint.deprecated() || endpoint.deprecation_version() > 0) {
605       AddDelimiter(&deprecated_endpoints, ", ");
606       strings::StrAppend(&deprecated_endpoints, "'", endpoint_name, "'");
607     }
608     // Add all endpoints to TensorFlow 1.* API.
609     AddDelimiter(&names_v1, ", ");
610     strings::StrAppend(&names_v1, "'", endpoint_name, "'");
611     // Add non-deprecated endpoints to TensorFlow 2.* API.
612     if (op_available_in_latest &&
613         (!endpoint.deprecation_version() ||
614          endpoint.deprecation_version() > kLatestAPIExportVersion)) {
615       AddDelimiter(&names, ", ");
616       strings::StrAppend(&names, "'", endpoint_name, "'");
617     }
618   }
619 
620   // tf_export decorator has the following format:
621   // @tf_export(v2_name, v2_name, v1=[v1_name, v1_name])
622   if (names != names_v1) {
623     AddDelimiter(&names, ", ");
624     strings::StrAppend(&names, "v1=[", names_v1, "]");
625   }
626   strings::StrAppend(&result_, "@tf_export(", names, ")\n");
627 
628   // If all endpoints are deprecated, add @deprecated decorator.
629   if (!api_def_.deprecation_message().empty()) {
630     const string instructions = api_def_.deprecation_message();
631     strings::StrAppend(&result_, "@deprecated(None, '", instructions, "')\n");
632   }
633   // Add @deprecated_endpoints decorator.
634   if (!deprecated_endpoints.empty()) {
635     strings::StrAppend(&result_, "@deprecated_endpoints(", deprecated_endpoints,
636                        ")\n");
637   }
638 }
639 
AddDefLine(const string & function_name,const string & parameters)640 void GenPythonOp::AddDefLine(const string& function_name,
641                              const string& parameters) {
642   strings::StrAppend(&result_, "def ", function_name, "(", parameters, "):\n");
643 }
644 
AddDefLine(const string & parameters)645 void GenPythonOp::AddDefLine(const string& parameters) {
646   AddDefLine(function_name_, parameters);
647 }
648 
AddDocStringDescription()649 void GenPythonOp::AddDocStringDescription() {
650   string comment;
651   if (api_def_.summary().empty()) {
652     comment = "TODO: add doc.\n";
653   } else {
654     comment = strings::StrCat(api_def_.summary(), "\n");
655     if (!api_def_.description().empty()) {
656       strings::StrAppend(&comment, "\n", Indent(2, 2, api_def_.description()));
657     }
658   }
659   strings::StrAppend(&result_, "  r\"\"\"", comment, "\n");
660 }
661 
AddDocStringArgs()662 void GenPythonOp::AddDocStringArgs() {
663   strings::StrAppend(&result_, "  Args:\n");
664 }
665 
AddDocStringInputs()666 void GenPythonOp::AddDocStringInputs() {
667   for (int i = 0; i < api_def_.arg_order_size(); ++i) {
668     const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
669     const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
670     StringPiece description = api_def_arg.description();
671     string desc;
672     if (ConsumeEquals(&description)) {  // Skip the generated type info.
673       desc = strings::StrCat(param_names_[i].GetRenameTo(), ": ");
674     } else {
675       desc = strings::StrCat(param_names_[i].GetRenameTo(), ": ",
676                              ArgTypeName(op_def_, arg, inferred_attrs_, false));
677     }
678     if (!description.empty()) {
679       AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */);
680     }
681     strings::StrAppend(&result_, Indent(4, 6, desc));
682   }
683 }
684 
AddDocStringAttrs()685 void GenPythonOp::AddDocStringAttrs() {
686   for (const string& name : attrs_) {
687     const auto& attr = *FindAttr(name, op_def_);
688     const auto& api_def_attr = *FindAttr(name, api_def_);
689     string desc =
690         strings::StrCat(AvoidPythonReserved(api_def_attr.rename_to()), ": ");
691 
692     static const char* const kAttrTypeName[][2] = {
693         {"string", "`string`"},
694         {"list(string)", "list of `strings`"},
695         {"int", "`int`"},
696         {"list(int)", "list of `ints`"},
697         {"float", "`float`"},
698         {"list(float)", "list of `floats`"},
699         {"bool", "`bool`"},
700         {"list(bool)", "list of `bools`"},
701         {"type", "`tf.DType`"},
702         {"list(type)", "list of `tf.DTypes`"},
703         {"shape", "`tf.TensorShape` or list of `ints`"},
704         {"list(shape)",
705          "list of shapes (each a `tf.TensorShape` or list of `ints`)"},
706         {"tensor", "`tf.TensorProto`"},
707         {"list(tensor)", "list of `tf.TensorProto` objects"},
708         {"func", "function decorated with @Defun"},
709         {"list(func)", "list of functions decorated with @Defun"},
710     };
711     for (size_t i = 0; i < TF_ARRAYSIZE(kAttrTypeName); ++i) {
712       if (attr.type() == kAttrTypeName[i][0]) {
713         string s;
714         if (api_def_attr.has_default_value()) {
715           s = strings::StrCat("optional ", kAttrTypeName[i][1]);
716         } else {
717           s = kAttrTypeName[i][1];
718         }
719         if (s[0] == 'o' || (s[0] == '`' && (s[1] == 'i' || s[1] == 'o'))) {
720           strings::StrAppend(&desc, "An ", s);
721         } else {
722           strings::StrAppend(&desc, "A ", s);
723         }
724         break;
725       }
726     }
727 
728     if (attr.has_allowed_values()) {
729       strings::StrAppend(&desc, " from: `",
730                          AttrListToPython(attr.allowed_values()), "`");
731     }
732 
733     if (attr.has_minimum()) {
734       if (attr.type() == "int") {
735         strings::StrAppend(&desc, " that is `>= ", attr.minimum(), "`");
736       } else if (attr.minimum() > 0) {
737         strings::StrAppend(&desc, " that has length `>= ", attr.minimum(), "`");
738       }
739     }
740 
741     strings::StrAppend(&desc, ".");
742 
743     if (api_def_attr.has_default_value()) {
744       strings::StrAppend(
745           &desc, " Defaults to `",
746           AttrValueToPython(attr.type(), api_def_attr.default_value()), "`.");
747     }
748     if (!api_def_attr.description().empty()) {
749       AppendWithinWidth(&desc, api_def_attr.description(),
750                         kRightMargin - 4 /* indent */);
751     }
752     strings::StrAppend(&result_, Indent(4, 6, desc));
753   }
754 }
755 
AddDocStringNameArg()756 void GenPythonOp::AddDocStringNameArg() {
757   strings::StrAppend(&result_,
758                      "    name: A name for the operation (optional).\n");
759 }
760 
AddOutputGlobals()761 void GenPythonOp::AddOutputGlobals() {
762   // Prepare a NamedTuple type to hold the outputs, if there are multiple
763   if (num_outs_ > 1) {
764     // Prepare the list of output names
765     std::vector<string> out_names(num_outs_);
766     for (int i = 0; i < num_outs_; ++i) {
767       if (!api_def_.out_arg(i).rename_to().empty()) {
768         out_names[i] = api_def_.out_arg(i).rename_to();
769       } else {
770         out_names[i] = strings::StrCat("output", i);
771       }
772     }
773     string out_names_list =
774         strings::StrCat("[\"", str_util::Join(out_names, "\", \""), "\"]");
775 
776     // Provide the output names as a Python list
777     string lower_op_name_outputs =
778         strings::StrCat("_", function_name_, "_outputs");
779     const string outputs_prefix = strings::StrCat(lower_op_name_outputs, " = ");
780     strings::StrAppend(&prelude_, "\n",
781                        WordWrap(outputs_prefix, out_names_list, kRightMargin),
782                        "\n");
783 
784     strings::StrAppend(&prelude_, "_", op_def_.name(),
785                        "Output = _collections.namedtuple(\n");
786     const string tuple_type_prefix = "    ";
787     const string tuple_type_suffix = strings::StrCat(
788         "\"", op_def_.name(), "\", ", lower_op_name_outputs, ")");
789     strings::StrAppend(
790         &prelude_, WordWrap(tuple_type_prefix, tuple_type_suffix, kRightMargin),
791         "\n\n");
792   }
793   strings::StrAppend(&prelude_, "\n");
794 }
795 
AddDocStringOutputs()796 void GenPythonOp::AddDocStringOutputs() {
797   std::vector<string> output_type_string;
798   output_type_string.reserve(num_outs_);
799   for (int i = 0; i < num_outs_; ++i) {
800     output_type_string.push_back(
801         ArgTypeName(op_def_, op_def_.output_arg(i), inferred_attrs_, true));
802   }
803   strings::StrAppend(&result_, GetReturns(op_def_, output_type_string));
804 }
805 
AddBody(const string & prefix)806 void GenPythonOp::AddBody(const string& prefix) {
807   const string apply_prefix = strings::StrCat(
808       prefix, "_result = _op_def_lib.apply_op(\"", op_def_.name(), "\", ");
809   AddBodyNoReturn(apply_prefix);
810   if (num_outs_ > 1) {
811     strings::StrAppend(&result_, prefix, "_result = _", op_def_.name(),
812                        "Output._make(_result)\n");
813   }
814   strings::StrAppend(&result_, prefix, "return _result\n");
815 }
816 
AddBodyNoReturn(const string & apply_prefix)817 void GenPythonOp::AddBodyNoReturn(const string& apply_prefix) {
818   string args;
819   for (size_t i = 0; i < param_names_.size(); ++i) {
820     strings::StrAppend(&args, AvoidPythonReserved(param_names_[i].GetName()),
821                        "=", param_names_[i].GetRenameTo(), ", ");
822   }
823   strings::StrAppend(&args, "name=name)");
824 
825   strings::StrAppend(&result_,
826                      // Wrap the arguments, and indent to the (.
827                      WordWrap(apply_prefix, args, kRightMargin), "\n");
828 }
829 
830 }  // namespace python_op_gen_internal
831 }  // namespace tensorflow
832