1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
17
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
20 #include "tensorflow/cc/ops/function_ops.h"
21 #include "tensorflow/cc/ops/functional_ops.h"
22 #include "tensorflow/cc/ops/resource_variable_ops.h"
23 #include "tensorflow/cc/ops/standard_ops.h"
24 #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
25 #include "tensorflow/compiler/tf2xla/test_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/core/common_runtime/function.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/graph/graph_constructor.h"
32 #include "tensorflow/core/graph/graph_def_builder.h"
33 #include "tensorflow/core/graph/validate.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/core/platform/test.h"
36 #include "tensorflow/core/util/equal_graph_def.h"
37
38 namespace tensorflow {
39 namespace {
40
41 // Returns the names of the "then" and "else" functions for the If node in a
42 // graph.
FindIfThenAndElse(const GraphDef & graph,string * op_name,NameAttrList * then_fn,NameAttrList * else_fn)43 Status FindIfThenAndElse(const GraphDef& graph, string* op_name,
44 NameAttrList* then_fn, NameAttrList* else_fn) {
45 for (const NodeDef& node : graph.node()) {
46 if (node.op() == "If") {
47 *op_name = node.name();
48 const NameAttrList* result;
49 TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result));
50 *then_fn = *result;
51 TF_RETURN_IF_ERROR(GetNodeAttr(node, "else_branch", &result));
52 *else_fn = *result;
53 return Status::OK();
54 }
55 }
56 return errors::NotFound("No If node found in graph");
57 }
58
59 // Graph:
60 // x = array_ops.placeholder(dtypes.int32)
61 // y = array_ops.placeholder(dtypes.int32)
62 // z = control_flow_ops.cond(
63 // math_ops.less(y, x), lambda: math_ops.multiply(y, 17),
64 // lambda: math_ops.add(x, 23))
TEST(FunctionalizeControlFlow,Conditional)65 TEST(FunctionalizeControlFlow, Conditional) {
66 Graph graph(OpRegistry::Global());
67 {
68 Scope scope = Scope::NewRootScope().ExitOnError();
69
70 auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
71 auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
72 auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
73 auto switch_1 = ops::Switch(scope.WithOpName("cond/Switch"), less, less);
74
75 auto identity_t =
76 ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_true);
77 auto seventeen = ops::Const<int32>(
78 scope.WithOpName("cond").WithControlDependencies(identity_t), 17);
79 auto switch_2 = ops::Switch(scope.WithOpName("cond/Switch"), y, less);
80 auto mul = ops::Multiply(scope.WithOpName("cond/Mul"), switch_2.output_true,
81 seventeen);
82
83 auto identity_f =
84 ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_false);
85 auto twenty_three = ops::Const<int32>(
86 scope.WithOpName("cond").WithControlDependencies(identity_f), 23);
87 auto switch_3 = ops::Switch(scope.WithOpName("cond/Switch"), x, less);
88 auto add = ops::Add(scope.WithOpName("cond/false/add"),
89 switch_3.output_false, twenty_three);
90
91 auto merge = ops::Merge(scope.WithOpName("cond/Merge"),
92 std::initializer_list<Input>{add, mul});
93
94 TF_EXPECT_OK(scope.ToGraph(&graph));
95 }
96
97 FunctionLibraryDefinition library(OpRegistry::Global(), {});
98 GraphDef optimized_graph_def;
99 graph.ToGraphDef(&optimized_graph_def);
100 TF_ASSERT_OK(
101 FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
102 TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
103 GraphDef converted_graph_def;
104 graph.ToGraphDef(&converted_graph_def);
105
106 for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
107 string op_name;
108 NameAttrList then_fn;
109 NameAttrList else_fn;
110 TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn));
111 InstantiationResultForTest else_result;
112 TF_EXPECT_OK(
113 InstantiateFunctionForTest(else_fn.name(), library, &else_result));
114
115 // Outer graph
116 {
117 Scope scope = Scope::NewRootScope().ExitOnError();
118 auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
119 auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
120 auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
121 auto if_op = ops::If(scope.WithOpName(op_name), less,
122 std::initializer_list<Input>{less, y, x}, {DT_INT32},
123 then_fn, else_fn);
124 auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]);
125 GraphDef expected;
126 TF_EXPECT_OK(scope.ToGraphDef(&expected));
127 TF_EXPECT_GRAPH_EQ(expected, graph_def);
128 }
129
130 // then body.
131 {
132 Scope scope = Scope::NewRootScope().ExitOnError();
133 auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0);
134 auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
135 auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
136 auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0);
137 auto cond = ops::Const(
138 scope.WithOpName("cond").WithControlDependencies(identity), 17);
139 auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond);
140 auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), mul, 0);
141
142 GraphDef expected;
143 TF_EXPECT_OK(scope.ToGraphDef(&expected));
144
145 InstantiationResultForTest result;
146 TF_EXPECT_OK(
147 InstantiateFunctionForTest(then_fn.name(), library, &result));
148
149 EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
150 EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}),
151 result.arg_types);
152 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
153 }
154
155 // else body.
156 {
157 Scope scope = Scope::NewRootScope().ExitOnError();
158 auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0);
159 auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
160 auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
161 auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0);
162 auto cond_1 = ops::Const(
163 scope.WithOpName("cond_1").WithControlDependencies(identity), 23);
164 auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1);
165 auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
166
167 GraphDef expected;
168 TF_EXPECT_OK(scope.ToGraphDef(&expected));
169
170 InstantiationResultForTest result;
171 TF_EXPECT_OK(
172 InstantiateFunctionForTest(else_fn.name(), library, &result));
173
174 EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
175 EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}),
176 result.arg_types);
177 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
178 }
179 }
180 }
181
182 // Returns the names of the "cond" and "body" functions for the While node
183 // in a graph.
FindWhileCondAndBody(const GraphDef & graph,NameAttrList * cond,NameAttrList * body)184 Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond,
185 NameAttrList* body) {
186 for (const NodeDef& node : graph.node()) {
187 if (node.op() == "While") {
188 const NameAttrList* result;
189 TF_RETURN_IF_ERROR(GetNodeAttr(node, "cond", &result));
190 *cond = *result;
191 TF_RETURN_IF_ERROR(GetNodeAttr(node, "body", &result));
192 *body = *result;
193 return Status::OK();
194 }
195 }
196 return errors::NotFound("No While node found in graph");
197 }
198
199 // Graph:
200 // x = array_ops.placeholder(dtypes.int32)
201 // y = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x])
TEST(FunctionalizeControlFlow,OneLoopVar)202 TEST(FunctionalizeControlFlow, OneLoopVar) {
203 Graph graph(OpRegistry::Global());
204 {
205 Scope scope = Scope::NewRootScope().ExitOnError();
206
207 auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
208
209 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
210 auto enter =
211 ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop");
212 // Add an unused Enter node. These should be ignored.
213 auto enter2 =
214 ops::internal::Enter(scope.WithOpName("while/Enter2"), source, "aloop");
215 auto merge = ops::Merge(scope.WithOpName("while/Merge"),
216 std::initializer_list<Input>{enter, dummy});
217 auto ten = ops::Const<int32>(
218 scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
219 10);
220 auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
221 auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
222 auto switch_ =
223 ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
224 auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"),
225 switch_.output_false);
226 auto identity =
227 ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
228 auto one = ops::Const<int32>(
229 scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
230 auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
231 auto next_iteration =
232 ops::NextIteration(scope.WithOpName("while/NextIteration"), add);
233
234 auto sink = ops::Identity(scope.WithOpName("sink"), exit);
235
236 // Remove the dummy node and add the loop backedge.
237 scope.graph()->RemoveNode(dummy.node());
238 scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
239
240 TF_EXPECT_OK(scope.ToGraph(&graph));
241 }
242
243 // Regression test: control edges from an Enter node to the graph sink should
244 // be ignored.
245 for (Node* n : graph.nodes()) {
246 if (n->name() == "while/Enter") {
247 graph.AddControlEdge(n, graph.sink_node());
248 }
249 }
250
251 FunctionLibraryDefinition library(OpRegistry::Global(), {});
252 GraphDef optimized_graph_def;
253 graph.ToGraphDef(&optimized_graph_def);
254 TF_ASSERT_OK(
255 FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
256 TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
257 GraphDef converted_graph_def;
258 graph.ToGraphDef(&converted_graph_def);
259
260 for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
261 NameAttrList cond_fn, body_fn;
262 TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
263
264 // Outer graph
265 {
266 Scope scope = Scope::NewRootScope().ExitOnError();
267 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
268 auto while_op =
269 ops::While(scope.WithOpName("while/LoopCond"),
270 std::initializer_list<Input>{source}, cond_fn, body_fn);
271 auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
272 GraphDef expected;
273 TF_EXPECT_OK(scope.ToGraphDef(&expected));
274 TF_EXPECT_GRAPH_EQ(expected, graph_def);
275 }
276
277 // Condition graph
278 {
279 Scope scope = Scope::NewRootScope().ExitOnError();
280 auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
281 auto ten = ops::Const<int32>(
282 scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10);
283 auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten);
284 auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
285
286 GraphDef expected;
287 TF_EXPECT_OK(scope.ToGraphDef(&expected));
288
289 InstantiationResultForTest result;
290 TF_EXPECT_OK(
291 InstantiateFunctionForTest(cond_fn.name(), library, &result));
292
293 EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
294 EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
295 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
296 }
297
298 // Body graph.
299 {
300 Scope scope = Scope::NewRootScope().ExitOnError();
301 auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
302 auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
303 auto one = ops::Const<int32>(
304 scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
305 auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
306 auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
307
308 GraphDef expected;
309 TF_EXPECT_OK(scope.ToGraphDef(&expected));
310
311 InstantiationResultForTest result;
312 TF_EXPECT_OK(
313 InstantiateFunctionForTest(body_fn.name(), library, &result));
314
315 EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
316 EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
317 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
318 }
319 }
320 }
321
GetNoinlineFunctionDef()322 FunctionDef GetNoinlineFunctionDef() {
323 FunctionDef fdef = FunctionDefHelper::Create(
324 "increment_fn", {"x:int32"}, {"add:int32"}, {},
325 {
326 {{"add/y"}, "Const", {}, {{"dtype", DT_INT32}}},
327 {{"add_0"}, "Add", {"x", "add/y:output:0"}, {{"T", DT_INT32}}},
328 },
329 {{"add", "add_0:z:0"}});
330 (*fdef.mutable_attr())["_noinline"].set_b(true);
331 return fdef;
332 }
333
334 // @function.Defun(noinline=True)
335 // def increment_fn(x):
336 // return [x + 1]
337 // Define the above function, and add it to the given graph. It's used as the
338 // while loop body in NoinlineLoopBody test.
AddNoinlineFunctionToGraph(const string & node_name,Graph * graph)339 Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) {
340 FunctionDefLibrary fdef_lib;
341 *(fdef_lib.add_function()) = GetNoinlineFunctionDef();
342 TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdef_lib));
343 NodeDef increment_fn;
344 increment_fn.set_name(node_name);
345 increment_fn.set_op("increment_fn");
346 *increment_fn.add_input() = "while/Identity";
347 *increment_fn.add_input() = "^while/Identity";
348 Status status;
349 graph->AddNode(increment_fn, &status);
350 return status;
351 }
352
353 // Graph:
354 // x = array_ops.placeholder(dtypes.int32)
355 // y = control_flow_ops.while_loop(lambda i: i < 10, increment_fn, [x])
TEST(FunctionalizeControlFlow,NoinlineLoopBody)356 TEST(FunctionalizeControlFlow, NoinlineLoopBody) {
357 const string& noinline_node_name = "while/increment_fn";
358 Graph graph(OpRegistry::Global());
359 {
360 Scope scope = Scope::NewRootScope().ExitOnError();
361 auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
362 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
363 auto enter = ops::internal::Enter(scope.WithOpName("while/Enter"), source,
364 "while/while_context");
365 auto merge = ops::Merge(scope.WithOpName("while/Merge"),
366 std::initializer_list<Input>{enter, dummy});
367 auto ten = ops::Const<int32>(
368 scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
369 10);
370 auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
371 auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
372 auto switch_ =
373 ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
374 auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"),
375 switch_.output_false);
376 auto identity =
377 ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
378
379 TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph()));
380
381 NodeDef next_iter;
382 next_iter.set_name("while/NextIteration");
383 next_iter.set_op("NextIteration");
384 *next_iter.add_input() = noinline_node_name;
385 (*next_iter.mutable_attr())["T"].set_type(DT_INT32);
386
387 Status status;
388 Node* n = scope.graph()->AddNode(next_iter, &status);
389 TF_ASSERT_OK(status);
390
391 // Remove the dummy node and add the loop backedge.
392 scope.graph()->RemoveNode(dummy.node());
393 scope.graph()->AddEdge(n, 0, merge.output.node(), 1);
394 TF_ASSERT_OK(scope.ToGraph(&graph));
395 }
396
397 FunctionLibraryDefinition lookup_lib(graph.flib_def());
398 FunctionLibraryDefinition library(OpRegistry::Global(), {});
399 // Function increment_fn will be copied from lookup_lib to library.
400 GraphDef optimized_graph_def;
401 graph.ToGraphDef(&optimized_graph_def);
402
403 *(optimized_graph_def.mutable_library()->add_function()) =
404 GetNoinlineFunctionDef();
405
406 TF_ASSERT_OK(FunctionalizeControlFlowForGraphDef(
407 &lookup_lib, &optimized_graph_def, &library));
408 TF_ASSERT_OK(FunctionalizeControlFlow(&lookup_lib, &graph, &library));
409 GraphDef converted_graph_def;
410 graph.ToGraphDef(&converted_graph_def);
411
412 for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
413 NameAttrList cond_fn, body_fn;
414 TF_ASSERT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
415
416 // Outer graph
417 {
418 Scope scope = Scope::NewRootScope().ExitOnError();
419 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
420 auto while_op =
421 ops::While(scope.WithOpName("while/LoopCond"),
422 std::initializer_list<Input>{source}, cond_fn, body_fn);
423 GraphDef expected;
424 TF_ASSERT_OK(scope.ToGraphDef(&expected));
425 TF_EXPECT_GRAPH_EQ(expected, graph_def);
426 }
427
428 // Body graph.
429 {
430 Scope scope = Scope::NewRootScope().ExitOnError();
431 auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
432 TF_ASSERT_OK(
433 AddNoinlineFunctionToGraph(noinline_node_name, scope.graph()));
434 auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
435 NodeDef retval;
436 retval.set_name("_retval0_RetVal");
437 retval.set_op(FunctionLibraryDefinition::kRetOp);
438 *retval.add_input() = noinline_node_name;
439 (*retval.mutable_attr())["T"].set_type(DT_INT32);
440 (*retval.mutable_attr())["index"].set_i(0);
441 Status status;
442 scope.graph()->AddNode(retval, &status);
443 TF_ASSERT_OK(status);
444
445 GraphDef expected;
446 TF_ASSERT_OK(scope.ToGraphDef(&expected));
447
448 InstantiationResultForTest result;
449 // Verify that increment_fn has been copied to library.
450 TF_EXPECT_OK(
451 InstantiateFunctionForTest(body_fn.name(), library, &result));
452
453 EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
454 EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
455 // Ignore the function library when comparing the graphs.
456 expected.clear_library();
457 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
458 }
459 }
460 }
461
TEST(FunctionalizeControlFlow,MissingFunctionDefInLibrary)462 TEST(FunctionalizeControlFlow, MissingFunctionDefInLibrary) {
463 const string& noinline_node_name = "while/increment_fn";
464 Graph graph(OpRegistry::Global());
465 {
466 Scope scope = Scope::NewRootScope().ExitOnError();
467 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
468 auto identity = ops::Identity(scope.WithOpName("while/Identity"), source);
469 TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph()));
470 TF_ASSERT_OK(scope.ToGraph(&graph));
471 }
472
473 FunctionLibraryDefinition lookup_lib(graph.flib_def());
474 FunctionLibraryDefinition library(OpRegistry::Global(), {});
475 GraphDef graph_def;
476 graph.ToGraphDef(&graph_def);
477 graph_def.clear_library();
478
479 Status status =
480 FunctionalizeControlFlowForGraphDef(&lookup_lib, &graph_def, &library);
481 EXPECT_EQ(tensorflow::error::NOT_FOUND, status.code());
482 }
483
484 // Tests functionalizing OneLoopVar where the loop value is not used post the
485 // loop.
486 // Graph:
487 // x = array_ops.placeholder(dtypes.int32)
488 // control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x])
TEST(FunctionalizeControlFlow,OneLoopVarWithoutExit)489 TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) {
490 Graph graph(OpRegistry::Global());
491 {
492 Scope scope = Scope::NewRootScope().ExitOnError();
493
494 auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
495
496 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
497 auto enter =
498 ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop");
499 auto merge = ops::Merge(scope.WithOpName("while/Merge"),
500 std::initializer_list<Input>{enter, dummy});
501 auto ten = ops::Const<int32>(
502 scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
503 10);
504 auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
505 auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
506 auto switch_ =
507 ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
508 auto identity =
509 ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
510 auto one = ops::Const<int32>(
511 scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
512 auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
513 auto next_iteration =
514 ops::NextIteration(scope.WithOpName("while/NextIteration"), add);
515
516 // Remove the dummy node and add the loop backedge.
517 scope.graph()->RemoveNode(dummy.node());
518 scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
519
520 TF_EXPECT_OK(scope.ToGraph(&graph));
521 }
522
523 FunctionLibraryDefinition library(OpRegistry::Global(), {});
524 GraphDef optimized_graph_def;
525 graph.ToGraphDef(&optimized_graph_def);
526 TF_ASSERT_OK(
527 FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
528 TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
529 GraphDef converted_graph_def;
530 graph.ToGraphDef(&converted_graph_def);
531
532 for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
533 NameAttrList cond_fn, body_fn;
534 TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
535
536 // Outer graph
537 {
538 Scope scope = Scope::NewRootScope().ExitOnError();
539 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
540 auto while_op =
541 ops::While(scope.WithOpName("while/LoopCond"),
542 std::initializer_list<Input>{source}, cond_fn, body_fn);
543 GraphDef expected;
544 TF_EXPECT_OK(scope.ToGraphDef(&expected));
545 TF_EXPECT_GRAPH_EQ(expected, graph_def);
546 }
547
548 // Condition graph
549 {
550 Scope scope = Scope::NewRootScope().ExitOnError();
551 auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
552 auto ten = ops::Const<int32>(
553 scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10);
554 auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten);
555 auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
556
557 GraphDef expected;
558 TF_EXPECT_OK(scope.ToGraphDef(&expected));
559
560 InstantiationResultForTest result;
561 TF_EXPECT_OK(
562 InstantiateFunctionForTest(cond_fn.name(), library, &result));
563
564 EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
565 EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
566 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
567 }
568
569 // Body graph.
570 {
571 Scope scope = Scope::NewRootScope().ExitOnError();
572 auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
573 auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
574 auto one = ops::Const<int32>(
575 scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
576 auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
577 auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
578
579 GraphDef expected;
580 TF_EXPECT_OK(scope.ToGraphDef(&expected));
581
582 InstantiationResultForTest result;
583 TF_EXPECT_OK(
584 InstantiateFunctionForTest(body_fn.name(), library, &result));
585
586 EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
587 EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
588 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
589 }
590 }
591 }
592
593 // Graph:
594 // x = array_ops.placeholder(dtypes.int32)
595 // y = array_ops.placeholder(dtypes.int32)
596 // cond = lambda (i, j): i + 3 < 10
597 // body = lambda (i, j): (i < 10, j * 2)
598 // z = control_flow_ops.while_loop(cond, body, [x, y])
TEST(FunctionalizeControlFlow,TwoLoopVars)599 TEST(FunctionalizeControlFlow, TwoLoopVars) {
600 Graph graph(OpRegistry::Global());
601 {
602 Scope scope = Scope::NewRootScope().ExitOnError();
603
604 auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
605
606 auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
607 auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
608 auto enter_x =
609 ops::internal::Enter(scope.WithOpName("while/Enter/x"), x, "aloop");
610 auto enter_y =
611 ops::internal::Enter(scope.WithOpName("while/Enter/y"), y, "aloop");
612 auto merge_x = ops::Merge(scope.WithOpName("while/Merge/x"),
613 std::initializer_list<Input>{enter_x, dummy});
614 auto merge_y = ops::Merge(scope.WithOpName("while/Merge/y"),
615 std::initializer_list<Input>{enter_y, dummy});
616
617 // Loop condition
618 auto three = ops::Const<int32>(scope.WithOpName("while/cond/three")
619 .WithControlDependencies(merge_x.output),
620 3);
621 auto cond_add =
622 ops::Add(scope.WithOpName("while/cond/Add"), merge_x.output, three);
623 auto ten = ops::Const<int32>(scope.WithOpName("while/cond/ten")
624 .WithControlDependencies(merge_x.output),
625 10);
626 auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten);
627 auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
628
629 auto switch_x = ops::Switch(scope.WithOpName("while/Switch/x"),
630 merge_x.output, loop_cond);
631 auto switch_y = ops::Switch(scope.WithOpName("while/Switch/y"),
632 merge_y.output, loop_cond);
633
634 auto exit_x = ops::internal::Exit(scope.WithOpName("while/Exit/x"),
635 switch_x.output_false);
636 auto exit_y = ops::internal::Exit(scope.WithOpName("while/Exit/y"),
637 switch_y.output_false);
638
639 auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"),
640 switch_x.output_true);
641 auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"),
642 switch_y.output_true);
643
644 auto one = ops::Const<int32>(
645 scope.WithOpName("while/add/one").WithControlDependencies(identity_x),
646 1);
647 auto two = ops::Const<int32>(
648 scope.WithOpName("while/mul/two").WithControlDependencies(identity_x),
649 2);
650
651 auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one);
652 auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two);
653 auto next_iteration_x =
654 ops::NextIteration(scope.WithOpName("while/NextIteration/x"), add);
655 auto next_iteration_y =
656 ops::NextIteration(scope.WithOpName("while/NextIteration/y"), mul);
657
658 auto sink_x = ops::Identity(scope.WithOpName("sink_x"), exit_x);
659 auto sink_y = ops::Identity(scope.WithOpName("sink_y"), exit_y);
660
661 // Remove the dummy node and add the loop backedges.
662 scope.graph()->RemoveNode(dummy.node());
663 scope.graph()->AddEdge(next_iteration_x.node(), 0, merge_x.output.node(),
664 1);
665 scope.graph()->AddEdge(next_iteration_y.node(), 0, merge_y.output.node(),
666 1);
667
668 TF_EXPECT_OK(scope.ToGraph(&graph));
669 }
670
671 FunctionLibraryDefinition library(OpRegistry::Global(), {});
672 GraphDef optimized_graph_def;
673 graph.ToGraphDef(&optimized_graph_def);
674 TF_ASSERT_OK(
675 FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
676 TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
677 GraphDef converted_graph_def;
678 graph.ToGraphDef(&converted_graph_def);
679
680 for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
681 NameAttrList cond_fn, body_fn;
682 TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
683
684 // Outer graph.
685 {
686 Scope scope = Scope::NewRootScope().ExitOnError();
687 auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
688 auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
689 auto while_op =
690 ops::While(scope.WithOpName("while/LoopCond"),
691 std::initializer_list<Input>{x, y}, cond_fn, body_fn);
692 auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]);
693 auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]);
694 GraphDef expected;
695 TF_EXPECT_OK(scope.ToGraphDef(&expected));
696 TF_EXPECT_GRAPH_EQ(expected, graph_def);
697 }
698
699 // Condition graph.
700 {
701 Scope scope = Scope::NewRootScope().ExitOnError();
702 auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
703 auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
704 auto three = ops::Const<int32>(scope.WithOpName("while/cond/three")
705 .WithControlDependencies(arg0.output),
706 3);
707 auto cond_add =
708 ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three);
709 auto ten = ops::Const<int32>(scope.WithOpName("while/cond/ten")
710 .WithControlDependencies(arg0.output),
711 10);
712 auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten);
713 auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
714
715 GraphDef expected;
716 TF_EXPECT_OK(scope.ToGraphDef(&expected));
717
718 InstantiationResultForTest result;
719 TF_EXPECT_OK(
720 InstantiateFunctionForTest(cond_fn.name(), library, &result));
721
722 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types);
723 EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
724 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
725 }
726
727 // Body graph.
728 {
729 Scope scope = Scope::NewRootScope().ExitOnError();
730 auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
731 auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
732
733 auto identity_x =
734 ops::Identity(scope.WithOpName("while/Identity/x"), arg0);
735 auto identity_y =
736 ops::Identity(scope.WithOpName("while/Identity/y"), arg1);
737
738 auto one = ops::Const<int32>(
739 scope.WithOpName("while/add/one").WithControlDependencies(identity_x),
740 1);
741 auto two = ops::Const<int32>(
742 scope.WithOpName("while/mul/two").WithControlDependencies(identity_x),
743 2);
744
745 auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one);
746 auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two);
747 auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
748 auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), mul, 1);
749
750 GraphDef expected;
751 TF_EXPECT_OK(scope.ToGraphDef(&expected));
752
753 InstantiationResultForTest result;
754 TF_EXPECT_OK(
755 InstantiateFunctionForTest(body_fn.name(), library, &result));
756
757 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types);
758 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.ret_types);
759 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
760 }
761 }
762 }
763
764 // Example with nesting, loop-invariant arguments, and resource variables.
765 //
766 // accum = resource_variable_ops.ResourceVariable(1)
767 // x = array_ops.placeholder(2, dtype=dtypes.int32)
768 // y = 3 + x
769 //
770 // def inner_body(j, k):
771 // add = state_ops.assign_add(accum, k * j + x)
772 // with ops.control_dependencies([add]):
773 // return [j + 1, k]
774 //
775 // def body(i):
776 // m = control_flow_ops.while_loop(lambda j, k: j < 5, inner_body,
777 // [1, y], name="inner")
778 // with ops.control_dependencies(m):
779 // return [i + 1]
780 //
781 // z = control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="outer")
TEST(FunctionalizeControlFlow,Complex)782 TEST(FunctionalizeControlFlow, Complex) {
783 Graph graph(OpRegistry::Global());
784 {
785 Scope scope = Scope::NewRootScope().ExitOnError();
786
787 auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
788
789 auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
790 auto three = ops::Const<int32>(scope.WithOpName("three"), 3);
791 auto y = ops::Add(scope.WithOpName("y"), x, three);
792
793 auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32,
794 TensorShape({}));
795
796 // Outer loop
797 auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
798 auto enter_i =
799 ops::internal::Enter(scope.WithOpName("outer/Enter_i"), zero, "outer");
800 auto merge_i = ops::Merge(scope.WithOpName("outer/Merge_i"),
801 std::initializer_list<Input>{enter_i, dummy});
802 auto ten = ops::Const<int32>(scope.WithOpName("outer/Less/y")
803 .WithControlDependencies(merge_i.output),
804 10);
805 auto less_i =
806 ops::Less(scope.WithOpName("outer/Less_i"), merge_i.output, ten);
807 auto outer_loop_cond =
808 ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_i);
809 auto switch_i = ops::Switch(scope.WithOpName("outer/Switch"),
810 merge_i.output, outer_loop_cond);
811 auto exit_i = ops::internal::Exit(scope.WithOpName("outer/Exit"),
812 switch_i.output_false);
813 auto identity_i =
814 ops::Identity(scope.WithOpName("outer/Identity"), switch_i.output_true);
815
816 auto enter_x_outer =
817 ops::internal::Enter(scope.WithOpName("outer/Enter_x"), x, "outer",
818 ops::internal::Enter::Attrs().IsConstant(true));
819 auto enter_k_outer =
820 ops::internal::Enter(scope.WithOpName("outer/Enter_k"), y, "outer",
821 ops::internal::Enter::Attrs().IsConstant(true));
822 auto enter_var_outer =
823 ops::internal::Enter(scope.WithOpName("outer/Enter_var"), var, "outer",
824 ops::internal::Enter::Attrs().IsConstant(true));
825
826 // Inner loop
827 auto one_j = ops::Const<int32>(
828 scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
829 auto enter_j = ops::internal::Enter(scope.WithOpName("outer/inner/Enter_j"),
830 one_j, "inner");
831 auto enter_k =
832 ops::internal::Enter(scope.WithOpName("outer/inner/Enter_k")
833 .WithControlDependencies(identity_i),
834 enter_k_outer, "inner");
835 auto enter_x = ops::internal::Enter(
836 scope.WithOpName("outer/inner/Enter_x"), enter_x_outer, "inner",
837 ops::internal::Enter::Attrs().IsConstant(true));
838 auto enter_var = ops::internal::Enter(
839 scope.WithOpName("outer/inner/Enter_var"), enter_var_outer, "inner",
840 ops::internal::Enter::Attrs().IsConstant(true));
841
842 auto merge_j = ops::Merge(scope.WithOpName("outer/inner/Merge_j"),
843 std::initializer_list<Input>{enter_j, dummy});
844 auto merge_k = ops::Merge(scope.WithOpName("outer/inner/Merge_k"),
845 std::initializer_list<Input>{enter_k, dummy});
846
847 auto five = ops::Const<int32>(scope.WithOpName("outer/inner/Five")
848 .WithControlDependencies(merge_j.output),
849 5);
850 auto less_j =
851 ops::Less(scope.WithOpName("outer/inner/Less_j"), merge_j.output, five);
852 auto loop_cond = ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_j);
853
854 auto switch_j = ops::Switch(scope.WithOpName("outer/inner/Switch_j"),
855 merge_j.output, loop_cond);
856 auto switch_k = ops::Switch(scope.WithOpName("outer/inner/Switch_k"),
857 merge_k.output, loop_cond);
858 auto exit_j = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_j"),
859 switch_j.output_false);
860 auto exit_k = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_k"),
861 switch_k.output_false);
862 auto identity_j = ops::Identity(scope.WithOpName("outer/inner/Identity_j"),
863 switch_j.output_true);
864 auto identity_k = ops::Identity(scope.WithOpName("outer/inner/Identity_k"),
865 switch_k.output_true);
866
867 // Variable update
868 auto mul_jk =
869 ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k);
870 auto add_jkx =
871 ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, enter_x);
872 auto assign = ops::AssignAddVariableOp(
873 scope.WithOpName("outer/inner/assign_add"), enter_var, add_jkx);
874
875 auto one = ops::Const<int32>(
876 scope.WithOpName("outer/inner/One")
877 .WithControlDependencies(
878 absl::Span<const Operation>{assign.operation}),
879 1);
880 auto add_j =
881 ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
882
883 auto next_iteration_j = ops::NextIteration(
884 scope.WithOpName("outer/inner/NextIteration_j"), add_j);
885 auto next_iteration_k = ops::NextIteration(
886 scope.WithOpName("outer/inner/NextIteration_k"), identity_k);
887
888 // Body and backedge for outer loop.
889 auto one_outer = ops::Const<int32>(
890 scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
891 auto add_i =
892 ops::Add(scope.WithOpName("outer/add")
893 .WithControlDependencies(absl::Span<const Operation>{
894 exit_j.output.op(), exit_k.output.op()}),
895 identity_i, one_outer);
896 auto next_iteration_i =
897 ops::NextIteration(scope.WithOpName("outer/NextIteration"), add_i);
898
899 auto sink = ops::Identity(scope.WithOpName("sink"), exit_i);
900
901 // Remove the dummy node and add the loop backedge.
902 scope.graph()->RemoveNode(dummy.node());
903 scope.graph()->AddEdge(next_iteration_i.node(), 0, merge_i.output.node(),
904 1);
905 scope.graph()->AddEdge(next_iteration_j.node(), 0, merge_j.output.node(),
906 1);
907 scope.graph()->AddEdge(next_iteration_k.node(), 0, merge_k.output.node(),
908 1);
909
910 TF_EXPECT_OK(scope.ToGraph(&graph));
911 }
912
913 FunctionLibraryDefinition library(OpRegistry::Global(), {});
914 GraphDef optimized_graph_def;
915 graph.ToGraphDef(&optimized_graph_def);
916 TF_ASSERT_OK(
917 FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
918 TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
919 GraphDef converted_graph_def;
920 graph.ToGraphDef(&converted_graph_def);
921
922 for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
923 NameAttrList outer_cond_fn, outer_body_fn;
924 TF_EXPECT_OK(
925 FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn));
926
927 // Outer graph.
928 {
929 Scope scope = Scope::NewRootScope().ExitOnError();
930 auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
931 auto three = ops::Const<int32>(scope.WithOpName("three"), 3);
932 auto y = ops::Add(scope.WithOpName("y"), x, three);
933
934 auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32,
935 TensorShape({}));
936
937 auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
938
939 auto while_op = ops::While(scope.WithOpName("outer/LoopCond"),
940 std::initializer_list<Input>{zero, y, x, var},
941 outer_cond_fn, outer_body_fn);
942 auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
943 GraphDef expected;
944 TF_EXPECT_OK(scope.ToGraphDef(&expected));
945 TF_EXPECT_GRAPH_EQ(expected, graph_def);
946 }
947
948 // Outer condition graph.
949 {
950 Scope scope = Scope::NewRootScope().ExitOnError();
951 auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
952 auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
953 auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
954 auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
955
956 auto ten = ops::Const<int32>(
957 scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output),
958 10);
959 auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten);
960 auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
961
962 GraphDef expected;
963 TF_EXPECT_OK(scope.ToGraphDef(&expected));
964
965 InstantiationResultForTest result;
966 TF_EXPECT_OK(
967 InstantiateFunctionForTest(outer_cond_fn.name(), library, &result));
968
969 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
970 result.arg_types);
971 EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
972 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
973 }
974
975 // Outer body graph.
976 NameAttrList inner_cond_fn, inner_body_fn;
977 {
978 InstantiationResultForTest result;
979 TF_EXPECT_OK(
980 InstantiateFunctionForTest(outer_body_fn.name(), library, &result));
981
982 // Find the inner condition and body names.
983 TF_EXPECT_OK(
984 FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn));
985
986 Scope scope = Scope::NewRootScope().ExitOnError();
987 auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
988 auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
989 auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
990 auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
991
992 auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0);
993 auto one_j = ops::Const<int32>(
994 scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
995 auto while_op =
996 ops::While(scope.WithOpName("outer/LoopCond_1"),
997 std::initializer_list<Input>{one_j, arg1, arg2, arg3},
998 inner_cond_fn, inner_body_fn);
999
1000 auto one_outer = ops::Const<int32>(
1001 scope.WithOpName("outer/add/y").WithControlDependencies(identity_i),
1002 1);
1003 auto add_i =
1004 ops::Add(scope.WithOpName("outer/add")
1005 .WithControlDependencies(absl::Span<const Operation>{
1006 while_op[0].op(), while_op[1].op()}),
1007 identity_i, one_outer);
1008
1009 auto retval0 =
1010 ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_i, 0);
1011 auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), arg1, 1);
1012 auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2);
1013
1014 GraphDef expected;
1015 TF_EXPECT_OK(scope.ToGraphDef(&expected));
1016
1017 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1018 result.arg_types);
1019 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}),
1020 result.ret_types);
1021 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
1022 }
1023
1024 // Inner condition graph.
1025 {
1026 Scope scope = Scope::NewRootScope().ExitOnError();
1027 auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
1028 auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
1029 auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
1030 auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
1031
1032 auto five = ops::Const<int32>(
1033 scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0),
1034 5);
1035 auto less_j =
1036 ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five);
1037 auto retval =
1038 ops::_Retval(scope.WithOpName("_retval0_RetVal"), less_j, 0);
1039
1040 GraphDef expected;
1041 TF_EXPECT_OK(scope.ToGraphDef(&expected));
1042
1043 InstantiationResultForTest result;
1044 TF_EXPECT_OK(
1045 InstantiateFunctionForTest(inner_cond_fn.name(), library, &result));
1046
1047 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1048 result.arg_types);
1049 EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
1050 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
1051 }
1052
1053 // Inner body graph.
1054 {
1055 Scope scope = Scope::NewRootScope().ExitOnError();
1056 auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
1057 auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
1058 auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
1059 auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
1060
1061 auto identity_j =
1062 ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0);
1063 auto identity_k =
1064 ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1);
1065
1066 auto mul_jk =
1067 ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k);
1068 auto add_jkx =
1069 ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2);
1070 auto assign = ops::AssignAddVariableOp(
1071 scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx);
1072
1073 auto one = ops::Const<int32>(
1074 scope.WithOpName("outer/inner/One")
1075 .WithControlDependencies(
1076 absl::Span<const Operation>{assign.operation}),
1077 1);
1078 auto add_j =
1079 ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
1080
1081 auto retval0 =
1082 ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_j, 0);
1083 auto retval1 =
1084 ops::_Retval(scope.WithOpName("_retval1_RetVal"), identity_k, 1);
1085 auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2);
1086
1087 GraphDef expected;
1088 TF_EXPECT_OK(scope.ToGraphDef(&expected));
1089
1090 InstantiationResultForTest result;
1091 TF_EXPECT_OK(
1092 InstantiateFunctionForTest(inner_body_fn.name(), library, &result));
1093
1094 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1095 result.arg_types);
1096 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}),
1097 result.ret_types);
1098 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
1099 }
1100 }
1101 }
1102
1103 } // namespace
1104 } // namespace tensorflow
1105