1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/deadness_analysis.h"
17 
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/array_ops.h"
20 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
21 #include "tensorflow/cc/ops/function_ops.h"
22 #include "tensorflow/cc/ops/sendrecv_ops.h"
23 #include "tensorflow/cc/ops/standard_ops.h"
24 #include "tensorflow/compiler/jit/deadness_analysis_internal.h"
25 #include "tensorflow/compiler/jit/defs.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
28 #include "tensorflow/core/common_runtime/graph_constructor.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/graph/algorithm.h"
32 #include "tensorflow/core/graph/graph_def_builder.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/test.h"
35 
36 namespace tensorflow {
37 namespace {
38 
HasInputsWithMismatchingDeadness(const DeadnessAnalysis & deadness_analysis,const Node & n)39 se::port::StatusOr<bool> HasInputsWithMismatchingDeadness(
40     const DeadnessAnalysis& deadness_analysis, const Node& n) {
41   absl::optional<DeadnessAnalysis::DeadnessPredicate> pred;
42   for (const Edge* edge : n.in_edges()) {
43     TF_ASSIGN_OR_RETURN(
44         DeadnessAnalysis::DeadnessPredicate this_pred,
45         deadness_analysis.GetPredicateFor(edge->src(), edge->src_output()));
46     if (pred && *pred != this_pred) {
47       return true;
48     }
49     pred = this_pred;
50   }
51 
52   return false;
53 }
54 
55 using deadness_analysis_internal::ComputePredicates;
56 using deadness_analysis_internal::PredicateMapTy;
57 
AnalyzeDeadness(Graph * graph,std::unique_ptr<DeadnessAnalysis> * result)58 Status AnalyzeDeadness(Graph* graph,
59                        std::unique_ptr<DeadnessAnalysis>* result) {
60   FixupSourceAndSinkEdges(graph);
61   return DeadnessAnalysis::Run(*graph, result);
62 }
63 
CreateSwitch(const Scope & root,const string & prefix)64 ops::Switch CreateSwitch(const Scope& root, const string& prefix) {
65   Output value = ops::Placeholder(root.WithOpName(prefix + "/value"), DT_FLOAT);
66   Output predicate =
67       ops::Placeholder(root.WithOpName(prefix + "/pred"), DT_BOOL);
68   return ops::Switch(root.WithOpName(prefix + "/switch"), value, predicate);
69 }
70 
ControlOutputFor(const Output & o)71 TensorId ControlOutputFor(const Output& o) {
72   return {o.node()->name(), Graph::kControlSlot};
73 }
74 
VLogGraphIfAsked(const Graph & graph)75 void VLogGraphIfAsked(const Graph& graph) {
76   if (VLOG_IS_ON(3)) {
77     GraphDef graph_def;
78     graph.ToGraphDef(&graph_def);
79     string serialized;
80     ::tensorflow::protobuf::TextFormat::PrintToString(graph_def, &serialized);
81     LOG(INFO) << serialized;
82   }
83 }
84 
85 struct InductionVarInfo {
86   Output induction_var;
87   Output loop_cond;
88 };
89 
90 // Creates an induction variable with the following structure (simplified for
91 // brevity):
92 //
93 //            +---------------+
94 //            | initial_value |
95 //            +---------------+
96 //              |
97 //              |
98 //              v
99 //            +---------------+
100 //            |     Enter     |
101 //            +---------------+
102 //              |
103 //              |
104 //              v
105 //            +---------------+
106 //         +> |     Merge     | -+
107 //         |  +---------------+  |
108 //         |    |                |
109 //         |    |                |
110 //         |    v                |
111 //         |  +---------------+  |
112 //         |  |  LessThan10   |  |
113 //         |  +---------------+  |
114 //         |    |                |
115 //         |    |                |
116 //         |    v                |
117 //         |  +---------------+  |
118 //    +----+- |    Switch     | <+
119 //    |    |  +---------------+
120 //    |    |    |
121 //    |    |    |
122 //    |    |    v
123 //    |    |  +---------------+
124 //    |    +- |    AddOne     |
125 //    |       +---------------+
126 //    |       +---------------+
127 //    +-----> |     Exit      |
128 //            +---------------+
CreateInductionVariable(const Scope & root,const string & prefix,const string & frame_name,const Output & initial_value)129 InductionVarInfo CreateInductionVariable(const Scope& root,
130                                          const string& prefix,
131                                          const string& frame_name,
132                                          const Output& initial_value) {
133   Output enter_initial_value = ops::internal::Enter(
134       root.WithOpName(prefix + "/enter"), initial_value, frame_name);
135 
136   ops::Merge iv(root.WithOpName(prefix + "/iv"),
137                 {enter_initial_value, enter_initial_value});
138   Output increment_by = ops::Const(root.WithOpName(prefix + "/incr"), 1);
139   Output final_value = ops::Const(root.WithOpName(prefix + "/final"), 10);
140   Output loop_cond_expr =
141       ops::Less(root.WithOpName(prefix + "/cond"), iv.output, final_value);
142   ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output,
143                     loop_cond_expr);
144   ops::internal::Exit exit(root.WithOpName(prefix + "/exit"),
145                            latch.output_false);
146   Output iv_next = ops::Add(root.WithOpName(prefix + "/ivnext"),
147                             latch.output_true, increment_by);
148   Output next_iteration =
149       ops::NextIteration(root.WithOpName(prefix + "/next_iteration"), iv_next);
150 
151   CHECK(root.graph()
152             ->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1)
153             .ok());
154   root.graph()->AddControlEdge(iv.output.node(), increment_by.node());
155   root.graph()->AddControlEdge(iv.output.node(), final_value.node());
156 
157   return {iv.output, loop_cond_expr};
158 }
159 
CreateInductionVariable(const Scope & root,const string & prefix,const string & frame_name,int32 init)160 InductionVarInfo CreateInductionVariable(const Scope& root,
161                                          const string& prefix,
162                                          const string& frame_name, int32 init) {
163   return CreateInductionVariable(
164       root, prefix, frame_name,
165       ops::Const(root.WithOpName(prefix + "/init"), init));
166 }
167 
168 // Creates an induction variable with the following structure:
169 //
170 //                           +---------------+
171 //                           | initial_value |
172 //                           +---------------+
173 //                             |
174 //                             |
175 //                             v
176 //                           +---------------+
177 //                           |     Enter     |
178 //                           +---------------+
179 //                             |
180 //                             |
181 //                             v
182 //                           +---------------+
183 //                           |     Merge     | <+
184 //                           +---------------+  |
185 //                             |                |
186 //                             |                |
187 //                             v                |
188 //         +-----------+     +---------------+  |
189 //         | loop_cond | --> |    Switch     | -+
190 //         +-----------+     +---------------+
191 //                             |
192 //                             |
193 //                             v
194 //                           +---------------+
195 //                           |     Exit      |
196 //                           +---------------+
197 struct DependentInductionVar {
198   Output induction_var;
199   ops::Switch latch;
200 };
201 
CreateDependentLoopInvariantValue(const Scope & root,const string & prefix,const string & frame_name,const Output & loop_cond,const Output & value)202 DependentInductionVar CreateDependentLoopInvariantValue(
203     const Scope& root, const string& prefix, const string& frame_name,
204     const Output& loop_cond, const Output& value) {
205   Output enter_value = ops::internal::Enter(root.WithOpName(prefix + "/enter"),
206                                             value, frame_name);
207   ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_value, enter_value});
208   ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond);
209   ops::internal::Exit exit(root.WithOpName(prefix + "/exit"),
210                            latch.output_false);
211   Output next_iteration = ops::NextIteration(
212       root.WithOpName(prefix + "/next_iteration"), latch.output_true);
213   CHECK(root.graph()
214             ->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1)
215             .ok());
216   return {iv.output, latch};
217 }
218 
CreateDependentLoopInvariantValue(const Scope & root,const string & prefix,const string & frame_name,const Output & loop_cond,int32 value)219 DependentInductionVar CreateDependentLoopInvariantValue(
220     const Scope& root, const string& prefix, const string& frame_name,
221     const Output& loop_cond, int32 value) {
222   return CreateDependentLoopInvariantValue(
223       root, prefix, frame_name, loop_cond,
224       ops::Const(root.WithOpName(prefix + "/init"), value));
225 }
226 
TEST(DeadnessAnalysisTest,BasicPositive)227 TEST(DeadnessAnalysisTest, BasicPositive) {
228   Scope root = Scope::NewRootScope().ExitOnError();
229 
230   ops::Switch sw = CreateSwitch(root, "0");
231   Output add =
232       ops::Add(root.WithOpName("add"), sw.output_true, sw.output_false);
233 
234   std::unique_ptr<DeadnessAnalysis> result;
235   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
236 
237   TF_ASSERT_OK_AND_ASSIGN(
238       bool has_inputs_with_mismatching_deadness,
239       HasInputsWithMismatchingDeadness(*result, *add.node()));
240   EXPECT_TRUE(has_inputs_with_mismatching_deadness);
241 }
242 
TEST(DeadnessAnalysisTest,BasicNegative)243 TEST(DeadnessAnalysisTest, BasicNegative) {
244   Scope root = Scope::NewRootScope().ExitOnError();
245 
246   Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT);
247   Output b = ops::Placeholder(root.WithOpName("b"), DT_FLOAT);
248   Output add = ops::Add(root.WithOpName("add"), a, b);
249 
250   std::unique_ptr<DeadnessAnalysis> result;
251   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
252 
253   TF_ASSERT_OK_AND_ASSIGN(
254       bool has_inputs_with_mismatching_deadness,
255       HasInputsWithMismatchingDeadness(*result, *add.node()));
256   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
257 }
258 
TEST(DeadnessAnalysisTest,AndIsCommutative)259 TEST(DeadnessAnalysisTest, AndIsCommutative) {
260   Scope root = Scope::NewRootScope().ExitOnError();
261 
262   ops::Switch sw_0 = CreateSwitch(root, "0");
263   ops::Switch sw_1 = CreateSwitch(root, "1");
264 
265   Output a0 =
266       ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false);
267   Output a1 =
268       ops::Add(root.WithOpName("a1"), sw_1.output_false, sw_0.output_false);
269 
270   Output b0 =
271       ops::Add(root.WithOpName("b0"), sw_0.output_false, sw_1.output_true);
272   Output b1 =
273       ops::Add(root.WithOpName("b1"), sw_1.output_true, sw_0.output_false);
274 
275   Output live0 = ops::Add(root.WithOpName("live0"), a0, a1);
276   Output live1 = ops::Add(root.WithOpName("live1"), b0, b1);
277 
278   Output halfdead0 = ops::Add(root.WithOpName("halfdead0"), a0, b0);
279   Output halfdead1 = ops::Add(root.WithOpName("halfdead1"), a1, b1);
280 
281   std::unique_ptr<DeadnessAnalysis> result;
282   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
283 
284   bool has_inputs_with_mismatching_deadness;
285 
286   TF_ASSERT_OK_AND_ASSIGN(
287       has_inputs_with_mismatching_deadness,
288       HasInputsWithMismatchingDeadness(*result, *live0.node()));
289   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
290 
291   TF_ASSERT_OK_AND_ASSIGN(
292       has_inputs_with_mismatching_deadness,
293       HasInputsWithMismatchingDeadness(*result, *live1.node()));
294   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
295 
296   TF_ASSERT_OK_AND_ASSIGN(
297       has_inputs_with_mismatching_deadness,
298       HasInputsWithMismatchingDeadness(*result, *halfdead0.node()));
299   EXPECT_TRUE(has_inputs_with_mismatching_deadness);
300 
301   TF_ASSERT_OK_AND_ASSIGN(
302       has_inputs_with_mismatching_deadness,
303       HasInputsWithMismatchingDeadness(*result, *halfdead1.node()));
304   EXPECT_TRUE(has_inputs_with_mismatching_deadness);
305 }
306 
TEST(DeadnessAnalysisTest,AndIsAssociative)307 TEST(DeadnessAnalysisTest, AndIsAssociative) {
308   Scope root = Scope::NewRootScope().ExitOnError();
309 
310   ops::Switch sw_0 = CreateSwitch(root, "0");
311   ops::Switch sw_1 = CreateSwitch(root, "1");
312   ops::Switch sw_2 = CreateSwitch(root, "2");
313 
314   Output a0 =
315       ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false);
316   Output a1 = ops::Add(root.WithOpName("a1"), a0, sw_2.output_false);
317 
318   Output b0 =
319       ops::Add(root.WithOpName("b0"), sw_1.output_false, sw_2.output_false);
320   Output b1 = ops::Add(root.WithOpName("b1"), sw_0.output_false, b0);
321 
322   Output add = ops::Add(root.WithOpName("add"), a1, b1);
323 
324   std::unique_ptr<DeadnessAnalysis> result;
325   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
326 
327   TF_ASSERT_OK_AND_ASSIGN(
328       bool has_inputs_with_mismatching_deadness,
329       HasInputsWithMismatchingDeadness(*result, *add.node()));
330   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
331 }
332 
TEST(DeadnessAnalysisTest,OrIsCommutative)333 TEST(DeadnessAnalysisTest, OrIsCommutative) {
334   Scope root = Scope::NewRootScope().ExitOnError();
335 
336   ops::Switch sw_0 = CreateSwitch(root, "0");
337   ops::Switch sw_1 = CreateSwitch(root, "1");
338 
339   ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
340   ops::Merge m1(root.WithOpName("m1"), {sw_1.output_false, sw_0.output_false});
341   ops::Merge m2(root.WithOpName("m2"), {sw_0.output_false, sw_1.output_true});
342   ops::Merge m3(root.WithOpName("m3"), {sw_1.output_true, sw_0.output_false});
343 
344   Output live0 = ops::Add(root.WithOpName("live0"), m0.output, m1.output);
345   Output live1 = ops::Add(root.WithOpName("live1"), m2.output, m3.output);
346 
347   Output halfdead0 =
348       ops::Add(root.WithOpName("halfdead0"), m0.output, m2.output);
349   Output halfdead1 =
350       ops::Add(root.WithOpName("halfdead1"), m1.output, m3.output);
351 
352   std::unique_ptr<DeadnessAnalysis> result;
353   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
354 
355   bool has_inputs_with_mismatching_deadness;
356 
357   TF_ASSERT_OK_AND_ASSIGN(
358       has_inputs_with_mismatching_deadness,
359       HasInputsWithMismatchingDeadness(*result, *live0.node()));
360   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
361 
362   TF_ASSERT_OK_AND_ASSIGN(
363       has_inputs_with_mismatching_deadness,
364       HasInputsWithMismatchingDeadness(*result, *live1.node()));
365   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
366 
367   TF_ASSERT_OK_AND_ASSIGN(
368       has_inputs_with_mismatching_deadness,
369       HasInputsWithMismatchingDeadness(*result, *halfdead0.node()));
370   EXPECT_TRUE(has_inputs_with_mismatching_deadness);
371 
372   TF_ASSERT_OK_AND_ASSIGN(
373       has_inputs_with_mismatching_deadness,
374       HasInputsWithMismatchingDeadness(*result, *halfdead1.node()));
375   EXPECT_TRUE(has_inputs_with_mismatching_deadness);
376 }
377 
TEST(DeadnessAnalysisTest,OrIsAssociative)378 TEST(DeadnessAnalysisTest, OrIsAssociative) {
379   Scope root = Scope::NewRootScope().ExitOnError();
380 
381   ops::Switch sw_0 = CreateSwitch(root, "0");
382   ops::Switch sw_1 = CreateSwitch(root, "1");
383   ops::Switch sw_2 = CreateSwitch(root, "2");
384 
385   ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
386   ops::Merge m1(root.WithOpName("m1"), {m0.output, sw_2.output_false});
387   ops::Merge m2(root.WithOpName("m2"), {sw_1.output_false, sw_2.output_false});
388   ops::Merge m3(root.WithOpName("m3"), {sw_0.output_false, m2.output});
389 
390   Output add = ops::Add(root.WithOpName("add"), m1.output, m3.output);
391 
392   std::unique_ptr<DeadnessAnalysis> result;
393   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
394 
395   TF_ASSERT_OK_AND_ASSIGN(
396       bool has_inputs_with_mismatching_deadness,
397       HasInputsWithMismatchingDeadness(*result, *add.node()));
398   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
399 }
400 
TEST(DeadnessAnalysisTest,AndOfOr)401 TEST(DeadnessAnalysisTest, AndOfOr) {
402   Scope root = Scope::NewRootScope().ExitOnError();
403 
404   ops::Switch sw_0 = CreateSwitch(root, "0");
405   ops::Switch sw_1 = CreateSwitch(root, "1");
406   ops::Switch sw_2 = CreateSwitch(root, "2");
407   ops::Switch sw_3 = CreateSwitch(root, "3");
408 
409   ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
410   ops::Merge m1(root.WithOpName("m1"), {sw_2.output_false, sw_3.output_false});
411 
412   Output add0 = ops::Add(root.WithOpName("add0"), m0.output, m1.output);
413   Output add1 = ops::Add(root.WithOpName("add1"), m0.output, m1.output);
414 
415   Output add2 = ops::Add(root.WithOpName("add2"), add0, add1);
416 
417   std::unique_ptr<DeadnessAnalysis> result;
418   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
419 
420   TF_ASSERT_OK_AND_ASSIGN(
421       bool has_inputs_with_mismatching_deadness,
422       HasInputsWithMismatchingDeadness(*result, *add2.node()));
423   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
424 }
425 
TEST(DeadnessAnalysisTest,OrOfAnd)426 TEST(DeadnessAnalysisTest, OrOfAnd) {
427   Scope root = Scope::NewRootScope().ExitOnError();
428 
429   ops::Switch sw_0 = CreateSwitch(root, "0");
430   ops::Switch sw_1 = CreateSwitch(root, "1");
431   ops::Switch sw_2 = CreateSwitch(root, "2");
432   ops::Switch sw_3 = CreateSwitch(root, "3");
433 
434   Output add0 =
435       ops::Add(root.WithOpName("add0"), sw_0.output_false, sw_1.output_false);
436   Output add1 =
437       ops::Add(root.WithOpName("add1"), sw_2.output_false, sw_3.output_false);
438 
439   ops::Merge m0(root.WithOpName("m0"), {add0, add1});
440   ops::Merge m1(root.WithOpName("m1"), {add0, add1});
441 
442   Output add2 = ops::Add(root.WithOpName("add2"), m0.output, m1.output);
443 
444   std::unique_ptr<DeadnessAnalysis> result;
445   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
446 
447   TF_ASSERT_OK_AND_ASSIGN(
448       bool has_inputs_with_mismatching_deadness,
449       HasInputsWithMismatchingDeadness(*result, *add2.node()));
450   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
451 }
452 
TEST(DeadnessAnalysisTest,AndOrDistributiveSimplified)453 TEST(DeadnessAnalysisTest, AndOrDistributiveSimplified) {
454   // (*A | (~*A & ((~*B & ~*A) | (~*A & *B)))) == #true
455   Scope root = Scope::NewRootScope().ExitOnError();
456 
457   ops::Switch sw_0 = CreateSwitch(root, "A");
458   ops::Switch sw_1 = CreateSwitch(root, "B");
459   Output add0 =
460       ops::Add(root.WithOpName("and0"), sw_0.output_false, sw_1.output_true);
461   Output add1 =
462       ops::Add(root.WithOpName("and1"), sw_0.output_false, sw_1.output_false);
463   ops::Merge or2(root.WithOpName("or2"), {add0, add1});
464   Output add3 =
465       ops::Add(root.WithOpName("and3"), or2.output, sw_0.output_false);
466   ops::Merge or4(root.WithOpName("or4"), {add3, sw_0.output_true});
467 
468   std::unique_ptr<DeadnessAnalysis> result;
469   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
470 
471   PredicateMapTy predicate_map;
472   TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
473   EXPECT_EQ(predicate_map[ControlOutputFor(or4.output)], "#true");
474 }
475 
TEST(DeadnessAnalysisTest,AndOrDistributive)476 TEST(DeadnessAnalysisTest, AndOrDistributive) {
477   // (A|B)&C == (A&C)|(B&C)
478   Scope root = Scope::NewRootScope().ExitOnError();
479 
480   ops::Switch sw_0 = CreateSwitch(root, "0");
481   ops::Switch sw_1 = CreateSwitch(root, "1");
482   ops::Switch sw_2 = CreateSwitch(root, "2");
483 
484   ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
485   Output add0 = ops::Add(root.WithOpName("add0"), m0.output, sw_2.output_false);
486 
487   Output add1 =
488       ops::Add(root.WithOpName("add1"), sw_0.output_false, sw_2.output_false);
489   Output add2 =
490       ops::Add(root.WithOpName("add2"), sw_1.output_false, sw_2.output_false);
491   ops::Merge m1(root.WithOpName("m1"), {add1, add2});
492 
493   Output add3 = ops::Add(root.WithOpName("add3"), add0, m1.output);
494 
495   std::unique_ptr<DeadnessAnalysis> result;
496   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
497 
498   TF_ASSERT_OK_AND_ASSIGN(
499       bool has_inputs_with_mismatching_deadness,
500       HasInputsWithMismatchingDeadness(*result, *add3.node()));
501   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
502 }
503 
TEST(DeadnessAnalysisTest,Ternary)504 TEST(DeadnessAnalysisTest, Ternary) {
505   Scope root = Scope::NewRootScope().ExitOnError();
506 
507   Output predicate = ops::Placeholder(root.WithOpName("predicate"), DT_BOOL);
508   Output true_value = ops::Placeholder(root.WithOpName("true_value"), DT_FLOAT);
509   Output false_value =
510       ops::Placeholder(root.WithOpName("false_value"), DT_FLOAT);
511 
512   ops::Switch predicated_true(root.WithOpName("predicated_true"), true_value,
513                               predicate);
514 
515   ops::Switch predicated_false(root.WithOpName("predicated_false"), true_value,
516                                predicate);
517   ops::Merge merge(root.WithOpName("ternary"), {predicated_true.output_true,
518                                                 predicated_false.output_false});
519   Output addend = ops::Placeholder(root.WithOpName("addend"), DT_FLOAT);
520   Output add = ops::Add(root.WithOpName("add"), merge.output, addend);
521 
522   std::unique_ptr<DeadnessAnalysis> result;
523   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
524 
525   TF_ASSERT_OK_AND_ASSIGN(
526       bool has_inputs_with_mismatching_deadness,
527       HasInputsWithMismatchingDeadness(*result, *add.node()));
528   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
529 }
530 
TEST(DeadnessAnalysisTest,Recv)531 TEST(DeadnessAnalysisTest, Recv) {
532   Scope root = Scope::NewRootScope().ExitOnError();
533 
534   Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_FLOAT, "tensor_a",
535                              "sender", 0, "receiver");
536   Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_FLOAT, "tensor_b",
537                              "sender", 0, "receiver");
538   Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b);
539 
540   std::unique_ptr<DeadnessAnalysis> result;
541   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
542 
543   TF_ASSERT_OK_AND_ASSIGN(
544       bool has_inputs_with_mismatching_deadness,
545       HasInputsWithMismatchingDeadness(*result, *add.node()));
546   EXPECT_TRUE(has_inputs_with_mismatching_deadness);
547 }
548 
TEST(DeadnessAnalysisTest,HostRecv)549 TEST(DeadnessAnalysisTest, HostRecv) {
550   Scope root = Scope::NewRootScope().ExitOnError();
551 
552   Output recv_a = ops::_HostRecv(root.WithOpName("recv_a"), DT_FLOAT,
553                                  "tensor_a", "sender", 0, "receiver");
554   Output recv_b = ops::_HostRecv(root.WithOpName("recv_b"), DT_FLOAT,
555                                  "tensor_b", "sender", 0, "receiver");
556   Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b);
557 
558   std::unique_ptr<DeadnessAnalysis> result;
559   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
560 
561   TF_ASSERT_OK_AND_ASSIGN(
562       bool has_inputs_with_mismatching_deadness,
563       HasInputsWithMismatchingDeadness(*result, *add.node()));
564   EXPECT_TRUE(has_inputs_with_mismatching_deadness);
565 }
566 
TEST(DeadnessAnalysisTest,Loop)567 TEST(DeadnessAnalysisTest, Loop) {
568   Scope root = Scope::NewRootScope().ExitOnError();
569   Output iv0 = CreateInductionVariable(root, "iv0", "fr0", 0).induction_var;
570   Output iv1 = CreateInductionVariable(root, "iv1", "fr0", 0).induction_var;
571   Output iv2 = CreateInductionVariable(root, "iv2", "fr0", 1).induction_var;
572   Output add0 = ops::Add(root.WithOpName("add0"), iv0, iv1);
573   Output add1 = ops::Add(root.WithOpName("add1"), iv1, iv2);
574 
575   // NB!  iv0 and iv1 are equivalent and a smarter deadness analysis would have
576   // noticed that.  Today we are pessimistic here because we assign an
577   // uninterpreted symbol to merges with backedges.
578 
579   VLogGraphIfAsked(*root.graph());
580 
581   {
582     std::unique_ptr<DeadnessAnalysis> result;
583     TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
584 
585     bool has_inputs_with_mismatching_deadness;
586 
587     TF_ASSERT_OK_AND_ASSIGN(
588         has_inputs_with_mismatching_deadness,
589         HasInputsWithMismatchingDeadness(*result, *add0.node()));
590     EXPECT_TRUE(has_inputs_with_mismatching_deadness);
591 
592     TF_ASSERT_OK_AND_ASSIGN(
593         has_inputs_with_mismatching_deadness,
594         HasInputsWithMismatchingDeadness(*result, *add1.node()));
595     EXPECT_TRUE(has_inputs_with_mismatching_deadness);
596   }
597   {
598     PredicateMapTy predicate_map;
599     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
600 
601     // In theory we should be able to tell that iv0/cond:0 and iv1/cond:0
602     // produce the same deadness.  But we're not that smart today.
603     EXPECT_EQ(predicate_map[ControlOutputFor(iv0)],
604               "{#true,&,*iv0/cond:0}<fr0>");
605     EXPECT_EQ(predicate_map[ControlOutputFor(iv1)],
606               "{#true,&,*iv1/cond:0}<fr0>");
607     EXPECT_EQ(predicate_map[ControlOutputFor(iv2)],
608               "{#true,&,*iv2/cond:0}<fr0>");
609     EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
610               "({#true,&,*iv0/cond:0}<fr0> & {#true,&,*iv1/cond:0}<fr0>)");
611     EXPECT_EQ(predicate_map[ControlOutputFor(add1)],
612               "({#true,&,*iv1/cond:0}<fr0> & {#true,&,*iv2/cond:0}<fr0>)");
613   }
614 }
615 
TEST(DeadnessAnalysisTest,ControlEquivalentLoopBodies)616 TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) {
617   Scope root = Scope::NewRootScope().ExitOnError();
618   InductionVarInfo iv = CreateInductionVariable(root, "iv0", "loop", 0);
619   Output dependent_iv0 =
620       CreateDependentLoopInvariantValue(root, "div0", "loop", iv.loop_cond, 0)
621           .induction_var;
622   Output dependent_iv1 =
623       CreateDependentLoopInvariantValue(root, "div1", "loop", iv.loop_cond, 0)
624           .induction_var;
625   Output add0 = ops::Add(root.WithOpName("add0"), dependent_iv0, dependent_iv1);
626 
627   VLogGraphIfAsked(*root.graph());
628 
629   {
630     std::unique_ptr<DeadnessAnalysis> result;
631     TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
632 
633     TF_ASSERT_OK_AND_ASSIGN(
634         bool has_inputs_with_mismatching_deadness,
635         HasInputsWithMismatchingDeadness(*result, *add0.node()));
636     EXPECT_FALSE(has_inputs_with_mismatching_deadness);
637   }
638   {
639     PredicateMapTy predicate_map;
640     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
641                                    /*enable_optimistic=*/true));
642 
643     EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
644               "{#true,&,*iv0/cond:0}<loop>");
645     EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)],
646               predicate_map[ControlOutputFor(iv.induction_var)]);
647     EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)],
648               predicate_map[ControlOutputFor(iv.induction_var)]);
649     EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
650               predicate_map[ControlOutputFor(iv.induction_var)]);
651   }
652   {
653     PredicateMapTy predicate_map;
654     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
655                                    /*enable_optimistic=*/false));
656 
657     EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
658               "{#true,&,*iv0/cond:0}<loop>");
659     EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)],
660               "{#true,&,(iv0/iv:0 & *iv0/cond:0)}<loop>");
661     EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)],
662               "{#true,&,(iv0/iv:0 & *iv0/cond:0)}<loop>");
663     EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
664               "{#true,&,(iv0/iv:0 & *iv0/cond:0)}<loop>");
665   }
666 }
667 
TEST(DeadnessAnalysisTest,LoopInvariantPredicateOnBackedge)668 TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) {
669   // Create a merge that "looks like" a loop but isn't really.  It has a value
670   // that does not depend on the merge on its backedge.
671   Scope root = Scope::NewRootScope().ExitOnError();
672   InductionVarInfo iv = CreateInductionVariable(root, "iv0", "frame", 0);
673   DependentInductionVar dependent_iv =
674       CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0);
675   FixupSourceAndSinkEdges(root.graph());
676 
677   TF_ASSERT_OK(root.graph()->UpdateEdge(
678       iv.induction_var.node(), 0, dependent_iv.latch.output_true.node(), 0));
679 
680   VLogGraphIfAsked(*root.graph());
681 
682   {
683     PredicateMapTy predicate_map;
684     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
685                                    /*enable_optimistic=*/true));
686 
687     EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
688               "{#true,&,*iv0/cond:0}<frame>");
689   }
690   {
691     PredicateMapTy predicate_map;
692     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
693                                    /*enable_optimistic=*/false));
694 
695     EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
696               "div0/iv:0");
697   }
698 }
699 
TEST(DeadnessAnalysisTest,ControlEquivalentNestedLoopBodies)700 TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
701   Scope root = Scope::NewRootScope().ExitOnError();
702   InductionVarInfo iv_outer =
703       CreateInductionVariable(root, "iv_outer", "outer_loop", 0);
704   Output enter_constant_outer_loop = ops::internal::Enter(
705       root.WithOpName("constant_enter_outer_loop"),
706       ops::Const(root.WithOpName("constant"), 5), "outer_loop",
707       ops::internal::Enter::Attrs().IsConstant(true));
708   ops::Switch inner_value(root.WithOpName("outer_is_live"),
709                           enter_constant_outer_loop, iv_outer.loop_cond);
710   InductionVarInfo iv_inner = CreateInductionVariable(
711       root, "iv_inner", "inner_loop", inner_value.output_true);
712 
713   Output dependent_outer_iv0 =
714       CreateDependentLoopInvariantValue(root, "dependent_outer_iv0",
715                                         "outer_loop", iv_outer.loop_cond, 0)
716           .induction_var;
717   Output dependent_outer_iv1 =
718       CreateDependentLoopInvariantValue(root, "dependent_outer_iv1",
719                                         "outer_loop", iv_outer.loop_cond, 0)
720           .induction_var;
721 
722   Output dependent_inner_iv0 = CreateDependentLoopInvariantValue(
723                                    root, "dependent_inner_iv0", "inner_loop",
724                                    iv_inner.loop_cond, dependent_outer_iv0)
725                                    .induction_var;
726   Output dependent_inner_iv1 = CreateDependentLoopInvariantValue(
727                                    root, "dependent_inner_iv1", "inner_loop",
728                                    iv_inner.loop_cond, dependent_outer_iv1)
729                                    .induction_var;
730 
731   Output add0 = ops::Add(root.WithOpName("add0"), dependent_inner_iv0,
732                          dependent_inner_iv1);
733 
734   VLogGraphIfAsked(*root.graph());
735 
736   {
737     std::unique_ptr<DeadnessAnalysis> result;
738     TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
739 
740     TF_ASSERT_OK_AND_ASSIGN(
741         bool has_inputs_with_mismatching_deadness,
742         HasInputsWithMismatchingDeadness(*result, *add0.node()));
743     EXPECT_FALSE(has_inputs_with_mismatching_deadness);
744   }
745   {
746     PredicateMapTy predicate_map;
747     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
748                                    /*enable_optimistic=*/true));
749 
750     EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
751               "{#true,&,*iv_outer/cond:0}<outer_loop>");
752     EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)],
753               "{(*iv_outer/cond:0 & "
754               "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
755               "cond:0}<inner_loop;outer_loop>");
756 
757     // enable_optimistic = true or not should produce the same results because
758     // of fallback.  However, note that the order of iv_inner/cond:0 and
759     // iv_inner/iv:0 is different because the optimistic approach does not
760     // create predicates for all merges and it can change the predicate id and
761     // hence the symbol order.
762     EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)],
763               "{{#true,&,(iv_outer/iv:0 & "
764               "*iv_outer/cond:0)}<outer_loop>,&,(*iv_inner/cond:0 & "
765               "iv_inner/iv:0)}<inner_loop;outer_loop>");
766     EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)],
767               predicate_map[ControlOutputFor(dependent_inner_iv0)]);
768     EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
769               predicate_map[ControlOutputFor(dependent_inner_iv0)]);
770   }
771   {
772     PredicateMapTy predicate_map;
773     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
774                                    /*enable_optimistic=*/false));
775 
776     EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
777               "{#true,&,*iv_outer/cond:0}<outer_loop>");
778     EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)],
779               "{(*iv_outer/cond:0 & "
780               "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
781               "cond:0}<inner_loop;outer_loop>");
782 
783     EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)],
784               "{{#true,&,(iv_outer/iv:0 & "
785               "*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
786               "*iv_inner/cond:0)}<inner_loop;outer_loop>");
787     EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)],
788               predicate_map[ControlOutputFor(dependent_inner_iv0)]);
789     EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
790               predicate_map[ControlOutputFor(dependent_inner_iv0)]);
791   }
792 }
793 
TEST(DeadnessAnalysisTest,ControlNonEquivalentNestedLoopBodies)794 TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) {
795   Scope root = Scope::NewRootScope().ExitOnError();
796 
797   std::array<Output, 2> outer_iv;
798   std::array<Output, 2> inner_iv;
799 
800   for (int i : {0, 1}) {
801     InductionVarInfo iv_outer =
802         CreateInductionVariable(root, "iv_outer", "outer_loop", 0);
803     Output enter_constant_outer_loop = ops::internal::Enter(
804         root.WithOpName("constant_enter_outer_loop"),
805         ops::Const(root.WithOpName("constant"), 5), "outer_loop",
806         ops::internal::Enter::Attrs().IsConstant(true));
807     ops::Switch inner_value(root.WithOpName("outer_is_live"),
808                             enter_constant_outer_loop, iv_outer.loop_cond);
809     InductionVarInfo iv_inner = CreateInductionVariable(
810         root, "iv_inner", "inner_loop", inner_value.output_true);
811 
812     outer_iv[i] = iv_outer.induction_var;
813     inner_iv[i] = iv_inner.induction_var;
814   }
815 
816   Output add0 = ops::Add(root.WithOpName("add0"), inner_iv[0], inner_iv[1]);
817 
818   VLogGraphIfAsked(*root.graph());
819 
820   {
821     std::unique_ptr<DeadnessAnalysis> result;
822     TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
823 
824     TF_ASSERT_OK_AND_ASSIGN(
825         bool has_inputs_with_mismatching_deadness,
826         HasInputsWithMismatchingDeadness(*result, *add0.node()));
827     EXPECT_TRUE(has_inputs_with_mismatching_deadness);
828   }
829 
830   {
831     PredicateMapTy predicate_map;
832     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
833 
834     EXPECT_EQ(predicate_map[ControlOutputFor(outer_iv[0])],
835               "{#true,&,*iv_outer/cond:0}<outer_loop>");
836     EXPECT_EQ(predicate_map[ControlOutputFor(inner_iv[0])],
837               "{(*iv_outer/cond:0 & "
838               "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
839               "cond:0}<inner_loop;outer_loop>");
840     EXPECT_EQ(predicate_map[ControlOutputFor(outer_iv[1])],
841               "{#true,&,*iv_outer/cond_1:0}<outer_loop>");
842     EXPECT_EQ(predicate_map[ControlOutputFor(inner_iv[1])],
843               "{(*iv_outer/cond_1:0 & "
844               "{#true,&,*iv_outer/cond_1:0}<outer_loop>),&,*iv_inner/"
845               "cond_1:0}<inner_loop;outer_loop>");
846     EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
847               "({(*iv_outer/cond:0 & "
848               "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
849               "cond:0}<inner_loop;outer_loop> & {(*iv_outer/cond_1:0 & "
850               "{#true,&,*iv_outer/cond_1:0}<outer_loop>),&,*iv_inner/"
851               "cond_1:0}<inner_loop;outer_loop>)");
852   }
853 }
854 
TEST(DeadnessAnalysisTest,NestedLoopBodiesWithACapture)855 TEST(DeadnessAnalysisTest, NestedLoopBodiesWithACapture) {
856   Scope root = Scope::NewRootScope().ExitOnError();
857   InductionVarInfo iv_outer =
858       CreateInductionVariable(root, "iv_outer", "outer_loop", 0);
859   Output enter_constant_outer_loop = ops::internal::Enter(
860       root.WithOpName("constant_enter_outer_loop"),
861       ops::Const(root.WithOpName("constant"), 5), "outer_loop",
862       ops::internal::Enter::Attrs().IsConstant(true));
863   ops::Switch inner_value(root.WithOpName("outer_is_live"),
864                           enter_constant_outer_loop, iv_outer.loop_cond);
865   InductionVarInfo iv_inner = CreateInductionVariable(
866       root, "iv_inner", "inner_loop", inner_value.output_true);
867 
868   DependentInductionVar div0_outer = CreateDependentLoopInvariantValue(
869       root, "div0_outer", "outer_loop", iv_outer.loop_cond, 0);
870   DependentInductionVar div1_outer = CreateDependentLoopInvariantValue(
871       root, "div1_outer", "outer_loop", iv_outer.loop_cond, 0);
872 
873   DependentInductionVar div0_inner = CreateDependentLoopInvariantValue(
874       root, "div0_inner", "inner_loop", iv_inner.loop_cond,
875       div0_outer.induction_var);
876   DependentInductionVar div1_inner = CreateDependentLoopInvariantValue(
877       root, "div1_inner", "inner_loop", iv_inner.loop_cond,
878       div1_outer.induction_var);
879 
880   Output captured = ops::_Recv(root.WithOpName("captured"), DT_INT32,
881                                "tensor_a", "sender", 0, "receiver");
882   Output capture_enter_outer = ops::internal::Enter(
883       root.WithOpName("capture_enter_outer"), captured, "outer_loop",
884       ops::internal::Enter::Attrs().IsConstant(true));
885   Output capture_enter_inner = ops::internal::Enter(
886       root.WithOpName("capture_enter_inner"), capture_enter_outer, "inner_loop",
887       ops::internal::Enter::Attrs().IsConstant(true));
888   Output mul0 = ops::Mul(root.WithOpName("mul0"), div1_inner.induction_var,
889                          capture_enter_inner);
890   TF_ASSERT_OK(root.graph()->UpdateEdge(
891       mul0.node(), 0, div1_inner.latch.output_true.node(), 0));
892 
893   Output add0 = ops::Add(root.WithOpName("add0"), div0_inner.induction_var,
894                          div1_inner.induction_var);
895 
896   VLogGraphIfAsked(*root.graph());
897 
898   {
899     std::unique_ptr<DeadnessAnalysis> result;
900     TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
901 
902     TF_ASSERT_OK_AND_ASSIGN(
903         bool has_inputs_with_mismatching_deadness,
904         HasInputsWithMismatchingDeadness(*result, *add0.node()));
905     EXPECT_TRUE(has_inputs_with_mismatching_deadness);
906   }
907 }
908 
TEST(DeadnessAnalysisTest,CyclicRecurrence)909 TEST(DeadnessAnalysisTest, CyclicRecurrence) {
910   Scope root = Scope::NewRootScope().ExitOnError();
911   InductionVarInfo iv = CreateInductionVariable(root, "iv0", "loop", 0);
912   DependentInductionVar div0 =
913       CreateDependentLoopInvariantValue(root, "div0", "loop", iv.loop_cond, 0);
914   DependentInductionVar div1 =
915       CreateDependentLoopInvariantValue(root, "div1", "loop", iv.loop_cond, 0);
916   FixupSourceAndSinkEdges(root.graph());
917   TF_ASSERT_OK(root.graph()->UpdateEdge(div1.induction_var.node(), 0,
918                                         div0.latch.output_true.node(), 0));
919   TF_ASSERT_OK(root.graph()->UpdateEdge(div0.induction_var.node(), 0,
920                                         div1.latch.output_true.node(), 0));
921 
922   VLogGraphIfAsked(*root.graph());
923 
924   {
925     PredicateMapTy predicate_map;
926     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
927                                    /*enable_optimistic=*/true));
928 
929     EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
930               "{#true,&,*iv0/cond:0}<loop>");
931     EXPECT_EQ(predicate_map[ControlOutputFor(div0.induction_var)],
932               "{#true,&,*iv0/cond:0}<loop>");
933     EXPECT_EQ(predicate_map[ControlOutputFor(div1.induction_var)],
934               "{#true,&,*iv0/cond:0}<loop>");
935 
936     // This tests the rule {S,&,X} & ~X => S.
937     TensorId switch_false_out = {div1.latch.output_false.node()->name(),
938                                  div1.latch.output_false.index()};
939     EXPECT_EQ(predicate_map[switch_false_out], "(#true)");
940   }
941   {
942     PredicateMapTy predicate_map;
943     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
944                                    /*enable_optimistic=*/false));
945 
946     EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
947               "{#true,&,*iv0/cond:0}<loop>");
948     EXPECT_EQ(predicate_map[ControlOutputFor(div0.induction_var)], "div0/iv:0");
949     EXPECT_EQ(predicate_map[ControlOutputFor(div1.induction_var)], "div1/iv:0");
950   }
951 }
952 
TEST(DeadnessAnalysisTest,AndRecurrenceNeedsFrameName)953 TEST(DeadnessAnalysisTest, AndRecurrenceNeedsFrameName) {
954   Scope root = Scope::NewRootScope().ExitOnError();
955   InductionVarInfo iv_0 = CreateInductionVariable(root, "iv_0", "frame_0", 10);
956   InductionVarInfo iv_1 = CreateInductionVariable(root, "iv_1", "frame_1", 9);
957 
958   Output init = CreateSwitch(root, "init").output_true;
959   Output step = CreateSwitch(root, "step").output_true;
960 
961   std::array<Output, 2> exits;
962   std::array<Output, 2> next_iterations;
963 
964   for (int i : {0, 1}) {
965     Output init_enter = ops::internal::Enter(
966         root.WithOpName(absl::StrCat("init_enter_frame_", i)), init,
967         absl::StrCat("frame_", i),
968         ops::internal::Enter::Attrs().IsConstant(true));
969     Output step_enter = ops::internal::Enter(
970         root.WithOpName(absl::StrCat("step_enter_frame_", i)), step,
971         absl::StrCat("frame_", i),
972         ops::internal::Enter::Attrs().IsConstant(true));
973 
974     ops::Merge iv(root.WithOpName(absl::StrCat("expr_", i)),
975                   {init_enter, init_enter});
976     Output add = ops::Add(root.WithOpName(absl::StrCat("add_", i)), iv.output,
977                           step_enter);
978     next_iterations[i] = ops::NextIteration(
979         root.WithOpName(absl::StrCat("expr_", i, "_next_iteration")), add);
980     EXPECT_TRUE(
981         root.graph()
982             ->UpdateEdge(next_iterations[i].node(), 0, iv.output.node(), 1)
983             .ok());
984     exits[i] = ops::internal::Exit(root.WithOpName(absl::StrCat("exit_", i)),
985                                    iv.output);
986   }
987 
988   FixupSourceAndSinkEdges(root.graph());
989 
990   {
991     PredicateMapTy predicate_map;
992     TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
993 
994     EXPECT_NE(predicate_map[ControlOutputFor(exits[0])],
995               predicate_map[ControlOutputFor(exits[1])]);
996     EXPECT_NE(predicate_map[ControlOutputFor(exits[0])], "");
997     EXPECT_NE(predicate_map[ControlOutputFor(exits[1])], "");
998 
999     EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[0])],
1000               predicate_map[ControlOutputFor(next_iterations[1])]);
1001     EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[0])], "");
1002     EXPECT_NE(predicate_map[ControlOutputFor(next_iterations[1])], "");
1003   }
1004 }
1005 
TEST(DeadnessAnalysisTest,ControlInputs)1006 TEST(DeadnessAnalysisTest, ControlInputs) {
1007   Scope root = Scope::NewRootScope().ExitOnError();
1008   ops::Switch sw = CreateSwitch(root, "0");
1009 
1010   Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
1011   Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
1012 
1013   Output const0 = ops::Const(root.WithOpName("const0"), 1);
1014   Output const1 = ops::Const(root.WithOpName("const1"), 2);
1015 
1016   Output add = ops::Add(root.WithOpName("add"), const0, const1);
1017 
1018   root.graph()->AddControlEdge(id0.node(), const0.node());
1019   root.graph()->AddControlEdge(id1.node(), const1.node());
1020 
1021   std::unique_ptr<DeadnessAnalysis> result;
1022   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
1023 
1024   TF_ASSERT_OK_AND_ASSIGN(
1025       bool has_inputs_with_mismatching_deadness,
1026       HasInputsWithMismatchingDeadness(*result, *add.node()));
1027   EXPECT_TRUE(has_inputs_with_mismatching_deadness);
1028 }
1029 
TEST(DeadnessAnalysisTest,ControlTrigger)1030 TEST(DeadnessAnalysisTest, ControlTrigger) {
1031   Scope root = Scope::NewRootScope().ExitOnError();
1032   ops::Switch sw = CreateSwitch(root, "0");
1033 
1034   Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
1035   Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
1036 
1037   ops::ControlTrigger ctrl_trigger0(root.WithOpName("ctrl_trigger0"));
1038   ops::ControlTrigger ctrl_trigger1(root.WithOpName("ctrl_trigger1"));
1039 
1040   Output const0 = ops::Const(root.WithOpName("const0"), 1);
1041   Output const1 = ops::Const(root.WithOpName("const1"), 2);
1042 
1043   Output add = ops::Add(root.WithOpName("add"), const0, const1);
1044 
1045   root.graph()->AddControlEdge(id0.node(), ctrl_trigger0.operation.node());
1046   root.graph()->AddControlEdge(ctrl_trigger0.operation.node(), const0.node());
1047 
1048   root.graph()->AddControlEdge(id1.node(), ctrl_trigger1.operation.node());
1049   root.graph()->AddControlEdge(ctrl_trigger1.operation.node(), const1.node());
1050 
1051   std::unique_ptr<DeadnessAnalysis> result;
1052   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
1053 
1054   TF_ASSERT_OK_AND_ASSIGN(
1055       bool has_inputs_with_mismatching_deadness,
1056       HasInputsWithMismatchingDeadness(*result, *add.node()));
1057   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
1058 }
1059 
TEST(DeadnessAnalysisTest,ControlInputsToMerge)1060 TEST(DeadnessAnalysisTest, ControlInputsToMerge) {
1061   Scope root = Scope::NewRootScope().ExitOnError();
1062   ops::Switch sw = CreateSwitch(root, "0");
1063 
1064   Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
1065   Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
1066 
1067   Output constant = ops::Const(root.WithOpName("constant"), 5);
1068   ops::Merge m0(root.WithOpName("m0"), {constant});
1069   ops::Merge m1(root.WithOpName("m0"), {constant});
1070   Output add = ops::Add(root.WithOpName("add"), m0.output, m1.output);
1071 
1072   root.graph()->AddControlEdge(id0.node(), m0.output.node());
1073   root.graph()->AddControlEdge(id1.node(), m1.output.node());
1074 
1075   std::unique_ptr<DeadnessAnalysis> result;
1076   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
1077 
1078   TF_ASSERT_OK_AND_ASSIGN(
1079       bool has_inputs_with_mismatching_deadness,
1080       HasInputsWithMismatchingDeadness(*result, *add.node()));
1081   EXPECT_FALSE(has_inputs_with_mismatching_deadness);
1082 }
1083 
TEST(DeadnessAnalysisTest,RecvVsSwitch)1084 TEST(DeadnessAnalysisTest, RecvVsSwitch) {
1085   // Demonstrates why we need the must_be_true bit on SymbolP.
1086   Scope root = Scope::NewRootScope().ExitOnError();
1087 
1088   Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender",
1089                            0, "receiver");
1090   Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL);
1091   ops::Switch sw(root.WithOpName("switch"), value, recv);
1092   Output logical_and =
1093       ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true);
1094 
1095   std::unique_ptr<DeadnessAnalysis> result;
1096   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
1097 
1098   TF_ASSERT_OK_AND_ASSIGN(
1099       bool has_inputs_with_mismatching_deadness,
1100       HasInputsWithMismatchingDeadness(*result, *logical_and.node()));
1101   EXPECT_TRUE(has_inputs_with_mismatching_deadness);
1102 }
1103 
TEST(DeadnessAnalysisTest,RecvVsSwitchText)1104 TEST(DeadnessAnalysisTest, RecvVsSwitchText) {
1105   // Demonstrates why we need the must_be_true bit on SymbolP.
1106   Scope root = Scope::NewRootScope().ExitOnError();
1107 
1108   Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender",
1109                            0, "receiver");
1110   Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL);
1111   ops::Switch sw(root.WithOpName("switch"), value, recv);
1112   Output logical_and =
1113       ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true);
1114 
1115   std::unique_ptr<DeadnessAnalysis> result;
1116   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
1117 
1118   PredicateMapTy predicate_map;
1119   TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1120 
1121   TensorId logical_and_output_0 = {logical_and.node()->name(),
1122                                    Graph::kControlSlot};
1123   EXPECT_EQ(predicate_map[logical_and_output_0], "(recv:0 & *recv:0)");
1124 }
1125 
TEST(DeadnessAnalysisTest,DeMorgan)1126 TEST(DeadnessAnalysisTest, DeMorgan) {
1127   Scope root = Scope::NewRootScope().ExitOnError();
1128 
1129   Output cond_0 = ops::Placeholder(root.WithOpName("cond_0"), DT_BOOL);
1130   Output cond_1 = ops::Placeholder(root.WithOpName("cond_1"), DT_BOOL);
1131   Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1132 
1133   ops::Switch sw_0(root.WithOpName("switch_0"), value, cond_0);
1134   ops::Switch sw_1(root.WithOpName("switch_1"), value, cond_1);
1135 
1136   Output and_0_1 =
1137       ops::Add(root.WithOpName("and_0_1"), sw_0.output_true, sw_1.output_true);
1138 
1139   Output or_not0_not1 = ops::Merge(root.WithOpName("or_not0_not1"),
1140                                    {sw_0.output_false, sw_1.output_false})
1141                             .output;
1142 
1143   // Predicate(should_always_be_dead) =
1144   // (A & B) & (~A | ~B) = (A & B) & ~(A & B) = False
1145   Output should_always_be_dead =
1146       ops::Add(root.WithOpName("should_always_be_dead"), and_0_1, or_not0_not1);
1147 
1148   // Predicate(should_always_be_dead) =
1149   // (A & B) | (~A | ~B) = (A & B) | ~(A & B) = True
1150   Output should_always_be_alive =
1151       ops::Merge(root.WithOpName("should_always_be_alive"),
1152                  {and_0_1, or_not0_not1})
1153           .output;
1154 
1155   std::unique_ptr<DeadnessAnalysis> result;
1156   TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
1157 
1158   PredicateMapTy predicate_map;
1159   TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1160 
1161   EXPECT_EQ(predicate_map[ControlOutputFor(should_always_be_dead)], "#false");
1162   EXPECT_EQ(predicate_map[ControlOutputFor(should_always_be_alive)], "#true");
1163 }
1164 
TEST(DeadnessAnalysisTest,ConstantTrueSwitchCondition)1165 TEST(DeadnessAnalysisTest, ConstantTrueSwitchCondition) {
1166   Scope root = Scope::NewRootScope().ExitOnError();
1167 
1168   Output constant_true = ops::Const(root.WithOpName("const_true"), true);
1169   Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1170   ops::Switch sw(root.WithOpName("switch"), value, constant_true);
1171 
1172   Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false);
1173   Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true);
1174 
1175   FixupSourceAndSinkEdges(root.graph());
1176 
1177   PredicateMapTy predicate_map;
1178   TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1179 
1180   EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "#false");
1181   EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#true");
1182 }
1183 
TEST(DeadnessAnalysisTest,ConstantFalseSwitchCondition)1184 TEST(DeadnessAnalysisTest, ConstantFalseSwitchCondition) {
1185   Scope root = Scope::NewRootScope().ExitOnError();
1186 
1187   Output constant_false = ops::Const(root.WithOpName("const_false"), false);
1188   Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1189   ops::Switch sw(root.WithOpName("switch"), value, constant_false);
1190 
1191   Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false);
1192   Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true);
1193 
1194   FixupSourceAndSinkEdges(root.graph());
1195 
1196   PredicateMapTy predicate_map;
1197   TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1198 
1199   EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "#true");
1200   EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "#false");
1201 }
1202 
TEST(DeadnessAnalysisTest,RefBoolSwitchCondition)1203 TEST(DeadnessAnalysisTest, RefBoolSwitchCondition) {
1204   Scope root = Scope::NewRootScope().ExitOnError();
1205 
1206   Output condition_ref_var =
1207       ops::Variable(root.WithOpName("cond_ref"), TensorShape({}), DT_BOOL);
1208   Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1209   ops::Switch sw(root.WithOpName("switch"), value, condition_ref_var);
1210 
1211   Output id_false = ops::Identity(root.WithOpName("id_false"), sw.output_false);
1212   Output id_true = ops::Identity(root.WithOpName("id_true"), sw.output_true);
1213 
1214   FixupSourceAndSinkEdges(root.graph());
1215 
1216   PredicateMapTy predicate_map;
1217   TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1218 
1219   EXPECT_EQ(predicate_map[ControlOutputFor(id_false)], "~*cond_ref:0");
1220   EXPECT_EQ(predicate_map[ControlOutputFor(id_true)], "*cond_ref:0");
1221 }
1222 
CreateSwitchN(const Scope & scope,Input data,Input output_index,int64 num_outs,OutputList * outputs)1223 void CreateSwitchN(const Scope& scope, Input data, Input output_index,
1224                    int64 num_outs, OutputList* outputs) {
1225   if (!scope.ok()) return;
1226   auto _data = ops::AsNodeOut(scope, data);
1227   if (!scope.ok()) return;
1228   auto _output_index = ops::AsNodeOut(scope, output_index);
1229   if (!scope.ok()) return;
1230   Node* ret;
1231   const auto unique_name = scope.GetUniqueNameForOp("_SwitchN");
1232   auto builder = NodeBuilder(unique_name, "_SwitchN")
1233                      .Input(_data)
1234                      .Input(_output_index)
1235                      .Attr("num_outs", num_outs);
1236   scope.UpdateBuilder(&builder);
1237   scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
1238   if (!scope.ok()) return;
1239   scope.UpdateStatus(scope.DoShapeInference(ret));
1240   for (int32 i = 0; i < ret->num_outputs(); ++i) {
1241     outputs->push_back(Output(ret, i));
1242   }
1243 }
1244 
TEST(DeadnessAnalysisTest,Constant1_SwitchN_2Branches_DoesNotFail)1245 TEST(DeadnessAnalysisTest, Constant1_SwitchN_2Branches_DoesNotFail) {
1246   Scope root = Scope::NewRootScope().ExitOnError();
1247 
1248   Output constant_1 = ops::Const(root.WithOpName("const_1"), 1);
1249   Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1250   OutputList outputs;
1251   CreateSwitchN(root.WithOpName("switchn"), value, constant_1, 2, &outputs);
1252 
1253   Output id_0 = ops::Identity(root.WithOpName("id_0"), outputs[0]);
1254   Output id_1 = ops::Identity(root.WithOpName("id_1"), outputs[1]);
1255 
1256   FixupSourceAndSinkEdges(root.graph());
1257 
1258   PredicateMapTy predicate_map;
1259   TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1260 
1261   EXPECT_EQ(predicate_map[ControlOutputFor(id_0)], "#false");
1262   EXPECT_EQ(predicate_map[ControlOutputFor(id_1)], "#true");
1263 }
1264 
TEST(DeadnessAnalysisTest,Constant7_SwitchN_3Branches)1265 TEST(DeadnessAnalysisTest, Constant7_SwitchN_3Branches) {
1266   Scope root = Scope::NewRootScope().ExitOnError();
1267 
1268   Output constant_7 = ops::Const(root.WithOpName("const_7"), 7);
1269   Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1270   OutputList outputs;
1271   CreateSwitchN(root.WithOpName("switchn"), value, constant_7, 3, &outputs);
1272 
1273   Output id_0 = ops::Identity(root.WithOpName("id_0"), outputs[0]);
1274   Output id_1 = ops::Identity(root.WithOpName("id_1"), outputs[1]);
1275   Output id_2 = ops::Identity(root.WithOpName("id_2"), outputs[2]);
1276 
1277   FixupSourceAndSinkEdges(root.graph());
1278 
1279   PredicateMapTy predicate_map;
1280   TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1281 
1282   EXPECT_EQ(predicate_map[ControlOutputFor(id_0)], "#false");
1283   EXPECT_EQ(predicate_map[ControlOutputFor(id_1)], "#false");
1284   EXPECT_EQ(predicate_map[ControlOutputFor(id_2)], "#true");
1285 }
1286 
TEST(DeadnessAnalysisTest,RefInt_SwitchN_3Branches)1287 TEST(DeadnessAnalysisTest, RefInt_SwitchN_3Branches) {
1288   Scope root = Scope::NewRootScope().ExitOnError();
1289 
1290   Output condition_ref_var =
1291       ops::Variable(root.WithOpName("bidx"), TensorShape({}), DT_INT32);
1292   Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1293   OutputList outputs;
1294   CreateSwitchN(root.WithOpName("switchn"), value, condition_ref_var, 3,
1295                 &outputs);
1296 
1297   Output id_0 = ops::Identity(root.WithOpName("id_0"), outputs[0]);
1298   Output id_1 = ops::Identity(root.WithOpName("id_1"), outputs[1]);
1299   Output id_2 = ops::Identity(root.WithOpName("id_2"), outputs[2]);
1300 
1301   FixupSourceAndSinkEdges(root.graph());
1302 
1303   PredicateMapTy predicate_map;
1304   TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
1305 
1306   EXPECT_EQ(predicate_map[ControlOutputFor(id_0)], "bidx:0=0");
1307   EXPECT_EQ(predicate_map[ControlOutputFor(id_1)], "(~bidx:0=0 & bidx:0=1)");
1308   EXPECT_EQ(predicate_map[ControlOutputFor(id_2)], "(~bidx:0=0 & ~bidx:0=1)");
1309 }
1310 
1311 }  // namespace
1312 }  // namespace tensorflow
1313