1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
17 
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
20 #include "tensorflow/cc/ops/function_ops.h"
21 #include "tensorflow/cc/ops/functional_ops.h"
22 #include "tensorflow/cc/ops/resource_variable_ops.h"
23 #include "tensorflow/cc/ops/standard_ops.h"
24 #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
25 #include "tensorflow/compiler/tf2xla/test_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/core/common_runtime/function.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/graph/graph_constructor.h"
32 #include "tensorflow/core/graph/graph_def_builder.h"
33 #include "tensorflow/core/graph/validate.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/core/platform/test.h"
36 #include "tensorflow/core/util/equal_graph_def.h"
37 
38 namespace tensorflow {
39 namespace {
40 
41 // Returns the names of the "then" and "else" functions for the If node in a
42 // graph.
FindIfThenAndElse(const GraphDef & graph,string * op_name,NameAttrList * then_fn,NameAttrList * else_fn)43 Status FindIfThenAndElse(const GraphDef& graph, string* op_name,
44                          NameAttrList* then_fn, NameAttrList* else_fn) {
45   for (const NodeDef& node : graph.node()) {
46     if (node.op() == "If") {
47       *op_name = node.name();
48       const NameAttrList* result;
49       TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result));
50       *then_fn = *result;
51       TF_RETURN_IF_ERROR(GetNodeAttr(node, "else_branch", &result));
52       *else_fn = *result;
53       return Status::OK();
54     }
55   }
56   return errors::NotFound("No If node found in graph");
57 }
58 
59 // Graph:
60 // x = array_ops.placeholder(dtypes.int32)
61 // y = array_ops.placeholder(dtypes.int32)
62 // z = control_flow_ops.cond(
63 //     math_ops.less(y, x), lambda: math_ops.multiply(y, 17),
64 //     lambda: math_ops.add(x, 23))
TEST(FunctionalizeControlFlow,Conditional)65 TEST(FunctionalizeControlFlow, Conditional) {
66   Graph graph(OpRegistry::Global());
67   {
68     Scope scope = Scope::NewRootScope().ExitOnError();
69 
70     auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
71     auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
72     auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
73     auto switch_1 = ops::Switch(scope.WithOpName("cond/Switch"), less, less);
74 
75     auto identity_t =
76         ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_true);
77     auto seventeen = ops::Const<int32>(
78         scope.WithOpName("cond").WithControlDependencies(identity_t), 17);
79     auto switch_2 = ops::Switch(scope.WithOpName("cond/Switch"), y, less);
80     auto mul = ops::Multiply(scope.WithOpName("cond/Mul"), switch_2.output_true,
81                              seventeen);
82 
83     auto identity_f =
84         ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_false);
85     auto twenty_three = ops::Const<int32>(
86         scope.WithOpName("cond").WithControlDependencies(identity_f), 23);
87     auto switch_3 = ops::Switch(scope.WithOpName("cond/Switch"), x, less);
88     auto add = ops::Add(scope.WithOpName("cond/false/add"),
89                         switch_3.output_false, twenty_three);
90 
91     auto merge = ops::Merge(scope.WithOpName("cond/Merge"),
92                             std::initializer_list<Input>{add, mul});
93 
94     TF_EXPECT_OK(scope.ToGraph(&graph));
95   }
96 
97   FunctionLibraryDefinition library(OpRegistry::Global(), {});
98   GraphDef optimized_graph_def;
99   graph.ToGraphDef(&optimized_graph_def);
100   TF_ASSERT_OK(
101       FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
102   TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
103   GraphDef converted_graph_def;
104   graph.ToGraphDef(&converted_graph_def);
105 
106   for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
107     string op_name;
108     NameAttrList then_fn;
109     NameAttrList else_fn;
110     TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn));
111     InstantiationResultForTest else_result;
112     TF_EXPECT_OK(
113         InstantiateFunctionForTest(else_fn.name(), library, &else_result));
114 
115     // Outer graph
116     {
117       Scope scope = Scope::NewRootScope().ExitOnError();
118       auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
119       auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
120       auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
121       auto if_op = ops::If(scope.WithOpName(op_name), less,
122                            std::initializer_list<Input>{less, y, x}, {DT_INT32},
123                            then_fn, else_fn);
124       auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]);
125       GraphDef expected;
126       TF_EXPECT_OK(scope.ToGraphDef(&expected));
127       TF_EXPECT_GRAPH_EQ(expected, graph_def);
128     }
129 
130     // then body.
131     {
132       Scope scope = Scope::NewRootScope().ExitOnError();
133       auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0);
134       auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
135       auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
136       auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0);
137       auto cond = ops::Const(
138           scope.WithOpName("cond").WithControlDependencies(identity), 17);
139       auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond);
140       auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), mul, 0);
141 
142       GraphDef expected;
143       TF_EXPECT_OK(scope.ToGraphDef(&expected));
144 
145       InstantiationResultForTest result;
146       TF_EXPECT_OK(
147           InstantiateFunctionForTest(then_fn.name(), library, &result));
148 
149       EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
150       EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}),
151                 result.arg_types);
152       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
153     }
154 
155     // else body.
156     {
157       Scope scope = Scope::NewRootScope().ExitOnError();
158       auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0);
159       auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
160       auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
161       auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0);
162       auto cond_1 = ops::Const(
163           scope.WithOpName("cond_1").WithControlDependencies(identity), 23);
164       auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1);
165       auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
166 
167       GraphDef expected;
168       TF_EXPECT_OK(scope.ToGraphDef(&expected));
169 
170       InstantiationResultForTest result;
171       TF_EXPECT_OK(
172           InstantiateFunctionForTest(else_fn.name(), library, &result));
173 
174       EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
175       EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}),
176                 result.arg_types);
177       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
178     }
179   }
180 }
181 
182 // Returns the names of the "cond" and "body" functions for the While node
183 // in a graph.
FindWhileCondAndBody(const GraphDef & graph,NameAttrList * cond,NameAttrList * body)184 Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond,
185                             NameAttrList* body) {
186   for (const NodeDef& node : graph.node()) {
187     if (node.op() == "While") {
188       const NameAttrList* result;
189       TF_RETURN_IF_ERROR(GetNodeAttr(node, "cond", &result));
190       *cond = *result;
191       TF_RETURN_IF_ERROR(GetNodeAttr(node, "body", &result));
192       *body = *result;
193       return Status::OK();
194     }
195   }
196   return errors::NotFound("No While node found in graph");
197 }
198 
199 // Graph:
200 // x = array_ops.placeholder(dtypes.int32)
201 // y = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x])
TEST(FunctionalizeControlFlow,OneLoopVar)202 TEST(FunctionalizeControlFlow, OneLoopVar) {
203   Graph graph(OpRegistry::Global());
204   {
205     Scope scope = Scope::NewRootScope().ExitOnError();
206 
207     auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
208 
209     auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
210     auto enter =
211         ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop");
212     // Add an unused Enter node. These should be ignored.
213     auto enter2 =
214         ops::internal::Enter(scope.WithOpName("while/Enter2"), source, "aloop");
215     auto merge = ops::Merge(scope.WithOpName("while/Merge"),
216                             std::initializer_list<Input>{enter, dummy});
217     auto ten = ops::Const<int32>(
218         scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
219         10);
220     auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
221     auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
222     auto switch_ =
223         ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
224     auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"),
225                                     switch_.output_false);
226     auto identity =
227         ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
228     auto one = ops::Const<int32>(
229         scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
230     auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
231     auto next_iteration =
232         ops::NextIteration(scope.WithOpName("while/NextIteration"), add);
233 
234     auto sink = ops::Identity(scope.WithOpName("sink"), exit);
235 
236     // Remove the dummy node and add the loop backedge.
237     scope.graph()->RemoveNode(dummy.node());
238     scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
239 
240     TF_EXPECT_OK(scope.ToGraph(&graph));
241   }
242 
243   // Regression test: control edges from an Enter node to the graph sink should
244   // be ignored.
245   for (Node* n : graph.nodes()) {
246     if (n->name() == "while/Enter") {
247       graph.AddControlEdge(n, graph.sink_node());
248     }
249   }
250 
251   FunctionLibraryDefinition library(OpRegistry::Global(), {});
252   GraphDef optimized_graph_def;
253   graph.ToGraphDef(&optimized_graph_def);
254   TF_ASSERT_OK(
255       FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
256   TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
257   GraphDef converted_graph_def;
258   graph.ToGraphDef(&converted_graph_def);
259 
260   for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
261     NameAttrList cond_fn, body_fn;
262     TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
263 
264     // Outer graph
265     {
266       Scope scope = Scope::NewRootScope().ExitOnError();
267       auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
268       auto while_op =
269           ops::While(scope.WithOpName("while/LoopCond"),
270                      std::initializer_list<Input>{source}, cond_fn, body_fn);
271       auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
272       GraphDef expected;
273       TF_EXPECT_OK(scope.ToGraphDef(&expected));
274       TF_EXPECT_GRAPH_EQ(expected, graph_def);
275     }
276 
277     // Condition graph
278     {
279       Scope scope = Scope::NewRootScope().ExitOnError();
280       auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
281       auto ten = ops::Const<int32>(
282           scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10);
283       auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten);
284       auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
285 
286       GraphDef expected;
287       TF_EXPECT_OK(scope.ToGraphDef(&expected));
288 
289       InstantiationResultForTest result;
290       TF_EXPECT_OK(
291           InstantiateFunctionForTest(cond_fn.name(), library, &result));
292 
293       EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
294       EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
295       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
296     }
297 
298     // Body graph.
299     {
300       Scope scope = Scope::NewRootScope().ExitOnError();
301       auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
302       auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
303       auto one = ops::Const<int32>(
304           scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
305       auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
306       auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
307 
308       GraphDef expected;
309       TF_EXPECT_OK(scope.ToGraphDef(&expected));
310 
311       InstantiationResultForTest result;
312       TF_EXPECT_OK(
313           InstantiateFunctionForTest(body_fn.name(), library, &result));
314 
315       EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
316       EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
317       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
318     }
319   }
320 }
321 
GetNoinlineFunctionDef()322 FunctionDef GetNoinlineFunctionDef() {
323   FunctionDef fdef = FunctionDefHelper::Create(
324       "increment_fn", {"x:int32"}, {"add:int32"}, {},
325       {
326           {{"add/y"}, "Const", {}, {{"dtype", DT_INT32}}},
327           {{"add_0"}, "Add", {"x", "add/y:output:0"}, {{"T", DT_INT32}}},
328       },
329       {{"add", "add_0:z:0"}});
330   (*fdef.mutable_attr())["_noinline"].set_b(true);
331   return fdef;
332 }
333 
334 // @function.Defun(noinline=True)
335 // def increment_fn(x):
336 //   return [x + 1]
337 // Define the above function, and add it to the given graph. It's used as the
338 // while loop body in NoinlineLoopBody test.
AddNoinlineFunctionToGraph(const string & node_name,Graph * graph)339 Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) {
340   FunctionDefLibrary fdef_lib;
341   *(fdef_lib.add_function()) = GetNoinlineFunctionDef();
342   TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdef_lib));
343   NodeDef increment_fn;
344   increment_fn.set_name(node_name);
345   increment_fn.set_op("increment_fn");
346   *increment_fn.add_input() = "while/Identity";
347   *increment_fn.add_input() = "^while/Identity";
348   Status status;
349   graph->AddNode(increment_fn, &status);
350   return status;
351 }
352 
353 // Graph:
354 // x = array_ops.placeholder(dtypes.int32)
355 // y = control_flow_ops.while_loop(lambda i: i < 10, increment_fn, [x])
TEST(FunctionalizeControlFlow,NoinlineLoopBody)356 TEST(FunctionalizeControlFlow, NoinlineLoopBody) {
357   const string& noinline_node_name = "while/increment_fn";
358   Graph graph(OpRegistry::Global());
359   {
360     Scope scope = Scope::NewRootScope().ExitOnError();
361     auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
362     auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
363     auto enter = ops::internal::Enter(scope.WithOpName("while/Enter"), source,
364                                       "while/while_context");
365     auto merge = ops::Merge(scope.WithOpName("while/Merge"),
366                             std::initializer_list<Input>{enter, dummy});
367     auto ten = ops::Const<int32>(
368         scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
369         10);
370     auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
371     auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
372     auto switch_ =
373         ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
374     auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"),
375                                     switch_.output_false);
376     auto identity =
377         ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
378 
379     TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph()));
380 
381     NodeDef next_iter;
382     next_iter.set_name("while/NextIteration");
383     next_iter.set_op("NextIteration");
384     *next_iter.add_input() = noinline_node_name;
385     (*next_iter.mutable_attr())["T"].set_type(DT_INT32);
386 
387     Status status;
388     Node* n = scope.graph()->AddNode(next_iter, &status);
389     TF_ASSERT_OK(status);
390 
391     // Remove the dummy node and add the loop backedge.
392     scope.graph()->RemoveNode(dummy.node());
393     scope.graph()->AddEdge(n, 0, merge.output.node(), 1);
394     TF_ASSERT_OK(scope.ToGraph(&graph));
395   }
396 
397   FunctionLibraryDefinition lookup_lib(graph.flib_def());
398   FunctionLibraryDefinition library(OpRegistry::Global(), {});
399   // Function increment_fn will be copied from lookup_lib to library.
400   GraphDef optimized_graph_def;
401   graph.ToGraphDef(&optimized_graph_def);
402 
403   *(optimized_graph_def.mutable_library()->add_function()) =
404       GetNoinlineFunctionDef();
405 
406   TF_ASSERT_OK(FunctionalizeControlFlowForGraphDef(
407       &lookup_lib, &optimized_graph_def, &library));
408   TF_ASSERT_OK(FunctionalizeControlFlow(&lookup_lib, &graph, &library));
409   GraphDef converted_graph_def;
410   graph.ToGraphDef(&converted_graph_def);
411 
412   for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
413     NameAttrList cond_fn, body_fn;
414     TF_ASSERT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
415 
416     // Outer graph
417     {
418       Scope scope = Scope::NewRootScope().ExitOnError();
419       auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
420       auto while_op =
421           ops::While(scope.WithOpName("while/LoopCond"),
422                      std::initializer_list<Input>{source}, cond_fn, body_fn);
423       GraphDef expected;
424       TF_ASSERT_OK(scope.ToGraphDef(&expected));
425       TF_EXPECT_GRAPH_EQ(expected, graph_def);
426     }
427 
428     // Body graph.
429     {
430       Scope scope = Scope::NewRootScope().ExitOnError();
431       auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
432       TF_ASSERT_OK(
433           AddNoinlineFunctionToGraph(noinline_node_name, scope.graph()));
434       auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
435       NodeDef retval;
436       retval.set_name("_retval0_RetVal");
437       retval.set_op(FunctionLibraryDefinition::kRetOp);
438       *retval.add_input() = noinline_node_name;
439       (*retval.mutable_attr())["T"].set_type(DT_INT32);
440       (*retval.mutable_attr())["index"].set_i(0);
441       Status status;
442       scope.graph()->AddNode(retval, &status);
443       TF_ASSERT_OK(status);
444 
445       GraphDef expected;
446       TF_ASSERT_OK(scope.ToGraphDef(&expected));
447 
448       InstantiationResultForTest result;
449       // Verify that increment_fn has been copied to library.
450       TF_EXPECT_OK(
451           InstantiateFunctionForTest(body_fn.name(), library, &result));
452 
453       EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
454       EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
455       // Ignore the function library when comparing the graphs.
456       expected.clear_library();
457       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
458     }
459   }
460 }
461 
TEST(FunctionalizeControlFlow,MissingFunctionDefInLibrary)462 TEST(FunctionalizeControlFlow, MissingFunctionDefInLibrary) {
463   const string& noinline_node_name = "while/increment_fn";
464   Graph graph(OpRegistry::Global());
465   {
466     Scope scope = Scope::NewRootScope().ExitOnError();
467     auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
468     auto identity = ops::Identity(scope.WithOpName("while/Identity"), source);
469     TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph()));
470     TF_ASSERT_OK(scope.ToGraph(&graph));
471   }
472 
473   FunctionLibraryDefinition lookup_lib(graph.flib_def());
474   FunctionLibraryDefinition library(OpRegistry::Global(), {});
475   GraphDef graph_def;
476   graph.ToGraphDef(&graph_def);
477   graph_def.clear_library();
478 
479   Status status =
480       FunctionalizeControlFlowForGraphDef(&lookup_lib, &graph_def, &library);
481   EXPECT_EQ(tensorflow::error::NOT_FOUND, status.code());
482 }
483 
484 // Tests functionalizing OneLoopVar where the loop value is not used post the
485 // loop.
486 // Graph:
487 // x = array_ops.placeholder(dtypes.int32)
488 // control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x])
TEST(FunctionalizeControlFlow,OneLoopVarWithoutExit)489 TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) {
490   Graph graph(OpRegistry::Global());
491   {
492     Scope scope = Scope::NewRootScope().ExitOnError();
493 
494     auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
495 
496     auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
497     auto enter =
498         ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop");
499     auto merge = ops::Merge(scope.WithOpName("while/Merge"),
500                             std::initializer_list<Input>{enter, dummy});
501     auto ten = ops::Const<int32>(
502         scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
503         10);
504     auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
505     auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
506     auto switch_ =
507         ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
508     auto identity =
509         ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
510     auto one = ops::Const<int32>(
511         scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
512     auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
513     auto next_iteration =
514         ops::NextIteration(scope.WithOpName("while/NextIteration"), add);
515 
516     // Remove the dummy node and add the loop backedge.
517     scope.graph()->RemoveNode(dummy.node());
518     scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
519 
520     TF_EXPECT_OK(scope.ToGraph(&graph));
521   }
522 
523   FunctionLibraryDefinition library(OpRegistry::Global(), {});
524   GraphDef optimized_graph_def;
525   graph.ToGraphDef(&optimized_graph_def);
526   TF_ASSERT_OK(
527       FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
528   TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
529   GraphDef converted_graph_def;
530   graph.ToGraphDef(&converted_graph_def);
531 
532   for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
533     NameAttrList cond_fn, body_fn;
534     TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
535 
536     // Outer graph
537     {
538       Scope scope = Scope::NewRootScope().ExitOnError();
539       auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
540       auto while_op =
541           ops::While(scope.WithOpName("while/LoopCond"),
542                      std::initializer_list<Input>{source}, cond_fn, body_fn);
543       GraphDef expected;
544       TF_EXPECT_OK(scope.ToGraphDef(&expected));
545       TF_EXPECT_GRAPH_EQ(expected, graph_def);
546     }
547 
548     // Condition graph
549     {
550       Scope scope = Scope::NewRootScope().ExitOnError();
551       auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
552       auto ten = ops::Const<int32>(
553           scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10);
554       auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten);
555       auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
556 
557       GraphDef expected;
558       TF_EXPECT_OK(scope.ToGraphDef(&expected));
559 
560       InstantiationResultForTest result;
561       TF_EXPECT_OK(
562           InstantiateFunctionForTest(cond_fn.name(), library, &result));
563 
564       EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
565       EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
566       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
567     }
568 
569     // Body graph.
570     {
571       Scope scope = Scope::NewRootScope().ExitOnError();
572       auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
573       auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
574       auto one = ops::Const<int32>(
575           scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
576       auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
577       auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
578 
579       GraphDef expected;
580       TF_EXPECT_OK(scope.ToGraphDef(&expected));
581 
582       InstantiationResultForTest result;
583       TF_EXPECT_OK(
584           InstantiateFunctionForTest(body_fn.name(), library, &result));
585 
586       EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
587       EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
588       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
589     }
590   }
591 }
592 
593 // Graph:
594 // x = array_ops.placeholder(dtypes.int32)
595 // y = array_ops.placeholder(dtypes.int32)
596 // cond = lambda (i, j): i + 3 < 10
597 // body = lambda (i, j): (i < 10, j * 2)
598 // z = control_flow_ops.while_loop(cond, body, [x, y])
TEST(FunctionalizeControlFlow,TwoLoopVars)599 TEST(FunctionalizeControlFlow, TwoLoopVars) {
600   Graph graph(OpRegistry::Global());
601   {
602     Scope scope = Scope::NewRootScope().ExitOnError();
603 
604     auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
605 
606     auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
607     auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
608     auto enter_x =
609         ops::internal::Enter(scope.WithOpName("while/Enter/x"), x, "aloop");
610     auto enter_y =
611         ops::internal::Enter(scope.WithOpName("while/Enter/y"), y, "aloop");
612     auto merge_x = ops::Merge(scope.WithOpName("while/Merge/x"),
613                               std::initializer_list<Input>{enter_x, dummy});
614     auto merge_y = ops::Merge(scope.WithOpName("while/Merge/y"),
615                               std::initializer_list<Input>{enter_y, dummy});
616 
617     // Loop condition
618     auto three = ops::Const<int32>(scope.WithOpName("while/cond/three")
619                                        .WithControlDependencies(merge_x.output),
620                                    3);
621     auto cond_add =
622         ops::Add(scope.WithOpName("while/cond/Add"), merge_x.output, three);
623     auto ten = ops::Const<int32>(scope.WithOpName("while/cond/ten")
624                                      .WithControlDependencies(merge_x.output),
625                                  10);
626     auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten);
627     auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
628 
629     auto switch_x = ops::Switch(scope.WithOpName("while/Switch/x"),
630                                 merge_x.output, loop_cond);
631     auto switch_y = ops::Switch(scope.WithOpName("while/Switch/y"),
632                                 merge_y.output, loop_cond);
633 
634     auto exit_x = ops::internal::Exit(scope.WithOpName("while/Exit/x"),
635                                       switch_x.output_false);
636     auto exit_y = ops::internal::Exit(scope.WithOpName("while/Exit/y"),
637                                       switch_y.output_false);
638 
639     auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"),
640                                     switch_x.output_true);
641     auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"),
642                                     switch_y.output_true);
643 
644     auto one = ops::Const<int32>(
645         scope.WithOpName("while/add/one").WithControlDependencies(identity_x),
646         1);
647     auto two = ops::Const<int32>(
648         scope.WithOpName("while/mul/two").WithControlDependencies(identity_x),
649         2);
650 
651     auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one);
652     auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two);
653     auto next_iteration_x =
654         ops::NextIteration(scope.WithOpName("while/NextIteration/x"), add);
655     auto next_iteration_y =
656         ops::NextIteration(scope.WithOpName("while/NextIteration/y"), mul);
657 
658     auto sink_x = ops::Identity(scope.WithOpName("sink_x"), exit_x);
659     auto sink_y = ops::Identity(scope.WithOpName("sink_y"), exit_y);
660 
661     // Remove the dummy node and add the loop backedges.
662     scope.graph()->RemoveNode(dummy.node());
663     scope.graph()->AddEdge(next_iteration_x.node(), 0, merge_x.output.node(),
664                            1);
665     scope.graph()->AddEdge(next_iteration_y.node(), 0, merge_y.output.node(),
666                            1);
667 
668     TF_EXPECT_OK(scope.ToGraph(&graph));
669   }
670 
671   FunctionLibraryDefinition library(OpRegistry::Global(), {});
672   GraphDef optimized_graph_def;
673   graph.ToGraphDef(&optimized_graph_def);
674   TF_ASSERT_OK(
675       FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
676   TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
677   GraphDef converted_graph_def;
678   graph.ToGraphDef(&converted_graph_def);
679 
680   for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
681     NameAttrList cond_fn, body_fn;
682     TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
683 
684     // Outer graph.
685     {
686       Scope scope = Scope::NewRootScope().ExitOnError();
687       auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
688       auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
689       auto while_op =
690           ops::While(scope.WithOpName("while/LoopCond"),
691                      std::initializer_list<Input>{x, y}, cond_fn, body_fn);
692       auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]);
693       auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]);
694       GraphDef expected;
695       TF_EXPECT_OK(scope.ToGraphDef(&expected));
696       TF_EXPECT_GRAPH_EQ(expected, graph_def);
697     }
698 
699     // Condition graph.
700     {
701       Scope scope = Scope::NewRootScope().ExitOnError();
702       auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
703       auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
704       auto three = ops::Const<int32>(scope.WithOpName("while/cond/three")
705                                          .WithControlDependencies(arg0.output),
706                                      3);
707       auto cond_add =
708           ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three);
709       auto ten = ops::Const<int32>(scope.WithOpName("while/cond/ten")
710                                        .WithControlDependencies(arg0.output),
711                                    10);
712       auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten);
713       auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
714 
715       GraphDef expected;
716       TF_EXPECT_OK(scope.ToGraphDef(&expected));
717 
718       InstantiationResultForTest result;
719       TF_EXPECT_OK(
720           InstantiateFunctionForTest(cond_fn.name(), library, &result));
721 
722       EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types);
723       EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
724       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
725     }
726 
727     // Body graph.
728     {
729       Scope scope = Scope::NewRootScope().ExitOnError();
730       auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
731       auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
732 
733       auto identity_x =
734           ops::Identity(scope.WithOpName("while/Identity/x"), arg0);
735       auto identity_y =
736           ops::Identity(scope.WithOpName("while/Identity/y"), arg1);
737 
738       auto one = ops::Const<int32>(
739           scope.WithOpName("while/add/one").WithControlDependencies(identity_x),
740           1);
741       auto two = ops::Const<int32>(
742           scope.WithOpName("while/mul/two").WithControlDependencies(identity_x),
743           2);
744 
745       auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one);
746       auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two);
747       auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
748       auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), mul, 1);
749 
750       GraphDef expected;
751       TF_EXPECT_OK(scope.ToGraphDef(&expected));
752 
753       InstantiationResultForTest result;
754       TF_EXPECT_OK(
755           InstantiateFunctionForTest(body_fn.name(), library, &result));
756 
757       EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types);
758       EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.ret_types);
759       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
760     }
761   }
762 }
763 
764 // Example with nesting, loop-invariant arguments, and resource variables.
765 //
766 // accum = resource_variable_ops.ResourceVariable(1)
767 // x = array_ops.placeholder(2, dtype=dtypes.int32)
768 // y = 3 + x
769 //
770 // def inner_body(j, k):
771 //   add = state_ops.assign_add(accum, k * j + x)
772 //   with ops.control_dependencies([add]):
773 //     return [j + 1, k]
774 //
775 // def body(i):
776 //   m = control_flow_ops.while_loop(lambda j, k: j < 5, inner_body,
777 //                                   [1, y], name="inner")
778 //   with ops.control_dependencies(m):
779 //     return [i + 1]
780 //
781 // z = control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="outer")
TEST(FunctionalizeControlFlow,Complex)782 TEST(FunctionalizeControlFlow, Complex) {
783   Graph graph(OpRegistry::Global());
784   {
785     Scope scope = Scope::NewRootScope().ExitOnError();
786 
787     auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
788 
789     auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
790     auto three = ops::Const<int32>(scope.WithOpName("three"), 3);
791     auto y = ops::Add(scope.WithOpName("y"), x, three);
792 
793     auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32,
794                                 TensorShape({}));
795 
796     // Outer loop
797     auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
798     auto enter_i =
799         ops::internal::Enter(scope.WithOpName("outer/Enter_i"), zero, "outer");
800     auto merge_i = ops::Merge(scope.WithOpName("outer/Merge_i"),
801                               std::initializer_list<Input>{enter_i, dummy});
802     auto ten = ops::Const<int32>(scope.WithOpName("outer/Less/y")
803                                      .WithControlDependencies(merge_i.output),
804                                  10);
805     auto less_i =
806         ops::Less(scope.WithOpName("outer/Less_i"), merge_i.output, ten);
807     auto outer_loop_cond =
808         ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_i);
809     auto switch_i = ops::Switch(scope.WithOpName("outer/Switch"),
810                                 merge_i.output, outer_loop_cond);
811     auto exit_i = ops::internal::Exit(scope.WithOpName("outer/Exit"),
812                                       switch_i.output_false);
813     auto identity_i =
814         ops::Identity(scope.WithOpName("outer/Identity"), switch_i.output_true);
815 
816     auto enter_x_outer =
817         ops::internal::Enter(scope.WithOpName("outer/Enter_x"), x, "outer",
818                              ops::internal::Enter::Attrs().IsConstant(true));
819     auto enter_k_outer =
820         ops::internal::Enter(scope.WithOpName("outer/Enter_k"), y, "outer",
821                              ops::internal::Enter::Attrs().IsConstant(true));
822     auto enter_var_outer =
823         ops::internal::Enter(scope.WithOpName("outer/Enter_var"), var, "outer",
824                              ops::internal::Enter::Attrs().IsConstant(true));
825 
826     // Inner loop
827     auto one_j = ops::Const<int32>(
828         scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
829     auto enter_j = ops::internal::Enter(scope.WithOpName("outer/inner/Enter_j"),
830                                         one_j, "inner");
831     auto enter_k =
832         ops::internal::Enter(scope.WithOpName("outer/inner/Enter_k")
833                                  .WithControlDependencies(identity_i),
834                              enter_k_outer, "inner");
835     auto enter_x = ops::internal::Enter(
836         scope.WithOpName("outer/inner/Enter_x"), enter_x_outer, "inner",
837         ops::internal::Enter::Attrs().IsConstant(true));
838     auto enter_var = ops::internal::Enter(
839         scope.WithOpName("outer/inner/Enter_var"), enter_var_outer, "inner",
840         ops::internal::Enter::Attrs().IsConstant(true));
841 
842     auto merge_j = ops::Merge(scope.WithOpName("outer/inner/Merge_j"),
843                               std::initializer_list<Input>{enter_j, dummy});
844     auto merge_k = ops::Merge(scope.WithOpName("outer/inner/Merge_k"),
845                               std::initializer_list<Input>{enter_k, dummy});
846 
847     auto five = ops::Const<int32>(scope.WithOpName("outer/inner/Five")
848                                       .WithControlDependencies(merge_j.output),
849                                   5);
850     auto less_j =
851         ops::Less(scope.WithOpName("outer/inner/Less_j"), merge_j.output, five);
852     auto loop_cond = ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_j);
853 
854     auto switch_j = ops::Switch(scope.WithOpName("outer/inner/Switch_j"),
855                                 merge_j.output, loop_cond);
856     auto switch_k = ops::Switch(scope.WithOpName("outer/inner/Switch_k"),
857                                 merge_k.output, loop_cond);
858     auto exit_j = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_j"),
859                                       switch_j.output_false);
860     auto exit_k = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_k"),
861                                       switch_k.output_false);
862     auto identity_j = ops::Identity(scope.WithOpName("outer/inner/Identity_j"),
863                                     switch_j.output_true);
864     auto identity_k = ops::Identity(scope.WithOpName("outer/inner/Identity_k"),
865                                     switch_k.output_true);
866 
867     // Variable update
868     auto mul_jk =
869         ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k);
870     auto add_jkx =
871         ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, enter_x);
872     auto assign = ops::AssignAddVariableOp(
873         scope.WithOpName("outer/inner/assign_add"), enter_var, add_jkx);
874 
875     auto one = ops::Const<int32>(
876         scope.WithOpName("outer/inner/One")
877             .WithControlDependencies(
878                 absl::Span<const Operation>{assign.operation}),
879         1);
880     auto add_j =
881         ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
882 
883     auto next_iteration_j = ops::NextIteration(
884         scope.WithOpName("outer/inner/NextIteration_j"), add_j);
885     auto next_iteration_k = ops::NextIteration(
886         scope.WithOpName("outer/inner/NextIteration_k"), identity_k);
887 
888     // Body and backedge for outer loop.
889     auto one_outer = ops::Const<int32>(
890         scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
891     auto add_i =
892         ops::Add(scope.WithOpName("outer/add")
893                      .WithControlDependencies(absl::Span<const Operation>{
894                          exit_j.output.op(), exit_k.output.op()}),
895                  identity_i, one_outer);
896     auto next_iteration_i =
897         ops::NextIteration(scope.WithOpName("outer/NextIteration"), add_i);
898 
899     auto sink = ops::Identity(scope.WithOpName("sink"), exit_i);
900 
901     // Remove the dummy node and add the loop backedge.
902     scope.graph()->RemoveNode(dummy.node());
903     scope.graph()->AddEdge(next_iteration_i.node(), 0, merge_i.output.node(),
904                            1);
905     scope.graph()->AddEdge(next_iteration_j.node(), 0, merge_j.output.node(),
906                            1);
907     scope.graph()->AddEdge(next_iteration_k.node(), 0, merge_k.output.node(),
908                            1);
909 
910     TF_EXPECT_OK(scope.ToGraph(&graph));
911   }
912 
913   FunctionLibraryDefinition library(OpRegistry::Global(), {});
914   GraphDef optimized_graph_def;
915   graph.ToGraphDef(&optimized_graph_def);
916   TF_ASSERT_OK(
917       FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
918   TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
919   GraphDef converted_graph_def;
920   graph.ToGraphDef(&converted_graph_def);
921 
922   for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
923     NameAttrList outer_cond_fn, outer_body_fn;
924     TF_EXPECT_OK(
925         FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn));
926 
927     // Outer graph.
928     {
929       Scope scope = Scope::NewRootScope().ExitOnError();
930       auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
931       auto three = ops::Const<int32>(scope.WithOpName("three"), 3);
932       auto y = ops::Add(scope.WithOpName("y"), x, three);
933 
934       auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32,
935                                   TensorShape({}));
936 
937       auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
938 
939       auto while_op = ops::While(scope.WithOpName("outer/LoopCond"),
940                                  std::initializer_list<Input>{zero, y, x, var},
941                                  outer_cond_fn, outer_body_fn);
942       auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
943       GraphDef expected;
944       TF_EXPECT_OK(scope.ToGraphDef(&expected));
945       TF_EXPECT_GRAPH_EQ(expected, graph_def);
946     }
947 
948     // Outer condition graph.
949     {
950       Scope scope = Scope::NewRootScope().ExitOnError();
951       auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
952       auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
953       auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
954       auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
955 
956       auto ten = ops::Const<int32>(
957           scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output),
958           10);
959       auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten);
960       auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
961 
962       GraphDef expected;
963       TF_EXPECT_OK(scope.ToGraphDef(&expected));
964 
965       InstantiationResultForTest result;
966       TF_EXPECT_OK(
967           InstantiateFunctionForTest(outer_cond_fn.name(), library, &result));
968 
969       EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
970                 result.arg_types);
971       EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
972       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
973     }
974 
975     // Outer body graph.
976     NameAttrList inner_cond_fn, inner_body_fn;
977     {
978       InstantiationResultForTest result;
979       TF_EXPECT_OK(
980           InstantiateFunctionForTest(outer_body_fn.name(), library, &result));
981 
982       // Find the inner condition and body names.
983       TF_EXPECT_OK(
984           FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn));
985 
986       Scope scope = Scope::NewRootScope().ExitOnError();
987       auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
988       auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
989       auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
990       auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
991 
992       auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0);
993       auto one_j = ops::Const<int32>(
994           scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
995       auto while_op =
996           ops::While(scope.WithOpName("outer/LoopCond_1"),
997                      std::initializer_list<Input>{one_j, arg1, arg2, arg3},
998                      inner_cond_fn, inner_body_fn);
999 
1000       auto one_outer = ops::Const<int32>(
1001           scope.WithOpName("outer/add/y").WithControlDependencies(identity_i),
1002           1);
1003       auto add_i =
1004           ops::Add(scope.WithOpName("outer/add")
1005                        .WithControlDependencies(absl::Span<const Operation>{
1006                            while_op[0].op(), while_op[1].op()}),
1007                    identity_i, one_outer);
1008 
1009       auto retval0 =
1010           ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_i, 0);
1011       auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), arg1, 1);
1012       auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2);
1013 
1014       GraphDef expected;
1015       TF_EXPECT_OK(scope.ToGraphDef(&expected));
1016 
1017       EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1018                 result.arg_types);
1019       EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}),
1020                 result.ret_types);
1021       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
1022     }
1023 
1024     // Inner condition graph.
1025     {
1026       Scope scope = Scope::NewRootScope().ExitOnError();
1027       auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
1028       auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
1029       auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
1030       auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
1031 
1032       auto five = ops::Const<int32>(
1033           scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0),
1034           5);
1035       auto less_j =
1036           ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five);
1037       auto retval =
1038           ops::_Retval(scope.WithOpName("_retval0_RetVal"), less_j, 0);
1039 
1040       GraphDef expected;
1041       TF_EXPECT_OK(scope.ToGraphDef(&expected));
1042 
1043       InstantiationResultForTest result;
1044       TF_EXPECT_OK(
1045           InstantiateFunctionForTest(inner_cond_fn.name(), library, &result));
1046 
1047       EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1048                 result.arg_types);
1049       EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
1050       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
1051     }
1052 
1053     // Inner body graph.
1054     {
1055       Scope scope = Scope::NewRootScope().ExitOnError();
1056       auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
1057       auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
1058       auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
1059       auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
1060 
1061       auto identity_j =
1062           ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0);
1063       auto identity_k =
1064           ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1);
1065 
1066       auto mul_jk =
1067           ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k);
1068       auto add_jkx =
1069           ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2);
1070       auto assign = ops::AssignAddVariableOp(
1071           scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx);
1072 
1073       auto one = ops::Const<int32>(
1074           scope.WithOpName("outer/inner/One")
1075               .WithControlDependencies(
1076                   absl::Span<const Operation>{assign.operation}),
1077           1);
1078       auto add_j =
1079           ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
1080 
1081       auto retval0 =
1082           ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_j, 0);
1083       auto retval1 =
1084           ops::_Retval(scope.WithOpName("_retval1_RetVal"), identity_k, 1);
1085       auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2);
1086 
1087       GraphDef expected;
1088       TF_EXPECT_OK(scope.ToGraphDef(&expected));
1089 
1090       InstantiationResultForTest result;
1091       TF_EXPECT_OK(
1092           InstantiateFunctionForTest(inner_body_fn.name(), library, &result));
1093 
1094       EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1095                 result.arg_types);
1096       EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}),
1097                 result.ret_types);
1098       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
1099     }
1100   }
1101 }
1102 
1103 }  // namespace
1104 }  // namespace tensorflow
1105