1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/node_def_util.h"
17
18 #include <algorithm>
19 #include <unordered_map>
20 #include <vector>
21
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/graph.pb_text.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_def.pb_text.h"
27 #include "tensorflow/core/framework/op_def_util.h"
28 #include "tensorflow/core/framework/tensor.pb_text.h"
29 #include "tensorflow/core/framework/tensor_shape.pb.h"
30 #include "tensorflow/core/graph/graph.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/gtl/map_util.h"
33 #include "tensorflow/core/lib/strings/scanner.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/platform/protobuf.h"
36
37 namespace tensorflow {
38
39 const char* const kColocationAttrName = "_class";
40 const char* const kColocationGroupPrefix = "loc:@";
41
AttrSlice()42 AttrSlice::AttrSlice() : ndef_(nullptr) {
43 static const AttrValueMap* const kEmptyAttrValueMap = new AttrValueMap;
44 attrs_ = kEmptyAttrValueMap;
45 }
46
AttrSlice(const NodeDef & node_def)47 AttrSlice::AttrSlice(const NodeDef& node_def)
48 : ndef_(&node_def), attrs_(&ndef_->attr()) {}
49
AttrSlice(const AttrValueMap * a)50 AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {}
51
SummarizeAttrsHelper(AttrSlice attrs,StringPiece device)52 static string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) {
53 string ret;
54
55 // We sort the attrs so the output is deterministic.
56 std::vector<string> attr_names;
57 attr_names.reserve(attrs.size());
58 for (const auto& attr : attrs) {
59 attr_names.push_back(attr.first);
60 }
61 std::sort(attr_names.begin(), attr_names.end());
62 bool first = true;
63 for (const string& attr_name : attr_names) {
64 if (!first) strings::StrAppend(&ret, ", ");
65 first = false;
66 strings::StrAppend(&ret, attr_name, "=",
67 SummarizeAttrValue(*attrs.Find(attr_name)));
68 }
69
70 // Consider the device to be a final attr with name "_device".
71 if (!device.empty()) {
72 if (!first) strings::StrAppend(&ret, ", ");
73 first = false;
74 strings::StrAppend(&ret, "_device=\"", device, "\"");
75 }
76 return ret;
77 }
78
SummarizeNode() const79 string AttrSlice::SummarizeNode() const {
80 return ndef_ ? SummarizeNodeDef(*ndef_)
81 : strings::StrCat(
82 "[", SummarizeAttrsHelper(*this, StringPiece()), "]");
83 }
84
SummarizeNode(const Node & node)85 string SummarizeNode(const Node& node) { return SummarizeNodeDef(node.def()); }
86
SummarizeNodeDef(const NodeDef & node_def)87 string SummarizeNodeDef(const NodeDef& node_def) {
88 string ret = strings::StrCat(node_def.name(), " = ", node_def.op(), "[");
89 strings::StrAppend(&ret, SummarizeAttrsHelper(node_def, node_def.device()));
90 strings::StrAppend(&ret, "](");
91
92 // Output inputs, including control inputs, verbatim.
93 bool first = true;
94 for (const string& input : node_def.input()) {
95 if (!first) strings::StrAppend(&ret, ", ");
96 first = false;
97 strings::StrAppend(&ret, input);
98 }
99 strings::StrAppend(&ret, ")");
100 return ret;
101 }
102
Find(StringPiece attr_name) const103 const AttrValue* AttrSlice::Find(StringPiece attr_name) const {
104 // Currently, the collection used for NodeDef::attr() (google::protobuf::Map)
105 // requires that the keys used for lookups have type 'const string&'. Because
106 // this method takes a StringPiece, it is necessary to allocate a temporary
107 // string, copy attr_name to it, and then use that temporary string for the
108 // lookup. This causes an excessive number of short-lived allocations, and for
109 // large graphs, this can be a significant cost.
110 //
111 // Because most nodes have a small number of attributes, a simple linear scan
112 // is generally more efficient than a hashed lookup. If google::protobuf::Map
113 // changes so that it supports efficient lookups using StringPiece instead of
114 // const string&, then this code could be changed to use attrs_->find() again.
115
116 for (const auto& attr : *attrs_) {
117 if (attr.first == attr_name) {
118 return &attr.second;
119 }
120 }
121 return nullptr;
122 }
123
Find(StringPiece attr_name,const AttrValue ** attr_value) const124 Status AttrSlice::Find(StringPiece attr_name,
125 const AttrValue** attr_value) const {
126 *attr_value = Find(attr_name);
127 if (*attr_value != nullptr) {
128 return Status::OK();
129 }
130 Status s = errors::NotFound("No attr named '", attr_name, "' in NodeDef:");
131 // Skip AttachDef for internal attrs since it is a little bit
132 // expensive and it is common for them to correctly not be included
133 // in a NodeDef.
134 if (!attr_name.starts_with("_") && ndef_ != nullptr) {
135 s = AttachDef(s, *ndef_);
136 }
137 return s;
138 }
139
EqualAttrs(AttrSlice other,Scratch * scratch) const140 bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const {
141 if (size() != other.size()) return false;
142
143 for (const auto& attr : *other.attrs_) {
144 auto iter = attrs_->find(attr.first);
145 if (iter == attrs_->end()) return false;
146 // TODO(irving): Comparing AttrValues by proto is slightly buggy, since
147 // TensorProto is a nonunique representation of Tensor. This bug will go
148 // away once AttrSlice switches over to NodeInfo.
149 iter->second.SerializeToString(&scratch->a);
150 attr.second.SerializeToString(&scratch->b);
151 if (scratch->a != scratch->b) return false;
152 }
153 return true;
154 }
155
156 // The ... is to allow the caller to inject some value validation code. Use
157 // just ; if no additional validation code is needed.
158 #define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
159 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
160 TYPE* value) { \
161 const AttrValue* attr_value; \
162 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \
163 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, ATTR_TYPE)); \
164 const auto& v = attr_value->FIELD(); \
165 __VA_ARGS__; \
166 *value = CAST; \
167 return Status::OK(); \
168 } \
169 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
170 std::vector<TYPE>* value) { \
171 const AttrValue* attr_value; \
172 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \
173 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")")); \
174 for (const auto& v : attr_value->list().FIELD()) { \
175 __VA_ARGS__; \
176 value->APPEND_OP(CAST); \
177 } \
178 return Status::OK(); \
179 }
180
181 #define DEFINE_GET_ATTR_SIMPLE(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
182 bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name, \
183 TYPE* value) { \
184 const AttrValue* attr_value = attrs.Find(attr_name); \
185 if (attr_value == nullptr) { \
186 return false; \
187 } \
188 Status s = AttrValueHasType(*attr_value, ATTR_TYPE); \
189 if (!s.ok()) { \
190 return false; \
191 } \
192 const auto& v = attr_value->FIELD(); \
193 __VA_ARGS__; \
194 *value = CAST; \
195 return true; \
196 } \
197 bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name, \
198 std::vector<TYPE>* value) { \
199 const AttrValue* attr_value = attrs.Find(attr_name); \
200 if (attr_value == nullptr) { \
201 return false; \
202 } \
203 Status s = AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")"); \
204 if (!s.ok()) { \
205 return false; \
206 } \
207 for (const auto& v : attr_value->list().FIELD()) { \
208 __VA_ARGS__; \
209 value->APPEND_OP(CAST); \
210 } \
211 return true; \
212 }
213
214 DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;)
215 DEFINE_GET_ATTR_SIMPLE(string, s, "string", emplace_back, v, ;)
216 DEFINE_GET_ATTR(int64, i, "int", emplace_back, v, ;)
217 DEFINE_GET_ATTR(int32, i, "int", emplace_back, static_cast<int32>(v),
218 if (static_cast<int64>(static_cast<int32>(v)) != v) {
219 return errors::InvalidArgument("Attr ", attr_name,
220 " has value ", v,
221 " out of range for an int32");
222 })
223 DEFINE_GET_ATTR(float, f, "float", emplace_back, v, ;)
224 // std::vector<bool> specialization does not have emplace_back until
225 // c++14, so we have to use push_back (see
226 // http://en.cppreference.com/w/cpp/container/vector/emplace_back)
227 DEFINE_GET_ATTR(bool, b, "bool", push_back, v, ;)
228 DEFINE_GET_ATTR(DataType, type, "type", emplace_back, static_cast<DataType>(v),
229 ;)
230 DEFINE_GET_ATTR(TensorShapeProto, shape, "shape", emplace_back, v, ;)
231 DEFINE_GET_ATTR(TensorShape, shape, "shape", emplace_back, TensorShape(v),
232 TF_RETURN_IF_ERROR(TensorShape::IsValidShape(v));)
233 DEFINE_GET_ATTR(PartialTensorShape, shape, "shape", emplace_back,
234 PartialTensorShape(v),
235 TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(v));)
236 DEFINE_GET_ATTR(Tensor, tensor, "tensor", emplace_back, t, Tensor t;
237 if (!t.FromProto(v)) {
238 return errors::InvalidArgument(
239 "Attr ", attr_name, " has value ",
240 ProtoShortDebugString(v),
241 " that can't be converted to a Tensor");
242 })
243 DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;);
244 #undef DEFINE_GET_ATTR
245
HasNodeAttr(const NodeDef & node_def,StringPiece attr_name)246 bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name) {
247 return node_def.attr().find(attr_name.ToString()) != node_def.attr().end();
248 }
249
250 static const string& kEmptyString = *new string();
251
GetNodeAttrString(const AttrSlice & attrs,StringPiece attr_name)252 const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name) {
253 const AttrValue* attr_value = attrs.Find(attr_name);
254 if (attr_value == nullptr) {
255 return kEmptyString;
256 }
257 Status s = AttrValueHasType(*attr_value, "string");
258 if (!s.ok()) {
259 return kEmptyString;
260 }
261 return attr_value->s();
262 }
263
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,DataTypeVector * value)264 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
265 DataTypeVector* value) {
266 const AttrValue* attr_value;
267 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
268 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(type)"));
269 for (const auto& v : attr_value->list().type()) {
270 value->push_back(static_cast<DataType>(v));
271 }
272 return Status::OK();
273 }
274
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const TensorProto ** value)275 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
276 const TensorProto** value) {
277 const AttrValue* attr_value;
278 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
279 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "tensor"));
280 *value = &attr_value->tensor();
281 return Status::OK();
282 }
283
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const NameAttrList ** value)284 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
285 const NameAttrList** value) {
286 const AttrValue* attr_value;
287 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
288 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func"));
289 *value = &attr_value->func();
290 return Status::OK();
291 }
292
293 namespace { // Helper for InOutTypesForNode().
294
AddArgToSig(const NodeDef & node_def,const OpDef::ArgDef & arg_def,DataTypeVector * sig)295 Status AddArgToSig(const NodeDef& node_def, const OpDef::ArgDef& arg_def,
296 DataTypeVector* sig) {
297 const int original_size = sig->size();
298 if (!arg_def.number_attr().empty()) {
299 // Same type repeated "repeats" times.
300 int32 repeats = -1;
301 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, arg_def.number_attr(), &repeats));
302 if (repeats < 0) {
303 return errors::InvalidArgument("Value for number_attr() ", repeats,
304 " < 0");
305 }
306
307 if (!arg_def.type_attr().empty()) {
308 DataType dtype;
309 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, arg_def.type_attr(), &dtype));
310 for (int i = 0; i < repeats; ++i) {
311 sig->push_back(dtype);
312 }
313 } else if (arg_def.type() != DT_INVALID) {
314 for (int i = 0; i < repeats; ++i) {
315 sig->push_back(arg_def.type());
316 }
317 } else {
318 return errors::InvalidArgument("Missing type or type_attr field in ",
319 ProtoShortDebugString(arg_def));
320 }
321 } else if (!arg_def.type_attr().empty()) {
322 const AttrValue* attr_value;
323 TF_RETURN_IF_ERROR(
324 AttrSlice(node_def).Find(arg_def.type_attr(), &attr_value));
325 sig->push_back(attr_value->type());
326 } else if (!arg_def.type_list_attr().empty()) {
327 const AttrValue* attr_value;
328 TF_RETURN_IF_ERROR(
329 AttrSlice(node_def).Find(arg_def.type_list_attr(), &attr_value));
330 for (int dtype : attr_value->list().type()) {
331 sig->push_back(static_cast<DataType>(dtype));
332 }
333 } else if (arg_def.type() != DT_INVALID) {
334 sig->push_back(arg_def.type());
335 } else {
336 return errors::InvalidArgument("No type fields in ",
337 ProtoShortDebugString(arg_def));
338 }
339 if (arg_def.is_ref()) {
340 // For all types that were added by this function call, make them refs.
341 for (size_t i = original_size; i < sig->size(); ++i) {
342 (*sig)[i] = MakeRefType((*sig)[i]);
343 }
344 }
345 return Status::OK();
346 }
347
348 } // namespace
349
InputTypeForNode(const NodeDef & node_def,const OpDef & op_def,int input_port,DataType * input_type)350 Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
351 int input_port, DataType* input_type) {
352 DataTypeVector input_types;
353 for (const auto& arg : op_def.input_arg()) {
354 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &input_types));
355 if (input_types.size() > input_port) {
356 const DataType dtype = input_types[input_port];
357 *input_type = dtype;
358 return Status::OK();
359 }
360 }
361 return errors::InvalidArgument("Input ", input_port, " not found for node ",
362 node_def.name());
363 }
364
OutputTypeForNode(const NodeDef & node_def,const OpDef & op_def,int output_port,DataType * output_type)365 Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
366 int output_port, DataType* output_type) {
367 DataTypeVector output_types;
368 for (const auto& arg : op_def.output_arg()) {
369 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &output_types));
370 if (output_types.size() > output_port) {
371 const DataType dtype = output_types[output_port];
372 *output_type = dtype;
373 return Status::OK();
374 }
375 }
376 return errors::InvalidArgument("Output ", output_port, " not found for node ",
377 node_def.name());
378 }
379
InOutTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * inputs,DataTypeVector * outputs)380 Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
381 DataTypeVector* inputs, DataTypeVector* outputs) {
382 for (const auto& arg : op_def.input_arg()) {
383 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
384 }
385 for (const auto& arg : op_def.output_arg()) {
386 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, outputs));
387 }
388 return Status::OK();
389 }
390
ValidateNodeDef(const NodeDef & node_def,const OpDef & op_def)391 Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
392 if (node_def.op() != op_def.name()) {
393 return errors::InvalidArgument("NodeDef op '", node_def.op(),
394 "' does not match ", SummarizeOpDef(op_def),
395 "; NodeDef: ", SummarizeNodeDef(node_def));
396 }
397
398 bool seen_control = false;
399 size_t num_inputs = 0;
400 // TODO(josh11b): Unify the input field validation.
401 for (const string& input : node_def.input()) {
402 if (StringPiece(input).starts_with("^")) {
403 seen_control = true;
404 if (input.find(':') != string::npos) {
405 return errors::InvalidArgument(
406 "Control input '", input,
407 "' must not have ':' in NodeDef: ", SummarizeNodeDef(node_def));
408 }
409 } else if (seen_control) {
410 return errors::InvalidArgument(
411 "Non-control input '", input,
412 "' after control input in NodeDef: ", SummarizeNodeDef(node_def));
413 } else {
414 ++num_inputs;
415 }
416 }
417
418 std::unordered_map<string, const OpDef::AttrDef*> op_attrs;
419 for (const auto& attr : op_def.attr()) {
420 if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) {
421 return errors::InvalidArgument("OpDef has duplicate attr name '",
422 attr.name(),
423 "': ", SummarizeOpDef(op_def));
424 }
425 }
426 for (const auto& attr : node_def.attr()) {
427 // Allow internal optional attributes with names starting with "_".
428 if (StringPiece(attr.first).starts_with("_")) {
429 continue;
430 }
431 auto iter = op_attrs.find(attr.first);
432 if (iter == op_attrs.end()) {
433 // A common cause of this error is that TensorFlow has made a
434 // backwards-compatible change to the NodeDef (e.g., adding a
435 // new attr with a default value), but the binary consuming the
436 // NodeDef does not know about the new attribute; the solution
437 // in these cases is to ensure that the binary consuming the
438 // NodeDef is built with a version of TensorFlow no earlier than
439 // the binary producing it.
440 return errors::InvalidArgument(
441 "NodeDef mentions attr '", attr.first, "' not in ",
442 SummarizeOpDef(op_def), "; NodeDef: ", SummarizeNodeDef(node_def),
443 ". (Check whether your GraphDef-interpreting binary is up to date "
444 "with your GraphDef-generating binary.).");
445 }
446 TF_RETURN_WITH_CONTEXT_IF_ERROR(
447 ValidateAttrValue(attr.second, *iter->second),
448 "; NodeDef: ", SummarizeNodeDef(node_def), "; ",
449 SummarizeOpDef(op_def));
450 // Keep track of which attr names have (not) been found in the NodeDef.
451 op_attrs.erase(iter);
452 }
453
454 // Were all attrs in the OpDef found in the NodeDef?
455 if (!op_attrs.empty()) {
456 string attrs;
457 for (const auto& attr_pair : op_attrs) {
458 if (!attrs.empty()) strings::StrAppend(&attrs, "', '");
459 strings::StrAppend(&attrs, attr_pair.first);
460 }
461 return errors::InvalidArgument("NodeDef missing attr",
462 op_attrs.size() == 1 ? " '" : "s '", attrs,
463 "' from ", SummarizeOpDef(op_def),
464 "; NodeDef: ", SummarizeNodeDef(node_def));
465 }
466
467 // Validate the number of inputs.
468 DataTypeVector inputs, outputs;
469 TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, op_def, &inputs, &outputs));
470
471 if (num_inputs != inputs.size()) {
472 return errors::InvalidArgument(
473 "NodeDef expected inputs '", DataTypeVectorString(inputs),
474 "' do not match ", num_inputs, " inputs specified; ",
475 SummarizeOpDef(op_def), "; NodeDef: ", SummarizeNodeDef(node_def));
476 }
477
478 return Status::OK();
479 }
480
481 namespace { // Helpers for NameRangesForNode()
482
ComputeArgRange(const NodeDef & node_def,const OpDef::ArgDef & arg_def,const OpDef & op_def,int * num)483 Status ComputeArgRange(const NodeDef& node_def, const OpDef::ArgDef& arg_def,
484 const OpDef& op_def, int* num) {
485 if (!arg_def.number_attr().empty()) {
486 // Same type repeated "num" times.
487 return GetNodeAttr(node_def, arg_def.number_attr(), num);
488 } else if (!arg_def.type_list_attr().empty()) {
489 const AttrValue* attr_value;
490 TF_RETURN_IF_ERROR(
491 AttrSlice(node_def).Find(arg_def.type_list_attr(), &attr_value));
492 *num = attr_value->list().type_size();
493 } else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) {
494 *num = 1;
495 } else {
496 return errors::InvalidArgument(
497 "Argument '", arg_def.name(),
498 "' incorrectly specified in op definition: ", SummarizeOpDef(op_def));
499 }
500 return Status::OK();
501 }
502
NameRangesHelper(const NodeDef & node_def,const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const OpDef & op_def,NameRangeMap * result)503 Status NameRangesHelper(const NodeDef& node_def,
504 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
505 const OpDef& op_def, NameRangeMap* result) {
506 int start = 0;
507 int num;
508 for (const auto& arg : args) {
509 TF_RETURN_IF_ERROR(ComputeArgRange(node_def, arg, op_def, &num));
510 (*result)[arg.name()] = std::make_pair(start, start + num);
511 start += num;
512 }
513 return Status::OK();
514 }
515
516 } // namespace
517
NameRangesForNode(const NodeDef & node_def,const OpDef & op_def,NameRangeMap * inputs,NameRangeMap * outputs)518 Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def,
519 NameRangeMap* inputs, NameRangeMap* outputs) {
520 if (inputs != nullptr) {
521 TF_RETURN_IF_ERROR(
522 NameRangesHelper(node_def, op_def.input_arg(), op_def, inputs));
523 }
524 if (outputs != nullptr) {
525 return NameRangesHelper(node_def, op_def.output_arg(), op_def, outputs);
526 }
527 return Status::OK();
528 }
529
NameRangesForNode(const Node & node,const OpDef & op_def,NameRangeMap * inputs,NameRangeMap * outputs)530 Status NameRangesForNode(const Node& node, const OpDef& op_def,
531 NameRangeMap* inputs, NameRangeMap* outputs) {
532 return NameRangesForNode(node.def(), op_def, inputs, outputs);
533 }
534
AddDefaultsToNodeDef(const OpDef & op_def,NodeDef * node_def)535 void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) {
536 for (const auto& attr_def : op_def.attr()) {
537 AttrSlice attrs(*node_def);
538 if (attr_def.has_default_value() && !attrs.Find(attr_def.name())) {
539 AddNodeAttr(attr_def.name(), attr_def.default_value(), node_def);
540 }
541 }
542 }
543
544 namespace {
545
546 using ::tensorflow::strings::Scanner;
547
IsValidOpName(StringPiece sp)548 bool IsValidOpName(StringPiece sp) {
549 return Scanner(sp)
550 .One(Scanner::LETTER_DIGIT_DOT)
551 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
552 .Eos()
553 .GetResult();
554 }
555
IsValidDataInputName(StringPiece sp)556 bool IsValidDataInputName(StringPiece sp) {
557 // Data inputs are op_name, op_name:0, or op_name:12345.
558 Scanner scan(sp);
559 scan.One(Scanner::LETTER_DIGIT_DOT)
560 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
561 if (scan.Peek() == ':') {
562 scan.OneLiteral(":");
563 if (scan.Peek() == '0') {
564 scan.OneLiteral("0"); // :0
565 } else {
566 scan.Many(Scanner::DIGIT); // :[1-9][0-9]*
567 }
568 }
569 scan.Eos();
570
571 return scan.GetResult();
572 }
573
IsValidControlInputName(StringPiece sp)574 bool IsValidControlInputName(StringPiece sp) {
575 return Scanner(sp)
576 .OneLiteral("^")
577 .One(Scanner::LETTER_DIGIT_DOT)
578 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
579 .Eos()
580 .GetResult();
581 }
582
583 } // namespace
584
ValidateOpInput(const string & input_name,bool * is_control_input)585 Status ValidateOpInput(const string& input_name, bool* is_control_input) {
586 *is_control_input = false;
587 if (IsValidDataInputName(input_name)) {
588 return Status::OK();
589 } else if (IsValidControlInputName(input_name)) {
590 *is_control_input = true;
591 return Status::OK();
592 } else {
593 return errors::InvalidArgument("Illegal op input name '", input_name, "'");
594 }
595 }
596
ValidateOpName(const string & op_name)597 Status ValidateOpName(const string& op_name) {
598 if (IsValidOpName(op_name)) {
599 return Status::OK();
600 } else {
601 return errors::InvalidArgument("Illegal op name '", op_name, "'");
602 }
603 }
604
ValidateExternalNodeDefSyntax(const NodeDef & node_def)605 Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) {
606 Status s = ValidateOpName(node_def.name());
607 if (!s.ok()) {
608 return AttachDef(s, node_def);
609 }
610 bool in_control_inputs = false;
611 for (const string& input_name : node_def.input()) {
612 bool is_control_input;
613 s = ValidateOpInput(input_name, &is_control_input);
614 if (!s.ok()) {
615 return AttachDef(s, node_def);
616 }
617
618 if (in_control_inputs && !is_control_input) {
619 return AttachDef(errors::InvalidArgument(
620 "All control inputs must follow all data inputs"),
621 node_def);
622 }
623 in_control_inputs = is_control_input;
624 }
625 return Status::OK();
626 }
627
AttachDef(const Status & status,const NodeDef & node_def)628 Status AttachDef(const Status& status, const NodeDef& node_def) {
629 Status ret = status;
630 errors::AppendToMessage(
631 &ret, strings::StrCat(" [[Node: ", SummarizeNodeDef(node_def), "]]"));
632 return ret;
633 }
634
AttachDef(const Status & status,const Node & node)635 Status AttachDef(const Status& status, const Node& node) {
636 return AttachDef(status, node.def());
637 }
638
AddNodeAttr(StringPiece name,const AttrValue & value,NodeDef * node_def)639 void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) {
640 node_def->mutable_attr()->insert(
641 AttrValueMap::value_type(name.ToString(), value));
642 }
643
644 #define ADD_NODE_ATTR(T) \
645 void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \
646 AttrValue attr_value; \
647 SetAttrValue(value, &attr_value); \
648 AddNodeAttr(name, attr_value, node_def); \
649 }
650 ADD_NODE_ATTR(StringPiece)
ADD_NODE_ATTR(const char *)651 ADD_NODE_ATTR(const char*)
652 ADD_NODE_ATTR(int32)
653 ADD_NODE_ATTR(int64)
654 ADD_NODE_ATTR(float)
655 ADD_NODE_ATTR(double)
656 ADD_NODE_ATTR(bool)
657 ADD_NODE_ATTR(DataType)
658 ADD_NODE_ATTR(const PartialTensorShape&)
659 ADD_NODE_ATTR(const Tensor&)
660 ADD_NODE_ATTR(const TensorProto&)
661 ADD_NODE_ATTR(const NameAttrList&)
662 ADD_NODE_ATTR(gtl::ArraySlice<StringPiece>)
663 ADD_NODE_ATTR(gtl::ArraySlice<const char*>)
664 ADD_NODE_ATTR(gtl::ArraySlice<string>)
665 ADD_NODE_ATTR(gtl::ArraySlice<int32>)
666 ADD_NODE_ATTR(gtl::ArraySlice<int64>)
667 ADD_NODE_ATTR(gtl::ArraySlice<float>)
668 ADD_NODE_ATTR(gtl::ArraySlice<bool>)
669 ADD_NODE_ATTR(const std::vector<bool>&)
670 ADD_NODE_ATTR(gtl::ArraySlice<DataType>)
671 ADD_NODE_ATTR(gtl::ArraySlice<TensorShape>)
672 ADD_NODE_ATTR(gtl::ArraySlice<PartialTensorShape>)
673 ADD_NODE_ATTR(gtl::ArraySlice<TensorShapeProto>)
674 ADD_NODE_ATTR(gtl::ArraySlice<Tensor>)
675 ADD_NODE_ATTR(gtl::ArraySlice<NameAttrList>)
676 #undef ADD_NODE_ATTR
677
678 void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) {
679 map->insert(AttrValueMap::value_type(name.ToString(), value));
680 }
681
682 #define ADD_ATTR(T) \
683 void AddAttr(StringPiece name, T value, AttrValueMap* map) { \
684 AttrValue attr_value; \
685 SetAttrValue(value, &attr_value); \
686 AddAttr(name, attr_value, map); \
687 }
688 ADD_ATTR(bool)
689 #undef ADD_ATTR
690
691 } // namespace tensorflow
692