1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op_gen_lib.h"
17 
18 #include <algorithm>
19 #include <vector>
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/gtl/map_util.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24 #include "tensorflow/core/lib/strings/strcat.h"
25 #include "tensorflow/core/platform/protobuf.h"
26 #include "tensorflow/core/util/proto/proto_utils.h"
27 
28 namespace tensorflow {
29 
WordWrap(StringPiece prefix,StringPiece str,int width)30 string WordWrap(StringPiece prefix, StringPiece str, int width) {
31   const string indent_next_line = "\n" + Spaces(prefix.size());
32   width -= prefix.size();
33   string result;
34   strings::StrAppend(&result, prefix);
35 
36   while (!str.empty()) {
37     if (static_cast<int>(str.size()) <= width) {
38       // Remaining text fits on one line.
39       strings::StrAppend(&result, str);
40       break;
41     }
42     auto space = str.rfind(' ', width);
43     if (space == StringPiece::npos) {
44       // Rather make a too-long line and break at a space.
45       space = str.find(' ');
46       if (space == StringPiece::npos) {
47         strings::StrAppend(&result, str);
48         break;
49       }
50     }
51     // Breaking at character at position <space>.
52     StringPiece to_append = str.substr(0, space);
53     str.remove_prefix(space + 1);
54     // Remove spaces at break.
55     while (str_util::EndsWith(to_append, " ")) {
56       to_append.remove_suffix(1);
57     }
58     while (str_util::ConsumePrefix(&str, " ")) {
59     }
60 
61     // Go on to the next line.
62     strings::StrAppend(&result, to_append);
63     if (!str.empty()) strings::StrAppend(&result, indent_next_line);
64   }
65 
66   return result;
67 }
68 
ConsumeEquals(StringPiece * description)69 bool ConsumeEquals(StringPiece* description) {
70   if (str_util::ConsumePrefix(description, "=")) {
71     while (str_util::ConsumePrefix(description,
72                                    " ")) {  // Also remove spaces after "=".
73     }
74     return true;
75   }
76   return false;
77 }
78 
79 // Split `*orig` into two pieces at the first occurrence of `split_ch`.
80 // Returns whether `split_ch` was found. Afterwards, `*before_split`
81 // contains the maximum prefix of the input `*orig` that doesn't
82 // contain `split_ch`, and `*orig` contains everything after the
83 // first `split_ch`.
SplitAt(char split_ch,StringPiece * orig,StringPiece * before_split)84 static bool SplitAt(char split_ch, StringPiece* orig,
85                     StringPiece* before_split) {
86   auto pos = orig->find(split_ch);
87   if (pos == StringPiece::npos) {
88     *before_split = *orig;
89     *orig = StringPiece();
90     return false;
91   } else {
92     *before_split = orig->substr(0, pos);
93     orig->remove_prefix(pos + 1);
94     return true;
95   }
96 }
97 
98 // Does this line start with "<spaces><field>:" where "<field>" is
99 // in multi_line_fields? Sets *colon_pos to the position of the colon.
StartsWithFieldName(StringPiece line,const std::vector<string> & multi_line_fields)100 static bool StartsWithFieldName(StringPiece line,
101                                 const std::vector<string>& multi_line_fields) {
102   StringPiece up_to_colon;
103   if (!SplitAt(':', &line, &up_to_colon)) return false;
104   while (str_util::ConsumePrefix(&up_to_colon, " "))
105     ;  // Remove leading spaces.
106   for (const auto& field : multi_line_fields) {
107     if (up_to_colon == field) {
108       return true;
109     }
110   }
111   return false;
112 }
113 
ConvertLine(StringPiece line,const std::vector<string> & multi_line_fields,string * ml)114 static bool ConvertLine(StringPiece line,
115                         const std::vector<string>& multi_line_fields,
116                         string* ml) {
117   // Is this a field we should convert?
118   if (!StartsWithFieldName(line, multi_line_fields)) {
119     return false;
120   }
121   // Has a matching field name, so look for "..." after the colon.
122   StringPiece up_to_colon;
123   StringPiece after_colon = line;
124   SplitAt(':', &after_colon, &up_to_colon);
125   while (str_util::ConsumePrefix(&after_colon, " "))
126     ;  // Remove leading spaces.
127   if (!str_util::ConsumePrefix(&after_colon, "\"")) {
128     // We only convert string fields, so don't convert this line.
129     return false;
130   }
131   auto last_quote = after_colon.rfind('\"');
132   if (last_quote == StringPiece::npos) {
133     // Error: we don't see the expected matching quote, abort the conversion.
134     return false;
135   }
136   StringPiece escaped = after_colon.substr(0, last_quote);
137   StringPiece suffix = after_colon.substr(last_quote + 1);
138   // We've now parsed line into '<up_to_colon>: "<escaped>"<suffix>'
139 
140   string unescaped;
141   if (!str_util::CUnescape(escaped, &unescaped, nullptr)) {
142     // Error unescaping, abort the conversion.
143     return false;
144   }
145   // No more errors possible at this point.
146 
147   // Find a string to mark the end that isn't in unescaped.
148   string end = "END";
149   for (int s = 0; unescaped.find(end) != string::npos; ++s) {
150     end = strings::StrCat("END", s);
151   }
152 
153   // Actually start writing the converted output.
154   strings::StrAppend(ml, up_to_colon, ": <<", end, "\n", unescaped, "\n", end);
155   if (!suffix.empty()) {
156     // Output suffix, in case there was a trailing comment in the source.
157     strings::StrAppend(ml, suffix);
158   }
159   strings::StrAppend(ml, "\n");
160   return true;
161 }
162 
PBTxtToMultiline(StringPiece pbtxt,const std::vector<string> & multi_line_fields)163 string PBTxtToMultiline(StringPiece pbtxt,
164                         const std::vector<string>& multi_line_fields) {
165   string ml;
166   // Probably big enough, since the input and output are about the
167   // same size, but just a guess.
168   ml.reserve(pbtxt.size() * (17. / 16));
169   StringPiece line;
170   while (!pbtxt.empty()) {
171     // Split pbtxt into its first line and everything after.
172     SplitAt('\n', &pbtxt, &line);
173     // Convert line or output it unchanged
174     if (!ConvertLine(line, multi_line_fields, &ml)) {
175       strings::StrAppend(&ml, line, "\n");
176     }
177   }
178   return ml;
179 }
180 
181 // Given a single line of text `line` with first : at `colon`, determine if
182 // there is an "<<END" expression after the colon and if so return true and set
183 // `*end` to everything after the "<<".
FindMultiline(StringPiece line,size_t colon,string * end)184 static bool FindMultiline(StringPiece line, size_t colon, string* end) {
185   if (colon == StringPiece::npos) return false;
186   line.remove_prefix(colon + 1);
187   while (str_util::ConsumePrefix(&line, " ")) {
188   }
189   if (str_util::ConsumePrefix(&line, "<<")) {
190     *end = string(line);
191     return true;
192   }
193   return false;
194 }
195 
PBTxtFromMultiline(StringPiece multiline_pbtxt)196 string PBTxtFromMultiline(StringPiece multiline_pbtxt) {
197   string pbtxt;
198   // Probably big enough, since the input and output are about the
199   // same size, but just a guess.
200   pbtxt.reserve(multiline_pbtxt.size() * (33. / 32));
201   StringPiece line;
202   while (!multiline_pbtxt.empty()) {
203     // Split multiline_pbtxt into its first line and everything after.
204     if (!SplitAt('\n', &multiline_pbtxt, &line)) {
205       strings::StrAppend(&pbtxt, line);
206       break;
207     }
208 
209     string end;
210     auto colon = line.find(':');
211     if (!FindMultiline(line, colon, &end)) {
212       // Normal case: not a multi-line string, just output the line as-is.
213       strings::StrAppend(&pbtxt, line, "\n");
214       continue;
215     }
216 
217     // Multi-line case:
218     //     something: <<END
219     // xx
220     // yy
221     // END
222     // Should be converted to:
223     //     something: "xx\nyy"
224 
225     // Output everything up to the colon ("    something:").
226     strings::StrAppend(&pbtxt, line.substr(0, colon + 1));
227 
228     // Add every line to unescaped until we see the "END" string.
229     string unescaped;
230     bool first = true;
231     while (!multiline_pbtxt.empty()) {
232       SplitAt('\n', &multiline_pbtxt, &line);
233       if (str_util::ConsumePrefix(&line, end)) break;
234       if (first) {
235         first = false;
236       } else {
237         unescaped.push_back('\n');
238       }
239       strings::StrAppend(&unescaped, line);
240       line = StringPiece();
241     }
242 
243     // Escape what we extracted and then output it in quotes.
244     strings::StrAppend(&pbtxt, " \"", str_util::CEscape(unescaped), "\"", line,
245                        "\n");
246   }
247   return pbtxt;
248 }
249 
StringReplace(const string & from,const string & to,string * s)250 static void StringReplace(const string& from, const string& to, string* s) {
251   // Split *s into pieces delimited by `from`.
252   std::vector<string> split;
253   string::size_type pos = 0;
254   while (pos < s->size()) {
255     auto found = s->find(from, pos);
256     if (found == string::npos) {
257       split.push_back(s->substr(pos));
258       break;
259     } else {
260       split.push_back(s->substr(pos, found - pos));
261       pos = found + from.size();
262       if (pos == s->size()) {  // handle case where `from` is at the very end.
263         split.push_back("");
264       }
265     }
266   }
267   // Join the pieces back together with a new delimiter.
268   *s = str_util::Join(split, to.c_str());
269 }
270 
RenameInDocs(const string & from,const string & to,ApiDef * api_def)271 static void RenameInDocs(const string& from, const string& to,
272                          ApiDef* api_def) {
273   const string from_quoted = strings::StrCat("`", from, "`");
274   const string to_quoted = strings::StrCat("`", to, "`");
275   for (int i = 0; i < api_def->in_arg_size(); ++i) {
276     if (!api_def->in_arg(i).description().empty()) {
277       StringReplace(from_quoted, to_quoted,
278                     api_def->mutable_in_arg(i)->mutable_description());
279     }
280   }
281   for (int i = 0; i < api_def->out_arg_size(); ++i) {
282     if (!api_def->out_arg(i).description().empty()) {
283       StringReplace(from_quoted, to_quoted,
284                     api_def->mutable_out_arg(i)->mutable_description());
285     }
286   }
287   for (int i = 0; i < api_def->attr_size(); ++i) {
288     if (!api_def->attr(i).description().empty()) {
289       StringReplace(from_quoted, to_quoted,
290                     api_def->mutable_attr(i)->mutable_description());
291     }
292   }
293   if (!api_def->summary().empty()) {
294     StringReplace(from_quoted, to_quoted, api_def->mutable_summary());
295   }
296   if (!api_def->description().empty()) {
297     StringReplace(from_quoted, to_quoted, api_def->mutable_description());
298   }
299 }
300 
301 namespace {
302 
303 // Initializes given ApiDef with data in OpDef.
InitApiDefFromOpDef(const OpDef & op_def,ApiDef * api_def)304 void InitApiDefFromOpDef(const OpDef& op_def, ApiDef* api_def) {
305   api_def->set_graph_op_name(op_def.name());
306   api_def->set_visibility(ApiDef::VISIBLE);
307 
308   auto* endpoint = api_def->add_endpoint();
309   endpoint->set_name(op_def.name());
310 
311   for (const auto& op_in_arg : op_def.input_arg()) {
312     auto* api_in_arg = api_def->add_in_arg();
313     api_in_arg->set_name(op_in_arg.name());
314     api_in_arg->set_rename_to(op_in_arg.name());
315     api_in_arg->set_description(op_in_arg.description());
316 
317     *api_def->add_arg_order() = op_in_arg.name();
318   }
319   for (const auto& op_out_arg : op_def.output_arg()) {
320     auto* api_out_arg = api_def->add_out_arg();
321     api_out_arg->set_name(op_out_arg.name());
322     api_out_arg->set_rename_to(op_out_arg.name());
323     api_out_arg->set_description(op_out_arg.description());
324   }
325   for (const auto& op_attr : op_def.attr()) {
326     auto* api_attr = api_def->add_attr();
327     api_attr->set_name(op_attr.name());
328     api_attr->set_rename_to(op_attr.name());
329     if (op_attr.has_default_value()) {
330       *api_attr->mutable_default_value() = op_attr.default_value();
331     }
332     api_attr->set_description(op_attr.description());
333   }
334   api_def->set_summary(op_def.summary());
335   api_def->set_description(op_def.description());
336 }
337 
338 // Updates base_arg based on overrides in new_arg.
MergeArg(ApiDef::Arg * base_arg,const ApiDef::Arg & new_arg)339 void MergeArg(ApiDef::Arg* base_arg, const ApiDef::Arg& new_arg) {
340   if (!new_arg.rename_to().empty()) {
341     base_arg->set_rename_to(new_arg.rename_to());
342   }
343   if (!new_arg.description().empty()) {
344     base_arg->set_description(new_arg.description());
345   }
346 }
347 
348 // Updates base_attr based on overrides in new_attr.
MergeAttr(ApiDef::Attr * base_attr,const ApiDef::Attr & new_attr)349 void MergeAttr(ApiDef::Attr* base_attr, const ApiDef::Attr& new_attr) {
350   if (!new_attr.rename_to().empty()) {
351     base_attr->set_rename_to(new_attr.rename_to());
352   }
353   if (new_attr.has_default_value()) {
354     *base_attr->mutable_default_value() = new_attr.default_value();
355   }
356   if (!new_attr.description().empty()) {
357     base_attr->set_description(new_attr.description());
358   }
359 }
360 
361 // Updates base_api_def based on overrides in new_api_def.
MergeApiDefs(ApiDef * base_api_def,const ApiDef & new_api_def)362 Status MergeApiDefs(ApiDef* base_api_def, const ApiDef& new_api_def) {
363   // Merge visibility
364   if (new_api_def.visibility() != ApiDef::DEFAULT_VISIBILITY) {
365     base_api_def->set_visibility(new_api_def.visibility());
366   }
367   // Merge endpoints
368   if (new_api_def.endpoint_size() > 0) {
369     base_api_def->clear_endpoint();
370     std::copy(
371         new_api_def.endpoint().begin(), new_api_def.endpoint().end(),
372         protobuf::RepeatedFieldBackInserter(base_api_def->mutable_endpoint()));
373   }
374   // Merge args
375   for (const auto& new_arg : new_api_def.in_arg()) {
376     bool found_base_arg = false;
377     for (int i = 0; i < base_api_def->in_arg_size(); ++i) {
378       auto* base_arg = base_api_def->mutable_in_arg(i);
379       if (base_arg->name() == new_arg.name()) {
380         MergeArg(base_arg, new_arg);
381         found_base_arg = true;
382         break;
383       }
384     }
385     if (!found_base_arg) {
386       return errors::FailedPrecondition("Argument ", new_arg.name(),
387                                         " not defined in base api for ",
388                                         base_api_def->graph_op_name());
389     }
390   }
391   for (const auto& new_arg : new_api_def.out_arg()) {
392     bool found_base_arg = false;
393     for (int i = 0; i < base_api_def->out_arg_size(); ++i) {
394       auto* base_arg = base_api_def->mutable_out_arg(i);
395       if (base_arg->name() == new_arg.name()) {
396         MergeArg(base_arg, new_arg);
397         found_base_arg = true;
398         break;
399       }
400     }
401     if (!found_base_arg) {
402       return errors::FailedPrecondition("Argument ", new_arg.name(),
403                                         " not defined in base api for ",
404                                         base_api_def->graph_op_name());
405     }
406   }
407   // Merge arg order
408   if (new_api_def.arg_order_size() > 0) {
409     // Validate that new arg_order is correct.
410     if (new_api_def.arg_order_size() != base_api_def->arg_order_size()) {
411       return errors::FailedPrecondition(
412           "Invalid number of arguments ", new_api_def.arg_order_size(), " for ",
413           base_api_def->graph_op_name(),
414           ". Expected: ", base_api_def->arg_order_size());
415     }
416     if (!std::is_permutation(new_api_def.arg_order().begin(),
417                              new_api_def.arg_order().end(),
418                              base_api_def->arg_order().begin())) {
419       return errors::FailedPrecondition(
420           "Invalid arg_order: ", str_util::Join(new_api_def.arg_order(), ", "),
421           " for ", base_api_def->graph_op_name(),
422           ". All elements in arg_order override must match base arg_order: ",
423           str_util::Join(base_api_def->arg_order(), ", "));
424     }
425 
426     base_api_def->clear_arg_order();
427     std::copy(
428         new_api_def.arg_order().begin(), new_api_def.arg_order().end(),
429         protobuf::RepeatedFieldBackInserter(base_api_def->mutable_arg_order()));
430   }
431   // Merge attributes
432   for (const auto& new_attr : new_api_def.attr()) {
433     bool found_base_attr = false;
434     for (int i = 0; i < base_api_def->attr_size(); ++i) {
435       auto* base_attr = base_api_def->mutable_attr(i);
436       if (base_attr->name() == new_attr.name()) {
437         MergeAttr(base_attr, new_attr);
438         found_base_attr = true;
439         break;
440       }
441     }
442     if (!found_base_attr) {
443       return errors::FailedPrecondition("Attribute ", new_attr.name(),
444                                         " not defined in base api for ",
445                                         base_api_def->graph_op_name());
446     }
447   }
448   // Merge summary
449   if (!new_api_def.summary().empty()) {
450     base_api_def->set_summary(new_api_def.summary());
451   }
452   // Merge description
453   auto description = new_api_def.description().empty()
454                          ? base_api_def->description()
455                          : new_api_def.description();
456 
457   if (!new_api_def.description_prefix().empty()) {
458     description =
459         strings::StrCat(new_api_def.description_prefix(), "\n", description);
460   }
461   if (!new_api_def.description_suffix().empty()) {
462     description =
463         strings::StrCat(description, "\n", new_api_def.description_suffix());
464   }
465   base_api_def->set_description(description);
466   return Status::OK();
467 }
468 }  // namespace
469 
ApiDefMap(const OpList & op_list)470 ApiDefMap::ApiDefMap(const OpList& op_list) {
471   for (const auto& op : op_list.op()) {
472     ApiDef api_def;
473     InitApiDefFromOpDef(op, &api_def);
474     map_[op.name()] = api_def;
475   }
476 }
477 
~ApiDefMap()478 ApiDefMap::~ApiDefMap() {}
479 
LoadFileList(Env * env,const std::vector<string> & filenames)480 Status ApiDefMap::LoadFileList(Env* env, const std::vector<string>& filenames) {
481   for (const auto& filename : filenames) {
482     TF_RETURN_IF_ERROR(LoadFile(env, filename));
483   }
484   return Status::OK();
485 }
486 
LoadFile(Env * env,const string & filename)487 Status ApiDefMap::LoadFile(Env* env, const string& filename) {
488   if (filename.empty()) return Status::OK();
489   string contents;
490   TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &contents));
491   Status status = LoadApiDef(contents);
492   if (!status.ok()) {
493     // Return failed status annotated with filename to aid in debugging.
494     return Status(status.code(),
495                   strings::StrCat("Error parsing ApiDef file ", filename, ": ",
496                                   status.error_message()));
497   }
498   return Status::OK();
499 }
500 
LoadApiDef(const string & api_def_file_contents)501 Status ApiDefMap::LoadApiDef(const string& api_def_file_contents) {
502   const string contents = PBTxtFromMultiline(api_def_file_contents);
503   ApiDefs api_defs;
504   TF_RETURN_IF_ERROR(
505       proto_utils::ParseTextFormatFromString(contents, &api_defs));
506   for (const auto& api_def : api_defs.op()) {
507     // Check if the op definition is loaded. If op definition is not
508     // loaded, then we just skip this ApiDef.
509     if (map_.find(api_def.graph_op_name()) != map_.end()) {
510       // Overwrite current api def with data in api_def.
511       TF_RETURN_IF_ERROR(MergeApiDefs(&map_[api_def.graph_op_name()], api_def));
512     }
513   }
514   return Status::OK();
515 }
516 
UpdateDocs()517 void ApiDefMap::UpdateDocs() {
518   for (auto& name_and_api_def : map_) {
519     auto& api_def = name_and_api_def.second;
520     CHECK_GT(api_def.endpoint_size(), 0);
521     const string canonical_name = api_def.endpoint(0).name();
522     if (api_def.graph_op_name() != canonical_name) {
523       RenameInDocs(api_def.graph_op_name(), canonical_name, &api_def);
524     }
525     for (const auto& in_arg : api_def.in_arg()) {
526       if (in_arg.name() != in_arg.rename_to()) {
527         RenameInDocs(in_arg.name(), in_arg.rename_to(), &api_def);
528       }
529     }
530     for (const auto& out_arg : api_def.out_arg()) {
531       if (out_arg.name() != out_arg.rename_to()) {
532         RenameInDocs(out_arg.name(), out_arg.rename_to(), &api_def);
533       }
534     }
535     for (const auto& attr : api_def.attr()) {
536       if (attr.name() != attr.rename_to()) {
537         RenameInDocs(attr.name(), attr.rename_to(), &api_def);
538       }
539     }
540   }
541 }
542 
GetApiDef(const string & name) const543 const tensorflow::ApiDef* ApiDefMap::GetApiDef(const string& name) const {
544   return gtl::FindOrNull(map_, name);
545 }
546 }  // namespace tensorflow
547