1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op_gen_lib.h"
17 
18 #include <algorithm>
19 #include <vector>
20 
21 #include "absl/strings/escaping.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/gtl/map_util.h"
25 #include "tensorflow/core/lib/strings/str_util.h"
26 #include "tensorflow/core/lib/strings/strcat.h"
27 #include "tensorflow/core/platform/protobuf.h"
28 #include "tensorflow/core/util/proto/proto_utils.h"
29 
30 namespace tensorflow {
31 
WordWrap(StringPiece prefix,StringPiece str,int width)32 string WordWrap(StringPiece prefix, StringPiece str, int width) {
33   const string indent_next_line = "\n" + Spaces(prefix.size());
34   width -= prefix.size();
35   string result;
36   strings::StrAppend(&result, prefix);
37 
38   while (!str.empty()) {
39     if (static_cast<int>(str.size()) <= width) {
40       // Remaining text fits on one line.
41       strings::StrAppend(&result, str);
42       break;
43     }
44     auto space = str.rfind(' ', width);
45     if (space == StringPiece::npos) {
46       // Rather make a too-long line and break at a space.
47       space = str.find(' ');
48       if (space == StringPiece::npos) {
49         strings::StrAppend(&result, str);
50         break;
51       }
52     }
53     // Breaking at character at position <space>.
54     StringPiece to_append = str.substr(0, space);
55     str.remove_prefix(space + 1);
56     // Remove spaces at break.
57     while (str_util::EndsWith(to_append, " ")) {
58       to_append.remove_suffix(1);
59     }
60     while (absl::ConsumePrefix(&str, " ")) {
61     }
62 
63     // Go on to the next line.
64     strings::StrAppend(&result, to_append);
65     if (!str.empty()) strings::StrAppend(&result, indent_next_line);
66   }
67 
68   return result;
69 }
70 
ConsumeEquals(StringPiece * description)71 bool ConsumeEquals(StringPiece* description) {
72   if (absl::ConsumePrefix(description, "=")) {
73     while (absl::ConsumePrefix(description,
74                                " ")) {  // Also remove spaces after "=".
75     }
76     return true;
77   }
78   return false;
79 }
80 
81 // Split `*orig` into two pieces at the first occurrence of `split_ch`.
82 // Returns whether `split_ch` was found. Afterwards, `*before_split`
83 // contains the maximum prefix of the input `*orig` that doesn't
84 // contain `split_ch`, and `*orig` contains everything after the
85 // first `split_ch`.
SplitAt(char split_ch,StringPiece * orig,StringPiece * before_split)86 static bool SplitAt(char split_ch, StringPiece* orig,
87                     StringPiece* before_split) {
88   auto pos = orig->find(split_ch);
89   if (pos == StringPiece::npos) {
90     *before_split = *orig;
91     *orig = StringPiece();
92     return false;
93   } else {
94     *before_split = orig->substr(0, pos);
95     orig->remove_prefix(pos + 1);
96     return true;
97   }
98 }
99 
100 // Does this line start with "<spaces><field>:" where "<field>" is
101 // in multi_line_fields? Sets *colon_pos to the position of the colon.
StartsWithFieldName(StringPiece line,const std::vector<string> & multi_line_fields)102 static bool StartsWithFieldName(StringPiece line,
103                                 const std::vector<string>& multi_line_fields) {
104   StringPiece up_to_colon;
105   if (!SplitAt(':', &line, &up_to_colon)) return false;
106   while (absl::ConsumePrefix(&up_to_colon, " "))
107     ;  // Remove leading spaces.
108   for (const auto& field : multi_line_fields) {
109     if (up_to_colon == field) {
110       return true;
111     }
112   }
113   return false;
114 }
115 
ConvertLine(StringPiece line,const std::vector<string> & multi_line_fields,string * ml)116 static bool ConvertLine(StringPiece line,
117                         const std::vector<string>& multi_line_fields,
118                         string* ml) {
119   // Is this a field we should convert?
120   if (!StartsWithFieldName(line, multi_line_fields)) {
121     return false;
122   }
123   // Has a matching field name, so look for "..." after the colon.
124   StringPiece up_to_colon;
125   StringPiece after_colon = line;
126   SplitAt(':', &after_colon, &up_to_colon);
127   while (absl::ConsumePrefix(&after_colon, " "))
128     ;  // Remove leading spaces.
129   if (!absl::ConsumePrefix(&after_colon, "\"")) {
130     // We only convert string fields, so don't convert this line.
131     return false;
132   }
133   auto last_quote = after_colon.rfind('\"');
134   if (last_quote == StringPiece::npos) {
135     // Error: we don't see the expected matching quote, abort the conversion.
136     return false;
137   }
138   StringPiece escaped = after_colon.substr(0, last_quote);
139   StringPiece suffix = after_colon.substr(last_quote + 1);
140   // We've now parsed line into '<up_to_colon>: "<escaped>"<suffix>'
141 
142   string unescaped;
143   if (!absl::CUnescape(escaped, &unescaped, nullptr)) {
144     // Error unescaping, abort the conversion.
145     return false;
146   }
147   // No more errors possible at this point.
148 
149   // Find a string to mark the end that isn't in unescaped.
150   string end = "END";
151   for (int s = 0; unescaped.find(end) != string::npos; ++s) {
152     end = strings::StrCat("END", s);
153   }
154 
155   // Actually start writing the converted output.
156   strings::StrAppend(ml, up_to_colon, ": <<", end, "\n", unescaped, "\n", end);
157   if (!suffix.empty()) {
158     // Output suffix, in case there was a trailing comment in the source.
159     strings::StrAppend(ml, suffix);
160   }
161   strings::StrAppend(ml, "\n");
162   return true;
163 }
164 
PBTxtToMultiline(StringPiece pbtxt,const std::vector<string> & multi_line_fields)165 string PBTxtToMultiline(StringPiece pbtxt,
166                         const std::vector<string>& multi_line_fields) {
167   string ml;
168   // Probably big enough, since the input and output are about the
169   // same size, but just a guess.
170   ml.reserve(pbtxt.size() * (17. / 16));
171   StringPiece line;
172   while (!pbtxt.empty()) {
173     // Split pbtxt into its first line and everything after.
174     SplitAt('\n', &pbtxt, &line);
175     // Convert line or output it unchanged
176     if (!ConvertLine(line, multi_line_fields, &ml)) {
177       strings::StrAppend(&ml, line, "\n");
178     }
179   }
180   return ml;
181 }
182 
183 // Given a single line of text `line` with first : at `colon`, determine if
184 // there is an "<<END" expression after the colon and if so return true and set
185 // `*end` to everything after the "<<".
FindMultiline(StringPiece line,size_t colon,string * end)186 static bool FindMultiline(StringPiece line, size_t colon, string* end) {
187   if (colon == StringPiece::npos) return false;
188   line.remove_prefix(colon + 1);
189   while (absl::ConsumePrefix(&line, " ")) {
190   }
191   if (absl::ConsumePrefix(&line, "<<")) {
192     *end = string(line);
193     return true;
194   }
195   return false;
196 }
197 
PBTxtFromMultiline(StringPiece multiline_pbtxt)198 string PBTxtFromMultiline(StringPiece multiline_pbtxt) {
199   string pbtxt;
200   // Probably big enough, since the input and output are about the
201   // same size, but just a guess.
202   pbtxt.reserve(multiline_pbtxt.size() * (33. / 32));
203   StringPiece line;
204   while (!multiline_pbtxt.empty()) {
205     // Split multiline_pbtxt into its first line and everything after.
206     if (!SplitAt('\n', &multiline_pbtxt, &line)) {
207       strings::StrAppend(&pbtxt, line);
208       break;
209     }
210 
211     string end;
212     auto colon = line.find(':');
213     if (!FindMultiline(line, colon, &end)) {
214       // Normal case: not a multi-line string, just output the line as-is.
215       strings::StrAppend(&pbtxt, line, "\n");
216       continue;
217     }
218 
219     // Multi-line case:
220     //     something: <<END
221     // xx
222     // yy
223     // END
224     // Should be converted to:
225     //     something: "xx\nyy"
226 
227     // Output everything up to the colon ("    something:").
228     strings::StrAppend(&pbtxt, line.substr(0, colon + 1));
229 
230     // Add every line to unescaped until we see the "END" string.
231     string unescaped;
232     bool first = true;
233     while (!multiline_pbtxt.empty()) {
234       SplitAt('\n', &multiline_pbtxt, &line);
235       if (absl::ConsumePrefix(&line, end)) break;
236       if (first) {
237         first = false;
238       } else {
239         unescaped.push_back('\n');
240       }
241       strings::StrAppend(&unescaped, line);
242       line = StringPiece();
243     }
244 
245     // Escape what we extracted and then output it in quotes.
246     strings::StrAppend(&pbtxt, " \"", absl::CEscape(unescaped), "\"", line,
247                        "\n");
248   }
249   return pbtxt;
250 }
251 
StringReplace(const string & from,const string & to,string * s)252 static void StringReplace(const string& from, const string& to, string* s) {
253   // Split *s into pieces delimited by `from`.
254   std::vector<string> split;
255   string::size_type pos = 0;
256   while (pos < s->size()) {
257     auto found = s->find(from, pos);
258     if (found == string::npos) {
259       split.push_back(s->substr(pos));
260       break;
261     } else {
262       split.push_back(s->substr(pos, found - pos));
263       pos = found + from.size();
264       if (pos == s->size()) {  // handle case where `from` is at the very end.
265         split.push_back("");
266       }
267     }
268   }
269   // Join the pieces back together with a new delimiter.
270   *s = absl::StrJoin(split, to);
271 }
272 
RenameInDocs(const string & from,const string & to,ApiDef * api_def)273 static void RenameInDocs(const string& from, const string& to,
274                          ApiDef* api_def) {
275   const string from_quoted = strings::StrCat("`", from, "`");
276   const string to_quoted = strings::StrCat("`", to, "`");
277   for (int i = 0; i < api_def->in_arg_size(); ++i) {
278     if (!api_def->in_arg(i).description().empty()) {
279       StringReplace(from_quoted, to_quoted,
280                     api_def->mutable_in_arg(i)->mutable_description());
281     }
282   }
283   for (int i = 0; i < api_def->out_arg_size(); ++i) {
284     if (!api_def->out_arg(i).description().empty()) {
285       StringReplace(from_quoted, to_quoted,
286                     api_def->mutable_out_arg(i)->mutable_description());
287     }
288   }
289   for (int i = 0; i < api_def->attr_size(); ++i) {
290     if (!api_def->attr(i).description().empty()) {
291       StringReplace(from_quoted, to_quoted,
292                     api_def->mutable_attr(i)->mutable_description());
293     }
294   }
295   if (!api_def->summary().empty()) {
296     StringReplace(from_quoted, to_quoted, api_def->mutable_summary());
297   }
298   if (!api_def->description().empty()) {
299     StringReplace(from_quoted, to_quoted, api_def->mutable_description());
300   }
301 }
302 
303 namespace {
304 
305 // Initializes given ApiDef with data in OpDef.
InitApiDefFromOpDef(const OpDef & op_def,ApiDef * api_def)306 void InitApiDefFromOpDef(const OpDef& op_def, ApiDef* api_def) {
307   api_def->set_graph_op_name(op_def.name());
308   api_def->set_visibility(ApiDef::VISIBLE);
309 
310   auto* endpoint = api_def->add_endpoint();
311   endpoint->set_name(op_def.name());
312 
313   for (const auto& op_in_arg : op_def.input_arg()) {
314     auto* api_in_arg = api_def->add_in_arg();
315     api_in_arg->set_name(op_in_arg.name());
316     api_in_arg->set_rename_to(op_in_arg.name());
317     api_in_arg->set_description(op_in_arg.description());
318 
319     *api_def->add_arg_order() = op_in_arg.name();
320   }
321   for (const auto& op_out_arg : op_def.output_arg()) {
322     auto* api_out_arg = api_def->add_out_arg();
323     api_out_arg->set_name(op_out_arg.name());
324     api_out_arg->set_rename_to(op_out_arg.name());
325     api_out_arg->set_description(op_out_arg.description());
326   }
327   for (const auto& op_attr : op_def.attr()) {
328     auto* api_attr = api_def->add_attr();
329     api_attr->set_name(op_attr.name());
330     api_attr->set_rename_to(op_attr.name());
331     if (op_attr.has_default_value()) {
332       *api_attr->mutable_default_value() = op_attr.default_value();
333     }
334     api_attr->set_description(op_attr.description());
335   }
336   api_def->set_summary(op_def.summary());
337   api_def->set_description(op_def.description());
338 }
339 
340 // Updates base_arg based on overrides in new_arg.
MergeArg(ApiDef::Arg * base_arg,const ApiDef::Arg & new_arg)341 void MergeArg(ApiDef::Arg* base_arg, const ApiDef::Arg& new_arg) {
342   if (!new_arg.rename_to().empty()) {
343     base_arg->set_rename_to(new_arg.rename_to());
344   }
345   if (!new_arg.description().empty()) {
346     base_arg->set_description(new_arg.description());
347   }
348 }
349 
350 // Updates base_attr based on overrides in new_attr.
MergeAttr(ApiDef::Attr * base_attr,const ApiDef::Attr & new_attr)351 void MergeAttr(ApiDef::Attr* base_attr, const ApiDef::Attr& new_attr) {
352   if (!new_attr.rename_to().empty()) {
353     base_attr->set_rename_to(new_attr.rename_to());
354   }
355   if (new_attr.has_default_value()) {
356     *base_attr->mutable_default_value() = new_attr.default_value();
357   }
358   if (!new_attr.description().empty()) {
359     base_attr->set_description(new_attr.description());
360   }
361 }
362 
363 // Updates base_api_def based on overrides in new_api_def.
MergeApiDefs(ApiDef * base_api_def,const ApiDef & new_api_def)364 Status MergeApiDefs(ApiDef* base_api_def, const ApiDef& new_api_def) {
365   // Merge visibility
366   if (new_api_def.visibility() != ApiDef::DEFAULT_VISIBILITY) {
367     base_api_def->set_visibility(new_api_def.visibility());
368   }
369   // Merge endpoints
370   if (new_api_def.endpoint_size() > 0) {
371     base_api_def->clear_endpoint();
372     std::copy(
373         new_api_def.endpoint().begin(), new_api_def.endpoint().end(),
374         protobuf::RepeatedFieldBackInserter(base_api_def->mutable_endpoint()));
375   }
376   // Merge args
377   for (const auto& new_arg : new_api_def.in_arg()) {
378     bool found_base_arg = false;
379     for (int i = 0; i < base_api_def->in_arg_size(); ++i) {
380       auto* base_arg = base_api_def->mutable_in_arg(i);
381       if (base_arg->name() == new_arg.name()) {
382         MergeArg(base_arg, new_arg);
383         found_base_arg = true;
384         break;
385       }
386     }
387     if (!found_base_arg) {
388       return errors::FailedPrecondition("Argument ", new_arg.name(),
389                                         " not defined in base api for ",
390                                         base_api_def->graph_op_name());
391     }
392   }
393   for (const auto& new_arg : new_api_def.out_arg()) {
394     bool found_base_arg = false;
395     for (int i = 0; i < base_api_def->out_arg_size(); ++i) {
396       auto* base_arg = base_api_def->mutable_out_arg(i);
397       if (base_arg->name() == new_arg.name()) {
398         MergeArg(base_arg, new_arg);
399         found_base_arg = true;
400         break;
401       }
402     }
403     if (!found_base_arg) {
404       return errors::FailedPrecondition("Argument ", new_arg.name(),
405                                         " not defined in base api for ",
406                                         base_api_def->graph_op_name());
407     }
408   }
409   // Merge arg order
410   if (new_api_def.arg_order_size() > 0) {
411     // Validate that new arg_order is correct.
412     if (new_api_def.arg_order_size() != base_api_def->arg_order_size()) {
413       return errors::FailedPrecondition(
414           "Invalid number of arguments ", new_api_def.arg_order_size(), " for ",
415           base_api_def->graph_op_name(),
416           ". Expected: ", base_api_def->arg_order_size());
417     }
418     if (!std::is_permutation(new_api_def.arg_order().begin(),
419                              new_api_def.arg_order().end(),
420                              base_api_def->arg_order().begin())) {
421       return errors::FailedPrecondition(
422           "Invalid arg_order: ", absl::StrJoin(new_api_def.arg_order(), ", "),
423           " for ", base_api_def->graph_op_name(),
424           ". All elements in arg_order override must match base arg_order: ",
425           absl::StrJoin(base_api_def->arg_order(), ", "));
426     }
427 
428     base_api_def->clear_arg_order();
429     std::copy(
430         new_api_def.arg_order().begin(), new_api_def.arg_order().end(),
431         protobuf::RepeatedFieldBackInserter(base_api_def->mutable_arg_order()));
432   }
433   // Merge attributes
434   for (const auto& new_attr : new_api_def.attr()) {
435     bool found_base_attr = false;
436     for (int i = 0; i < base_api_def->attr_size(); ++i) {
437       auto* base_attr = base_api_def->mutable_attr(i);
438       if (base_attr->name() == new_attr.name()) {
439         MergeAttr(base_attr, new_attr);
440         found_base_attr = true;
441         break;
442       }
443     }
444     if (!found_base_attr) {
445       return errors::FailedPrecondition("Attribute ", new_attr.name(),
446                                         " not defined in base api for ",
447                                         base_api_def->graph_op_name());
448     }
449   }
450   // Merge summary
451   if (!new_api_def.summary().empty()) {
452     base_api_def->set_summary(new_api_def.summary());
453   }
454   // Merge description
455   auto description = new_api_def.description().empty()
456                          ? base_api_def->description()
457                          : new_api_def.description();
458 
459   if (!new_api_def.description_prefix().empty()) {
460     description =
461         strings::StrCat(new_api_def.description_prefix(), "\n", description);
462   }
463   if (!new_api_def.description_suffix().empty()) {
464     description =
465         strings::StrCat(description, "\n", new_api_def.description_suffix());
466   }
467   base_api_def->set_description(description);
468   return Status::OK();
469 }
470 }  // namespace
471 
ApiDefMap(const OpList & op_list)472 ApiDefMap::ApiDefMap(const OpList& op_list) {
473   for (const auto& op : op_list.op()) {
474     ApiDef api_def;
475     InitApiDefFromOpDef(op, &api_def);
476     map_[op.name()] = api_def;
477   }
478 }
479 
~ApiDefMap()480 ApiDefMap::~ApiDefMap() {}
481 
LoadFileList(Env * env,const std::vector<string> & filenames)482 Status ApiDefMap::LoadFileList(Env* env, const std::vector<string>& filenames) {
483   for (const auto& filename : filenames) {
484     TF_RETURN_IF_ERROR(LoadFile(env, filename));
485   }
486   return Status::OK();
487 }
488 
LoadFile(Env * env,const string & filename)489 Status ApiDefMap::LoadFile(Env* env, const string& filename) {
490   if (filename.empty()) return Status::OK();
491   string contents;
492   TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &contents));
493   Status status = LoadApiDef(contents);
494   if (!status.ok()) {
495     // Return failed status annotated with filename to aid in debugging.
496     return Status(status.code(),
497                   strings::StrCat("Error parsing ApiDef file ", filename, ": ",
498                                   status.error_message()));
499   }
500   return Status::OK();
501 }
502 
LoadApiDef(const string & api_def_file_contents)503 Status ApiDefMap::LoadApiDef(const string& api_def_file_contents) {
504   const string contents = PBTxtFromMultiline(api_def_file_contents);
505   ApiDefs api_defs;
506   TF_RETURN_IF_ERROR(
507       proto_utils::ParseTextFormatFromString(contents, &api_defs));
508   for (const auto& api_def : api_defs.op()) {
509     // Check if the op definition is loaded. If op definition is not
510     // loaded, then we just skip this ApiDef.
511     if (map_.find(api_def.graph_op_name()) != map_.end()) {
512       // Overwrite current api def with data in api_def.
513       TF_RETURN_IF_ERROR(MergeApiDefs(&map_[api_def.graph_op_name()], api_def));
514     }
515   }
516   return Status::OK();
517 }
518 
UpdateDocs()519 void ApiDefMap::UpdateDocs() {
520   for (auto& name_and_api_def : map_) {
521     auto& api_def = name_and_api_def.second;
522     CHECK_GT(api_def.endpoint_size(), 0);
523     const string canonical_name = api_def.endpoint(0).name();
524     if (api_def.graph_op_name() != canonical_name) {
525       RenameInDocs(api_def.graph_op_name(), canonical_name, &api_def);
526     }
527     for (const auto& in_arg : api_def.in_arg()) {
528       if (in_arg.name() != in_arg.rename_to()) {
529         RenameInDocs(in_arg.name(), in_arg.rename_to(), &api_def);
530       }
531     }
532     for (const auto& out_arg : api_def.out_arg()) {
533       if (out_arg.name() != out_arg.rename_to()) {
534         RenameInDocs(out_arg.name(), out_arg.rename_to(), &api_def);
535       }
536     }
537     for (const auto& attr : api_def.attr()) {
538       if (attr.name() != attr.rename_to()) {
539         RenameInDocs(attr.name(), attr.rename_to(), &api_def);
540       }
541     }
542   }
543 }
544 
GetApiDef(const string & name) const545 const tensorflow::ApiDef* ApiDefMap::GetApiDef(const string& name) const {
546   return gtl::FindOrNull(map_, name);
547 }
548 }  // namespace tensorflow
549