1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/op_gen_lib.h"
17
18 #include <vector>
19 #include "tensorflow/core/framework/attr_value.pb.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/gtl/map_util.h"
22 #include "tensorflow/core/lib/strings/str_util.h"
23 #include "tensorflow/core/lib/strings/strcat.h"
24 #include "tensorflow/core/platform/protobuf.h"
25
26 namespace tensorflow {
27
WordWrap(StringPiece prefix,StringPiece str,int width)28 string WordWrap(StringPiece prefix, StringPiece str, int width) {
29 const string indent_next_line = "\n" + Spaces(prefix.size());
30 width -= prefix.size();
31 string result;
32 strings::StrAppend(&result, prefix);
33
34 while (!str.empty()) {
35 if (static_cast<int>(str.size()) <= width) {
36 // Remaining text fits on one line.
37 strings::StrAppend(&result, str);
38 break;
39 }
40 auto space = str.rfind(' ', width);
41 if (space == StringPiece::npos) {
42 // Rather make a too-long line and break at a space.
43 space = str.find(' ');
44 if (space == StringPiece::npos) {
45 strings::StrAppend(&result, str);
46 break;
47 }
48 }
49 // Breaking at character at position <space>.
50 StringPiece to_append = str.substr(0, space);
51 str.remove_prefix(space + 1);
52 // Remove spaces at break.
53 while (to_append.ends_with(" ")) {
54 to_append.remove_suffix(1);
55 }
56 while (str.Consume(" ")) {
57 }
58
59 // Go on to the next line.
60 strings::StrAppend(&result, to_append);
61 if (!str.empty()) strings::StrAppend(&result, indent_next_line);
62 }
63
64 return result;
65 }
66
ConsumeEquals(StringPiece * description)67 bool ConsumeEquals(StringPiece* description) {
68 if (description->Consume("=")) {
69 while (description->Consume(" ")) { // Also remove spaces after "=".
70 }
71 return true;
72 }
73 return false;
74 }
75
76 // Split `*orig` into two pieces at the first occurrence of `split_ch`.
77 // Returns whether `split_ch` was found. Afterwards, `*before_split`
78 // contains the maximum prefix of the input `*orig` that doesn't
79 // contain `split_ch`, and `*orig` contains everything after the
80 // first `split_ch`.
SplitAt(char split_ch,StringPiece * orig,StringPiece * before_split)81 static bool SplitAt(char split_ch, StringPiece* orig,
82 StringPiece* before_split) {
83 auto pos = orig->find(split_ch);
84 if (pos == StringPiece::npos) {
85 *before_split = *orig;
86 *orig = StringPiece();
87 return false;
88 } else {
89 *before_split = orig->substr(0, pos);
90 orig->remove_prefix(pos + 1);
91 return true;
92 }
93 }
94
95 // Does this line start with "<spaces><field>:" where "<field>" is
96 // in multi_line_fields? Sets *colon_pos to the position of the colon.
StartsWithFieldName(StringPiece line,const std::vector<string> & multi_line_fields)97 static bool StartsWithFieldName(StringPiece line,
98 const std::vector<string>& multi_line_fields) {
99 StringPiece up_to_colon;
100 if (!SplitAt(':', &line, &up_to_colon)) return false;
101 while (up_to_colon.Consume(" "))
102 ; // Remove leading spaces.
103 for (const auto& field : multi_line_fields) {
104 if (up_to_colon == field) {
105 return true;
106 }
107 }
108 return false;
109 }
110
ConvertLine(StringPiece line,const std::vector<string> & multi_line_fields,string * ml)111 static bool ConvertLine(StringPiece line,
112 const std::vector<string>& multi_line_fields,
113 string* ml) {
114 // Is this a field we should convert?
115 if (!StartsWithFieldName(line, multi_line_fields)) {
116 return false;
117 }
118 // Has a matching field name, so look for "..." after the colon.
119 StringPiece up_to_colon;
120 StringPiece after_colon = line;
121 SplitAt(':', &after_colon, &up_to_colon);
122 while (after_colon.Consume(" "))
123 ; // Remove leading spaces.
124 if (!after_colon.Consume("\"")) {
125 // We only convert string fields, so don't convert this line.
126 return false;
127 }
128 auto last_quote = after_colon.rfind('\"');
129 if (last_quote == StringPiece::npos) {
130 // Error: we don't see the expected matching quote, abort the conversion.
131 return false;
132 }
133 StringPiece escaped = after_colon.substr(0, last_quote);
134 StringPiece suffix = after_colon.substr(last_quote + 1);
135 // We've now parsed line into '<up_to_colon>: "<escaped>"<suffix>'
136
137 string unescaped;
138 if (!str_util::CUnescape(escaped, &unescaped, nullptr)) {
139 // Error unescaping, abort the conversion.
140 return false;
141 }
142 // No more errors possible at this point.
143
144 // Find a string to mark the end that isn't in unescaped.
145 string end = "END";
146 for (int s = 0; unescaped.find(end) != string::npos; ++s) {
147 end = strings::StrCat("END", s);
148 }
149
150 // Actually start writing the converted output.
151 strings::StrAppend(ml, up_to_colon, ": <<", end, "\n", unescaped, "\n", end);
152 if (!suffix.empty()) {
153 // Output suffix, in case there was a trailing comment in the source.
154 strings::StrAppend(ml, suffix);
155 }
156 strings::StrAppend(ml, "\n");
157 return true;
158 }
159
PBTxtToMultiline(StringPiece pbtxt,const std::vector<string> & multi_line_fields)160 string PBTxtToMultiline(StringPiece pbtxt,
161 const std::vector<string>& multi_line_fields) {
162 string ml;
163 // Probably big enough, since the input and output are about the
164 // same size, but just a guess.
165 ml.reserve(pbtxt.size() * (17. / 16));
166 StringPiece line;
167 while (!pbtxt.empty()) {
168 // Split pbtxt into its first line and everything after.
169 SplitAt('\n', &pbtxt, &line);
170 // Convert line or output it unchanged
171 if (!ConvertLine(line, multi_line_fields, &ml)) {
172 strings::StrAppend(&ml, line, "\n");
173 }
174 }
175 return ml;
176 }
177
178 // Given a single line of text `line` with first : at `colon`, determine if
179 // there is an "<<END" expression after the colon and if so return true and set
180 // `*end` to everything after the "<<".
FindMultiline(StringPiece line,size_t colon,string * end)181 static bool FindMultiline(StringPiece line, size_t colon, string* end) {
182 if (colon == StringPiece::npos) return false;
183 line.remove_prefix(colon + 1);
184 while (line.Consume(" ")) {
185 }
186 if (line.Consume("<<")) {
187 *end = line.ToString();
188 return true;
189 }
190 return false;
191 }
192
PBTxtFromMultiline(StringPiece multiline_pbtxt)193 string PBTxtFromMultiline(StringPiece multiline_pbtxt) {
194 string pbtxt;
195 // Probably big enough, since the input and output are about the
196 // same size, but just a guess.
197 pbtxt.reserve(multiline_pbtxt.size() * (33. / 32));
198 StringPiece line;
199 while (!multiline_pbtxt.empty()) {
200 // Split multiline_pbtxt into its first line and everything after.
201 if (!SplitAt('\n', &multiline_pbtxt, &line)) {
202 strings::StrAppend(&pbtxt, line);
203 break;
204 }
205
206 string end;
207 auto colon = line.find(':');
208 if (!FindMultiline(line, colon, &end)) {
209 // Normal case: not a multi-line string, just output the line as-is.
210 strings::StrAppend(&pbtxt, line, "\n");
211 continue;
212 }
213
214 // Multi-line case:
215 // something: <<END
216 // xx
217 // yy
218 // END
219 // Should be converted to:
220 // something: "xx\nyy"
221
222 // Output everything up to the colon (" something:").
223 strings::StrAppend(&pbtxt, line.substr(0, colon + 1));
224
225 // Add every line to unescaped until we see the "END" string.
226 string unescaped;
227 bool first = true;
228 string suffix;
229 while (!multiline_pbtxt.empty()) {
230 SplitAt('\n', &multiline_pbtxt, &line);
231 if (line.Consume(end)) break;
232 if (first) {
233 first = false;
234 } else {
235 unescaped.push_back('\n');
236 }
237 strings::StrAppend(&unescaped, line);
238 line = StringPiece();
239 }
240
241 // Escape what we extracted and then output it in quotes.
242 strings::StrAppend(&pbtxt, " \"", str_util::CEscape(unescaped), "\"", line,
243 "\n");
244 }
245 return pbtxt;
246 }
247
StringReplace(const string & from,const string & to,string * s)248 static void StringReplace(const string& from, const string& to, string* s) {
249 // Split *s into pieces delimited by `from`.
250 std::vector<string> split;
251 string::size_type pos = 0;
252 while (pos < s->size()) {
253 auto found = s->find(from, pos);
254 if (found == string::npos) {
255 split.push_back(s->substr(pos));
256 break;
257 } else {
258 split.push_back(s->substr(pos, found - pos));
259 pos = found + from.size();
260 if (pos == s->size()) { // handle case where `from` is at the very end.
261 split.push_back("");
262 }
263 }
264 }
265 // Join the pieces back together with a new delimiter.
266 *s = str_util::Join(split, to.c_str());
267 }
268
RenameInDocs(const string & from,const string & to,ApiDef * api_def)269 static void RenameInDocs(const string& from, const string& to,
270 ApiDef* api_def) {
271 const string from_quoted = strings::StrCat("`", from, "`");
272 const string to_quoted = strings::StrCat("`", to, "`");
273 for (int i = 0; i < api_def->in_arg_size(); ++i) {
274 if (!api_def->in_arg(i).description().empty()) {
275 StringReplace(from_quoted, to_quoted,
276 api_def->mutable_in_arg(i)->mutable_description());
277 }
278 }
279 for (int i = 0; i < api_def->out_arg_size(); ++i) {
280 if (!api_def->out_arg(i).description().empty()) {
281 StringReplace(from_quoted, to_quoted,
282 api_def->mutable_out_arg(i)->mutable_description());
283 }
284 }
285 for (int i = 0; i < api_def->attr_size(); ++i) {
286 if (!api_def->attr(i).description().empty()) {
287 StringReplace(from_quoted, to_quoted,
288 api_def->mutable_attr(i)->mutable_description());
289 }
290 }
291 if (!api_def->summary().empty()) {
292 StringReplace(from_quoted, to_quoted, api_def->mutable_summary());
293 }
294 if (!api_def->description().empty()) {
295 StringReplace(from_quoted, to_quoted, api_def->mutable_description());
296 }
297 }
298
299 namespace {
300
301 // Initializes given ApiDef with data in OpDef.
InitApiDefFromOpDef(const OpDef & op_def,ApiDef * api_def)302 void InitApiDefFromOpDef(const OpDef& op_def, ApiDef* api_def) {
303 api_def->set_graph_op_name(op_def.name());
304 api_def->set_visibility(ApiDef::VISIBLE);
305
306 auto* endpoint = api_def->add_endpoint();
307 endpoint->set_name(op_def.name());
308 if (op_def.has_deprecation()) {
309 endpoint->set_deprecation_version(op_def.deprecation().version());
310 }
311
312 for (const auto& op_in_arg : op_def.input_arg()) {
313 auto* api_in_arg = api_def->add_in_arg();
314 api_in_arg->set_name(op_in_arg.name());
315 api_in_arg->set_rename_to(op_in_arg.name());
316 api_in_arg->set_description(op_in_arg.description());
317
318 *api_def->add_arg_order() = op_in_arg.name();
319 }
320 for (const auto& op_out_arg : op_def.output_arg()) {
321 auto* api_out_arg = api_def->add_out_arg();
322 api_out_arg->set_name(op_out_arg.name());
323 api_out_arg->set_rename_to(op_out_arg.name());
324 api_out_arg->set_description(op_out_arg.description());
325 }
326 for (const auto& op_attr : op_def.attr()) {
327 auto* api_attr = api_def->add_attr();
328 api_attr->set_name(op_attr.name());
329 api_attr->set_rename_to(op_attr.name());
330 if (op_attr.has_default_value()) {
331 *api_attr->mutable_default_value() = op_attr.default_value();
332 }
333 api_attr->set_description(op_attr.description());
334 }
335 api_def->set_summary(op_def.summary());
336 api_def->set_description(op_def.description());
337 }
338
339 // Updates base_arg based on overrides in new_arg.
MergeArg(ApiDef::Arg * base_arg,const ApiDef::Arg & new_arg)340 void MergeArg(ApiDef::Arg* base_arg, const ApiDef::Arg& new_arg) {
341 if (!new_arg.rename_to().empty()) {
342 base_arg->set_rename_to(new_arg.rename_to());
343 }
344 if (!new_arg.description().empty()) {
345 base_arg->set_description(new_arg.description());
346 }
347 }
348
349 // Updates base_attr based on overrides in new_attr.
MergeAttr(ApiDef::Attr * base_attr,const ApiDef::Attr & new_attr)350 void MergeAttr(ApiDef::Attr* base_attr, const ApiDef::Attr& new_attr) {
351 if (!new_attr.rename_to().empty()) {
352 base_attr->set_rename_to(new_attr.rename_to());
353 }
354 if (new_attr.has_default_value()) {
355 *base_attr->mutable_default_value() = new_attr.default_value();
356 }
357 if (!new_attr.description().empty()) {
358 base_attr->set_description(new_attr.description());
359 }
360 }
361
362 // Updates base_api_def based on overrides in new_api_def.
MergeApiDefs(ApiDef * base_api_def,const ApiDef & new_api_def)363 Status MergeApiDefs(ApiDef* base_api_def, const ApiDef& new_api_def) {
364 // Merge visibility
365 if (new_api_def.visibility() != ApiDef::DEFAULT_VISIBILITY) {
366 base_api_def->set_visibility(new_api_def.visibility());
367 }
368 // Merge endpoints
369 if (new_api_def.endpoint_size() > 0) {
370 base_api_def->clear_endpoint();
371 std::copy(
372 new_api_def.endpoint().begin(), new_api_def.endpoint().end(),
373 protobuf::RepeatedFieldBackInserter(base_api_def->mutable_endpoint()));
374 }
375 // Merge args
376 for (const auto& new_arg : new_api_def.in_arg()) {
377 bool found_base_arg = false;
378 for (int i = 0; i < base_api_def->in_arg_size(); ++i) {
379 auto* base_arg = base_api_def->mutable_in_arg(i);
380 if (base_arg->name() == new_arg.name()) {
381 MergeArg(base_arg, new_arg);
382 found_base_arg = true;
383 break;
384 }
385 }
386 if (!found_base_arg) {
387 return errors::FailedPrecondition("Argument ", new_arg.name(),
388 " not defined in base api for ",
389 base_api_def->graph_op_name());
390 }
391 }
392 for (const auto& new_arg : new_api_def.out_arg()) {
393 bool found_base_arg = false;
394 for (int i = 0; i < base_api_def->out_arg_size(); ++i) {
395 auto* base_arg = base_api_def->mutable_out_arg(i);
396 if (base_arg->name() == new_arg.name()) {
397 MergeArg(base_arg, new_arg);
398 found_base_arg = true;
399 break;
400 }
401 }
402 if (!found_base_arg) {
403 return errors::FailedPrecondition("Argument ", new_arg.name(),
404 " not defined in base api for ",
405 base_api_def->graph_op_name());
406 }
407 }
408 // Merge arg order
409 if (new_api_def.arg_order_size() > 0) {
410 // Validate that new arg_order is correct.
411 if (new_api_def.arg_order_size() != base_api_def->arg_order_size()) {
412 return errors::FailedPrecondition(
413 "Invalid number of arguments ", new_api_def.arg_order_size(), " for ",
414 base_api_def->graph_op_name(),
415 ". Expected: ", base_api_def->arg_order_size());
416 }
417 if (!std::is_permutation(new_api_def.arg_order().begin(),
418 new_api_def.arg_order().end(),
419 base_api_def->arg_order().begin())) {
420 return errors::FailedPrecondition(
421 "Invalid arg_order: ", str_util::Join(new_api_def.arg_order(), ", "),
422 " for ", base_api_def->graph_op_name(),
423 ". All elements in arg_order override must match base arg_order: ",
424 str_util::Join(base_api_def->arg_order(), ", "));
425 }
426
427 base_api_def->clear_arg_order();
428 std::copy(
429 new_api_def.arg_order().begin(), new_api_def.arg_order().end(),
430 protobuf::RepeatedFieldBackInserter(base_api_def->mutable_arg_order()));
431 }
432 // Merge attributes
433 for (const auto& new_attr : new_api_def.attr()) {
434 bool found_base_attr = false;
435 for (int i = 0; i < base_api_def->attr_size(); ++i) {
436 auto* base_attr = base_api_def->mutable_attr(i);
437 if (base_attr->name() == new_attr.name()) {
438 MergeAttr(base_attr, new_attr);
439 found_base_attr = true;
440 break;
441 }
442 }
443 if (!found_base_attr) {
444 return errors::FailedPrecondition("Attribute ", new_attr.name(),
445 " not defined in base api for ",
446 base_api_def->graph_op_name());
447 }
448 }
449 // Merge summary
450 if (!new_api_def.summary().empty()) {
451 base_api_def->set_summary(new_api_def.summary());
452 }
453 // Merge description
454 auto description = new_api_def.description().empty()
455 ? base_api_def->description()
456 : new_api_def.description();
457
458 if (!new_api_def.description_prefix().empty()) {
459 description =
460 strings::StrCat(new_api_def.description_prefix(), "\n", description);
461 }
462 if (!new_api_def.description_suffix().empty()) {
463 description =
464 strings::StrCat(description, "\n", new_api_def.description_suffix());
465 }
466 base_api_def->set_description(description);
467 return Status::OK();
468 }
469 } // namespace
470
ApiDefMap(const OpList & op_list)471 ApiDefMap::ApiDefMap(const OpList& op_list) {
472 for (const auto& op : op_list.op()) {
473 ApiDef api_def;
474 InitApiDefFromOpDef(op, &api_def);
475 map_[op.name()] = api_def;
476 }
477 }
478
~ApiDefMap()479 ApiDefMap::~ApiDefMap() {}
480
LoadFileList(Env * env,const std::vector<string> & filenames)481 Status ApiDefMap::LoadFileList(Env* env, const std::vector<string>& filenames) {
482 for (const auto& filename : filenames) {
483 TF_RETURN_IF_ERROR(LoadFile(env, filename));
484 }
485 return Status::OK();
486 }
487
LoadFile(Env * env,const string & filename)488 Status ApiDefMap::LoadFile(Env* env, const string& filename) {
489 if (filename.empty()) return Status::OK();
490 string contents;
491 TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &contents));
492 TF_RETURN_IF_ERROR(LoadApiDef(contents));
493 return Status::OK();
494 }
495
LoadApiDef(const string & api_def_file_contents)496 Status ApiDefMap::LoadApiDef(const string& api_def_file_contents) {
497 const string contents = PBTxtFromMultiline(api_def_file_contents);
498 ApiDefs api_defs;
499 protobuf::TextFormat::ParseFromString(contents, &api_defs);
500 for (const auto& api_def : api_defs.op()) {
501 // Check if the op definition is loaded. If op definition is not
502 // loaded, then we just skip this ApiDef.
503 if (map_.find(api_def.graph_op_name()) != map_.end()) {
504 // Overwrite current api def with data in api_def.
505 TF_RETURN_IF_ERROR(MergeApiDefs(&map_[api_def.graph_op_name()], api_def));
506 }
507 }
508 return Status::OK();
509 }
510
UpdateDocs()511 void ApiDefMap::UpdateDocs() {
512 for (auto& name_and_api_def : map_) {
513 auto& api_def = name_and_api_def.second;
514 CHECK_GT(api_def.endpoint_size(), 0);
515 const string canonical_name = api_def.endpoint(0).name();
516 if (api_def.graph_op_name() != canonical_name) {
517 RenameInDocs(api_def.graph_op_name(), canonical_name, &api_def);
518 }
519 for (const auto& in_arg : api_def.in_arg()) {
520 if (in_arg.name() != in_arg.rename_to()) {
521 RenameInDocs(in_arg.name(), in_arg.rename_to(), &api_def);
522 }
523 }
524 for (const auto& out_arg : api_def.out_arg()) {
525 if (out_arg.name() != out_arg.rename_to()) {
526 RenameInDocs(out_arg.name(), out_arg.rename_to(), &api_def);
527 }
528 }
529 for (const auto& attr : api_def.attr()) {
530 if (attr.name() != attr.rename_to()) {
531 RenameInDocs(attr.name(), attr.rename_to(), &api_def);
532 }
533 }
534 }
535 }
536
GetApiDef(const string & name) const537 const tensorflow::ApiDef* ApiDefMap::GetApiDef(const string& name) const {
538 return gtl::FindOrNull(map_, name);
539 }
540 } // namespace tensorflow
541