1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/node_def_util.h"
17 
18 #include <algorithm>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/attr_value_util.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/op_def.pb.h"
28 #include "tensorflow/core/framework/op_def_util.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/tensor_shape.pb.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/lib/gtl/map_util.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/scanner.h"
38 #include "tensorflow/core/platform/status.h"
39 #include "tensorflow/core/platform/strcat.h"
40 #include "tensorflow/core/platform/stringpiece.h"
41 #include "tensorflow/core/platform/types.h"
42 
43 namespace tensorflow {
44 
45 const char* const kColocationAttrName = "_class";
46 const char* const kColocationGroupPrefix = "loc:@";
47 
AttrSlice()48 AttrSlice::AttrSlice() : ndef_(nullptr) {
49   static const AttrValueMap* const kEmptyAttrValueMap = new AttrValueMap;
50   attrs_ = kEmptyAttrValueMap;
51 }
52 
AttrSlice(const NodeDef & node_def)53 AttrSlice::AttrSlice(const NodeDef& node_def)
54     : ndef_(&node_def), attrs_(&ndef_->attr()) {}
55 
AttrSlice(const AttrValueMap * a)56 AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {}
57 
SummarizeAttrsHelper(AttrSlice attrs,StringPiece device)58 string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) {
59   string ret;
60 
61   // We sort the attrs so the output is deterministic.
62   std::vector<string> attr_names;
63   attr_names.reserve(attrs.size());
64   for (const auto& attr : attrs) {
65     attr_names.push_back(attr.first);
66   }
67   std::sort(attr_names.begin(), attr_names.end());
68   bool first = true;
69   for (const string& attr_name : attr_names) {
70     if (!first) strings::StrAppend(&ret, ", ");
71     first = false;
72     strings::StrAppend(&ret, attr_name, "=",
73                        SummarizeAttrValue(*attrs.Find(attr_name)));
74   }
75 
76   // Consider the device to be a final attr with name "_device".
77   if (!device.empty()) {
78     if (!first) strings::StrAppend(&ret, ", ");
79     first = false;
80     strings::StrAppend(&ret, "_device=\"", device, "\"");
81   }
82   return ret;
83 }
84 
SummarizeNode() const85 string AttrSlice::SummarizeNode() const {
86   return ndef_ ? SummarizeNodeDef(*ndef_)
87                : strings::StrCat(
88                      "[", SummarizeAttrsHelper(*this, StringPiece()), "]");
89 }
90 
DebugString() const91 string AttrSlice::DebugString() const {
92   std::vector<string> attr_key_vals;
93   attr_key_vals.reserve(attrs_->size());
94   for (const auto& it : *this) {
95     const string& name = it.first;
96     const AttrValue& attr_value = it.second;
97     attr_key_vals.push_back(
98         absl::StrCat(name, "=", SummarizeAttrValue(attr_value)));
99   }
100   return absl::StrJoin(attr_key_vals, ", ");
101 }
102 
SummarizeNodeDef(const NodeDef & node_def,int max_inputs_in_summary)103 string SummarizeNodeDef(const NodeDef& node_def, int max_inputs_in_summary) {
104   string ret = strings::StrCat(errors::FormatNodeNameForError(node_def.name()),
105                                " = ", node_def.op(), "[");
106   strings::StrAppend(&ret, SummarizeAttrsHelper(node_def, node_def.device()));
107   strings::StrAppend(&ret, "](");
108 
109   // Output inputs, including control inputs, verbatim.
110   bool first = true;
111   for (const string& input : node_def.input()) {
112     if (!first) strings::StrAppend(&ret, ", ");
113     first = false;
114     if (max_inputs_in_summary-- == 0) {
115       strings::StrAppend(&ret, "...");
116       break;
117     }
118     strings::StrAppend(&ret, input);
119   }
120   strings::StrAppend(&ret, ")");
121   return ret;
122 }
123 
SummarizeAttrs(const NodeDef & node_def)124 string SummarizeAttrs(const NodeDef& node_def) {
125   return SummarizeAttrsHelper(node_def, node_def.device());
126 }
127 
FormatNodeDefForError(StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info)128 string FormatNodeDefForError(
129     StringPiece node_name, bool has_experimental_debug_info,
130     const NodeDef_ExperimentalDebugInfo& experimental_debug_info) {
131   return !has_experimental_debug_info ||
132                  experimental_debug_info.original_node_names().empty()
133              ? errors::FormatNodeNameForError(string(node_name))
134              : errors::FormatNodeNamesForError(
135                    experimental_debug_info.original_node_names());
136 }
137 
FormatNodeDefForError(const NodeDef & node_def)138 string FormatNodeDefForError(const NodeDef& node_def) {
139   return FormatNodeDefForError(node_def.name(),
140                                node_def.has_experimental_debug_info(),
141                                node_def.experimental_debug_info());
142 }
143 
Find(StringPiece attr_name) const144 const AttrValue* AttrSlice::Find(StringPiece attr_name) const {
145   // Currently, the collection used for NodeDef::attr() (google::protobuf::Map)
146   // requires that the keys used for lookups have type 'const string&'. Because
147   // this method takes a StringPiece, it is necessary to allocate a temporary
148   // string, copy attr_name to it, and then use that temporary string for the
149   // lookup. This causes an excessive number of short-lived allocations, and for
150   // large graphs, this can be a significant cost.
151   //
152   // Because most nodes have a small number of attributes, a simple linear scan
153   // is generally more efficient than a hashed lookup.  If google::protobuf::Map
154   // changes so that it supports efficient lookups using StringPiece instead of
155   // const string&, then this code could be changed to use attrs_->find() again.
156 
157   for (const auto& attr : *attrs_) {
158     if (attr.first == attr_name) {
159       return &attr.second;
160     }
161   }
162   return nullptr;
163 }
164 
FindByString(const string & attr_name) const165 const AttrValue* AttrSlice::FindByString(const string& attr_name) const {
166   auto iter = attrs_->find(attr_name);
167   if (iter != attrs_->end()) {
168     return &iter->second;
169   } else {
170     return nullptr;
171   }
172 }
173 
Find(StringPiece attr_name,const AttrValue ** attr_value) const174 Status AttrSlice::Find(StringPiece attr_name,
175                        const AttrValue** attr_value) const {
176   *attr_value = Find(attr_name);
177   if (*attr_value != nullptr) {
178     return Status::OK();
179   }
180   Status s = errors::NotFound("No attr named '", attr_name, "' in NodeDef:");
181   // Skip AttachDef for internal attrs since it is a little bit
182   // expensive and it is common for them to correctly not be included
183   // in a NodeDef.
184   if (!absl::StartsWith(attr_name, "_") && ndef_ != nullptr) {
185     s = AttachDef(s, *ndef_);
186   }
187   return s;
188 }
189 
EqualAttrs(AttrSlice other,Scratch * scratch) const190 bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const {
191   if (size() != other.size()) return false;
192 
193   for (const auto& attr : *other.attrs_) {
194     auto iter = attrs_->find(attr.first);
195     if (iter == attrs_->end()) return false;
196     // TODO(irving): Comparing AttrValues by proto is slightly buggy, since
197     // TensorProto is a nonunique representation of Tensor.  This bug will go
198     // away once AttrSlice switches over to NodeInfo.
199     iter->second.SerializeToString(&scratch->a);
200     attr.second.SerializeToString(&scratch->b);
201     if (scratch->a != scratch->b) return false;
202   }
203   return true;
204 }
205 
206 // The ... is to allow the caller to inject some value validation code.  Use
207 // just ; if no additional validation code is needed.
208 #define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...)         \
209   Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,           \
210                      TYPE* value) {                                           \
211     const AttrValue* attr_value;                                              \
212     TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));                   \
213     TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, ATTR_TYPE));             \
214     const auto& v = attr_value->FIELD();                                      \
215     __VA_ARGS__;                                                              \
216     *value = CAST;                                                            \
217     return Status::OK();                                                      \
218   }                                                                           \
219   Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,           \
220                      std::vector<TYPE>* value) {                              \
221     const AttrValue* attr_value;                                              \
222     TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));                   \
223     TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")")); \
224     value->reserve(attr_value->list().FIELD().size());                        \
225     for (const auto& v : attr_value->list().FIELD()) {                        \
226       __VA_ARGS__;                                                            \
227       value->APPEND_OP(CAST);                                                 \
228     }                                                                         \
229     return Status::OK();                                                      \
230   }
231 
232 #define DEFINE_TRY_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
233   bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,      \
234                       TYPE* value) {                                      \
235     const AttrValue* attr_value = attrs.Find(attr_name);                  \
236     if (attr_value == nullptr) {                                          \
237       return false;                                                       \
238     }                                                                     \
239     Status s = AttrValueHasType(*attr_value, ATTR_TYPE);                  \
240     if (!s.ok()) {                                                        \
241       return false;                                                       \
242     }                                                                     \
243     const auto& v = attr_value->FIELD();                                  \
244     __VA_ARGS__;                                                          \
245     *value = CAST;                                                        \
246     return true;                                                          \
247   }                                                                       \
248   bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,      \
249                       std::vector<TYPE>* value) {                         \
250     const AttrValue* attr_value = attrs.Find(attr_name);                  \
251     if (attr_value == nullptr) {                                          \
252       return false;                                                       \
253     }                                                                     \
254     Status s = AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")");      \
255     if (!s.ok()) {                                                        \
256       return false;                                                       \
257     }                                                                     \
258     value->reserve(attr_value->list().FIELD().size());                    \
259     for (const auto& v : attr_value->list().FIELD()) {                    \
260       __VA_ARGS__;                                                        \
261       value->APPEND_OP(CAST);                                             \
262     }                                                                     \
263     return true;                                                          \
264   }
265 DEFINE_GET_ATTR(tstring, s, "string", emplace_back, v, ;)
266 DEFINE_TRY_GET_ATTR(tstring, s, "string", emplace_back, v, ;)
267 DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;)
268 DEFINE_TRY_GET_ATTR(string, s, "string", emplace_back, v, ;)
269 DEFINE_GET_ATTR(int64, i, "int", emplace_back, v, ;)
270 DEFINE_TRY_GET_ATTR(int64, i, "int", emplace_back, v, ;)
271 DEFINE_GET_ATTR(
272     int32, i, "int", emplace_back, static_cast<int32>(v),
273     if (static_cast<int64>(static_cast<int32>(v)) != v) {
274       return errors::InvalidArgument("Attr ", attr_name, " has value ", v,
275                                      " out of range for an int32");
276     })
277 DEFINE_TRY_GET_ATTR(
278     int32, i, "int", emplace_back, static_cast<int32>(v),
279     if (static_cast<int64>(static_cast<int32>(v)) != v) {
280       static int log_counter = 0;
281       if (log_counter < 10) {
282         log_counter++;
283         LOG(WARNING) << "Attr " << attr_name << " has value " << v
284                      << " out of range for an int32";
285       }
286       return false;
287     })
288 DEFINE_GET_ATTR(float, f, "float", emplace_back, v, ;)
289 DEFINE_TRY_GET_ATTR(float, f, "float", emplace_back, v, ;)
290 // std::vector<bool> specialization does not have emplace_back until
291 // c++14, so we have to use push_back (see
292 // http://en.cppreference.com/w/cpp/container/vector/emplace_back)
293 DEFINE_GET_ATTR(bool, b, "bool", push_back, v, ;)
294 DEFINE_TRY_GET_ATTR(bool, b, "bool", push_back, v, ;)
295 DEFINE_GET_ATTR(DataType, type, "type", emplace_back, static_cast<DataType>(v),
296                 ;)
297 DEFINE_TRY_GET_ATTR(DataType, type, "type", emplace_back,
298                     static_cast<DataType>(v),
299                     ;)
300 DEFINE_GET_ATTR(TensorShapeProto, shape, "shape", emplace_back, v, ;)
301 DEFINE_GET_ATTR(TensorShape, shape, "shape", emplace_back, TensorShape(v),
302                 TF_RETURN_IF_ERROR(TensorShape::IsValidShape(v));)
303 DEFINE_TRY_GET_ATTR(
304     TensorShape, shape, "shape", emplace_back, TensorShape(v),
305     if (!TensorShape::IsValidShape(v).ok()) {
306       static int log_counter = 0;
307       if (log_counter < 10) {
308         log_counter++;
309         LOG(WARNING) << "Attr " << attr_name << " has invalid shape value "
310                      << v.DebugString();
311       }
312       return false;
313     })
314 DEFINE_GET_ATTR(PartialTensorShape, shape, "shape", emplace_back,
315                 PartialTensorShape(v),
316                 TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(v));)
317 DEFINE_GET_ATTR(
318     Tensor, tensor, "tensor", emplace_back, t, Tensor t; if (!t.FromProto(v)) {
319       return errors::InvalidArgument("Attr ", attr_name, " has value ",
320                                      v.ShortDebugString(),
321                                      " that can't be converted to a Tensor");
322     })
323 DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;);
324 #undef DEFINE_GET_ATTR
325 
HasNodeAttr(const NodeDef & node_def,StringPiece attr_name)326 bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name) {
327   return node_def.attr().find(string(attr_name)) != node_def.attr().end();
328 }
329 
330 static const string& kEmptyString = *new string();
331 
GetNodeAttrString(const AttrSlice & attrs,StringPiece attr_name)332 const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name) {
333   const AttrValue* attr_value = attrs.Find(attr_name);
334   if (attr_value == nullptr) {
335     return kEmptyString;
336   }
337   Status s = AttrValueHasType(*attr_value, "string");
338   if (!s.ok()) {
339     return kEmptyString;
340   }
341   return attr_value->s();
342 }
343 
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,std::vector<const string * > * value)344 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
345                     std::vector<const string*>* value) {
346   const AttrValue* attr_value = attrs.Find(attr_name);
347   if (attr_value == nullptr) {
348     return false;
349   }
350   Status s = AttrValueHasType(*attr_value, "list(string)");
351   if (!s.ok()) {
352     return false;
353   }
354   value->reserve(attr_value->list().s().size());
355   for (const auto& v : attr_value->list().s()) {
356     value->push_back(&v);
357   }
358   return true;
359 }
360 
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,std::vector<const TensorShapeProto * > * value)361 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
362                     std::vector<const TensorShapeProto*>* value) {
363   const AttrValue* attr_value = attrs.Find(attr_name);
364   if (attr_value == nullptr) {
365     return false;
366   }
367   Status s = AttrValueHasType(*attr_value, "list(shape)");
368   if (!s.ok()) {
369     return false;
370   }
371   value->reserve(attr_value->list().shape().size());
372   for (const auto& v : attr_value->list().shape()) {
373     value->push_back(&v);
374   }
375   return true;
376 }
377 
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,DataTypeVector * value)378 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
379                    DataTypeVector* value) {
380   const AttrValue* attr_value;
381   TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
382   TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(type)"));
383   for (const auto& v : attr_value->list().type()) {
384     value->push_back(static_cast<DataType>(v));
385   }
386   return Status::OK();
387 }
388 
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const TensorProto ** value)389 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
390                    const TensorProto** value) {
391   const AttrValue* attr_value;
392   TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
393   TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "tensor"));
394   *value = &attr_value->tensor();
395   return Status::OK();
396 }
397 
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const TensorProto ** value)398 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
399                     const TensorProto** value) {
400   const AttrValue* attr_value = attrs.Find(attr_name);
401   if (attr_value == nullptr) {
402     return false;
403   }
404   Status s = AttrValueHasType(*attr_value, "tensor");
405   if (!s.ok()) {
406     return false;
407   }
408   *value = &attr_value->tensor();
409   return true;
410 }
411 
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const NameAttrList ** value)412 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
413                    const NameAttrList** value) {
414   const AttrValue* attr_value;
415   TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
416   TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func"));
417   *value = &attr_value->func();
418   return Status::OK();
419 }
420 
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const NameAttrList ** value)421 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
422                     const NameAttrList** value) {
423   const AttrValue* attr_value = attrs.Find(attr_name);
424   if (attr_value == nullptr) {
425     return false;
426   }
427   Status s = AttrValueHasType(*attr_value, "func");
428   if (!s.ok()) {
429     return false;
430   }
431   *value = &attr_value->func();
432   return true;
433 }
434 
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,Padding * value)435 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
436                    Padding* value) {
437   string str_value;
438   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_name, &str_value));
439   return GetPaddingFromString(str_value, value);
440 }
441 
442 namespace {  // Helper for InOutTypesForNode().
443 
444 template <class NodeDefOrAttrSlice>
AddArgToSig(const NodeDefOrAttrSlice & node_or_attrs,const OpDef::ArgDef & arg_def,DataTypeVector * sig)445 Status AddArgToSig(const NodeDefOrAttrSlice& node_or_attrs,
446                    const OpDef::ArgDef& arg_def, DataTypeVector* sig) {
447   const int original_size = sig->size();
448   if (!arg_def.number_attr().empty()) {
449     // Same type repeated "repeats" times.
450     int64 repeats = -1;
451     TF_RETURN_IF_ERROR(
452         GetNodeAttr(node_or_attrs, arg_def.number_attr(), &repeats));
453     // We can't handle outputs that are larger than int32 sizes.
454     if (static_cast<int64>(static_cast<int32>(repeats)) != repeats) {
455       return errors::InvalidArgument("Number of outputs is too big: ", repeats);
456     }
457     if (repeats < 0) {
458       return errors::InvalidArgument("Value for number_attr() ", repeats,
459                                      " < 0");
460     }
461 
462     if (!arg_def.type_attr().empty()) {
463       DataType dtype;
464       TF_RETURN_IF_ERROR(
465           GetNodeAttr(node_or_attrs, arg_def.type_attr(), &dtype));
466       for (int i = 0; i < repeats; ++i) {
467         sig->push_back(dtype);
468       }
469     } else if (arg_def.type() != DT_INVALID) {
470       for (int i = 0; i < repeats; ++i) {
471         sig->push_back(arg_def.type());
472       }
473     } else {
474       return errors::InvalidArgument("Missing type or type_attr field in ",
475                                      arg_def.ShortDebugString());
476     }
477   } else if (!arg_def.type_attr().empty()) {
478     const AttrValue* attr_value;
479     TF_RETURN_IF_ERROR(
480         AttrSlice(node_or_attrs).Find(arg_def.type_attr(), &attr_value));
481     sig->push_back(attr_value->type());
482   } else if (!arg_def.type_list_attr().empty()) {
483     const AttrValue* attr_value;
484     TF_RETURN_IF_ERROR(
485         AttrSlice(node_or_attrs).Find(arg_def.type_list_attr(), &attr_value));
486     for (int dtype : attr_value->list().type()) {
487       sig->push_back(static_cast<DataType>(dtype));
488     }
489   } else if (arg_def.type() != DT_INVALID) {
490     sig->push_back(arg_def.type());
491   } else {
492     return errors::InvalidArgument("No type fields in ",
493                                    arg_def.ShortDebugString());
494   }
495   if (arg_def.is_ref()) {
496     // For all types that were added by this function call, make them refs.
497     for (size_t i = original_size; i < sig->size(); ++i) {
498       if (IsRefType((*sig)[i])) {
499         return errors::InvalidArgument(
500             "Requested reference to a reference type: ",
501             arg_def.ShortDebugString());
502       }
503       (*sig)[i] = MakeRefType((*sig)[i]);
504     }
505   }
506   return Status::OK();
507 }
508 
509 }  // namespace
510 
InputTypeForNode(const NodeDef & node_def,const OpDef & op_def,int input_port,DataType * input_type)511 Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
512                         int input_port, DataType* input_type) {
513   DataTypeVector input_types;
514   for (const auto& arg : op_def.input_arg()) {
515     TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &input_types));
516     int input_types_size = input_types.size();
517     if (input_types_size > input_port) {
518       const DataType dtype = input_types[input_port];
519       *input_type = dtype;
520       return Status::OK();
521     }
522   }
523   return errors::InvalidArgument("Input ", input_port, " not found for node ",
524                                  node_def.name());
525 }
526 
InputTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * inputs)527 Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
528                          DataTypeVector* inputs) {
529   for (const auto& arg : op_def.input_arg()) {
530     TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
531   }
532   return Status::OK();
533 }
534 
OutputTypeForNode(const NodeDef & node_def,const OpDef & op_def,int output_port,DataType * output_type)535 Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
536                          int output_port, DataType* output_type) {
537   DataTypeVector output_types;
538   for (const auto& arg : op_def.output_arg()) {
539     TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &output_types));
540     int output_types_size = output_types.size();
541     if (output_types_size > output_port) {
542       const DataType dtype = output_types[output_port];
543       *output_type = dtype;
544       return Status::OK();
545     }
546   }
547   return errors::InvalidArgument("Output ", output_port, " not found for node ",
548                                  node_def.name());
549 }
550 
OutputTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * outputs)551 Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
552                           DataTypeVector* outputs) {
553   for (const auto& arg : op_def.output_arg()) {
554     TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, outputs));
555   }
556   return Status::OK();
557 }
558 
OutputTypesForNode(const AttrSlice & attrs,const OpDef & op_def,DataTypeVector * outputs)559 Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def,
560                           DataTypeVector* outputs) {
561   for (const auto& arg : op_def.output_arg()) {
562     TF_RETURN_IF_ERROR(AddArgToSig(attrs, arg, outputs));
563   }
564   return Status::OK();
565 }
566 
InOutTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * inputs,DataTypeVector * outputs)567 Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
568                          DataTypeVector* inputs, DataTypeVector* outputs) {
569   TF_RETURN_IF_ERROR(InputTypesForNode(node_def, op_def, inputs));
570   return OutputTypesForNode(node_def, op_def, outputs);
571 }
572 
NumOutputsForNode(const NodeDef & node_def,const OpDef & op_def,int * num_outputs)573 Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
574                          int* num_outputs) {
575   DataTypeVector outputs;
576   TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, op_def, &outputs));
577   *num_outputs = outputs.size();
578   return Status::OK();
579 }
580 
ValidateNodeDef(const NodeDef & node_def,const OpDef & op_def)581 Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
582   if (node_def.op() != op_def.name()) {
583     return errors::InvalidArgument(
584         "NodeDef op '", node_def.op(), "' does not match ",
585         SummarizeOpDef(op_def), "; NodeDef: ", FormatNodeDefForError(node_def));
586   }
587 
588   bool seen_control = false;
589   size_t num_inputs = 0;
590   // TODO(josh11b): Unify the input field validation.
591   for (const string& input : node_def.input()) {
592     if (absl::StartsWith(input, "^")) {
593       seen_control = true;
594       if (input.find(':') != string::npos) {
595         return errors::InvalidArgument("Control input '", input,
596                                        "' must not have ':' in NodeDef: ",
597                                        FormatNodeDefForError(node_def));
598       }
599     } else if (seen_control) {
600       return errors::InvalidArgument("Non-control input '", input,
601                                      "' after control input in NodeDef: ",
602                                      FormatNodeDefForError(node_def));
603     } else {
604       ++num_inputs;
605     }
606   }
607 
608   std::unordered_map<string, const OpDef::AttrDef*> op_attrs;
609   for (const auto& attr : op_def.attr()) {
610     if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) {
611       return errors::InvalidArgument("OpDef has duplicate attr name '",
612                                      attr.name(),
613                                      "': ", SummarizeOpDef(op_def));
614     }
615   }
616   for (const auto& attr : node_def.attr()) {
617     // Allow internal optional attributes with names starting with "_".
618     if (absl::StartsWith(attr.first, "_")) {
619       continue;
620     }
621     auto iter = op_attrs.find(attr.first);
622     if (iter == op_attrs.end()) {
623       // A common cause of this error is that TensorFlow has made a
624       // backwards-compatible change to the NodeDef (e.g., adding a
625       // new attr with a default value), but the binary consuming the
626       // NodeDef does not know about the new attribute; the solution
627       // in these cases is to ensure that the binary consuming the
628       // NodeDef is built with a version of TensorFlow no earlier than
629       // the binary producing it.
630       return errors::InvalidArgument(
631           "NodeDef mentions attr '", attr.first, "' not in ",
632           SummarizeOpDef(op_def),
633           "; NodeDef: ", FormatNodeDefForError(node_def),
634           ". (Check whether your GraphDef-interpreting binary is up to date "
635           "with your GraphDef-generating binary.).");
636     }
637     // If attr value is placeholder, do not check it.
638     if (attr.second.placeholder().empty()) {
639       TF_RETURN_WITH_CONTEXT_IF_ERROR(
640           ValidateAttrValue(attr.second, *iter->second),
641           "; NodeDef: ", FormatNodeDefForError(node_def), "; ",
642           SummarizeOpDef(op_def));
643     }
644     // Keep track of which attr names have (not) been found in the NodeDef.
645     op_attrs.erase(iter);
646   }
647 
648   // Were all attrs in the OpDef found in the NodeDef?
649   if (!op_attrs.empty()) {
650     string attrs;
651     for (const auto& attr_pair : op_attrs) {
652       if (!attrs.empty()) strings::StrAppend(&attrs, "', '");
653       strings::StrAppend(&attrs, attr_pair.first);
654     }
655     return errors::InvalidArgument(
656         "NodeDef missing attr", op_attrs.size() == 1 ? " '" : "s '", attrs,
657         "' from ", SummarizeOpDef(op_def),
658         "; NodeDef: ", FormatNodeDefForError(node_def));
659   }
660 
661   // Validate the number of inputs.
662   DataTypeVector inputs, outputs;
663   TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, op_def, &inputs, &outputs));
664 
665   if (num_inputs != inputs.size()) {
666     return errors::InvalidArgument(
667         "NodeDef expected inputs '", DataTypeVectorString(inputs),
668         "' do not match ", num_inputs, " inputs specified; ",
669         SummarizeOpDef(op_def), "; NodeDef: ", FormatNodeDefForError(node_def));
670   }
671 
672   return Status::OK();
673 }
674 
675 namespace {  // Helpers for NameRangesForNode()
676 
ComputeArgRange(const AttrSlice & attrs,const OpDef::ArgDef & arg_def,const OpDef & op_def,int * num)677 Status ComputeArgRange(const AttrSlice& attrs, const OpDef::ArgDef& arg_def,
678                        const OpDef& op_def, int* num) {
679   if (!arg_def.number_attr().empty()) {
680     // Same type repeated "num" times.
681     return GetNodeAttr(attrs, arg_def.number_attr(), num);
682   } else if (!arg_def.type_list_attr().empty()) {
683     const AttrValue* attr_value;
684     TF_RETURN_IF_ERROR(attrs.Find(arg_def.type_list_attr(), &attr_value));
685     *num = attr_value->list().type_size();
686   } else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) {
687     *num = 1;
688   } else {
689     return errors::InvalidArgument(
690         "Argument '", arg_def.name(),
691         "' incorrectly specified in op definition: ", SummarizeOpDef(op_def));
692   }
693   return Status::OK();
694 }
695 
NameRangesHelper(const AttrSlice & attrs,const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const OpDef & op_def,NameRangeMap * result)696 Status NameRangesHelper(const AttrSlice& attrs,
697                         const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
698                         const OpDef& op_def, NameRangeMap* result) {
699   int start = 0;
700   int num;
701   for (const auto& arg : args) {
702     TF_RETURN_IF_ERROR(ComputeArgRange(attrs, arg, op_def, &num));
703     (*result)[arg.name()] = std::make_pair(start, start + num);
704     start += num;
705   }
706   return Status::OK();
707 }
708 
709 }  // namespace
710 
NameRangesForNode(const AttrSlice & attrs,const OpDef & op_def,NameRangeMap * inputs,NameRangeMap * outputs)711 Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def,
712                          NameRangeMap* inputs, NameRangeMap* outputs) {
713   if (inputs != nullptr) {
714     TF_RETURN_IF_ERROR(
715         NameRangesHelper(attrs, op_def.input_arg(), op_def, inputs));
716   }
717   if (outputs != nullptr) {
718     return NameRangesHelper(attrs, op_def.output_arg(), op_def, outputs);
719   }
720   return Status::OK();
721 }
722 
AddDefaultsToNodeDef(const OpDef & op_def,NodeDef * node_def)723 void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) {
724   for (const auto& attr_def : op_def.attr()) {
725     AttrSlice attrs(*node_def);
726     if (attr_def.has_default_value() && !attrs.Find(attr_def.name())) {
727       AddNodeAttr(attr_def.name(), attr_def.default_value(), node_def);
728     }
729   }
730 }
731 
732 namespace {
733 
734 using ::tensorflow::tstring;
735 using ::tensorflow::strings::Scanner;
736 
IsValidNodeName(StringPiece sp)737 bool IsValidNodeName(StringPiece sp) {
738   Scanner scanner(sp);
739   scanner.One(Scanner::LETTER_DIGIT_DOT)
740       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
741 
742   while (true) {
743     if (!scanner.GetResult())  // Some error in previous iteration.
744       return false;
745     if (scanner.empty())  // No error, but nothing left, good.
746       return true;
747 
748     // Absorb another name/namespace, starting with a '>'
749     scanner.One(Scanner::RANGLE)
750         .One(Scanner::LETTER_DIGIT_DOT)
751         .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
752   }
753 }
754 
IsValidDataInputName(StringPiece sp)755 bool IsValidDataInputName(StringPiece sp) {
756   // Data inputs are op_name, op_name:0, or op_name:12345.
757   Scanner scan(sp);
758   scan.One(Scanner::LETTER_DIGIT_DOT)
759       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
760 
761   while (true) {
762     if (!scan.GetResult())  // Some error in previous iteration.
763       return false;
764     if (scan.empty())  // No error, but nothing left, good.
765       return true;
766 
767     if (scan.Peek() == ':') {  // Absorb identifier after the colon
768       scan.OneLiteral(":");
769       if (scan.Peek() == '0') {
770         scan.OneLiteral("0");  // :0
771       } else {
772         scan.Many(Scanner::DIGIT);  // :[1-9][0-9]*
773       }
774     } else {
775       // Absorb another name/namespace, starting with a '>'
776       scan.One(Scanner::RANGLE)
777           .One(Scanner::LETTER_DIGIT_DOT)
778           .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
779     }
780   }
781 }
782 
IsValidControlInputName(StringPiece sp)783 bool IsValidControlInputName(StringPiece sp) {
784   Scanner scan(sp);
785   scan.OneLiteral("^")
786       .One(Scanner::LETTER_DIGIT_DOT)
787       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
788 
789   while (true) {
790     if (!scan.GetResult())  // Some error in previous iteration.
791       return false;
792     if (scan.empty())  // No error, but nothing left, good.
793       return true;
794 
795     // Absorb another name/namespace, starting with a '>'
796     scan.One(Scanner::RANGLE)
797         .One(Scanner::LETTER_DIGIT_DOT)
798         .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
799   }
800 }
801 
802 const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix);
803 
804 }  // namespace
805 
ValidateOpInput(const string & input_name,bool * is_control_input)806 Status ValidateOpInput(const string& input_name, bool* is_control_input) {
807   *is_control_input = false;
808   if (IsValidDataInputName(input_name)) {
809     return Status::OK();
810   } else if (IsValidControlInputName(input_name)) {
811     *is_control_input = true;
812     return Status::OK();
813   } else {
814     return errors::InvalidArgument("Illegal op input name '", input_name, "'");
815   }
816 }
817 
ValidateNodeName(const string & node_name)818 Status ValidateNodeName(const string& node_name) {
819   if (IsValidNodeName(node_name)) {
820     return Status::OK();
821   } else {
822     return errors::InvalidArgument("Illegal op name '", node_name, "'");
823   }
824 }
825 
ValidateExternalNodeDefSyntax(const NodeDef & node_def)826 Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) {
827   Status s = ValidateNodeName(node_def.name());
828   if (!s.ok()) {
829     return AttachDef(s, node_def);
830   }
831   bool in_control_inputs = false;
832   for (const string& input_name : node_def.input()) {
833     bool is_control_input;
834     s = ValidateOpInput(input_name, &is_control_input);
835     if (!s.ok()) {
836       return AttachDef(s, node_def);
837     }
838 
839     if (in_control_inputs && !is_control_input) {
840       return AttachDef(errors::InvalidArgument(
841                            "All control inputs must follow all data inputs"),
842                        node_def);
843     }
844     in_control_inputs = is_control_input;
845   }
846   return Status::OK();
847 }
848 
AttachDef(const Status & status,const NodeDef & node_def,bool allow_multiple_formatted_node)849 Status AttachDef(const Status& status, const NodeDef& node_def,
850                  bool allow_multiple_formatted_node) {
851   Status ret = status;
852   string node_error;
853   if (!allow_multiple_formatted_node &&
854       status.error_message().find("{{node ") != string::npos) {
855     node_error = node_def.name();
856   } else {
857     node_error = FormatNodeDefForError(node_def);
858   }
859   errors::AppendToMessage(&ret, strings::StrCat(" [[", node_error, "]]"));
860   return ret;
861 }
862 
AddNodeAttr(StringPiece name,const AttrValue & value,NodeDef * node_def)863 void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) {
864   node_def->mutable_attr()->insert(
865       AttrValueMap::value_type(string(name), value));
866 }
867 
AddNodeAttr(StringPiece name,AttrValue && value,NodeDef * node_def)868 void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def) {
869   (*node_def->mutable_attr())[string(name)] = std::move(value);
870 }
871 
872 #define ADD_NODE_ATTR(T)                                           \
873   void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \
874     AttrValue attr_value;                                          \
875     SetAttrValue(value, &attr_value);                              \
876     AddNodeAttr(name, attr_value, node_def);                       \
877   }
878 ADD_NODE_ATTR(StringPiece)
ADD_NODE_ATTR(const char *)879 ADD_NODE_ATTR(const char*)
880 ADD_NODE_ATTR(int32)
881 ADD_NODE_ATTR(int64)
882 ADD_NODE_ATTR(float)
883 ADD_NODE_ATTR(double)
884 ADD_NODE_ATTR(bool)
885 ADD_NODE_ATTR(DataType)
886 ADD_NODE_ATTR(const PartialTensorShape&)
887 ADD_NODE_ATTR(const Tensor&)
888 ADD_NODE_ATTR(const TensorProto&)
889 ADD_NODE_ATTR(const NameAttrList&)
890 ADD_NODE_ATTR(gtl::ArraySlice<StringPiece>)
891 ADD_NODE_ATTR(gtl::ArraySlice<const char*>)
892 ADD_NODE_ATTR(gtl::ArraySlice<string>)
893 ADD_NODE_ATTR(gtl::ArraySlice<int32>)
894 ADD_NODE_ATTR(gtl::ArraySlice<int64>)
895 ADD_NODE_ATTR(gtl::ArraySlice<float>)
896 ADD_NODE_ATTR(gtl::ArraySlice<bool>)
897 ADD_NODE_ATTR(const std::vector<bool>&)
898 ADD_NODE_ATTR(gtl::ArraySlice<DataType>)
899 ADD_NODE_ATTR(gtl::ArraySlice<TensorShape>)
900 ADD_NODE_ATTR(gtl::ArraySlice<PartialTensorShape>)
901 ADD_NODE_ATTR(gtl::ArraySlice<TensorShapeProto>)
902 ADD_NODE_ATTR(gtl::ArraySlice<Tensor>)
903 ADD_NODE_ATTR(gtl::ArraySlice<NameAttrList>)
904 #undef ADD_NODE_ATTR
905 
906 void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) {
907   map->insert(AttrValueMap::value_type(string(name), value));
908 }
909 
910 #define ADD_ATTR(T)                                            \
911   void AddAttr(StringPiece name, T value, AttrValueMap* map) { \
912     AttrValue attr_value;                                      \
913     SetAttrValue(value, &attr_value);                          \
914     AddAttr(name, attr_value, map);                            \
915   }
ADD_ATTR(bool)916 ADD_ATTR(bool)
917 #undef ADD_ATTR
918 
919 Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
920                                 NodeDef* node_def, bool uniquify_frame_name) {
921   node_def->set_name(strings::StrCat(prefix, node_def->name(), suffix));
922 
923   // Update frame name to avoid multiple LoopCond nodes in one frame.
924   if (uniquify_frame_name &&
925       (node_def->op() == "Enter" || node_def->op() == "RefEnter")) {
926     string frame_name;
927     TF_RETURN_IF_ERROR(GetNodeAttr(*node_def, "frame_name", &frame_name));
928     AttrValue& attr = (*node_def->mutable_attr())["frame_name"];
929     frame_name = strings::StrCat(prefix, frame_name, suffix);
930     attr.set_s(frame_name);
931   }
932 
933   return Status::OK();
934 }
935 
MaybeAddPrefixToColocationConstraints(const std::unordered_set<string> & match,StringPiece prefix,NodeDef * node_def)936 Status MaybeAddPrefixToColocationConstraints(
937     const std::unordered_set<string>& match, StringPiece prefix,
938     NodeDef* node_def) {
939   auto attr = node_def->mutable_attr()->find(kColocationAttrName);
940   if (attr == node_def->mutable_attr()->end()) {
941     return Status::OK();
942   }
943   auto constraints_list = attr->second.mutable_list();
944   auto constraints_size = constraints_list->s_size();
945   for (size_t i = 0; i < constraints_size; ++i) {
946     StringPiece original(constraints_list->s(i));
947     if (absl::ConsumePrefix(&original, kColocationGroupPrefixStringPiece)) {
948       if (match.find(string(original)) != match.end()) {
949         (*constraints_list->mutable_s(i)) =
950             strings::StrCat(kColocationGroupPrefix, prefix, original);
951       }
952     }
953   }
954   return Status::OK();
955 }
956 
957 }  // namespace tensorflow
958