1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
17
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
20 #include "tensorflow/cc/ops/function_ops.h"
21 #include "tensorflow/cc/ops/resource_variable_ops.h"
22 #include "tensorflow/cc/ops/standard_ops.h"
23 #include "tensorflow/compiler/tf2xla/cc/ops/functional_ops.h"
24 #include "tensorflow/compiler/tf2xla/test_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/core/common_runtime/function.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/op.h"
30 #include "tensorflow/core/graph/graph_constructor.h"
31 #include "tensorflow/core/graph/graph_def_builder.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/platform/test.h"
34 #include "tensorflow/core/util/equal_graph_def.h"
35
36 namespace tensorflow {
37 namespace {
38
39 // Returns the names of the "then" and "else" functions for the XlaIf node in a
40 // graph.
FindIfThenAndElse(const GraphDef & graph,string * op_name,NameAttrList * then_fn,NameAttrList * else_fn)41 Status FindIfThenAndElse(const GraphDef& graph, string* op_name,
42 NameAttrList* then_fn, NameAttrList* else_fn) {
43 for (const NodeDef& node : graph.node()) {
44 if (node.op() == "XlaIf") {
45 *op_name = node.name();
46 const NameAttrList* result;
47 TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result));
48 *then_fn = *result;
49 TF_RETURN_IF_ERROR(GetNodeAttr(node, "else_branch", &result));
50 *else_fn = *result;
51 return Status::OK();
52 }
53 }
54 return errors::NotFound("No XlaIf node found in graph");
55 }
56
57 // Graph:
58 // x = array_ops.placeholder(dtypes.int32)
59 // y = array_ops.placeholder(dtypes.int32)
60 // z = control_flow_ops.cond(
61 // math_ops.less(y, x), lambda: math_ops.multiply(y, 17),
62 // lambda: math_ops.add(x, 23))
TEST(FunctionalizeControlFlow,Conditional)63 TEST(FunctionalizeControlFlow, Conditional) {
64 Graph graph(OpRegistry::Global());
65 {
66 Scope scope = Scope::NewRootScope().ExitOnError();
67
68 auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
69 auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
70 auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
71 auto switch_1 = ops::Switch(scope.WithOpName("cond/Switch"), less, less);
72
73 auto identity_t =
74 ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_true);
75 auto seventeen = ops::Const<int32>(
76 scope.WithOpName("cond").WithControlDependencies(identity_t), 17);
77 auto switch_2 = ops::Switch(scope.WithOpName("cond/Switch"), y, less);
78 auto mul = ops::Multiply(scope.WithOpName("cond/Mul"), switch_2.output_true,
79 seventeen);
80
81 auto identity_f =
82 ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_false);
83 auto twenty_three = ops::Const<int32>(
84 scope.WithOpName("cond").WithControlDependencies(identity_f), 23);
85 auto switch_3 = ops::Switch(scope.WithOpName("cond/Switch"), x, less);
86 auto add = ops::Add(scope.WithOpName("cond/false/add"),
87 switch_3.output_false, twenty_three);
88
89 auto merge = ops::Merge(scope.WithOpName("cond/Merge"),
90 std::initializer_list<Input>{add, mul});
91
92 TF_EXPECT_OK(scope.ToGraph(&graph));
93 }
94
95 FunctionLibraryDefinition library(OpRegistry::Global(), {});
96 TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
97
98 GraphDef graph_def;
99 graph.ToGraphDef(&graph_def);
100 string op_name;
101 NameAttrList then_fn;
102 NameAttrList else_fn;
103 TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn));
104 InstantiationResultForTest else_result;
105 TF_EXPECT_OK(
106 InstantiateFunctionForTest(else_fn.name(), library, &else_result));
107
108 // Outer graph
109 {
110 Scope scope = Scope::NewRootScope().ExitOnError();
111 auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
112 auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
113 auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
114 auto if_op = ops::XlaIf(scope.WithOpName(op_name), less,
115 std::initializer_list<Input>{less, y, x}, then_fn,
116 else_fn, {DT_INT32});
117 GraphDef expected;
118 TF_EXPECT_OK(scope.ToGraphDef(&expected));
119 TF_EXPECT_GRAPH_EQ(expected, graph_def);
120 }
121
122 // then body.
123 {
124 Scope scope = Scope::NewRootScope().ExitOnError();
125 auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0);
126 auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
127 auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
128 auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0);
129 auto cond = ops::Const(
130 scope.WithOpName("cond").WithControlDependencies(identity), 17);
131 auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond);
132 auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), mul, 0);
133
134 GraphDef expected;
135 TF_EXPECT_OK(scope.ToGraphDef(&expected));
136
137 InstantiationResultForTest result;
138 TF_EXPECT_OK(InstantiateFunctionForTest(then_fn.name(), library, &result));
139
140 EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
141 EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types);
142 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
143 }
144
145 // else body.
146 {
147 Scope scope = Scope::NewRootScope().ExitOnError();
148 auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0);
149 auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
150 auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
151 auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0);
152 auto cond_1 = ops::Const(
153 scope.WithOpName("cond_1").WithControlDependencies(identity), 23);
154 auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1);
155 auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
156
157 GraphDef expected;
158 TF_EXPECT_OK(scope.ToGraphDef(&expected));
159
160 InstantiationResultForTest result;
161 TF_EXPECT_OK(InstantiateFunctionForTest(else_fn.name(), library, &result));
162
163 EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
164 EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types);
165 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
166 }
167 }
168
169 // Returns the names of the "cond" and "body" functions for the While node
170 // in a graph.
FindWhileCondAndBody(const GraphDef & graph,NameAttrList * cond,NameAttrList * body)171 Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond,
172 NameAttrList* body) {
173 for (const NodeDef& node : graph.node()) {
174 if (node.op() == "XlaWhile") {
175 const NameAttrList* result;
176 TF_RETURN_IF_ERROR(GetNodeAttr(node, "cond", &result));
177 *cond = *result;
178 TF_RETURN_IF_ERROR(GetNodeAttr(node, "body", &result));
179 *body = *result;
180 return Status::OK();
181 }
182 }
183 return errors::NotFound("No XlaWhile node found in graph");
184 }
185
186 // Graph:
187 // x = array_ops.placeholder(dtypes.int32)
188 // y = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x])
TEST(FunctionalizeControlFlow,OneLoopVar)189 TEST(FunctionalizeControlFlow, OneLoopVar) {
190 Graph graph(OpRegistry::Global());
191 {
192 Scope scope = Scope::NewRootScope().ExitOnError();
193
194 auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
195
196 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
197 auto enter =
198 ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop");
199 // Add an unused Enter node. These should be ignored.
200 auto enter2 =
201 ops::internal::Enter(scope.WithOpName("while/Enter2"), source, "aloop");
202 auto merge = ops::Merge(scope.WithOpName("while/Merge"),
203 std::initializer_list<Input>{enter, dummy});
204 auto ten = ops::Const<int32>(
205 scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
206 10);
207 auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
208 auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
209 auto switch_ =
210 ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
211 auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"),
212 switch_.output_false);
213 auto identity =
214 ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
215 auto one = ops::Const<int32>(
216 scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
217 auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
218 auto next_iteration =
219 ops::NextIteration(scope.WithOpName("while/NextIteration"), add);
220
221 auto sink = ops::Identity(scope.WithOpName("sink"), exit);
222
223 // Remove the dummy node and add the loop backedge.
224 scope.graph()->RemoveNode(dummy.node());
225 scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
226
227 TF_EXPECT_OK(scope.ToGraph(&graph));
228 }
229
230 // Regression test: control edges from an Enter node to the graph sink should
231 // be ignored.
232 for (Node* n : graph.nodes()) {
233 if (n->name() == "while/Enter") {
234 graph.AddControlEdge(n, graph.sink_node());
235 }
236 }
237
238 FunctionLibraryDefinition library(OpRegistry::Global(), {});
239 TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
240
241 GraphDef graph_def;
242 graph.ToGraphDef(&graph_def);
243
244 NameAttrList cond_fn, body_fn;
245 TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
246
247 // Outer graph
248 {
249 Scope scope = Scope::NewRootScope().ExitOnError();
250 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
251 auto while_op =
252 ops::XlaWhile(scope.WithOpName("while/LoopCond"),
253 std::initializer_list<Input>{source}, cond_fn, body_fn);
254 auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
255 GraphDef expected;
256 TF_EXPECT_OK(scope.ToGraphDef(&expected));
257 TF_EXPECT_GRAPH_EQ(expected, graph_def);
258 }
259
260 // Condition graph
261 {
262 Scope scope = Scope::NewRootScope().ExitOnError();
263 auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
264 auto ten = ops::Const<int32>(
265 scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10);
266 auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten);
267 auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
268
269 GraphDef expected;
270 TF_EXPECT_OK(scope.ToGraphDef(&expected));
271
272 InstantiationResultForTest result;
273 TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result));
274
275 EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
276 EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
277 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
278 }
279
280 // Body graph.
281 {
282 Scope scope = Scope::NewRootScope().ExitOnError();
283 auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
284 auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
285 auto one = ops::Const<int32>(
286 scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
287 auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
288 auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
289
290 GraphDef expected;
291 TF_EXPECT_OK(scope.ToGraphDef(&expected));
292
293 InstantiationResultForTest result;
294 TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result));
295
296 EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
297 EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
298 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
299 }
300 }
301
302 // Tests functionalizing OneLoopVar where the loop value is not used post the
303 // loop.
304 // Graph:
305 // x = array_ops.placeholder(dtypes.int32)
306 // control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x])
TEST(FunctionalizeControlFlow,OneLoopVarWithoutExit)307 TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) {
308 Graph graph(OpRegistry::Global());
309 {
310 Scope scope = Scope::NewRootScope().ExitOnError();
311
312 auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
313
314 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
315 auto enter =
316 ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop");
317 auto merge = ops::Merge(scope.WithOpName("while/Merge"),
318 std::initializer_list<Input>{enter, dummy});
319 auto ten = ops::Const<int32>(
320 scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
321 10);
322 auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
323 auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
324 auto switch_ =
325 ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
326 auto identity =
327 ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
328 auto one = ops::Const<int32>(
329 scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
330 auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
331 auto next_iteration =
332 ops::NextIteration(scope.WithOpName("while/NextIteration"), add);
333
334 // Remove the dummy node and add the loop backedge.
335 scope.graph()->RemoveNode(dummy.node());
336 scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
337
338 TF_EXPECT_OK(scope.ToGraph(&graph));
339 }
340
341 FunctionLibraryDefinition library(OpRegistry::Global(), {});
342 TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
343
344 GraphDef graph_def;
345 graph.ToGraphDef(&graph_def);
346
347 NameAttrList cond_fn, body_fn;
348 TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
349
350 // Outer graph
351 {
352 Scope scope = Scope::NewRootScope().ExitOnError();
353 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
354 auto while_op =
355 ops::XlaWhile(scope.WithOpName("while/LoopCond"),
356 std::initializer_list<Input>{source}, cond_fn, body_fn);
357 GraphDef expected;
358 TF_EXPECT_OK(scope.ToGraphDef(&expected));
359 TF_EXPECT_GRAPH_EQ(expected, graph_def);
360 }
361
362 // Condition graph
363 {
364 Scope scope = Scope::NewRootScope().ExitOnError();
365 auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
366 auto ten = ops::Const<int32>(
367 scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10);
368 auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten);
369 auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
370
371 GraphDef expected;
372 TF_EXPECT_OK(scope.ToGraphDef(&expected));
373
374 InstantiationResultForTest result;
375 TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result));
376
377 EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
378 EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
379 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
380 }
381
382 // Body graph.
383 {
384 Scope scope = Scope::NewRootScope().ExitOnError();
385 auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
386 auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
387 auto one = ops::Const<int32>(
388 scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
389 auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
390 auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
391
392 GraphDef expected;
393 TF_EXPECT_OK(scope.ToGraphDef(&expected));
394
395 InstantiationResultForTest result;
396 TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result));
397
398 EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
399 EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
400 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
401 }
402 }
403
404 // Graph:
405 // x = array_ops.placeholder(dtypes.int32)
406 // y = array_ops.placeholder(dtypes.int32)
407 // cond = lambda (i, j): i + 3 < 10
408 // body = lambda (i, j): (i < 10, j * 2)
409 // z = control_flow_ops.while_loop(cond, body, [x, y])
TEST(FunctionalizeControlFlow,TwoLoopVars)410 TEST(FunctionalizeControlFlow, TwoLoopVars) {
411 Graph graph(OpRegistry::Global());
412 {
413 Scope scope = Scope::NewRootScope().ExitOnError();
414
415 auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
416
417 auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
418 auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
419 auto enter_x =
420 ops::internal::Enter(scope.WithOpName("while/Enter/x"), x, "aloop");
421 auto enter_y =
422 ops::internal::Enter(scope.WithOpName("while/Enter/y"), y, "aloop");
423 auto merge_x = ops::Merge(scope.WithOpName("while/Merge/x"),
424 std::initializer_list<Input>{enter_x, dummy});
425 auto merge_y = ops::Merge(scope.WithOpName("while/Merge/y"),
426 std::initializer_list<Input>{enter_y, dummy});
427
428 // Loop condition
429 auto three = ops::Const<int32>(scope.WithOpName("while/cond/three")
430 .WithControlDependencies(merge_x.output),
431 3);
432 auto cond_add =
433 ops::Add(scope.WithOpName("while/cond/Add"), merge_x.output, three);
434 auto ten = ops::Const<int32>(scope.WithOpName("while/cond/ten")
435 .WithControlDependencies(merge_x.output),
436 10);
437 auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten);
438 auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
439
440 auto switch_x = ops::Switch(scope.WithOpName("while/Switch/x"),
441 merge_x.output, loop_cond);
442 auto switch_y = ops::Switch(scope.WithOpName("while/Switch/y"),
443 merge_y.output, loop_cond);
444
445 auto exit_x = ops::internal::Exit(scope.WithOpName("while/Exit/x"),
446 switch_x.output_false);
447 auto exit_y = ops::internal::Exit(scope.WithOpName("while/Exit/y"),
448 switch_y.output_false);
449
450 auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"),
451 switch_x.output_true);
452 auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"),
453 switch_y.output_true);
454
455 auto one = ops::Const<int32>(
456 scope.WithOpName("while/add/one").WithControlDependencies(identity_x),
457 1);
458 auto two = ops::Const<int32>(
459 scope.WithOpName("while/mul/two").WithControlDependencies(identity_x),
460 2);
461
462 auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one);
463 auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two);
464 auto next_iteration_x =
465 ops::NextIteration(scope.WithOpName("while/NextIteration/x"), add);
466 auto next_iteration_y =
467 ops::NextIteration(scope.WithOpName("while/NextIteration/y"), mul);
468
469 auto sink_x = ops::Identity(scope.WithOpName("sink_x"), exit_x);
470 auto sink_y = ops::Identity(scope.WithOpName("sink_y"), exit_y);
471
472 // Remove the dummy node and add the loop backedges.
473 scope.graph()->RemoveNode(dummy.node());
474 scope.graph()->AddEdge(next_iteration_x.node(), 0, merge_x.output.node(),
475 1);
476 scope.graph()->AddEdge(next_iteration_y.node(), 0, merge_y.output.node(),
477 1);
478
479 TF_EXPECT_OK(scope.ToGraph(&graph));
480 }
481
482 FunctionLibraryDefinition library(OpRegistry::Global(), {});
483 TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
484
485 GraphDef graph_def;
486 graph.ToGraphDef(&graph_def);
487
488 NameAttrList cond_fn, body_fn;
489 TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
490
491 // Outer graph.
492 {
493 Scope scope = Scope::NewRootScope().ExitOnError();
494 auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
495 auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
496 auto while_op =
497 ops::XlaWhile(scope.WithOpName("while/LoopCond"),
498 std::initializer_list<Input>{x, y}, cond_fn, body_fn);
499 auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]);
500 auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]);
501 GraphDef expected;
502 TF_EXPECT_OK(scope.ToGraphDef(&expected));
503 TF_EXPECT_GRAPH_EQ(expected, graph_def);
504 }
505
506 // Condition graph.
507 {
508 Scope scope = Scope::NewRootScope().ExitOnError();
509 auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
510 auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
511 auto three = ops::Const<int32>(scope.WithOpName("while/cond/three")
512 .WithControlDependencies(arg0.output),
513 3);
514 auto cond_add =
515 ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three);
516 auto ten = ops::Const<int32>(
517 scope.WithOpName("while/cond/ten").WithControlDependencies(arg0.output),
518 10);
519 auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten);
520 auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
521
522 GraphDef expected;
523 TF_EXPECT_OK(scope.ToGraphDef(&expected));
524
525 InstantiationResultForTest result;
526 TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result));
527
528 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types);
529 EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
530 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
531 }
532
533 // Body graph.
534 {
535 Scope scope = Scope::NewRootScope().ExitOnError();
536 auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
537 auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
538
539 auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"), arg0);
540 auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"), arg1);
541
542 auto one = ops::Const<int32>(
543 scope.WithOpName("while/add/one").WithControlDependencies(identity_x),
544 1);
545 auto two = ops::Const<int32>(
546 scope.WithOpName("while/mul/two").WithControlDependencies(identity_x),
547 2);
548
549 auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one);
550 auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two);
551 auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0);
552 auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), mul, 1);
553
554 GraphDef expected;
555 TF_EXPECT_OK(scope.ToGraphDef(&expected));
556
557 InstantiationResultForTest result;
558 TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result));
559
560 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types);
561 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.ret_types);
562 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
563 }
564 }
565
566 // Example with nesting, loop-invariant arguments, and resource variables.
567 //
568 // accum = resource_variable_ops.ResourceVariable(1)
569 // x = array_ops.placeholder(2, dtype=dtypes.int32)
570 // y = 3 + x
571 //
572 // def inner_body(j, k):
573 // add = state_ops.assign_add(accum, k * j + x)
574 // with ops.control_dependencies([add]):
575 // return [j + 1, k]
576 //
577 // def body(i):
578 // m = control_flow_ops.while_loop(lambda j, k: j < 5, inner_body,
579 // [1, y], name="inner")
580 // with ops.control_dependencies(m):
581 // return [i + 1]
582 //
583 // z = control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="outer")
TEST(FunctionalizeControlFlow,Complex)584 TEST(FunctionalizeControlFlow, Complex) {
585 Graph graph(OpRegistry::Global());
586 {
587 Scope scope = Scope::NewRootScope().ExitOnError();
588
589 auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
590
591 auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
592 auto three = ops::Const<int32>(scope.WithOpName("three"), 3);
593 auto y = ops::Add(scope.WithOpName("y"), x, three);
594
595 auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32,
596 TensorShape({}));
597
598 // Outer loop
599 auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
600 auto enter_i =
601 ops::internal::Enter(scope.WithOpName("outer/Enter_i"), zero, "outer");
602 auto merge_i = ops::Merge(scope.WithOpName("outer/Merge_i"),
603 std::initializer_list<Input>{enter_i, dummy});
604 auto ten = ops::Const<int32>(scope.WithOpName("outer/Less/y")
605 .WithControlDependencies(merge_i.output),
606 10);
607 auto less_i =
608 ops::Less(scope.WithOpName("outer/Less_i"), merge_i.output, ten);
609 auto outer_loop_cond =
610 ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_i);
611 auto switch_i = ops::Switch(scope.WithOpName("outer/Switch"),
612 merge_i.output, outer_loop_cond);
613 auto exit_i = ops::internal::Exit(scope.WithOpName("outer/Exit"),
614 switch_i.output_false);
615 auto identity_i =
616 ops::Identity(scope.WithOpName("outer/Identity"), switch_i.output_true);
617
618 auto enter_x_outer =
619 ops::internal::Enter(scope.WithOpName("outer/Enter_x"), x, "outer",
620 ops::internal::Enter::Attrs().IsConstant(true));
621 auto enter_k_outer =
622 ops::internal::Enter(scope.WithOpName("outer/Enter_k"), y, "outer",
623 ops::internal::Enter::Attrs().IsConstant(true));
624 auto enter_var_outer =
625 ops::internal::Enter(scope.WithOpName("outer/Enter_var"), var, "outer",
626 ops::internal::Enter::Attrs().IsConstant(true));
627
628 // Inner loop
629 auto one_j = ops::Const<int32>(
630 scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
631 auto enter_j = ops::internal::Enter(scope.WithOpName("outer/inner/Enter_j"),
632 one_j, "inner");
633 auto enter_k =
634 ops::internal::Enter(scope.WithOpName("outer/inner/Enter_k")
635 .WithControlDependencies(identity_i),
636 enter_k_outer, "inner");
637 auto enter_x = ops::internal::Enter(
638 scope.WithOpName("outer/inner/Enter_x"), enter_x_outer, "inner",
639 ops::internal::Enter::Attrs().IsConstant(true));
640 auto enter_var = ops::internal::Enter(
641 scope.WithOpName("outer/inner/Enter_var"), enter_var_outer, "inner",
642 ops::internal::Enter::Attrs().IsConstant(true));
643
644 auto merge_j = ops::Merge(scope.WithOpName("outer/inner/Merge_j"),
645 std::initializer_list<Input>{enter_j, dummy});
646 auto merge_k = ops::Merge(scope.WithOpName("outer/inner/Merge_k"),
647 std::initializer_list<Input>{enter_k, dummy});
648
649 auto five = ops::Const<int32>(scope.WithOpName("outer/inner/Five")
650 .WithControlDependencies(merge_j.output),
651 5);
652 auto less_j =
653 ops::Less(scope.WithOpName("outer/inner/Less_j"), merge_j.output, five);
654 auto loop_cond = ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_j);
655
656 auto switch_j = ops::Switch(scope.WithOpName("outer/inner/Switch_j"),
657 merge_j.output, loop_cond);
658 auto switch_k = ops::Switch(scope.WithOpName("outer/inner/Switch_k"),
659 merge_k.output, loop_cond);
660 auto exit_j = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_j"),
661 switch_j.output_false);
662 auto exit_k = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_k"),
663 switch_k.output_false);
664 auto identity_j = ops::Identity(scope.WithOpName("outer/inner/Identity_j"),
665 switch_j.output_true);
666 auto identity_k = ops::Identity(scope.WithOpName("outer/inner/Identity_k"),
667 switch_k.output_true);
668
669 // Variable update
670 auto mul_jk =
671 ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k);
672 auto add_jkx =
673 ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, enter_x);
674 auto assign = ops::AssignAddVariableOp(
675 scope.WithOpName("outer/inner/assign_add"), enter_var, add_jkx);
676
677 auto one =
678 ops::Const<int32>(scope.WithOpName("outer/inner/One")
679 .WithControlDependencies(
680 gtl::ArraySlice<Operation>{assign.operation}),
681 1);
682 auto add_j =
683 ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
684
685 auto next_iteration_j = ops::NextIteration(
686 scope.WithOpName("outer/inner/NextIteration_j"), add_j);
687 auto next_iteration_k = ops::NextIteration(
688 scope.WithOpName("outer/inner/NextIteration_k"), identity_k);
689
690 // Body and backedge for outer loop.
691 auto one_outer = ops::Const<int32>(
692 scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
693 auto add_i =
694 ops::Add(scope.WithOpName("outer/add")
695 .WithControlDependencies(gtl::ArraySlice<Operation>{
696 exit_j.output.op(), exit_k.output.op()}),
697 identity_i, one_outer);
698 auto next_iteration_i =
699 ops::NextIteration(scope.WithOpName("outer/NextIteration"), add_i);
700
701 auto sink = ops::Identity(scope.WithOpName("sink"), exit_i);
702
703 // Remove the dummy node and add the loop backedge.
704 scope.graph()->RemoveNode(dummy.node());
705 scope.graph()->AddEdge(next_iteration_i.node(), 0, merge_i.output.node(),
706 1);
707 scope.graph()->AddEdge(next_iteration_j.node(), 0, merge_j.output.node(),
708 1);
709 scope.graph()->AddEdge(next_iteration_k.node(), 0, merge_k.output.node(),
710 1);
711
712 TF_EXPECT_OK(scope.ToGraph(&graph));
713 }
714
715 FunctionLibraryDefinition library(OpRegistry::Global(), {});
716 TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
717
718 GraphDef graph_def;
719 graph.ToGraphDef(&graph_def);
720
721 NameAttrList outer_cond_fn, outer_body_fn;
722 TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn));
723
724 // Outer graph.
725 {
726 Scope scope = Scope::NewRootScope().ExitOnError();
727 auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
728 auto three = ops::Const<int32>(scope.WithOpName("three"), 3);
729 auto y = ops::Add(scope.WithOpName("y"), x, three);
730
731 auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32,
732 TensorShape({}));
733
734 auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
735
736 auto while_op = ops::XlaWhile(scope.WithOpName("outer/LoopCond"),
737 std::initializer_list<Input>{zero, y, x, var},
738 outer_cond_fn, outer_body_fn);
739 auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
740 GraphDef expected;
741 TF_EXPECT_OK(scope.ToGraphDef(&expected));
742 TF_EXPECT_GRAPH_EQ(expected, graph_def);
743 }
744
745 // Outer condition graph.
746 {
747 Scope scope = Scope::NewRootScope().ExitOnError();
748 auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
749 auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
750 auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
751 auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
752
753 auto ten = ops::Const<int32>(
754 scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output),
755 10);
756 auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten);
757 auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0);
758
759 GraphDef expected;
760 TF_EXPECT_OK(scope.ToGraphDef(&expected));
761
762 InstantiationResultForTest result;
763 TF_EXPECT_OK(
764 InstantiateFunctionForTest(outer_cond_fn.name(), library, &result));
765
766 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
767 result.arg_types);
768 EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
769 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
770 }
771
772 // Outer body graph.
773 NameAttrList inner_cond_fn, inner_body_fn;
774 {
775 InstantiationResultForTest result;
776 TF_EXPECT_OK(
777 InstantiateFunctionForTest(outer_body_fn.name(), library, &result));
778
779 // Find the inner condition and body names.
780 TF_EXPECT_OK(
781 FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn));
782
783 Scope scope = Scope::NewRootScope().ExitOnError();
784 auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
785 auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
786 auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
787 auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
788
789 auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0);
790 auto one_j = ops::Const<int32>(
791 scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
792 auto while_op =
793 ops::XlaWhile(scope.WithOpName("outer/LoopCond_1"),
794 std::initializer_list<Input>{one_j, arg1, arg2, arg3},
795 inner_cond_fn, inner_body_fn);
796
797 auto one_outer = ops::Const<int32>(
798 scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
799 auto add_i =
800 ops::Add(scope.WithOpName("outer/add")
801 .WithControlDependencies(gtl::ArraySlice<Operation>{
802 while_op[0].op(), while_op[1].op()}),
803 identity_i, one_outer);
804
805 auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_i, 0);
806 auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), arg1, 1);
807 auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2);
808
809 GraphDef expected;
810 TF_EXPECT_OK(scope.ToGraphDef(&expected));
811
812 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
813 result.arg_types);
814 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types);
815 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
816 }
817
818 // Inner condition graph.
819 {
820 Scope scope = Scope::NewRootScope().ExitOnError();
821 auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
822 auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
823 auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
824 auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
825
826 auto five = ops::Const<int32>(
827 scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), 5);
828 auto less_j = ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five);
829 auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less_j, 0);
830
831 GraphDef expected;
832 TF_EXPECT_OK(scope.ToGraphDef(&expected));
833
834 InstantiationResultForTest result;
835 TF_EXPECT_OK(
836 InstantiateFunctionForTest(inner_cond_fn.name(), library, &result));
837
838 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
839 result.arg_types);
840 EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
841 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
842 }
843
844 // Inner body graph.
845 {
846 Scope scope = Scope::NewRootScope().ExitOnError();
847 auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0);
848 auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1);
849 auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2);
850 auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3);
851
852 auto identity_j =
853 ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0);
854 auto identity_k =
855 ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1);
856
857 auto mul_jk =
858 ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k);
859 auto add_jkx = ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2);
860 auto assign = ops::AssignAddVariableOp(
861 scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx);
862
863 auto one =
864 ops::Const<int32>(scope.WithOpName("outer/inner/One")
865 .WithControlDependencies(
866 gtl::ArraySlice<Operation>{assign.operation}),
867 1);
868 auto add_j =
869 ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
870
871 auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_j, 0);
872 auto retval1 =
873 ops::_Retval(scope.WithOpName("_retval1_RetVal"), identity_k, 1);
874 auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2);
875
876 GraphDef expected;
877 TF_EXPECT_OK(scope.ToGraphDef(&expected));
878
879 InstantiationResultForTest result;
880 TF_EXPECT_OK(
881 InstantiateFunctionForTest(inner_body_fn.name(), library, &result));
882
883 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
884 result.arg_types);
885 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types);
886 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
887 }
888 }
889
890 } // namespace
891 } // namespace tensorflow
892