1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/op_def_util.h"
17
18 #include <set>
19 #include <unordered_map>
20 #include <unordered_set>
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/op_def.pb_text.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/stringpiece.h"
27 #include "tensorflow/core/lib/gtl/map_util.h"
28 #include "tensorflow/core/lib/hash/hash.h"
29 #include "tensorflow/core/lib/strings/scanner.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/protobuf.h"
33 #include "tensorflow/core/platform/types.h"
34
35 namespace tensorflow {
36 namespace { // ------ Helper functions ------
37
HasAttrStyleType(const OpDef::ArgDef & arg)38 bool HasAttrStyleType(const OpDef::ArgDef& arg) {
39 return arg.type() != DT_INVALID || !arg.type_attr().empty() ||
40 !arg.type_list_attr().empty();
41 }
42
AllowedTypeValue(DataType dt,const OpDef::AttrDef & attr)43 Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) {
44 const AttrValue& allowed_values(attr.allowed_values());
45 for (auto allowed : allowed_values.list().type()) {
46 if (dt == allowed) {
47 return Status::OK();
48 }
49 }
50 string allowed_str;
51 for (int i = 0; i < allowed_values.list().type_size(); ++i) {
52 if (!allowed_str.empty()) {
53 strings::StrAppend(&allowed_str, ", ");
54 }
55 strings::StrAppend(&allowed_str,
56 DataTypeString(allowed_values.list().type(i)));
57 }
58 return errors::InvalidArgument(
59 "Value for attr '", attr.name(), "' of ", DataTypeString(dt),
60 " is not in the list of allowed values: ", allowed_str);
61 }
62
AllowedStringValue(const string & str,const OpDef::AttrDef & attr)63 Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) {
64 const AttrValue& allowed_values(attr.allowed_values());
65 for (const auto& allowed : allowed_values.list().s()) {
66 if (str == allowed) {
67 return Status::OK();
68 }
69 }
70 string allowed_str;
71 for (const string& allowed : allowed_values.list().s()) {
72 if (!allowed_str.empty()) {
73 strings::StrAppend(&allowed_str, ", ");
74 }
75 strings::StrAppend(&allowed_str, "\"", allowed, "\"");
76 }
77 return errors::InvalidArgument(
78 "Value for attr '", attr.name(), "' of \"", str,
79 "\" is not in the list of allowed values: ", allowed_str);
80 }
81
82 } // namespace
83
84 // Requires: attr has already been validated.
ValidateAttrValue(const AttrValue & attr_value,const OpDef::AttrDef & attr)85 Status ValidateAttrValue(const AttrValue& attr_value,
86 const OpDef::AttrDef& attr) {
87 // Is it a valid value?
88 TF_RETURN_WITH_CONTEXT_IF_ERROR(AttrValueHasType(attr_value, attr.type()),
89 " for attr '", attr.name(), "'");
90
91 // Does the value satisfy the minimum constraint in the AttrDef?
92 if (attr.has_minimum()) {
93 if (attr.type() == "int") {
94 if (attr_value.i() < attr.minimum()) {
95 return errors::InvalidArgument(
96 "Value for attr '", attr.name(), "' of ", attr_value.i(),
97 " must be at least minimum ", attr.minimum());
98 }
99 } else {
100 int length = -1;
101 if (attr.type() == "list(string)") {
102 length = attr_value.list().s_size();
103 } else if (attr.type() == "list(int)") {
104 length = attr_value.list().i_size();
105 } else if (attr.type() == "list(float)") {
106 length = attr_value.list().f_size();
107 } else if (attr.type() == "list(bool)") {
108 length = attr_value.list().b_size();
109 } else if (attr.type() == "list(type)") {
110 length = attr_value.list().type_size();
111 } else if (attr.type() == "list(shape)") {
112 length = attr_value.list().shape_size();
113 } else if (attr.type() == "list(tensor)") {
114 length = attr_value.list().tensor_size();
115 }
116 if (length < attr.minimum()) {
117 return errors::InvalidArgument(
118 "Length for attr '", attr.name(), "' of ", length,
119 " must be at least minimum ", attr.minimum());
120 }
121 }
122 }
123
124 // Does the value satisfy the allowed_value constraint in the AttrDef?
125 if (attr.has_allowed_values()) {
126 if (attr.type() == "type") {
127 TF_RETURN_IF_ERROR(AllowedTypeValue(attr_value.type(), attr));
128 } else if (attr.type() == "list(type)") {
129 for (int dt : attr_value.list().type()) {
130 TF_RETURN_IF_ERROR(AllowedTypeValue(static_cast<DataType>(dt), attr));
131 }
132 } else if (attr.type() == "string") {
133 TF_RETURN_IF_ERROR(AllowedStringValue(attr_value.s(), attr));
134 } else if (attr.type() == "list(string)") {
135 for (const string& str : attr_value.list().s()) {
136 TF_RETURN_IF_ERROR(AllowedStringValue(str, attr));
137 }
138 } else {
139 return errors::Unimplemented(
140 "Support for allowed_values not implemented for type ", attr.type());
141 }
142 }
143 return Status::OK();
144 }
145
FindAttr(StringPiece name,const OpDef & op_def)146 const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def) {
147 for (int i = 0; i < op_def.attr_size(); ++i) {
148 if (op_def.attr(i).name() == name) {
149 return &op_def.attr(i);
150 }
151 }
152 return nullptr;
153 }
154
FindAttrMutable(StringPiece name,OpDef * op_def)155 OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def) {
156 for (int i = 0; i < op_def->attr_size(); ++i) {
157 if (op_def->attr(i).name() == name) {
158 return op_def->mutable_attr(i);
159 }
160 }
161 return nullptr;
162 }
163
FindInputArg(StringPiece name,const OpDef & op_def)164 const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def) {
165 for (int i = 0; i < op_def.input_arg_size(); ++i) {
166 if (op_def.input_arg(i).name() == name) {
167 return &op_def.input_arg(i);
168 }
169 }
170 return nullptr;
171 }
172
173 #define VALIDATE(EXPR, ...) \
174 do { \
175 if (!(EXPR)) { \
176 return errors::InvalidArgument( \
177 __VA_ARGS__, "; in OpDef: ", ProtoShortDebugString(op_def)); \
178 } \
179 } while (false)
180
ValidateArg(const OpDef::ArgDef & arg,const OpDef & op_def,bool output,std::set<string> * names)181 static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def,
182 bool output, std::set<string>* names) {
183 const string suffix = strings::StrCat(
184 output ? " for output '" : " for input '", arg.name(), "'");
185 VALIDATE(gtl::InsertIfNotPresent(names, arg.name()),
186 "Duplicate name: ", arg.name());
187 VALIDATE(HasAttrStyleType(arg), "Missing type", suffix);
188
189 if (!arg.number_attr().empty()) {
190 const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def);
191 VALIDATE(attr != nullptr, "No attr with name '", arg.number_attr(), "'",
192 suffix);
193 VALIDATE(attr->type() == "int", "Attr '", attr->name(), "' used as length",
194 suffix, " has type ", attr->type(), " != int");
195 VALIDATE(attr->has_minimum(), "Attr '", attr->name(), "' used as length",
196 suffix, " must have minimum");
197 VALIDATE(attr->minimum() >= 0, "Attr '", attr->name(), "' used as length",
198 suffix, " must have minimum >= 0");
199 VALIDATE(arg.type_list_attr().empty(),
200 "Can't have both number_attr and type_list_attr", suffix);
201 VALIDATE((arg.type() != DT_INVALID ? 1 : 0) +
202 (!arg.type_attr().empty() ? 1 : 0) ==
203 1,
204 "Exactly one of type, type_attr must be set", suffix);
205 } else {
206 const int num_type_fields = (arg.type() != DT_INVALID ? 1 : 0) +
207 (!arg.type_attr().empty() ? 1 : 0) +
208 (!arg.type_list_attr().empty() ? 1 : 0);
209 VALIDATE(num_type_fields == 1,
210 "Exactly one of type, type_attr, type_list_attr must be set",
211 suffix);
212 }
213
214 if (!arg.type_attr().empty()) {
215 const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def);
216 VALIDATE(attr != nullptr, "No attr with name '", arg.type_attr(), "'",
217 suffix);
218 VALIDATE(attr->type() == "type", "Attr '", attr->name(),
219 "' used as type_attr", suffix, " has type ", attr->type(),
220 " != type");
221 } else if (!arg.type_list_attr().empty()) {
222 const OpDef::AttrDef* attr = FindAttr(arg.type_list_attr(), op_def);
223 VALIDATE(attr != nullptr, "No attr with name '", arg.type_list_attr(), "'",
224 suffix);
225 VALIDATE(attr->type() == "list(type)", "Attr '", attr->name(),
226 "' used as type_list_attr", suffix, " has type ", attr->type(),
227 " != list(type)");
228 } else {
229 // All argument types should be non-reference types at this point.
230 // ArgDef.is_ref is set to true for reference arguments.
231 VALIDATE(!IsRefType(arg.type()), "Illegal use of ref type '",
232 DataTypeString(arg.type()), "'. Use 'Ref(type)' instead", suffix);
233 }
234
235 return Status::OK();
236 }
237
ValidateOpDef(const OpDef & op_def)238 Status ValidateOpDef(const OpDef& op_def) {
239 using ::tensorflow::strings::Scanner;
240
241 if (!StringPiece(op_def.name()).starts_with("_")) {
242 VALIDATE(Scanner(op_def.name())
243 .One(Scanner::UPPERLETTER)
244 .Any(Scanner::LETTER_DIGIT)
245 .Eos()
246 .GetResult(),
247 "Invalid name: ", op_def.name(), " (Did you use CamelCase?)");
248 }
249
250 std::set<string> names; // for detecting duplicate names
251 for (const auto& attr : op_def.attr()) {
252 // Validate name
253 VALIDATE(gtl::InsertIfNotPresent(&names, attr.name()),
254 "Duplicate name: ", attr.name());
255 DataType dt;
256 VALIDATE(!DataTypeFromString(attr.name(), &dt), "Attr can't have name ",
257 attr.name(), " that matches a data type");
258
259 // Validate type
260 StringPiece type(attr.type());
261 bool is_list = type.Consume("list(");
262 bool found = false;
263 for (StringPiece valid : {"string", "int", "float", "bool", "type", "shape",
264 "tensor", "func"}) {
265 if (type.Consume(valid)) {
266 found = true;
267 break;
268 }
269 }
270 VALIDATE(found, "Unrecognized type '", type, "' in attr '", attr.name(),
271 "'");
272 if (is_list) {
273 VALIDATE(type.Consume(")"), "'list(' is missing ')' in attr ",
274 attr.name(), "'s type ", attr.type());
275 }
276 VALIDATE(type.empty(), "Extra '", type, "' at the end of attr ",
277 attr.name(), "'s type ", attr.type());
278
279 // Validate minimum
280 if (attr.has_minimum()) {
281 VALIDATE(attr.type() == "int" || is_list, "Attr '", attr.name(),
282 "' has minimum for unsupported type ", attr.type());
283 if (is_list) {
284 VALIDATE(attr.minimum() >= 0, "Attr '", attr.name(),
285 "' with list type must have a non-negative minimum, not ",
286 attr.minimum());
287 }
288 } else {
289 VALIDATE(attr.minimum() == 0, "Attr '", attr.name(),
290 "' with has_minimum = false but minimum ", attr.minimum(),
291 " not equal to default of 0");
292 }
293
294 // Validate allowed_values
295 if (attr.has_allowed_values()) {
296 const string list_type =
297 is_list ? attr.type() : strings::StrCat("list(", attr.type(), ")");
298 TF_RETURN_WITH_CONTEXT_IF_ERROR(
299 AttrValueHasType(attr.allowed_values(), list_type), " for attr '",
300 attr.name(), "' in Op '", op_def.name(), "'");
301 }
302
303 // Validate default_value (after we have validated the rest of the attr,
304 // so we can use ValidateAttrValue()).
305 if (attr.has_default_value()) {
306 TF_RETURN_WITH_CONTEXT_IF_ERROR(
307 ValidateAttrValue(attr.default_value(), attr), " in Op '",
308 op_def.name(), "'");
309 }
310 }
311
312 for (const auto& arg : op_def.input_arg()) {
313 TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, false, &names));
314 }
315
316 for (const auto& arg : op_def.output_arg()) {
317 TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, true, &names));
318 }
319
320 return Status::OK();
321 }
322
323 #undef VALIDATE
324
CheckOpDeprecation(const OpDef & op_def,int graph_def_version)325 Status CheckOpDeprecation(const OpDef& op_def, int graph_def_version) {
326 if (op_def.has_deprecation()) {
327 const OpDeprecation& dep = op_def.deprecation();
328 if (graph_def_version >= dep.version()) {
329 return errors::Unimplemented(
330 "Op ", op_def.name(), " is not available in GraphDef version ",
331 graph_def_version, ". It has been removed in version ", dep.version(),
332 ". ", dep.explanation(), ".");
333 } else {
334 // Warn only once for each op name, and do it in a threadsafe manner.
335 static mutex mu(LINKER_INITIALIZED);
336 static std::unordered_set<string> warned;
337 bool warn;
338 {
339 mutex_lock lock(mu);
340 warn = warned.insert(op_def.name()).second;
341 }
342 if (warn) {
343 LOG(WARNING) << "Op " << op_def.name() << " is deprecated."
344 << " It will cease to work in GraphDef version "
345 << dep.version() << ". " << dep.explanation() << ".";
346 }
347 }
348 }
349 return Status::OK();
350 }
351
352 namespace {
353
SummarizeArgs(const protobuf::RepeatedPtrField<OpDef::ArgDef> & args)354 string SummarizeArgs(const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
355 string ret;
356 for (const OpDef::ArgDef& arg : args) {
357 if (!ret.empty()) strings::StrAppend(&ret, ", ");
358 strings::StrAppend(&ret, arg.name(), ":");
359 if (arg.is_ref()) strings::StrAppend(&ret, "Ref(");
360 if (!arg.number_attr().empty()) {
361 strings::StrAppend(&ret, arg.number_attr(), "*");
362 }
363 if (arg.type() != DT_INVALID) {
364 strings::StrAppend(&ret, DataTypeString(arg.type()));
365 } else {
366 strings::StrAppend(&ret, arg.type_attr());
367 }
368 if (arg.is_ref()) strings::StrAppend(&ret, ")");
369 }
370 return ret;
371 }
372
373 } // namespace
374
SummarizeOpDef(const OpDef & op_def)375 string SummarizeOpDef(const OpDef& op_def) {
376 string ret = strings::StrCat("Op<name=", op_def.name());
377 strings::StrAppend(&ret, "; signature=", SummarizeArgs(op_def.input_arg()),
378 " -> ", SummarizeArgs(op_def.output_arg()));
379 for (int i = 0; i < op_def.attr_size(); ++i) {
380 strings::StrAppend(&ret, "; attr=", op_def.attr(i).name(), ":",
381 op_def.attr(i).type());
382 if (op_def.attr(i).has_default_value()) {
383 strings::StrAppend(&ret, ",default=",
384 SummarizeAttrValue(op_def.attr(i).default_value()));
385 }
386 if (op_def.attr(i).has_minimum()) {
387 strings::StrAppend(&ret, ",min=", op_def.attr(i).minimum());
388 }
389 if (op_def.attr(i).has_allowed_values()) {
390 strings::StrAppend(&ret, ",allowed=",
391 SummarizeAttrValue(op_def.attr(i).allowed_values()));
392 }
393 }
394 if (op_def.is_commutative()) {
395 strings::StrAppend(&ret, "; is_commutative=true");
396 }
397 if (op_def.is_aggregate()) {
398 strings::StrAppend(&ret, "; is_aggregate=true");
399 }
400 if (op_def.is_stateful()) {
401 strings::StrAppend(&ret, "; is_stateful=true");
402 }
403 if (op_def.allows_uninitialized_input()) {
404 strings::StrAppend(&ret, "; allows_uninitialized_input=true");
405 }
406 strings::StrAppend(&ret, ">");
407 return ret;
408 }
409
410 namespace {
411
412 // Returns true if every element of `sub` is contained in `super`.
413 template <class T>
IsSubsetOf(const T & sub,const T & super)414 bool IsSubsetOf(const T& sub, const T& super) {
415 for (const auto& o : sub) {
416 bool found = false;
417 for (const auto& n : super) {
418 if (o == n) {
419 found = true;
420 break;
421 }
422 }
423 if (!found) return false;
424 }
425 return true;
426 }
427
MoreRestrictive(const OpDef::AttrDef & old_attr,const OpDef::AttrDef & new_attr)428 bool MoreRestrictive(const OpDef::AttrDef& old_attr,
429 const OpDef::AttrDef& new_attr) {
430 // Anything -> no restriction : not more restrictive.
431 if (!new_attr.has_allowed_values()) return false;
432 // No restriction -> restriction : more restrictive.
433 if (!old_attr.has_allowed_values()) return true;
434 // If anything that was previously allowed is no longer allowed:
435 // more restrictive.
436 if (!IsSubsetOf(old_attr.allowed_values().list().type(),
437 new_attr.allowed_values().list().type())) {
438 return true;
439 }
440 if (!IsSubsetOf(old_attr.allowed_values().list().s(),
441 new_attr.allowed_values().list().s())) {
442 return true;
443 }
444 return false;
445 }
446
AllowedStr(const OpDef::AttrDef & attr)447 string AllowedStr(const OpDef::AttrDef& attr) {
448 if (!attr.has_allowed_values()) return "no restriction";
449 return SummarizeAttrValue(attr.allowed_values());
450 }
451
DefaultAttrStr(const OpDef::AttrDef & attr)452 string DefaultAttrStr(const OpDef::AttrDef& attr) {
453 if (!attr.has_default_value()) return "no default";
454 return SummarizeAttrValue(attr.default_value());
455 }
456
HigherMinimum(const OpDef::AttrDef & old_attr,const OpDef::AttrDef & new_attr)457 bool HigherMinimum(const OpDef::AttrDef& old_attr,
458 const OpDef::AttrDef& new_attr) {
459 // Anything -> no restriction : not more restrictive.
460 if (!new_attr.has_minimum()) return false;
461 // No restriction -> restriction : more restrictive.
462 if (!old_attr.has_minimum()) return true;
463 // If anything that was previously allowed is no longer allowed:
464 // more restrictive.
465 return new_attr.minimum() > old_attr.minimum();
466 }
467
MinStr(const OpDef::AttrDef & attr)468 string MinStr(const OpDef::AttrDef& attr) {
469 if (!attr.has_minimum()) return "no minimum";
470 return strings::StrCat(attr.minimum());
471 }
472
473 typedef std::unordered_map<string, const OpDef::AttrDef*> AttrMap;
FillAttrMap(const OpDef & op_def,AttrMap * attr_map)474 void FillAttrMap(const OpDef& op_def, AttrMap* attr_map) {
475 for (const auto& attr : op_def.attr()) {
476 (*attr_map)[attr.name()] = &attr;
477 }
478 }
479
480 // Add a comma to *s every call but the first (*add_comma should be
481 // initialized to false).
AddComma(string * s,bool * add_comma)482 void AddComma(string* s, bool* add_comma) {
483 if (*add_comma) {
484 strings::StrAppend(s, ", ");
485 } else {
486 *add_comma = true;
487 }
488 }
489
490 // Will add the `name` from arg if name is true.
AddName(string * s,bool name,const OpDef::ArgDef & arg)491 void AddName(string* s, bool name, const OpDef::ArgDef& arg) {
492 if (name) {
493 strings::StrAppend(s, arg.name(), ":");
494 }
495 }
496
497 // Compute a signature for either inputs or outputs that will be the
498 // same for both the old and new OpDef if they are compatible. We
499 // assume that new_attrs is a superset of old_attrs, and that any attr
500 // in the difference has a default. Our strategy is to make a list of
501 // types, where the types are things like:
502 // * "int32", "float", etc.,
503 // * "T" for some attr "T" in old_attrs, or
504 // * "N * type" for "N" either some attr in old_attrs.
505 //
506 // We get the types by either using the attrs in args if they are in
507 // old_attrs, or substituting the default value from new_attrs.
ComputeArgSignature(const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const AttrMap & old_attrs,const AttrMap & new_attrs,std::vector<bool> * ref,bool names)508 string ComputeArgSignature(
509 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
510 const AttrMap& old_attrs, const AttrMap& new_attrs, std::vector<bool>* ref,
511 bool names) {
512 string s;
513 bool add_comma = false;
514 for (const OpDef::ArgDef& arg : args) {
515 if (!arg.type_list_attr().empty()) {
516 const OpDef::AttrDef* old_attr =
517 gtl::FindPtrOrNull(old_attrs, arg.type_list_attr());
518 if (old_attr) {
519 // Both old and new have the list(type) attr, so can use it directly.
520 AddComma(&s, &add_comma);
521 AddName(&s, names, arg);
522 strings::StrAppend(&s, arg.type_list_attr());
523 ref->push_back(arg.is_ref());
524 } else {
525 // Missing the list(type) attr in the old, so use the default
526 // value for the attr from new instead.
527 const OpDef::AttrDef* new_attr =
528 gtl::FindPtrOrNull(new_attrs, arg.type_list_attr());
529 const auto& type_list = new_attr->default_value().list().type();
530 if (type_list.empty()) continue;
531 for (int i = 0; i < type_list.size(); ++i) {
532 AddComma(&s, &add_comma);
533 AddName(&s, names, arg);
534 strings::StrAppend(
535 &s, DataTypeString(static_cast<DataType>(type_list.Get(i))));
536 ref->push_back(arg.is_ref());
537 }
538 }
539 } else {
540 int num = 1; // How many input/outputs does this represent?
541 string type; // What is the type of this arg?
542 AddName(&type, names, arg);
543 if (!arg.number_attr().empty()) {
544 // N * type case.
545 const OpDef::AttrDef* old_attr =
546 gtl::FindPtrOrNull(old_attrs, arg.number_attr());
547 if (old_attr) {
548 // Both old and new have the number attr, so can use it directly.
549 strings::StrAppend(&type, arg.number_attr(), " * ");
550 } else {
551 // Missing the number attr in the old, so use the default
552 // value for the attr from new instead.
553 const OpDef::AttrDef* new_attr =
554 gtl::FindPtrOrNull(new_attrs, arg.number_attr());
555 num = new_attr->default_value().i();
556 }
557 }
558
559 if (arg.type() != DT_INVALID) {
560 // int32, float, etc. case
561 strings::StrAppend(&type, DataTypeString(arg.type()));
562 } else {
563 const OpDef::AttrDef* old_attr =
564 gtl::FindPtrOrNull(old_attrs, arg.type_attr());
565 if (old_attr) {
566 // Both old and new have the type attr, so can use it directly.
567 strings::StrAppend(&type, arg.type_attr());
568 } else {
569 // Missing the type attr in the old, so use the default
570 // value for the attr from new instead.
571 const OpDef::AttrDef* new_attr =
572 gtl::FindPtrOrNull(new_attrs, arg.type_attr());
573 strings::StrAppend(&type,
574 DataTypeString(new_attr->default_value().type()));
575 }
576 }
577
578 // Record `num` * `type` in the signature.
579 for (int i = 0; i < num; ++i) {
580 AddComma(&s, &add_comma);
581 strings::StrAppend(&s, type);
582 ref->push_back(arg.is_ref());
583 }
584 }
585 }
586
587 return s;
588 }
589
590 } // namespace
591
OpDefCompatible(const OpDef & old_op,const OpDef & new_op)592 Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op) {
593 #define VALIDATE(CONDITION, ...) \
594 if (!(CONDITION)) { \
595 return errors::InvalidArgument("Incompatible Op change: ", __VA_ARGS__, \
596 "; old: ", SummarizeOpDef(old_op), \
597 "; new: ", SummarizeOpDef(new_op)); \
598 }
599
600 VALIDATE(old_op.name() == new_op.name(), "Name mismatch");
601
602 AttrMap new_attrs, old_attrs;
603 FillAttrMap(old_op, &old_attrs);
604 FillAttrMap(new_op, &new_attrs);
605 for (const auto& old_attr : old_op.attr()) {
606 const OpDef::AttrDef* new_attr =
607 gtl::FindPtrOrNull(new_attrs, old_attr.name());
608 VALIDATE(new_attr != nullptr, "Attr '", old_attr.name(), "' removed");
609 VALIDATE(old_attr.type() == new_attr->type(), "Attr '", old_attr.name(),
610 "' changed type '", old_attr.type(), "' -> '", new_attr->type(),
611 "'");
612 VALIDATE(!MoreRestrictive(old_attr, *new_attr), "Attr '", old_attr.name(),
613 "' has a stricter set of allowed values; from ",
614 AllowedStr(old_attr), " to ", AllowedStr(*new_attr));
615 VALIDATE(!HigherMinimum(old_attr, *new_attr), "Attr '", old_attr.name(),
616 "' has a higher minimum; from ", MinStr(old_attr), " to ",
617 MinStr(*new_attr));
618 }
619
620 for (const auto& new_attr : new_op.attr()) {
621 const OpDef::AttrDef* old_attr =
622 gtl::FindPtrOrNull(old_attrs, new_attr.name());
623 VALIDATE(old_attr != nullptr || new_attr.has_default_value(), "Attr '",
624 new_attr.name(), "' added without default");
625 }
626
627 std::vector<bool> old_in_ref, new_in_ref, old_out_ref, new_out_ref;
628 const string old_in_sig = ComputeArgSignature(
629 old_op.input_arg(), old_attrs, new_attrs, &old_in_ref, false /* names */);
630 const string new_in_sig = ComputeArgSignature(
631 new_op.input_arg(), old_attrs, new_attrs, &new_in_ref, false /* names */);
632 VALIDATE(old_in_sig == new_in_sig, "Input signature mismatch '", old_in_sig,
633 "' vs. '", new_in_sig, "'");
634 VALIDATE(old_in_ref.size() == new_in_ref.size(), // Should not happen
635 "Unexpected change in input ref lists.");
636 for (int i = 0; i < old_in_ref.size(); ++i) {
637 // Allowed to remove "ref" from an input (or leave it unchanged).
638 VALIDATE(old_in_ref[i] || !new_in_ref[i], "Input ", i,
639 " changed from non-ref to ref");
640 }
641
642 const string old_out_sig =
643 ComputeArgSignature(old_op.output_arg(), old_attrs, new_attrs,
644 &old_out_ref, true /* names */);
645 const string new_out_sig =
646 ComputeArgSignature(new_op.output_arg(), old_attrs, new_attrs,
647 &new_out_ref, true /* names */);
648 VALIDATE(old_out_sig == new_out_sig, "Output signature mismatch '",
649 old_out_sig, "' vs. '", new_out_sig, "'");
650 VALIDATE(old_out_ref.size() == new_out_ref.size(), // Should not happen
651 "Unexpected change in output ref lists");
652 for (int i = 0; i < old_out_ref.size(); ++i) {
653 // Allowed to add "ref" to an output (or leave it unchanged).
654 VALIDATE(!old_out_ref[i] || new_out_ref[i], "Output ", i,
655 " changed from ref to non-ref");
656 }
657
658 return Status::OK();
659 }
660
OpDefAddedDefaultsUnchanged(const OpDef & old_op,const OpDef & penultimate_op,const OpDef & new_op)661 Status OpDefAddedDefaultsUnchanged(const OpDef& old_op,
662 const OpDef& penultimate_op,
663 const OpDef& new_op) {
664 AttrMap new_attrs, old_attrs;
665 FillAttrMap(old_op, &old_attrs);
666 FillAttrMap(new_op, &new_attrs);
667
668 for (const auto& penultimate_attr : penultimate_op.attr()) {
669 const OpDef::AttrDef* old_attr =
670 gtl::FindPtrOrNull(old_attrs, penultimate_attr.name());
671 if (old_attr != nullptr) continue; // attr wasn't added
672 const OpDef::AttrDef* new_attr =
673 gtl::FindPtrOrNull(new_attrs, penultimate_attr.name());
674
675 // These shouldn't happen if the op passed OpDefCompatible().
676 if (new_attr == nullptr) {
677 return errors::InvalidArgument("Missing attr '", penultimate_attr.name(),
678 "' in op: ", SummarizeOpDef(new_op));
679 }
680 if (!penultimate_attr.has_default_value() ||
681 !new_attr->has_default_value()) {
682 return errors::InvalidArgument("Missing default for attr '",
683 penultimate_attr.name(),
684 "' in op: ", SummarizeOpDef(new_op));
685 }
686
687 // Actually test that the attr's default value hasn't changed.
688 if (!AreAttrValuesEqual(penultimate_attr.default_value(),
689 new_attr->default_value())) {
690 return errors::InvalidArgument(
691 "Can't change default value for attr '", penultimate_attr.name(),
692 "' from ", SummarizeAttrValue(penultimate_attr.default_value()),
693 " in op: ", SummarizeOpDef(new_op));
694 }
695 }
696
697 return Status::OK();
698 }
699
OpDefAttrDefaultsUnchanged(const OpDef & old_op,const OpDef & new_op)700 Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op) {
701 AttrMap new_attrs, old_attrs;
702 FillAttrMap(old_op, &old_attrs);
703 FillAttrMap(new_op, &new_attrs);
704
705 for (const auto& old_attr : old_op.attr()) {
706 const OpDef::AttrDef* new_attr =
707 gtl::FindPtrOrNull(new_attrs, old_attr.name());
708 if (new_attr == nullptr) continue;
709 if (old_attr.has_default_value() != new_attr->has_default_value()) {
710 return errors::InvalidArgument(
711 "Attr '", old_attr.name(), "' has added/removed it's default; ",
712 "from ", DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr));
713 }
714 if (old_attr.has_default_value() &&
715 !AreAttrValuesEqual(old_attr.default_value(),
716 new_attr->default_value())) {
717 return errors::InvalidArgument(
718 "Attr '", old_attr.name(), "' has changed it's default value; ",
719 "from ", DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr));
720 }
721 }
722
723 return Status::OK();
724 }
725
RemoveNonDeprecationDescriptionsFromOpDef(OpDef * op_def)726 void RemoveNonDeprecationDescriptionsFromOpDef(OpDef* op_def) {
727 for (int i = 0; i < op_def->input_arg_size(); ++i) {
728 op_def->mutable_input_arg(i)->clear_description();
729 }
730 for (int i = 0; i < op_def->output_arg_size(); ++i) {
731 op_def->mutable_output_arg(i)->clear_description();
732 }
733 for (int i = 0; i < op_def->attr_size(); ++i) {
734 op_def->mutable_attr(i)->clear_description();
735 }
736 op_def->clear_summary();
737 op_def->clear_description();
738 }
739
RemoveDescriptionsFromOpDef(OpDef * op_def)740 void RemoveDescriptionsFromOpDef(OpDef* op_def) {
741 RemoveNonDeprecationDescriptionsFromOpDef(op_def);
742 if (op_def->has_deprecation()) {
743 op_def->mutable_deprecation()->clear_explanation();
744 }
745 }
746
RemoveDescriptionsFromOpList(OpList * op_list)747 void RemoveDescriptionsFromOpList(OpList* op_list) {
748 for (int i = 0; i < op_list->op_size(); ++i) {
749 OpDef* op_def = op_list->mutable_op(i);
750 RemoveDescriptionsFromOpDef(op_def);
751 }
752 }
753
AttrDefEqual(const OpDef::AttrDef & a1,const OpDef::AttrDef & a2)754 bool AttrDefEqual(const OpDef::AttrDef& a1, const OpDef::AttrDef& a2) {
755 #ifndef TENSORFLOW_LITE_PROTOS
756 DCHECK_EQ(7, a1.GetDescriptor()->field_count())
757 << "Please modify these equality and hash functions to reflect the "
758 "changes to the AttrDef protobuf";
759 #endif // TENSORFLOW_LITE_PROTOS
760
761 if (a1.name() != a2.name()) return false;
762 if (a1.type() != a2.type()) return false;
763 if (a1.description() != a2.description()) return false;
764 if (a1.has_minimum() != a2.has_minimum()) return false;
765 if (a1.has_minimum() && a1.minimum() != a2.minimum()) return false;
766 if (!AreAttrValuesEqual(a1.default_value(), a2.default_value())) return false;
767 if (!AreAttrValuesEqual(a1.allowed_values(), a2.allowed_values()))
768 return false;
769 return true;
770 }
771
AttrDefHash(const OpDef::AttrDef & a)772 uint64 AttrDefHash(const OpDef::AttrDef& a) {
773 uint64 h = Hash64(a.name());
774 h = Hash64(a.type().data(), a.type().size(), h);
775 h = Hash64Combine(AttrValueHash(a.default_value()), h);
776 h = Hash64(a.description().data(), a.description().size(), h);
777 h = Hash64Combine(static_cast<uint64>(a.has_minimum()), h);
778 h = Hash64Combine(static_cast<uint64>(a.minimum()), h);
779 h = Hash64Combine(AttrValueHash(a.allowed_values()), h);
780 return h;
781 }
782
RepeatedAttrDefEqual(const protobuf::RepeatedPtrField<OpDef::AttrDef> & a1,const protobuf::RepeatedPtrField<OpDef::AttrDef> & a2)783 bool RepeatedAttrDefEqual(
784 const protobuf::RepeatedPtrField<OpDef::AttrDef>& a1,
785 const protobuf::RepeatedPtrField<OpDef::AttrDef>& a2) {
786 std::unordered_map<string, const OpDef::AttrDef*> a1_set;
787 for (const OpDef::AttrDef& def : a1) {
788 DCHECK(a1_set.find(def.name()) == a1_set.end())
789 << "AttrDef names must be unique, but '" << def.name()
790 << "' appears more than once";
791 a1_set[def.name()] = &def;
792 }
793 for (const OpDef::AttrDef& def : a2) {
794 auto iter = a1_set.find(def.name());
795 if (iter == a1_set.end()) return false;
796 if (!AttrDefEqual(*iter->second, def)) return false;
797 a1_set.erase(iter);
798 }
799 if (!a1_set.empty()) return false;
800 return true;
801 }
802
RepeatedAttrDefHash(const protobuf::RepeatedPtrField<OpDef::AttrDef> & a)803 uint64 RepeatedAttrDefHash(
804 const protobuf::RepeatedPtrField<OpDef::AttrDef>& a) {
805 // Insert AttrDefs into map to deterministically sort by name
806 std::map<string, const OpDef::AttrDef*> a_set;
807 for (const OpDef::AttrDef& def : a) {
808 a_set[def.name()] = &def;
809 }
810 // Iterate and combines hashes of keys and values
811 uint64 h = 0xDECAFCAFFE;
812 for (const auto& pair : a_set) {
813 h = Hash64(pair.first.data(), pair.first.size(), h);
814 h = Hash64Combine(AttrDefHash(*pair.second), h);
815 }
816 return h;
817 }
818
OpDefEqual(const OpDef & o1,const OpDef & o2)819 bool OpDefEqual(const OpDef& o1, const OpDef& o2) {
820 // attr order doesn't matter.
821 // Compare it separately here instead of serializing below.
822 if (!RepeatedAttrDefEqual(o1.attr(), o2.attr())) return false;
823
824 // Clear attr field, serialize, and compare serialized strings
825 OpDef o1_copy = o1;
826 OpDef o2_copy = o2;
827 o1_copy.clear_attr();
828 o2_copy.clear_attr();
829 string s1, s2;
830 SerializeToStringDeterministic(o1_copy, &s1);
831 SerializeToStringDeterministic(o2_copy, &s2);
832 if (s1 != s2) return false;
833 return true;
834 }
835
OpDefHash(const OpDef & o)836 uint64 OpDefHash(const OpDef& o) {
837 uint64 h = RepeatedAttrDefHash(o.attr());
838 OpDef o_copy = o;
839 o_copy.clear_attr();
840 string s;
841 SerializeToStringDeterministic(o_copy, &s);
842 return Hash64(s.data(), s.size(), h);
843 }
844
845 } // namespace tensorflow
846