1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/node_def_util.h"
17 
18 #include "tensorflow/core/framework/attr_value.pb.h"  // NOLINT
19 #include "tensorflow/core/framework/fake_input.h"
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/framework/op_def_builder.h"
22 #include "tensorflow/core/framework/op_def_util.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/graph/node_builder.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/platform/protobuf.h"
29 #include "tensorflow/core/platform/test.h"
30 
31 namespace tensorflow {
32 namespace {
33 
ToOpDef(const OpDefBuilder & builder)34 OpDef ToOpDef(const OpDefBuilder& builder) {
35   OpRegistrationData op_reg_data;
36   TF_EXPECT_OK(builder.Finalize(&op_reg_data));
37   return op_reg_data.op_def;
38 }
39 
ToNodeDef(const string & text)40 NodeDef ToNodeDef(const string& text) {
41   NodeDef node_def;
42   EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
43   return node_def;
44 }
45 
ToNodeDef(const NodeDefBuilder & builder)46 NodeDef ToNodeDef(const NodeDefBuilder& builder) {
47   NodeDef node_def;
48   TF_EXPECT_OK(builder.Finalize(&node_def));
49   return node_def;
50 }
51 
ExpectSuccess(const NodeDef & good,const OpDef & op_def)52 void ExpectSuccess(const NodeDef& good, const OpDef& op_def) {
53   EXPECT_EQ(Status::OK(), ValidateNodeDef(good, op_def))
54       << "NodeDef: " << SummarizeNodeDef(good)
55       << "; OpDef: " << SummarizeOpDef(op_def);
56 }
57 
ExpectFailure(const NodeDef & bad,const OpDef & op_def,const string & message)58 void ExpectFailure(const NodeDef& bad, const OpDef& op_def,
59                    const string& message) {
60   Status status = ValidateNodeDef(bad, op_def);
61 
62   EXPECT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad)
63                             << "; OpDef: " << SummarizeOpDef(op_def);
64   if (status.ok()) return;
65 
66   EXPECT_TRUE(errors::IsInvalidArgument(status))
67       << status << "; NodeDef: " << SummarizeNodeDef(bad)
68       << "; OpDef: " << SummarizeOpDef(op_def);
69 
70   LOG(INFO) << "Message: " << status.error_message();
71   EXPECT_TRUE(str_util::StrContains(status.ToString(), message))
72       << "NodeDef: " << SummarizeNodeDef(bad)
73       << "; OpDef: " << SummarizeOpDef(op_def) << "\nActual error: " << status
74       << "\nDoes not contain: " << message;
75 }
76 
TEST(NodeDefUtilTest,In)77 TEST(NodeDefUtilTest, In) {
78   const OpDef op = ToOpDef(OpDefBuilder("In").Input("i: T").Attr("T: type"));
79   const NodeDef node_def = ToNodeDef(R"proto(
80     name:'n' op:'In' input:'a' attr { key:'T' value { type:DT_FLOAT } }
81     )proto");
82   ExpectSuccess(node_def, op);
83 
84   EXPECT_EQ("{{node n}} = In[T=DT_FLOAT](a)", SummarizeNodeDef(node_def));
85 
86   // Mismatching Op names.
87   NodeDef bad = node_def;
88   bad.set_op("Wrong");
89   ExpectFailure(bad, op, "NodeDef op 'Wrong' does not match Op<name=In;");
90 
91   // Missing attr
92   bad = node_def;
93   bad.clear_attr();
94   ExpectFailure(bad, op, "NodeDef missing attr 'T' from Op<name=In;");
95 
96   // Extra attr
97   bad = node_def;
98   AddNodeAttr("EXTRA", 17, &bad);
99   ExpectFailure(bad, op, "NodeDef mentions attr 'EXTRA' not in Op<name=In;");
100 
101   // Attr has wrong type
102   bad = node_def;
103   bad.clear_attr();
104   AddNodeAttr("T", 17, &bad);
105   ExpectFailure(
106       bad, op,
107       "AttrValue had value with type 'int' when 'type' expected\n\t for attr "
108       "'T'\n\t; NodeDef: ");
109 
110   // Wrong number of inputs
111   bad = node_def;
112   bad.add_input("b");
113   ExpectFailure(
114       bad, op,
115       "NodeDef expected inputs 'float' do not match 2 inputs specified;");
116 
117   bad = node_def;
118   bad.clear_input();
119   ExpectFailure(
120       bad, op,
121       "NodeDef expected inputs 'float' do not match 0 inputs specified;");
122 
123   // Control inputs must appear after data inputs
124   NodeDef good = node_def;
125   good.add_input("^b");
126   ExpectSuccess(node_def, op);
127 
128   bad = node_def;
129   bad.clear_input();
130   bad.add_input("^b");
131   bad.add_input("a");
132   ExpectFailure(bad, op,
133                 "Invalid argument: Non-control input 'a' after control input "
134                 "in NodeDef:");
135 
136   bad = node_def;
137   bad.add_input("^b:0");
138   ExpectFailure(bad, op, "Control input '^b:0' must not have ':' in NodeDef:");
139 }
140 
TEST(NodeDefUtilTest,Out)141 TEST(NodeDefUtilTest, Out) {
142   const OpDef op =
143       ToOpDef(OpDefBuilder("Out").Output("o: T").Attr("T: numbertype"));
144   const NodeDef node_def = ToNodeDef(R"proto(
145     name:'n' op:'Out' attr { key:'T' value { type:DT_INT32 } }
146     )proto");
147   ExpectSuccess(node_def, op);
148 
149   EXPECT_EQ("{{node n}} = Out[T=DT_INT32]()", SummarizeNodeDef(node_def));
150 
151   // Non-number type.
152   NodeDef bad = node_def;
153   bad.clear_attr();
154   AddNodeAttr("T", DT_STRING, &bad);
155   ExpectFailure(bad, op,
156                 "Value for attr 'T' of string is not in the list of allowed "
157                 "values: float, double, int32, uint8, int16, int8, complex64, "
158                 "int64, qint8, quint8, qint32, bfloat16, uint16, complex128, "
159                 "half, uint32, uint64");
160 }
161 
TEST(NodeDefUtilTest,Enum)162 TEST(NodeDefUtilTest, Enum) {
163   const OpDef op = ToOpDef(OpDefBuilder("Enum").Attr("e: {'apple','orange'}"));
164   const NodeDef node_def = ToNodeDef(R"proto(
165     name:'n' op:'Enum' attr { key:'e' value { s:'apple' } }
166     )proto");
167   ExpectSuccess(node_def, op);
168 
169   EXPECT_EQ("{{node n}} = Enum[e=\"apple\"]()", SummarizeNodeDef(node_def));
170 
171   NodeDef good = node_def;
172   good.clear_attr();
173   AddNodeAttr("e", "orange", &good);
174   ExpectSuccess(good, op);
175 
176   // Non-allowed value.
177   NodeDef bad = node_def;
178   bad.clear_attr();
179   AddNodeAttr("e", "foo", &bad);
180   ExpectFailure(bad, op,
181                 "Value for attr 'e' of \"foo\" is not in the list of allowed "
182                 "values: \"apple\", \"orange\"");
183 }
184 
TEST(NodeDefUtilTest,SameIn)185 TEST(NodeDefUtilTest, SameIn) {
186   const OpDef op = ToOpDef(OpDefBuilder("SameIn")
187                                .Input("i: N * T")
188                                .Attr("N: int >= 2")
189                                .Attr("T: {float,double}"));
190   const NodeDef node_def = ToNodeDef(R"proto(
191     name:'n' op:'SameIn' input:'a' input:'b'
192     attr { key:'N' value { i:2 } } attr { key:'T' value { type:DT_DOUBLE } }
193     )proto");
194   ExpectSuccess(node_def, op);
195 
196   EXPECT_EQ("{{node n}} = SameIn[N=2, T=DT_DOUBLE](a, b)",
197             SummarizeNodeDef(node_def));
198 
199   // Illegal type
200   NodeDef bad = ToNodeDef(R"proto(
201     name:'n' op:'SameIn' input:'a' input:'b'
202     attr { key:'N' value { i:2 } } attr { key:'T' value { type:DT_STRING } }
203     )proto");
204   ExpectFailure(bad, op,
205                 "Value for attr 'T' of string is not in the list of allowed "
206                 "values: float, double");
207 
208   // Too few inputs
209   bad = ToNodeDef(R"proto(
210     name:'n' op:'SameIn' input:'a' input:'b'
211     attr { key:'N' value { i:1 } } attr { key:'T' value { type:DT_FLOAT } }
212     )proto");
213   ExpectFailure(bad, op, "Value for attr 'N' of 1 must be at least minimum 2");
214 }
215 
TEST(NodeDefUtilTest,AnyIn)216 TEST(NodeDefUtilTest, AnyIn) {
217   const OpDef op =
218       ToOpDef(OpDefBuilder("AnyIn").Input("i: T").Attr("T: list(type) >= 1"));
219 
220   const NodeDef node_def = ToNodeDef(R"proto(
221     name:'n' op:'AnyIn' input:'a' input:'b'
222     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
223     )proto");
224   ExpectSuccess(node_def, op);
225 
226   EXPECT_EQ("{{node n}} = AnyIn[T=[DT_INT32, DT_STRING]](a, b)",
227             SummarizeNodeDef(node_def));
228 
229   const NodeDef bad = ToNodeDef(R"proto(
230     name:'n' op:'AnyIn' input:'a' attr { key:'T' value { list { } } }
231     )proto");
232   ExpectFailure(bad, op, "Length for attr 'T' of 0 must be at least minimum 1");
233 
234   // With proto3 semantics, an empty value {} is indistinguishable from a value
235   // with an empty list in it. So we simply expect to get a message complaining
236   // about empty list for value {}.
237   const NodeDef bad2 = ToNodeDef(R"proto(
238     name:'n' op:'AnyIn' input:'a' attr { key:'T' value { } }
239     )proto");
240   ExpectFailure(bad2, op,
241                 "Length for attr 'T' of 0 must be at least minimum 1");
242 }
243 
TEST(NodeDefUtilTest,Device)244 TEST(NodeDefUtilTest, Device) {
245   const OpDef op_def1 = ToOpDef(OpDefBuilder("None"));
246   const NodeDef node_def1 =
247       ToNodeDef(NodeDefBuilder("d", &op_def1).Device("/cpu:17"));
248   ExpectSuccess(node_def1, op_def1);
249   EXPECT_EQ("{{node d}} = None[_device=\"/cpu:17\"]()",
250             SummarizeNodeDef(node_def1));
251 
252   const OpDef op_def2 = ToOpDef(OpDefBuilder("WithAttr").Attr("v: int"));
253   const NodeDef node_def2 =
254       ToNodeDef(NodeDefBuilder("d", &op_def2).Attr("v", 7).Device("/cpu:5"));
255   ExpectSuccess(node_def2, op_def2);
256   EXPECT_EQ("{{node d}} = WithAttr[v=7, _device=\"/cpu:5\"]()",
257             SummarizeNodeDef(node_def2));
258 }
259 
ExpectValidSyntax(const NodeDef & good)260 void ExpectValidSyntax(const NodeDef& good) {
261   EXPECT_EQ(Status::OK(), ValidateExternalNodeDefSyntax(good))
262       << "NodeDef: " << SummarizeNodeDef(good);
263 }
264 
ExpectInvalidSyntax(const NodeDef & bad,const string & message)265 void ExpectInvalidSyntax(const NodeDef& bad, const string& message) {
266   Status status = ValidateExternalNodeDefSyntax(bad);
267 
268   ASSERT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad);
269 
270   EXPECT_TRUE(errors::IsInvalidArgument(status))
271       << status << "; NodeDef: " << SummarizeNodeDef(bad);
272 
273   EXPECT_TRUE(str_util::StrContains(StringPiece(status.ToString()), message))
274       << "NodeDef: " << SummarizeNodeDef(bad) << ", " << status << ", "
275       << message;
276 }
277 
TEST(NodeDefUtilTest,ValidSyntax)278 TEST(NodeDefUtilTest, ValidSyntax) {
279   const NodeDef node_def = ToNodeDef(R"proto(
280     name:'n' op:'AnyIn' input:'a' input:'b'
281     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
282     )proto");
283   ExpectValidSyntax(node_def);
284 
285   const NodeDef node_def_explicit_inputs = ToNodeDef(R"proto(
286     name:'n' op:'AnyIn' input:'a:0' input:'b:123'
287     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
288     )proto");
289   ExpectValidSyntax(node_def_explicit_inputs);
290 
291   EXPECT_EQ("{{node n}} = AnyIn[T=[DT_INT32, DT_STRING]](a:0, b:123)",
292             SummarizeNodeDef(node_def_explicit_inputs));
293 
294   const NodeDef node_def_partial_shape = ToNodeDef(R"proto(
295     name:'n' op:'AnyIn'
296     attr { key:'shp' value { shape { dim { size: -1 } dim { size: 0 } } } }
297     )proto");
298   ExpectValidSyntax(node_def_partial_shape);
299 
300   const NodeDef node_def_control_input = ToNodeDef(R"proto(
301     name:'n-' op:'AnyIn' input:'a' input:'^b'
302     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
303     )proto");
304   ExpectValidSyntax(node_def_control_input);
305 
306   const NodeDef node_def_invalid_name = ToNodeDef(R"proto(
307     name:'n:0' op:'AnyIn' input:'a' input:'b'
308     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
309     )proto");
310   ExpectInvalidSyntax(node_def_invalid_name, "Illegal op name 'n:0'");
311 
312   const NodeDef node_def_internal_name = ToNodeDef(R"proto(
313     name:'_n' op:'AnyIn' input:'a' input:'b'
314     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
315     )proto");
316   ExpectInvalidSyntax(node_def_internal_name, "Illegal op name '_n'");
317 
318   const NodeDef node_def_slash_in_name = ToNodeDef(R"proto(
319     name:'n\\' op:'AnyIn' input:'a' input:'b'
320     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
321     )proto");
322   ExpectInvalidSyntax(node_def_slash_in_name, "Illegal op name 'n\\'");
323 
324   const NodeDef node_def_internal_input_name = ToNodeDef(R"proto(
325     name:'n' op:'AnyIn' input:'_a' input:'b'
326     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
327     )proto");
328   ExpectInvalidSyntax(node_def_internal_input_name,
329                       "Illegal op input name '_a'");
330 
331   const NodeDef node_def_input_name_slash = ToNodeDef(R"proto(
332     name:'n' op:'AnyIn' input:'a\\' input:'b'
333     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
334     )proto");
335   ExpectInvalidSyntax(node_def_input_name_slash, "Illegal op input name 'a\\'");
336 
337   const NodeDef node_def_invalid_control_input_name = ToNodeDef(R"proto(
338     name:'n' op:'AnyIn' input:'a' input:'^b:0'
339     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
340     )proto");
341   ExpectInvalidSyntax(node_def_invalid_control_input_name,
342                       "Illegal op input name '^b:0'");
343 
344   const NodeDef node_def_control_input_name_slash = ToNodeDef(R"proto(
345     name:'n' op:'AnyIn' input:'a' input:'^b\\'
346     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
347     )proto");
348   ExpectInvalidSyntax(node_def_control_input_name_slash,
349                       "Illegal op input name '^b\\'");
350 
351   const NodeDef node_def_data_input_after_control = ToNodeDef(R"proto(
352     name:'n' op:'AnyIn' input:'^a' input:'b'
353     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
354     )proto");
355   ExpectInvalidSyntax(node_def_data_input_after_control,
356                       "All control inputs must follow all data inputs");
357 
358   const NodeDef node_def_data_input_invalid_port = ToNodeDef(R"proto(
359     name:'n' op:'AnyIn' input:'a:b' input:'b'
360     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
361     )proto");
362   ExpectInvalidSyntax(node_def_data_input_invalid_port,
363                       "Illegal op input name 'a:b");
364 
365   const NodeDef node_def_data_input_invalid_port2 = ToNodeDef(R"proto(
366     name:'n' op:'AnyIn' input:'a:00' input:'b'
367     attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
368     )proto");
369   ExpectInvalidSyntax(node_def_data_input_invalid_port2,
370                       "Illegal op input name 'a:00");
371 }
372 
TEST(InputTypesForNode,Simple)373 TEST(InputTypesForNode, Simple) {
374   const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
375                                    .Input("a: float")
376                                    .Input("b: int32")
377                                    .Output("c: string")
378                                    .Output("d: bool"));
379   const NodeDef node_def = ToNodeDef(
380       NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
381   DataTypeVector types;
382   EXPECT_TRUE(InputTypesForNode(node_def, op_def, &types).ok());
383   EXPECT_EQ(types[0], DT_FLOAT);
384   EXPECT_EQ(types[1], DT_INT32);
385 
386   DataType type;
387   EXPECT_TRUE(InputTypeForNode(node_def, op_def, 0, &type).ok());
388   EXPECT_EQ(type, DT_FLOAT);
389   EXPECT_TRUE(InputTypeForNode(node_def, op_def, 1, &type).ok());
390   EXPECT_EQ(type, DT_INT32);
391   EXPECT_FALSE(InputTypeForNode(node_def, op_def, 2, &type).ok());
392 }
393 
TEST(OutputTypesForNode,Simple)394 TEST(OutputTypesForNode, Simple) {
395   const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
396                                    .Input("a: float")
397                                    .Input("b: int32")
398                                    .Output("c: string")
399                                    .Output("d: bool"));
400   const NodeDef node_def = ToNodeDef(
401       NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
402   DataTypeVector types;
403   EXPECT_TRUE(OutputTypesForNode(node_def, op_def, &types).ok());
404   EXPECT_EQ(types[0], DT_STRING);
405   EXPECT_EQ(types[1], DT_BOOL);
406 
407   DataType type;
408   EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 0, &type).ok());
409   EXPECT_EQ(type, DT_STRING);
410   EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 1, &type).ok());
411   EXPECT_EQ(type, DT_BOOL);
412   EXPECT_FALSE(OutputTypeForNode(node_def, op_def, 2, &type).ok());
413 }
414 
TEST(NameRangesForNodeTest,Simple)415 TEST(NameRangesForNodeTest, Simple) {
416   const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
417                                    .Input("a: float")
418                                    .Input("b: int32")
419                                    .Output("c: string")
420                                    .Output("d: bool"));
421   NameRangeMap inputs, outputs;
422   const NodeDef node_def = ToNodeDef(
423       NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
424   TF_EXPECT_OK(NameRangesForNode(node_def, op_def, &inputs, &outputs));
425   EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
426   EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 2}}}), outputs);
427 
428   EXPECT_EQ("{{node simple}} = Simple[](a, b)", SummarizeNodeDef(node_def));
429 
430   OpDef bad_op_def = op_def;
431   bad_op_def.mutable_input_arg(0)->clear_type();
432   EXPECT_FALSE(NameRangesForNode(node_def, bad_op_def, &inputs, &outputs).ok());
433 }
434 
TEST(NameRangesForNodeTest,Polymorphic)435 TEST(NameRangesForNodeTest, Polymorphic) {
436   const OpDef op_def = ToOpDef(OpDefBuilder("Polymorphic")
437                                    .Input("a: T")
438                                    .Input("b: T")
439                                    .Output("c: T")
440                                    .Attr("T: type"));
441   NameRangeMap inputs, outputs;
442   const NodeDef node_def1 = ToNodeDef(NodeDefBuilder("poly", &op_def)
443                                           .Input(FakeInput(DT_INT32))
444                                           .Input(FakeInput(DT_INT32)));
445   TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
446   EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
447   EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs);
448   EXPECT_EQ("{{node poly}} = Polymorphic[T=DT_INT32](a, b)",
449             SummarizeNodeDef(node_def1));
450 
451   const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("poly", &op_def)
452                                           .Input(FakeInput(DT_BOOL))
453                                           .Input(FakeInput(DT_BOOL)));
454   TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
455   EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
456   EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs);
457   EXPECT_EQ("{{node poly}} = Polymorphic[T=DT_BOOL](a, b)",
458             SummarizeNodeDef(node_def2));
459 }
460 
TEST(NameRangesForNodeTest,NRepeats)461 TEST(NameRangesForNodeTest, NRepeats) {
462   const OpDef op_def = ToOpDef(OpDefBuilder("NRepeats")
463                                    .Input("a: N * int32")
464                                    .Input("b: N * T")
465                                    .Output("c: T")
466                                    .Output("d: N * string")
467                                    .Output("e: M * bool")
468                                    .Attr("N: int")
469                                    .Attr("M: int")
470                                    .Attr("T: type"));
471   NameRangeMap inputs, outputs;
472   const NodeDef node_def1 = ToNodeDef(NodeDefBuilder("nr", &op_def)
473                                           .Input(FakeInput(4, DT_INT32))
474                                           .Input(FakeInput(4, DT_FLOAT))
475                                           .Attr("M", 3));
476   TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
477   EXPECT_EQ(NameRangeMap({{"a", {0, 4}}, {"b", {4, 8}}}), inputs);
478   EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 5}}, {"e", {5, 8}}}),
479             outputs);
480   EXPECT_EQ(
481       "{{node nr}} = NRepeats[M=3, N=4, T=DT_FLOAT](a, a:1, a:2, a:3, b, b:1, "
482       "b:2, b:3)",
483       SummarizeNodeDef(node_def1));
484 
485   const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("nr", &op_def)
486                                           .Input(FakeInput(2, DT_INT32))
487                                           .Input(FakeInput(2, DT_DOUBLE))
488                                           .Attr("M", 7));
489   TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
490   EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 4}}}), inputs);
491   EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}),
492             outputs);
493   EXPECT_EQ("{{node nr}} = NRepeats[M=7, N=2, T=DT_DOUBLE](a, a:1, b, b:1)",
494             SummarizeNodeDef(node_def2));
495 
496   NodeDef bad_node_def = node_def2;
497   bad_node_def.clear_attr();
498   EXPECT_FALSE(NameRangesForNode(bad_node_def, op_def, &inputs, &outputs).ok());
499 }
500 
TEST(NameRangesForNodeTest,TypeList)501 TEST(NameRangesForNodeTest, TypeList) {
502   const OpDef op_def = ToOpDef(OpDefBuilder("TypeList")
503                                    .Input("a: T1")
504                                    .Input("b: T2")
505                                    .Output("c: T2")
506                                    .Output("d: T3")
507                                    .Output("e: T1")
508                                    .Attr("T1: list(type)")
509                                    .Attr("T2: list(type)")
510                                    .Attr("T3: list(type)"));
511   NameRangeMap inputs, outputs;
512   const NodeDef node_def1 =
513       ToNodeDef(NodeDefBuilder("tl", &op_def)
514                     .Input(FakeInput({DT_BOOL, DT_FLOAT}))
515                     .Input(FakeInput(4, DT_FLOAT))
516                     .Attr("T3", {DT_INT32, DT_DOUBLE, DT_STRING}));
517   TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
518   EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 6}}}), inputs);
519   EXPECT_EQ(NameRangeMap({{"c", {0, 4}}, {"d", {4, 7}}, {"e", {7, 9}}}),
520             outputs);
521   EXPECT_EQ(
522       "{{node tl}} = TypeList[T1=[DT_BOOL, DT_FLOAT],"
523       " T2=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT],"
524       " T3=[DT_INT32, DT_DOUBLE, DT_STRING]](a, a:1, b, b:1, b:2, b:3)",
525       SummarizeNodeDef(node_def1));
526 
527   const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("tl", &op_def)
528                                           .Input(FakeInput(7, DT_INT32))
529                                           .Input(FakeInput({DT_DOUBLE}))
530                                           .Attr("T3", {DT_DOUBLE, DT_STRING}));
531   TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
532   EXPECT_EQ(NameRangeMap({{"a", {0, 7}}, {"b", {7, 8}}}), inputs);
533   EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}),
534             outputs);
535   EXPECT_EQ(
536       "{{node tl}} = TypeList[T1=[DT_INT32, DT_INT32, DT_INT32, DT_INT32, "
537       "DT_INT32,"
538       " DT_INT32, DT_INT32], T2=[DT_DOUBLE], T3=[DT_DOUBLE, DT_STRING]]"
539       "(a, a:1, a:2, a:3, a:4, a:5, a:6, b)",
540       SummarizeNodeDef(node_def2));
541 
542   NodeDef bad_node_def = node_def2;
543   bad_node_def.clear_attr();
544   EXPECT_FALSE(NameRangesForNode(bad_node_def, op_def, &inputs, &outputs).ok());
545 }
546 
TEST(AddPrefixAndSuffixToNode,Enter)547 TEST(AddPrefixAndSuffixToNode, Enter) {
548   NodeDef node_def;
549   node_def.set_name("enter");
550   node_def.set_op("Enter");
551   AddNodeAttr("frame_name", "test_frame", &node_def);
552   const string prefix = "prefix/";
553   const string suffix = "/suffix";
554   TF_ASSERT_OK(AddPrefixAndSuffixToNode(prefix, suffix, &node_def));
555   EXPECT_EQ("prefix/enter/suffix", node_def.name());
556   string frame_name;
557   TF_ASSERT_OK(GetNodeAttr(node_def, "frame_name", &frame_name));
558   EXPECT_EQ("prefix/test_frame/suffix", frame_name);
559 }
560 
TEST(FormatNodeForErrorTest,Node)561 TEST(FormatNodeForErrorTest, Node) {
562   Graph g(OpRegistry::Global());
563   Node* node;
564   TF_CHECK_OK(NodeBuilder("enter", "NoOp").Finalize(&g, &node));
565   EXPECT_EQ("{{node enter}}", FormatNodeForError(*node));
566 }
567 
TEST(FormatNodeForErrorTest,NodeDef)568 TEST(FormatNodeForErrorTest, NodeDef) {
569   NodeDef node_def;
570   node_def.set_name("enter");
571   node_def.set_op("Enter");
572   AddNodeAttr("frame_name", "test_frame", &node_def);
573   EXPECT_EQ("{{node enter}}", FormatNodeDefForError(node_def));
574 }
575 
TEST(AttachDef,AllowMultipleFormattedNode)576 TEST(AttachDef, AllowMultipleFormattedNode) {
577   NodeDef a;
578   a.set_name("a");
579   NodeDef b;
580   b.set_name("b");
581   Status s = Status(error::CANCELLED, "Error");
582   Status s2 = AttachDef(s, a, true);
583   EXPECT_EQ("Error\n\t [[{{node a}}]]", s2.error_message());
584   Status s3 = AttachDef(s2, b, true);
585   EXPECT_EQ("Error\n\t [[{{node a}}]]\n\t [[{{node b}}]]", s3.error_message());
586 }
587 
TEST(AttachDef,DisallowMultipleFormattedNode)588 TEST(AttachDef, DisallowMultipleFormattedNode) {
589   NodeDef a;
590   a.set_name("a");
591   NodeDef b;
592   b.set_name("b");
593   Status s = Status(error::CANCELLED, "Error");
594   Status s2 = AttachDef(s, a, false);
595   EXPECT_EQ("Error\n\t [[{{node a}}]]", s2.error_message());
596   Status s3 = AttachDef(s2, b, false);
597   EXPECT_EQ("Error\n\t [[{{node a}}]]\n\t [[b]]", s3.error_message());
598 }
599 
600 }  // namespace
601 }  // namespace tensorflow
602