1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/node_def_util.h"
17 
18 #include <algorithm>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/graph.pb_text.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_def.pb_text.h"
27 #include "tensorflow/core/framework/op_def_util.h"
28 #include "tensorflow/core/framework/tensor.pb_text.h"
29 #include "tensorflow/core/framework/tensor_shape.pb.h"
30 #include "tensorflow/core/graph/graph.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/gtl/map_util.h"
33 #include "tensorflow/core/lib/strings/scanner.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 
37 namespace tensorflow {
38 
39 const char* const kColocationAttrName = "_class";
40 const char* const kColocationGroupPrefix = "loc:@";
41 
AttrSlice()42 AttrSlice::AttrSlice() : ndef_(nullptr) {
43   static const AttrValueMap* const kEmptyAttrValueMap = new AttrValueMap;
44   attrs_ = kEmptyAttrValueMap;
45 }
46 
AttrSlice(const NodeDef & node_def)47 AttrSlice::AttrSlice(const NodeDef& node_def)
48     : ndef_(&node_def), attrs_(&ndef_->attr()) {}
49 
AttrSlice(const AttrValueMap * a)50 AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {}
51 
SummarizeAttrsHelper(AttrSlice attrs,StringPiece device)52 static string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) {
53   string ret;
54 
55   // We sort the attrs so the output is deterministic.
56   std::vector<string> attr_names;
57   attr_names.reserve(attrs.size());
58   for (const auto& attr : attrs) {
59     attr_names.push_back(attr.first);
60   }
61   std::sort(attr_names.begin(), attr_names.end());
62   bool first = true;
63   for (const string& attr_name : attr_names) {
64     if (!first) strings::StrAppend(&ret, ", ");
65     first = false;
66     strings::StrAppend(&ret, attr_name, "=",
67                        SummarizeAttrValue(*attrs.Find(attr_name)));
68   }
69 
70   // Consider the device to be a final attr with name "_device".
71   if (!device.empty()) {
72     if (!first) strings::StrAppend(&ret, ", ");
73     first = false;
74     strings::StrAppend(&ret, "_device=\"", device, "\"");
75   }
76   return ret;
77 }
78 
SummarizeNode() const79 string AttrSlice::SummarizeNode() const {
80   return ndef_ ? SummarizeNodeDef(*ndef_)
81                : strings::StrCat(
82                      "[", SummarizeAttrsHelper(*this, StringPiece()), "]");
83 }
84 
SummarizeNode(const Node & node)85 string SummarizeNode(const Node& node) { return SummarizeNodeDef(node.def()); }
86 
SummarizeNodeDef(const NodeDef & node_def)87 string SummarizeNodeDef(const NodeDef& node_def) {
88   string ret = strings::StrCat(node_def.name(), " = ", node_def.op(), "[");
89   strings::StrAppend(&ret, SummarizeAttrsHelper(node_def, node_def.device()));
90   strings::StrAppend(&ret, "](");
91 
92   // Output inputs, including control inputs, verbatim.
93   bool first = true;
94   for (const string& input : node_def.input()) {
95     if (!first) strings::StrAppend(&ret, ", ");
96     first = false;
97     strings::StrAppend(&ret, input);
98   }
99   strings::StrAppend(&ret, ")");
100   return ret;
101 }
102 
Find(StringPiece attr_name) const103 const AttrValue* AttrSlice::Find(StringPiece attr_name) const {
104   // Currently, the collection used for NodeDef::attr() (google::protobuf::Map)
105   // requires that the keys used for lookups have type 'const string&'. Because
106   // this method takes a StringPiece, it is necessary to allocate a temporary
107   // string, copy attr_name to it, and then use that temporary string for the
108   // lookup. This causes an excessive number of short-lived allocations, and for
109   // large graphs, this can be a significant cost.
110   //
111   // Because most nodes have a small number of attributes, a simple linear scan
112   // is generally more efficient than a hashed lookup.  If google::protobuf::Map
113   // changes so that it supports efficient lookups using StringPiece instead of
114   // const string&, then this code could be changed to use attrs_->find() again.
115 
116   for (const auto& attr : *attrs_) {
117     if (attr.first == attr_name) {
118       return &attr.second;
119     }
120   }
121   return nullptr;
122 }
123 
Find(StringPiece attr_name,const AttrValue ** attr_value) const124 Status AttrSlice::Find(StringPiece attr_name,
125                        const AttrValue** attr_value) const {
126   *attr_value = Find(attr_name);
127   if (*attr_value != nullptr) {
128     return Status::OK();
129   }
130   Status s = errors::NotFound("No attr named '", attr_name, "' in NodeDef:");
131   // Skip AttachDef for internal attrs since it is a little bit
132   // expensive and it is common for them to correctly not be included
133   // in a NodeDef.
134   if (!attr_name.starts_with("_") && ndef_ != nullptr) {
135     s = AttachDef(s, *ndef_);
136   }
137   return s;
138 }
139 
EqualAttrs(AttrSlice other,Scratch * scratch) const140 bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const {
141   if (size() != other.size()) return false;
142 
143   for (const auto& attr : *other.attrs_) {
144     auto iter = attrs_->find(attr.first);
145     if (iter == attrs_->end()) return false;
146     // TODO(irving): Comparing AttrValues by proto is slightly buggy, since
147     // TensorProto is a nonunique representation of Tensor.  This bug will go
148     // away once AttrSlice switches over to NodeInfo.
149     iter->second.SerializeToString(&scratch->a);
150     attr.second.SerializeToString(&scratch->b);
151     if (scratch->a != scratch->b) return false;
152   }
153   return true;
154 }
155 
156 // The ... is to allow the caller to inject some value validation code.  Use
157 // just ; if no additional validation code is needed.
158 #define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...)         \
159   Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,           \
160                      TYPE* value) {                                           \
161     const AttrValue* attr_value;                                              \
162     TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));                   \
163     TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, ATTR_TYPE));             \
164     const auto& v = attr_value->FIELD();                                      \
165     __VA_ARGS__;                                                              \
166     *value = CAST;                                                            \
167     return Status::OK();                                                      \
168   }                                                                           \
169   Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,           \
170                      std::vector<TYPE>* value) {                              \
171     const AttrValue* attr_value;                                              \
172     TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));                   \
173     TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")")); \
174     for (const auto& v : attr_value->list().FIELD()) {                        \
175       __VA_ARGS__;                                                            \
176       value->APPEND_OP(CAST);                                                 \
177     }                                                                         \
178     return Status::OK();                                                      \
179   }
180 
181 #define DEFINE_GET_ATTR_SIMPLE(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
182   bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,      \
183                          TYPE* value) {                                      \
184     const AttrValue* attr_value = attrs.Find(attr_name);                     \
185     if (attr_value == nullptr) {                                             \
186       return false;                                                          \
187     }                                                                        \
188     Status s = AttrValueHasType(*attr_value, ATTR_TYPE);                     \
189     if (!s.ok()) {                                                           \
190       return false;                                                          \
191     }                                                                        \
192     const auto& v = attr_value->FIELD();                                     \
193     __VA_ARGS__;                                                             \
194     *value = CAST;                                                           \
195     return true;                                                             \
196   }                                                                          \
197   bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name,      \
198                          std::vector<TYPE>* value) {                         \
199     const AttrValue* attr_value = attrs.Find(attr_name);                     \
200     if (attr_value == nullptr) {                                             \
201       return false;                                                          \
202     }                                                                        \
203     Status s = AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")");         \
204     if (!s.ok()) {                                                           \
205       return false;                                                          \
206     }                                                                        \
207     for (const auto& v : attr_value->list().FIELD()) {                       \
208       __VA_ARGS__;                                                           \
209       value->APPEND_OP(CAST);                                                \
210     }                                                                        \
211     return true;                                                             \
212   }
213 
214 DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;)
215 DEFINE_GET_ATTR_SIMPLE(string, s, "string", emplace_back, v, ;)
216 DEFINE_GET_ATTR(int64, i, "int", emplace_back, v, ;)
217 DEFINE_GET_ATTR(int32, i, "int", emplace_back, static_cast<int32>(v),
218                 if (static_cast<int64>(static_cast<int32>(v)) != v) {
219                   return errors::InvalidArgument("Attr ", attr_name,
220                                                  " has value ", v,
221                                                  " out of range for an int32");
222                 })
223 DEFINE_GET_ATTR(float, f, "float", emplace_back, v, ;)
224 // std::vector<bool> specialization does not have emplace_back until
225 // c++14, so we have to use push_back (see
226 // http://en.cppreference.com/w/cpp/container/vector/emplace_back)
227 DEFINE_GET_ATTR(bool, b, "bool", push_back, v, ;)
228 DEFINE_GET_ATTR(DataType, type, "type", emplace_back, static_cast<DataType>(v),
229                 ;)
230 DEFINE_GET_ATTR(TensorShapeProto, shape, "shape", emplace_back, v, ;)
231 DEFINE_GET_ATTR(TensorShape, shape, "shape", emplace_back, TensorShape(v),
232                 TF_RETURN_IF_ERROR(TensorShape::IsValidShape(v));)
233 DEFINE_GET_ATTR(PartialTensorShape, shape, "shape", emplace_back,
234                 PartialTensorShape(v),
235                 TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(v));)
236 DEFINE_GET_ATTR(Tensor, tensor, "tensor", emplace_back, t, Tensor t;
237                 if (!t.FromProto(v)) {
238                   return errors::InvalidArgument(
239                       "Attr ", attr_name, " has value ",
240                       ProtoShortDebugString(v),
241                       " that can't be converted to a Tensor");
242                 })
243 DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;);
244 #undef DEFINE_GET_ATTR
245 
HasNodeAttr(const NodeDef & node_def,StringPiece attr_name)246 bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name) {
247   return node_def.attr().find(attr_name.ToString()) != node_def.attr().end();
248 }
249 
250 static const string& kEmptyString = *new string();
251 
GetNodeAttrString(const AttrSlice & attrs,StringPiece attr_name)252 const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name) {
253   const AttrValue* attr_value = attrs.Find(attr_name);
254   if (attr_value == nullptr) {
255     return kEmptyString;
256   }
257   Status s = AttrValueHasType(*attr_value, "string");
258   if (!s.ok()) {
259     return kEmptyString;
260   }
261   return attr_value->s();
262 }
263 
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,DataTypeVector * value)264 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
265                    DataTypeVector* value) {
266   const AttrValue* attr_value;
267   TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
268   TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(type)"));
269   for (const auto& v : attr_value->list().type()) {
270     value->push_back(static_cast<DataType>(v));
271   }
272   return Status::OK();
273 }
274 
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const TensorProto ** value)275 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
276                    const TensorProto** value) {
277   const AttrValue* attr_value;
278   TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
279   TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "tensor"));
280   *value = &attr_value->tensor();
281   return Status::OK();
282 }
283 
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const NameAttrList ** value)284 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
285                    const NameAttrList** value) {
286   const AttrValue* attr_value;
287   TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
288   TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func"));
289   *value = &attr_value->func();
290   return Status::OK();
291 }
292 
293 namespace {  // Helper for InOutTypesForNode().
294 
AddArgToSig(const NodeDef & node_def,const OpDef::ArgDef & arg_def,DataTypeVector * sig)295 Status AddArgToSig(const NodeDef& node_def, const OpDef::ArgDef& arg_def,
296                    DataTypeVector* sig) {
297   const int original_size = sig->size();
298   if (!arg_def.number_attr().empty()) {
299     // Same type repeated "repeats" times.
300     int32 repeats = -1;
301     TF_RETURN_IF_ERROR(GetNodeAttr(node_def, arg_def.number_attr(), &repeats));
302     if (repeats < 0) {
303       return errors::InvalidArgument("Value for number_attr() ", repeats,
304                                      " < 0");
305     }
306 
307     if (!arg_def.type_attr().empty()) {
308       DataType dtype;
309       TF_RETURN_IF_ERROR(GetNodeAttr(node_def, arg_def.type_attr(), &dtype));
310       for (int i = 0; i < repeats; ++i) {
311         sig->push_back(dtype);
312       }
313     } else if (arg_def.type() != DT_INVALID) {
314       for (int i = 0; i < repeats; ++i) {
315         sig->push_back(arg_def.type());
316       }
317     } else {
318       return errors::InvalidArgument("Missing type or type_attr field in ",
319                                      ProtoShortDebugString(arg_def));
320     }
321   } else if (!arg_def.type_attr().empty()) {
322     const AttrValue* attr_value;
323     TF_RETURN_IF_ERROR(
324         AttrSlice(node_def).Find(arg_def.type_attr(), &attr_value));
325     sig->push_back(attr_value->type());
326   } else if (!arg_def.type_list_attr().empty()) {
327     const AttrValue* attr_value;
328     TF_RETURN_IF_ERROR(
329         AttrSlice(node_def).Find(arg_def.type_list_attr(), &attr_value));
330     for (int dtype : attr_value->list().type()) {
331       sig->push_back(static_cast<DataType>(dtype));
332     }
333   } else if (arg_def.type() != DT_INVALID) {
334     sig->push_back(arg_def.type());
335   } else {
336     return errors::InvalidArgument("No type fields in ",
337                                    ProtoShortDebugString(arg_def));
338   }
339   if (arg_def.is_ref()) {
340     // For all types that were added by this function call, make them refs.
341     for (size_t i = original_size; i < sig->size(); ++i) {
342       (*sig)[i] = MakeRefType((*sig)[i]);
343     }
344   }
345   return Status::OK();
346 }
347 
348 }  // namespace
349 
InputTypeForNode(const NodeDef & node_def,const OpDef & op_def,int input_port,DataType * input_type)350 Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
351                         int input_port, DataType* input_type) {
352   DataTypeVector input_types;
353   for (const auto& arg : op_def.input_arg()) {
354     TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &input_types));
355     if (input_types.size() > input_port) {
356       const DataType dtype = input_types[input_port];
357       *input_type = dtype;
358       return Status::OK();
359     }
360   }
361   return errors::InvalidArgument("Input ", input_port, " not found for node ",
362                                  node_def.name());
363 }
364 
OutputTypeForNode(const NodeDef & node_def,const OpDef & op_def,int output_port,DataType * output_type)365 Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
366                          int output_port, DataType* output_type) {
367   DataTypeVector output_types;
368   for (const auto& arg : op_def.output_arg()) {
369     TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &output_types));
370     if (output_types.size() > output_port) {
371       const DataType dtype = output_types[output_port];
372       *output_type = dtype;
373       return Status::OK();
374     }
375   }
376   return errors::InvalidArgument("Output ", output_port, " not found for node ",
377                                  node_def.name());
378 }
379 
InOutTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * inputs,DataTypeVector * outputs)380 Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
381                          DataTypeVector* inputs, DataTypeVector* outputs) {
382   for (const auto& arg : op_def.input_arg()) {
383     TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
384   }
385   for (const auto& arg : op_def.output_arg()) {
386     TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, outputs));
387   }
388   return Status::OK();
389 }
390 
ValidateNodeDef(const NodeDef & node_def,const OpDef & op_def)391 Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
392   if (node_def.op() != op_def.name()) {
393     return errors::InvalidArgument("NodeDef op '", node_def.op(),
394                                    "' does not match ", SummarizeOpDef(op_def),
395                                    "; NodeDef: ", SummarizeNodeDef(node_def));
396   }
397 
398   bool seen_control = false;
399   size_t num_inputs = 0;
400   // TODO(josh11b): Unify the input field validation.
401   for (const string& input : node_def.input()) {
402     if (StringPiece(input).starts_with("^")) {
403       seen_control = true;
404       if (input.find(':') != string::npos) {
405         return errors::InvalidArgument(
406             "Control input '", input,
407             "' must not have ':' in NodeDef: ", SummarizeNodeDef(node_def));
408       }
409     } else if (seen_control) {
410       return errors::InvalidArgument(
411           "Non-control input '", input,
412           "' after control input in NodeDef: ", SummarizeNodeDef(node_def));
413     } else {
414       ++num_inputs;
415     }
416   }
417 
418   std::unordered_map<string, const OpDef::AttrDef*> op_attrs;
419   for (const auto& attr : op_def.attr()) {
420     if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) {
421       return errors::InvalidArgument("OpDef has duplicate attr name '",
422                                      attr.name(),
423                                      "': ", SummarizeOpDef(op_def));
424     }
425   }
426   for (const auto& attr : node_def.attr()) {
427     // Allow internal optional attributes with names starting with "_".
428     if (StringPiece(attr.first).starts_with("_")) {
429       continue;
430     }
431     auto iter = op_attrs.find(attr.first);
432     if (iter == op_attrs.end()) {
433       // A common cause of this error is that TensorFlow has made a
434       // backwards-compatible change to the NodeDef (e.g., adding a
435       // new attr with a default value), but the binary consuming the
436       // NodeDef does not know about the new attribute; the solution
437       // in these cases is to ensure that the binary consuming the
438       // NodeDef is built with a version of TensorFlow no earlier than
439       // the binary producing it.
440       return errors::InvalidArgument(
441           "NodeDef mentions attr '", attr.first, "' not in ",
442           SummarizeOpDef(op_def), "; NodeDef: ", SummarizeNodeDef(node_def),
443           ". (Check whether your GraphDef-interpreting binary is up to date "
444           "with your GraphDef-generating binary.).");
445     }
446     TF_RETURN_WITH_CONTEXT_IF_ERROR(
447         ValidateAttrValue(attr.second, *iter->second),
448         "; NodeDef: ", SummarizeNodeDef(node_def), "; ",
449         SummarizeOpDef(op_def));
450     // Keep track of which attr names have (not) been found in the NodeDef.
451     op_attrs.erase(iter);
452   }
453 
454   // Were all attrs in the OpDef found in the NodeDef?
455   if (!op_attrs.empty()) {
456     string attrs;
457     for (const auto& attr_pair : op_attrs) {
458       if (!attrs.empty()) strings::StrAppend(&attrs, "', '");
459       strings::StrAppend(&attrs, attr_pair.first);
460     }
461     return errors::InvalidArgument("NodeDef missing attr",
462                                    op_attrs.size() == 1 ? " '" : "s '", attrs,
463                                    "' from ", SummarizeOpDef(op_def),
464                                    "; NodeDef: ", SummarizeNodeDef(node_def));
465   }
466 
467   // Validate the number of inputs.
468   DataTypeVector inputs, outputs;
469   TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, op_def, &inputs, &outputs));
470 
471   if (num_inputs != inputs.size()) {
472     return errors::InvalidArgument(
473         "NodeDef expected inputs '", DataTypeVectorString(inputs),
474         "' do not match ", num_inputs, " inputs specified; ",
475         SummarizeOpDef(op_def), "; NodeDef: ", SummarizeNodeDef(node_def));
476   }
477 
478   return Status::OK();
479 }
480 
481 namespace {  // Helpers for NameRangesForNode()
482 
ComputeArgRange(const NodeDef & node_def,const OpDef::ArgDef & arg_def,const OpDef & op_def,int * num)483 Status ComputeArgRange(const NodeDef& node_def, const OpDef::ArgDef& arg_def,
484                        const OpDef& op_def, int* num) {
485   if (!arg_def.number_attr().empty()) {
486     // Same type repeated "num" times.
487     return GetNodeAttr(node_def, arg_def.number_attr(), num);
488   } else if (!arg_def.type_list_attr().empty()) {
489     const AttrValue* attr_value;
490     TF_RETURN_IF_ERROR(
491         AttrSlice(node_def).Find(arg_def.type_list_attr(), &attr_value));
492     *num = attr_value->list().type_size();
493   } else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) {
494     *num = 1;
495   } else {
496     return errors::InvalidArgument(
497         "Argument '", arg_def.name(),
498         "' incorrectly specified in op definition: ", SummarizeOpDef(op_def));
499   }
500   return Status::OK();
501 }
502 
NameRangesHelper(const NodeDef & node_def,const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const OpDef & op_def,NameRangeMap * result)503 Status NameRangesHelper(const NodeDef& node_def,
504                         const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
505                         const OpDef& op_def, NameRangeMap* result) {
506   int start = 0;
507   int num;
508   for (const auto& arg : args) {
509     TF_RETURN_IF_ERROR(ComputeArgRange(node_def, arg, op_def, &num));
510     (*result)[arg.name()] = std::make_pair(start, start + num);
511     start += num;
512   }
513   return Status::OK();
514 }
515 
516 }  // namespace
517 
NameRangesForNode(const NodeDef & node_def,const OpDef & op_def,NameRangeMap * inputs,NameRangeMap * outputs)518 Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def,
519                          NameRangeMap* inputs, NameRangeMap* outputs) {
520   if (inputs != nullptr) {
521     TF_RETURN_IF_ERROR(
522         NameRangesHelper(node_def, op_def.input_arg(), op_def, inputs));
523   }
524   if (outputs != nullptr) {
525     return NameRangesHelper(node_def, op_def.output_arg(), op_def, outputs);
526   }
527   return Status::OK();
528 }
529 
NameRangesForNode(const Node & node,const OpDef & op_def,NameRangeMap * inputs,NameRangeMap * outputs)530 Status NameRangesForNode(const Node& node, const OpDef& op_def,
531                          NameRangeMap* inputs, NameRangeMap* outputs) {
532   return NameRangesForNode(node.def(), op_def, inputs, outputs);
533 }
534 
AddDefaultsToNodeDef(const OpDef & op_def,NodeDef * node_def)535 void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) {
536   for (const auto& attr_def : op_def.attr()) {
537     AttrSlice attrs(*node_def);
538     if (attr_def.has_default_value() && !attrs.Find(attr_def.name())) {
539       AddNodeAttr(attr_def.name(), attr_def.default_value(), node_def);
540     }
541   }
542 }
543 
544 namespace {
545 
546 using ::tensorflow::strings::Scanner;
547 
IsValidOpName(StringPiece sp)548 bool IsValidOpName(StringPiece sp) {
549   return Scanner(sp)
550       .One(Scanner::LETTER_DIGIT_DOT)
551       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
552       .Eos()
553       .GetResult();
554 }
555 
IsValidDataInputName(StringPiece sp)556 bool IsValidDataInputName(StringPiece sp) {
557   // Data inputs are op_name, op_name:0, or op_name:12345.
558   Scanner scan(sp);
559   scan.One(Scanner::LETTER_DIGIT_DOT)
560       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
561   if (scan.Peek() == ':') {
562     scan.OneLiteral(":");
563     if (scan.Peek() == '0') {
564       scan.OneLiteral("0");  // :0
565     } else {
566       scan.Many(Scanner::DIGIT);  // :[1-9][0-9]*
567     }
568   }
569   scan.Eos();
570 
571   return scan.GetResult();
572 }
573 
IsValidControlInputName(StringPiece sp)574 bool IsValidControlInputName(StringPiece sp) {
575   return Scanner(sp)
576       .OneLiteral("^")
577       .One(Scanner::LETTER_DIGIT_DOT)
578       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
579       .Eos()
580       .GetResult();
581 }
582 
583 }  // namespace
584 
ValidateOpInput(const string & input_name,bool * is_control_input)585 Status ValidateOpInput(const string& input_name, bool* is_control_input) {
586   *is_control_input = false;
587   if (IsValidDataInputName(input_name)) {
588     return Status::OK();
589   } else if (IsValidControlInputName(input_name)) {
590     *is_control_input = true;
591     return Status::OK();
592   } else {
593     return errors::InvalidArgument("Illegal op input name '", input_name, "'");
594   }
595 }
596 
ValidateOpName(const string & op_name)597 Status ValidateOpName(const string& op_name) {
598   if (IsValidOpName(op_name)) {
599     return Status::OK();
600   } else {
601     return errors::InvalidArgument("Illegal op name '", op_name, "'");
602   }
603 }
604 
ValidateExternalNodeDefSyntax(const NodeDef & node_def)605 Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) {
606   Status s = ValidateOpName(node_def.name());
607   if (!s.ok()) {
608     return AttachDef(s, node_def);
609   }
610   bool in_control_inputs = false;
611   for (const string& input_name : node_def.input()) {
612     bool is_control_input;
613     s = ValidateOpInput(input_name, &is_control_input);
614     if (!s.ok()) {
615       return AttachDef(s, node_def);
616     }
617 
618     if (in_control_inputs && !is_control_input) {
619       return AttachDef(errors::InvalidArgument(
620                            "All control inputs must follow all data inputs"),
621                        node_def);
622     }
623     in_control_inputs = is_control_input;
624   }
625   return Status::OK();
626 }
627 
AttachDef(const Status & status,const NodeDef & node_def)628 Status AttachDef(const Status& status, const NodeDef& node_def) {
629   Status ret = status;
630   errors::AppendToMessage(
631       &ret, strings::StrCat(" [[Node: ", SummarizeNodeDef(node_def), "]]"));
632   return ret;
633 }
634 
AttachDef(const Status & status,const Node & node)635 Status AttachDef(const Status& status, const Node& node) {
636   return AttachDef(status, node.def());
637 }
638 
AddNodeAttr(StringPiece name,const AttrValue & value,NodeDef * node_def)639 void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) {
640   node_def->mutable_attr()->insert(
641       AttrValueMap::value_type(name.ToString(), value));
642 }
643 
644 #define ADD_NODE_ATTR(T)                                           \
645   void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \
646     AttrValue attr_value;                                          \
647     SetAttrValue(value, &attr_value);                              \
648     AddNodeAttr(name, attr_value, node_def);                       \
649   }
650 ADD_NODE_ATTR(StringPiece)
ADD_NODE_ATTR(const char *)651 ADD_NODE_ATTR(const char*)
652 ADD_NODE_ATTR(int32)
653 ADD_NODE_ATTR(int64)
654 ADD_NODE_ATTR(float)
655 ADD_NODE_ATTR(double)
656 ADD_NODE_ATTR(bool)
657 ADD_NODE_ATTR(DataType)
658 ADD_NODE_ATTR(const PartialTensorShape&)
659 ADD_NODE_ATTR(const Tensor&)
660 ADD_NODE_ATTR(const TensorProto&)
661 ADD_NODE_ATTR(const NameAttrList&)
662 ADD_NODE_ATTR(gtl::ArraySlice<StringPiece>)
663 ADD_NODE_ATTR(gtl::ArraySlice<const char*>)
664 ADD_NODE_ATTR(gtl::ArraySlice<string>)
665 ADD_NODE_ATTR(gtl::ArraySlice<int32>)
666 ADD_NODE_ATTR(gtl::ArraySlice<int64>)
667 ADD_NODE_ATTR(gtl::ArraySlice<float>)
668 ADD_NODE_ATTR(gtl::ArraySlice<bool>)
669 ADD_NODE_ATTR(const std::vector<bool>&)
670 ADD_NODE_ATTR(gtl::ArraySlice<DataType>)
671 ADD_NODE_ATTR(gtl::ArraySlice<TensorShape>)
672 ADD_NODE_ATTR(gtl::ArraySlice<PartialTensorShape>)
673 ADD_NODE_ATTR(gtl::ArraySlice<TensorShapeProto>)
674 ADD_NODE_ATTR(gtl::ArraySlice<Tensor>)
675 ADD_NODE_ATTR(gtl::ArraySlice<NameAttrList>)
676 #undef ADD_NODE_ATTR
677 
678 void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) {
679   map->insert(AttrValueMap::value_type(name.ToString(), value));
680 }
681 
682 #define ADD_ATTR(T)                                            \
683   void AddAttr(StringPiece name, T value, AttrValueMap* map) { \
684     AttrValue attr_value;                                      \
685     SetAttrValue(value, &attr_value);                          \
686     AddAttr(name, attr_value, map);                            \
687   }
688 ADD_ATTR(bool)
689 #undef ADD_ATTR
690 
691 }  // namespace tensorflow
692