1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/deadness_analysis.h"
17
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/array_ops.h"
20 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
21 #include "tensorflow/cc/ops/function_ops.h"
22 #include "tensorflow/cc/ops/sendrecv_ops.h"
23 #include "tensorflow/cc/ops/standard_ops.h"
24 #include "tensorflow/compiler/jit/deadness_analysis_internal.h"
25 #include "tensorflow/compiler/jit/defs.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
28 #include "tensorflow/core/common_runtime/graph_constructor.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/graph/algorithm.h"
32 #include "tensorflow/core/graph/graph_def_builder.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/test.h"
35
36 namespace tensorflow {
37 namespace {
38
HasInputsWithMismatchingDeadness(const DeadnessAnalysis & deadness_analysis,const Node & n)39 se::port::StatusOr<bool> HasInputsWithMismatchingDeadness(
40 const DeadnessAnalysis& deadness_analysis, const Node& n) {
41 absl::optional<DeadnessAnalysis::DeadnessPredicate> pred;
42 for (const Edge* edge : n.in_edges()) {
43 TF_ASSIGN_OR_RETURN(
44 DeadnessAnalysis::DeadnessPredicate this_pred,
45 deadness_analysis.GetPredicateFor(edge->src(), edge->src_output()));
46 if (pred && *pred != this_pred) {
47 return true;
48 }
49 pred = this_pred;
50 }
51
52 return false;
53 }
54
55 using deadness_analysis_internal::ComputePredicates;
56 using deadness_analysis_internal::PredicateMapTy;
57
AnalyzeDeadness(Graph * graph,std::unique_ptr<DeadnessAnalysis> * result)58 Status AnalyzeDeadness(Graph* graph,
59 std::unique_ptr<DeadnessAnalysis>* result) {
60 FixupSourceAndSinkEdges(graph);
61 return DeadnessAnalysis::Run(*graph, result);
62 }
63
CreateSwitch(const Scope & root,const string & prefix)64 ops::Switch CreateSwitch(const Scope& root, const string& prefix) {
65 Output value = ops::Placeholder(root.WithOpName(prefix + "/value"), DT_FLOAT);
66 Output predicate =
67 ops::Placeholder(root.WithOpName(prefix + "/pred"), DT_BOOL);
68 return ops::Switch(root.WithOpName(prefix + "/switch"), value, predicate);
69 }
70
ControlOutputFor(const Output & o)71 TensorId ControlOutputFor(const Output& o) {
72 return {o.node()->name(), Graph::kControlSlot};
73 }
74
VLogGraphIfAsked(const Graph & graph)75 void VLogGraphIfAsked(const Graph& graph) {
76 if (VLOG_IS_ON(3)) {
77 GraphDef graph_def;
78 graph.ToGraphDef(&graph_def);
79 string serialized;
80 ::tensorflow::protobuf::TextFormat::PrintToString(graph_def, &serialized);
81 LOG(INFO) << serialized;
82 }
83 }
84
85 struct InductionVarInfo {
86 Output induction_var;
87 Output loop_cond;
88 };
89
90 // Creates an induction variable with the following structure (simplified for
91 // brevity):
92 //
93 // +---------------+
94 // | initial_value |
95 // +---------------+
96 // |
97 // |
98 // v
99 // +---------------+
100 // | Enter |
101 // +---------------+
102 // |
103 // |
104 // v
105 // +---------------+
106 // +> | Merge | -+
107 // | +---------------+ |
108 // | | |
109 // | | |
110 // | v |
111 // | +---------------+ |
112 // | | LessThan10 | |
113 // | +---------------+ |
114 // | | |
115 // | | |
116 // | v |
117 // | +---------------+ |
118 // +----+- | Switch | <+
119 // | | +---------------+
120 // | | |
121 // | | |
122 // | | v
123 // | | +---------------+
124 // | +- | AddOne |
125 // | +---------------+
126 // | +---------------+
127 // +-----> | Exit |
128 // +---------------+
CreateInductionVariable(const Scope & root,const string & prefix,const string & frame_name,const Output & initial_value)129 InductionVarInfo CreateInductionVariable(const Scope& root,
130 const string& prefix,
131 const string& frame_name,
132 const Output& initial_value) {
133 Output enter_initial_value = ops::internal::Enter(
134 root.WithOpName(prefix + "/enter"), initial_value, frame_name);
135
136 ops::Merge iv(root.WithOpName(prefix + "/iv"),
137 {enter_initial_value, enter_initial_value});
138 Output increment_by = ops::Const(root.WithOpName(prefix + "/incr"), 1);
139 Output final_value = ops::Const(root.WithOpName(prefix + "/final"), 10);
140 Output loop_cond_expr =
141 ops::Less(root.WithOpName(prefix + "/cond"), iv.output, final_value);
142 ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output,
143 loop_cond_expr);
144 ops::internal::Exit exit(root.WithOpName(prefix + "/exit"),
145 latch.output_false);
146 Output iv_next = ops::Add(root.WithOpName(prefix + "/ivnext"),
147 latch.output_true, increment_by);
148 Output next_iteration =
149 ops::NextIteration(root.WithOpName(prefix + "/next_iteration"), iv_next);
150
151 CHECK(root.graph()
152 ->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1)
153 .ok());
154 root.graph()->AddControlEdge(iv.output.node(), increment_by.node());
155 root.graph()->AddControlEdge(iv.output.node(), final_value.node());
156
157 return {iv.output, loop_cond_expr};
158 }
159
CreateInductionVariable(const Scope & root,const string & prefix,const string & frame_name,int32 init)160 InductionVarInfo CreateInductionVariable(const Scope& root,
161 const string& prefix,
162 const string& frame_name, int32 init) {
163 return CreateInductionVariable(
164 root, prefix, frame_name,
165 ops::Const(root.WithOpName(prefix + "/init"), init));
166 }
167
168 // Creates an induction variable with the following structure:
169 //
170 // +---------------+
171 // | initial_value |
172 // +---------------+
173 // |
174 // |
175 // v
176 // +---------------+
177 // | Enter |
178 // +---------------+
179 // |
180 // |
181 // v
182 // +---------------+
183 // | Merge | <+
184 // +---------------+ |
185 // | |
186 // | |
187 // v |
188 // +-----------+ +---------------+ |
189 // | loop_cond | --> | Switch | -+
190 // +-----------+ +---------------+
191 // |
192 // |
193 // v
194 // +---------------+
195 // | Exit |
196 // +---------------+
197 struct DependentInductionVar {
198 Output induction_var;
199 ops::Switch latch;
200 };
201
CreateDependentLoopInvariantValue(const Scope & root,const string & prefix,const string & frame_name,const Output & loop_cond,const Output & value)202 DependentInductionVar CreateDependentLoopInvariantValue(
203 const Scope& root, const string& prefix, const string& frame_name,
204 const Output& loop_cond, const Output& value) {
205 Output enter_value = ops::internal::Enter(root.WithOpName(prefix + "/enter"),
206 value, frame_name);
207 ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_value, enter_value});
208 ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond);
209 ops::internal::Exit exit(root.WithOpName(prefix + "/exit"),
210 latch.output_false);
211 Output next_iteration = ops::NextIteration(
212 root.WithOpName(prefix + "/next_iteration"), latch.output_true);
213 CHECK(root.graph()
214 ->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1)
215 .ok());
216 return {iv.output, latch};
217 }
218
CreateDependentLoopInvariantValue(const Scope & root,const string & prefix,const string & frame_name,const Output & loop_cond,int32 value)219 DependentInductionVar CreateDependentLoopInvariantValue(
220 const Scope& root, const string& prefix, const string& frame_name,
221 const Output& loop_cond, int32 value) {
222 return CreateDependentLoopInvariantValue(
223 root, prefix, frame_name, loop_cond,
224 ops::Const(root.WithOpName(prefix + "/init"), value));
225 }
226
TEST(DeadnessAnalysisTest,BasicPositive)227 TEST(DeadnessAnalysisTest, BasicPositive) {
228 Scope root = Scope::NewRootScope().ExitOnError();
229
230 ops::Switch sw = CreateSwitch(root, "0");
231 Output add =
232 ops::Add(root.WithOpName("add"), sw.output_true, sw.output_false);
233
234 std::unique_ptr<DeadnessAnalysis> result;
235 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
236
237 TF_ASSERT_OK_AND_ASSIGN(
238 bool has_inputs_with_mismatching_deadness,
239 HasInputsWithMismatchingDeadness(*result, *add.node()));
240 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
241 }
242
TEST(DeadnessAnalysisTest,BasicNegative)243 TEST(DeadnessAnalysisTest, BasicNegative) {
244 Scope root = Scope::NewRootScope().ExitOnError();
245
246 Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT);
247 Output b = ops::Placeholder(root.WithOpName("b"), DT_FLOAT);
248 Output add = ops::Add(root.WithOpName("add"), a, b);
249
250 std::unique_ptr<DeadnessAnalysis> result;
251 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
252
253 TF_ASSERT_OK_AND_ASSIGN(
254 bool has_inputs_with_mismatching_deadness,
255 HasInputsWithMismatchingDeadness(*result, *add.node()));
256 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
257 }
258
TEST(DeadnessAnalysisTest,AndIsCommutative)259 TEST(DeadnessAnalysisTest, AndIsCommutative) {
260 Scope root = Scope::NewRootScope().ExitOnError();
261
262 ops::Switch sw_0 = CreateSwitch(root, "0");
263 ops::Switch sw_1 = CreateSwitch(root, "1");
264
265 Output a0 =
266 ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false);
267 Output a1 =
268 ops::Add(root.WithOpName("a1"), sw_1.output_false, sw_0.output_false);
269
270 Output b0 =
271 ops::Add(root.WithOpName("b0"), sw_0.output_false, sw_1.output_true);
272 Output b1 =
273 ops::Add(root.WithOpName("b1"), sw_1.output_true, sw_0.output_false);
274
275 Output live0 = ops::Add(root.WithOpName("live0"), a0, a1);
276 Output live1 = ops::Add(root.WithOpName("live1"), b0, b1);
277
278 Output halfdead0 = ops::Add(root.WithOpName("halfdead0"), a0, b0);
279 Output halfdead1 = ops::Add(root.WithOpName("halfdead1"), a1, b1);
280
281 std::unique_ptr<DeadnessAnalysis> result;
282 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
283
284 bool has_inputs_with_mismatching_deadness;
285
286 TF_ASSERT_OK_AND_ASSIGN(
287 has_inputs_with_mismatching_deadness,
288 HasInputsWithMismatchingDeadness(*result, *live0.node()));
289 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
290
291 TF_ASSERT_OK_AND_ASSIGN(
292 has_inputs_with_mismatching_deadness,
293 HasInputsWithMismatchingDeadness(*result, *live1.node()));
294 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
295
296 TF_ASSERT_OK_AND_ASSIGN(
297 has_inputs_with_mismatching_deadness,
298 HasInputsWithMismatchingDeadness(*result, *halfdead0.node()));
299 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
300
301 TF_ASSERT_OK_AND_ASSIGN(
302 has_inputs_with_mismatching_deadness,
303 HasInputsWithMismatchingDeadness(*result, *halfdead1.node()));
304 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
305 }
306
TEST(DeadnessAnalysisTest,AndIsAssociative)307 TEST(DeadnessAnalysisTest, AndIsAssociative) {
308 Scope root = Scope::NewRootScope().ExitOnError();
309
310 ops::Switch sw_0 = CreateSwitch(root, "0");
311 ops::Switch sw_1 = CreateSwitch(root, "1");
312 ops::Switch sw_2 = CreateSwitch(root, "2");
313
314 Output a0 =
315 ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false);
316 Output a1 = ops::Add(root.WithOpName("a1"), a0, sw_2.output_false);
317
318 Output b0 =
319 ops::Add(root.WithOpName("b0"), sw_1.output_false, sw_2.output_false);
320 Output b1 = ops::Add(root.WithOpName("b1"), sw_0.output_false, b0);
321
322 Output add = ops::Add(root.WithOpName("add"), a1, b1);
323
324 std::unique_ptr<DeadnessAnalysis> result;
325 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
326
327 TF_ASSERT_OK_AND_ASSIGN(
328 bool has_inputs_with_mismatching_deadness,
329 HasInputsWithMismatchingDeadness(*result, *add.node()));
330 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
331 }
332
TEST(DeadnessAnalysisTest,OrIsCommutative)333 TEST(DeadnessAnalysisTest, OrIsCommutative) {
334 Scope root = Scope::NewRootScope().ExitOnError();
335
336 ops::Switch sw_0 = CreateSwitch(root, "0");
337 ops::Switch sw_1 = CreateSwitch(root, "1");
338
339 ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
340 ops::Merge m1(root.WithOpName("m1"), {sw_1.output_false, sw_0.output_false});
341 ops::Merge m2(root.WithOpName("m2"), {sw_0.output_false, sw_1.output_true});
342 ops::Merge m3(root.WithOpName("m3"), {sw_1.output_true, sw_0.output_false});
343
344 Output live0 = ops::Add(root.WithOpName("live0"), m0.output, m1.output);
345 Output live1 = ops::Add(root.WithOpName("live1"), m2.output, m3.output);
346
347 Output halfdead0 =
348 ops::Add(root.WithOpName("halfdead0"), m0.output, m2.output);
349 Output halfdead1 =
350 ops::Add(root.WithOpName("halfdead1"), m1.output, m3.output);
351
352 std::unique_ptr<DeadnessAnalysis> result;
353 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
354
355 bool has_inputs_with_mismatching_deadness;
356
357 TF_ASSERT_OK_AND_ASSIGN(
358 has_inputs_with_mismatching_deadness,
359 HasInputsWithMismatchingDeadness(*result, *live0.node()));
360 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
361
362 TF_ASSERT_OK_AND_ASSIGN(
363 has_inputs_with_mismatching_deadness,
364 HasInputsWithMismatchingDeadness(*result, *live1.node()));
365 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
366
367 TF_ASSERT_OK_AND_ASSIGN(
368 has_inputs_with_mismatching_deadness,
369 HasInputsWithMismatchingDeadness(*result, *halfdead0.node()));
370 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
371
372 TF_ASSERT_OK_AND_ASSIGN(
373 has_inputs_with_mismatching_deadness,
374 HasInputsWithMismatchingDeadness(*result, *halfdead1.node()));
375 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
376 }
377
TEST(DeadnessAnalysisTest,OrIsAssociative)378 TEST(DeadnessAnalysisTest, OrIsAssociative) {
379 Scope root = Scope::NewRootScope().ExitOnError();
380
381 ops::Switch sw_0 = CreateSwitch(root, "0");
382 ops::Switch sw_1 = CreateSwitch(root, "1");
383 ops::Switch sw_2 = CreateSwitch(root, "2");
384
385 ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
386 ops::Merge m1(root.WithOpName("m1"), {m0.output, sw_2.output_false});
387 ops::Merge m2(root.WithOpName("m2"), {sw_1.output_false, sw_2.output_false});
388 ops::Merge m3(root.WithOpName("m3"), {sw_0.output_false, m2.output});
389
390 Output add = ops::Add(root.WithOpName("add"), m1.output, m3.output);
391
392 std::unique_ptr<DeadnessAnalysis> result;
393 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
394
395 TF_ASSERT_OK_AND_ASSIGN(
396 bool has_inputs_with_mismatching_deadness,
397 HasInputsWithMismatchingDeadness(*result, *add.node()));
398 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
399 }
400
TEST(DeadnessAnalysisTest,AndOfOr)401 TEST(DeadnessAnalysisTest, AndOfOr) {
402 Scope root = Scope::NewRootScope().ExitOnError();
403
404 ops::Switch sw_0 = CreateSwitch(root, "0");
405 ops::Switch sw_1 = CreateSwitch(root, "1");
406 ops::Switch sw_2 = CreateSwitch(root, "2");
407 ops::Switch sw_3 = CreateSwitch(root, "3");
408
409 ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
410 ops::Merge m1(root.WithOpName("m1"), {sw_2.output_false, sw_3.output_false});
411
412 Output add0 = ops::Add(root.WithOpName("add0"), m0.output, m1.output);
413 Output add1 = ops::Add(root.WithOpName("add1"), m0.output, m1.output);
414
415 Output add2 = ops::Add(root.WithOpName("add2"), add0, add1);
416
417 std::unique_ptr<DeadnessAnalysis> result;
418 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
419
420 TF_ASSERT_OK_AND_ASSIGN(
421 bool has_inputs_with_mismatching_deadness,
422 HasInputsWithMismatchingDeadness(*result, *add2.node()));
423 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
424 }
425
TEST(DeadnessAnalysisTest,OrOfAnd)426 TEST(DeadnessAnalysisTest, OrOfAnd) {
427 Scope root = Scope::NewRootScope().ExitOnError();
428
429 ops::Switch sw_0 = CreateSwitch(root, "0");
430 ops::Switch sw_1 = CreateSwitch(root, "1");
431 ops::Switch sw_2 = CreateSwitch(root, "2");
432 ops::Switch sw_3 = CreateSwitch(root, "3");
433
434 Output add0 =
435 ops::Add(root.WithOpName("add0"), sw_0.output_false, sw_1.output_false);
436 Output add1 =
437 ops::Add(root.WithOpName("add1"), sw_2.output_false, sw_3.output_false);
438
439 ops::Merge m0(root.WithOpName("m0"), {add0, add1});
440 ops::Merge m1(root.WithOpName("m1"), {add0, add1});
441
442 Output add2 = ops::Add(root.WithOpName("add2"), m0.output, m1.output);
443
444 std::unique_ptr<DeadnessAnalysis> result;
445 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
446
447 TF_ASSERT_OK_AND_ASSIGN(
448 bool has_inputs_with_mismatching_deadness,
449 HasInputsWithMismatchingDeadness(*result, *add2.node()));
450 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
451 }
452
TEST(DeadnessAnalysisTest,AndOrDistributiveSimplified)453 TEST(DeadnessAnalysisTest, AndOrDistributiveSimplified) {
454 // (*A | (~*A & ((~*B & ~*A) | (~*A & *B)))) == #true
455 Scope root = Scope::NewRootScope().ExitOnError();
456
457 ops::Switch sw_0 = CreateSwitch(root, "A");
458 ops::Switch sw_1 = CreateSwitch(root, "B");
459 Output add0 =
460 ops::Add(root.WithOpName("and0"), sw_0.output_false, sw_1.output_true);
461 Output add1 =
462 ops::Add(root.WithOpName("and1"), sw_0.output_false, sw_1.output_false);
463 ops::Merge or2(root.WithOpName("or2"), {add0, add1});
464 Output add3 =
465 ops::Add(root.WithOpName("and3"), or2.output, sw_0.output_false);
466 ops::Merge or4(root.WithOpName("or4"), {add3, sw_0.output_true});
467
468 std::unique_ptr<DeadnessAnalysis> result;
469 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
470
471 PredicateMapTy predicate_map;
472 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
473 EXPECT_EQ(predicate_map[ControlOutputFor(or4.output)], "#true");
474 }
475
TEST(DeadnessAnalysisTest,AndOrDistributive)476 TEST(DeadnessAnalysisTest, AndOrDistributive) {
477 // (A|B)&C == (A&C)|(B&C)
478 Scope root = Scope::NewRootScope().ExitOnError();
479
480 ops::Switch sw_0 = CreateSwitch(root, "0");
481 ops::Switch sw_1 = CreateSwitch(root, "1");
482 ops::Switch sw_2 = CreateSwitch(root, "2");
483
484 ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
485 Output add0 = ops::Add(root.WithOpName("add0"), m0.output, sw_2.output_false);
486
487 Output add1 =
488 ops::Add(root.WithOpName("add1"), sw_0.output_false, sw_2.output_false);
489 Output add2 =
490 ops::Add(root.WithOpName("add2"), sw_1.output_false, sw_2.output_false);
491 ops::Merge m1(root.WithOpName("m1"), {add1, add2});
492
493 Output add3 = ops::Add(root.WithOpName("add3"), add0, m1.output);
494
495 std::unique_ptr<DeadnessAnalysis> result;
496 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
497
498 TF_ASSERT_OK_AND_ASSIGN(
499 bool has_inputs_with_mismatching_deadness,
500 HasInputsWithMismatchingDeadness(*result, *add3.node()));
501 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
502 }
503
TEST(DeadnessAnalysisTest,Ternary)504 TEST(DeadnessAnalysisTest, Ternary) {
505 Scope root = Scope::NewRootScope().ExitOnError();
506
507 Output predicate = ops::Placeholder(root.WithOpName("predicate"), DT_BOOL);
508 Output true_value = ops::Placeholder(root.WithOpName("true_value"), DT_FLOAT);
509 Output false_value =
510 ops::Placeholder(root.WithOpName("false_value"), DT_FLOAT);
511
512 ops::Switch predicated_true(root.WithOpName("predicated_true"), true_value,
513 predicate);
514
515 ops::Switch predicated_false(root.WithOpName("predicated_false"), true_value,
516 predicate);
517 ops::Merge merge(root.WithOpName("ternary"), {predicated_true.output_true,
518 predicated_false.output_false});
519 Output addend = ops::Placeholder(root.WithOpName("addend"), DT_FLOAT);
520 Output add = ops::Add(root.WithOpName("add"), merge.output, addend);
521
522 std::unique_ptr<DeadnessAnalysis> result;
523 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
524
525 TF_ASSERT_OK_AND_ASSIGN(
526 bool has_inputs_with_mismatching_deadness,
527 HasInputsWithMismatchingDeadness(*result, *add.node()));
528 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
529 }
530
TEST(DeadnessAnalysisTest,Recv)531 TEST(DeadnessAnalysisTest, Recv) {
532 Scope root = Scope::NewRootScope().ExitOnError();
533
534 Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_FLOAT, "tensor_a",
535 "sender", 0, "receiver");
536 Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_FLOAT, "tensor_b",
537 "sender", 0, "receiver");
538 Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b);
539
540 std::unique_ptr<DeadnessAnalysis> result;
541 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
542
543 TF_ASSERT_OK_AND_ASSIGN(
544 bool has_inputs_with_mismatching_deadness,
545 HasInputsWithMismatchingDeadness(*result, *add.node()));
546 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
547 }
548
TEST(DeadnessAnalysisTest,HostRecv)549 TEST(DeadnessAnalysisTest, HostRecv) {
550 Scope root = Scope::NewRootScope().ExitOnError();
551
552 Output recv_a = ops::_HostRecv(root.WithOpName("recv_a"), DT_FLOAT,
553 "tensor_a", "sender", 0, "receiver");
554 Output recv_b = ops::_HostRecv(root.WithOpName("recv_b"), DT_FLOAT,
555 "tensor_b", "sender", 0, "receiver");
556 Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b);
557
558 std::unique_ptr<DeadnessAnalysis> result;
559 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
560
561 TF_ASSERT_OK_AND_ASSIGN(
562 bool has_inputs_with_mismatching_deadness,
563 HasInputsWithMismatchingDeadness(*result, *add.node()));
564 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
565 }
566
TEST(DeadnessAnalysisTest,Loop)567 TEST(DeadnessAnalysisTest, Loop) {
568 Scope root = Scope::NewRootScope().ExitOnError();
569 Output iv0 = CreateInductionVariable(root, "iv0", "fr0", 0).induction_var;
570 Output iv1 = CreateInductionVariable(root, "iv1", "fr0", 0).induction_var;
571 Output iv2 = CreateInductionVariable(root, "iv2", "fr0", 1).induction_var;
572 Output add0 = ops::Add(root.WithOpName("add0"), iv0, iv1);
573 Output add1 = ops::Add(root.WithOpName("add1"), iv1, iv2);
574
575 // NB! iv0 and iv1 are equivalent and a smarter deadness analysis would have
576 // noticed that. Today we are pessimistic here because we assign an
577 // uninterpreted symbol to merges with backedges.
578
579 VLogGraphIfAsked(*root.graph());
580
581 {
582 std::unique_ptr<DeadnessAnalysis> result;
583 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
584
585 bool has_inputs_with_mismatching_deadness;
586
587 TF_ASSERT_OK_AND_ASSIGN(
588 has_inputs_with_mismatching_deadness,
589 HasInputsWithMismatchingDeadness(*result, *add0.node()));
590 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
591
592 TF_ASSERT_OK_AND_ASSIGN(
593 has_inputs_with_mismatching_deadness,
594 HasInputsWithMismatchingDeadness(*result, *add1.node()));
595 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
596 }
597 {
598 PredicateMapTy predicate_map;
599 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
600
601 // In theory we should be able to tell that iv0/cond:0 and iv1/cond:0
602 // produce the same deadness. But we're not that smart today.
603 EXPECT_EQ(predicate_map[ControlOutputFor(iv0)],
604 "{#true,&,*iv0/cond:0}<fr0>");
605 EXPECT_EQ(predicate_map[ControlOutputFor(iv1)],
606 "{#true,&,*iv1/cond:0}<fr0>");
607 EXPECT_EQ(predicate_map[ControlOutputFor(iv2)],
608 "{#true,&,*iv2/cond:0}<fr0>");
609 EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
610 "({#true,&,*iv0/cond:0}<fr0> & {#true,&,*iv1/cond:0}<fr0>)");
611 EXPECT_EQ(predicate_map[ControlOutputFor(add1)],
612 "({#true,&,*iv1/cond:0}<fr0> & {#true,&,*iv2/cond:0}<fr0>)");
613 }
614 }
615
TEST(DeadnessAnalysisTest,ControlEquivalentLoopBodies)616 TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) {
617 Scope root = Scope::NewRootScope().ExitOnError();
618 InductionVarInfo iv = CreateInductionVariable(root, "iv0", "loop", 0);
619 Output dependent_iv0 =
620 CreateDependentLoopInvariantValue(root, "div0", "loop", iv.loop_cond, 0)
621 .induction_var;
622 Output dependent_iv1 =
623 CreateDependentLoopInvariantValue(root, "div1", "loop", iv.loop_cond, 0)
624 .induction_var;
625 Output add0 = ops::Add(root.WithOpName("add0"), dependent_iv0, dependent_iv1);
626
627 VLogGraphIfAsked(*root.graph());
628
629 {
630 std::unique_ptr<DeadnessAnalysis> result;
631 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
632
633 TF_ASSERT_OK_AND_ASSIGN(
634 bool has_inputs_with_mismatching_deadness,
635 HasInputsWithMismatchingDeadness(*result, *add0.node()));
636 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
637 }
638 {
639 PredicateMapTy predicate_map;
640 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
641 /*enable_optimistic=*/true));
642
643 EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
644 "{#true,&,*iv0/cond:0}<loop>");
645 EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)],
646 predicate_map[ControlOutputFor(iv.induction_var)]);
647 EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)],
648 predicate_map[ControlOutputFor(iv.induction_var)]);
649 EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
650 predicate_map[ControlOutputFor(iv.induction_var)]);
651 }
652 {
653 PredicateMapTy predicate_map;
654 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
655 /*enable_optimistic=*/false));
656
657 EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
658 "{#true,&,*iv0/cond:0}<loop>");
659 EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)],
660 "{#true,&,(iv0/iv:0 & *iv0/cond:0)}<loop>");
661 EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)],
662 "{#true,&,(iv0/iv:0 & *iv0/cond:0)}<loop>");
663 EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
664 "{#true,&,(iv0/iv:0 & *iv0/cond:0)}<loop>");
665 }
666 }
667
TEST(DeadnessAnalysisTest,LoopInvariantPredicateOnBackedge)668 TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) {
669 // Create a merge that "looks like" a loop but isn't really. It has a value
670 // that does not depend on the merge on its backedge.
671 Scope root = Scope::NewRootScope().ExitOnError();
672 InductionVarInfo iv = CreateInductionVariable(root, "iv0", "frame", 0);
673 DependentInductionVar dependent_iv =
674 CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0);
675 FixupSourceAndSinkEdges(root.graph());
676
677 TF_ASSERT_OK(root.graph()->UpdateEdge(
678 iv.induction_var.node(), 0, dependent_iv.latch.output_true.node(), 0));
679
680 VLogGraphIfAsked(*root.graph());
681
682 {
683 PredicateMapTy predicate_map;
684 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
685 /*enable_optimistic=*/true));
686
687 EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
688 "{#true,&,*iv0/cond:0}<frame>");
689 }
690 {
691 PredicateMapTy predicate_map;
692 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
693 /*enable_optimistic=*/false));
694
695 EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
696 "div0/iv:0");
697 }
698 }
699
TEST(DeadnessAnalysisTest,ControlEquivalentNestedLoopBodies)700 TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
701 Scope root = Scope::NewRootScope().ExitOnError();
702 InductionVarInfo iv_outer =
703 CreateInductionVariable(root, "iv_outer", "outer_loop", 0);
704 Output enter_constant_outer_loop = ops::internal::Enter(
705 root.WithOpName("constant_enter_outer_loop"),
706 ops::Const(root.WithOpName("constant"), 5), "outer_loop",
707 ops::internal::Enter::Attrs().IsConstant(true));
708 ops::Switch inner_value(root.WithOpName("outer_is_live"),
709 enter_constant_outer_loop, iv_outer.loop_cond);
710 InductionVarInfo iv_inner = CreateInductionVariable(
711 root, "iv_inner", "inner_loop", inner_value.output_true);
712
713 Output dependent_outer_iv0 =
714 CreateDependentLoopInvariantValue(root, "dependent_outer_iv0",
715 "outer_loop", iv_outer.loop_cond, 0)
716 .induction_var;
717 Output dependent_outer_iv1 =
718 CreateDependentLoopInvariantValue(root, "dependent_outer_iv1",
719 "outer_loop", iv_outer.loop_cond, 0)
720 .induction_var;
721
722 Output dependent_inner_iv0 = CreateDependentLoopInvariantValue(
723 root, "dependent_inner_iv0", "inner_loop",
724 iv_inner.loop_cond, dependent_outer_iv0)
725 .induction_var;
726 Output dependent_inner_iv1 = CreateDependentLoopInvariantValue(
727 root, "dependent_inner_iv1", "inner_loop",
728 iv_inner.loop_cond, dependent_outer_iv1)
729 .induction_var;
730
731 Output add0 = ops::Add(root.WithOpName("add0"), dependent_inner_iv0,
732 dependent_inner_iv1);
733
734 VLogGraphIfAsked(*root.graph());
735
736 {
737 std::unique_ptr<DeadnessAnalysis> result;
738 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
739
740 TF_ASSERT_OK_AND_ASSIGN(
741 bool has_inputs_with_mismatching_deadness,
742 HasInputsWithMismatchingDeadness(*result, *add0.node()));
743 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
744 }
745 {
746 PredicateMapTy predicate_map;
747 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
748 /*enable_optimistic=*/true));
749
750 EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
751 "{#true,&,*iv_outer/cond:0}<outer_loop>");
752 EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)],
753 "{(*iv_outer/cond:0 & "
754 "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
755 "cond:0}<inner_loop;outer_loop>");
756
757 // enable_optimistic = true or not should produce the same results because
758 // of fallback. However, note that the order of iv_inner/cond:0 and
759 // iv_inner/iv:0 is different because the optimistic approach does not
760 // create predicates for all merges and it can change the predicate id and
761 // hence the symbol order.
762 EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)],
763 "{{#true,&,(iv_outer/iv:0 & "
764 "*iv_outer/cond:0)}<outer_loop>,&,(*iv_inner/cond:0 & "
765 "iv_inner/iv:0)}<inner_loop;outer_loop>");
766 EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)],
767 predicate_map[ControlOutputFor(dependent_inner_iv0)]);
768 EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
769 predicate_map[ControlOutputFor(dependent_inner_iv0)]);
770 }
771 {
772 PredicateMapTy predicate_map;
773 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
774 /*enable_optimistic=*/false));
775
776 EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
777 "{#true,&,*iv_outer/cond:0}<outer_loop>");
778 EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)],
779 "{(*iv_outer/cond:0 & "
780 "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
781 "cond:0}<inner_loop;outer_loop>");
782
783 EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)],
784 "{{#true,&,(iv_outer/iv:0 & "
785 "*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
786 "*iv_inner/cond:0)}<inner_loop;outer_loop>");
787 EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)],
788 predicate_map[ControlOutputFor(dependent_inner_iv0)]);
789 EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
790 predicate_map[ControlOutputFor(dependent_inner_iv0)]);
791 }
792 }
793
TEST(DeadnessAnalysisTest,ControlNonEquivalentNestedLoopBodies)794 TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) {
795 Scope root = Scope::NewRootScope().ExitOnError();
796
797 std::array<Output, 2> outer_iv;
798 std::array<Output, 2> inner_iv;
799
800 for (int i : {0, 1}) {
801 InductionVarInfo iv_outer =
802 CreateInductionVariable(root, "iv_outer", "outer_loop", 0);
803 Output enter_constant_outer_loop = ops::internal::Enter(
804 root.WithOpName("constant_enter_outer_loop"),
805 ops::Const(root.WithOpName("constant"), 5), "outer_loop",
806 ops::internal::Enter::Attrs().IsConstant(true));
807 ops::Switch inner_value(root.WithOpName("outer_is_live"),
808 enter_constant_outer_loop, iv_outer.loop_cond);
809 InductionVarInfo iv_inner = CreateInductionVariable(
810 root, "iv_inner", "inner_loop", inner_value.output_true);
811
812 outer_iv[i] = iv_outer.induction_var;
813 inner_iv[i] = iv_inner.induction_var;
814 }
815
816 Output add0 = ops::Add(root.WithOpName("add0"), inner_iv[0], inner_iv[1]);
817
818 VLogGraphIfAsked(*root.graph());
819
820 {
821 std::unique_ptr<DeadnessAnalysis> result;
822 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
823
824 TF_ASSERT_OK_AND_ASSIGN(
825 bool has_inputs_with_mismatching_deadness,
826 HasInputsWithMismatchingDeadness(*result, *add0.node()));
827 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
828 }
829
830 {
831 PredicateMapTy predicate_map;
832 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
833
834 EXPECT_EQ(predicate_map[ControlOutputFor(outer_iv[0])],
835 "{#true,&,*iv_outer/cond:0}<outer_loop>");
836 EXPECT_EQ(predicate_map[ControlOutputFor(inner_iv[0])],
837 "{(*iv_outer/cond:0 & "
838 "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
839 "cond:0}<inner_loop;outer_loop>");
840 EXPECT_EQ(predicate_map[ControlOutputFor(outer_iv[1])],
841 "{#true,&,*iv_outer/cond_1:0}<outer_loop>");
842 EXPECT_EQ(predicate_map[ControlOutputFor(inner_iv[1])],
843 "{(*iv_outer/cond_1:0 & "
844 "{#true,&,*iv_outer/cond_1:0}<outer_loop>),&,*iv_inner/"
845 "cond_1:0}<inner_loop;outer_loop>");
846 EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
847 "({(*iv_outer/cond:0 & "
848 "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
849 "cond:0}<inner_loop;outer_loop> & {(*iv_outer/cond_1:0 & "
850 "{#true,&,*iv_outer/cond_1:0}<outer_loop>),&,*iv_inner/"
851 "cond_1:0}<inner_loop;outer_loop>)");
852 }
853 }
854
TEST(DeadnessAnalysisTest,NestedLoopBodiesWithACapture)855 TEST(DeadnessAnalysisTest, NestedLoopBodiesWithACapture) {
856 Scope root = Scope::NewRootScope().ExitOnError();
857 InductionVarInfo iv_outer =
858 CreateInductionVariable(root, "iv_outer", "outer_loop", 0);
859 Output enter_constant_outer_loop = ops::internal::Enter(
860 root.WithOpName("constant_enter_outer_loop"),
861 ops::Const(root.WithOpName("constant"), 5), "outer_loop",
862 ops::internal::Enter::Attrs().IsConstant(true));
863 ops::Switch inner_value(root.WithOpName("outer_is_live"),
864 enter_constant_outer_loop, iv_outer.loop_cond);
865 InductionVarInfo iv_inner = CreateInductionVariable(
866 root, "iv_inner", "inner_loop", inner_value.output_true);
867
868 DependentInductionVar div0_outer = CreateDependentLoopInvariantValue(
869 root, "div0_outer", "outer_loop", iv_outer.loop_cond, 0);
870 DependentInductionVar div1_outer = CreateDependentLoopInvariantValue(
871 root, "div1_outer", "outer_loop", iv_outer.loop_cond, 0);
872
873 DependentInductionVar div0_inner = CreateDependentLoopInvariantValue(
874 root, "div0_inner", "inner_loop", iv_inner.loop_cond,
875 div0_outer.induction_var);
876 DependentInductionVar div1_inner = CreateDependentLoopInvariantValue(
877 root, "div1_inner", "inner_loop", iv_inner.loop_cond,
878 div1_outer.induction_var);
879
880 Output captured = ops::_Recv(root.WithOpName("captured"), DT_INT32,
881 "tensor_a", "sender", 0, "receiver");
882 Output capture_enter_outer = ops::internal::Enter(
883 root.WithOpName("capture_enter_outer"), captured, "outer_loop",
884 ops::internal::Enter::Attrs().IsConstant(true));
885 Output capture_enter_inner = ops::internal::Enter(
886 root.WithOpName("capture_enter_inner"), capture_enter_outer, "inner_loop",
887 ops::internal::Enter::Attrs().IsConstant(true));
888 Output mul0 = ops::Mul(root.WithOpName("mul0"), div1_inner.induction_var,
889 capture_enter_inner);
890 TF_ASSERT_OK(root.graph()->UpdateEdge(
891 mul0.node(), 0, div1_inner.latch.output_true.node(), 0));
892
893 Output add0 = ops::Add(root.WithOpName("add0"), div0_inner.induction_var,
894 div1_inner.induction_var);
895
896 VLogGraphIfAsked(*root.graph());
897
898 {
899 std::unique_ptr<DeadnessAnalysis> result;
900 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
901
902 TF_ASSERT_OK_AND_ASSIGN(
903 bool has_inputs_with_mismatching_deadness,
904 HasInputsWithMismatchingDeadness(*result, *add0.node()));
905 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
906 }
907 }
908
TEST(DeadnessAnalysisTest,CyclicRecurrence)909 TEST(DeadnessAnalysisTest, CyclicRecurrence) {
910 Scope root = Scope::NewRootScope().ExitOnError();
911 InductionVarInfo iv = CreateInductionVariable(root, "iv0", "loop", 0);
912 DependentInductionVar div0 =
913 CreateDependentLoopInvariantValue(root, "div0", "loop", iv.loop_cond, 0);
914 DependentInductionVar div1 =
915 CreateDependentLoopInvariantValue(root, "div1", "loop", iv.loop_cond, 0);
916 FixupSourceAndSinkEdges(root.graph());
917 TF_ASSERT_OK(root.graph()->UpdateEdge(div1.induction_var.node(), 0,
918 div0.latch.output_true.node(), 0));
919 TF_ASSERT_OK(root.graph()->UpdateEdge(div0.induction_var.node(), 0,
920 div1.latch.output_true.node(), 0));
921
922 VLogGraphIfAsked(*root.graph());
923
924 {
925 PredicateMapTy predicate_map;
926 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
927 /*enable_optimistic=*/true));
928
929 EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
930 "{#true,&,*iv0/cond:0}<loop>");
931 EXPECT_EQ(predicate_map[ControlOutputFor(div0.induction_var)],
932 "{#true,&,*iv0/cond:0}<loop>");
933 EXPECT_EQ(predicate_map[ControlOutputFor(div1.induction_var)],
934 "{#true,&,*iv0/cond:0}<loop>");
935
936 // This tests the rule {S,&,X} & ~X => S.
937 TensorId switch_false_out = {div1.latch.output_false.node()->name(),
938 div1.latch.output_false.index()};
939 EXPECT_EQ(predicate_map[switch_false_out], "(#true)");
940 }
941 {
942 PredicateMapTy predicate_map;
943 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
944 /*enable_optimistic=*/false));
945
946 EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
947 "{#true,&,*iv0/cond:0}<loop>");
948 EXPECT_EQ(predicate_map[ControlOutputFor(div0.induction_var)], "div0/iv:0");
949 EXPECT_EQ(predicate_map[ControlOutputFor(div1.induction_var)], "div1/iv:0");
950 }
951 }
952
TEST(DeadnessAnalysisTest,AndRecurrenceNeedsFrameName)953 TEST(DeadnessAnalysisTest, AndRecurrenceNeedsFrameName) {
954 Scope root = Scope::NewRootScope().ExitOnError();
955 InductionVarInfo iv_0 = CreateInductionVariable(root, "iv_0", "frame_0", 10);
956 InductionVarInfo iv_1 = CreateInductionVariable(root, "iv_1", "frame_1", 9);
957
958 Output init = CreateSwitch(root, "init").output_true;
959 Output step = CreateSwitch(root, "step").output_true;
960
961 std::array<Output, 2> exits;
962 std::array<Output, 2> next_iterations;
963
964 for (int i : {0, 1}) {
965 Output init_enter = ops::internal::Enter(
966 root.WithOpName(absl::StrCat("init_enter_frame_", i)), init,
967 absl::StrCat("frame_", i),
968 ops::internal::Enter::Attrs().IsConstant(true));
969 Output step_enter = ops::internal::Enter(
970 root.WithOpName(absl::StrCat("step_enter_frame_", i)), step,
971 absl::StrCat("frame_", i),
972 ops::internal::Enter::Attrs().IsConstant(true));
973
974 ops::Merge iv(root.WithOpName(absl::StrCat("expr_", i)),
975 {init_enter, init_enter});
976 Output add = ops::Add(root.WithOpName(absl::StrCat("add_", i)), iv.output,
977 step_enter);
978 next_iterations[i] = ops::NextIteration(
979 root.WithOpName(absl::StrCat("expr_", i, "_next_iteration")), add);
980 EXPECT_TRUE(
981 root.graph()
982 ->UpdateEdge(next_iterations[i].node(), 0, iv.output.node(), 1)
983 .ok());
984 exits[i] = ops::internal::Exit(root.WithOpName(absl::StrCat("exit_", i)),
985 iv.output);
986 }
987
988 FixupSourceAndSinkEdges(root.graph());
989
990 {
991 PredicateMapTy predicate_map;
992 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
993
994 EXPECT_NE(predicate_map[ControlOutputFor(exits[0])],
995 predicate_map[ControlOutputFor(exits[1])]);
996 EXPECT_NE(predicate_map[ControlOutputFor(exits[0])], "");
997 EXPECT_NE(predicate_map[ControlOutputFor(exits[1])], "");
998
999 EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[0])],
1000 predicate_map[ControlOutputFor(next_iterations[1])]);
1001 EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[0])], "");
1002 EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[1])], "");
1003 }
1004 }
1005
TEST(DeadnessAnalysisTest,ControlInputs)1006 TEST(DeadnessAnalysisTest, ControlInputs) {
1007 Scope root = Scope::NewRootScope().ExitOnError();
1008 ops::Switch sw = CreateSwitch(root, "0");
1009
1010 Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
1011 Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
1012
1013 Output const0 = ops::Const(root.WithOpName("const0"), 1);
1014 Output const1 = ops::Const(root.WithOpName("const1"), 2);
1015
1016 Output add = ops::Add(root.WithOpName("add"), const0, const1);
1017
1018 root.graph()->AddControlEdge(id0.node(), const0.node());
1019 root.graph()->AddControlEdge(id1.node(), const1.node());
1020
1021 std::unique_ptr<DeadnessAnalysis> result;
1022 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
1023
1024 TF_ASSERT_OK_AND_ASSIGN(
1025 bool has_inputs_with_mismatching_deadness,
1026 HasInputsWithMismatchingDeadness(*result, *add.node()));
1027 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
1028 }
1029
TEST(DeadnessAnalysisTest,ControlTrigger)1030 TEST(DeadnessAnalysisTest, ControlTrigger) {
1031 Scope root = Scope::NewRootScope().ExitOnError();
1032 ops::Switch sw = CreateSwitch(root, "0");
1033
1034 Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
1035 Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
1036
1037 ops::ControlTrigger ctrl_trigger0(root.WithOpName("ctrl_trigger0"));
1038 ops::ControlTrigger ctrl_trigger1(root.WithOpName("ctrl_trigger1"));
1039
1040 Output const0 = ops::Const(root.WithOpName("const0"), 1);
1041 Output const1 = ops::Const(root.WithOpName("const1"), 2);
1042
1043 Output add = ops::Add(root.WithOpName("add"), const0, const1);
1044
1045 root.graph()->AddControlEdge(id0.node(), ctrl_trigger0.operation.node());
1046 root.graph()->AddControlEdge(ctrl_trigger0.operation.node(), const0.node());
1047
1048 root.graph()->AddControlEdge(id1.node(), ctrl_trigger1.operation.node());
1049 root.graph()->AddControlEdge(ctrl_trigger1.operation.node(), const1.node());
1050
1051 std::unique_ptr<DeadnessAnalysis> result;
1052 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
1053
1054 TF_ASSERT_OK_AND_ASSIGN(
1055 bool has_inputs_with_mismatching_deadness,
1056 HasInputsWithMismatchingDeadness(*result, *add.node()));
1057 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
1058 }
1059
TEST(DeadnessAnalysisTest,ControlInputsToMerge)1060 TEST(DeadnessAnalysisTest, ControlInputsToMerge) {
1061 Scope root = Scope::NewRootScope().ExitOnError();
1062 ops::Switch sw = CreateSwitch(root, "0");
1063
1064 Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
1065 Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
1066
1067 Output constant = ops::Const(root.WithOpName("constant"), 5);
1068 ops::Merge m0(root.WithOpName("m0"), {constant});
1069 ops::Merge m1(root.WithOpName("m0"), {constant});
1070 Output add = ops::Add(root.WithOpName("add"), m0.output, m1.output);
1071
1072 root.graph()->AddControlEdge(id0.node(), m0.output.node());
1073 root.graph()->AddControlEdge(id1.node(), m1.output.node());
1074
1075 std::unique_ptr<DeadnessAnalysis> result;
1076 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
1077
1078 TF_ASSERT_OK_AND_ASSIGN(
1079 bool has_inputs_with_mismatching_deadness,
1080 HasInputsWithMismatchingDeadness(*result, *add.node()));
1081 EXPECT_FALSE(has_inputs_with_mismatching_deadness);
1082 }
1083
TEST(DeadnessAnalysisTest,RecvVsSwitch)1084 TEST(DeadnessAnalysisTest, RecvVsSwitch) {
1085 // Demonstrates why we need the must_be_true bit on SymbolP.
1086 Scope root = Scope::NewRootScope().ExitOnError();
1087
1088 Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender",
1089 0, "receiver");
1090 Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL);
1091 ops::Switch sw(root.WithOpName("switch"), value, recv);
1092 Output logical_and =
1093 ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true);
1094
1095 std::unique_ptr<DeadnessAnalysis> result;
1096 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
1097
1098 TF_ASSERT_OK_AND_ASSIGN(
1099 bool has_inputs_with_mismatching_deadness,
1100 HasInputsWithMismatchingDeadness(*result, *logical_and.node()));
1101 EXPECT_TRUE(has_inputs_with_mismatching_deadness);
1102 }
1103
TEST(DeadnessAnalysisTest,RecvVsSwitchText)1104 TEST(DeadnessAnalysisTest, RecvVsSwitchText) {
1105 // Demonstrates why we need the must_be_true bit on SymbolP.
1106 Scope root = Scope::NewRootScope().ExitOnError();
1107
1108 Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender",
1109 0, "receiver");
1110 Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL);
1111 ops::Switch sw(root.WithOpName("switch"), value, recv);
1112 Output logical_and =
1113 ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true);
1114
1115 std::unique_ptr<DeadnessAnalysis> result;
1116 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
1117
1118 PredicateMapTy predicate_map;
1119 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1120
1121 TensorId logical_and_output_0 = {logical_and.node()->name(),
1122 Graph::kControlSlot};
1123 EXPECT_EQ(predicate_map[logical_and_output_0], "(recv:0 & *recv:0)");
1124 }
1125
TEST(DeadnessAnalysisTest,DeMorgan)1126 TEST(DeadnessAnalysisTest, DeMorgan) {
1127 Scope root = Scope::NewRootScope().ExitOnError();
1128
1129 Output cond_0 = ops::Placeholder(root.WithOpName("cond_0"), DT_BOOL);
1130 Output cond_1 = ops::Placeholder(root.WithOpName("cond_1"), DT_BOOL);
1131 Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1132
1133 ops::Switch sw_0(root.WithOpName("switch_0"), value, cond_0);
1134 ops::Switch sw_1(root.WithOpName("switch_1"), value, cond_1);
1135
1136 Output and_0_1 =
1137 ops::Add(root.WithOpName("and_0_1"), sw_0.output_true, sw_1.output_true);
1138
1139 Output or_not0_not1 = ops::Merge(root.WithOpName("or_not0_not1"),
1140 {sw_0.output_false, sw_1.output_false})
1141 .output;
1142
1143 // Predicate(should_always_be_dead) =
1144 // (A & B) & (~A | ~B) = (A & B) & ~(A & B) = False
1145 Output should_always_be_dead =
1146 ops::Add(root.WithOpName("should_always_be_dead"), and_0_1, or_not0_not1);
1147
1148 // Predicate(should_always_be_dead) =
1149 // (A & B) | (~A | ~B) = (A & B) | ~(A & B) = True
1150 Output should_always_be_alive =
1151 ops::Merge(root.WithOpName("should_always_be_alive"),
1152 {and_0_1, or_not0_not1})
1153 .output;
1154
1155 std::unique_ptr<DeadnessAnalysis> result;
1156 TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
1157
1158 PredicateMapTy predicate_map;
1159 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1160
1161 EXPECT_EQ(predicate_map[ControlOutputFor(should_always_be_dead)], "#false");
1162 EXPECT_EQ(predicate_map[ControlOutputFor(should_always_be_alive)], "#true");
1163 }
1164
TEST(DeadnessAnalysisTest,ConstantTrueSwitchCondition)1165 TEST(DeadnessAnalysisTest, ConstantTrueSwitchCondition) {
1166 Scope root = Scope::NewRootScope().ExitOnError();
1167
1168 Output constant_true = ops::Const(root.WithOpName("const_true"), true);
1169 Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1170 ops::Switch sw(root.WithOpName("switch"), value, constant_true);
1171
1172 Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false);
1173 Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true);
1174
1175 FixupSourceAndSinkEdges(root.graph());
1176
1177 PredicateMapTy predicate_map;
1178 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1179
1180 EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "#false");
1181 EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#true");
1182 }
1183
TEST(DeadnessAnalysisTest,ConstantFalseSwitchCondition)1184 TEST(DeadnessAnalysisTest, ConstantFalseSwitchCondition) {
1185 Scope root = Scope::NewRootScope().ExitOnError();
1186
1187 Output constant_false = ops::Const(root.WithOpName("const_false"), false);
1188 Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1189 ops::Switch sw(root.WithOpName("switch"), value, constant_false);
1190
1191 Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false);
1192 Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true);
1193
1194 FixupSourceAndSinkEdges(root.graph());
1195
1196 PredicateMapTy predicate_map;
1197 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1198
1199 EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "#true");
1200 EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#false");
1201 }
1202
TEST(DeadnessAnalysisTest,RefBoolSwitchCondition)1203 TEST(DeadnessAnalysisTest, RefBoolSwitchCondition) {
1204 Scope root = Scope::NewRootScope().ExitOnError();
1205
1206 Output condition_ref_var =
1207 ops::Variable(root.WithOpName("cond_ref"), TensorShape({}), DT_BOOL);
1208 Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1209 ops::Switch sw(root.WithOpName("switch"), value, condition_ref_var);
1210
1211 Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false);
1212 Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true);
1213
1214 FixupSourceAndSinkEdges(root.graph());
1215
1216 PredicateMapTy predicate_map;
1217 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1218
1219 EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "~*cond_ref:0");
1220 EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "*cond_ref:0");
1221 }
1222
CreateSwitchN(const Scope & scope,Input data,Input output_index,int64 num_outs,OutputList * outputs)1223 void CreateSwitchN(const Scope& scope, Input data, Input output_index,
1224 int64 num_outs, OutputList* outputs) {
1225 if (!scope.ok()) return;
1226 auto _data = ops::AsNodeOut(scope, data);
1227 if (!scope.ok()) return;
1228 auto _output_index = ops::AsNodeOut(scope, output_index);
1229 if (!scope.ok()) return;
1230 Node* ret;
1231 const auto unique_name = scope.GetUniqueNameForOp("_SwitchN");
1232 auto builder = NodeBuilder(unique_name, "_SwitchN")
1233 .Input(_data)
1234 .Input(_output_index)
1235 .Attr("num_outs", num_outs);
1236 scope.UpdateBuilder(&builder);
1237 scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
1238 if (!scope.ok()) return;
1239 scope.UpdateStatus(scope.DoShapeInference(ret));
1240 for (int32 i = 0; i < ret->num_outputs(); ++i) {
1241 outputs->push_back(Output(ret, i));
1242 }
1243 }
1244
TEST(DeadnessAnalysisTest,Constant1_SwitchN_2Branches_DoesNotFail)1245 TEST(DeadnessAnalysisTest, Constant1_SwitchN_2Branches_DoesNotFail) {
1246 Scope root = Scope::NewRootScope().ExitOnError();
1247
1248 Output constant_1 = ops::Const(root.WithOpName("const_1"), 1);
1249 Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1250 OutputList outputs;
1251 CreateSwitchN(root.WithOpName("switchn"), value, constant_1, 2, &outputs);
1252
1253 Output id_0 = ops::Identity(root.WithOpName("id_0"), outputs[0]);
1254 Output id_1 = ops::Identity(root.WithOpName("id_1"), outputs[1]);
1255
1256 FixupSourceAndSinkEdges(root.graph());
1257
1258 PredicateMapTy predicate_map;
1259 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1260
1261 EXPECT_EQ(predicate_map[ControlOutputFor(id_0)], "#false");
1262 EXPECT_EQ(predicate_map[ControlOutputFor(id_1)], "#true");
1263 }
1264
TEST(DeadnessAnalysisTest,Constant7_SwitchN_3Branches)1265 TEST(DeadnessAnalysisTest, Constant7_SwitchN_3Branches) {
1266 Scope root = Scope::NewRootScope().ExitOnError();
1267
1268 Output constant_7 = ops::Const(root.WithOpName("const_7"), 7);
1269 Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1270 OutputList outputs;
1271 CreateSwitchN(root.WithOpName("switchn"), value, constant_7, 3, &outputs);
1272
1273 Output id_0 = ops::Identity(root.WithOpName("id_0"), outputs[0]);
1274 Output id_1 = ops::Identity(root.WithOpName("id_1"), outputs[1]);
1275 Output id_2 = ops::Identity(root.WithOpName("id_2"), outputs[2]);
1276
1277 FixupSourceAndSinkEdges(root.graph());
1278
1279 PredicateMapTy predicate_map;
1280 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1281
1282 EXPECT_EQ(predicate_map[ControlOutputFor(id_0)], "#false");
1283 EXPECT_EQ(predicate_map[ControlOutputFor(id_1)], "#false");
1284 EXPECT_EQ(predicate_map[ControlOutputFor(id_2)], "#true");
1285 }
1286
TEST(DeadnessAnalysisTest,RefInt_SwitchN_3Branches)1287 TEST(DeadnessAnalysisTest, RefInt_SwitchN_3Branches) {
1288 Scope root = Scope::NewRootScope().ExitOnError();
1289
1290 Output condition_ref_var =
1291 ops::Variable(root.WithOpName("bidx"), TensorShape({}), DT_INT32);
1292 Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1293 OutputList outputs;
1294 CreateSwitchN(root.WithOpName("switchn"), value, condition_ref_var, 3,
1295 &outputs);
1296
1297 Output id_0 = ops::Identity(root.WithOpName("id_0"), outputs[0]);
1298 Output id_1 = ops::Identity(root.WithOpName("id_1"), outputs[1]);
1299 Output id_2 = ops::Identity(root.WithOpName("id_2"), outputs[2]);
1300
1301 FixupSourceAndSinkEdges(root.graph());
1302
1303 PredicateMapTy predicate_map;
1304 TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1305
1306 EXPECT_EQ(predicate_map[ControlOutputFor(id_0)], "bidx:0=0");
1307 EXPECT_EQ(predicate_map[ControlOutputFor(id_1)], "(~bidx:0=0 & bidx:0=1)");
1308 EXPECT_EQ(predicate_map[ControlOutputFor(id_2)], "(~bidx:0=0 & ~bidx:0=1)");
1309 }
1310
1311 } // namespace
1312 } // namespace tensorflow
1313