1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/node_def_util.h"
17
18 #include "tensorflow/core/framework/attr_value.pb.h" // NOLINT
19 #include "tensorflow/core/framework/fake_input.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/framework/op_def_builder.h"
22 #include "tensorflow/core/framework/op_def_util.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/graph/node_builder.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/platform/protobuf.h"
29 #include "tensorflow/core/platform/test.h"
30
31 namespace tensorflow {
32 namespace {
33
ToOpDef(const OpDefBuilder & builder)34 OpDef ToOpDef(const OpDefBuilder& builder) {
35 OpRegistrationData op_reg_data;
36 TF_EXPECT_OK(builder.Finalize(&op_reg_data));
37 return op_reg_data.op_def;
38 }
39
ToNodeDef(const string & text)40 NodeDef ToNodeDef(const string& text) {
41 NodeDef node_def;
42 EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
43 return node_def;
44 }
45
ToNodeDef(const NodeDefBuilder & builder)46 NodeDef ToNodeDef(const NodeDefBuilder& builder) {
47 NodeDef node_def;
48 TF_EXPECT_OK(builder.Finalize(&node_def));
49 return node_def;
50 }
51
ExpectSuccess(const NodeDef & good,const OpDef & op_def)52 void ExpectSuccess(const NodeDef& good, const OpDef& op_def) {
53 EXPECT_EQ(Status::OK(), ValidateNodeDef(good, op_def))
54 << "NodeDef: " << SummarizeNodeDef(good)
55 << "; OpDef: " << SummarizeOpDef(op_def);
56 }
57
ExpectFailure(const NodeDef & bad,const OpDef & op_def,const string & message)58 void ExpectFailure(const NodeDef& bad, const OpDef& op_def,
59 const string& message) {
60 Status status = ValidateNodeDef(bad, op_def);
61
62 EXPECT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad)
63 << "; OpDef: " << SummarizeOpDef(op_def);
64 if (status.ok()) return;
65
66 EXPECT_TRUE(errors::IsInvalidArgument(status))
67 << status << "; NodeDef: " << SummarizeNodeDef(bad)
68 << "; OpDef: " << SummarizeOpDef(op_def);
69
70 LOG(INFO) << "Message: " << status.error_message();
71 EXPECT_TRUE(str_util::StrContains(status.ToString(), message))
72 << "NodeDef: " << SummarizeNodeDef(bad)
73 << "; OpDef: " << SummarizeOpDef(op_def) << "\nActual error: " << status
74 << "\nDoes not contain: " << message;
75 }
76
TEST(NodeDefUtilTest,In)77 TEST(NodeDefUtilTest, In) {
78 const OpDef op = ToOpDef(OpDefBuilder("In").Input("i: T").Attr("T: type"));
79 const NodeDef node_def = ToNodeDef(R"proto(
80 name:'n' op:'In' input:'a' attr { key:'T' value { type:DT_FLOAT } }
81 )proto");
82 ExpectSuccess(node_def, op);
83
84 EXPECT_EQ("{{node n}} = In[T=DT_FLOAT](a)", SummarizeNodeDef(node_def));
85
86 // Mismatching Op names.
87 NodeDef bad = node_def;
88 bad.set_op("Wrong");
89 ExpectFailure(bad, op, "NodeDef op 'Wrong' does not match Op<name=In;");
90
91 // Missing attr
92 bad = node_def;
93 bad.clear_attr();
94 ExpectFailure(bad, op, "NodeDef missing attr 'T' from Op<name=In;");
95
96 // Extra attr
97 bad = node_def;
98 AddNodeAttr("EXTRA", 17, &bad);
99 ExpectFailure(bad, op, "NodeDef mentions attr 'EXTRA' not in Op<name=In;");
100
101 // Attr has wrong type
102 bad = node_def;
103 bad.clear_attr();
104 AddNodeAttr("T", 17, &bad);
105 ExpectFailure(
106 bad, op,
107 "AttrValue had value with type 'int' when 'type' expected\n\t for attr "
108 "'T'\n\t; NodeDef: ");
109
110 // Wrong number of inputs
111 bad = node_def;
112 bad.add_input("b");
113 ExpectFailure(
114 bad, op,
115 "NodeDef expected inputs 'float' do not match 2 inputs specified;");
116
117 bad = node_def;
118 bad.clear_input();
119 ExpectFailure(
120 bad, op,
121 "NodeDef expected inputs 'float' do not match 0 inputs specified;");
122
123 // Control inputs must appear after data inputs
124 NodeDef good = node_def;
125 good.add_input("^b");
126 ExpectSuccess(node_def, op);
127
128 bad = node_def;
129 bad.clear_input();
130 bad.add_input("^b");
131 bad.add_input("a");
132 ExpectFailure(bad, op,
133 "Invalid argument: Non-control input 'a' after control input "
134 "in NodeDef:");
135
136 bad = node_def;
137 bad.add_input("^b:0");
138 ExpectFailure(bad, op, "Control input '^b:0' must not have ':' in NodeDef:");
139 }
140
TEST(NodeDefUtilTest,Out)141 TEST(NodeDefUtilTest, Out) {
142 const OpDef op =
143 ToOpDef(OpDefBuilder("Out").Output("o: T").Attr("T: numbertype"));
144 const NodeDef node_def = ToNodeDef(R"proto(
145 name:'n' op:'Out' attr { key:'T' value { type:DT_INT32 } }
146 )proto");
147 ExpectSuccess(node_def, op);
148
149 EXPECT_EQ("{{node n}} = Out[T=DT_INT32]()", SummarizeNodeDef(node_def));
150
151 // Non-number type.
152 NodeDef bad = node_def;
153 bad.clear_attr();
154 AddNodeAttr("T", DT_STRING, &bad);
155 ExpectFailure(bad, op,
156 "Value for attr 'T' of string is not in the list of allowed "
157 "values: float, double, int32, uint8, int16, int8, complex64, "
158 "int64, qint8, quint8, qint32, bfloat16, uint16, complex128, "
159 "half, uint32, uint64");
160 }
161
TEST(NodeDefUtilTest,Enum)162 TEST(NodeDefUtilTest, Enum) {
163 const OpDef op = ToOpDef(OpDefBuilder("Enum").Attr("e: {'apple','orange'}"));
164 const NodeDef node_def = ToNodeDef(R"proto(
165 name:'n' op:'Enum' attr { key:'e' value { s:'apple' } }
166 )proto");
167 ExpectSuccess(node_def, op);
168
169 EXPECT_EQ("{{node n}} = Enum[e=\"apple\"]()", SummarizeNodeDef(node_def));
170
171 NodeDef good = node_def;
172 good.clear_attr();
173 AddNodeAttr("e", "orange", &good);
174 ExpectSuccess(good, op);
175
176 // Non-allowed value.
177 NodeDef bad = node_def;
178 bad.clear_attr();
179 AddNodeAttr("e", "foo", &bad);
180 ExpectFailure(bad, op,
181 "Value for attr 'e' of \"foo\" is not in the list of allowed "
182 "values: \"apple\", \"orange\"");
183 }
184
TEST(NodeDefUtilTest,SameIn)185 TEST(NodeDefUtilTest, SameIn) {
186 const OpDef op = ToOpDef(OpDefBuilder("SameIn")
187 .Input("i: N * T")
188 .Attr("N: int >= 2")
189 .Attr("T: {float,double}"));
190 const NodeDef node_def = ToNodeDef(R"proto(
191 name:'n' op:'SameIn' input:'a' input:'b'
192 attr { key:'N' value { i:2 } } attr { key:'T' value { type:DT_DOUBLE } }
193 )proto");
194 ExpectSuccess(node_def, op);
195
196 EXPECT_EQ("{{node n}} = SameIn[N=2, T=DT_DOUBLE](a, b)",
197 SummarizeNodeDef(node_def));
198
199 // Illegal type
200 NodeDef bad = ToNodeDef(R"proto(
201 name:'n' op:'SameIn' input:'a' input:'b'
202 attr { key:'N' value { i:2 } } attr { key:'T' value { type:DT_STRING } }
203 )proto");
204 ExpectFailure(bad, op,
205 "Value for attr 'T' of string is not in the list of allowed "
206 "values: float, double");
207
208 // Too few inputs
209 bad = ToNodeDef(R"proto(
210 name:'n' op:'SameIn' input:'a' input:'b'
211 attr { key:'N' value { i:1 } } attr { key:'T' value { type:DT_FLOAT } }
212 )proto");
213 ExpectFailure(bad, op, "Value for attr 'N' of 1 must be at least minimum 2");
214 }
215
TEST(NodeDefUtilTest,AnyIn)216 TEST(NodeDefUtilTest, AnyIn) {
217 const OpDef op =
218 ToOpDef(OpDefBuilder("AnyIn").Input("i: T").Attr("T: list(type) >= 1"));
219
220 const NodeDef node_def = ToNodeDef(R"proto(
221 name:'n' op:'AnyIn' input:'a' input:'b'
222 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
223 )proto");
224 ExpectSuccess(node_def, op);
225
226 EXPECT_EQ("{{node n}} = AnyIn[T=[DT_INT32, DT_STRING]](a, b)",
227 SummarizeNodeDef(node_def));
228
229 const NodeDef bad = ToNodeDef(R"proto(
230 name:'n' op:'AnyIn' input:'a' attr { key:'T' value { list { } } }
231 )proto");
232 ExpectFailure(bad, op, "Length for attr 'T' of 0 must be at least minimum 1");
233
234 // With proto3 semantics, an empty value {} is indistinguishable from a value
235 // with an empty list in it. So we simply expect to get a message complaining
236 // about empty list for value {}.
237 const NodeDef bad2 = ToNodeDef(R"proto(
238 name:'n' op:'AnyIn' input:'a' attr { key:'T' value { } }
239 )proto");
240 ExpectFailure(bad2, op,
241 "Length for attr 'T' of 0 must be at least minimum 1");
242 }
243
TEST(NodeDefUtilTest,Device)244 TEST(NodeDefUtilTest, Device) {
245 const OpDef op_def1 = ToOpDef(OpDefBuilder("None"));
246 const NodeDef node_def1 =
247 ToNodeDef(NodeDefBuilder("d", &op_def1).Device("/cpu:17"));
248 ExpectSuccess(node_def1, op_def1);
249 EXPECT_EQ("{{node d}} = None[_device=\"/cpu:17\"]()",
250 SummarizeNodeDef(node_def1));
251
252 const OpDef op_def2 = ToOpDef(OpDefBuilder("WithAttr").Attr("v: int"));
253 const NodeDef node_def2 =
254 ToNodeDef(NodeDefBuilder("d", &op_def2).Attr("v", 7).Device("/cpu:5"));
255 ExpectSuccess(node_def2, op_def2);
256 EXPECT_EQ("{{node d}} = WithAttr[v=7, _device=\"/cpu:5\"]()",
257 SummarizeNodeDef(node_def2));
258 }
259
ExpectValidSyntax(const NodeDef & good)260 void ExpectValidSyntax(const NodeDef& good) {
261 EXPECT_EQ(Status::OK(), ValidateExternalNodeDefSyntax(good))
262 << "NodeDef: " << SummarizeNodeDef(good);
263 }
264
ExpectInvalidSyntax(const NodeDef & bad,const string & message)265 void ExpectInvalidSyntax(const NodeDef& bad, const string& message) {
266 Status status = ValidateExternalNodeDefSyntax(bad);
267
268 ASSERT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad);
269
270 EXPECT_TRUE(errors::IsInvalidArgument(status))
271 << status << "; NodeDef: " << SummarizeNodeDef(bad);
272
273 EXPECT_TRUE(str_util::StrContains(StringPiece(status.ToString()), message))
274 << "NodeDef: " << SummarizeNodeDef(bad) << ", " << status << ", "
275 << message;
276 }
277
TEST(NodeDefUtilTest,ValidSyntax)278 TEST(NodeDefUtilTest, ValidSyntax) {
279 const NodeDef node_def = ToNodeDef(R"proto(
280 name:'n' op:'AnyIn' input:'a' input:'b'
281 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
282 )proto");
283 ExpectValidSyntax(node_def);
284
285 const NodeDef node_def_explicit_inputs = ToNodeDef(R"proto(
286 name:'n' op:'AnyIn' input:'a:0' input:'b:123'
287 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
288 )proto");
289 ExpectValidSyntax(node_def_explicit_inputs);
290
291 EXPECT_EQ("{{node n}} = AnyIn[T=[DT_INT32, DT_STRING]](a:0, b:123)",
292 SummarizeNodeDef(node_def_explicit_inputs));
293
294 const NodeDef node_def_partial_shape = ToNodeDef(R"proto(
295 name:'n' op:'AnyIn'
296 attr { key:'shp' value { shape { dim { size: -1 } dim { size: 0 } } } }
297 )proto");
298 ExpectValidSyntax(node_def_partial_shape);
299
300 const NodeDef node_def_control_input = ToNodeDef(R"proto(
301 name:'n-' op:'AnyIn' input:'a' input:'^b'
302 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
303 )proto");
304 ExpectValidSyntax(node_def_control_input);
305
306 const NodeDef node_def_invalid_name = ToNodeDef(R"proto(
307 name:'n:0' op:'AnyIn' input:'a' input:'b'
308 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
309 )proto");
310 ExpectInvalidSyntax(node_def_invalid_name, "Illegal op name 'n:0'");
311
312 const NodeDef node_def_internal_name = ToNodeDef(R"proto(
313 name:'_n' op:'AnyIn' input:'a' input:'b'
314 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
315 )proto");
316 ExpectInvalidSyntax(node_def_internal_name, "Illegal op name '_n'");
317
318 const NodeDef node_def_slash_in_name = ToNodeDef(R"proto(
319 name:'n\\' op:'AnyIn' input:'a' input:'b'
320 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
321 )proto");
322 ExpectInvalidSyntax(node_def_slash_in_name, "Illegal op name 'n\\'");
323
324 const NodeDef node_def_internal_input_name = ToNodeDef(R"proto(
325 name:'n' op:'AnyIn' input:'_a' input:'b'
326 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
327 )proto");
328 ExpectInvalidSyntax(node_def_internal_input_name,
329 "Illegal op input name '_a'");
330
331 const NodeDef node_def_input_name_slash = ToNodeDef(R"proto(
332 name:'n' op:'AnyIn' input:'a\\' input:'b'
333 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
334 )proto");
335 ExpectInvalidSyntax(node_def_input_name_slash, "Illegal op input name 'a\\'");
336
337 const NodeDef node_def_invalid_control_input_name = ToNodeDef(R"proto(
338 name:'n' op:'AnyIn' input:'a' input:'^b:0'
339 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
340 )proto");
341 ExpectInvalidSyntax(node_def_invalid_control_input_name,
342 "Illegal op input name '^b:0'");
343
344 const NodeDef node_def_control_input_name_slash = ToNodeDef(R"proto(
345 name:'n' op:'AnyIn' input:'a' input:'^b\\'
346 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
347 )proto");
348 ExpectInvalidSyntax(node_def_control_input_name_slash,
349 "Illegal op input name '^b\\'");
350
351 const NodeDef node_def_data_input_after_control = ToNodeDef(R"proto(
352 name:'n' op:'AnyIn' input:'^a' input:'b'
353 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
354 )proto");
355 ExpectInvalidSyntax(node_def_data_input_after_control,
356 "All control inputs must follow all data inputs");
357
358 const NodeDef node_def_data_input_invalid_port = ToNodeDef(R"proto(
359 name:'n' op:'AnyIn' input:'a:b' input:'b'
360 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
361 )proto");
362 ExpectInvalidSyntax(node_def_data_input_invalid_port,
363 "Illegal op input name 'a:b");
364
365 const NodeDef node_def_data_input_invalid_port2 = ToNodeDef(R"proto(
366 name:'n' op:'AnyIn' input:'a:00' input:'b'
367 attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
368 )proto");
369 ExpectInvalidSyntax(node_def_data_input_invalid_port2,
370 "Illegal op input name 'a:00");
371 }
372
TEST(InputTypesForNode,Simple)373 TEST(InputTypesForNode, Simple) {
374 const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
375 .Input("a: float")
376 .Input("b: int32")
377 .Output("c: string")
378 .Output("d: bool"));
379 const NodeDef node_def = ToNodeDef(
380 NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
381 DataTypeVector types;
382 EXPECT_TRUE(InputTypesForNode(node_def, op_def, &types).ok());
383 EXPECT_EQ(types[0], DT_FLOAT);
384 EXPECT_EQ(types[1], DT_INT32);
385
386 DataType type;
387 EXPECT_TRUE(InputTypeForNode(node_def, op_def, 0, &type).ok());
388 EXPECT_EQ(type, DT_FLOAT);
389 EXPECT_TRUE(InputTypeForNode(node_def, op_def, 1, &type).ok());
390 EXPECT_EQ(type, DT_INT32);
391 EXPECT_FALSE(InputTypeForNode(node_def, op_def, 2, &type).ok());
392 }
393
TEST(OutputTypesForNode,Simple)394 TEST(OutputTypesForNode, Simple) {
395 const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
396 .Input("a: float")
397 .Input("b: int32")
398 .Output("c: string")
399 .Output("d: bool"));
400 const NodeDef node_def = ToNodeDef(
401 NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
402 DataTypeVector types;
403 EXPECT_TRUE(OutputTypesForNode(node_def, op_def, &types).ok());
404 EXPECT_EQ(types[0], DT_STRING);
405 EXPECT_EQ(types[1], DT_BOOL);
406
407 DataType type;
408 EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 0, &type).ok());
409 EXPECT_EQ(type, DT_STRING);
410 EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 1, &type).ok());
411 EXPECT_EQ(type, DT_BOOL);
412 EXPECT_FALSE(OutputTypeForNode(node_def, op_def, 2, &type).ok());
413 }
414
TEST(NameRangesForNodeTest,Simple)415 TEST(NameRangesForNodeTest, Simple) {
416 const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
417 .Input("a: float")
418 .Input("b: int32")
419 .Output("c: string")
420 .Output("d: bool"));
421 NameRangeMap inputs, outputs;
422 const NodeDef node_def = ToNodeDef(
423 NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
424 TF_EXPECT_OK(NameRangesForNode(node_def, op_def, &inputs, &outputs));
425 EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
426 EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 2}}}), outputs);
427
428 EXPECT_EQ("{{node simple}} = Simple[](a, b)", SummarizeNodeDef(node_def));
429
430 OpDef bad_op_def = op_def;
431 bad_op_def.mutable_input_arg(0)->clear_type();
432 EXPECT_FALSE(NameRangesForNode(node_def, bad_op_def, &inputs, &outputs).ok());
433 }
434
TEST(NameRangesForNodeTest,Polymorphic)435 TEST(NameRangesForNodeTest, Polymorphic) {
436 const OpDef op_def = ToOpDef(OpDefBuilder("Polymorphic")
437 .Input("a: T")
438 .Input("b: T")
439 .Output("c: T")
440 .Attr("T: type"));
441 NameRangeMap inputs, outputs;
442 const NodeDef node_def1 = ToNodeDef(NodeDefBuilder("poly", &op_def)
443 .Input(FakeInput(DT_INT32))
444 .Input(FakeInput(DT_INT32)));
445 TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
446 EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
447 EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs);
448 EXPECT_EQ("{{node poly}} = Polymorphic[T=DT_INT32](a, b)",
449 SummarizeNodeDef(node_def1));
450
451 const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("poly", &op_def)
452 .Input(FakeInput(DT_BOOL))
453 .Input(FakeInput(DT_BOOL)));
454 TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
455 EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
456 EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs);
457 EXPECT_EQ("{{node poly}} = Polymorphic[T=DT_BOOL](a, b)",
458 SummarizeNodeDef(node_def2));
459 }
460
TEST(NameRangesForNodeTest,NRepeats)461 TEST(NameRangesForNodeTest, NRepeats) {
462 const OpDef op_def = ToOpDef(OpDefBuilder("NRepeats")
463 .Input("a: N * int32")
464 .Input("b: N * T")
465 .Output("c: T")
466 .Output("d: N * string")
467 .Output("e: M * bool")
468 .Attr("N: int")
469 .Attr("M: int")
470 .Attr("T: type"));
471 NameRangeMap inputs, outputs;
472 const NodeDef node_def1 = ToNodeDef(NodeDefBuilder("nr", &op_def)
473 .Input(FakeInput(4, DT_INT32))
474 .Input(FakeInput(4, DT_FLOAT))
475 .Attr("M", 3));
476 TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
477 EXPECT_EQ(NameRangeMap({{"a", {0, 4}}, {"b", {4, 8}}}), inputs);
478 EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 5}}, {"e", {5, 8}}}),
479 outputs);
480 EXPECT_EQ(
481 "{{node nr}} = NRepeats[M=3, N=4, T=DT_FLOAT](a, a:1, a:2, a:3, b, b:1, "
482 "b:2, b:3)",
483 SummarizeNodeDef(node_def1));
484
485 const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("nr", &op_def)
486 .Input(FakeInput(2, DT_INT32))
487 .Input(FakeInput(2, DT_DOUBLE))
488 .Attr("M", 7));
489 TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
490 EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 4}}}), inputs);
491 EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}),
492 outputs);
493 EXPECT_EQ("{{node nr}} = NRepeats[M=7, N=2, T=DT_DOUBLE](a, a:1, b, b:1)",
494 SummarizeNodeDef(node_def2));
495
496 NodeDef bad_node_def = node_def2;
497 bad_node_def.clear_attr();
498 EXPECT_FALSE(NameRangesForNode(bad_node_def, op_def, &inputs, &outputs).ok());
499 }
500
TEST(NameRangesForNodeTest,TypeList)501 TEST(NameRangesForNodeTest, TypeList) {
502 const OpDef op_def = ToOpDef(OpDefBuilder("TypeList")
503 .Input("a: T1")
504 .Input("b: T2")
505 .Output("c: T2")
506 .Output("d: T3")
507 .Output("e: T1")
508 .Attr("T1: list(type)")
509 .Attr("T2: list(type)")
510 .Attr("T3: list(type)"));
511 NameRangeMap inputs, outputs;
512 const NodeDef node_def1 =
513 ToNodeDef(NodeDefBuilder("tl", &op_def)
514 .Input(FakeInput({DT_BOOL, DT_FLOAT}))
515 .Input(FakeInput(4, DT_FLOAT))
516 .Attr("T3", {DT_INT32, DT_DOUBLE, DT_STRING}));
517 TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
518 EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 6}}}), inputs);
519 EXPECT_EQ(NameRangeMap({{"c", {0, 4}}, {"d", {4, 7}}, {"e", {7, 9}}}),
520 outputs);
521 EXPECT_EQ(
522 "{{node tl}} = TypeList[T1=[DT_BOOL, DT_FLOAT],"
523 " T2=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT],"
524 " T3=[DT_INT32, DT_DOUBLE, DT_STRING]](a, a:1, b, b:1, b:2, b:3)",
525 SummarizeNodeDef(node_def1));
526
527 const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("tl", &op_def)
528 .Input(FakeInput(7, DT_INT32))
529 .Input(FakeInput({DT_DOUBLE}))
530 .Attr("T3", {DT_DOUBLE, DT_STRING}));
531 TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
532 EXPECT_EQ(NameRangeMap({{"a", {0, 7}}, {"b", {7, 8}}}), inputs);
533 EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}),
534 outputs);
535 EXPECT_EQ(
536 "{{node tl}} = TypeList[T1=[DT_INT32, DT_INT32, DT_INT32, DT_INT32, "
537 "DT_INT32,"
538 " DT_INT32, DT_INT32], T2=[DT_DOUBLE], T3=[DT_DOUBLE, DT_STRING]]"
539 "(a, a:1, a:2, a:3, a:4, a:5, a:6, b)",
540 SummarizeNodeDef(node_def2));
541
542 NodeDef bad_node_def = node_def2;
543 bad_node_def.clear_attr();
544 EXPECT_FALSE(NameRangesForNode(bad_node_def, op_def, &inputs, &outputs).ok());
545 }
546
TEST(AddPrefixAndSuffixToNode,Enter)547 TEST(AddPrefixAndSuffixToNode, Enter) {
548 NodeDef node_def;
549 node_def.set_name("enter");
550 node_def.set_op("Enter");
551 AddNodeAttr("frame_name", "test_frame", &node_def);
552 const string prefix = "prefix/";
553 const string suffix = "/suffix";
554 TF_ASSERT_OK(AddPrefixAndSuffixToNode(prefix, suffix, &node_def));
555 EXPECT_EQ("prefix/enter/suffix", node_def.name());
556 string frame_name;
557 TF_ASSERT_OK(GetNodeAttr(node_def, "frame_name", &frame_name));
558 EXPECT_EQ("prefix/test_frame/suffix", frame_name);
559 }
560
TEST(FormatNodeForErrorTest,Node)561 TEST(FormatNodeForErrorTest, Node) {
562 Graph g(OpRegistry::Global());
563 Node* node;
564 TF_CHECK_OK(NodeBuilder("enter", "NoOp").Finalize(&g, &node));
565 EXPECT_EQ("{{node enter}}", FormatNodeForError(*node));
566 }
567
TEST(FormatNodeForErrorTest,NodeDef)568 TEST(FormatNodeForErrorTest, NodeDef) {
569 NodeDef node_def;
570 node_def.set_name("enter");
571 node_def.set_op("Enter");
572 AddNodeAttr("frame_name", "test_frame", &node_def);
573 EXPECT_EQ("{{node enter}}", FormatNodeDefForError(node_def));
574 }
575
TEST(AttachDef,AllowMultipleFormattedNode)576 TEST(AttachDef, AllowMultipleFormattedNode) {
577 NodeDef a;
578 a.set_name("a");
579 NodeDef b;
580 b.set_name("b");
581 Status s = Status(error::CANCELLED, "Error");
582 Status s2 = AttachDef(s, a, true);
583 EXPECT_EQ("Error\n\t [[{{node a}}]]", s2.error_message());
584 Status s3 = AttachDef(s2, b, true);
585 EXPECT_EQ("Error\n\t [[{{node a}}]]\n\t [[{{node b}}]]", s3.error_message());
586 }
587
TEST(AttachDef,DisallowMultipleFormattedNode)588 TEST(AttachDef, DisallowMultipleFormattedNode) {
589 NodeDef a;
590 a.set_name("a");
591 NodeDef b;
592 b.set_name("b");
593 Status s = Status(error::CANCELLED, "Error");
594 Status s2 = AttachDef(s, a, false);
595 EXPECT_EQ("Error\n\t [[{{node a}}]]", s2.error_message());
596 Status s3 = AttachDef(s2, b, false);
597 EXPECT_EQ("Error\n\t [[{{node a}}]]\n\t [[b]]", s3.error_message());
598 }
599
600 } // namespace
601 } // namespace tensorflow
602