1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op_def_util.h"
17 
18 #include <set>
19 #include <unordered_map>
20 #include <unordered_set>
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/op_def.pb_text.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/stringpiece.h"
27 #include "tensorflow/core/lib/gtl/map_util.h"
28 #include "tensorflow/core/lib/hash/hash.h"
29 #include "tensorflow/core/lib/strings/scanner.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/protobuf.h"
33 #include "tensorflow/core/platform/types.h"
34 
35 namespace tensorflow {
36 namespace {  // ------ Helper functions ------
37 
HasAttrStyleType(const OpDef::ArgDef & arg)38 bool HasAttrStyleType(const OpDef::ArgDef& arg) {
39   return arg.type() != DT_INVALID || !arg.type_attr().empty() ||
40          !arg.type_list_attr().empty();
41 }
42 
AllowedTypeValue(DataType dt,const OpDef::AttrDef & attr)43 Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) {
44   const AttrValue& allowed_values(attr.allowed_values());
45   for (auto allowed : allowed_values.list().type()) {
46     if (dt == allowed) {
47       return Status::OK();
48     }
49   }
50   string allowed_str;
51   for (int i = 0; i < allowed_values.list().type_size(); ++i) {
52     if (!allowed_str.empty()) {
53       strings::StrAppend(&allowed_str, ", ");
54     }
55     strings::StrAppend(&allowed_str,
56                        DataTypeString(allowed_values.list().type(i)));
57   }
58   return errors::InvalidArgument(
59       "Value for attr '", attr.name(), "' of ", DataTypeString(dt),
60       " is not in the list of allowed values: ", allowed_str);
61 }
62 
AllowedStringValue(const string & str,const OpDef::AttrDef & attr)63 Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) {
64   const AttrValue& allowed_values(attr.allowed_values());
65   for (const auto& allowed : allowed_values.list().s()) {
66     if (str == allowed) {
67       return Status::OK();
68     }
69   }
70   string allowed_str;
71   for (const string& allowed : allowed_values.list().s()) {
72     if (!allowed_str.empty()) {
73       strings::StrAppend(&allowed_str, ", ");
74     }
75     strings::StrAppend(&allowed_str, "\"", allowed, "\"");
76   }
77   return errors::InvalidArgument(
78       "Value for attr '", attr.name(), "' of \"", str,
79       "\" is not in the list of allowed values: ", allowed_str);
80 }
81 
82 }  // namespace
83 
84 // Requires: attr has already been validated.
ValidateAttrValue(const AttrValue & attr_value,const OpDef::AttrDef & attr)85 Status ValidateAttrValue(const AttrValue& attr_value,
86                          const OpDef::AttrDef& attr) {
87   // Is it a valid value?
88   TF_RETURN_WITH_CONTEXT_IF_ERROR(AttrValueHasType(attr_value, attr.type()),
89                                   " for attr '", attr.name(), "'");
90 
91   // Does the value satisfy the minimum constraint in the AttrDef?
92   if (attr.has_minimum()) {
93     if (attr.type() == "int") {
94       if (attr_value.i() < attr.minimum()) {
95         return errors::InvalidArgument(
96             "Value for attr '", attr.name(), "' of ", attr_value.i(),
97             " must be at least minimum ", attr.minimum());
98       }
99     } else {
100       int length = -1;
101       if (attr.type() == "list(string)") {
102         length = attr_value.list().s_size();
103       } else if (attr.type() == "list(int)") {
104         length = attr_value.list().i_size();
105       } else if (attr.type() == "list(float)") {
106         length = attr_value.list().f_size();
107       } else if (attr.type() == "list(bool)") {
108         length = attr_value.list().b_size();
109       } else if (attr.type() == "list(type)") {
110         length = attr_value.list().type_size();
111       } else if (attr.type() == "list(shape)") {
112         length = attr_value.list().shape_size();
113       } else if (attr.type() == "list(tensor)") {
114         length = attr_value.list().tensor_size();
115       }
116       if (length < attr.minimum()) {
117         return errors::InvalidArgument(
118             "Length for attr '", attr.name(), "' of ", length,
119             " must be at least minimum ", attr.minimum());
120       }
121     }
122   }
123 
124   // Does the value satisfy the allowed_value constraint in the AttrDef?
125   if (attr.has_allowed_values()) {
126     if (attr.type() == "type") {
127       TF_RETURN_IF_ERROR(AllowedTypeValue(attr_value.type(), attr));
128     } else if (attr.type() == "list(type)") {
129       for (int dt : attr_value.list().type()) {
130         TF_RETURN_IF_ERROR(AllowedTypeValue(static_cast<DataType>(dt), attr));
131       }
132     } else if (attr.type() == "string") {
133       TF_RETURN_IF_ERROR(AllowedStringValue(attr_value.s(), attr));
134     } else if (attr.type() == "list(string)") {
135       for (const string& str : attr_value.list().s()) {
136         TF_RETURN_IF_ERROR(AllowedStringValue(str, attr));
137       }
138     } else {
139       return errors::Unimplemented(
140           "Support for allowed_values not implemented for type ", attr.type());
141     }
142   }
143   return Status::OK();
144 }
145 
FindAttr(StringPiece name,const OpDef & op_def)146 const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def) {
147   for (int i = 0; i < op_def.attr_size(); ++i) {
148     if (op_def.attr(i).name() == name) {
149       return &op_def.attr(i);
150     }
151   }
152   return nullptr;
153 }
154 
FindAttrMutable(StringPiece name,OpDef * op_def)155 OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def) {
156   for (int i = 0; i < op_def->attr_size(); ++i) {
157     if (op_def->attr(i).name() == name) {
158       return op_def->mutable_attr(i);
159     }
160   }
161   return nullptr;
162 }
163 
FindInputArg(StringPiece name,const OpDef & op_def)164 const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def) {
165   for (int i = 0; i < op_def.input_arg_size(); ++i) {
166     if (op_def.input_arg(i).name() == name) {
167       return &op_def.input_arg(i);
168     }
169   }
170   return nullptr;
171 }
172 
173 #define VALIDATE(EXPR, ...)                                            \
174   do {                                                                 \
175     if (!(EXPR)) {                                                     \
176       return errors::InvalidArgument(                                  \
177           __VA_ARGS__, "; in OpDef: ", ProtoShortDebugString(op_def)); \
178     }                                                                  \
179   } while (false)
180 
ValidateArg(const OpDef::ArgDef & arg,const OpDef & op_def,bool output,std::set<string> * names)181 static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def,
182                           bool output, std::set<string>* names) {
183   const string suffix = strings::StrCat(
184       output ? " for output '" : " for input '", arg.name(), "'");
185   VALIDATE(gtl::InsertIfNotPresent(names, arg.name()),
186            "Duplicate name: ", arg.name());
187   VALIDATE(HasAttrStyleType(arg), "Missing type", suffix);
188 
189   if (!arg.number_attr().empty()) {
190     const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def);
191     VALIDATE(attr != nullptr, "No attr with name '", arg.number_attr(), "'",
192              suffix);
193     VALIDATE(attr->type() == "int", "Attr '", attr->name(), "' used as length",
194              suffix, " has type ", attr->type(), " != int");
195     VALIDATE(attr->has_minimum(), "Attr '", attr->name(), "' used as length",
196              suffix, " must have minimum");
197     VALIDATE(attr->minimum() >= 0, "Attr '", attr->name(), "' used as length",
198              suffix, " must have minimum >= 0");
199     VALIDATE(arg.type_list_attr().empty(),
200              "Can't have both number_attr and type_list_attr", suffix);
201     VALIDATE((arg.type() != DT_INVALID ? 1 : 0) +
202                      (!arg.type_attr().empty() ? 1 : 0) ==
203                  1,
204              "Exactly one of type, type_attr must be set", suffix);
205   } else {
206     const int num_type_fields = (arg.type() != DT_INVALID ? 1 : 0) +
207                                 (!arg.type_attr().empty() ? 1 : 0) +
208                                 (!arg.type_list_attr().empty() ? 1 : 0);
209     VALIDATE(num_type_fields == 1,
210              "Exactly one of type, type_attr, type_list_attr must be set",
211              suffix);
212   }
213 
214   if (!arg.type_attr().empty()) {
215     const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def);
216     VALIDATE(attr != nullptr, "No attr with name '", arg.type_attr(), "'",
217              suffix);
218     VALIDATE(attr->type() == "type", "Attr '", attr->name(),
219              "' used as type_attr", suffix, " has type ", attr->type(),
220              " != type");
221   } else if (!arg.type_list_attr().empty()) {
222     const OpDef::AttrDef* attr = FindAttr(arg.type_list_attr(), op_def);
223     VALIDATE(attr != nullptr, "No attr with name '", arg.type_list_attr(), "'",
224              suffix);
225     VALIDATE(attr->type() == "list(type)", "Attr '", attr->name(),
226              "' used as type_list_attr", suffix, " has type ", attr->type(),
227              " != list(type)");
228   } else {
229     // All argument types should be non-reference types at this point.
230     // ArgDef.is_ref is set to true for reference arguments.
231     VALIDATE(!IsRefType(arg.type()), "Illegal use of ref type '",
232              DataTypeString(arg.type()), "'. Use 'Ref(type)' instead", suffix);
233   }
234 
235   return Status::OK();
236 }
237 
ValidateOpDef(const OpDef & op_def)238 Status ValidateOpDef(const OpDef& op_def) {
239   using ::tensorflow::strings::Scanner;
240 
241   if (!StringPiece(op_def.name()).starts_with("_")) {
242     VALIDATE(Scanner(op_def.name())
243                  .One(Scanner::UPPERLETTER)
244                  .Any(Scanner::LETTER_DIGIT)
245                  .Eos()
246                  .GetResult(),
247              "Invalid name: ", op_def.name(), " (Did you use CamelCase?)");
248   }
249 
250   std::set<string> names;  // for detecting duplicate names
251   for (const auto& attr : op_def.attr()) {
252     // Validate name
253     VALIDATE(gtl::InsertIfNotPresent(&names, attr.name()),
254              "Duplicate name: ", attr.name());
255     DataType dt;
256     VALIDATE(!DataTypeFromString(attr.name(), &dt), "Attr can't have name ",
257              attr.name(), " that matches a data type");
258 
259     // Validate type
260     StringPiece type(attr.type());
261     bool is_list = type.Consume("list(");
262     bool found = false;
263     for (StringPiece valid : {"string", "int", "float", "bool", "type", "shape",
264                               "tensor", "func"}) {
265       if (type.Consume(valid)) {
266         found = true;
267         break;
268       }
269     }
270     VALIDATE(found, "Unrecognized type '", type, "' in attr '", attr.name(),
271              "'");
272     if (is_list) {
273       VALIDATE(type.Consume(")"), "'list(' is missing ')' in attr ",
274                attr.name(), "'s type ", attr.type());
275     }
276     VALIDATE(type.empty(), "Extra '", type, "' at the end of attr ",
277              attr.name(), "'s type ", attr.type());
278 
279     // Validate minimum
280     if (attr.has_minimum()) {
281       VALIDATE(attr.type() == "int" || is_list, "Attr '", attr.name(),
282                "' has minimum for unsupported type ", attr.type());
283       if (is_list) {
284         VALIDATE(attr.minimum() >= 0, "Attr '", attr.name(),
285                  "' with list type must have a non-negative minimum, not ",
286                  attr.minimum());
287       }
288     } else {
289       VALIDATE(attr.minimum() == 0, "Attr '", attr.name(),
290                "' with has_minimum = false but minimum ", attr.minimum(),
291                " not equal to default of 0");
292     }
293 
294     // Validate allowed_values
295     if (attr.has_allowed_values()) {
296       const string list_type =
297           is_list ? attr.type() : strings::StrCat("list(", attr.type(), ")");
298       TF_RETURN_WITH_CONTEXT_IF_ERROR(
299           AttrValueHasType(attr.allowed_values(), list_type), " for attr '",
300           attr.name(), "' in Op '", op_def.name(), "'");
301     }
302 
303     // Validate default_value (after we have validated the rest of the attr,
304     // so we can use ValidateAttrValue()).
305     if (attr.has_default_value()) {
306       TF_RETURN_WITH_CONTEXT_IF_ERROR(
307           ValidateAttrValue(attr.default_value(), attr), " in Op '",
308           op_def.name(), "'");
309     }
310   }
311 
312   for (const auto& arg : op_def.input_arg()) {
313     TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, false, &names));
314   }
315 
316   for (const auto& arg : op_def.output_arg()) {
317     TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, true, &names));
318   }
319 
320   return Status::OK();
321 }
322 
323 #undef VALIDATE
324 
CheckOpDeprecation(const OpDef & op_def,int graph_def_version)325 Status CheckOpDeprecation(const OpDef& op_def, int graph_def_version) {
326   if (op_def.has_deprecation()) {
327     const OpDeprecation& dep = op_def.deprecation();
328     if (graph_def_version >= dep.version()) {
329       return errors::Unimplemented(
330           "Op ", op_def.name(), " is not available in GraphDef version ",
331           graph_def_version, ". It has been removed in version ", dep.version(),
332           ". ", dep.explanation(), ".");
333     } else {
334       // Warn only once for each op name, and do it in a threadsafe manner.
335       static mutex mu(LINKER_INITIALIZED);
336       static std::unordered_set<string> warned;
337       bool warn;
338       {
339         mutex_lock lock(mu);
340         warn = warned.insert(op_def.name()).second;
341       }
342       if (warn) {
343         LOG(WARNING) << "Op " << op_def.name() << " is deprecated."
344                      << " It will cease to work in GraphDef version "
345                      << dep.version() << ". " << dep.explanation() << ".";
346       }
347     }
348   }
349   return Status::OK();
350 }
351 
352 namespace {
353 
SummarizeArgs(const protobuf::RepeatedPtrField<OpDef::ArgDef> & args)354 string SummarizeArgs(const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
355   string ret;
356   for (const OpDef::ArgDef& arg : args) {
357     if (!ret.empty()) strings::StrAppend(&ret, ", ");
358     strings::StrAppend(&ret, arg.name(), ":");
359     if (arg.is_ref()) strings::StrAppend(&ret, "Ref(");
360     if (!arg.number_attr().empty()) {
361       strings::StrAppend(&ret, arg.number_attr(), "*");
362     }
363     if (arg.type() != DT_INVALID) {
364       strings::StrAppend(&ret, DataTypeString(arg.type()));
365     } else {
366       strings::StrAppend(&ret, arg.type_attr());
367     }
368     if (arg.is_ref()) strings::StrAppend(&ret, ")");
369   }
370   return ret;
371 }
372 
373 }  // namespace
374 
SummarizeOpDef(const OpDef & op_def)375 string SummarizeOpDef(const OpDef& op_def) {
376   string ret = strings::StrCat("Op<name=", op_def.name());
377   strings::StrAppend(&ret, "; signature=", SummarizeArgs(op_def.input_arg()),
378                      " -> ", SummarizeArgs(op_def.output_arg()));
379   for (int i = 0; i < op_def.attr_size(); ++i) {
380     strings::StrAppend(&ret, "; attr=", op_def.attr(i).name(), ":",
381                        op_def.attr(i).type());
382     if (op_def.attr(i).has_default_value()) {
383       strings::StrAppend(&ret, ",default=",
384                          SummarizeAttrValue(op_def.attr(i).default_value()));
385     }
386     if (op_def.attr(i).has_minimum()) {
387       strings::StrAppend(&ret, ",min=", op_def.attr(i).minimum());
388     }
389     if (op_def.attr(i).has_allowed_values()) {
390       strings::StrAppend(&ret, ",allowed=",
391                          SummarizeAttrValue(op_def.attr(i).allowed_values()));
392     }
393   }
394   if (op_def.is_commutative()) {
395     strings::StrAppend(&ret, "; is_commutative=true");
396   }
397   if (op_def.is_aggregate()) {
398     strings::StrAppend(&ret, "; is_aggregate=true");
399   }
400   if (op_def.is_stateful()) {
401     strings::StrAppend(&ret, "; is_stateful=true");
402   }
403   if (op_def.allows_uninitialized_input()) {
404     strings::StrAppend(&ret, "; allows_uninitialized_input=true");
405   }
406   strings::StrAppend(&ret, ">");
407   return ret;
408 }
409 
410 namespace {
411 
412 // Returns true if every element of `sub` is contained in `super`.
413 template <class T>
IsSubsetOf(const T & sub,const T & super)414 bool IsSubsetOf(const T& sub, const T& super) {
415   for (const auto& o : sub) {
416     bool found = false;
417     for (const auto& n : super) {
418       if (o == n) {
419         found = true;
420         break;
421       }
422     }
423     if (!found) return false;
424   }
425   return true;
426 }
427 
MoreRestrictive(const OpDef::AttrDef & old_attr,const OpDef::AttrDef & new_attr)428 bool MoreRestrictive(const OpDef::AttrDef& old_attr,
429                      const OpDef::AttrDef& new_attr) {
430   // Anything -> no restriction : not more restrictive.
431   if (!new_attr.has_allowed_values()) return false;
432   // No restriction -> restriction : more restrictive.
433   if (!old_attr.has_allowed_values()) return true;
434   // If anything that was previously allowed is no longer allowed:
435   // more restrictive.
436   if (!IsSubsetOf(old_attr.allowed_values().list().type(),
437                   new_attr.allowed_values().list().type())) {
438     return true;
439   }
440   if (!IsSubsetOf(old_attr.allowed_values().list().s(),
441                   new_attr.allowed_values().list().s())) {
442     return true;
443   }
444   return false;
445 }
446 
AllowedStr(const OpDef::AttrDef & attr)447 string AllowedStr(const OpDef::AttrDef& attr) {
448   if (!attr.has_allowed_values()) return "no restriction";
449   return SummarizeAttrValue(attr.allowed_values());
450 }
451 
DefaultAttrStr(const OpDef::AttrDef & attr)452 string DefaultAttrStr(const OpDef::AttrDef& attr) {
453   if (!attr.has_default_value()) return "no default";
454   return SummarizeAttrValue(attr.default_value());
455 }
456 
HigherMinimum(const OpDef::AttrDef & old_attr,const OpDef::AttrDef & new_attr)457 bool HigherMinimum(const OpDef::AttrDef& old_attr,
458                    const OpDef::AttrDef& new_attr) {
459   // Anything -> no restriction : not more restrictive.
460   if (!new_attr.has_minimum()) return false;
461   // No restriction -> restriction : more restrictive.
462   if (!old_attr.has_minimum()) return true;
463   // If anything that was previously allowed is no longer allowed:
464   // more restrictive.
465   return new_attr.minimum() > old_attr.minimum();
466 }
467 
MinStr(const OpDef::AttrDef & attr)468 string MinStr(const OpDef::AttrDef& attr) {
469   if (!attr.has_minimum()) return "no minimum";
470   return strings::StrCat(attr.minimum());
471 }
472 
473 typedef std::unordered_map<string, const OpDef::AttrDef*> AttrMap;
FillAttrMap(const OpDef & op_def,AttrMap * attr_map)474 void FillAttrMap(const OpDef& op_def, AttrMap* attr_map) {
475   for (const auto& attr : op_def.attr()) {
476     (*attr_map)[attr.name()] = &attr;
477   }
478 }
479 
480 // Add a comma to *s every call but the first (*add_comma should be
481 // initialized to false).
AddComma(string * s,bool * add_comma)482 void AddComma(string* s, bool* add_comma) {
483   if (*add_comma) {
484     strings::StrAppend(s, ", ");
485   } else {
486     *add_comma = true;
487   }
488 }
489 
490 // Will add the `name` from arg if name is true.
AddName(string * s,bool name,const OpDef::ArgDef & arg)491 void AddName(string* s, bool name, const OpDef::ArgDef& arg) {
492   if (name) {
493     strings::StrAppend(s, arg.name(), ":");
494   }
495 }
496 
497 // Compute a signature for either inputs or outputs that will be the
498 // same for both the old and new OpDef if they are compatible.  We
499 // assume that new_attrs is a superset of old_attrs, and that any attr
500 // in the difference has a default.  Our strategy is to make a list of
501 // types, where the types are things like:
502 // * "int32", "float", etc.,
503 // * "T" for some attr "T" in old_attrs, or
504 // * "N * type" for "N" either some attr in old_attrs.
505 //
506 // We get the types by either using the attrs in args if they are in
507 // old_attrs, or substituting the default value from new_attrs.
ComputeArgSignature(const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const AttrMap & old_attrs,const AttrMap & new_attrs,std::vector<bool> * ref,bool names)508 string ComputeArgSignature(
509     const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
510     const AttrMap& old_attrs, const AttrMap& new_attrs, std::vector<bool>* ref,
511     bool names) {
512   string s;
513   bool add_comma = false;
514   for (const OpDef::ArgDef& arg : args) {
515     if (!arg.type_list_attr().empty()) {
516       const OpDef::AttrDef* old_attr =
517           gtl::FindPtrOrNull(old_attrs, arg.type_list_attr());
518       if (old_attr) {
519         // Both old and new have the list(type) attr, so can use it directly.
520         AddComma(&s, &add_comma);
521         AddName(&s, names, arg);
522         strings::StrAppend(&s, arg.type_list_attr());
523         ref->push_back(arg.is_ref());
524       } else {
525         // Missing the list(type) attr in the old, so use the default
526         // value for the attr from new instead.
527         const OpDef::AttrDef* new_attr =
528             gtl::FindPtrOrNull(new_attrs, arg.type_list_attr());
529         const auto& type_list = new_attr->default_value().list().type();
530         if (type_list.empty()) continue;
531         for (int i = 0; i < type_list.size(); ++i) {
532           AddComma(&s, &add_comma);
533           AddName(&s, names, arg);
534           strings::StrAppend(
535               &s, DataTypeString(static_cast<DataType>(type_list.Get(i))));
536           ref->push_back(arg.is_ref());
537         }
538       }
539     } else {
540       int num = 1;  // How many input/outputs does this represent?
541       string type;  // What is the type of this arg?
542       AddName(&type, names, arg);
543       if (!arg.number_attr().empty()) {
544         // N * type case.
545         const OpDef::AttrDef* old_attr =
546             gtl::FindPtrOrNull(old_attrs, arg.number_attr());
547         if (old_attr) {
548           // Both old and new have the number attr, so can use it directly.
549           strings::StrAppend(&type, arg.number_attr(), " * ");
550         } else {
551           // Missing the number attr in the old, so use the default
552           // value for the attr from new instead.
553           const OpDef::AttrDef* new_attr =
554               gtl::FindPtrOrNull(new_attrs, arg.number_attr());
555           num = new_attr->default_value().i();
556         }
557       }
558 
559       if (arg.type() != DT_INVALID) {
560         // int32, float, etc. case
561         strings::StrAppend(&type, DataTypeString(arg.type()));
562       } else {
563         const OpDef::AttrDef* old_attr =
564             gtl::FindPtrOrNull(old_attrs, arg.type_attr());
565         if (old_attr) {
566           // Both old and new have the type attr, so can use it directly.
567           strings::StrAppend(&type, arg.type_attr());
568         } else {
569           // Missing the type attr in the old, so use the default
570           // value for the attr from new instead.
571           const OpDef::AttrDef* new_attr =
572               gtl::FindPtrOrNull(new_attrs, arg.type_attr());
573           strings::StrAppend(&type,
574                              DataTypeString(new_attr->default_value().type()));
575         }
576       }
577 
578       // Record `num` * `type` in the signature.
579       for (int i = 0; i < num; ++i) {
580         AddComma(&s, &add_comma);
581         strings::StrAppend(&s, type);
582         ref->push_back(arg.is_ref());
583       }
584     }
585   }
586 
587   return s;
588 }
589 
590 }  // namespace
591 
OpDefCompatible(const OpDef & old_op,const OpDef & new_op)592 Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op) {
593 #define VALIDATE(CONDITION, ...)                                            \
594   if (!(CONDITION)) {                                                       \
595     return errors::InvalidArgument("Incompatible Op change: ", __VA_ARGS__, \
596                                    "; old: ", SummarizeOpDef(old_op),       \
597                                    "; new: ", SummarizeOpDef(new_op));      \
598   }
599 
600   VALIDATE(old_op.name() == new_op.name(), "Name mismatch");
601 
602   AttrMap new_attrs, old_attrs;
603   FillAttrMap(old_op, &old_attrs);
604   FillAttrMap(new_op, &new_attrs);
605   for (const auto& old_attr : old_op.attr()) {
606     const OpDef::AttrDef* new_attr =
607         gtl::FindPtrOrNull(new_attrs, old_attr.name());
608     VALIDATE(new_attr != nullptr, "Attr '", old_attr.name(), "' removed");
609     VALIDATE(old_attr.type() == new_attr->type(), "Attr '", old_attr.name(),
610              "' changed type '", old_attr.type(), "' -> '", new_attr->type(),
611              "'");
612     VALIDATE(!MoreRestrictive(old_attr, *new_attr), "Attr '", old_attr.name(),
613              "' has a stricter set of allowed values; from ",
614              AllowedStr(old_attr), " to ", AllowedStr(*new_attr));
615     VALIDATE(!HigherMinimum(old_attr, *new_attr), "Attr '", old_attr.name(),
616              "' has a higher minimum; from ", MinStr(old_attr), " to ",
617              MinStr(*new_attr));
618   }
619 
620   for (const auto& new_attr : new_op.attr()) {
621     const OpDef::AttrDef* old_attr =
622         gtl::FindPtrOrNull(old_attrs, new_attr.name());
623     VALIDATE(old_attr != nullptr || new_attr.has_default_value(), "Attr '",
624              new_attr.name(), "' added without default");
625   }
626 
627   std::vector<bool> old_in_ref, new_in_ref, old_out_ref, new_out_ref;
628   const string old_in_sig = ComputeArgSignature(
629       old_op.input_arg(), old_attrs, new_attrs, &old_in_ref, false /* names */);
630   const string new_in_sig = ComputeArgSignature(
631       new_op.input_arg(), old_attrs, new_attrs, &new_in_ref, false /* names */);
632   VALIDATE(old_in_sig == new_in_sig, "Input signature mismatch '", old_in_sig,
633            "' vs. '", new_in_sig, "'");
634   VALIDATE(old_in_ref.size() == new_in_ref.size(),  // Should not happen
635            "Unexpected change in input ref lists.");
636   for (int i = 0; i < old_in_ref.size(); ++i) {
637     // Allowed to remove "ref" from an input (or leave it unchanged).
638     VALIDATE(old_in_ref[i] || !new_in_ref[i], "Input ", i,
639              " changed from non-ref to ref");
640   }
641 
642   const string old_out_sig =
643       ComputeArgSignature(old_op.output_arg(), old_attrs, new_attrs,
644                           &old_out_ref, true /* names */);
645   const string new_out_sig =
646       ComputeArgSignature(new_op.output_arg(), old_attrs, new_attrs,
647                           &new_out_ref, true /* names */);
648   VALIDATE(old_out_sig == new_out_sig, "Output signature mismatch '",
649            old_out_sig, "' vs. '", new_out_sig, "'");
650   VALIDATE(old_out_ref.size() == new_out_ref.size(),  // Should not happen
651            "Unexpected change in output ref lists");
652   for (int i = 0; i < old_out_ref.size(); ++i) {
653     // Allowed to add "ref" to an output (or leave it unchanged).
654     VALIDATE(!old_out_ref[i] || new_out_ref[i], "Output ", i,
655              " changed from ref to non-ref");
656   }
657 
658   return Status::OK();
659 }
660 
OpDefAddedDefaultsUnchanged(const OpDef & old_op,const OpDef & penultimate_op,const OpDef & new_op)661 Status OpDefAddedDefaultsUnchanged(const OpDef& old_op,
662                                    const OpDef& penultimate_op,
663                                    const OpDef& new_op) {
664   AttrMap new_attrs, old_attrs;
665   FillAttrMap(old_op, &old_attrs);
666   FillAttrMap(new_op, &new_attrs);
667 
668   for (const auto& penultimate_attr : penultimate_op.attr()) {
669     const OpDef::AttrDef* old_attr =
670         gtl::FindPtrOrNull(old_attrs, penultimate_attr.name());
671     if (old_attr != nullptr) continue;  // attr wasn't added
672     const OpDef::AttrDef* new_attr =
673         gtl::FindPtrOrNull(new_attrs, penultimate_attr.name());
674 
675     // These shouldn't happen if the op passed OpDefCompatible().
676     if (new_attr == nullptr) {
677       return errors::InvalidArgument("Missing attr '", penultimate_attr.name(),
678                                      "' in op: ", SummarizeOpDef(new_op));
679     }
680     if (!penultimate_attr.has_default_value() ||
681         !new_attr->has_default_value()) {
682       return errors::InvalidArgument("Missing default for attr '",
683                                      penultimate_attr.name(),
684                                      "' in op: ", SummarizeOpDef(new_op));
685     }
686 
687     // Actually test that the attr's default value hasn't changed.
688     if (!AreAttrValuesEqual(penultimate_attr.default_value(),
689                             new_attr->default_value())) {
690       return errors::InvalidArgument(
691           "Can't change default value for attr '", penultimate_attr.name(),
692           "' from ", SummarizeAttrValue(penultimate_attr.default_value()),
693           " in op: ", SummarizeOpDef(new_op));
694     }
695   }
696 
697   return Status::OK();
698 }
699 
OpDefAttrDefaultsUnchanged(const OpDef & old_op,const OpDef & new_op)700 Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op) {
701   AttrMap new_attrs, old_attrs;
702   FillAttrMap(old_op, &old_attrs);
703   FillAttrMap(new_op, &new_attrs);
704 
705   for (const auto& old_attr : old_op.attr()) {
706     const OpDef::AttrDef* new_attr =
707         gtl::FindPtrOrNull(new_attrs, old_attr.name());
708     if (new_attr == nullptr) continue;
709     if (old_attr.has_default_value() != new_attr->has_default_value()) {
710       return errors::InvalidArgument(
711           "Attr '", old_attr.name(), "' has added/removed it's default; ",
712           "from ", DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr));
713     }
714     if (old_attr.has_default_value() &&
715         !AreAttrValuesEqual(old_attr.default_value(),
716                             new_attr->default_value())) {
717       return errors::InvalidArgument(
718           "Attr '", old_attr.name(), "' has changed it's default value; ",
719           "from ", DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr));
720     }
721   }
722 
723   return Status::OK();
724 }
725 
RemoveNonDeprecationDescriptionsFromOpDef(OpDef * op_def)726 void RemoveNonDeprecationDescriptionsFromOpDef(OpDef* op_def) {
727   for (int i = 0; i < op_def->input_arg_size(); ++i) {
728     op_def->mutable_input_arg(i)->clear_description();
729   }
730   for (int i = 0; i < op_def->output_arg_size(); ++i) {
731     op_def->mutable_output_arg(i)->clear_description();
732   }
733   for (int i = 0; i < op_def->attr_size(); ++i) {
734     op_def->mutable_attr(i)->clear_description();
735   }
736   op_def->clear_summary();
737   op_def->clear_description();
738 }
739 
RemoveDescriptionsFromOpDef(OpDef * op_def)740 void RemoveDescriptionsFromOpDef(OpDef* op_def) {
741   RemoveNonDeprecationDescriptionsFromOpDef(op_def);
742   if (op_def->has_deprecation()) {
743     op_def->mutable_deprecation()->clear_explanation();
744   }
745 }
746 
RemoveDescriptionsFromOpList(OpList * op_list)747 void RemoveDescriptionsFromOpList(OpList* op_list) {
748   for (int i = 0; i < op_list->op_size(); ++i) {
749     OpDef* op_def = op_list->mutable_op(i);
750     RemoveDescriptionsFromOpDef(op_def);
751   }
752 }
753 
AttrDefEqual(const OpDef::AttrDef & a1,const OpDef::AttrDef & a2)754 bool AttrDefEqual(const OpDef::AttrDef& a1, const OpDef::AttrDef& a2) {
755 #ifndef TENSORFLOW_LITE_PROTOS
756   DCHECK_EQ(7, a1.GetDescriptor()->field_count())
757       << "Please modify these equality and hash functions to reflect the "
758          "changes to the AttrDef protobuf";
759 #endif  // TENSORFLOW_LITE_PROTOS
760 
761   if (a1.name() != a2.name()) return false;
762   if (a1.type() != a2.type()) return false;
763   if (a1.description() != a2.description()) return false;
764   if (a1.has_minimum() != a2.has_minimum()) return false;
765   if (a1.has_minimum() && a1.minimum() != a2.minimum()) return false;
766   if (!AreAttrValuesEqual(a1.default_value(), a2.default_value())) return false;
767   if (!AreAttrValuesEqual(a1.allowed_values(), a2.allowed_values()))
768     return false;
769   return true;
770 }
771 
AttrDefHash(const OpDef::AttrDef & a)772 uint64 AttrDefHash(const OpDef::AttrDef& a) {
773   uint64 h = Hash64(a.name());
774   h = Hash64(a.type().data(), a.type().size(), h);
775   h = Hash64Combine(AttrValueHash(a.default_value()), h);
776   h = Hash64(a.description().data(), a.description().size(), h);
777   h = Hash64Combine(static_cast<uint64>(a.has_minimum()), h);
778   h = Hash64Combine(static_cast<uint64>(a.minimum()), h);
779   h = Hash64Combine(AttrValueHash(a.allowed_values()), h);
780   return h;
781 }
782 
RepeatedAttrDefEqual(const protobuf::RepeatedPtrField<OpDef::AttrDef> & a1,const protobuf::RepeatedPtrField<OpDef::AttrDef> & a2)783 bool RepeatedAttrDefEqual(
784     const protobuf::RepeatedPtrField<OpDef::AttrDef>& a1,
785     const protobuf::RepeatedPtrField<OpDef::AttrDef>& a2) {
786   std::unordered_map<string, const OpDef::AttrDef*> a1_set;
787   for (const OpDef::AttrDef& def : a1) {
788     DCHECK(a1_set.find(def.name()) == a1_set.end())
789         << "AttrDef names must be unique, but '" << def.name()
790         << "' appears more than once";
791     a1_set[def.name()] = &def;
792   }
793   for (const OpDef::AttrDef& def : a2) {
794     auto iter = a1_set.find(def.name());
795     if (iter == a1_set.end()) return false;
796     if (!AttrDefEqual(*iter->second, def)) return false;
797     a1_set.erase(iter);
798   }
799   if (!a1_set.empty()) return false;
800   return true;
801 }
802 
RepeatedAttrDefHash(const protobuf::RepeatedPtrField<OpDef::AttrDef> & a)803 uint64 RepeatedAttrDefHash(
804     const protobuf::RepeatedPtrField<OpDef::AttrDef>& a) {
805   // Insert AttrDefs into map to deterministically sort by name
806   std::map<string, const OpDef::AttrDef*> a_set;
807   for (const OpDef::AttrDef& def : a) {
808     a_set[def.name()] = &def;
809   }
810   // Iterate and combines hashes of keys and values
811   uint64 h = 0xDECAFCAFFE;
812   for (const auto& pair : a_set) {
813     h = Hash64(pair.first.data(), pair.first.size(), h);
814     h = Hash64Combine(AttrDefHash(*pair.second), h);
815   }
816   return h;
817 }
818 
OpDefEqual(const OpDef & o1,const OpDef & o2)819 bool OpDefEqual(const OpDef& o1, const OpDef& o2) {
820   // attr order doesn't matter.
821   // Compare it separately here instead of serializing below.
822   if (!RepeatedAttrDefEqual(o1.attr(), o2.attr())) return false;
823 
824   // Clear attr field, serialize, and compare serialized strings
825   OpDef o1_copy = o1;
826   OpDef o2_copy = o2;
827   o1_copy.clear_attr();
828   o2_copy.clear_attr();
829   string s1, s2;
830   SerializeToStringDeterministic(o1_copy, &s1);
831   SerializeToStringDeterministic(o2_copy, &s2);
832   if (s1 != s2) return false;
833   return true;
834 }
835 
OpDefHash(const OpDef & o)836 uint64 OpDefHash(const OpDef& o) {
837   uint64 h = RepeatedAttrDefHash(o.attr());
838   OpDef o_copy = o;
839   o_copy.clear_attr();
840   string s;
841   SerializeToStringDeterministic(o_copy, &s);
842   return Hash64(s.data(), s.size(), h);
843 }
844 
845 }  // namespace tensorflow
846