1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <unordered_map>
17 #include <unordered_set>
18 #include <vector>
19
20 #include "tensorflow/cc/framework/cc_op_gen.h"
21 #include "tensorflow/core/framework/api_def.pb.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/attr_value_util.h"
24 #include "tensorflow/core/framework/op_def_util.h"
25 #include "tensorflow/core/framework/op_gen_lib.h"
26 #include "tensorflow/core/framework/tensor.pb.h"
27 #include "tensorflow/core/framework/tensor_shape.pb.h"
28 #include "tensorflow/core/framework/types.pb_text.h"
29 #include "tensorflow/core/lib/gtl/map_util.h"
30 #include "tensorflow/core/lib/gtl/stl_util.h"
31 #include "tensorflow/core/lib/hash/hash.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/platform/env.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/types.h"
37 #include "tensorflow/core/public/version.h"
38
39 namespace tensorflow {
40 namespace {
41
42 const int kRightMargin = 79;
43
44 // Converts:
45 // bazel-out/.../genfiles/(external/YYY/)?XX
46 // to: XX.
GetPath(const string & dot_h_fname)47 string GetPath(const string& dot_h_fname) {
48 auto pos = dot_h_fname.find("/genfiles/");
49 string result = dot_h_fname;
50 if (pos != string::npos) {
51 // - 1 account for the terminating null character (\0) in "/genfiles/".
52 result = dot_h_fname.substr(pos + sizeof("/genfiles/") - 1);
53 }
54 if (result.size() > sizeof("external/") &&
55 result.compare(0, sizeof("external/") - 1, "external/") == 0) {
56 result = result.substr(sizeof("external/") - 1);
57 pos = result.find("/");
58 if (pos != string::npos) {
59 result = result.substr(pos + 1);
60 }
61 }
62 return result;
63 }
64
65 // Converts: some/path/to/file.xx
66 // to: file
67 // (note that suffix is removed)
GetFilename(const string & path)68 string GetFilename(const string& path) {
69 size_t slash_pos = path.rfind('/');
70 if (slash_pos == path.npos) slash_pos = -1;
71 size_t dot_pos = path.rfind('.');
72 return path.substr(slash_pos + 1, dot_pos - (slash_pos + 1));
73 }
74
75 // Converts:
76 // cc/ops/gen_foo_ops.h
77 // to:
78 // CC_OPS_GEN_FOO_OPS_H_
ToGuard(const string & path)79 string ToGuard(const string& path) {
80 string guard;
81 guard.reserve(path.size() + 1); // + 1 -> trailing _
82 for (const char c : path) {
83 if (c >= 'A' && c <= 'Z') {
84 guard += c;
85 } else if (c >= 'a' && c <= 'z') {
86 guard += c + 'A' - 'a';
87 } else {
88 guard += '_';
89 }
90 }
91 guard += '_';
92 return guard;
93 }
94
95 // Converts: some_name_xyz
96 // to: Some Name Xyz
ToTitle(const string & name)97 string ToTitle(const string& name) {
98 string title = name;
99 for (int i = 0; i < title.size(); ++i) {
100 if (title[i] == '_') title[i] = ' ';
101 }
102 str_util::TitlecaseString(&title, " ");
103 return title;
104 }
105
106 // Change: Into:
107 // ABC /// ABC
108 // ///
109 // DEF /// DEF
MakeComment(StringPiece text,StringPiece indent)110 string MakeComment(StringPiece text, StringPiece indent) {
111 string ret;
112 while (!text.empty()) {
113 int last_non_space = -1;
114 int newline;
115 for (newline = 0; newline < static_cast<int>(text.size()); ++newline) {
116 if (text[newline] == '\n') break;
117 if (text[newline] != ' ') last_non_space = newline;
118 }
119 if (last_non_space == -1) {
120 strings::StrAppend(&ret, indent, "///\n");
121 } else {
122 strings::StrAppend(&ret, indent, "/// ",
123 text.substr(0, last_non_space + 1), "\n");
124 }
125 text.remove_prefix(newline + 1);
126 }
127 return ret;
128 }
129
PrintString(const string & str)130 string PrintString(const string& str) {
131 return strings::StrCat("\"", str_util::CEscape(str), "\"");
132 }
133
PrintTensorShape(const TensorShapeProto & shape_proto)134 string PrintTensorShape(const TensorShapeProto& shape_proto) {
135 PartialTensorShape shape(shape_proto);
136 if (shape.IsIdenticalTo(PartialTensorShape())) {
137 return "::tensorflow::PartialTensorShape() /* unknown */";
138 }
139 string ret = "{";
140 for (int d = 0; d < shape.dims(); ++d) {
141 if (d > 0) strings::StrAppend(&ret, ", ");
142 strings::StrAppend(&ret, shape.dim_size(d));
143 }
144 strings::StrAppend(&ret, "}");
145 return ret;
146 }
147
148 template <typename T>
PrintArray(int64 num_elts,const T * array)149 string PrintArray(int64 num_elts, const T* array) {
150 string ret;
151 for (int64 i = 0; i < num_elts; ++i) {
152 if (i > 0) strings::StrAppend(&ret, ", ");
153 strings::StrAppend(&ret, array[i]);
154 }
155 return ret;
156 }
157
PrintTensor(const TensorProto & tensor_proto)158 string PrintTensor(const TensorProto& tensor_proto) {
159 Tensor t(tensor_proto.dtype());
160 CHECK(t.FromProto(tensor_proto));
161 const int64 num_elts = t.NumElements();
162 switch (t.dtype()) {
163 case DT_FLOAT:
164 return PrintArray(num_elts, t.flat<float>().data());
165 case DT_DOUBLE:
166 return PrintArray(num_elts, t.flat<double>().data());
167 case DT_INT32:
168 return PrintArray(num_elts, t.flat<int32>().data());
169 case DT_UINT8:
170 case DT_QUINT8:
171 return PrintArray(num_elts, t.flat<uint8>().data());
172 case DT_UINT16:
173 case DT_QUINT16:
174 return PrintArray(num_elts, t.flat<uint16>().data());
175 case DT_INT16:
176 case DT_QINT16:
177 return PrintArray(num_elts, t.flat<int16>().data());
178 case DT_INT8:
179 case DT_QINT8:
180 return PrintArray(num_elts, t.flat<int8>().data());
181 case DT_INT64:
182 return PrintArray(num_elts, t.flat<int64>().data());
183 case DT_BOOL:
184 return PrintArray(num_elts, t.flat<bool>().data());
185 case DT_STRING: {
186 string ret;
187 for (int64 i = 0; i < num_elts; ++i) {
188 if (i > 0) strings::StrAppend(&ret, " ");
189 strings::StrAppend(&ret, str_util::CEscape(t.flat<string>()(i)));
190 }
191 return ret;
192 }
193 default: {
194 LOG(FATAL) << "Not handling type " << EnumName_DataType(t.dtype());
195 return string();
196 }
197 }
198 }
199
PrintTensorProto(const TensorProto & proto)200 string PrintTensorProto(const TensorProto& proto) {
201 return strings::StrCat("Input::Initializer(", "{", PrintTensor(proto), "}, ",
202 PrintTensorShape(proto.tensor_shape()),
203 ").AsTensorProto()");
204 }
205
PrintAttrValue(const string & op,const AttrValue & attr_value)206 string PrintAttrValue(const string& op, const AttrValue& attr_value) {
207 switch (attr_value.value_case()) {
208 case AttrValue::kS:
209 return PrintString(attr_value.s());
210 case AttrValue::kI:
211 return strings::StrCat(attr_value.i());
212 case AttrValue::kF: {
213 const float f = attr_value.f();
214 return strings::StrCat(attr_value.f(), floorf(f) == f ? ".0" : "", "f");
215 }
216 case AttrValue::kB:
217 return attr_value.b() ? "true" : "false";
218 case AttrValue::kType:
219 return EnumName_DataType(attr_value.type());
220 case AttrValue::kShape:
221 return PrintTensorShape(attr_value.shape());
222 case AttrValue::kTensor:
223 return PrintTensorProto(attr_value.tensor());
224 case AttrValue::kList: {
225 string ret = "{";
226 if (attr_value.list().s_size() > 0) {
227 for (int i = 0; i < attr_value.list().s_size(); ++i) {
228 if (i > 0) strings::StrAppend(&ret, ", ");
229 strings::StrAppend(&ret, PrintString(attr_value.list().s(i)));
230 }
231 } else if (attr_value.list().i_size() > 0) {
232 for (int i = 0; i < attr_value.list().i_size(); ++i) {
233 if (i > 0) strings::StrAppend(&ret, ", ");
234 strings::StrAppend(&ret, attr_value.list().i(i));
235 }
236 } else if (attr_value.list().f_size() > 0) {
237 for (int i = 0; i < attr_value.list().f_size(); ++i) {
238 if (i > 0) strings::StrAppend(&ret, ", ");
239 const float f = attr_value.list().f(i);
240 strings::StrAppend(&ret, f, floorf(f) == f ? ".0" : "", "f");
241 }
242 } else if (attr_value.list().b_size() > 0) {
243 for (int i = 0; i < attr_value.list().b_size(); ++i) {
244 if (i > 0) strings::StrAppend(&ret, ", ");
245 strings::StrAppend(&ret, attr_value.list().b(i) ? "true" : "false");
246 }
247 } else if (attr_value.list().type_size() > 0) {
248 for (int i = 0; i < attr_value.list().type_size(); ++i) {
249 if (i > 0) strings::StrAppend(&ret, ", ");
250 strings::StrAppend(&ret,
251 EnumName_DataType(attr_value.list().type(i)));
252 }
253 } else if (attr_value.list().shape_size() > 0) {
254 for (int i = 0; i < attr_value.list().shape_size(); ++i) {
255 if (i > 0) strings::StrAppend(&ret, ", ");
256 strings::StrAppend(&ret,
257 PrintTensorShape(attr_value.list().shape(i)));
258 }
259 } else if (attr_value.list().tensor_size() > 0) {
260 for (int i = 0; i < attr_value.list().tensor_size(); ++i) {
261 if (i > 0) strings::StrAppend(&ret, ", ");
262 strings::StrAppend(&ret,
263 PrintTensorProto(attr_value.list().tensor(i)));
264 }
265 }
266 strings::StrAppend(&ret, "}");
267 return ret;
268 }
269 default:
270 LOG(FATAL) << "Unsupported Attr type: " << op << " "
271 << attr_value.value_case();
272 }
273 return "<Unknown AttrValue type>"; // Prevent missing return warning
274 }
275
IsEmptyList(const AttrValue::ListValue & list)276 bool IsEmptyList(const AttrValue::ListValue& list) {
277 return list.s_size() == 0 && list.i_size() == 0 && list.f_size() == 0 &&
278 list.b_size() == 0 && list.type_size() == 0 &&
279 list.shape_size() == 0 && list.tensor_size() == 0;
280 }
281
ToCamelCase(const string & str)282 string ToCamelCase(const string& str) {
283 string result;
284 const char joiner = '_';
285 size_t i = 0;
286 bool cap = true;
287 while (i < str.size()) {
288 const char c = str[i++];
289 if (c == joiner) {
290 cap = true;
291 } else if (cap) {
292 result += toupper(c);
293 cap = false;
294 } else {
295 result += c;
296 }
297 }
298 return result;
299 }
300
301 // Returns a <string, bool> pair. The string is the C++ type name to be used for
302 // attr_type when defining an object of that type. The bool is a flag to
303 // indicate whether to treat the type as const when accepting the C++ type as an
304 // argument to a function.
AttrTypeName(StringPiece attr_type)305 std::pair<const char*, bool> AttrTypeName(StringPiece attr_type) {
306 static const auto* attr_type_map =
307 new std::unordered_map<StringPiece, std::pair<const char*, bool>,
308 StringPieceHasher>{
309 {"string", {"StringPiece", false}},
310 {"list(string)", {"gtl::ArraySlice<string>", true}},
311 {"int", {"int64", false}},
312 {"list(int)", {"gtl::ArraySlice<int>", true}},
313 {"float", {"float", false}},
314 {"list(float)", {"gtl::ArraySlice<float>", true}},
315 {"bool", {"bool", false}},
316 {"list(bool)", {"gtl::ArraySlice<bool>", true}},
317 {"type", {"DataType", false}},
318 {"list(type)", {"DataTypeSlice", true}},
319 {"shape", {"PartialTensorShape", false}},
320 {"list(shape)", {"gtl::ArraySlice<PartialTensorShape>", true}},
321 {"tensor", {"TensorProto", true}},
322 {"list(tensor)", {"gtl::ArraySlice<TensorProto>", true}},
323 {"func", {"NameAttrList", true}},
324 {"list(func)", {"gtl::ArraySlice<NameAttrList>", true}},
325 };
326
327 auto entry = attr_type_map->find(attr_type);
328 if (entry == attr_type_map->end()) {
329 LOG(FATAL) << "Unsupported Attr type: " << attr_type;
330 return {"", false};
331 }
332 return entry->second;
333 }
334
ListElementTypeName(StringPiece attr_type)335 const char* ListElementTypeName(StringPiece attr_type) {
336 static const auto* attr_list_type_map =
337 new std::unordered_map<StringPiece, const char*, StringPieceHasher>{
338 {"list(string)", "string"},
339 {"list(int)", "int"},
340 {"list(float)", "float"},
341 {"list(bool)", "bool"},
342 {"list(type)", "DataType"},
343 {"list(shape)", "PartialTensorShape"},
344 {"list(tensor)", "TensorProto"},
345 };
346
347 auto entry = attr_list_type_map->find(attr_type);
348 if (entry == attr_list_type_map->end()) {
349 LOG(FATAL) << "Unsupported or non-list Attr type: " << attr_type;
350 return "";
351 }
352 return entry->second;
353 }
354
IsCPPKeyword(StringPiece name)355 bool IsCPPKeyword(StringPiece name) {
356 static const std::unordered_set<StringPiece, StringPieceHasher>
357 // Keywords obtained from http://en.cppreference.com/w/cpp/keyword
358 kCPPReserved{
359 "alignas",
360 "alignof",
361 "and",
362 "and_eq",
363 "asm",
364 "atomic_cancel",
365 "atomic_commit",
366 "atomic_noexcept",
367 "auto",
368 "bitand",
369 "bitor",
370 "bool",
371 "break",
372 "case",
373 "catch",
374 "char",
375 "char16_t",
376 "char32_t",
377 "class",
378 "compl",
379 "concept",
380 "const",
381 "const_cast",
382 "constexpr",
383 "continue",
384 "decltype",
385 "default",
386 "delete",
387 "do",
388 "double",
389 "dynamic_cast",
390 "else",
391 "enum",
392 "explicit",
393 "export",
394 "extern",
395 "false",
396 "final",
397 "float",
398 "for",
399 "friend",
400 "goto",
401 "if",
402 "import",
403 "inline",
404 "int",
405 "long",
406 "module",
407 "mutable",
408 "namespace",
409 "new",
410 "noexcept",
411 "not",
412 "not_eq",
413 "nullptr",
414 "operator",
415 "or",
416 "or_eq",
417 "override",
418 "private",
419 "protected",
420 "public",
421 "register",
422 "reinterpret_cast",
423 "requires",
424 "return",
425 "short",
426 "signed",
427 "sizeof",
428 "static",
429 "static_assert",
430 "static_cast",
431 "struct",
432 "switch",
433 "synchronized",
434 "template",
435 "this",
436 "thread_local",
437 "throw",
438 "true",
439 "try",
440 "typedef",
441 "typeid",
442 "typename",
443 "union",
444 "unsigned",
445 "using",
446 "virtual",
447 "void",
448 "volatile",
449 "wchar_t",
450 "while",
451 "xor",
452 "xor_eq",
453
454 // The following are not C++ keywords, but names of local variables
455 // and parameters used in the op constructor. Treating them as
456 // keywords, so that other parameter names don't conflict with these.
457 "builder",
458 "node",
459 "ret",
460 "scope",
461 "unique_name",
462 };
463 return kCPPReserved.count(name) > 0;
464 }
465
AvoidCPPKeywords(StringPiece name)466 string AvoidCPPKeywords(StringPiece name) {
467 if (IsCPPKeyword(name)) {
468 return strings::StrCat(name, "_");
469 }
470 return string(name);
471 }
472
InferArgAttributes(const OpDef::ArgDef & arg,std::unordered_map<string,string> * inferred_attrs)473 void InferArgAttributes(const OpDef::ArgDef& arg,
474 std::unordered_map<string, string>* inferred_attrs) {
475 if (!arg.type_attr().empty()) {
476 gtl::InsertIfNotPresent(inferred_attrs, arg.type_attr(), arg.name());
477 } else if (!arg.type_list_attr().empty()) {
478 gtl::InsertIfNotPresent(inferred_attrs, arg.type_list_attr(), arg.name());
479 }
480 if (!arg.number_attr().empty()) {
481 gtl::InsertIfNotPresent(inferred_attrs, arg.number_attr(), arg.name());
482 }
483 }
484
InferOpAttributes(const OpDef & op_def,std::unordered_map<string,string> * inferred_input_attrs)485 void InferOpAttributes(
486 const OpDef& op_def,
487 std::unordered_map<string, string>* inferred_input_attrs) {
488 for (int i = 0; i < op_def.input_arg_size(); ++i) {
489 const auto& arg(op_def.input_arg(i));
490 InferArgAttributes(arg, inferred_input_attrs);
491 }
492 }
493
ArgIsList(const OpDef::ArgDef & arg)494 bool ArgIsList(const OpDef::ArgDef& arg) {
495 return !arg.type_list_attr().empty() || !arg.number_attr().empty();
496 }
497
HasOptionalAttrs(const ApiDef & api_def,const std::unordered_map<string,string> & inferred_input_attrs)498 bool HasOptionalAttrs(
499 const ApiDef& api_def,
500 const std::unordered_map<string, string>& inferred_input_attrs) {
501 for (int i = 0; i < api_def.attr_size(); ++i) {
502 const auto& attr(api_def.attr(i));
503 if ((inferred_input_attrs.find(attr.name()) ==
504 inferred_input_attrs.end()) &&
505 attr.has_default_value()) {
506 return true;
507 }
508 }
509 return false;
510 }
511
512 struct OpInfo {
513 // graph_op_def: The OpDef used by the runtime, has the names that
514 // must be used when calling NodeBuilder.
515 // interface_op_def: The OpDef used in the interface in the generated
516 // code, with possibly overridden names and defaults.
517 explicit OpInfo(const OpDef& graph_op_def, const ApiDef& api_def,
518 const std::vector<string>& aliases);
519 string GetOpAttrStruct() const;
520 string GetConstructorDecl(StringPiece op_name_prefix,
521 bool include_attr) const;
522 void WriteClassDecl(WritableFile* h) const;
523 void GetOutput(string* out) const;
524 string GetConstructorBody() const;
525 void WriteClassDef(WritableFile* cc) const;
526
527 string op_name;
528 std::vector<string> arg_types;
529 std::vector<string> arg_names;
530 std::vector<string> output_types;
531 std::vector<string> output_names;
532 std::vector<bool> is_list_output;
533 bool has_optional_attrs;
534 string comment;
535
536 const OpDef& graph_op_def;
537 const ApiDef& api_def;
538 const std::vector<string>& aliases;
539 // Map from type attribute to corresponding original argument name.
540 std::unordered_map<string, string> inferred_input_attrs;
541 };
542
OpInfo(const OpDef & graph_op_def,const ApiDef & api_def,const std::vector<string> & aliases)543 OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def,
544 const std::vector<string>& aliases)
545 : graph_op_def(graph_op_def), api_def(api_def), aliases(aliases) {
546 op_name = api_def.endpoint(0).name();
547 InferOpAttributes(graph_op_def, &inferred_input_attrs);
548 has_optional_attrs = HasOptionalAttrs(api_def, inferred_input_attrs);
549 arg_types.push_back("const ::tensorflow::Scope&");
550 arg_names.push_back("scope");
551
552 if (graph_op_def.has_deprecation()) {
553 if (!api_def.summary().empty()) {
554 comment = strings::StrCat(api_def.summary(), "\n");
555 }
556 strings::StrAppend(&comment, "DEPRECATED at GraphDef version ",
557 graph_op_def.deprecation().version(), ":\n",
558 graph_op_def.deprecation().explanation(), ".\n");
559 } else if (api_def.summary().empty()) {
560 comment = "TODO: add doc.\n";
561 } else {
562 comment = strings::StrCat(api_def.summary(), "\n");
563 }
564 if (!api_def.description().empty()) {
565 strings::StrAppend(&comment, "\n", api_def.description(), "\n");
566 }
567 strings::StrAppend(&comment, "\nArguments:\n* scope: A Scope object\n");
568
569 // Process inputs
570 for (int i = 0; i < api_def.arg_order_size(); ++i) {
571 const auto& arg = *FindInputArg(api_def.arg_order(i), graph_op_def);
572 const auto& api_def_arg = *FindInputArg(api_def.arg_order(i), api_def);
573 arg_types.push_back(strings::StrCat(
574 "::tensorflow::", ArgIsList(arg) ? "InputList" : "Input"));
575 arg_names.push_back(AvoidCPPKeywords(api_def_arg.rename_to()));
576
577 // TODO(keveman): Include input type information.
578 StringPiece description = api_def_arg.description();
579 if (!description.empty()) {
580 ConsumeEquals(&description);
581 strings::StrAppend(&comment, "* ",
582 AvoidCPPKeywords(api_def_arg.rename_to()), ": ",
583 api_def_arg.description(), "\n");
584 }
585 }
586
587 // Process attrs
588 string required_attrs_comment;
589 string optional_attrs_comment;
590 for (int i = 0; i < graph_op_def.attr_size(); ++i) {
591 // ApiDef attributes must be in the same order as in OpDef since
592 // we initialize ApiDef based on OpDef.
593 const auto& attr(graph_op_def.attr(i));
594 const auto& api_def_attr(api_def.attr(i));
595 CHECK_EQ(attr.name(), api_def_attr.name());
596 // Skip inferred arguments
597 if (inferred_input_attrs.count(attr.name()) > 0) continue;
598
599 const auto entry = AttrTypeName(attr.type());
600 const auto attr_type_name = entry.first;
601 const bool use_const = entry.second;
602 string attr_name = AvoidCPPKeywords(api_def_attr.rename_to());
603
604 string attr_comment;
605 if (!api_def_attr.description().empty()) {
606 // TODO(keveman): Word wrap and indent this, to handle multi-line
607 // descriptions.
608 strings::StrAppend(&attr_comment, "* ", attr_name, ": ",
609 api_def_attr.description(), "\n");
610 }
611 if (api_def_attr.has_default_value()) {
612 strings::StrAppend(&optional_attrs_comment, attr_comment);
613 } else {
614 strings::StrAppend(&required_attrs_comment, attr_comment);
615 arg_types.push_back(strings::StrCat(
616 use_const ? "const " : "", attr_type_name, use_const ? "&" : ""));
617 arg_names.push_back(attr_name);
618 }
619 }
620
621 strings::StrAppend(&comment, required_attrs_comment);
622
623 if (!optional_attrs_comment.empty()) {
624 strings::StrAppend(&comment, "\nOptional attributes (see `Attrs`):\n");
625 strings::StrAppend(&comment, optional_attrs_comment);
626 }
627
628 // Process outputs
629 for (int i = 0; i < graph_op_def.output_arg_size(); ++i) {
630 // ApiDef arguments must be in the same order as in OpDef since
631 // we initialize ApiDef based on OpDef.
632 const auto& arg = graph_op_def.output_arg(i);
633 const auto& api_def_arg(api_def.out_arg(i));
634 CHECK_EQ(arg.name(), api_def_arg.name());
635
636 bool is_list = ArgIsList(arg);
637 output_types.push_back(
638 strings::StrCat("::tensorflow::", is_list ? "OutputList" : "Output"));
639 output_names.push_back(AvoidCPPKeywords(api_def_arg.rename_to()));
640 is_list_output.push_back(is_list);
641 }
642
643 strings::StrAppend(&comment, "\nReturns:\n");
644 if (graph_op_def.output_arg_size() == 0) { // No outputs.
645 strings::StrAppend(&comment, "* the created `Operation`\n");
646 } else if (graph_op_def.output_arg_size() == 1) { // One output
647 if (is_list_output[0]) {
648 strings::StrAppend(&comment, "* `OutputList`: ");
649 } else {
650 strings::StrAppend(&comment, "* `Output`: ");
651 }
652 if (api_def.out_arg(0).description().empty()) {
653 strings::StrAppend(&comment, "The ", api_def.out_arg(0).name(),
654 " tensor.\n");
655 } else {
656 // TODO(josh11b): Word wrap this.
657 strings::StrAppend(&comment, api_def.out_arg(0).description(), "\n");
658 }
659 } else { // Multiple outputs.
660 for (int i = 0; i < graph_op_def.output_arg_size(); ++i) {
661 if (is_list_output[i]) {
662 strings::StrAppend(&comment, "* `OutputList`");
663 } else {
664 strings::StrAppend(&comment, "* `Output`");
665 }
666 strings::StrAppend(&comment, " ", output_names[i]);
667 if (api_def.out_arg(i).description().empty()) {
668 strings::StrAppend(&comment, "\n");
669 } else {
670 // TODO(josh11b): Word wrap this.
671 strings::StrAppend(&comment, ": ", api_def.out_arg(i).description(),
672 "\n");
673 }
674 }
675 }
676
677 if (!aliases.empty()) {
678 strings::StrAppend(&comment, "\nAliases:\n");
679 for (const auto& alias : aliases) {
680 strings::StrAppend(&comment, "* ", alias, "\n");
681 }
682 }
683 comment = MakeComment(comment, "");
684 }
685
GetOpAttrStruct() const686 string OpInfo::GetOpAttrStruct() const {
687 string struct_fields;
688 string setters;
689 string defaults_static_storage;
690
691 for (int i = 0; i < graph_op_def.attr_size(); ++i) {
692 const auto& attr(graph_op_def.attr(i));
693 const auto& api_def_attr(api_def.attr(i));
694 // If attr will be inferred or it doesn't have a default value, don't
695 // add it to the struct.
696 if ((inferred_input_attrs.find(attr.name()) !=
697 inferred_input_attrs.end()) ||
698 !api_def_attr.has_default_value()) {
699 continue;
700 }
701 const auto entry = AttrTypeName(attr.type());
702 const auto attr_type_name = entry.first;
703 const bool use_const = entry.second;
704 const string camel_case_name = ToCamelCase(api_def_attr.rename_to());
705 const string suffix =
706 (camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : "";
707 const string attr_func_def =
708 strings::StrCat(camel_case_name, suffix, "(", use_const ? "const " : "",
709 attr_type_name, use_const ? "&" : "");
710
711 string attr_comment;
712 if (!api_def_attr.description().empty()) {
713 strings::StrAppend(&attr_comment, api_def_attr.description(), "\n\n");
714 }
715 strings::StrAppend(&attr_comment, "Defaults to ",
716 SummarizeAttrValue(api_def_attr.default_value()), "\n");
717 attr_comment = MakeComment(attr_comment, " ");
718
719 strings::StrAppend(&setters, attr_comment);
720 strings::StrAppend(&setters, " TF_MUST_USE_RESULT Attrs ", attr_func_def,
721 " x) {\n");
722 strings::StrAppend(&setters, " Attrs ret = *this;\n");
723 strings::StrAppend(&setters, " ret.", api_def_attr.rename_to(),
724 "_ = x;\n");
725 strings::StrAppend(&setters, " return ret;\n }\n\n");
726
727 string field_initiliazer;
728 auto& default_value = api_def_attr.default_value();
729 if (default_value.value_case() == AttrValue::kList &&
730 !IsEmptyList(default_value.list())) {
731 // Non-empty lists need static storage for their defaults. Define a
732 // function with static local variable that stores the array.
733 strings::StrAppend(&defaults_static_storage, " static ",
734 attr_type_name, " Default_", api_def_attr.rename_to(),
735 "() {\n");
736 strings::StrAppend(
737 &defaults_static_storage, " static const ",
738 ListElementTypeName(attr.type()), " kStorage[] = ",
739 PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()),
740 ";\n");
741 strings::StrAppend(&defaults_static_storage, " return ",
742 attr_type_name, "(kStorage);\n }\n");
743 // Set the field_initializer to call the defined function.
744 strings::StrAppend(&field_initiliazer, "Default_",
745 api_def_attr.rename_to(), "()");
746 } else {
747 field_initiliazer =
748 PrintAttrValue(graph_op_def.name(), api_def_attr.default_value());
749 }
750 strings::StrAppend(&struct_fields, " ", attr_type_name, " ",
751 api_def_attr.rename_to(), "_ = ", field_initiliazer,
752 ";\n");
753 }
754
755 if (struct_fields.empty()) {
756 return "";
757 }
758
759 string attrs_comment =
760 strings::StrCat("Optional attribute setters for ", op_name, "\n");
761 string struct_decl = MakeComment(attrs_comment, " ");
762 strings::StrAppend(&struct_decl, " struct Attrs {\n");
763 strings::StrAppend(&struct_decl, setters, struct_fields);
764 if (!defaults_static_storage.empty()) {
765 strings::StrAppend(&struct_decl, " private:\n", defaults_static_storage);
766 }
767 strings::StrAppend(&struct_decl, " };\n");
768
769 return struct_decl;
770 }
771
GetConstructorDecl(StringPiece op_name_prefix,bool include_attr) const772 string OpInfo::GetConstructorDecl(StringPiece op_name_prefix,
773 bool include_attr) const {
774 const string prefix = strings::StrCat(op_name_prefix, op_name, "(");
775 string c_decl;
776 for (int i = 0; i < arg_types.size(); ++i) {
777 if (i > 0) strings::StrAppend(&c_decl, ", ");
778 strings::StrAppend(&c_decl, arg_types[i], " ", arg_names[i]);
779 }
780 if (include_attr && has_optional_attrs) {
781 strings::StrAppend(&c_decl, ", const ", op_name, "::Attrs& attrs");
782 }
783 strings::StrAppend(&c_decl, ")");
784 return WordWrap(prefix, c_decl, kRightMargin);
785 }
786
WriteClassDecl(WritableFile * h) const787 void OpInfo::WriteClassDecl(WritableFile* h) const {
788 string class_decl = comment;
789 strings::StrAppend(&class_decl, "class ", op_name, " {\n");
790 strings::StrAppend(&class_decl, " public:\n");
791 if (has_optional_attrs) {
792 strings::StrAppend(&class_decl, GetOpAttrStruct());
793 }
794 strings::StrAppend(&class_decl, " ",
795 GetConstructorDecl("", /* include_attr */ false), ";\n");
796 if (has_optional_attrs) {
797 strings::StrAppend(&class_decl, " ",
798 GetConstructorDecl("", /* include_attr */ true), ";\n");
799 }
800 if (output_types.empty()) {
801 // Allow casting this class to Operation.
802 strings::StrAppend(&class_decl,
803 " operator ::tensorflow::Operation() const { "
804 "return operation; }\n");
805 } else if (output_types.size() == 1) {
806 if (is_list_output[0]) {
807 // Write the subscript operator, allowing out[i] for the list-typed
808 // output.
809 strings::StrAppend(&class_decl,
810 " ::tensorflow::Output operator[](size_t index) "
811 "const { return ",
812 output_names[0], "[index]; }\n\n");
813
814 } else {
815 // Write type cast functions, allowing casting this class to Input and
816 // Output.
817 strings::StrAppend(&class_decl,
818 " operator ::tensorflow::Output() const { return ",
819 output_names[0], "; }\n");
820 strings::StrAppend(&class_decl,
821 " operator ::tensorflow::Input() const { return ",
822 output_names[0], "; }\n");
823 // Write node() to get the Node* directly.
824 strings::StrAppend(&class_decl,
825 " ::tensorflow::Node* node() const { return ",
826 output_names[0], ".node(); }\n");
827 }
828 }
829 // Add the static functions to set optional attrs
830 if (has_optional_attrs) {
831 strings::StrAppend(&class_decl, "\n");
832 for (int i = 0; i < graph_op_def.attr_size(); ++i) {
833 const auto& attr(graph_op_def.attr(i));
834 const auto& api_def_attr(api_def.attr(i));
835 if ((inferred_input_attrs.find(attr.name()) !=
836 inferred_input_attrs.end()) ||
837 !api_def_attr.has_default_value()) {
838 continue;
839 }
840 const auto entry = AttrTypeName(attr.type());
841 const auto attr_type_name = entry.first;
842 const bool use_const = entry.second;
843 const string camel_case_name = ToCamelCase(api_def_attr.rename_to());
844 const string suffix =
845 (camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : "";
846 const string attr_func_def = strings::StrCat(
847 camel_case_name, suffix, "(", use_const ? "const " : "",
848 attr_type_name, use_const ? "&" : "");
849 strings::StrAppend(&class_decl, " static Attrs ", attr_func_def,
850 " x) {\n");
851 strings::StrAppend(&class_decl, " return Attrs().", camel_case_name,
852 suffix, "(x);\n");
853 strings::StrAppend(&class_decl, " }\n");
854 }
855 }
856
857 strings::StrAppend(&class_decl, "\n Operation operation;\n");
858 for (int i = 0; i < output_types.size(); ++i) {
859 strings::StrAppend(&class_decl, " ", output_types[i], " ", output_names[i],
860 ";\n");
861 }
862
863 strings::StrAppend(&class_decl, "};\n");
864 if (!aliases.empty()) {
865 for (const auto& alias : aliases) {
866 strings::StrAppend(&class_decl, "typedef ", op_name, " ", alias, ";\n");
867 }
868 }
869 strings::StrAppend(&class_decl, "\n");
870 TF_CHECK_OK(h->Append(class_decl));
871 }
872
GetOutput(string * out) const873 void OpInfo::GetOutput(string* out) const {
874 const string scope_str = arg_names[0];
875 string return_on_error =
876 strings::StrCat("if (!", scope_str, ".ok()) return;");
877
878 strings::StrAppend(out, " this->operation = Operation(ret);\n");
879
880 // No outputs.
881 if (graph_op_def.output_arg_size() == 0) {
882 strings::StrAppend(out, " return;\n");
883 return;
884 }
885 if (graph_op_def.output_arg_size() == 1) {
886 // One output, no need for NameRangeMap
887 if (is_list_output[0]) {
888 strings::StrAppend(out,
889 " for (int32 i = 0; i < ret->num_outputs(); ++i)\n");
890 strings::StrAppend(out, " this->", output_names[0],
891 ".push_back(Output(ret, i));\n");
892 } else {
893 strings::StrAppend(out, " this->", output_names[0],
894 " = Output(ret, 0);\n");
895 }
896 return;
897 }
898 strings::StrAppend(out, " ::tensorflow::NameRangeMap _outputs_range;\n");
899 strings::StrAppend(out,
900 " ::tensorflow::Status _status_ = "
901 "::tensorflow::NameRangesForNode(*ret, ret->op_def(), "
902 "nullptr, &_outputs_range);\n");
903 strings::StrAppend(out, " if (!_status_.ok()) {\n", " ", scope_str,
904 ".UpdateStatus(_status_);\n", " return;\n");
905 strings::StrAppend(out, " }\n\n");
906
907 for (int i = 0; i < graph_op_def.output_arg_size(); ++i) {
908 const string arg_range = strings::StrCat(
909 "_outputs_range[\"", graph_op_def.output_arg(i).name(), "\"]");
910 if (is_list_output[i]) {
911 strings::StrAppend(out, " for (int32 i = ", arg_range, ".first; i < ",
912 arg_range, ".second; ++i)\n");
913 strings::StrAppend(out, " this->", output_names[i],
914 ".push_back(Output(ret, i));\n");
915 } else {
916 strings::StrAppend(out, " this->", output_names[i], " = Output(ret, ",
917 arg_range, ".first);\n");
918 }
919 }
920 }
921
GetConstructorBody() const922 string OpInfo::GetConstructorBody() const {
923 const string scope_str = arg_names[0];
924
925 string body;
926 string return_on_error =
927 strings::StrCat("if (!", scope_str, ".ok()) return;");
928
929 strings::StrAppend(&body, " ", return_on_error, "\n");
930
931 for (int i = 0; i < graph_op_def.input_arg_size(); ++i) {
932 const auto& arg(graph_op_def.input_arg(i));
933 const auto& api_def_arg(api_def.in_arg(i));
934 strings::StrAppend(
935 &body, " auto _", api_def_arg.rename_to(), " = ::tensorflow::ops::",
936 ArgIsList(arg) ? "AsNodeOutList" : "AsNodeOut", "(", scope_str, ", ",
937 AvoidCPPKeywords(api_def_arg.rename_to()), ");\n");
938 strings::StrAppend(&body, " ", return_on_error, "\n");
939 }
940
941 strings::StrAppend(&body, " ::tensorflow::Node* ret;\n");
942 strings::StrAppend(&body, " const auto unique_name = ", scope_str,
943 ".GetUniqueNameForOp(\"", op_name, "\");\n");
944 strings::StrAppend(
945 &body, " auto builder = ::tensorflow::NodeBuilder(unique_name, \"",
946 graph_op_def.name(), "\")\n");
947 const string spaces = " ";
948 for (int i = 0; i < api_def.in_arg_size(); ++i) {
949 const auto& arg(api_def.in_arg(i));
950 strings::StrAppend(&body, spaces, ".Input(_", arg.rename_to(), ")\n");
951 }
952 for (int i = 0; i < api_def.attr_size(); ++i) {
953 const auto& graph_attr(graph_op_def.attr(i));
954 const auto& api_def_attr(api_def.attr(i));
955 if (inferred_input_attrs.find(api_def_attr.name()) !=
956 inferred_input_attrs.end()) {
957 continue;
958 }
959 const string attr_name =
960 api_def_attr.has_default_value()
961 ? strings::StrCat("attrs.", api_def_attr.rename_to(), "_")
962 : AvoidCPPKeywords(api_def_attr.rename_to());
963 strings::StrAppend(&body, spaces, ".Attr(\"", graph_attr.name(), "\", ",
964 attr_name, ")\n");
965 }
966 strings::StrAppend(&body, " ;\n");
967 strings::StrAppend(&body, " ", scope_str, ".UpdateBuilder(&builder);\n");
968 strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(builder.Finalize(",
969 scope_str, ".graph(), &ret));\n");
970 strings::StrAppend(&body, " ", return_on_error, "\n");
971 strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(", scope_str,
972 ".DoShapeInference(ret));\n");
973
974 GetOutput(&body);
975 return body;
976 }
977
WriteClassDef(WritableFile * cc) const978 void OpInfo::WriteClassDef(WritableFile* cc) const {
979 string class_def;
980 strings::StrAppend(&class_def,
981 GetConstructorDecl(strings::StrCat(op_name, "::"),
982 /* include_attr */ true),
983 " {\n");
984 strings::StrAppend(&class_def, GetConstructorBody());
985 strings::StrAppend(&class_def, "}\n\n");
986
987 if (has_optional_attrs) {
988 strings::StrAppend(&class_def,
989 GetConstructorDecl(strings::StrCat(op_name, "::"),
990 /* include_attr */ false));
991 strings::StrAppend(&class_def, "\n : ", op_name, "(");
992 int i = 0;
993 for (; i < arg_names.size(); ++i) {
994 if (i > 0) strings::StrAppend(&class_def, ", ");
995 strings::StrAppend(&class_def, arg_names[i]);
996 }
997 if (i > 0) strings::StrAppend(&class_def, ", ");
998 strings::StrAppend(&class_def, op_name, "::Attrs()");
999 strings::StrAppend(&class_def, ") {}\n\n");
1000 }
1001 TF_CHECK_OK(cc->Append(class_def));
1002 }
1003
WriteCCOp(const OpDef & graph_op_def,const ApiDef & api_def,const std::vector<string> & aliases,WritableFile * h,WritableFile * cc)1004 void WriteCCOp(const OpDef& graph_op_def, const ApiDef& api_def,
1005 const std::vector<string>& aliases, WritableFile* h,
1006 WritableFile* cc) {
1007 OpInfo op_info(graph_op_def, api_def, aliases);
1008
1009 op_info.WriteClassDecl(h);
1010 op_info.WriteClassDef(cc);
1011 }
1012
StartFiles(bool internal,const string & dot_h_fname,WritableFile * h,WritableFile * cc,string * op_header_guard)1013 void StartFiles(bool internal, const string& dot_h_fname, WritableFile* h,
1014 WritableFile* cc, string* op_header_guard) {
1015 const string header =
1016 R"header(// This file is MACHINE GENERATED! Do not edit.
1017
1018 #include "tensorflow/cc/framework/ops.h"
1019 #include "tensorflow/cc/framework/scope.h"
1020 #include "tensorflow/core/framework/tensor.h"
1021 #include "tensorflow/core/framework/tensor_shape.h"
1022 #include "tensorflow/core/framework/types.h"
1023 #include "tensorflow/core/lib/gtl/array_slice.h"
1024 )header";
1025
1026 // TODO(keveman): Make namespaces configurable.
1027 const string namespace_begin = internal ? R"namespace(
1028 namespace tensorflow {
1029 namespace ops {
1030 namespace internal {
1031 // NOTE: This namespace has internal TensorFlow details that
1032 // are not part of TensorFlow's public API.
1033
1034 )namespace"
1035 : R"namespace(
1036 namespace tensorflow {
1037 namespace ops {
1038
1039 )namespace";
1040
1041 const string op_header = GetPath(dot_h_fname);
1042 *op_header_guard = ToGuard(op_header);
1043 const string cc_header = strings::StrCat(
1044 R"include(// This file is MACHINE GENERATED! Do not edit.
1045
1046
1047 #include "tensorflow/cc/ops/const_op.h"
1048 )include",
1049 "#include \"", op_header, "\"\n", namespace_begin);
1050
1051 const string filename = GetFilename(dot_h_fname);
1052 const string doxygen = strings::StrCat("/// @defgroup ", filename, " ",
1053 ToTitle(filename), "\n", "/// @{\n\n");
1054
1055 TF_CHECK_OK(h->Append(
1056 strings::StrCat("// This file is MACHINE GENERATED! Do not edit.\n\n"
1057 "#ifndef ",
1058 *op_header_guard,
1059 "\n"
1060 "#define ",
1061 *op_header_guard, "\n\n")));
1062 TF_CHECK_OK(h->Append(header));
1063 TF_CHECK_OK(h->Append(namespace_begin));
1064 TF_CHECK_OK(h->Append(doxygen));
1065 TF_CHECK_OK(cc->Append(cc_header));
1066 }
1067
FinishFiles(bool internal,WritableFile * h,WritableFile * cc,const string & op_header_guard)1068 void FinishFiles(bool internal, WritableFile* h, WritableFile* cc,
1069 const string& op_header_guard) {
1070 const string footer = internal ? R"footer(} // namespace internal
1071 } // namespace ops
1072 } // namespace tensorflow
1073 )footer"
1074 :
1075 R"footer(/// @}
1076
1077 } // namespace ops
1078 } // namespace tensorflow
1079 )footer";
1080
1081 TF_CHECK_OK(h->Append(footer));
1082 TF_CHECK_OK(
1083 h->Append(strings::StrCat("\n#endif ", "// ", op_header_guard, "\n")));
1084 TF_CHECK_OK(cc->Append(footer));
1085
1086 TF_CHECK_OK(cc->Close());
1087 TF_CHECK_OK(h->Close());
1088 }
1089
MakeInternal(const string & fname)1090 string MakeInternal(const string& fname) {
1091 auto dot_pos = fname.rfind('.');
1092 if (dot_pos == string::npos) {
1093 return strings::StrCat(fname, "_internal");
1094 } else {
1095 return strings::StrCat(fname.substr(0, dot_pos), "_internal",
1096 fname.substr(dot_pos));
1097 }
1098 }
1099
1100 } // namespace
1101
WriteCCOps(const OpList & ops,const ApiDefMap & api_def_map,const string & dot_h_fname,const string & dot_cc_fname)1102 void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map,
1103 const string& dot_h_fname, const string& dot_cc_fname) {
1104 Env* env = Env::Default();
1105
1106 // Write the initial boilerplate to the .h and .cc files.
1107 std::unique_ptr<WritableFile> h = nullptr;
1108 std::unique_ptr<WritableFile> cc = nullptr;
1109 TF_CHECK_OK(env->NewWritableFile(dot_h_fname, &h));
1110 TF_CHECK_OK(env->NewWritableFile(dot_cc_fname, &cc));
1111 string op_header_guard;
1112 StartFiles(false, dot_h_fname, h.get(), cc.get(), &op_header_guard);
1113
1114 // Create the internal versions of these files for the hidden ops.
1115 std::unique_ptr<WritableFile> internal_h = nullptr;
1116 std::unique_ptr<WritableFile> internal_cc = nullptr;
1117 const string internal_dot_h_fname = MakeInternal(dot_h_fname);
1118 TF_CHECK_OK(env->NewWritableFile(internal_dot_h_fname, &internal_h));
1119 TF_CHECK_OK(env->NewWritableFile(MakeInternal(dot_cc_fname), &internal_cc));
1120 string internal_op_header_guard;
1121 StartFiles(true /* internal */, internal_dot_h_fname, internal_h.get(),
1122 internal_cc.get(), &internal_op_header_guard);
1123
1124 for (const auto& graph_op_def : ops.op()) {
1125 // Skip deprecated ops.
1126 // TODO(josh11b): If needed, can put them into a "deprecated" namespace
1127 // instead of skipping.
1128 if (graph_op_def.has_deprecation() &&
1129 graph_op_def.deprecation().version() <= TF_GRAPH_DEF_VERSION) {
1130 continue;
1131 }
1132
1133 // We use a hand-written wrapper for "Const", since the generated
1134 // code depends on it.
1135 if (graph_op_def.name() == "Const") continue;
1136
1137 const auto* api_def = api_def_map.GetApiDef(graph_op_def.name());
1138
1139 std::vector<string> aliases;
1140 if (api_def->visibility() == ApiDef::SKIP) continue;
1141 // First endpoint is canonical, the rest are aliases.
1142 for (int endpoint_i = 1; endpoint_i < api_def->endpoint_size();
1143 ++endpoint_i) {
1144 aliases.push_back(api_def->endpoint(endpoint_i).name());
1145 }
1146 if (api_def->visibility() == ApiDef::HIDDEN) {
1147 // Write hidden ops to _internal.h and _internal.cc.
1148 WriteCCOp(graph_op_def, *api_def, aliases, internal_h.get(),
1149 internal_cc.get());
1150 continue;
1151 }
1152 // This isn't a hidden op, write it to the main files.
1153 WriteCCOp(graph_op_def, *api_def, aliases, h.get(), cc.get());
1154 }
1155
1156 FinishFiles(false, h.get(), cc.get(), op_header_guard);
1157 FinishFiles(true /* internal */, internal_h.get(), internal_cc.get(),
1158 internal_op_header_guard);
1159 }
1160
1161 } // namespace tensorflow
1162