1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/node_def_util.h"
17 
18 #include <algorithm>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/graph.pb_text.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_def.pb_text.h"
27 #include "tensorflow/core/framework/op_def_util.h"
28 #include "tensorflow/core/framework/tensor.pb_text.h"
29 #include "tensorflow/core/framework/tensor_shape.pb.h"
30 #include "tensorflow/core/graph/graph.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/gtl/map_util.h"
33 #include "tensorflow/core/lib/strings/scanner.h"
34 #include "tensorflow/core/lib/strings/str_util.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/platform/protobuf.h"
37 
38 namespace tensorflow {
39 
40 const char* const kColocationAttrName = "_class";
41 const char* const kColocationGroupPrefix = "loc:@";
42 
AttrSlice()43 AttrSlice::AttrSlice() : ndef_(nullptr) {
44   static const AttrValueMap* const kEmptyAttrValueMap = new AttrValueMap;
45   attrs_ = kEmptyAttrValueMap;
46 }
47 
AttrSlice(const NodeDef & node_def)48 AttrSlice::AttrSlice(const NodeDef& node_def)
49     : ndef_(&node_def), attrs_(&ndef_->attr()) {}
50 
AttrSlice(const AttrValueMap * a)51 AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {}
52 
SummarizeAttrsHelper(AttrSlice attrs,StringPiece device)53 static string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) {
54   string ret;
55 
56   // We sort the attrs so the output is deterministic.
57   std::vector<string> attr_names;
58   attr_names.reserve(attrs.size());
59   for (const auto& attr : attrs) {
60     attr_names.push_back(attr.first);
61   }
62   std::sort(attr_names.begin(), attr_names.end());
63   bool first = true;
64   for (const string& attr_name : attr_names) {
65     if (!first) strings::StrAppend(&ret, ", ");
66     first = false;
67     strings::StrAppend(&ret, attr_name, "=",
68                        SummarizeAttrValue(*attrs.Find(attr_name)));
69   }
70 
71   // Consider the device to be a final attr with name "_device".
72   if (!device.empty()) {
73     if (!first) strings::StrAppend(&ret, ", ");
74     first = false;
75     strings::StrAppend(&ret, "_device=\"", device, "\"");
76   }
77   return ret;
78 }
79 
SummarizeNode() const80 string AttrSlice::SummarizeNode() const {
81   return ndef_ ? SummarizeNodeDef(*ndef_)
82                : strings::StrCat(
83                      "[", SummarizeAttrsHelper(*this, StringPiece()), "]");
84 }
85 
SummarizeNode(const Node & node)86 string SummarizeNode(const Node& node) { return SummarizeNodeDef(node.def()); }
87 
SummarizeNodeDef(const NodeDef & node_def)88 string SummarizeNodeDef(const NodeDef& node_def) {
89   string ret = strings::StrCat(errors::FormatNodeNameForError(node_def.name()),
90                                " = ", node_def.op(), "[");
91   strings::StrAppend(&ret, SummarizeAttrsHelper(node_def, node_def.device()));
92   strings::StrAppend(&ret, "](");
93 
94   // Output inputs, including control inputs, verbatim.
95   bool first = true;
96   for (const string& input : node_def.input()) {
97     if (!first) strings::StrAppend(&ret, ", ");
98     first = false;
99     strings::StrAppend(&ret, input);
100   }
101   strings::StrAppend(&ret, ")");
102   return ret;
103 }
104 
SummarizeAttrs(const NodeDef & node_def)105 string SummarizeAttrs(const NodeDef& node_def) {
106   return SummarizeAttrsHelper(node_def, node_def.device());
107 }
108 
FormatNodeForError(const NodeDebugInfo & debug_info)109 string FormatNodeForError(const NodeDebugInfo& debug_info) {
110   return debug_info.original_node_names.empty()
111              ? errors::FormatNodeNameForError(debug_info.name)
112              : errors::FormatNodeNamesForError(debug_info.original_node_names);
113 }
114 
FormatNodeForError(const Node & node)115 string FormatNodeForError(const Node& node) {
116   return FormatNodeForError(NodeDebugInfo(node));
117 }
118 
FormatNodeDefForError(const NodeDef & node_def)119 string FormatNodeDefForError(const NodeDef& node_def) {
120   return FormatNodeForError(NodeDebugInfo(node_def));
121 }
122 
GetMergedOriginalNodeNames(const NodeDebugInfo & from,const NodeDebugInfo & to,std::set<string> * names)123 void GetMergedOriginalNodeNames(const NodeDebugInfo& from,
124                                 const NodeDebugInfo& to,
125                                 std::set<string>* names) {
126   if (!from.original_node_names.empty()) {
127     names->insert(from.original_node_names.begin(),
128                   from.original_node_names.end());
129   } else {
130     names->insert(from.name);
131   }
132   names->insert(to.original_node_names.begin(), to.original_node_names.end());
133 }
134 
MergeDebugInfo(const NodeDebugInfo & from,Node * to)135 void MergeDebugInfo(const NodeDebugInfo& from, Node* to) {
136   std::set<string> names;
137   GetMergedOriginalNodeNames(from, NodeDebugInfo(*to), &names);
138   to->set_original_node_names({names.begin(), names.end()});
139 }
140 
MergeDebugInfo(const NodeDebugInfo & from,NodeDef * to)141 void MergeDebugInfo(const NodeDebugInfo& from, NodeDef* to) {
142   std::set<string> names;
143   GetMergedOriginalNodeNames(from, NodeDebugInfo(*to), &names);
144   to->mutable_experimental_debug_info()->clear_original_node_names();
145   if (!names.empty()) {
146     *to->mutable_experimental_debug_info()->mutable_original_node_names() = {
147         names.begin(), names.end()};
148   }
149 }
150 
MergeDebugInfo(const NodeDef & from,NodeDef * to)151 void MergeDebugInfo(const NodeDef& from, NodeDef* to) {
152   MergeDebugInfo(NodeDebugInfo(from), to);
153 }
154 
Find(StringPiece attr_name) const155 const AttrValue* AttrSlice::Find(StringPiece attr_name) const {
156   // Currently, the collection used for NodeDef::attr() (google::protobuf::Map)
157   // requires that the keys used for lookups have type 'const string&'. Because
158   // this method takes a StringPiece, it is necessary to allocate a temporary
159   // string, copy attr_name to it, and then use that temporary string for the
160   // lookup. This causes an excessive number of short-lived allocations, and for
161   // large graphs, this can be a significant cost.
162   //
163   // Because most nodes have a small number of attributes, a simple linear scan
164   // is generally more efficient than a hashed lookup.  If google::protobuf::Map
165   // changes so that it supports efficient lookups using StringPiece instead of
166   // const string&, then this code could be changed to use attrs_->find() again.
167 
168   for (const auto& attr : *attrs_) {
169     if (attr.first == attr_name) {
170       return &attr.second;
171     }
172   }
173   return nullptr;
174 }
175 
Find(StringPiece attr_name,const AttrValue ** attr_value) const176 Status AttrSlice::Find(StringPiece attr_name,
177                        const AttrValue** attr_value) const {
178   *attr_value = Find(attr_name);
179   if (*attr_value != nullptr) {
180     return Status::OK();
181   }
182   Status s = errors::NotFound("No attr named '", attr_name, "' in NodeDef:");
183   // Skip AttachDef for internal attrs since it is a little bit
184   // expensive and it is common for them to correctly not be included
185   // in a NodeDef.
186   if (!str_util::StartsWith(attr_name, "_") && ndef_ != nullptr) {
187     s = AttachDef(s, *ndef_);
188   }
189   return s;
190 }
191 
EqualAttrs(AttrSlice other,Scratch * scratch) const192 bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const {
193   if (size() != other.size()) return false;
194 
195   for (const auto& attr : *other.attrs_) {
196     auto iter = attrs_->find(attr.first);
197     if (iter == attrs_->end()) return false;
198     // TODO(irving): Comparing AttrValues by proto is slightly buggy, since
199     // TensorProto is a nonunique representation of Tensor.  This bug will go
200     // away once AttrSlice switches over to NodeInfo.
201     iter->second.SerializeToString(&scratch->a);
202     attr.second.SerializeToString(&scratch->b);
203     if (scratch->a != scratch->b) return false;
204   }
205   return true;
206 }
207 
208 // The ... is to allow the caller to inject some value validation code.  Use
209 // just ; if no additional validation code is needed.
210 #define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...)         \
211   Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,           \
212                      TYPE* value) {                                           \
213     const AttrValue* attr_value;                                              \
214     TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));                   \
215     TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, ATTR_TYPE));             \
216     const auto& v = attr_value->FIELD();                                      \
217     __VA_ARGS__;                                                              \
218     *value = CAST;                                                            \
219     return Status::OK();                                                      \
220   }                                                                           \
221   Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,           \
222                      std::vector<TYPE>* value) {                              \
223     const AttrValue* attr_value;                                              \
224     TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));                   \
225     TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")")); \
226     for (const auto& v : attr_value->list().FIELD()) {                        \
227       __VA_ARGS__;                                                            \
228       value->APPEND_OP(CAST);                                                 \
229     }                                                                         \
230     return Status::OK();                                                      \
231   }
232 
233 #define DEFINE_GET_ATTR_SIMPLE(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
234   bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,      \
235                          TYPE* value) {                                      \
236     const AttrValue* attr_value = attrs.Find(attr_name);                     \
237     if (attr_value == nullptr) {                                             \
238       return false;                                                          \
239     }                                                                        \
240     Status s = AttrValueHasType(*attr_value, ATTR_TYPE);                     \
241     if (!s.ok()) {                                                           \
242       return false;                                                          \
243     }                                                                        \
244     const auto& v = attr_value->FIELD();                                     \
245     __VA_ARGS__;                                                             \
246     *value = CAST;                                                           \
247     return true;                                                             \
248   }                                                                          \
249   bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,      \
250                          std::vector<TYPE>* value) {                         \
251     const AttrValue* attr_value = attrs.Find(attr_name);                     \
252     if (attr_value == nullptr) {                                             \
253       return false;                                                          \
254     }                                                                        \
255     Status s = AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")");         \
256     if (!s.ok()) {                                                           \
257       return false;                                                          \
258     }                                                                        \
259     for (const auto& v : attr_value->list().FIELD()) {                       \
260       __VA_ARGS__;                                                           \
261       value->APPEND_OP(CAST);                                                \
262     }                                                                        \
263     return true;                                                             \
264   }
265 
266 DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;)
267 DEFINE_GET_ATTR_SIMPLE(string, s, "string", emplace_back, v, ;)
268 DEFINE_GET_ATTR(int64, i, "int", emplace_back, v, ;)
269 DEFINE_GET_ATTR(int32, i, "int", emplace_back, static_cast<int32>(v),
270                 if (static_cast<int64>(static_cast<int32>(v)) != v) {
271                   return errors::InvalidArgument("Attr ", attr_name,
272                                                  " has value ", v,
273                                                  " out of range for an int32");
274                 })
275 DEFINE_GET_ATTR(float, f, "float", emplace_back, v, ;)
276 // std::vector<bool> specialization does not have emplace_back until
277 // c++14, so we have to use push_back (see
278 // http://en.cppreference.com/w/cpp/container/vector/emplace_back)
279 DEFINE_GET_ATTR(bool, b, "bool", push_back, v, ;)
280 DEFINE_GET_ATTR(DataType, type, "type", emplace_back, static_cast<DataType>(v),
281                 ;)
282 DEFINE_GET_ATTR(TensorShapeProto, shape, "shape", emplace_back, v, ;)
283 DEFINE_GET_ATTR(TensorShape, shape, "shape", emplace_back, TensorShape(v),
284                 TF_RETURN_IF_ERROR(TensorShape::IsValidShape(v));)
285 DEFINE_GET_ATTR(PartialTensorShape, shape, "shape", emplace_back,
286                 PartialTensorShape(v),
287                 TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(v));)
288 DEFINE_GET_ATTR(Tensor, tensor, "tensor", emplace_back, t, Tensor t;
289                 if (!t.FromProto(v)) {
290                   return errors::InvalidArgument(
291                       "Attr ", attr_name, " has value ",
292                       ProtoShortDebugString(v),
293                       " that can't be converted to a Tensor");
294                 })
295 DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;);
296 #undef DEFINE_GET_ATTR
297 
HasNodeAttr(const NodeDef & node_def,StringPiece attr_name)298 bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name) {
299   return node_def.attr().find(string(attr_name)) != node_def.attr().end();
300 }
301 
302 static const string& kEmptyString = *new string();
303 
GetNodeAttrString(const AttrSlice & attrs,StringPiece attr_name)304 const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name) {
305   const AttrValue* attr_value = attrs.Find(attr_name);
306   if (attr_value == nullptr) {
307     return kEmptyString;
308   }
309   Status s = AttrValueHasType(*attr_value, "string");
310   if (!s.ok()) {
311     return kEmptyString;
312   }
313   return attr_value->s();
314 }
315 
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,DataTypeVector * value)316 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
317                    DataTypeVector* value) {
318   const AttrValue* attr_value;
319   TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
320   TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(type)"));
321   for (const auto& v : attr_value->list().type()) {
322     value->push_back(static_cast<DataType>(v));
323   }
324   return Status::OK();
325 }
326 
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const TensorProto ** value)327 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
328                    const TensorProto** value) {
329   const AttrValue* attr_value;
330   TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
331   TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "tensor"));
332   *value = &attr_value->tensor();
333   return Status::OK();
334 }
335 
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const NameAttrList ** value)336 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
337                    const NameAttrList** value) {
338   const AttrValue* attr_value;
339   TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
340   TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func"));
341   *value = &attr_value->func();
342   return Status::OK();
343 }
344 
345 namespace {  // Helper for InOutTypesForNode().
346 
AddArgToSig(const NodeDef & node_def,const OpDef::ArgDef & arg_def,DataTypeVector * sig)347 Status AddArgToSig(const NodeDef& node_def, const OpDef::ArgDef& arg_def,
348                    DataTypeVector* sig) {
349   const int original_size = sig->size();
350   if (!arg_def.number_attr().empty()) {
351     // Same type repeated "repeats" times.
352     int32 repeats = -1;
353     TF_RETURN_IF_ERROR(GetNodeAttr(node_def, arg_def.number_attr(), &repeats));
354     if (repeats < 0) {
355       return errors::InvalidArgument("Value for number_attr() ", repeats,
356                                      " < 0");
357     }
358 
359     if (!arg_def.type_attr().empty()) {
360       DataType dtype;
361       TF_RETURN_IF_ERROR(GetNodeAttr(node_def, arg_def.type_attr(), &dtype));
362       for (int i = 0; i < repeats; ++i) {
363         sig->push_back(dtype);
364       }
365     } else if (arg_def.type() != DT_INVALID) {
366       for (int i = 0; i < repeats; ++i) {
367         sig->push_back(arg_def.type());
368       }
369     } else {
370       return errors::InvalidArgument("Missing type or type_attr field in ",
371                                      ProtoShortDebugString(arg_def));
372     }
373   } else if (!arg_def.type_attr().empty()) {
374     const AttrValue* attr_value;
375     TF_RETURN_IF_ERROR(
376         AttrSlice(node_def).Find(arg_def.type_attr(), &attr_value));
377     sig->push_back(attr_value->type());
378   } else if (!arg_def.type_list_attr().empty()) {
379     const AttrValue* attr_value;
380     TF_RETURN_IF_ERROR(
381         AttrSlice(node_def).Find(arg_def.type_list_attr(), &attr_value));
382     for (int dtype : attr_value->list().type()) {
383       sig->push_back(static_cast<DataType>(dtype));
384     }
385   } else if (arg_def.type() != DT_INVALID) {
386     sig->push_back(arg_def.type());
387   } else {
388     return errors::InvalidArgument("No type fields in ",
389                                    ProtoShortDebugString(arg_def));
390   }
391   if (arg_def.is_ref()) {
392     // For all types that were added by this function call, make them refs.
393     for (size_t i = original_size; i < sig->size(); ++i) {
394       (*sig)[i] = MakeRefType((*sig)[i]);
395     }
396   }
397   return Status::OK();
398 }
399 
400 }  // namespace
401 
InputTypeForNode(const NodeDef & node_def,const OpDef & op_def,int input_port,DataType * input_type)402 Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
403                         int input_port, DataType* input_type) {
404   DataTypeVector input_types;
405   for (const auto& arg : op_def.input_arg()) {
406     TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &input_types));
407     if (input_types.size() > input_port) {
408       const DataType dtype = input_types[input_port];
409       *input_type = dtype;
410       return Status::OK();
411     }
412   }
413   return errors::InvalidArgument("Input ", input_port, " not found for node ",
414                                  node_def.name());
415 }
416 
InputTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * inputs)417 Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
418                          DataTypeVector* inputs) {
419   for (const auto& arg : op_def.input_arg()) {
420     TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
421   }
422   return Status::OK();
423 }
424 
OutputTypeForNode(const NodeDef & node_def,const OpDef & op_def,int output_port,DataType * output_type)425 Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
426                          int output_port, DataType* output_type) {
427   DataTypeVector output_types;
428   for (const auto& arg : op_def.output_arg()) {
429     TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &output_types));
430     if (output_types.size() > output_port) {
431       const DataType dtype = output_types[output_port];
432       *output_type = dtype;
433       return Status::OK();
434     }
435   }
436   return errors::InvalidArgument("Output ", output_port, " not found for node ",
437                                  node_def.name());
438 }
439 
OutputTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * outputs)440 Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
441                           DataTypeVector* outputs) {
442   for (const auto& arg : op_def.output_arg()) {
443     TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, outputs));
444   }
445   return Status::OK();
446 }
447 
InOutTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * inputs,DataTypeVector * outputs)448 Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
449                          DataTypeVector* inputs, DataTypeVector* outputs) {
450   TF_RETURN_IF_ERROR(InputTypesForNode(node_def, op_def, inputs));
451   return OutputTypesForNode(node_def, op_def, outputs);
452 }
453 
NumOutputsForNode(const NodeDef & node_def,const OpDef & op_def,int * num_outputs)454 Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
455                          int* num_outputs) {
456   DataTypeVector outputs;
457   TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, op_def, &outputs));
458   *num_outputs = outputs.size();
459   return Status::OK();
460 }
461 
ValidateNodeDef(const NodeDef & node_def,const OpDef & op_def)462 Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
463   if (node_def.op() != op_def.name()) {
464     return errors::InvalidArgument(
465         "NodeDef op '", node_def.op(), "' does not match ",
466         SummarizeOpDef(op_def), "; NodeDef: ", FormatNodeDefForError(node_def));
467   }
468 
469   bool seen_control = false;
470   size_t num_inputs = 0;
471   // TODO(josh11b): Unify the input field validation.
472   for (const string& input : node_def.input()) {
473     if (str_util::StartsWith(input, "^")) {
474       seen_control = true;
475       if (input.find(':') != string::npos) {
476         return errors::InvalidArgument("Control input '", input,
477                                        "' must not have ':' in NodeDef: ",
478                                        FormatNodeDefForError(node_def));
479       }
480     } else if (seen_control) {
481       return errors::InvalidArgument("Non-control input '", input,
482                                      "' after control input in NodeDef: ",
483                                      FormatNodeDefForError(node_def));
484     } else {
485       ++num_inputs;
486     }
487   }
488 
489   std::unordered_map<string, const OpDef::AttrDef*> op_attrs;
490   for (const auto& attr : op_def.attr()) {
491     if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) {
492       return errors::InvalidArgument("OpDef has duplicate attr name '",
493                                      attr.name(),
494                                      "': ", SummarizeOpDef(op_def));
495     }
496   }
497   for (const auto& attr : node_def.attr()) {
498     // Allow internal optional attributes with names starting with "_".
499     if (str_util::StartsWith(attr.first, "_")) {
500       continue;
501     }
502     auto iter = op_attrs.find(attr.first);
503     if (iter == op_attrs.end()) {
504       // A common cause of this error is that TensorFlow has made a
505       // backwards-compatible change to the NodeDef (e.g., adding a
506       // new attr with a default value), but the binary consuming the
507       // NodeDef does not know about the new attribute; the solution
508       // in these cases is to ensure that the binary consuming the
509       // NodeDef is built with a version of TensorFlow no earlier than
510       // the binary producing it.
511       return errors::InvalidArgument(
512           "NodeDef mentions attr '", attr.first, "' not in ",
513           SummarizeOpDef(op_def),
514           "; NodeDef: ", FormatNodeDefForError(node_def),
515           ". (Check whether your GraphDef-interpreting binary is up to date "
516           "with your GraphDef-generating binary.).");
517     }
518     // If attr value is placeholder, do not check it.
519     if (attr.second.placeholder().empty()) {
520       TF_RETURN_WITH_CONTEXT_IF_ERROR(
521           ValidateAttrValue(attr.second, *iter->second),
522           "; NodeDef: ", FormatNodeDefForError(node_def), "; ",
523           SummarizeOpDef(op_def));
524     }
525     // Keep track of which attr names have (not) been found in the NodeDef.
526     op_attrs.erase(iter);
527   }
528 
529   // Were all attrs in the OpDef found in the NodeDef?
530   if (!op_attrs.empty()) {
531     string attrs;
532     for (const auto& attr_pair : op_attrs) {
533       if (!attrs.empty()) strings::StrAppend(&attrs, "', '");
534       strings::StrAppend(&attrs, attr_pair.first);
535     }
536     return errors::InvalidArgument(
537         "NodeDef missing attr", op_attrs.size() == 1 ? " '" : "s '", attrs,
538         "' from ", SummarizeOpDef(op_def),
539         "; NodeDef: ", FormatNodeDefForError(node_def));
540   }
541 
542   // Validate the number of inputs.
543   DataTypeVector inputs, outputs;
544   TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, op_def, &inputs, &outputs));
545 
546   if (num_inputs != inputs.size()) {
547     return errors::InvalidArgument(
548         "NodeDef expected inputs '", DataTypeVectorString(inputs),
549         "' do not match ", num_inputs, " inputs specified; ",
550         SummarizeOpDef(op_def), "; NodeDef: ", FormatNodeDefForError(node_def));
551   }
552 
553   return Status::OK();
554 }
555 
556 namespace {  // Helpers for NameRangesForNode()
557 
ComputeArgRange(const NodeDef & node_def,const OpDef::ArgDef & arg_def,const OpDef & op_def,int * num)558 Status ComputeArgRange(const NodeDef& node_def, const OpDef::ArgDef& arg_def,
559                        const OpDef& op_def, int* num) {
560   if (!arg_def.number_attr().empty()) {
561     // Same type repeated "num" times.
562     return GetNodeAttr(node_def, arg_def.number_attr(), num);
563   } else if (!arg_def.type_list_attr().empty()) {
564     const AttrValue* attr_value;
565     TF_RETURN_IF_ERROR(
566         AttrSlice(node_def).Find(arg_def.type_list_attr(), &attr_value));
567     *num = attr_value->list().type_size();
568   } else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) {
569     *num = 1;
570   } else {
571     return errors::InvalidArgument(
572         "Argument '", arg_def.name(),
573         "' incorrectly specified in op definition: ", SummarizeOpDef(op_def));
574   }
575   return Status::OK();
576 }
577 
NameRangesHelper(const NodeDef & node_def,const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const OpDef & op_def,NameRangeMap * result)578 Status NameRangesHelper(const NodeDef& node_def,
579                         const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
580                         const OpDef& op_def, NameRangeMap* result) {
581   int start = 0;
582   int num;
583   for (const auto& arg : args) {
584     TF_RETURN_IF_ERROR(ComputeArgRange(node_def, arg, op_def, &num));
585     (*result)[arg.name()] = std::make_pair(start, start + num);
586     start += num;
587   }
588   return Status::OK();
589 }
590 
591 }  // namespace
592 
NameRangesForNode(const NodeDef & node_def,const OpDef & op_def,NameRangeMap * inputs,NameRangeMap * outputs)593 Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def,
594                          NameRangeMap* inputs, NameRangeMap* outputs) {
595   if (inputs != nullptr) {
596     TF_RETURN_IF_ERROR(
597         NameRangesHelper(node_def, op_def.input_arg(), op_def, inputs));
598   }
599   if (outputs != nullptr) {
600     return NameRangesHelper(node_def, op_def.output_arg(), op_def, outputs);
601   }
602   return Status::OK();
603 }
604 
NameRangesForNode(const Node & node,const OpDef & op_def,NameRangeMap * inputs,NameRangeMap * outputs)605 Status NameRangesForNode(const Node& node, const OpDef& op_def,
606                          NameRangeMap* inputs, NameRangeMap* outputs) {
607   return NameRangesForNode(node.def(), op_def, inputs, outputs);
608 }
609 
AddDefaultsToNodeDef(const OpDef & op_def,NodeDef * node_def)610 void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) {
611   for (const auto& attr_def : op_def.attr()) {
612     AttrSlice attrs(*node_def);
613     if (attr_def.has_default_value() && !attrs.Find(attr_def.name())) {
614       AddNodeAttr(attr_def.name(), attr_def.default_value(), node_def);
615     }
616   }
617 }
618 
619 namespace {
620 
621 using ::tensorflow::strings::Scanner;
622 
IsValidOpName(StringPiece sp)623 bool IsValidOpName(StringPiece sp) {
624   return Scanner(sp)
625       .One(Scanner::LETTER_DIGIT_DOT)
626       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
627       .Eos()
628       .GetResult();
629 }
630 
IsValidDataInputName(StringPiece sp)631 bool IsValidDataInputName(StringPiece sp) {
632   // Data inputs are op_name, op_name:0, or op_name:12345.
633   Scanner scan(sp);
634   scan.One(Scanner::LETTER_DIGIT_DOT)
635       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
636   if (scan.Peek() == ':') {
637     scan.OneLiteral(":");
638     if (scan.Peek() == '0') {
639       scan.OneLiteral("0");  // :0
640     } else {
641       scan.Many(Scanner::DIGIT);  // :[1-9][0-9]*
642     }
643   }
644   scan.Eos();
645 
646   return scan.GetResult();
647 }
648 
IsValidControlInputName(StringPiece sp)649 bool IsValidControlInputName(StringPiece sp) {
650   return Scanner(sp)
651       .OneLiteral("^")
652       .One(Scanner::LETTER_DIGIT_DOT)
653       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
654       .Eos()
655       .GetResult();
656 }
657 
658 }  // namespace
659 
ValidateOpInput(const string & input_name,bool * is_control_input)660 Status ValidateOpInput(const string& input_name, bool* is_control_input) {
661   *is_control_input = false;
662   if (IsValidDataInputName(input_name)) {
663     return Status::OK();
664   } else if (IsValidControlInputName(input_name)) {
665     *is_control_input = true;
666     return Status::OK();
667   } else {
668     return errors::InvalidArgument("Illegal op input name '", input_name, "'");
669   }
670 }
671 
ValidateOpName(const string & op_name)672 Status ValidateOpName(const string& op_name) {
673   if (IsValidOpName(op_name)) {
674     return Status::OK();
675   } else {
676     return errors::InvalidArgument("Illegal op name '", op_name, "'");
677   }
678 }
679 
ValidateExternalNodeDefSyntax(const NodeDef & node_def)680 Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) {
681   Status s = ValidateOpName(node_def.name());
682   if (!s.ok()) {
683     return AttachDef(s, node_def);
684   }
685   bool in_control_inputs = false;
686   for (const string& input_name : node_def.input()) {
687     bool is_control_input;
688     s = ValidateOpInput(input_name, &is_control_input);
689     if (!s.ok()) {
690       return AttachDef(s, node_def);
691     }
692 
693     if (in_control_inputs && !is_control_input) {
694       return AttachDef(errors::InvalidArgument(
695                            "All control inputs must follow all data inputs"),
696                        node_def);
697     }
698     in_control_inputs = is_control_input;
699   }
700   return Status::OK();
701 }
702 
AttachDef(const Status & status,const NodeDef & node_def,bool allow_multiple_formatted_node)703 Status AttachDef(const Status& status, const NodeDef& node_def,
704                  bool allow_multiple_formatted_node) {
705   Status ret = status;
706   string node_error;
707   if (!allow_multiple_formatted_node &&
708       status.error_message().find("{{node ") != string::npos) {
709     node_error = node_def.name();
710   } else {
711     node_error = FormatNodeDefForError(node_def);
712   }
713   errors::AppendToMessage(&ret, strings::StrCat(" [[", node_error, "]]"));
714   return ret;
715 }
716 
AttachDef(const Status & status,const Node & node,bool allow_multiple_formatted_node)717 Status AttachDef(const Status& status, const Node& node,
718                  bool allow_multiple_formatted_node) {
719   return AttachDef(status, node.def(), allow_multiple_formatted_node);
720 }
721 
AddNodeAttr(StringPiece name,const AttrValue & value,NodeDef * node_def)722 void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) {
723   node_def->mutable_attr()->insert(
724       AttrValueMap::value_type(string(name), value));
725 }
726 
727 #define ADD_NODE_ATTR(T)                                           \
728   void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \
729     AttrValue attr_value;                                          \
730     SetAttrValue(value, &attr_value);                              \
731     AddNodeAttr(name, attr_value, node_def);                       \
732   }
733 ADD_NODE_ATTR(StringPiece)
ADD_NODE_ATTR(const char *)734 ADD_NODE_ATTR(const char*)
735 ADD_NODE_ATTR(int32)
736 ADD_NODE_ATTR(int64)
737 ADD_NODE_ATTR(float)
738 ADD_NODE_ATTR(double)
739 ADD_NODE_ATTR(bool)
740 ADD_NODE_ATTR(DataType)
741 ADD_NODE_ATTR(const PartialTensorShape&)
742 ADD_NODE_ATTR(const Tensor&)
743 ADD_NODE_ATTR(const TensorProto&)
744 ADD_NODE_ATTR(const NameAttrList&)
745 ADD_NODE_ATTR(gtl::ArraySlice<StringPiece>)
746 ADD_NODE_ATTR(gtl::ArraySlice<const char*>)
747 ADD_NODE_ATTR(gtl::ArraySlice<string>)
748 ADD_NODE_ATTR(gtl::ArraySlice<int32>)
749 ADD_NODE_ATTR(gtl::ArraySlice<int64>)
750 ADD_NODE_ATTR(gtl::ArraySlice<float>)
751 ADD_NODE_ATTR(gtl::ArraySlice<bool>)
752 ADD_NODE_ATTR(const std::vector<bool>&)
753 ADD_NODE_ATTR(gtl::ArraySlice<DataType>)
754 ADD_NODE_ATTR(gtl::ArraySlice<TensorShape>)
755 ADD_NODE_ATTR(gtl::ArraySlice<PartialTensorShape>)
756 ADD_NODE_ATTR(gtl::ArraySlice<TensorShapeProto>)
757 ADD_NODE_ATTR(gtl::ArraySlice<Tensor>)
758 ADD_NODE_ATTR(gtl::ArraySlice<NameAttrList>)
759 #undef ADD_NODE_ATTR
760 
761 void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) {
762   map->insert(AttrValueMap::value_type(string(name), value));
763 }
764 
765 #define ADD_ATTR(T)                                            \
766   void AddAttr(StringPiece name, T value, AttrValueMap* map) { \
767     AttrValue attr_value;                                      \
768     SetAttrValue(value, &attr_value);                          \
769     AddAttr(name, attr_value, map);                            \
770   }
ADD_ATTR(bool)771 ADD_ATTR(bool)
772 #undef ADD_ATTR
773 
774 Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
775                                 NodeDef* node_def) {
776   node_def->set_name(strings::StrCat(prefix, node_def->name(), suffix));
777   if (node_def->op() == "Enter" || node_def->op() == "RefEnter") {
778     string frame_name;
779     TF_RETURN_IF_ERROR(GetNodeAttr(*node_def, "frame_name", &frame_name));
780     AttrValue& attr = (*node_def->mutable_attr())["frame_name"];
781     frame_name = strings::StrCat(prefix, frame_name, suffix);
782     attr.set_s(frame_name);
783   }
784   return Status::OK();
785 }
786 
787 }  // namespace tensorflow
788