• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <unordered_map>
17 #include <unordered_set>
18 #include <vector>
19 
20 #include "tensorflow/cc/framework/cc_op_gen.h"
21 #include "tensorflow/core/framework/api_def.pb.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/attr_value_util.h"
24 #include "tensorflow/core/framework/op_def_util.h"
25 #include "tensorflow/core/framework/op_gen_lib.h"
26 #include "tensorflow/core/framework/tensor.pb.h"
27 #include "tensorflow/core/framework/tensor_shape.pb.h"
28 #include "tensorflow/core/framework/types.pb_text.h"
29 #include "tensorflow/core/lib/gtl/map_util.h"
30 #include "tensorflow/core/lib/gtl/stl_util.h"
31 #include "tensorflow/core/lib/hash/hash.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/platform/env.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/types.h"
37 #include "tensorflow/core/public/version.h"
38 
39 namespace tensorflow {
40 namespace {
41 
42 const int kRightMargin = 79;
43 
44 // Converts:
45 //   bazel-out/.../genfiles/(external/YYY/)?XX
46 // to: XX.
GetPath(const string & dot_h_fname)47 string GetPath(const string& dot_h_fname) {
48   auto pos = dot_h_fname.find("/genfiles/");
49   string result = dot_h_fname;
50   if (pos != string::npos) {
51     // - 1 account for the terminating null character (\0) in "/genfiles/".
52     result = dot_h_fname.substr(pos + sizeof("/genfiles/") - 1);
53   }
54   if (result.size() > sizeof("external/") &&
55       result.compare(0, sizeof("external/") - 1, "external/") == 0) {
56     result = result.substr(sizeof("external/") - 1);
57     pos = result.find("/");
58     if (pos != string::npos) {
59       result = result.substr(pos + 1);
60     }
61   }
62   return result;
63 }
64 
65 // Converts: some/path/to/file.xx
66 // to: file
67 // (note that suffix is removed)
GetFilename(const string & path)68 string GetFilename(const string& path) {
69   size_t slash_pos = path.rfind('/');
70   if (slash_pos == path.npos) slash_pos = -1;
71   size_t dot_pos = path.rfind('.');
72   return path.substr(slash_pos + 1, dot_pos - (slash_pos + 1));
73 }
74 
75 // Converts:
76 //   cc/ops/gen_foo_ops.h
77 // to:
78 //   CC_OPS_GEN_FOO_OPS_H_
ToGuard(const string & path)79 string ToGuard(const string& path) {
80   string guard;
81   guard.reserve(path.size() + 1);  // + 1 -> trailing _
82   for (const char c : path) {
83     if (c >= 'A' && c <= 'Z') {
84       guard += c;
85     } else if (c >= 'a' && c <= 'z') {
86       guard += c + 'A' - 'a';
87     } else {
88       guard += '_';
89     }
90   }
91   guard += '_';
92   return guard;
93 }
94 
95 // Converts: some_name_xyz
96 // to: Some Name Xyz
ToTitle(const string & name)97 string ToTitle(const string& name) {
98   string title = name;
99   for (int i = 0; i < title.size(); ++i) {
100     if (title[i] == '_') title[i] = ' ';
101   }
102   str_util::TitlecaseString(&title, " ");
103   return title;
104 }
105 
106 // Change:     Into:
107 //   ABC         /// ABC
108 //               ///
109 //   DEF         /// DEF
MakeComment(StringPiece text,StringPiece indent)110 string MakeComment(StringPiece text, StringPiece indent) {
111   string ret;
112   while (!text.empty()) {
113     int last_non_space = -1;
114     int newline;
115     for (newline = 0; newline < static_cast<int>(text.size()); ++newline) {
116       if (text[newline] == '\n') break;
117       if (text[newline] != ' ') last_non_space = newline;
118     }
119     if (last_non_space == -1) {
120       strings::StrAppend(&ret, indent, "///\n");
121     } else {
122       strings::StrAppend(&ret, indent, "/// ",
123                          text.substr(0, last_non_space + 1), "\n");
124     }
125     text.remove_prefix(newline + 1);
126   }
127   return ret;
128 }
129 
PrintString(const string & str)130 string PrintString(const string& str) {
131   return strings::StrCat("\"", str_util::CEscape(str), "\"");
132 }
133 
PrintTensorShape(const TensorShapeProto & shape_proto)134 string PrintTensorShape(const TensorShapeProto& shape_proto) {
135   PartialTensorShape shape(shape_proto);
136   if (shape.IsIdenticalTo(PartialTensorShape())) {
137     return "::tensorflow::PartialTensorShape() /* unknown */";
138   }
139   string ret = "{";
140   for (int d = 0; d < shape.dims(); ++d) {
141     if (d > 0) strings::StrAppend(&ret, ", ");
142     strings::StrAppend(&ret, shape.dim_size(d));
143   }
144   strings::StrAppend(&ret, "}");
145   return ret;
146 }
147 
148 template <typename T>
PrintArray(int64 num_elts,const T * array)149 string PrintArray(int64 num_elts, const T* array) {
150   string ret;
151   for (int64 i = 0; i < num_elts; ++i) {
152     if (i > 0) strings::StrAppend(&ret, ", ");
153     strings::StrAppend(&ret, array[i]);
154   }
155   return ret;
156 }
157 
PrintTensor(const TensorProto & tensor_proto)158 string PrintTensor(const TensorProto& tensor_proto) {
159   Tensor t(tensor_proto.dtype());
160   CHECK(t.FromProto(tensor_proto));
161   const int64 num_elts = t.NumElements();
162   switch (t.dtype()) {
163     case DT_FLOAT:
164       return PrintArray(num_elts, t.flat<float>().data());
165     case DT_DOUBLE:
166       return PrintArray(num_elts, t.flat<double>().data());
167     case DT_INT32:
168       return PrintArray(num_elts, t.flat<int32>().data());
169     case DT_UINT8:
170     case DT_QUINT8:
171       return PrintArray(num_elts, t.flat<uint8>().data());
172     case DT_UINT16:
173     case DT_QUINT16:
174       return PrintArray(num_elts, t.flat<uint16>().data());
175     case DT_INT16:
176     case DT_QINT16:
177       return PrintArray(num_elts, t.flat<int16>().data());
178     case DT_INT8:
179     case DT_QINT8:
180       return PrintArray(num_elts, t.flat<int8>().data());
181     case DT_INT64:
182       return PrintArray(num_elts, t.flat<int64>().data());
183     case DT_BOOL:
184       return PrintArray(num_elts, t.flat<bool>().data());
185     case DT_STRING: {
186       string ret;
187       for (int64 i = 0; i < num_elts; ++i) {
188         if (i > 0) strings::StrAppend(&ret, " ");
189         strings::StrAppend(&ret, str_util::CEscape(t.flat<string>()(i)));
190       }
191       return ret;
192     }
193     default: {
194       LOG(FATAL) << "Not handling type " << EnumName_DataType(t.dtype());
195       return string();
196     }
197   }
198 }
199 
PrintTensorProto(const TensorProto & proto)200 string PrintTensorProto(const TensorProto& proto) {
201   return strings::StrCat("Input::Initializer(", "{", PrintTensor(proto), "}, ",
202                          PrintTensorShape(proto.tensor_shape()),
203                          ").AsTensorProto()");
204 }
205 
PrintAttrValue(const string & op,const AttrValue & attr_value)206 string PrintAttrValue(const string& op, const AttrValue& attr_value) {
207   switch (attr_value.value_case()) {
208     case AttrValue::kS:
209       return PrintString(attr_value.s());
210     case AttrValue::kI:
211       return strings::StrCat(attr_value.i());
212     case AttrValue::kF: {
213       const float f = attr_value.f();
214       return strings::StrCat(attr_value.f(), floorf(f) == f ? ".0" : "", "f");
215     }
216     case AttrValue::kB:
217       return attr_value.b() ? "true" : "false";
218     case AttrValue::kType:
219       return EnumName_DataType(attr_value.type());
220     case AttrValue::kShape:
221       return PrintTensorShape(attr_value.shape());
222     case AttrValue::kTensor:
223       return PrintTensorProto(attr_value.tensor());
224     case AttrValue::kList: {
225       string ret = "{";
226       if (attr_value.list().s_size() > 0) {
227         for (int i = 0; i < attr_value.list().s_size(); ++i) {
228           if (i > 0) strings::StrAppend(&ret, ", ");
229           strings::StrAppend(&ret, PrintString(attr_value.list().s(i)));
230         }
231       } else if (attr_value.list().i_size() > 0) {
232         for (int i = 0; i < attr_value.list().i_size(); ++i) {
233           if (i > 0) strings::StrAppend(&ret, ", ");
234           strings::StrAppend(&ret, attr_value.list().i(i));
235         }
236       } else if (attr_value.list().f_size() > 0) {
237         for (int i = 0; i < attr_value.list().f_size(); ++i) {
238           if (i > 0) strings::StrAppend(&ret, ", ");
239           const float f = attr_value.list().f(i);
240           strings::StrAppend(&ret, f, floorf(f) == f ? ".0" : "", "f");
241         }
242       } else if (attr_value.list().b_size() > 0) {
243         for (int i = 0; i < attr_value.list().b_size(); ++i) {
244           if (i > 0) strings::StrAppend(&ret, ", ");
245           strings::StrAppend(&ret, attr_value.list().b(i) ? "true" : "false");
246         }
247       } else if (attr_value.list().type_size() > 0) {
248         for (int i = 0; i < attr_value.list().type_size(); ++i) {
249           if (i > 0) strings::StrAppend(&ret, ", ");
250           strings::StrAppend(&ret,
251                              EnumName_DataType(attr_value.list().type(i)));
252         }
253       } else if (attr_value.list().shape_size() > 0) {
254         for (int i = 0; i < attr_value.list().shape_size(); ++i) {
255           if (i > 0) strings::StrAppend(&ret, ", ");
256           strings::StrAppend(&ret,
257                              PrintTensorShape(attr_value.list().shape(i)));
258         }
259       } else if (attr_value.list().tensor_size() > 0) {
260         for (int i = 0; i < attr_value.list().tensor_size(); ++i) {
261           if (i > 0) strings::StrAppend(&ret, ", ");
262           strings::StrAppend(&ret,
263                              PrintTensorProto(attr_value.list().tensor(i)));
264         }
265       }
266       strings::StrAppend(&ret, "}");
267       return ret;
268     }
269     default:
270       LOG(FATAL) << "Unsupported Attr type: " << op << " "
271                  << attr_value.value_case();
272   }
273   return "<Unknown AttrValue type>";  // Prevent missing return warning
274 }
275 
IsEmptyList(const AttrValue::ListValue & list)276 bool IsEmptyList(const AttrValue::ListValue& list) {
277   return list.s_size() == 0 && list.i_size() == 0 && list.f_size() == 0 &&
278          list.b_size() == 0 && list.type_size() == 0 &&
279          list.shape_size() == 0 && list.tensor_size() == 0;
280 }
281 
ToCamelCase(const string & str)282 string ToCamelCase(const string& str) {
283   string result;
284   const char joiner = '_';
285   size_t i = 0;
286   bool cap = true;
287   while (i < str.size()) {
288     const char c = str[i++];
289     if (c == joiner) {
290       cap = true;
291     } else if (cap) {
292       result += toupper(c);
293       cap = false;
294     } else {
295       result += c;
296     }
297   }
298   return result;
299 }
300 
301 // Returns a <string, bool> pair. The string is the C++ type name to be used for
302 // attr_type when defining an object of that type. The bool is a flag to
303 // indicate whether to treat the type as const when accepting the C++ type as an
304 // argument to a function.
AttrTypeName(StringPiece attr_type)305 std::pair<const char*, bool> AttrTypeName(StringPiece attr_type) {
306   static const auto* attr_type_map =
307       new std::unordered_map<StringPiece, std::pair<const char*, bool>,
308                              StringPieceHasher>{
309           {"string", {"StringPiece", false}},
310           {"list(string)", {"gtl::ArraySlice<string>", true}},
311           {"int", {"int64", false}},
312           {"list(int)", {"gtl::ArraySlice<int>", true}},
313           {"float", {"float", false}},
314           {"list(float)", {"gtl::ArraySlice<float>", true}},
315           {"bool", {"bool", false}},
316           {"list(bool)", {"gtl::ArraySlice<bool>", true}},
317           {"type", {"DataType", false}},
318           {"list(type)", {"DataTypeSlice", true}},
319           {"shape", {"PartialTensorShape", false}},
320           {"list(shape)", {"gtl::ArraySlice<PartialTensorShape>", true}},
321           {"tensor", {"TensorProto", true}},
322           {"list(tensor)", {"gtl::ArraySlice<TensorProto>", true}},
323           {"func", {"NameAttrList", true}},
324           {"list(func)", {"gtl::ArraySlice<NameAttrList>", true}},
325       };
326 
327   auto entry = attr_type_map->find(attr_type);
328   if (entry == attr_type_map->end()) {
329     LOG(FATAL) << "Unsupported Attr type: " << attr_type;
330     return {"", false};
331   }
332   return entry->second;
333 }
334 
ListElementTypeName(StringPiece attr_type)335 const char* ListElementTypeName(StringPiece attr_type) {
336   static const auto* attr_list_type_map =
337       new std::unordered_map<StringPiece, const char*, StringPieceHasher>{
338           {"list(string)", "string"},
339           {"list(int)", "int"},
340           {"list(float)", "float"},
341           {"list(bool)", "bool"},
342           {"list(type)", "DataType"},
343           {"list(shape)", "PartialTensorShape"},
344           {"list(tensor)", "TensorProto"},
345       };
346 
347   auto entry = attr_list_type_map->find(attr_type);
348   if (entry == attr_list_type_map->end()) {
349     LOG(FATAL) << "Unsupported or non-list Attr type: " << attr_type;
350     return "";
351   }
352   return entry->second;
353 }
354 
IsCPPKeyword(StringPiece name)355 bool IsCPPKeyword(StringPiece name) {
356   static const std::unordered_set<StringPiece, StringPieceHasher>
357       // Keywords obtained from http://en.cppreference.com/w/cpp/keyword
358       kCPPReserved{
359           "alignas",
360           "alignof",
361           "and",
362           "and_eq",
363           "asm",
364           "atomic_cancel",
365           "atomic_commit",
366           "atomic_noexcept",
367           "auto",
368           "bitand",
369           "bitor",
370           "bool",
371           "break",
372           "case",
373           "catch",
374           "char",
375           "char16_t",
376           "char32_t",
377           "class",
378           "compl",
379           "concept",
380           "const",
381           "const_cast",
382           "constexpr",
383           "continue",
384           "decltype",
385           "default",
386           "delete",
387           "do",
388           "double",
389           "dynamic_cast",
390           "else",
391           "enum",
392           "explicit",
393           "export",
394           "extern",
395           "false",
396           "final",
397           "float",
398           "for",
399           "friend",
400           "goto",
401           "if",
402           "import",
403           "inline",
404           "int",
405           "long",
406           "module",
407           "mutable",
408           "namespace",
409           "new",
410           "noexcept",
411           "not",
412           "not_eq",
413           "nullptr",
414           "operator",
415           "or",
416           "or_eq",
417           "override",
418           "private",
419           "protected",
420           "public",
421           "register",
422           "reinterpret_cast",
423           "requires",
424           "return",
425           "short",
426           "signed",
427           "sizeof",
428           "static",
429           "static_assert",
430           "static_cast",
431           "struct",
432           "switch",
433           "synchronized",
434           "template",
435           "this",
436           "thread_local",
437           "throw",
438           "true",
439           "try",
440           "typedef",
441           "typeid",
442           "typename",
443           "union",
444           "unsigned",
445           "using",
446           "virtual",
447           "void",
448           "volatile",
449           "wchar_t",
450           "while",
451           "xor",
452           "xor_eq",
453 
454           // The following are not C++ keywords, but names of local variables
455           // and parameters used in the op constructor. Treating them as
456           // keywords, so that other parameter names don't conflict with these.
457           "builder",
458           "node",
459           "ret",
460           "scope",
461           "unique_name",
462       };
463   return kCPPReserved.count(name) > 0;
464 }
465 
AvoidCPPKeywords(StringPiece name)466 string AvoidCPPKeywords(StringPiece name) {
467   if (IsCPPKeyword(name)) {
468     return strings::StrCat(name, "_");
469   }
470   return string(name);
471 }
472 
InferArgAttributes(const OpDef::ArgDef & arg,std::unordered_map<string,string> * inferred_attrs)473 void InferArgAttributes(const OpDef::ArgDef& arg,
474                         std::unordered_map<string, string>* inferred_attrs) {
475   if (!arg.type_attr().empty()) {
476     gtl::InsertIfNotPresent(inferred_attrs, arg.type_attr(), arg.name());
477   } else if (!arg.type_list_attr().empty()) {
478     gtl::InsertIfNotPresent(inferred_attrs, arg.type_list_attr(), arg.name());
479   }
480   if (!arg.number_attr().empty()) {
481     gtl::InsertIfNotPresent(inferred_attrs, arg.number_attr(), arg.name());
482   }
483 }
484 
InferOpAttributes(const OpDef & op_def,std::unordered_map<string,string> * inferred_input_attrs)485 void InferOpAttributes(
486     const OpDef& op_def,
487     std::unordered_map<string, string>* inferred_input_attrs) {
488   for (int i = 0; i < op_def.input_arg_size(); ++i) {
489     const auto& arg(op_def.input_arg(i));
490     InferArgAttributes(arg, inferred_input_attrs);
491   }
492 }
493 
ArgIsList(const OpDef::ArgDef & arg)494 bool ArgIsList(const OpDef::ArgDef& arg) {
495   return !arg.type_list_attr().empty() || !arg.number_attr().empty();
496 }
497 
HasOptionalAttrs(const ApiDef & api_def,const std::unordered_map<string,string> & inferred_input_attrs)498 bool HasOptionalAttrs(
499     const ApiDef& api_def,
500     const std::unordered_map<string, string>& inferred_input_attrs) {
501   for (int i = 0; i < api_def.attr_size(); ++i) {
502     const auto& attr(api_def.attr(i));
503     if ((inferred_input_attrs.find(attr.name()) ==
504          inferred_input_attrs.end()) &&
505         attr.has_default_value()) {
506       return true;
507     }
508   }
509   return false;
510 }
511 
512 struct OpInfo {
513   // graph_op_def: The OpDef used by the runtime, has the names that
514   //   must be used when calling NodeBuilder.
515   // interface_op_def: The OpDef used in the interface in the generated
516   //   code, with possibly overridden names and defaults.
517   explicit OpInfo(const OpDef& graph_op_def, const ApiDef& api_def,
518                   const std::vector<string>& aliases);
519   string GetOpAttrStruct() const;
520   string GetConstructorDecl(StringPiece op_name_prefix,
521                             bool include_attr) const;
522   void WriteClassDecl(WritableFile* h) const;
523   void GetOutput(string* out) const;
524   string GetConstructorBody() const;
525   void WriteClassDef(WritableFile* cc) const;
526 
527   string op_name;
528   std::vector<string> arg_types;
529   std::vector<string> arg_names;
530   std::vector<string> output_types;
531   std::vector<string> output_names;
532   std::vector<bool> is_list_output;
533   bool has_optional_attrs;
534   string comment;
535 
536   const OpDef& graph_op_def;
537   const ApiDef& api_def;
538   const std::vector<string>& aliases;
539   // Map from type attribute to corresponding original argument name.
540   std::unordered_map<string, string> inferred_input_attrs;
541 };
542 
OpInfo(const OpDef & graph_op_def,const ApiDef & api_def,const std::vector<string> & aliases)543 OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def,
544                const std::vector<string>& aliases)
545     : graph_op_def(graph_op_def), api_def(api_def), aliases(aliases) {
546   op_name = api_def.endpoint(0).name();
547   InferOpAttributes(graph_op_def, &inferred_input_attrs);
548   has_optional_attrs = HasOptionalAttrs(api_def, inferred_input_attrs);
549   arg_types.push_back("const ::tensorflow::Scope&");
550   arg_names.push_back("scope");
551 
552   if (graph_op_def.has_deprecation()) {
553     if (!api_def.summary().empty()) {
554       comment = strings::StrCat(api_def.summary(), "\n");
555     }
556     strings::StrAppend(&comment, "DEPRECATED at GraphDef version ",
557                        graph_op_def.deprecation().version(), ":\n",
558                        graph_op_def.deprecation().explanation(), ".\n");
559   } else if (api_def.summary().empty()) {
560     comment = "TODO: add doc.\n";
561   } else {
562     comment = strings::StrCat(api_def.summary(), "\n");
563   }
564   if (!api_def.description().empty()) {
565     strings::StrAppend(&comment, "\n", api_def.description(), "\n");
566   }
567   strings::StrAppend(&comment, "\nArguments:\n* scope: A Scope object\n");
568 
569   // Process inputs
570   for (int i = 0; i < api_def.arg_order_size(); ++i) {
571     const auto& arg = *FindInputArg(api_def.arg_order(i), graph_op_def);
572     const auto& api_def_arg = *FindInputArg(api_def.arg_order(i), api_def);
573     arg_types.push_back(strings::StrCat(
574         "::tensorflow::", ArgIsList(arg) ? "InputList" : "Input"));
575     arg_names.push_back(AvoidCPPKeywords(api_def_arg.rename_to()));
576 
577     // TODO(keveman): Include input type information.
578     StringPiece description = api_def_arg.description();
579     if (!description.empty()) {
580       ConsumeEquals(&description);
581       strings::StrAppend(&comment, "* ",
582                          AvoidCPPKeywords(api_def_arg.rename_to()), ": ",
583                          api_def_arg.description(), "\n");
584     }
585   }
586 
587   // Process attrs
588   string required_attrs_comment;
589   string optional_attrs_comment;
590   for (int i = 0; i < graph_op_def.attr_size(); ++i) {
591     // ApiDef attributes must be in the same order as in OpDef since
592     // we initialize ApiDef based on OpDef.
593     const auto& attr(graph_op_def.attr(i));
594     const auto& api_def_attr(api_def.attr(i));
595     CHECK_EQ(attr.name(), api_def_attr.name());
596     // Skip inferred arguments
597     if (inferred_input_attrs.count(attr.name()) > 0) continue;
598 
599     const auto entry = AttrTypeName(attr.type());
600     const auto attr_type_name = entry.first;
601     const bool use_const = entry.second;
602     string attr_name = AvoidCPPKeywords(api_def_attr.rename_to());
603 
604     string attr_comment;
605     if (!api_def_attr.description().empty()) {
606       // TODO(keveman): Word wrap and indent this, to handle multi-line
607       // descriptions.
608       strings::StrAppend(&attr_comment, "* ", attr_name, ": ",
609                          api_def_attr.description(), "\n");
610     }
611     if (api_def_attr.has_default_value()) {
612       strings::StrAppend(&optional_attrs_comment, attr_comment);
613     } else {
614       strings::StrAppend(&required_attrs_comment, attr_comment);
615       arg_types.push_back(strings::StrCat(
616           use_const ? "const " : "", attr_type_name, use_const ? "&" : ""));
617       arg_names.push_back(attr_name);
618     }
619   }
620 
621   strings::StrAppend(&comment, required_attrs_comment);
622 
623   if (!optional_attrs_comment.empty()) {
624     strings::StrAppend(&comment, "\nOptional attributes (see `Attrs`):\n");
625     strings::StrAppend(&comment, optional_attrs_comment);
626   }
627 
628   // Process outputs
629   for (int i = 0; i < graph_op_def.output_arg_size(); ++i) {
630     // ApiDef arguments must be in the same order as in OpDef since
631     // we initialize ApiDef based on OpDef.
632     const auto& arg = graph_op_def.output_arg(i);
633     const auto& api_def_arg(api_def.out_arg(i));
634     CHECK_EQ(arg.name(), api_def_arg.name());
635 
636     bool is_list = ArgIsList(arg);
637     output_types.push_back(
638         strings::StrCat("::tensorflow::", is_list ? "OutputList" : "Output"));
639     output_names.push_back(AvoidCPPKeywords(api_def_arg.rename_to()));
640     is_list_output.push_back(is_list);
641   }
642 
643   strings::StrAppend(&comment, "\nReturns:\n");
644   if (graph_op_def.output_arg_size() == 0) {  // No outputs.
645     strings::StrAppend(&comment, "* the created `Operation`\n");
646   } else if (graph_op_def.output_arg_size() == 1) {  // One output
647     if (is_list_output[0]) {
648       strings::StrAppend(&comment, "* `OutputList`: ");
649     } else {
650       strings::StrAppend(&comment, "* `Output`: ");
651     }
652     if (api_def.out_arg(0).description().empty()) {
653       strings::StrAppend(&comment, "The ", api_def.out_arg(0).name(),
654                          " tensor.\n");
655     } else {
656       // TODO(josh11b): Word wrap this.
657       strings::StrAppend(&comment, api_def.out_arg(0).description(), "\n");
658     }
659   } else {  // Multiple outputs.
660     for (int i = 0; i < graph_op_def.output_arg_size(); ++i) {
661       if (is_list_output[i]) {
662         strings::StrAppend(&comment, "* `OutputList`");
663       } else {
664         strings::StrAppend(&comment, "* `Output`");
665       }
666       strings::StrAppend(&comment, " ", output_names[i]);
667       if (api_def.out_arg(i).description().empty()) {
668         strings::StrAppend(&comment, "\n");
669       } else {
670         // TODO(josh11b): Word wrap this.
671         strings::StrAppend(&comment, ": ", api_def.out_arg(i).description(),
672                            "\n");
673       }
674     }
675   }
676 
677   if (!aliases.empty()) {
678     strings::StrAppend(&comment, "\nAliases:\n");
679     for (const auto& alias : aliases) {
680       strings::StrAppend(&comment, "* ", alias, "\n");
681     }
682   }
683   comment = MakeComment(comment, "");
684 }
685 
GetOpAttrStruct() const686 string OpInfo::GetOpAttrStruct() const {
687   string struct_fields;
688   string setters;
689   string defaults_static_storage;
690 
691   for (int i = 0; i < graph_op_def.attr_size(); ++i) {
692     const auto& attr(graph_op_def.attr(i));
693     const auto& api_def_attr(api_def.attr(i));
694     // If attr will be inferred or it doesn't have a default value, don't
695     // add it to the struct.
696     if ((inferred_input_attrs.find(attr.name()) !=
697          inferred_input_attrs.end()) ||
698         !api_def_attr.has_default_value()) {
699       continue;
700     }
701     const auto entry = AttrTypeName(attr.type());
702     const auto attr_type_name = entry.first;
703     const bool use_const = entry.second;
704     const string camel_case_name = ToCamelCase(api_def_attr.rename_to());
705     const string suffix =
706         (camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : "";
707     const string attr_func_def =
708         strings::StrCat(camel_case_name, suffix, "(", use_const ? "const " : "",
709                         attr_type_name, use_const ? "&" : "");
710 
711     string attr_comment;
712     if (!api_def_attr.description().empty()) {
713       strings::StrAppend(&attr_comment, api_def_attr.description(), "\n\n");
714     }
715     strings::StrAppend(&attr_comment, "Defaults to ",
716                        SummarizeAttrValue(api_def_attr.default_value()), "\n");
717     attr_comment = MakeComment(attr_comment, "    ");
718 
719     strings::StrAppend(&setters, attr_comment);
720     strings::StrAppend(&setters, "    TF_MUST_USE_RESULT Attrs ", attr_func_def,
721                        " x) {\n");
722     strings::StrAppend(&setters, "      Attrs ret = *this;\n");
723     strings::StrAppend(&setters, "      ret.", api_def_attr.rename_to(),
724                        "_ = x;\n");
725     strings::StrAppend(&setters, "      return ret;\n    }\n\n");
726 
727     string field_initiliazer;
728     auto& default_value = api_def_attr.default_value();
729     if (default_value.value_case() == AttrValue::kList &&
730         !IsEmptyList(default_value.list())) {
731       // Non-empty lists need static storage for their defaults. Define a
732       // function with static local variable that stores the array.
733       strings::StrAppend(&defaults_static_storage, "    static ",
734                          attr_type_name, " Default_", api_def_attr.rename_to(),
735                          "() {\n");
736       strings::StrAppend(
737           &defaults_static_storage, "      static const ",
738           ListElementTypeName(attr.type()), " kStorage[] = ",
739           PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()),
740           ";\n");
741       strings::StrAppend(&defaults_static_storage, "      return ",
742                          attr_type_name, "(kStorage);\n    }\n");
743       // Set the field_initializer to call the defined function.
744       strings::StrAppend(&field_initiliazer, "Default_",
745                          api_def_attr.rename_to(), "()");
746     } else {
747       field_initiliazer =
748           PrintAttrValue(graph_op_def.name(), api_def_attr.default_value());
749     }
750     strings::StrAppend(&struct_fields, "    ", attr_type_name, " ",
751                        api_def_attr.rename_to(), "_ = ", field_initiliazer,
752                        ";\n");
753   }
754 
755   if (struct_fields.empty()) {
756     return "";
757   }
758 
759   string attrs_comment =
760       strings::StrCat("Optional attribute setters for ", op_name, "\n");
761   string struct_decl = MakeComment(attrs_comment, "  ");
762   strings::StrAppend(&struct_decl, "  struct Attrs {\n");
763   strings::StrAppend(&struct_decl, setters, struct_fields);
764   if (!defaults_static_storage.empty()) {
765     strings::StrAppend(&struct_decl, "  private:\n", defaults_static_storage);
766   }
767   strings::StrAppend(&struct_decl, "  };\n");
768 
769   return struct_decl;
770 }
771 
GetConstructorDecl(StringPiece op_name_prefix,bool include_attr) const772 string OpInfo::GetConstructorDecl(StringPiece op_name_prefix,
773                                   bool include_attr) const {
774   const string prefix = strings::StrCat(op_name_prefix, op_name, "(");
775   string c_decl;
776   for (int i = 0; i < arg_types.size(); ++i) {
777     if (i > 0) strings::StrAppend(&c_decl, ", ");
778     strings::StrAppend(&c_decl, arg_types[i], " ", arg_names[i]);
779   }
780   if (include_attr && has_optional_attrs) {
781     strings::StrAppend(&c_decl, ", const ", op_name, "::Attrs& attrs");
782   }
783   strings::StrAppend(&c_decl, ")");
784   return WordWrap(prefix, c_decl, kRightMargin);
785 }
786 
WriteClassDecl(WritableFile * h) const787 void OpInfo::WriteClassDecl(WritableFile* h) const {
788   string class_decl = comment;
789   strings::StrAppend(&class_decl, "class ", op_name, " {\n");
790   strings::StrAppend(&class_decl, " public:\n");
791   if (has_optional_attrs) {
792     strings::StrAppend(&class_decl, GetOpAttrStruct());
793   }
794   strings::StrAppend(&class_decl, "  ",
795                      GetConstructorDecl("", /* include_attr */ false), ";\n");
796   if (has_optional_attrs) {
797     strings::StrAppend(&class_decl, "  ",
798                        GetConstructorDecl("", /* include_attr */ true), ";\n");
799   }
800   if (output_types.empty()) {
801     // Allow casting this class to Operation.
802     strings::StrAppend(&class_decl,
803                        "  operator ::tensorflow::Operation() const { "
804                        "return operation; }\n");
805   } else if (output_types.size() == 1) {
806     if (is_list_output[0]) {
807       // Write the subscript operator, allowing out[i] for the list-typed
808       // output.
809       strings::StrAppend(&class_decl,
810                          "  ::tensorflow::Output operator[](size_t index) "
811                          "const { return ",
812                          output_names[0], "[index]; }\n\n");
813 
814     } else {
815       // Write type cast functions, allowing casting this class to Input and
816       // Output.
817       strings::StrAppend(&class_decl,
818                          "  operator ::tensorflow::Output() const { return ",
819                          output_names[0], "; }\n");
820       strings::StrAppend(&class_decl,
821                          "  operator ::tensorflow::Input() const { return ",
822                          output_names[0], "; }\n");
823       // Write node() to get the Node* directly.
824       strings::StrAppend(&class_decl,
825                          "  ::tensorflow::Node* node() const { return ",
826                          output_names[0], ".node(); }\n");
827     }
828   }
829   // Add the static functions to set optional attrs
830   if (has_optional_attrs) {
831     strings::StrAppend(&class_decl, "\n");
832     for (int i = 0; i < graph_op_def.attr_size(); ++i) {
833       const auto& attr(graph_op_def.attr(i));
834       const auto& api_def_attr(api_def.attr(i));
835       if ((inferred_input_attrs.find(attr.name()) !=
836            inferred_input_attrs.end()) ||
837           !api_def_attr.has_default_value()) {
838         continue;
839       }
840       const auto entry = AttrTypeName(attr.type());
841       const auto attr_type_name = entry.first;
842       const bool use_const = entry.second;
843       const string camel_case_name = ToCamelCase(api_def_attr.rename_to());
844       const string suffix =
845           (camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : "";
846       const string attr_func_def = strings::StrCat(
847           camel_case_name, suffix, "(", use_const ? "const " : "",
848           attr_type_name, use_const ? "&" : "");
849       strings::StrAppend(&class_decl, "  static Attrs ", attr_func_def,
850                          " x) {\n");
851       strings::StrAppend(&class_decl, "    return Attrs().", camel_case_name,
852                          suffix, "(x);\n");
853       strings::StrAppend(&class_decl, "  }\n");
854     }
855   }
856 
857   strings::StrAppend(&class_decl, "\n  Operation operation;\n");
858   for (int i = 0; i < output_types.size(); ++i) {
859     strings::StrAppend(&class_decl, "  ", output_types[i], " ", output_names[i],
860                        ";\n");
861   }
862 
863   strings::StrAppend(&class_decl, "};\n");
864   if (!aliases.empty()) {
865     for (const auto& alias : aliases) {
866       strings::StrAppend(&class_decl, "typedef ", op_name, " ", alias, ";\n");
867     }
868   }
869   strings::StrAppend(&class_decl, "\n");
870   TF_CHECK_OK(h->Append(class_decl));
871 }
872 
GetOutput(string * out) const873 void OpInfo::GetOutput(string* out) const {
874   const string scope_str = arg_names[0];
875   string return_on_error =
876       strings::StrCat("if (!", scope_str, ".ok()) return;");
877 
878   strings::StrAppend(out, "  this->operation = Operation(ret);\n");
879 
880   // No outputs.
881   if (graph_op_def.output_arg_size() == 0) {
882     strings::StrAppend(out, "  return;\n");
883     return;
884   }
885   if (graph_op_def.output_arg_size() == 1) {
886     // One output, no need for NameRangeMap
887     if (is_list_output[0]) {
888       strings::StrAppend(out,
889                          "  for (int32 i = 0; i < ret->num_outputs(); ++i)\n");
890       strings::StrAppend(out, "    this->", output_names[0],
891                          ".push_back(Output(ret, i));\n");
892     } else {
893       strings::StrAppend(out, "  this->", output_names[0],
894                          " = Output(ret, 0);\n");
895     }
896     return;
897   }
898   strings::StrAppend(out, "  ::tensorflow::NameRangeMap _outputs_range;\n");
899   strings::StrAppend(out,
900                      "  ::tensorflow::Status _status_ = "
901                      "::tensorflow::NameRangesForNode(*ret, ret->op_def(), "
902                      "nullptr, &_outputs_range);\n");
903   strings::StrAppend(out, "  if (!_status_.ok()) {\n", "    ", scope_str,
904                      ".UpdateStatus(_status_);\n", "    return;\n");
905   strings::StrAppend(out, "  }\n\n");
906 
907   for (int i = 0; i < graph_op_def.output_arg_size(); ++i) {
908     const string arg_range = strings::StrCat(
909         "_outputs_range[\"", graph_op_def.output_arg(i).name(), "\"]");
910     if (is_list_output[i]) {
911       strings::StrAppend(out, "  for (int32 i = ", arg_range, ".first; i < ",
912                          arg_range, ".second; ++i)\n");
913       strings::StrAppend(out, "    this->", output_names[i],
914                          ".push_back(Output(ret, i));\n");
915     } else {
916       strings::StrAppend(out, "  this->", output_names[i], " = Output(ret, ",
917                          arg_range, ".first);\n");
918     }
919   }
920 }
921 
GetConstructorBody() const922 string OpInfo::GetConstructorBody() const {
923   const string scope_str = arg_names[0];
924 
925   string body;
926   string return_on_error =
927       strings::StrCat("if (!", scope_str, ".ok()) return;");
928 
929   strings::StrAppend(&body, "  ", return_on_error, "\n");
930 
931   for (int i = 0; i < graph_op_def.input_arg_size(); ++i) {
932     const auto& arg(graph_op_def.input_arg(i));
933     const auto& api_def_arg(api_def.in_arg(i));
934     strings::StrAppend(
935         &body, "  auto _", api_def_arg.rename_to(), " = ::tensorflow::ops::",
936         ArgIsList(arg) ? "AsNodeOutList" : "AsNodeOut", "(", scope_str, ", ",
937         AvoidCPPKeywords(api_def_arg.rename_to()), ");\n");
938     strings::StrAppend(&body, "  ", return_on_error, "\n");
939   }
940 
941   strings::StrAppend(&body, "  ::tensorflow::Node* ret;\n");
942   strings::StrAppend(&body, "  const auto unique_name = ", scope_str,
943                      ".GetUniqueNameForOp(\"", op_name, "\");\n");
944   strings::StrAppend(
945       &body, "  auto builder = ::tensorflow::NodeBuilder(unique_name, \"",
946       graph_op_def.name(), "\")\n");
947   const string spaces = "                     ";
948   for (int i = 0; i < api_def.in_arg_size(); ++i) {
949     const auto& arg(api_def.in_arg(i));
950     strings::StrAppend(&body, spaces, ".Input(_", arg.rename_to(), ")\n");
951   }
952   for (int i = 0; i < api_def.attr_size(); ++i) {
953     const auto& graph_attr(graph_op_def.attr(i));
954     const auto& api_def_attr(api_def.attr(i));
955     if (inferred_input_attrs.find(api_def_attr.name()) !=
956         inferred_input_attrs.end()) {
957       continue;
958     }
959     const string attr_name =
960         api_def_attr.has_default_value()
961             ? strings::StrCat("attrs.", api_def_attr.rename_to(), "_")
962             : AvoidCPPKeywords(api_def_attr.rename_to());
963     strings::StrAppend(&body, spaces, ".Attr(\"", graph_attr.name(), "\", ",
964                        attr_name, ")\n");
965   }
966   strings::StrAppend(&body, "  ;\n");
967   strings::StrAppend(&body, "  ", scope_str, ".UpdateBuilder(&builder);\n");
968   strings::StrAppend(&body, "  ", scope_str, ".UpdateStatus(builder.Finalize(",
969                      scope_str, ".graph(), &ret));\n");
970   strings::StrAppend(&body, "  ", return_on_error, "\n");
971   strings::StrAppend(&body, "  ", scope_str, ".UpdateStatus(", scope_str,
972                      ".DoShapeInference(ret));\n");
973 
974   GetOutput(&body);
975   return body;
976 }
977 
WriteClassDef(WritableFile * cc) const978 void OpInfo::WriteClassDef(WritableFile* cc) const {
979   string class_def;
980   strings::StrAppend(&class_def,
981                      GetConstructorDecl(strings::StrCat(op_name, "::"),
982                                         /* include_attr */ true),
983                      " {\n");
984   strings::StrAppend(&class_def, GetConstructorBody());
985   strings::StrAppend(&class_def, "}\n\n");
986 
987   if (has_optional_attrs) {
988     strings::StrAppend(&class_def,
989                        GetConstructorDecl(strings::StrCat(op_name, "::"),
990                                           /* include_attr */ false));
991     strings::StrAppend(&class_def, "\n  : ", op_name, "(");
992     int i = 0;
993     for (; i < arg_names.size(); ++i) {
994       if (i > 0) strings::StrAppend(&class_def, ", ");
995       strings::StrAppend(&class_def, arg_names[i]);
996     }
997     if (i > 0) strings::StrAppend(&class_def, ", ");
998     strings::StrAppend(&class_def, op_name, "::Attrs()");
999     strings::StrAppend(&class_def, ") {}\n\n");
1000   }
1001   TF_CHECK_OK(cc->Append(class_def));
1002 }
1003 
WriteCCOp(const OpDef & graph_op_def,const ApiDef & api_def,const std::vector<string> & aliases,WritableFile * h,WritableFile * cc)1004 void WriteCCOp(const OpDef& graph_op_def, const ApiDef& api_def,
1005                const std::vector<string>& aliases, WritableFile* h,
1006                WritableFile* cc) {
1007   OpInfo op_info(graph_op_def, api_def, aliases);
1008 
1009   op_info.WriteClassDecl(h);
1010   op_info.WriteClassDef(cc);
1011 }
1012 
StartFiles(bool internal,const string & dot_h_fname,WritableFile * h,WritableFile * cc,string * op_header_guard)1013 void StartFiles(bool internal, const string& dot_h_fname, WritableFile* h,
1014                 WritableFile* cc, string* op_header_guard) {
1015   const string header =
1016       R"header(// This file is MACHINE GENERATED! Do not edit.
1017 
1018 #include "tensorflow/cc/framework/ops.h"
1019 #include "tensorflow/cc/framework/scope.h"
1020 #include "tensorflow/core/framework/tensor.h"
1021 #include "tensorflow/core/framework/tensor_shape.h"
1022 #include "tensorflow/core/framework/types.h"
1023 #include "tensorflow/core/lib/gtl/array_slice.h"
1024 )header";
1025 
1026   // TODO(keveman): Make namespaces configurable.
1027   const string namespace_begin = internal ? R"namespace(
1028 namespace tensorflow {
1029 namespace ops {
1030 namespace internal {
1031 // NOTE: This namespace has internal TensorFlow details that
1032 // are not part of TensorFlow's public API.
1033 
1034 )namespace"
1035                                           : R"namespace(
1036 namespace tensorflow {
1037 namespace ops {
1038 
1039 )namespace";
1040 
1041   const string op_header = GetPath(dot_h_fname);
1042   *op_header_guard = ToGuard(op_header);
1043   const string cc_header = strings::StrCat(
1044       R"include(// This file is MACHINE GENERATED! Do not edit.
1045 
1046 
1047 #include "tensorflow/cc/ops/const_op.h"
1048 )include",
1049       "#include \"", op_header, "\"\n", namespace_begin);
1050 
1051   const string filename = GetFilename(dot_h_fname);
1052   const string doxygen = strings::StrCat("/// @defgroup ", filename, " ",
1053                                          ToTitle(filename), "\n", "/// @{\n\n");
1054 
1055   TF_CHECK_OK(h->Append(
1056       strings::StrCat("// This file is MACHINE GENERATED! Do not edit.\n\n"
1057                       "#ifndef ",
1058                       *op_header_guard,
1059                       "\n"
1060                       "#define ",
1061                       *op_header_guard, "\n\n")));
1062   TF_CHECK_OK(h->Append(header));
1063   TF_CHECK_OK(h->Append(namespace_begin));
1064   TF_CHECK_OK(h->Append(doxygen));
1065   TF_CHECK_OK(cc->Append(cc_header));
1066 }
1067 
FinishFiles(bool internal,WritableFile * h,WritableFile * cc,const string & op_header_guard)1068 void FinishFiles(bool internal, WritableFile* h, WritableFile* cc,
1069                  const string& op_header_guard) {
1070   const string footer = internal ? R"footer(}  // namespace internal
1071 }  // namespace ops
1072 }  // namespace tensorflow
1073 )footer"
1074                                  :
1075                                  R"footer(/// @}
1076 
1077 }  // namespace ops
1078 }  // namespace tensorflow
1079 )footer";
1080 
1081   TF_CHECK_OK(h->Append(footer));
1082   TF_CHECK_OK(
1083       h->Append(strings::StrCat("\n#endif  ", "// ", op_header_guard, "\n")));
1084   TF_CHECK_OK(cc->Append(footer));
1085 
1086   TF_CHECK_OK(cc->Close());
1087   TF_CHECK_OK(h->Close());
1088 }
1089 
MakeInternal(const string & fname)1090 string MakeInternal(const string& fname) {
1091   auto dot_pos = fname.rfind('.');
1092   if (dot_pos == string::npos) {
1093     return strings::StrCat(fname, "_internal");
1094   } else {
1095     return strings::StrCat(fname.substr(0, dot_pos), "_internal",
1096                            fname.substr(dot_pos));
1097   }
1098 }
1099 
1100 }  // namespace
1101 
WriteCCOps(const OpList & ops,const ApiDefMap & api_def_map,const string & dot_h_fname,const string & dot_cc_fname)1102 void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map,
1103                 const string& dot_h_fname, const string& dot_cc_fname) {
1104   Env* env = Env::Default();
1105 
1106   // Write the initial boilerplate to the .h and .cc files.
1107   std::unique_ptr<WritableFile> h = nullptr;
1108   std::unique_ptr<WritableFile> cc = nullptr;
1109   TF_CHECK_OK(env->NewWritableFile(dot_h_fname, &h));
1110   TF_CHECK_OK(env->NewWritableFile(dot_cc_fname, &cc));
1111   string op_header_guard;
1112   StartFiles(false, dot_h_fname, h.get(), cc.get(), &op_header_guard);
1113 
1114   // Create the internal versions of these files for the hidden ops.
1115   std::unique_ptr<WritableFile> internal_h = nullptr;
1116   std::unique_ptr<WritableFile> internal_cc = nullptr;
1117   const string internal_dot_h_fname = MakeInternal(dot_h_fname);
1118   TF_CHECK_OK(env->NewWritableFile(internal_dot_h_fname, &internal_h));
1119   TF_CHECK_OK(env->NewWritableFile(MakeInternal(dot_cc_fname), &internal_cc));
1120   string internal_op_header_guard;
1121   StartFiles(true /* internal */, internal_dot_h_fname, internal_h.get(),
1122              internal_cc.get(), &internal_op_header_guard);
1123 
1124   for (const auto& graph_op_def : ops.op()) {
1125     // Skip deprecated ops.
1126     // TODO(josh11b): If needed, can put them into a "deprecated" namespace
1127     // instead of skipping.
1128     if (graph_op_def.has_deprecation() &&
1129         graph_op_def.deprecation().version() <= TF_GRAPH_DEF_VERSION) {
1130       continue;
1131     }
1132 
1133     // We use a hand-written wrapper for "Const", since the generated
1134     // code depends on it.
1135     if (graph_op_def.name() == "Const") continue;
1136 
1137     const auto* api_def = api_def_map.GetApiDef(graph_op_def.name());
1138 
1139     std::vector<string> aliases;
1140     if (api_def->visibility() == ApiDef::SKIP) continue;
1141     // First endpoint is canonical, the rest are aliases.
1142     for (int endpoint_i = 1; endpoint_i < api_def->endpoint_size();
1143          ++endpoint_i) {
1144       aliases.push_back(api_def->endpoint(endpoint_i).name());
1145     }
1146     if (api_def->visibility() == ApiDef::HIDDEN) {
1147       // Write hidden ops to _internal.h and _internal.cc.
1148       WriteCCOp(graph_op_def, *api_def, aliases, internal_h.get(),
1149                 internal_cc.get());
1150       continue;
1151     }
1152     // This isn't a hidden op, write it to the main files.
1153     WriteCCOp(graph_op_def, *api_def, aliases, h.get(), cc.get());
1154   }
1155 
1156   FinishFiles(false, h.get(), cc.get(), op_header_guard);
1157   FinishFiles(true /* internal */, internal_h.get(), internal_cc.get(),
1158               internal_op_header_guard);
1159 }
1160 
1161 }  // namespace tensorflow
1162