1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
17 
18 #include "tensorflow/compiler/xla/literal.h"
19 #include "tensorflow/compiler/xla/service/hlo_computation.h"
20 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
21 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
22 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
23 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
24 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
25 #include "tensorflow/compiler/xla/service/instruction_fusion.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/test.h"
29 #include "tensorflow/compiler/xla/test_helpers.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/test.h"
35 
36 namespace xla {
37 namespace {
38 
39 using ::testing::ElementsAre;
40 using ::testing::UnorderedElementsAre;
41 
42 // Test is parameterized on a bool which is whether the dataflow analysis is
43 // performed with SSA form.
44 class HloDataflowAnalysisTest : public HloTestBase,
45                                 public ::testing::WithParamInterface<bool> {
46  protected:
HloDataflowAnalysisTest()47   HloDataflowAnalysisTest() : module_(CreateNewVerifiedModule()) {}
48 
49   // Run dataflow analysis on the member module. For convenience returns a
50   // reference to the generated analysis stored in analysis_.
RunAnalysis(bool ssa_form,bool bitcast_defines_value=false)51   const HloDataflowAnalysis& RunAnalysis(bool ssa_form,
52                                          bool bitcast_defines_value = false) {
53     analysis_ =
54         HloDataflowAnalysis::Run(*module_, ssa_form, bitcast_defines_value)
55             .ConsumeValueOrDie();
56     return *analysis_;
57   }
58 
59   // Return a vector of the HloValues at the given program position.
HloValuesAt(const HloInstruction * instruction,const ShapeIndex & index={})60   std::vector<HloValue> HloValuesAt(const HloInstruction* instruction,
61                                     const ShapeIndex& index = {}) {
62     CHECK(analysis_ != nullptr);
63     std::vector<HloValue> values;
64     for (const HloValue* value :
65          analysis_->GetValueSet(instruction, index).values()) {
66       values.push_back(*value);
67     }
68     return values;
69   }
70 
71   // Returns true if the top-level values for instructions 'a' and 'b' may
72   // interfere. Precondition: 'a' and 'b' define array-shaped values.
InstructionsMayInterfere(const HloOrdering & ordering,const HloInstruction * a,const HloInstruction * b)73   bool InstructionsMayInterfere(const HloOrdering& ordering,
74                                 const HloInstruction* a,
75                                 const HloInstruction* b) {
76     EXPECT_FALSE(a->shape().IsTuple());
77     EXPECT_FALSE(b->shape().IsTuple());
78     return ordering.MayInterfere(analysis_->GetValueDefinedAt(a),
79                                  analysis_->GetValueDefinedAt(b), *analysis_);
80   }
81 
CreateR0F32UnaryOpComputation(HloOpcode opcode)82   std::unique_ptr<HloComputation> CreateR0F32UnaryOpComputation(
83       HloOpcode opcode) {
84     HloComputation::Builder builder(TestName() + "." + HloOpcodeString(opcode));
85     HloInstruction* param0 = builder.AddInstruction(
86         HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
87     builder.AddInstruction(
88         HloInstruction::CreateUnary(scalar_shape_, opcode, param0));
89     return builder.Build();
90   }
91 
92   std::unique_ptr<HloModule> module_;
93   std::unique_ptr<HloDataflowAnalysis> analysis_;
94 
95   const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
96   const Shape vector_shape_ = ShapeUtil::MakeShape(F32, {42});
97   const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
98       {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})});
99 };
100 
TEST_P(HloDataflowAnalysisTest,BinaryOperation)101 TEST_P(HloDataflowAnalysisTest, BinaryOperation) {
102   // Test the dataflow for a simple binary operation (Add).
103   auto builder = HloComputation::Builder(TestName());
104   auto constant1 = builder.AddInstruction(
105       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
106   auto constant2 = builder.AddInstruction(
107       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
108   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
109       scalar_shape_, HloOpcode::kAdd, constant1, constant2));
110   module_->AddEntryComputation(builder.Build());
111   SCOPED_TRACE(module_->ToString());
112 
113   bool ssa_form = GetParam();
114   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
115 
116   // Each instruction should define a single value.
117   EXPECT_EQ(analysis.values().size(), 3);
118   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
119   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
120   EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
121 
122   // Verify the positions of the values. These positions are all trivial because
123   // there are no instructions which forward values.
124   EXPECT_THAT(analysis.GetValueDefinedAt(constant1).positions(),
125               UnorderedElementsAre(HloPosition{constant1, {}}));
126   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).positions(),
127               UnorderedElementsAre(HloPosition{constant2, {}}));
128   EXPECT_THAT(analysis.GetValueDefinedAt(add).positions(),
129               UnorderedElementsAre(HloPosition{add, {}}));
130 
131   // Verify the uses of the values.
132   EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
133               UnorderedElementsAre(HloUse{add, 0, {}}));
134   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
135               UnorderedElementsAre(HloUse{add, 1, {}}));
136   EXPECT_TRUE(analysis.GetValueDefinedAt(add).uses().empty());
137 
138   // Verify liveout values from the module.
139   EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
140   EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
141   EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
142 }
143 
TEST_P(HloDataflowAnalysisTest,TupleAndGtes)144 TEST_P(HloDataflowAnalysisTest, TupleAndGtes) {
145   // Verify the dataflow through a Tuple and GetTupleElement instructions.
146   auto builder = HloComputation::Builder(TestName());
147   auto param0 = builder.AddInstruction(
148       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
149   auto param1 = builder.AddInstruction(
150       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
151   auto tuple =
152       builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
153   auto gte0 = builder.AddInstruction(
154       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 0));
155   auto gte1 = builder.AddInstruction(
156       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1));
157   auto add = builder.AddInstruction(
158       HloInstruction::CreateBinary(scalar_shape_, HloOpcode::kAdd, gte0, gte1));
159   module_->AddEntryComputation(builder.Build());
160   SCOPED_TRACE(module_->ToString());
161 
162   bool ssa_form = GetParam();
163   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
164 
165   // The two params, tuple, and add should each define one value.
166   EXPECT_EQ(analysis.values().size(), 4);
167 
168   EXPECT_TRUE(analysis.ValueIsDefinedAt(param0));
169   EXPECT_TRUE(analysis.ValueIsDefinedAt(param1));
170   EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple, /*index=*/{}));
171   EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{0}));
172   EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{1}));
173   EXPECT_FALSE(analysis.ValueIsDefinedAt(gte0));
174   EXPECT_FALSE(analysis.ValueIsDefinedAt(gte1));
175   EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
176 
177   // Verify the positions of the values.
178   EXPECT_THAT(
179       analysis.GetValueDefinedAt(param0).positions(),
180       UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}},
181                            HloPosition{gte0, {}}));
182   EXPECT_THAT(
183       analysis.GetValueDefinedAt(param1).positions(),
184       UnorderedElementsAre(HloPosition{param1, {}}, HloPosition{tuple, {1}},
185                            HloPosition{gte1, {}}));
186   EXPECT_THAT(analysis.GetValueDefinedAt(tuple).positions(),
187               UnorderedElementsAre(HloPosition{tuple, {}}));
188 
189   // Verify uses. Of interest is that a GetTupleElement instruction is only a
190   // use of the top-level value in the tuple operand.
191   EXPECT_THAT(analysis.GetValueDefinedAt(param0).uses(),
192               UnorderedElementsAre(HloUse{add, 0, {}}));
193   EXPECT_THAT(analysis.GetValueDefinedAt(param1).uses(),
194               UnorderedElementsAre(HloUse{add, 1, {}}));
195   EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).uses(),
196               UnorderedElementsAre(HloUse{gte0, 0, {}}, HloUse{gte1, 0, {}}));
197   EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
198 }
199 
TEST_P(HloDataflowAnalysisTest,NestedTuple)200 TEST_P(HloDataflowAnalysisTest, NestedTuple) {
201   // Verify the dataflow through a nested tuple.
202   auto builder = HloComputation::Builder(TestName());
203   auto constant1 = builder.AddInstruction(
204       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
205   auto constant2 = builder.AddInstruction(
206       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
207   auto tuple = builder.AddInstruction(
208       HloInstruction::CreateTuple({constant1, constant2}));
209   auto nested_tuple = builder.AddInstruction(
210       HloInstruction::CreateTuple({tuple, tuple, constant1}));
211   auto gte_tuple = builder.AddInstruction(
212       HloInstruction::CreateGetTupleElement(tuple->shape(), nested_tuple, 1));
213   auto gte_out = builder.AddInstruction(
214       HloInstruction::CreateGetTupleElement(scalar_shape_, gte_tuple, 0));
215   module_->AddEntryComputation(builder.Build());
216   SCOPED_TRACE(module_->ToString());
217 
218   bool ssa_form = GetParam();
219   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
220 
221   EXPECT_EQ(analysis.values().size(), 4);
222 
223   // Verify positions and uses.
224   EXPECT_THAT(
225       analysis.GetValueDefinedAt(constant1).positions(),
226       UnorderedElementsAre(
227           HloPosition{constant1, {}}, HloPosition{tuple, {0}},
228           HloPosition{nested_tuple, {0, 0}}, HloPosition{nested_tuple, {1, 0}},
229           HloPosition{nested_tuple, {2}}, HloPosition{gte_tuple, {0}},
230           HloPosition{gte_out, {}}));
231   // Constant values should have only a single use, which is the root of the
232   // computation.
233   EXPECT_THAT(analysis.GetValueDefinedAt(constant1, /*index=*/{}).uses(),
234               UnorderedElementsAre(HloUse{gte_out, 0, {0}}));
235   EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty());
236 
237   // The top-level tuple values are used in GTE instructions.
238   EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).uses(),
239               UnorderedElementsAre(HloUse{gte_out, 0, {}}));
240   EXPECT_THAT(analysis.GetValueDefinedAt(nested_tuple, /*index=*/{}).uses(),
241               UnorderedElementsAre(HloUse{gte_tuple, 0, {}}));
242 
243   EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
244   EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
245   EXPECT_FALSE(
246       analysis.GetValueDefinedAt(tuple, /*index=*/{}).live_out_of_module());
247   EXPECT_FALSE(analysis.GetValueDefinedAt(nested_tuple, /*index=*/{})
248                    .live_out_of_module());
249 }
250 
TEST_P(HloDataflowAnalysisTest,SingleCall)251 TEST_P(HloDataflowAnalysisTest, SingleCall) {
252   // Test a single call of a subcomputation. The subcomputation adds its two
253   // array-shaped parameters.
254   auto subbuilder = HloComputation::Builder("Subcomputation");
255   auto subparam0 = subbuilder.AddInstruction(
256       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
257   auto subparam1 = subbuilder.AddInstruction(
258       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
259   auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary(
260       scalar_shape_, HloOpcode::kAdd, subparam0, subparam1));
261   HloComputation* called_computation =
262       module_->AddEmbeddedComputation(subbuilder.Build());
263 
264   auto builder = HloComputation::Builder(TestName());
265   auto constant1 = builder.AddInstruction(
266       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
267   auto constant2 = builder.AddInstruction(
268       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
269   auto call = builder.AddInstruction(HloInstruction::CreateCall(
270       scalar_shape_, {constant1, constant2}, called_computation));
271   module_->AddEntryComputation(builder.Build());
272   SCOPED_TRACE(module_->ToString());
273 
274   bool ssa_form = GetParam();
275   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
276 
277   EXPECT_EQ(analysis.values().size(), 3);
278 
279   // The parameters of the subcomputation and the call instruction itself should
280   // not define values. Their values flow from elsewhere.
281   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
282   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
283   EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam0));
284   EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam1));
285   EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
286   EXPECT_FALSE(analysis.ValueIsDefinedAt(call));
287 
288   EXPECT_EQ(analysis.GetUniqueValueAt(subparam0),
289             analysis.GetValueDefinedAt(constant1));
290   EXPECT_EQ(analysis.GetUniqueValueAt(subparam1),
291             analysis.GetValueDefinedAt(constant2));
292   EXPECT_EQ(analysis.GetUniqueValueAt(call), analysis.GetValueDefinedAt(add));
293 
294   EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
295               UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{add, 0, {}}));
296   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
297               UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{add, 1, {}}));
298 
299   EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
300 }
301 
TEST_P(HloDataflowAnalysisTest,ComputationCalledTwiceWithSameArguments)302 TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) {
303   // Test a subcomputation which is called twice with identical values.
304   auto subbuilder = HloComputation::Builder("Subcomputation");
305   auto subparam0 = subbuilder.AddInstruction(
306       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
307   auto subparam1 = subbuilder.AddInstruction(
308       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
309   auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary(
310       scalar_shape_, HloOpcode::kAdd, subparam0, subparam1));
311   HloComputation* called_computation =
312       module_->AddEmbeddedComputation(subbuilder.Build());
313 
314   auto builder = HloComputation::Builder(TestName());
315   auto constant1 = builder.AddInstruction(
316       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
317   auto constant2 = builder.AddInstruction(
318       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
319   auto call1 = builder.AddInstruction(HloInstruction::CreateCall(
320       scalar_shape_, {constant1, constant2}, called_computation));
321   auto call2 = builder.AddInstruction(HloInstruction::CreateCall(
322       scalar_shape_, {constant1, constant2}, called_computation));
323   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
324       scalar_shape_, HloOpcode::kSubtract, call1, call2));
325   module_->AddEntryComputation(builder.Build());
326   SCOPED_TRACE(module_->ToString());
327 
328   bool ssa_form = GetParam();
329   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
330 
331   EXPECT_EQ(analysis.values().size(), 4);
332 
333   // Definitions should be identical to the single callsite case.
334   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
335   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
336   EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam0));
337   EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam1));
338   EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
339   EXPECT_FALSE(analysis.ValueIsDefinedAt(call1));
340   EXPECT_FALSE(analysis.ValueIsDefinedAt(call2));
341   EXPECT_TRUE(analysis.ValueIsDefinedAt(sub));
342 
343   EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
344               UnorderedElementsAre(HloUse{call1, 0, {}}, HloUse{call2, 0, {}},
345                                    HloUse{add, 0, {}}));
346   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
347               UnorderedElementsAre(HloUse{call1, 1, {}}, HloUse{call2, 1, {}},
348                                    HloUse{add, 1, {}}));
349   // The Add from the subcomputation is used as both operands of the Subtract.
350   EXPECT_THAT(analysis.GetValueDefinedAt(add).uses(),
351               UnorderedElementsAre(HloUse{sub, 0, {}}, HloUse{sub, 1, {}}));
352 
353   EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module());
354   EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_module());
355 }
356 
TEST_P(HloDataflowAnalysisTest,ComputationCalledTwiceWithDifferentArguments)357 TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) {
358   // Test a subcomputation which is called twice with different argument values.
359   auto subbuilder = HloComputation::Builder("Subcomputation");
360   auto subparam0 = subbuilder.AddInstruction(
361       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
362   auto subparam1 = subbuilder.AddInstruction(
363       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
364   auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary(
365       scalar_shape_, HloOpcode::kAdd, subparam0, subparam1));
366   HloComputation* called_computation =
367       module_->AddEmbeddedComputation(subbuilder.Build());
368 
369   auto builder = HloComputation::Builder(TestName());
370   auto constant1 = builder.AddInstruction(
371       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
372   auto constant2 = builder.AddInstruction(
373       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
374   auto call1 = builder.AddInstruction(HloInstruction::CreateCall(
375       scalar_shape_, {constant1, constant2}, called_computation));
376   auto call2 = builder.AddInstruction(HloInstruction::CreateCall(
377       scalar_shape_, {call1, constant2}, called_computation));
378   module_->AddEntryComputation(builder.Build());
379   SCOPED_TRACE(module_->ToString());
380 
381   bool ssa_form = GetParam();
382   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
383 
384   EXPECT_FALSE(analysis.ValueIsDefinedAt(call1));
385   EXPECT_FALSE(analysis.ValueIsDefinedAt(call2));
386 
387   EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam0));
388 
389   EXPECT_THAT(HloValuesAt(subparam0),
390               UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
391                                    analysis.GetValueDefinedAt(add)));
392   EXPECT_THAT(HloValuesAt(subparam1),
393               UnorderedElementsAre(analysis.GetValueDefinedAt(constant2)));
394 
395   EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
396 }
397 
TEST_P(HloDataflowAnalysisTest,NestedCalls)398 TEST_P(HloDataflowAnalysisTest, NestedCalls) {
399   // Test a module with nested computations. HLO is:
400   //
401   // F32[] inner_computation(F32[] %param0, F32[] %param1):
402   //   %add = Add(%param0, %param1)
403   //
404   // F32[] outer_computation((F32[] %param0, F32[] %param1):
405   //  ;; Note that parameters are interchanged in the call.
406   //   %nested_call = Call(inner_computation, {%param1, %param0})
407   //
408   // F32[] entry:
409   //   %constant1 = Constant(1.0)
410   //   %constant2 = Constant(2.0)
411   //   %call = Call(outer_computation, {%constant1, %constant2})
412   //
413   auto inner_builder = HloComputation::Builder("InnerComputation");
414   auto inner_param0 = inner_builder.AddInstruction(
415       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
416   auto inner_param1 = inner_builder.AddInstruction(
417       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
418   auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary(
419       scalar_shape_, HloOpcode::kAdd, inner_param0, inner_param1));
420   HloComputation* inner_computation =
421       module_->AddEmbeddedComputation(inner_builder.Build());
422 
423   auto outer_builder = HloComputation::Builder("OuterComputation");
424   auto outer_param0 = outer_builder.AddInstruction(
425       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
426   auto outer_param1 = outer_builder.AddInstruction(
427       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
428   // Swizzle parameters.
429   auto nested_call = outer_builder.AddInstruction(HloInstruction::CreateCall(
430       scalar_shape_, {outer_param1, outer_param0}, inner_computation));
431   HloComputation* outer_computation =
432       module_->AddEmbeddedComputation(outer_builder.Build());
433 
434   auto builder = HloComputation::Builder(TestName());
435   auto constant1 = builder.AddInstruction(
436       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
437   auto constant2 = builder.AddInstruction(
438       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
439   auto call = builder.AddInstruction(HloInstruction::CreateCall(
440       scalar_shape_, {constant1, constant2}, outer_computation));
441   module_->AddEntryComputation(builder.Build());
442   SCOPED_TRACE(module_->ToString());
443 
444   bool ssa_form = GetParam();
445   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
446 
447   // Only three values should be defined. Most instructions just pass through
448   // their operand values.
449   EXPECT_EQ(analysis.values().size(), 3);
450 
451   // Verify that the uses of the constants are properly swizzled by parameter
452   // permutation in nested_call.
453   EXPECT_THAT(
454       analysis.GetValueDefinedAt(constant1).uses(),
455       UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{nested_call, 1, {}},
456                            HloUse{add, 1, {}}));
457   EXPECT_THAT(
458       analysis.GetValueDefinedAt(constant2).uses(),
459       UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{nested_call, 0, {}},
460                            HloUse{add, 0, {}}));
461 
462   EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
463 }
464 
TEST_P(HloDataflowAnalysisTest,SingleWhile)465 TEST_P(HloDataflowAnalysisTest, SingleWhile) {
466   // Test a simple single while instruction. The while body includes a
467   // pass-through value. HLO:
468   //
469   // body((F32[], F32[]) %tuple_param):
470   //   %add = Add(%tuple_param{0}, %tuple_param{1})
471   //   return Tuple(%tuple_param{0}, %add)
472   //
473   // condition((F32[], F32[]) %tuple_param):
474   //   return Constant(false)
475   //
476   // entry:
477   //   %constant1 = Constant(1.0)
478   //   %constant2 = Constant(2.0)
479   //   %tuple = Tuple(%constant1, %constant2)
480   //   return While(%tuple, body, condition)
481   //
482   const Shape tuple_shape =
483       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
484 
485   // Element 0 passes transparently through the body.
486   auto body_builder = HloComputation::Builder("body");
487   auto body_param = body_builder.AddInstruction(
488       HloInstruction::CreateParameter(0, tuple_shape, "param"));
489   auto body_element_0 = body_builder.AddInstruction(
490       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
491   auto body_element_1 = body_builder.AddInstruction(
492       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
493   auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
494       scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
495   auto body_root = body_builder.AddInstruction(
496       HloInstruction::CreateTuple({body_element_0, add}));
497   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
498 
499   // Condition computation trivially returns a constant "false".
500   auto cond_builder = HloComputation::Builder("condition");
501   auto cond_param = cond_builder.AddInstruction(
502       HloInstruction::CreateParameter(0, tuple_shape, "param"));
503   auto cond_constant = cond_builder.AddInstruction(
504       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
505   HloComputation* condition =
506       module_->AddEmbeddedComputation(cond_builder.Build());
507 
508   auto builder = HloComputation::Builder(TestName());
509   auto constant1 = builder.AddInstruction(
510       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
511   auto constant2 = builder.AddInstruction(
512       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
513   auto tuple = builder.AddInstruction(
514       HloInstruction::CreateTuple({constant1, constant2}));
515   auto xla_while = builder.AddInstruction(
516       HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
517   module_->AddEntryComputation(builder.Build());
518   SCOPED_TRACE(module_->ToString());
519 
520   bool ssa_form = GetParam();
521   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
522 
523   EXPECT_FALSE(analysis.GetValueDefinedAt(cond_constant).live_out_of_module());
524 
525   if (ssa_form) {
526     // Element 0 of the tuple passed through the body so no phi value is
527     // defined.
528     EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
529     EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
530     EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
531 
532     // Element 1 of the tuple should be a phi value.
533     EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
534     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
535     EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
536     EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{1}).is_phi());
537     EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
538     EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi());
539 
540     EXPECT_THAT(
541         analysis.GetValueDefinedAt(constant1).uses(),
542         UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{body_root, 0, {}},
543                              HloUse{xla_while, 0, {0}}));
544 
545     // Constant1 passes through the body and out of the module.
546     EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
547     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
548                     .live_out_of_module());
549 
550     EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module());
551   } else {
552     // While instruction and subcomputation parameters should not define values
553     // in non-ssa form.
554     EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
555     EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
556     EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
557     EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
558     EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
559     EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
560 
561     EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
562     EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
563   }
564 }
565 
TEST_P(HloDataflowAnalysisTest,SequentialWhiles)566 TEST_P(HloDataflowAnalysisTest, SequentialWhiles) {
567   // Test sequential while instructions. The while body includes a
568   // pass-through value. HLO:
569   //
570   // body((F32[], F32[]) %tuple_param):
571   //   %add = Add(%tuple_param{0}, %tuple_param{1})
572   //   return Tuple(%tuple_param{0}, %add)
573   //
574   // condition((F32[], F32[]) %tuple_param):
575   //   return Constant(false)
576   //
577   // entry:
578   //   %constant1 = Constant(1.0)
579   //   %constant2 = Constant(2.0)
580   //   %tuple = Tuple(%constant1, %constant2)
581   //   %while0 = While(%tuple, body, condition)
582   //   %while1 = While(%while0, body, condition)
583   //   return While(%while1, body, condition)
584   //
585   const Shape tuple_shape =
586       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
587 
588   // Element 0 passes transparently through the body.
589   auto body_builder = HloComputation::Builder("body");
590   auto body_param = body_builder.AddInstruction(
591       HloInstruction::CreateParameter(0, tuple_shape, "param"));
592   auto body_element_0 = body_builder.AddInstruction(
593       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
594   auto body_element_1 = body_builder.AddInstruction(
595       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
596   auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
597       scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
598   body_builder.AddInstruction(
599       HloInstruction::CreateTuple({body_element_0, add}));
600   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
601 
602   auto cond_builder = HloComputation::Builder("condition");
603   cond_builder.AddInstruction(
604       HloInstruction::CreateParameter(0, tuple_shape, "param"));
605   cond_builder.AddInstruction(
606       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
607   HloComputation* condition =
608       module_->AddEmbeddedComputation(cond_builder.Build());
609 
610   auto builder = HloComputation::Builder(TestName());
611   auto constant1 = builder.AddInstruction(
612       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
613   auto constant2 = builder.AddInstruction(
614       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
615   auto tuple = builder.AddInstruction(
616       HloInstruction::CreateTuple({constant1, constant2}));
617   auto xla_while0 = builder.AddInstruction(
618       HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
619   auto xla_while1 = builder.AddInstruction(
620       HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while0));
621   auto xla_while2 = builder.AddInstruction(
622       HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while1));
623   module_->AddEntryComputation(builder.Build());
624   SCOPED_TRACE(module_->ToString());
625 
626   bool ssa_form = GetParam();
627   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
628 
629   // Element 0 is passed through all the while instructions and out of the
630   // module..
631   EXPECT_EQ(analysis.GetUniqueValueAt(xla_while0, /*index=*/{0}),
632             analysis.GetValueDefinedAt(constant1));
633   EXPECT_EQ(analysis.GetUniqueValueAt(xla_while1, /*index=*/{0}),
634             analysis.GetValueDefinedAt(constant1));
635   EXPECT_EQ(analysis.GetUniqueValueAt(xla_while2, /*index=*/{0}),
636             analysis.GetValueDefinedAt(constant1));
637   EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
638 }
639 
TEST_P(HloDataflowAnalysisTest,NestedWhiles)640 TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
641   // Test nested while instructions. The inner body passes through element 0 of
642   // its parameter, and the outer body passes through element 1.  HLO:
643   //
644   // inner_body((F32[], F32[]) %tuple_param):
645   //   %add = Add(%tuple_param{0}, %tuple_param{1})
646   //   return Tuple(%tuple_param{0}, %add)
647   //
648   // outer_body((F32[], F32[]) %tuple_param):
649   //   %negate = Negate(%tuple_param{0})
650   //   %tuple = Tuple(%negate, %tuple_param{1})
651   //   return While(%tuple, inner_body, condition)
652   //
653   // entry:
654   //   %constant1 = Constant(1.0)
655   //   %constant2 = Constant(2.0)
656   //   %tuple = Tuple(%constant1, %constant2)
657   //   return While(%tuple, outer_body, condition)
658   //
659   const Shape tuple_shape =
660       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
661 
662   auto cond_builder = HloComputation::Builder("condition");
663   cond_builder.AddInstruction(
664       HloInstruction::CreateParameter(0, tuple_shape, "param"));
665   cond_builder.AddInstruction(
666       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
667   HloComputation* condition =
668       module_->AddEmbeddedComputation(cond_builder.Build());
669 
670   // Element 0 passes transparently through the body.
671   auto inner_builder = HloComputation::Builder("inner_body");
672   auto inner_param = inner_builder.AddInstruction(
673       HloInstruction::CreateParameter(0, tuple_shape, "param"));
674   auto inner_element_0 = inner_builder.AddInstruction(
675       HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 0));
676   auto inner_element_1 = inner_builder.AddInstruction(
677       HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 1));
678   auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary(
679       scalar_shape_, HloOpcode::kAdd, inner_element_0, inner_element_1));
680   inner_builder.AddInstruction(
681       HloInstruction::CreateTuple({inner_element_0, add}));
682   HloComputation* inner_body =
683       module_->AddEmbeddedComputation(inner_builder.Build());
684 
685   // Element 1 passes transparently through the body.
686   auto outer_builder = HloComputation::Builder("outer_body");
687   auto outer_param = outer_builder.AddInstruction(
688       HloInstruction::CreateParameter(0, tuple_shape, "param"));
689   auto outer_element_0 = outer_builder.AddInstruction(
690       HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 0));
691   auto negate = outer_builder.AddInstruction(HloInstruction::CreateUnary(
692       scalar_shape_, HloOpcode::kNegate, outer_element_0));
693   auto outer_element_1 = outer_builder.AddInstruction(
694       HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 1));
695   auto outer_tuple = outer_builder.AddInstruction(
696       HloInstruction::CreateTuple({negate, outer_element_1}));
697   auto nested_while = outer_builder.AddInstruction(HloInstruction::CreateWhile(
698       tuple_shape, condition, inner_body, outer_tuple));
699   HloComputation* outer_body =
700       module_->AddEmbeddedComputation(outer_builder.Build());
701 
702   auto builder = HloComputation::Builder(TestName());
703   auto constant1 = builder.AddInstruction(
704       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
705   auto constant2 = builder.AddInstruction(
706       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
707   auto tuple = builder.AddInstruction(
708       HloInstruction::CreateTuple({constant1, constant2}));
709   auto entry_while = builder.AddInstruction(
710       HloInstruction::CreateWhile(tuple_shape, condition, outer_body, tuple));
711   module_->AddEntryComputation(builder.Build());
712   SCOPED_TRACE(module_->ToString());
713 
714   bool ssa_form = GetParam();
715   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
716 
717   EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
718               UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
719   if (ssa_form) {
720     EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_param, /*index=*/{1}));
721     EXPECT_TRUE(
722         analysis.GetValueDefinedAt(inner_param, /*index=*/{1}).is_phi());
723 
724     // Element 0 of the nested while is %negate.
725     EXPECT_FALSE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{0}));
726     EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
727                 UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
728     // Element 1 is a phi value (join of %add and %constant2).
729     EXPECT_TRUE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{1}));
730     EXPECT_TRUE(
731         analysis.GetValueDefinedAt(nested_while, /*index=*/{1}).is_phi());
732 
733     EXPECT_TRUE(analysis.ValueIsDefinedAt(entry_while, /*index=*/{0}));
734     EXPECT_TRUE(
735         analysis.GetValueDefinedAt(entry_while, /*index=*/{0}).is_phi());
736 
737     EXPECT_TRUE(analysis.ValueIsDefinedAt(entry_while, /*index=*/{1}));
738     EXPECT_TRUE(
739         analysis.GetValueDefinedAt(entry_while, /*index=*/{1}).is_phi());
740   } else {
741     EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{1}),
742                 UnorderedElementsAre(analysis.GetValueDefinedAt(add),
743                                      analysis.GetValueDefinedAt(constant2)));
744 
745     EXPECT_THAT(HloValuesAt(nested_while, /*index=*/{0}),
746                 UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
747     EXPECT_THAT(HloValuesAt(nested_while, /*index=*/{1}),
748                 UnorderedElementsAre(analysis.GetValueDefinedAt(add),
749                                      analysis.GetValueDefinedAt(constant2)));
750 
751     EXPECT_THAT(HloValuesAt(entry_while, /*index=*/{0}),
752                 UnorderedElementsAre(analysis.GetValueDefinedAt(negate),
753                                      analysis.GetValueDefinedAt(constant1)));
754     EXPECT_THAT(HloValuesAt(entry_while, /*index=*/{1}),
755                 UnorderedElementsAre(analysis.GetValueDefinedAt(add),
756                                      analysis.GetValueDefinedAt(constant2)));
757   }
758 }
759 
TEST_P(HloDataflowAnalysisTest,SwizzlingWhile)760 TEST_P(HloDataflowAnalysisTest, SwizzlingWhile) {
761   // Test a while instruction with a body which permutes it's tuple parameter
762   // elements. HLO:
763   //
764   // body((F32[], F32[]) %tuple_param):
765   //   return Tuple(%tuple_param{1}, %tuple_param{0})
766   //
767   // condition((F32[], F32[]) %tuple_param):
768   //   return Constant(false)
769   //
770   // entry:
771   //   %constant1 = Constant(1.0)
772   //   %constant2 = Constant(2.0)
773   //   %tuple = Tuple(%constant1, %constant2)
774   //   return While(%tuple, body, condition)
775   //
776   const Shape tuple_shape =
777       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
778 
779   auto body_builder = HloComputation::Builder("body");
780   auto body_param = body_builder.AddInstruction(
781       HloInstruction::CreateParameter(0, tuple_shape, "param"));
782   auto body_element_0 = body_builder.AddInstruction(
783       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
784   auto body_element_1 = body_builder.AddInstruction(
785       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
786   body_builder.AddInstruction(
787       HloInstruction::CreateTuple({body_element_1, body_element_0}));
788   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
789 
790   auto cond_builder = HloComputation::Builder("condition");
791   auto cond_param = cond_builder.AddInstruction(
792       HloInstruction::CreateParameter(0, tuple_shape, "param"));
793   cond_builder.AddInstruction(
794       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
795   HloComputation* condition =
796       module_->AddEmbeddedComputation(cond_builder.Build());
797 
798   auto builder = HloComputation::Builder(TestName());
799   auto constant1 = builder.AddInstruction(
800       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
801   auto constant2 = builder.AddInstruction(
802       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
803   auto tuple = builder.AddInstruction(
804       HloInstruction::CreateTuple({constant1, constant2}));
805   auto xla_while = builder.AddInstruction(
806       HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
807   module_->AddEntryComputation(builder.Build());
808   SCOPED_TRACE(module_->ToString());
809 
810   bool ssa_form = GetParam();
811   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
812 
813   if (ssa_form) {
814     // Element 0 and 1 in the while should both be phi values.
815     EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
816     EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{0}).is_phi());
817     EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
818     EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{1}).is_phi());
819 
820     EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
821     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}).is_phi());
822     EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
823     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
824 
825     EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
826     EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{0}).is_phi());
827     EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
828     EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi());
829 
830     EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
831     EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
832     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{})
833                     .live_out_of_module());
834     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0})
835                     .live_out_of_module());
836     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
837                     .live_out_of_module());
838   } else {
839     // Elements 0 and 1 have both constants as reaching definitions.
840     EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{0}),
841                 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
842                                      analysis.GetValueDefinedAt(constant2)));
843     EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{1}),
844                 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
845                                      analysis.GetValueDefinedAt(constant2)));
846     EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
847     EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
848   }
849 }
850 
TEST_P(HloDataflowAnalysisTest,ArraySelect)851 TEST_P(HloDataflowAnalysisTest, ArraySelect) {
852   // Test a kSelect of an array value.
853   auto builder = HloComputation::Builder(TestName());
854   auto pred = builder.AddInstruction(
855       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
856   auto constant1 = builder.AddInstruction(
857       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
858   auto constant2 = builder.AddInstruction(
859       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
860   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
861       scalar_shape_, HloOpcode::kSelect, pred, constant1, constant2));
862 
863   module_->AddEntryComputation(builder.Build());
864   SCOPED_TRACE(module_->ToString());
865 
866   bool ssa_form = GetParam();
867   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
868 
869   EXPECT_TRUE(analysis.ValueIsDefinedAt(select));
870   EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
871   EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
872   EXPECT_TRUE(analysis.GetValueDefinedAt(select).live_out_of_module());
873 }
874 
TEST_P(HloDataflowAnalysisTest,TupleSelect)875 TEST_P(HloDataflowAnalysisTest, TupleSelect) {
876   // Test a kTupleSelect. Non-top-level element flow through the instruction.
877   auto builder = HloComputation::Builder(TestName());
878   auto pred = builder.AddInstruction(
879       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
880   auto constant1 = builder.AddInstruction(
881       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
882   auto constant2 = builder.AddInstruction(
883       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
884   auto constant3 = builder.AddInstruction(
885       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
886   auto constant4 = builder.AddInstruction(
887       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4.0)));
888   auto tuple1 =
889       builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
890   auto tuple2 =
891       builder.AddInstruction(HloInstruction::CreateTuple({constant2}));
892   auto tuple3 =
893       builder.AddInstruction(HloInstruction::CreateTuple({constant3}));
894   auto tuple4 =
895       builder.AddInstruction(HloInstruction::CreateTuple({constant4}));
896   const Shape tuple_shape = tuple1->shape();
897   auto select11 = builder.AddInstruction(HloInstruction::CreateTernary(
898       tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple1));
899   auto select12 = builder.AddInstruction(HloInstruction::CreateTernary(
900       tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple2));
901   auto select34 = builder.AddInstruction(HloInstruction::CreateTernary(
902       tuple_shape, HloOpcode::kTupleSelect, pred, tuple3, tuple4));
903   auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary(
904       tuple_shape, HloOpcode::kTupleSelect, pred, select12, select34));
905 
906   module_->AddEntryComputation(builder.Build());
907   SCOPED_TRACE(module_->ToString());
908 
909   bool ssa_form = GetParam();
910   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
911 
912   // Top-level value is always defined by a kTupleSelect.
913   EXPECT_TRUE(analysis.ValueIsDefinedAt(select11));
914   EXPECT_TRUE(analysis.ValueIsDefinedAt(select12));
915   EXPECT_TRUE(analysis.ValueIsDefinedAt(select34));
916   EXPECT_TRUE(analysis.ValueIsDefinedAt(select1234));
917 
918   EXPECT_FALSE(analysis.ValueIsDefinedAt(select11, /*index=*/{0}));
919   EXPECT_FALSE(analysis.ValueIsDefinedAt(select12, /*index=*/{0}));
920   EXPECT_FALSE(analysis.ValueIsDefinedAt(select34, /*index=*/{0}));
921   EXPECT_FALSE(analysis.ValueIsDefinedAt(select1234, /*index=*/{0}));
922 
923   EXPECT_THAT(HloValuesAt(select11, /*index=*/{0}),
924               UnorderedElementsAre(analysis.GetValueDefinedAt(constant1)));
925   EXPECT_THAT(HloValuesAt(select12, /*index=*/{0}),
926               UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
927                                    analysis.GetValueDefinedAt(constant2)));
928   EXPECT_THAT(HloValuesAt(select34, /*index=*/{0}),
929               UnorderedElementsAre(analysis.GetValueDefinedAt(constant3),
930                                    analysis.GetValueDefinedAt(constant4)));
931   EXPECT_THAT(HloValuesAt(select1234, /*index=*/{0}),
932               UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
933                                    analysis.GetValueDefinedAt(constant2),
934                                    analysis.GetValueDefinedAt(constant3),
935                                    analysis.GetValueDefinedAt(constant4)));
936 
937   EXPECT_THAT(
938       analysis.GetValueDefinedAt(tuple1, /*index=*/{}).uses(),
939       UnorderedElementsAre(HloUse{select11, 1, {}}, HloUse{select11, 2, {}},
940                            HloUse{select12, 1, {}}));
941 
942   // The two constant values just pass through the Selects and are not
943   // used except at the root. They are live out however.
944   EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
945               UnorderedElementsAre(HloUse{select1234, 1, {0}}));
946   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
947               UnorderedElementsAre(HloUse{select1234, 1, {0}}));
948   EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
949   EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
950 }
951 
TEST_P(HloDataflowAnalysisTest,NestedTupleSelect)952 TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) {
953   // Test kTupleSelect of a nested tuple.
954   auto builder = HloComputation::Builder(TestName());
955   auto pred = builder.AddInstruction(
956       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
957   auto constant1 = builder.AddInstruction(
958       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
959   auto constant2 = builder.AddInstruction(
960       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
961   auto constant3 = builder.AddInstruction(
962       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
963   auto constant4 = builder.AddInstruction(
964       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4.0)));
965   auto constant5 = builder.AddInstruction(
966       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0)));
967   auto inner_tuple1 = builder.AddInstruction(
968       HloInstruction::CreateTuple({constant2, constant3}));
969   auto tuple1 = builder.AddInstruction(
970       HloInstruction::CreateTuple({constant1, inner_tuple1}));
971   auto inner_tuple2 = builder.AddInstruction(
972       HloInstruction::CreateTuple({constant5, constant3}));
973   auto tuple2 = builder.AddInstruction(
974       HloInstruction::CreateTuple({constant4, inner_tuple2}));
975   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
976       tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
977 
978   module_->AddEntryComputation(builder.Build());
979   SCOPED_TRACE(module_->ToString());
980 
981   bool ssa_form = GetParam();
982   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
983 
984   EXPECT_TRUE(analysis.ValueIsDefinedAt(select));
985 
986   EXPECT_THAT(HloValuesAt(select, /*index=*/{0}),
987               UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
988                                    analysis.GetValueDefinedAt(constant4)));
989   EXPECT_THAT(HloValuesAt(select, /*index=*/{1}),
990               UnorderedElementsAre(analysis.GetValueDefinedAt(inner_tuple1),
991                                    analysis.GetValueDefinedAt(inner_tuple2)));
992   EXPECT_THAT(HloValuesAt(select, /*index=*/{1, 0}),
993               UnorderedElementsAre(analysis.GetValueDefinedAt(constant2),
994                                    analysis.GetValueDefinedAt(constant5)));
995   EXPECT_THAT(HloValuesAt(select, /*index=*/{1, 1}),
996               UnorderedElementsAre(analysis.GetValueDefinedAt(constant3)));
997 }
998 
TEST_P(HloDataflowAnalysisTest,TupleSelectToWhile)999 TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) {
1000   // Test a tuple-shaped kTupleSelect feeding a kWhile instruction. HLO:
1001   //
1002   // body((F32[], F32[]) %tuple_param):
1003   //   %add = Add(%tuple_param{0}, %tuple_param{1})
1004   //   return Tuple(%tuple_param{0}, %add)
1005   //
1006   // condition((F32[], F32[]) %tuple_param):
1007   //   return Constant(false)
1008   //
1009   // entry:
1010   //   %constant1 = Constant(1.0)
1011   //   %constant2 = Constant(2.0)
1012   //   %constant3 = Constant(3.0)
1013   //   %tuple1 = Tuple(%constant1)
1014   //   %tuple2 = Tuple(%constant2)
1015   //   %select = Select(%tuple1, %tuple2)
1016   //   %gte = GetTupleElement(%select, 0)
1017   //   %tuple = Tuple(%gte, %constant3)
1018   //   return While(%tuple, body, condition)
1019   //
1020   auto builder = HloComputation::Builder(TestName());
1021 
1022   const Shape tuple_shape =
1023       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1024 
1025   // Element 0 passes transparently through the body.
1026   auto body_builder = HloComputation::Builder("body");
1027   auto body_param = body_builder.AddInstruction(
1028       HloInstruction::CreateParameter(0, tuple_shape, "param"));
1029   auto body_element_0 = body_builder.AddInstruction(
1030       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
1031   auto body_element_1 = body_builder.AddInstruction(
1032       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
1033   auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
1034       scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
1035   body_builder.AddInstruction(
1036       HloInstruction::CreateTuple({body_element_0, add}));
1037   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
1038 
1039   auto cond_builder = HloComputation::Builder("condition");
1040   cond_builder.AddInstruction(
1041       HloInstruction::CreateParameter(0, tuple_shape, "param"));
1042   cond_builder.AddInstruction(
1043       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1044   HloComputation* condition =
1045       module_->AddEmbeddedComputation(cond_builder.Build());
1046 
1047   auto pred = builder.AddInstruction(
1048       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1049   auto constant1 = builder.AddInstruction(
1050       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1051   auto constant2 = builder.AddInstruction(
1052       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
1053   auto constant3 = builder.AddInstruction(
1054       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
1055   auto tuple1 =
1056       builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
1057   auto tuple2 =
1058       builder.AddInstruction(HloInstruction::CreateTuple({constant2}));
1059   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
1060       tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
1061   auto gte = builder.AddInstruction(
1062       HloInstruction::CreateGetTupleElement(scalar_shape_, select, 0));
1063   auto tuple =
1064       builder.AddInstruction(HloInstruction::CreateTuple({gte, constant3}));
1065   auto xla_while = builder.AddInstruction(
1066       HloInstruction::CreateWhile(tuple->shape(), condition, body, tuple));
1067 
1068   module_->AddEntryComputation(builder.Build());
1069   SCOPED_TRACE(module_->ToString());
1070 
1071   bool ssa_form = GetParam();
1072   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1073 
1074   if (ssa_form) {
1075     EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
1076     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}).is_phi());
1077     EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
1078     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
1079 
1080     EXPECT_FALSE(analysis.ValueIsDefinedAt(select, /*index=*/{0}));
1081 
1082     EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
1083     EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
1084     EXPECT_FALSE(analysis.GetValueDefinedAt(constant3).live_out_of_module());
1085     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
1086                     .live_out_of_module());
1087   } else {
1088     EXPECT_THAT(HloValuesAt(gte),
1089                 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
1090                                      analysis.GetValueDefinedAt(constant2)));
1091     EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{0}),
1092                 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
1093                                      analysis.GetValueDefinedAt(constant2)));
1094     EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{1}),
1095                 UnorderedElementsAre(analysis.GetValueDefinedAt(add),
1096                                      analysis.GetValueDefinedAt(constant3)));
1097     EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
1098     EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
1099     EXPECT_TRUE(analysis.GetValueDefinedAt(constant3).live_out_of_module());
1100   }
1101 }
1102 
TEST_P(HloDataflowAnalysisTest,BitcastDefinesValue)1103 TEST_P(HloDataflowAnalysisTest, BitcastDefinesValue) {
1104   // Test the bitcast_defines_value flag to the dataflow analysis.
1105   auto builder = HloComputation::Builder(TestName());
1106   auto constant = builder.AddInstruction(
1107       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1108   auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
1109       scalar_shape_, HloOpcode::kBitcast, constant));
1110 
1111   module_->AddEntryComputation(builder.Build());
1112   SCOPED_TRACE(module_->ToString());
1113 
1114   bool ssa_form = GetParam();
1115   {
1116     const HloDataflowAnalysis& analysis =
1117         RunAnalysis(ssa_form, /*bitcast_defines_value=*/true);
1118 
1119     EXPECT_EQ(analysis.values().size(), 2);
1120 
1121     EXPECT_TRUE(analysis.ValueIsDefinedAt(constant));
1122     EXPECT_TRUE(analysis.ValueIsDefinedAt(bitcast));
1123     EXPECT_FALSE(analysis.GetValueDefinedAt(constant).live_out_of_module());
1124     EXPECT_TRUE(analysis.GetValueDefinedAt(bitcast).live_out_of_module());
1125   }
1126   {
1127     const HloDataflowAnalysis& analysis =
1128         RunAnalysis(ssa_form, /*bitcast_defines_value=*/false);
1129     EXPECT_EQ(analysis.values().size(), 1);
1130 
1131     EXPECT_TRUE(analysis.ValueIsDefinedAt(constant));
1132     EXPECT_FALSE(analysis.ValueIsDefinedAt(bitcast));
1133     EXPECT_TRUE(analysis.GetValueDefinedAt(constant).live_out_of_module());
1134   }
1135 }
1136 
TEST_P(HloDataflowAnalysisTest,TupleCopy)1137 TEST_P(HloDataflowAnalysisTest, TupleCopy) {
1138   // Test that a tuple-shaped copy only copies (defines) the top-level value.
1139   auto builder = HloComputation::Builder(TestName());
1140   auto param0 = builder.AddInstruction(
1141       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
1142   auto param1 = builder.AddInstruction(
1143       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
1144   auto tuple =
1145       builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
1146   auto copy = builder.AddInstruction(
1147       HloInstruction::CreateUnary(tuple->shape(), HloOpcode::kCopy, tuple));
1148   module_->AddEntryComputation(builder.Build());
1149   SCOPED_TRACE(module_->ToString());
1150 
1151   bool ssa_form = GetParam();
1152   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1153 
1154   EXPECT_EQ(analysis.values().size(), 4);
1155 
1156   EXPECT_TRUE(analysis.ValueIsDefinedAt(param0));
1157   EXPECT_TRUE(analysis.ValueIsDefinedAt(param1));
1158   EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple, /*index=*/{}));
1159   EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{0}));
1160   EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{1}));
1161   EXPECT_TRUE(analysis.ValueIsDefinedAt(copy, /*index=*/{}));
1162   EXPECT_FALSE(analysis.ValueIsDefinedAt(copy, /*index=*/{0}));
1163   EXPECT_FALSE(analysis.ValueIsDefinedAt(copy, /*index=*/{1}));
1164 
1165   EXPECT_THAT(HloValuesAt(copy, /*index=*/{0}),
1166               UnorderedElementsAre(analysis.GetValueDefinedAt(param0)));
1167   EXPECT_THAT(HloValuesAt(copy, /*index=*/{1}),
1168               UnorderedElementsAre(analysis.GetValueDefinedAt(param1)));
1169   EXPECT_TRUE(
1170       analysis.GetValueDefinedAt(copy, /*index=*/{}).live_out_of_module());
1171 }
1172 
TEST_P(HloDataflowAnalysisTest,SendAndSendDone)1173 TEST_P(HloDataflowAnalysisTest, SendAndSendDone) {
1174   // Test that a Send forwards its operand to the output tuple at {0}.
1175   auto builder = HloComputation::Builder(TestName());
1176   auto param = builder.AddInstruction(
1177       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
1178   auto token = builder.AddInstruction(HloInstruction::CreateToken());
1179   auto send = builder.AddInstruction(
1180       HloInstruction::CreateSend(param, token, /*channel_id=*/0));
1181   auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
1182   module_->AddEntryComputation(builder.Build());
1183   SCOPED_TRACE(module_->ToString());
1184 
1185   bool ssa_form = GetParam();
1186   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1187 
1188   EXPECT_EQ(analysis.values().size(), 6);
1189 
1190   EXPECT_TRUE(analysis.ValueIsDefinedAt(param));
1191   EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{}));
1192   EXPECT_FALSE(analysis.ValueIsDefinedAt(send, /*index=*/{0}));
1193   EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{1}));
1194   EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{2}));
1195   EXPECT_TRUE(analysis.ValueIsDefinedAt(send_done));
1196   EXPECT_THAT(HloValuesAt(send, /*index=*/{0}),
1197               UnorderedElementsAre(analysis.GetValueDefinedAt(param)));
1198 }
1199 
TEST_P(HloDataflowAnalysisTest,RecvAndRecvDone)1200 TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) {
1201   // Test that a RecvDone forwards its operand tuple element at {0} to element
1202   // {0} of the output.
1203   auto builder = HloComputation::Builder(TestName());
1204   auto token = builder.AddInstruction(HloInstruction::CreateToken());
1205   auto recv = builder.AddInstruction(
1206       HloInstruction::CreateRecv(scalar_shape_, token, /*channel_id=*/0));
1207   auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
1208   module_->AddEntryComputation(builder.Build());
1209   SCOPED_TRACE(module_->ToString());
1210 
1211   bool ssa_form = GetParam();
1212   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1213 
1214   EXPECT_EQ(analysis.values().size(), 7);
1215 
1216   EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{}));
1217   EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0}));
1218   EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{1}));
1219   EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{2}));
1220   EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{}));
1221   EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{0}));
1222   EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{1}));
1223   EXPECT_THAT(HloValuesAt(recv_done, /*index=*/{0}),
1224               UnorderedElementsAre(analysis.GetValueDefinedAt(recv, {0})));
1225   EXPECT_TRUE(
1226       analysis.GetValueDefinedAt(recv, /*index=*/{0}).live_out_of_module());
1227 }
1228 
TEST_P(HloDataflowAnalysisTest,ElementwiseChainInterference)1229 TEST_P(HloDataflowAnalysisTest, ElementwiseChainInterference) {
1230   // A simple chain of elementwise operations. No values should interfere.
1231   //
1232   // param --> negate -> exp -> log
1233   //
1234   auto builder = HloComputation::Builder(TestName());
1235   auto param = builder.AddInstruction(
1236       HloInstruction::CreateParameter(0, vector_shape_, "param"));
1237   auto negate = builder.AddInstruction(
1238       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
1239   auto exp = builder.AddInstruction(
1240       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, negate));
1241   auto log = builder.AddInstruction(
1242       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kLog, exp));
1243 
1244   module_->AddEntryComputation(builder.Build());
1245   SCOPED_TRACE(module_->ToString());
1246   RunAnalysis(GetParam());
1247 
1248   DependencyHloOrdering ordering(module_.get());
1249 
1250   // No values should interfere.
1251   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, negate));
1252   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
1253   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, log));
1254   EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, exp));
1255   EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, log));
1256   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, negate));
1257   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, log));
1258   EXPECT_FALSE(InstructionsMayInterfere(ordering, log, negate));
1259   EXPECT_FALSE(InstructionsMayInterfere(ordering, log, exp));
1260 
1261   // Values should interfere with itself.
1262   EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, exp));
1263 }
1264 
TEST_P(HloDataflowAnalysisTest,MultipleEntryParameters_Sequential)1265 TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) {
1266   // Two entry params, which interfere with each other.
1267   //
1268   // param0 --> negate ---------------\
1269   //                param1 --> exp --> add
1270   auto builder = HloComputation::Builder(TestName());
1271   auto param0 = builder.AddInstruction(
1272       HloInstruction::CreateParameter(0, vector_shape_, "param0"));
1273   auto param1 = builder.AddInstruction(
1274       HloInstruction::CreateParameter(1, vector_shape_, "param1"));
1275   auto negate = builder.AddInstruction(
1276       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param0));
1277   auto exp = builder.AddInstruction(
1278       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param1));
1279   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1280       vector_shape_, HloOpcode::kAdd, negate, exp));
1281 
1282   auto entry = module_->AddEntryComputation(builder.Build());
1283   SCOPED_TRACE(module_->ToString());
1284   RunAnalysis(GetParam());
1285 
1286   HloSchedule schedule(module_.get());
1287   schedule.set_sequence(entry, {param0, negate, param1, exp, add});
1288   TF_ASSERT_OK(schedule.Verify());
1289   SequentialHloOrdering ordering(schedule);
1290 
1291   // Entry parameters interfere as if they are defined simultaneously at
1292   // the very beginning.
1293   EXPECT_TRUE(InstructionsMayInterfere(ordering, param0, param1));
1294   EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, negate));
1295   EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, exp));
1296   EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, add));
1297   EXPECT_TRUE(InstructionsMayInterfere(ordering, param1, param0));
1298   EXPECT_TRUE(InstructionsMayInterfere(ordering, param1, negate));
1299   EXPECT_FALSE(InstructionsMayInterfere(ordering, param1, exp));
1300   EXPECT_FALSE(InstructionsMayInterfere(ordering, param1, add));
1301 
1302   // Negate and exp still interfere.
1303   EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp));
1304   EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate));
1305 
1306   // But {negate, add} and {exp, add} don't interfere.
1307   EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add));
1308   EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate));
1309   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add));
1310   EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp));
1311 }
1312 
TEST_P(HloDataflowAnalysisTest,WhileParameters_Sequential)1313 TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) {
1314   // Similar to MultipleEntryParameters_Sequential, but the parameter is of
1315   // while body computation. Body computation in the sequential order:
1316   //
1317   //  %constant = Constant(...)
1318   //  %exp = Exp(%constant)
1319   //  %param = Param(0)
1320   //  %add = Add(%param, %exp)  ;; Root of body
1321   //  %dead_constant = Constant(...)
1322   //  %dead_negate = Negate(%dead_constant)
1323   //
1324   // %constant and its only use %exp are ordered before 'param'. However, the
1325   // %constant and %param values still interfere because the parameter is
1326   // considered live into the while body.
1327   //
1328   // Similarly, %dead_constant and %dead_negate are ordered after the root of
1329   // the body computation %add. However, %add is liveout of the computation so
1330   // %dead_constant and %add interfere.
1331   auto body_builder = HloComputation::Builder(TestName());
1332   auto body_param = body_builder.AddInstruction(
1333       HloInstruction::CreateParameter(0, scalar_shape_, "body_param"));
1334   auto constant = body_builder.AddInstruction(
1335       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1336   auto exp = body_builder.AddInstruction(
1337       HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant));
1338   auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
1339       scalar_shape_, HloOpcode::kAdd, exp, body_param));
1340   auto dead_constant = body_builder.AddInstruction(
1341       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1342   auto dead_negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
1343       scalar_shape_, HloOpcode::kNegate, dead_constant));
1344   HloComputation* body = module_->AddEmbeddedComputation(
1345       body_builder.Build(/*root_instruction=*/add));
1346 
1347   auto cond_builder = HloComputation::Builder("condition");
1348   auto cond_param = cond_builder.AddInstruction(
1349       HloInstruction::CreateParameter(0, scalar_shape_, "cond_param"));
1350   auto cond_constant = cond_builder.AddInstruction(
1351       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1352   HloComputation* condition =
1353       module_->AddEmbeddedComputation(cond_builder.Build());
1354 
1355   auto builder = HloComputation::Builder(TestName());
1356   auto param = builder.AddInstruction(
1357       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
1358   auto xla_while = builder.AddInstruction(
1359       HloInstruction::CreateWhile(scalar_shape_, condition, body, param));
1360 
1361   auto entry = module_->AddEntryComputation(builder.Build());
1362   SCOPED_TRACE(module_->ToString());
1363   bool ssa_form = GetParam();
1364   RunAnalysis(ssa_form);
1365 
1366   HloSchedule schedule(module_.get());
1367   schedule.set_sequence(entry, {param, xla_while});
1368   schedule.set_sequence(condition, {cond_param, cond_constant});
1369   // Construct the order such that 'constant' and its use 'exp' are before
1370   // body_param.
1371   schedule.set_sequence(
1372       body, {constant, exp, body_param, add, dead_constant, dead_negate});
1373   TF_ASSERT_OK(schedule.Verify());
1374 
1375   SequentialHloOrdering ordering(schedule);
1376 
1377   // 'add' is live out of the body and will interfere with an later instructions
1378   // such as 'dead_constant' and 'dead_negate'.
1379   EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_constant));
1380   EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_negate));
1381 
1382   // The remaining checks test phi values defined by body and condition
1383   // parameters which only occur in the SSA form of the analysis.
1384   if (ssa_form) {
1385     // Though the ordering suggests 'constant' and 'param' should not interfere,
1386     // 'param' is live in and thus interferes with any earlier instruction of
1387     // the computation in the order (eg 'constant')'
1388     EXPECT_TRUE(InstructionsMayInterfere(ordering, body_param, constant));
1389     EXPECT_TRUE(InstructionsMayInterfere(ordering, body_param, exp));
1390     EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, add));
1391 
1392     // The following values end up in the same buffer:
1393     //  (1) the init value: 'param'
1394     //  (2) the body parameter: 'body_param'
1395     //  (3) the condition parameter: 'cond_param'
1396     //  (4) the root value of the while body: 'add'
1397     //  (5) the while value: 'xla_while'
1398     // None should interfere.
1399     EXPECT_FALSE(InstructionsMayInterfere(ordering, param, body_param));
1400     EXPECT_FALSE(InstructionsMayInterfere(ordering, param, cond_param));
1401     EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add));
1402     EXPECT_FALSE(InstructionsMayInterfere(ordering, param, xla_while));
1403 
1404     EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, cond_param));
1405     EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, add));
1406     EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, xla_while));
1407 
1408     EXPECT_FALSE(InstructionsMayInterfere(ordering, cond_param, add));
1409     EXPECT_FALSE(InstructionsMayInterfere(ordering, cond_param, xla_while));
1410 
1411     EXPECT_FALSE(InstructionsMayInterfere(ordering, add, xla_while));
1412   }
1413 }
1414 
TEST_P(HloDataflowAnalysisTest,NonElementwiseOperand)1415 TEST_P(HloDataflowAnalysisTest, NonElementwiseOperand) {
1416   // A chain of operations with two elementwise and one non-elementwise. The
1417   // elementwise op should not interfere with its operand, while the
1418   // non-elementwise op should interfere. Entry params always interfere.
1419   //
1420   // param --> exp -> negate -> reverse
1421   //
1422   auto builder = HloComputation::Builder(TestName());
1423   auto param = builder.AddInstruction(
1424       HloInstruction::CreateParameter(0, vector_shape_, "param"));
1425   auto exp = builder.AddInstruction(
1426       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
1427   auto negate = builder.AddInstruction(
1428       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, exp));
1429   auto reverse = builder.AddInstruction(
1430       HloInstruction::CreateReverse(vector_shape_, negate, {0}));
1431 
1432   module_->AddEntryComputation(builder.Build());
1433   SCOPED_TRACE(module_->ToString());
1434   RunAnalysis(GetParam());
1435 
1436   DependencyHloOrdering ordering(module_.get());
1437 
1438   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
1439   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, negate));
1440   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, reverse));
1441 
1442   // Negate is elementwise, so doesn't interfere with its operand.
1443   // Reverse is non-elementwise, so does interfere with its operand.
1444   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, negate));
1445   EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, reverse));
1446 }
1447 
TEST_P(HloDataflowAnalysisTest,OverlappedValues)1448 TEST_P(HloDataflowAnalysisTest, OverlappedValues) {
1449   // Verify simultaneously live values interfere (exp and negate).
1450   //
1451   // param --> negate -> add
1452   //     \---> exp -----/
1453   //
1454   auto builder = HloComputation::Builder(TestName());
1455   auto param = builder.AddInstruction(
1456       HloInstruction::CreateParameter(0, vector_shape_, "param"));
1457   auto negate = builder.AddInstruction(
1458       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
1459   auto exp = builder.AddInstruction(
1460       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
1461   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1462       vector_shape_, HloOpcode::kAdd, negate, exp));
1463 
1464   module_->AddEntryComputation(builder.Build());
1465   SCOPED_TRACE(module_->ToString());
1466   RunAnalysis(GetParam());
1467 
1468   DependencyHloOrdering ordering(module_.get());
1469 
1470   EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate));
1471   EXPECT_TRUE(InstructionsMayInterfere(ordering, param, exp));
1472   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add));
1473 
1474   // Negate and exp interfere with each other, but not with add.
1475   EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp));
1476   EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate));
1477   EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add));
1478   EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate));
1479   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add));
1480   EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp));
1481 }
1482 
TEST_P(HloDataflowAnalysisTest,OverlappedValuesSequentialOrder)1483 TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) {
1484   // Identical to the test OverlappedValue but using a sequential ordering of
1485   // HLO instructions.
1486   //
1487   // param --> negate -> add
1488   //     \---> exp -----/
1489   //
1490   // Sequential order:
1491   //  param, negate, exp, add
1492   //
1493   // Liveness is identical to the DependencyHloOrdering.
1494   auto builder = HloComputation::Builder(TestName());
1495   auto param = builder.AddInstruction(
1496       HloInstruction::CreateParameter(0, vector_shape_, "param"));
1497   auto negate = builder.AddInstruction(
1498       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
1499   auto exp = builder.AddInstruction(
1500       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
1501   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1502       vector_shape_, HloOpcode::kAdd, negate, exp));
1503 
1504   auto entry = module_->AddEntryComputation(builder.Build());
1505   SCOPED_TRACE(module_->ToString());
1506   RunAnalysis(GetParam());
1507 
1508   HloSchedule schedule(module_.get());
1509   schedule.set_sequence(entry, {param, negate, exp, add});
1510   TF_ASSERT_OK(schedule.Verify());
1511   SequentialHloOrdering ordering(schedule);
1512 
1513   EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate));
1514   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
1515   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add));
1516 
1517   // Negate and exp interfere with each other, but not with add.
1518   EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp));
1519   EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate));
1520   EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add));
1521   EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate));
1522   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add));
1523   EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp));
1524 }
1525 
TEST_P(HloDataflowAnalysisTest,EmbeddedComputationInterference)1526 TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) {
1527   // Test MayInterfere() for embedded computation, specifically the interference
1528   // of values in different computations.
1529   //
1530   // embedded_computation:
1531   //   %embedded_param = Param(0)
1532   //   %embedded_log = Log(%embedded_param)
1533   //
1534   // entry computation:
1535   //   %param = Param(0)
1536   //   %negate = Negate(%param)
1537   //   %exp = Negate(%exp)
1538   //   %call = Call(embedded_computation, {%exp})
1539   //   %add = Add(%negate, %call)
1540   //
1541   // Note %negate is live across the call and should interfere with all values
1542   // in the embedded computation.
1543   auto embedded_builder = HloComputation::Builder(TestName() + "_embedded");
1544   auto embedded_param = embedded_builder.AddInstruction(
1545       HloInstruction::CreateParameter(0, vector_shape_, "embedded_param"));
1546   auto embedded_log =
1547       embedded_builder.AddInstruction(HloInstruction::CreateUnary(
1548           vector_shape_, HloOpcode::kLog, embedded_param));
1549   auto embedded_computation =
1550       module_->AddEmbeddedComputation(embedded_builder.Build());
1551 
1552   auto builder = HloComputation::Builder(TestName());
1553   auto param = builder.AddInstruction(
1554       HloInstruction::CreateParameter(0, vector_shape_, "param"));
1555   auto negate = builder.AddInstruction(
1556       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
1557   auto exp = builder.AddInstruction(
1558       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
1559   auto call = builder.AddInstruction(
1560       HloInstruction::CreateCall(vector_shape_, {exp}, embedded_computation));
1561   builder.AddInstruction(HloInstruction::CreateBinary(
1562       vector_shape_, HloOpcode::kAdd, negate, call));
1563   module_->AddEntryComputation(builder.Build());
1564   SCOPED_TRACE(module_->ToString());
1565   RunAnalysis(GetParam());
1566 
1567   DependencyHloOrdering ordering(module_.get());
1568 
1569   // Exp only use is the call so it should not interfere with values inside the
1570   // embedded computation.
1571   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, embedded_log));
1572 
1573   // Negate is live across the call and should interfere with values in the
1574   // embedded computation
1575   EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log));
1576 }
1577 
TEST_P(HloDataflowAnalysisTest,ConditionalWithIdentity)1578 TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) {
1579   // Test conditional with identity computations in both true and false cases.
1580   //
1581   // true_computation(F32[] %true_param):
1582   //   return %true_param
1583   //
1584   // false_computation(F32[] %false_param):
1585   //   return %false_param
1586   //
1587   // entry:
1588   //   %pred = Constant(true)
1589   //   %constant1 = Constant(56.0)
1590   //   %constant2 = Constant(12.0)
1591   //   return Conditional(%pred, %constant1, true_computation,
1592   //                      %constant2, false_computation)
1593 
1594   auto true_builder = HloComputation::Builder(TestName() + "_true");
1595   auto true_param = true_builder.AddInstruction(
1596       HloInstruction::CreateParameter(0, scalar_shape_, "true_param"));
1597   HloComputation* true_computation =
1598       module_->AddEmbeddedComputation(true_builder.Build());
1599 
1600   auto false_builder = HloComputation::Builder(TestName() + "_false");
1601   auto false_param = false_builder.AddInstruction(
1602       HloInstruction::CreateParameter(0, scalar_shape_, "false_param"));
1603   HloComputation* false_computation =
1604       module_->AddEmbeddedComputation(false_builder.Build());
1605 
1606   auto builder = HloComputation::Builder(TestName());
1607   auto pred = builder.AddInstruction(
1608       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1609   auto constant1 = builder.AddInstruction(
1610       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.0f)));
1611   auto constant2 = builder.AddInstruction(
1612       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.0f)));
1613   auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
1614       scalar_shape_, pred, constant1, true_computation, constant2,
1615       false_computation));
1616   module_->AddEntryComputation(builder.Build());
1617   SCOPED_TRACE(module_->ToString());
1618 
1619   const HloDataflowAnalysis& analysis = RunAnalysis(GetParam());
1620 
1621   EXPECT_TRUE(analysis.ValueIsDefinedAt(pred));
1622   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
1623   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
1624 
1625   EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param));
1626   EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param));
1627 
1628   EXPECT_EQ(analysis.GetUniqueValueAt(true_param),
1629             analysis.GetValueDefinedAt(constant1));
1630   EXPECT_EQ(analysis.GetUniqueValueAt(false_param),
1631             analysis.GetValueDefinedAt(constant2));
1632 
1633   EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(),
1634               ElementsAre(HloUse{conditional, 0, {}}));
1635   EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
1636               ElementsAre(HloUse{conditional, 1, {}}));
1637   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
1638               ElementsAre(HloUse{conditional, 2, {}}));
1639 
1640   bool ssa_form = GetParam();
1641   if (ssa_form) {
1642     EXPECT_EQ(analysis.values().size(), 4);
1643     EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional));
1644   } else {
1645     EXPECT_EQ(analysis.values().size(), 3);
1646     EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
1647     EXPECT_THAT(HloValuesAt(conditional),
1648                 UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
1649                                      analysis.GetValueDefinedAt(constant2)));
1650   }
1651 }
1652 
TEST_P(HloDataflowAnalysisTest,ConditionalTakingTupleOperand)1653 TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) {
1654   // Test conditional with true and false computations taking a tuple operand.
1655   //
1656   // true_computation((F32[], F32[]) %true_param):
1657   //   %true_x = GetTupleElement(%true_param, 0)
1658   //   %true_y = GetTupleElement(%true_param, 1)
1659   //   return Add(%true_x, %true_y)
1660   //
1661   // false_computation((F32[], F32[]) %false_param):
1662   //   %false_x = GetTupleElement(%false_param, 0)
1663   //   %false_y = GetTupleElement(%false_param, 1)
1664   //   return Subtract(%false_x, %false_y)
1665   //
1666   // entry:
1667   //   %pred = Constant(true)
1668   //   %constant1 = Constant(56.0)
1669   //   %constant2 = Constant(12.0)
1670   //   %tuple_operand = Tuple(%constant1, %constant2)
1671   //   return Conditional(%pred, %tuple_operand, true_computation,
1672   //                      %tuple_operand, false_computation)
1673 
1674   auto true_builder = HloComputation::Builder(TestName() + "_true");
1675   auto true_param = true_builder.AddInstruction(
1676       HloInstruction::CreateParameter(0, tuple_shape_, "true_param"));
1677   auto true_x = true_builder.AddInstruction(
1678       HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 0));
1679   auto true_y = true_builder.AddInstruction(
1680       HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 1));
1681   auto add = true_builder.AddInstruction(HloInstruction::CreateBinary(
1682       scalar_shape_, HloOpcode::kAdd, true_x, true_y));
1683   HloComputation* true_computation =
1684       module_->AddEmbeddedComputation(true_builder.Build());
1685 
1686   auto false_builder = HloComputation::Builder(TestName() + "_false");
1687   auto false_param = false_builder.AddInstruction(
1688       HloInstruction::CreateParameter(0, tuple_shape_, "false_param"));
1689   auto false_x = false_builder.AddInstruction(
1690       HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 0));
1691   auto false_y = false_builder.AddInstruction(
1692       HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 1));
1693   auto sub = false_builder.AddInstruction(HloInstruction::CreateBinary(
1694       scalar_shape_, HloOpcode::kSubtract, false_x, false_y));
1695   HloComputation* false_computation =
1696       module_->AddEmbeddedComputation(false_builder.Build());
1697 
1698   auto builder = HloComputation::Builder(TestName());
1699   auto pred = builder.AddInstruction(
1700       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1701   auto constant1 = builder.AddInstruction(
1702       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.0f)));
1703   auto constant2 = builder.AddInstruction(
1704       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.0f)));
1705   auto tuple_operand = builder.AddInstruction(
1706       HloInstruction::CreateTuple({constant1, constant2}));
1707   auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
1708       scalar_shape_, pred, tuple_operand, true_computation, tuple_operand,
1709       false_computation));
1710   module_->AddEntryComputation(builder.Build());
1711   SCOPED_TRACE(module_->ToString());
1712 
1713   const HloDataflowAnalysis& analysis = RunAnalysis(GetParam());
1714 
1715   EXPECT_TRUE(analysis.ValueIsDefinedAt(pred));
1716   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
1717   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
1718   EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand));
1719   EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
1720   EXPECT_TRUE(analysis.ValueIsDefinedAt(sub));
1721 
1722   EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param));
1723   EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param));
1724   EXPECT_FALSE(analysis.ValueIsDefinedAt(true_x));
1725   EXPECT_FALSE(analysis.ValueIsDefinedAt(true_y));
1726   EXPECT_FALSE(analysis.ValueIsDefinedAt(false_x));
1727   EXPECT_FALSE(analysis.ValueIsDefinedAt(false_y));
1728 
1729   EXPECT_EQ(analysis.GetUniqueValueAt(true_param),
1730             analysis.GetValueDefinedAt(tuple_operand));
1731   EXPECT_EQ(analysis.GetUniqueValueAt(false_param),
1732             analysis.GetValueDefinedAt(tuple_operand));
1733   EXPECT_EQ(analysis.GetUniqueValueAt(true_x),
1734             analysis.GetValueDefinedAt(constant1));
1735   EXPECT_EQ(analysis.GetUniqueValueAt(true_y),
1736             analysis.GetValueDefinedAt(constant2));
1737   EXPECT_EQ(analysis.GetUniqueValueAt(false_x),
1738             analysis.GetValueDefinedAt(constant1));
1739   EXPECT_EQ(analysis.GetUniqueValueAt(false_y),
1740             analysis.GetValueDefinedAt(constant2));
1741 
1742   EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(),
1743               ElementsAre(HloUse{conditional, 0, {}}));
1744   EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
1745               UnorderedElementsAre(HloUse{conditional, 1, {0}},
1746                                    HloUse{conditional, 2, {0}},
1747                                    HloUse{add, 0, {}}, HloUse{sub, 0, {}}));
1748   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(),
1749               UnorderedElementsAre(HloUse{conditional, 1, {1}},
1750                                    HloUse{conditional, 2, {1}},
1751                                    HloUse{add, 1, {}}, HloUse{sub, 1, {}}));
1752   EXPECT_THAT(analysis.GetValueDefinedAt(tuple_operand).uses(),
1753               UnorderedElementsAre(
1754                   HloUse{conditional, 1, {}}, HloUse{conditional, 2, {}},
1755                   HloUse{true_x, 0, {}}, HloUse{true_y, 0, {}},
1756                   HloUse{false_x, 0, {}}, HloUse{false_y, 0, {}}));
1757 
1758   bool ssa_form = GetParam();
1759   if (ssa_form) {
1760     EXPECT_EQ(analysis.values().size(), 7);
1761     EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional));
1762   } else {
1763     EXPECT_EQ(analysis.values().size(), 6);
1764     EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
1765     EXPECT_THAT(HloValuesAt(conditional),
1766                 UnorderedElementsAre(analysis.GetValueDefinedAt(add),
1767                                      analysis.GetValueDefinedAt(sub)));
1768   }
1769 }
1770 
TEST_P(HloDataflowAnalysisTest,NestedConditionals)1771 TEST_P(HloDataflowAnalysisTest, NestedConditionals) {
1772   // computation1(F32[] %param1):
1773   //   %ceil = Ceil(%param1)
1774   //   return %ceil
1775   //
1776   // computation2(F32[] %param2):
1777   //   %floor = Floor(%param2)
1778   //   return %floor
1779   //
1780   // computation3(F32[] %param3):
1781   //   %negate = Negate(%param3)
1782   //   return %negate
1783   //
1784   // inner_conditional((PRED, F32[], F32[]) %param_cond):
1785   //   %pred_cond = GetTupleElement(%param_cond, 0)
1786   //   %true_operand_cond = GetTupleElement(%param_cond, 1)
1787   //   %false_opearnd_cond = GetTupleElement(%param_cond, 2)
1788   //   return Conditional(%pred_cond, %true_operand_cond, computation1,
1789   //                      %false_operand_cond, computation2)
1790   //
1791   // entry:
1792   //   %pred1 = Constant(true)
1793   //   %pred2 = Constant(false)
1794   //   %constant1 = Constant(1.1);
1795   //   %constant2 = Constant(2.2);
1796   //   %constant3 = Constant(3.3);
1797   //   return Conditional(%pred1, (%pred2, %constant1, %constant2),
1798   //                      inner_conditional, %constant3, computation3)
1799 
1800   auto computation1 = module_->AddEmbeddedComputation(
1801       CreateR0F32UnaryOpComputation(HloOpcode::kCeil));
1802   auto computation2 = module_->AddEmbeddedComputation(
1803       CreateR0F32UnaryOpComputation(HloOpcode::kFloor));
1804   auto computation3 = module_->AddEmbeddedComputation(
1805       CreateR0F32UnaryOpComputation(HloOpcode::kNegate));
1806 
1807   // Build inner_conditional computation.
1808   const Shape scalar_bool_shape = ShapeUtil::MakeShape(PRED, {});
1809   const Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
1810       {scalar_bool_shape, scalar_shape_, scalar_shape_});
1811   auto inner_builder =
1812       HloComputation::Builder(TestName() + "_inner_conditional");
1813   auto param_cond = inner_builder.AddInstruction(
1814       HloInstruction::CreateParameter(0, tuple_param_shape, "param_cond"));
1815   auto pred_cond = inner_builder.AddInstruction(
1816       HloInstruction::CreateGetTupleElement(scalar_bool_shape, param_cond, 0));
1817   auto true_operand_cond = inner_builder.AddInstruction(
1818       HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 1));
1819   auto false_operand_cond = inner_builder.AddInstruction(
1820       HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 2));
1821   auto inner_conditional =
1822       inner_builder.AddInstruction(HloInstruction::CreateConditional(
1823           scalar_shape_, pred_cond, true_operand_cond, computation1,
1824           false_operand_cond, computation2));
1825   auto inner_conditional_computation =
1826       module_->AddEmbeddedComputation(inner_builder.Build());
1827 
1828   // Build entry computation.
1829   auto builder = HloComputation::Builder(TestName());
1830   auto pred1 = builder.AddInstruction(
1831       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1832   auto pred2 = builder.AddInstruction(
1833       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1834   auto constant1 = builder.AddInstruction(
1835       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
1836   auto constant2 = builder.AddInstruction(
1837       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.2f)));
1838   auto constant3 = builder.AddInstruction(
1839       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.3f)));
1840   auto tuple_operand = builder.AddInstruction(
1841       HloInstruction::CreateTuple({pred2, constant1, constant2}));
1842   auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
1843       scalar_shape_, pred1, tuple_operand, inner_conditional_computation,
1844       constant3, computation3));
1845   module_->AddEntryComputation(builder.Build());
1846   SCOPED_TRACE(module_->ToString());
1847 
1848   const HloDataflowAnalysis& analysis = RunAnalysis(GetParam());
1849 
1850   EXPECT_TRUE(analysis.ValueIsDefinedAt(pred1));
1851   EXPECT_TRUE(analysis.ValueIsDefinedAt(pred2));
1852   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
1853   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
1854   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant3));
1855   EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand));
1856   EXPECT_TRUE(analysis.ValueIsDefinedAt(computation1->root_instruction()));
1857   EXPECT_TRUE(analysis.ValueIsDefinedAt(computation2->root_instruction()));
1858   EXPECT_TRUE(analysis.ValueIsDefinedAt(computation3->root_instruction()));
1859 
1860   auto computation1_param = computation1->parameter_instruction(0);
1861   auto computation2_param = computation2->parameter_instruction(0);
1862   auto computation3_param = computation3->parameter_instruction(0);
1863   EXPECT_FALSE(analysis.ValueIsDefinedAt(computation1_param));
1864   EXPECT_FALSE(analysis.ValueIsDefinedAt(computation2_param));
1865   EXPECT_FALSE(analysis.ValueIsDefinedAt(computation3_param));
1866   EXPECT_EQ(analysis.GetUniqueValueAt(computation1_param),
1867             analysis.GetValueDefinedAt(constant1));
1868   EXPECT_EQ(analysis.GetUniqueValueAt(computation2_param),
1869             analysis.GetValueDefinedAt(constant2));
1870   EXPECT_EQ(analysis.GetUniqueValueAt(computation3_param),
1871             analysis.GetValueDefinedAt(constant3));
1872 
1873   EXPECT_FALSE(analysis.ValueIsDefinedAt(param_cond));
1874   EXPECT_FALSE(analysis.ValueIsDefinedAt(pred_cond));
1875   EXPECT_FALSE(analysis.ValueIsDefinedAt(true_operand_cond));
1876   EXPECT_FALSE(analysis.ValueIsDefinedAt(false_operand_cond));
1877   EXPECT_EQ(analysis.GetUniqueValueAt(param_cond),
1878             analysis.GetValueDefinedAt(tuple_operand));
1879   EXPECT_EQ(analysis.GetUniqueValueAt(pred_cond),
1880             analysis.GetValueDefinedAt(pred2));
1881   EXPECT_EQ(analysis.GetUniqueValueAt(true_operand_cond),
1882             analysis.GetValueDefinedAt(constant1));
1883   EXPECT_EQ(analysis.GetUniqueValueAt(false_operand_cond),
1884             analysis.GetValueDefinedAt(constant2));
1885 
1886   bool ssa_form = GetParam();
1887   if (ssa_form) {
1888     EXPECT_EQ(analysis.values().size(), 11);
1889     EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_conditional));
1890     EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional));
1891   } else {
1892     EXPECT_EQ(analysis.values().size(), 9);
1893     EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional));
1894     EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
1895     EXPECT_THAT(
1896         HloValuesAt(inner_conditional),
1897         UnorderedElementsAre(
1898             analysis.GetValueDefinedAt(computation1->root_instruction()),
1899             analysis.GetValueDefinedAt(computation2->root_instruction())));
1900     EXPECT_THAT(
1901         HloValuesAt(conditional),
1902         UnorderedElementsAre(
1903             analysis.GetValueDefinedAt(computation1->root_instruction()),
1904             analysis.GetValueDefinedAt(computation2->root_instruction()),
1905             analysis.GetValueDefinedAt(computation3->root_instruction())));
1906   }
1907 }
1908 
TEST_P(HloDataflowAnalysisTest,AddDependency)1909 TEST_P(HloDataflowAnalysisTest, AddDependency) {
1910   string module_string = R"(
1911 HloModule AddDependency
1912 ENTRY %AddDependency (p: f32[3]) -> f32[3] {
1913   %p = f32[3] parameter(0)
1914   %token0 = token[] after-all()
1915   ROOT %add_dep = f32[3] add-dependency(f32[3] %p, token[] %token0)
1916 }
1917 )";
1918   TF_ASSERT_OK_AND_ASSIGN(
1919       std::unique_ptr<HloModule> module,
1920       ParseHloString(module_string, GetModuleConfigForTest()));
1921 
1922   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloDataflowAnalysis> analysis,
1923                           HloDataflowAnalysis::Run(*module));
1924   const HloInstruction* root = module->entry_computation()->root_instruction();
1925   EXPECT_EQ(root->opcode(), HloOpcode::kAddDependency);
1926 
1927   // The after-all and parameter should define a value. Add-dependency should
1928   // not.
1929   EXPECT_EQ(analysis->values().size(), 2);
1930   EXPECT_FALSE(analysis->ValueIsDefinedAt(root));
1931 }
1932 
1933 INSTANTIATE_TEST_SUITE_P(HloDataflowAnalysisInstantiation,
1934                          HloDataflowAnalysisTest,
1935                          ::testing::Values(false, true));
1936 
1937 class HloDataflowAnalysisTestBase : public HloTestBase {
1938  protected:
BuildModule(std::unique_ptr<HloComputation> computation)1939   void BuildModule(std::unique_ptr<HloComputation> computation) {
1940     module_ = CreateNewUnverifiedModule();
1941     computation_ = module_->AddEntryComputation(std::move(computation));
1942   }
1943 
RunAnalysis(const HloDataflowAnalysis::FusionCanShareBufferFunction & fusion_can_share_buffer=nullptr)1944   void RunAnalysis(const HloDataflowAnalysis::FusionCanShareBufferFunction&
1945                        fusion_can_share_buffer = nullptr) {
1946     CHECK_NOTNULL(module_.get());
1947     dataflow_analysis_ =
1948         HloDataflowAnalysis::Run(*module_, /*ssa_form=*/false,
1949                                  /*bitcast_defines_value=*/false,
1950                                  fusion_can_share_buffer)
1951             .ConsumeValueOrDie();
1952   }
1953 
BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation)1954   void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
1955     BuildModule(std::move(computation));
1956     RunAnalysis();
1957   }
1958 
1959   std::unique_ptr<HloModule> module_;
1960   HloComputation* computation_ = nullptr;
1961   std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
1962 };
1963 
1964 class DoesNotUseOperandBufferTest : public HloDataflowAnalysisTestBase {};
1965 
TEST_F(DoesNotUseOperandBufferTest,GetTupleElement)1966 TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) {
1967   auto builder = HloComputation::Builder(TestName());
1968 
1969   Shape elem_shape = ShapeUtil::MakeShape(F32, {8});
1970   auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
1971       0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple"));
1972   auto gte0 = builder.AddInstruction(
1973       HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0));
1974   auto gte1 = builder.AddInstruction(
1975       HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1));
1976   builder.AddInstruction(
1977       HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1));
1978 
1979   BuildModuleAndRunAnalysis(builder.Build());
1980 
1981   // GetTupleElement instructions only access the top-level buffer of their
1982   // operand.
1983   EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, gte0));
1984   EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, gte1));
1985   EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte0));
1986   EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte1));
1987 }
1988 
TEST_F(DoesNotUseOperandBufferTest,FusedDynamicUpdateSlice)1989 TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
1990   auto builder = HloComputation::Builder(TestName());
1991 
1992   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
1993   auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
1994       0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
1995   auto gte0 = builder.AddInstruction(
1996       HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
1997   auto gte1 = builder.AddInstruction(
1998       HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
1999 
2000   // Create a DynamicUpdateSlice instruction of tuple element 1.
2001   auto starts = builder.AddInstruction(
2002       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
2003   auto update = builder.AddInstruction(HloInstruction::CreateConstant(
2004       LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
2005   auto dynamic_update_slice =
2006       builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2007           data_shape, gte1, update,
2008           std::initializer_list<HloInstruction*>({starts})));
2009   builder.AddInstruction(
2010       HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
2011 
2012   BuildModule(builder.Build());
2013   auto fusion = computation_->CreateFusionInstruction(
2014       {dynamic_update_slice, starts, update, gte1},
2015       HloInstruction::FusionKind::kLoop);
2016   RunAnalysis();
2017 
2018   // The fusion instruction never uses tuple element 0, but does use element 1.
2019   EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion));
2020   EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion));
2021 }
2022 
2023 // Similar to FusedDynamicUpdateSlice above, but tests indirect uses of the
2024 // parameter tuple.
TEST_F(DoesNotUseOperandBufferTest,IndirectUses)2025 TEST_F(DoesNotUseOperandBufferTest, IndirectUses) {
2026   auto builder = HloComputation::Builder(TestName());
2027 
2028   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2029   auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
2030       0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
2031   auto t0 = builder.AddInstruction(
2032       HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 0));
2033   auto t1 = builder.AddInstruction(
2034       HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 1));
2035   // Swap the tuple elements.
2036   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({t1, t0}));
2037 
2038   auto gte0 = builder.AddInstruction(
2039       HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
2040   auto gte1 = builder.AddInstruction(
2041       HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
2042 
2043   // Create a DynamicUpdateSlice instruction of tuple element 1.
2044   auto starts = builder.AddInstruction(
2045       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
2046   auto update = builder.AddInstruction(HloInstruction::CreateConstant(
2047       LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
2048   auto dynamic_update_slice =
2049       builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2050           data_shape, gte1, update,
2051           std::initializer_list<HloInstruction*>({starts})));
2052   builder.AddInstruction(
2053       HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
2054 
2055   BuildModule(builder.Build());
2056   auto fusion = computation_->CreateFusionInstruction(
2057       {dynamic_update_slice, starts, update, gte1},
2058       HloInstruction::FusionKind::kLoop);
2059   RunAnalysis();
2060 
2061   // The fusion instruction never uses tuple element 0, but does use element 1.
2062   EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion));
2063   EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion));
2064   // The same holds for the parameter tuple, except that the tuple elements are
2065   // swapped in 'tuple'.
2066   EXPECT_TRUE(
2067       dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {1}, fusion));
2068   EXPECT_FALSE(
2069       dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {0}, fusion));
2070 }
2071 
2072 class CanShareOperandBufferWithUserTest : public HloDataflowAnalysisTestBase {};
2073 
TEST_F(CanShareOperandBufferWithUserTest,ElementWiseSameShape)2074 TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) {
2075   auto builder = HloComputation::Builder(TestName());
2076 
2077   Shape shape = ShapeUtil::MakeShape(F32, {8});
2078   auto param = builder.AddInstruction(
2079       HloInstruction::CreateParameter(0, shape, "param"));
2080   auto exp = builder.AddInstruction(
2081       HloInstruction::CreateUnary(shape, HloOpcode::kExp, param));
2082   auto log = builder.AddInstruction(
2083       HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp));
2084 
2085   BuildModuleAndRunAnalysis(builder.Build());
2086 
2087   EXPECT_TRUE(
2088       dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {}));
2089   EXPECT_TRUE(
2090       dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, log, {}));
2091 }
2092 
TEST_F(CanShareOperandBufferWithUserTest,NonElementwiseLoopFusionCantAliasOperandBuffer)2093 TEST_F(CanShareOperandBufferWithUserTest,
2094        NonElementwiseLoopFusionCantAliasOperandBuffer) {
2095   auto builder = HloComputation::Builder(TestName());
2096   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2097 
2098   auto param0 = builder.AddInstruction(
2099       HloInstruction::CreateParameter(0, data_shape, "param0"));
2100 
2101   auto neg = builder.AddInstruction(
2102       HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, param0));
2103 
2104   auto reverse = builder.AddInstruction(
2105       HloInstruction::CreateReverse(data_shape, neg, {0, 1}));
2106 
2107   BuildModule(builder.Build());
2108   auto fusion = computation_->CreateFusionInstruction(
2109       {reverse, neg}, HloInstruction::FusionKind::kLoop);
2110   RunAnalysis();
2111 
2112   EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
2113                                                                  fusion, {}));
2114 }
2115 
TEST_F(CanShareOperandBufferWithUserTest,MultiOutputFusionCanAliasOperandBuffer)2116 TEST_F(CanShareOperandBufferWithUserTest,
2117        MultiOutputFusionCanAliasOperandBuffer) {
2118   auto builder = HloComputation::Builder(TestName());
2119   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2120 
2121   Shape in_shape = ShapeUtil::MakeShape(F32, {8});
2122   Shape out_shape = ShapeUtil::MakeShape(PRED, {8});
2123   auto param0 = builder.AddInstruction(
2124       HloInstruction::CreateParameter(0, in_shape, "param0"));
2125   auto param1 = builder.AddInstruction(
2126       HloInstruction::CreateParameter(1, in_shape, "param1"));
2127 
2128   auto copy0 = builder.AddInstruction(
2129       HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param0));
2130   auto copy1 = builder.AddInstruction(
2131       HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param1));
2132 
2133   auto tuple =
2134       builder.AddInstruction(HloInstruction::CreateTuple({copy1, copy0}));
2135 
2136   BuildModule(builder.Build());
2137   auto fusion = computation_->CreateFusionInstruction(
2138       {tuple, copy1, copy0}, HloInstruction::FusionKind::kLoop);
2139   RunAnalysis();
2140 
2141   EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
2142                                                                 fusion, {0}));
2143   EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
2144                                                                 fusion, {1}));
2145   EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
2146                                                                 fusion, {0}));
2147   EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
2148                                                                 fusion, {1}));
2149 }
2150 
TEST_F(CanShareOperandBufferWithUserTest,ElementwiseLoopFusionCantAliasOperandBuffer)2151 TEST_F(CanShareOperandBufferWithUserTest,
2152        ElementwiseLoopFusionCantAliasOperandBuffer) {
2153   auto builder = HloComputation::Builder(TestName());
2154   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2155 
2156   auto one = builder.AddInstruction(
2157       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2158   auto operand = builder.AddInstruction(
2159       HloInstruction::CreateBroadcast(data_shape, one, {1}));
2160 
2161   auto neg = builder.AddInstruction(
2162       HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, operand));
2163 
2164   auto exp = builder.AddInstruction(
2165       HloInstruction::CreateUnary(data_shape, HloOpcode::kExp, neg));
2166 
2167   BuildModule(builder.Build());
2168   auto fusion = computation_->CreateFusionInstruction(
2169       {exp, neg}, HloInstruction::FusionKind::kLoop);
2170   RunAnalysis();
2171 
2172   EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {},
2173                                                                 fusion, {}));
2174 }
2175 
TEST_F(CanShareOperandBufferWithUserTest,CanShareOperandWhenDynamicUpdateSliceIsFedByDynamicSliceWithSameIndex)2176 TEST_F(CanShareOperandBufferWithUserTest,
2177        CanShareOperandWhenDynamicUpdateSliceIsFedByDynamicSliceWithSameIndex) {
2178   auto builder = HloComputation::Builder(TestName());
2179   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2180   Shape slice_shape = ShapeUtil::MakeShape(F32, {1, 2});
2181 
2182   auto param = builder.AddInstruction(
2183       HloInstruction::CreateParameter(0, data_shape, "param0"));
2184   auto zero = builder.AddInstruction(
2185       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64>(0)));
2186   auto ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
2187       slice_shape, param, {zero, zero}, {1, 2, 2}));
2188 
2189   auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2190       data_shape, param, ds, {zero, zero}));
2191 
2192   BuildModule(builder.Build());
2193   auto fusion = computation_->CreateFusionInstruction(
2194       {dus, ds, zero}, HloInstruction::FusionKind::kLoop);
2195   RunAnalysis();
2196 
2197   EXPECT_TRUE(
2198       dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {}));
2199 }
2200 
TEST_F(CanShareOperandBufferWithUserTest,DUSWithSliceWithDifferentIndices)2201 TEST_F(CanShareOperandBufferWithUserTest, DUSWithSliceWithDifferentIndices) {
2202   const char* kModule = R"(
2203     HloModule test
2204 
2205     fused_computation {
2206       p0 = f32[10,20,30] parameter(0)
2207       p1 = s32[] parameter(1)
2208       p2 = s32[] parameter(2)
2209       p3 = s32[] parameter(3)
2210       slice = f32[1,1,30] dynamic-slice(p0, p1, p2, p3), dynamic_slice_sizes={1,1,30}
2211       ROOT dus = f32[10,20,30] dynamic-update-slice(p0, slice, p1, p3, p2)
2212     }
2213 
2214     ENTRY test {
2215       p0 = f32[10,20,30] parameter(0)
2216       p1 = s32[] parameter(1)
2217       p2 = s32[] parameter(2)
2218       p3 = s32[] parameter(3)
2219       ROOT fusion = f32[10,20,30] fusion(p0, p1, p2, p3), kind=kLoop, calls=fused_computation
2220     }
2221   )";
2222   TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule));
2223   auto* fusion = module_->entry_computation()->root_instruction();
2224   auto* param = module_->entry_computation()->parameter_instruction(0);
2225 
2226   RunAnalysis();
2227   EXPECT_FALSE(
2228       dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {}));
2229 }
2230 
TEST_F(CanShareOperandBufferWithUserTest,DUSWithSliceWithSameIndices)2231 TEST_F(CanShareOperandBufferWithUserTest, DUSWithSliceWithSameIndices) {
2232   const char* kModule = R"(
2233     HloModule test
2234 
2235     fused_computation {
2236       p0 = f32[10,20,30] parameter(0)
2237       p1 = s32[] parameter(1)
2238       p2 = s32[] parameter(2)
2239       p3 = s32[] parameter(3)
2240       slice = f32[1,1,30] dynamic-slice(p0, p1, p2, p3), dynamic_slice_sizes={1,1,30}
2241       ROOT dus = f32[10,20,30] dynamic-update-slice(p0, slice, p1, p2, p3)
2242     }
2243 
2244     ENTRY test {
2245       p0 = f32[10,20,30] parameter(0)
2246       p1 = s32[] parameter(1)
2247       p2 = s32[] parameter(2)
2248       p3 = s32[] parameter(3)
2249       ROOT fusion = f32[10,20,30] fusion(p0, p1, p2, p3), kind=kLoop, calls=fused_computation
2250     }
2251   )";
2252   TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule));
2253   auto* fusion = module_->entry_computation()->root_instruction();
2254   auto* param = module_->entry_computation()->parameter_instruction(0);
2255 
2256   RunAnalysis();
2257   EXPECT_TRUE(
2258       dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {}));
2259 }
2260 
TEST_F(CanShareOperandBufferWithUserTest,ElementWiseDifferentShape)2261 TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
2262   auto builder = HloComputation::Builder(TestName());
2263 
2264   Shape in_shape = ShapeUtil::MakeShape(F32, {8});
2265   Shape out_shape = ShapeUtil::MakeShape(PRED, {8});
2266   auto param0 = builder.AddInstruction(
2267       HloInstruction::CreateParameter(0, in_shape, "param0"));
2268   auto param1 = builder.AddInstruction(
2269       HloInstruction::CreateParameter(1, in_shape, "param1"));
2270   auto result = builder.AddInstruction(HloInstruction::CreateCompare(
2271       out_shape, param0, param1, ComparisonDirection::kEq));
2272 
2273   BuildModuleAndRunAnalysis(builder.Build());
2274 
2275   EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
2276                                                                  result, {}));
2277   EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
2278                                                                  result, {}));
2279 }
2280 
TEST_F(CanShareOperandBufferWithUserTest,CopyShares)2281 TEST_F(CanShareOperandBufferWithUserTest, CopyShares) {
2282   auto builder = HloComputation::Builder(TestName());
2283 
2284   Shape shape = ShapeUtil::MakeShape(F32, {8});
2285   auto param = builder.AddInstruction(
2286       HloInstruction::CreateParameter(0, shape, "param"));
2287   auto exp = builder.AddInstruction(
2288       HloInstruction::CreateUnary(shape, HloOpcode::kExp, param));
2289   auto copy = builder.AddInstruction(
2290       HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp));
2291 
2292   BuildModuleAndRunAnalysis(builder.Build());
2293 
2294   EXPECT_TRUE(
2295       dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {}));
2296   EXPECT_TRUE(
2297       dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, copy, {}));
2298 }
2299 
TEST_F(CanShareOperandBufferWithUserTest,FusedDynamicUpdateSlice)2300 TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
2301   auto builder = HloComputation::Builder(TestName());
2302 
2303   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2304   auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
2305       0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
2306   auto gte0 = builder.AddInstruction(
2307       HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
2308   auto gte1 = builder.AddInstruction(
2309       HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
2310 
2311   // Create a DynamicUpdateSlice instruction of tuple element 1.
2312   auto starts = builder.AddInstruction(
2313       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
2314   auto update = builder.AddInstruction(HloInstruction::CreateConstant(
2315       LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
2316   auto dynamic_update_slice =
2317       builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2318           data_shape, gte1, update,
2319           std::initializer_list<HloInstruction*>({starts})));
2320   builder.AddInstruction(
2321       HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
2322 
2323   BuildModule(builder.Build());
2324   auto fusion = computation_->CreateFusionInstruction(
2325       {dynamic_update_slice, starts, update, gte1},
2326       HloInstruction::FusionKind::kLoop);
2327   RunAnalysis();
2328 
2329   // The fusion instruction can share with tuple element 1.
2330   EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(tuple, {0},
2331                                                                  fusion, {}));
2332   EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(tuple, {1},
2333                                                                 fusion, {}));
2334 }
2335 
TEST_F(CanShareOperandBufferWithUserTest,FusedDynamicUpdateSliceWithConvertCanShare)2336 TEST_F(CanShareOperandBufferWithUserTest,
2337        FusedDynamicUpdateSliceWithConvertCanShare) {
2338   auto builder = HloComputation::Builder(TestName());
2339 
2340   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2341   Shape data_shape_bf16 = ShapeUtil::MakeShape(BF16, {8});
2342   auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
2343       0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
2344   auto gte0 = builder.AddInstruction(
2345       HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
2346   auto gte1 = builder.AddInstruction(
2347       HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
2348 
2349   auto convert1 = builder.AddInstruction(
2350       HloInstruction::CreateConvert(data_shape_bf16, gte1));
2351 
2352   // Create a DynamicUpdateSlice instruction of tuple element 1.
2353   auto starts = builder.AddInstruction(
2354       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
2355   auto update = builder.AddInstruction(HloInstruction::CreateConstant(
2356       LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
2357   auto dynamic_update_slice =
2358       builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2359           data_shape_bf16, convert1, update,
2360           std::initializer_list<HloInstruction*>({starts})));
2361 
2362   auto convert2 = builder.AddInstruction(
2363       HloInstruction::CreateConvert(data_shape, dynamic_update_slice));
2364   builder.AddInstruction(HloInstruction::CreateTuple({gte0, convert2}));
2365 
2366   BuildModule(builder.Build());
2367   auto fusion = computation_->CreateFusionInstruction(
2368       {convert2, dynamic_update_slice, starts, update, convert1},
2369       HloInstruction::FusionKind::kLoop);
2370   RunAnalysis();
2371 
2372   EXPECT_TRUE(
2373       dataflow_analysis_->CanShareOperandBufferWithUser(gte1, {}, fusion, {}));
2374 }
2375 
TEST_F(CanShareOperandBufferWithUserTest,DynamicUpdateSliceCanShare)2376 TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
2377   auto builder = HloComputation::Builder(TestName());
2378 
2379   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2380   Shape update_shape = ShapeUtil::MakeShape(F32, {4});
2381   Shape starts_shape = ShapeUtil::MakeShape(S32, {1});
2382   auto data = builder.AddInstruction(
2383       HloInstruction::CreateParameter(0, data_shape, "data"));
2384   auto update = builder.AddInstruction(
2385       HloInstruction::CreateParameter(1, update_shape, "update"));
2386   auto start0 = builder.AddInstruction(
2387       HloInstruction::CreateParameter(2, starts_shape, "start0"));
2388   auto start1 = builder.AddInstruction(
2389       HloInstruction::CreateParameter(3, starts_shape, "start1"));
2390 
2391   auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2392       data_shape, data, update, {start0, start1}));
2393 
2394   BuildModuleAndRunAnalysis(builder.Build());
2395 
2396   // The DynamicUpdateSlice instruction can share with the data operand, but not
2397   // with update or starts.
2398   EXPECT_TRUE(
2399       dataflow_analysis_->CanShareOperandBufferWithUser(data, {}, dus, {}));
2400   EXPECT_FALSE(
2401       dataflow_analysis_->CanShareOperandBufferWithUser(update, {}, dus, {}));
2402   EXPECT_FALSE(
2403       dataflow_analysis_->CanShareOperandBufferWithUser(start0, {}, dus, {}));
2404   EXPECT_FALSE(
2405       dataflow_analysis_->CanShareOperandBufferWithUser(start1, {}, dus, {}));
2406 }
2407 
TEST_F(CanShareOperandBufferWithUserTest,ScatterCanShare)2408 TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) {
2409   const char* hlo_text = R"(
2410     HloModule TensorFlowScatterV1
2411 
2412     update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2413       lhs = s32[] parameter(0)
2414       ROOT rhs = s32[] parameter(1)
2415     }
2416 
2417     ENTRY main {
2418       operand = s32[3,3] parameter(0)
2419       indices = s32[2] parameter(1)
2420       updates = s32[2,3] parameter(2)
2421       ROOT scatter = s32[3,3] scatter(operand, indices, updates),
2422           to_apply=update_s32,
2423           update_window_dims={1},
2424           inserted_window_dims={0},
2425           scatter_dims_to_operand_dims={0},
2426           index_vector_dim=1
2427     }
2428   )";
2429   TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text));
2430   computation_ = module_->entry_computation();
2431   RunAnalysis();
2432 
2433   HloInstruction* operand_param = computation_->parameter_instruction(0);
2434   HloInstruction* indices_param = computation_->parameter_instruction(1);
2435   HloInstruction* updates_param = computation_->parameter_instruction(2);
2436   HloInstruction* scatter = computation_->root_instruction();
2437 
2438   EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(
2439       operand_param, {}, scatter, {}));
2440   EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(
2441       indices_param, {}, scatter, {}));
2442   EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(
2443       updates_param, {}, scatter, {}));
2444 }
2445 
TEST_F(CanShareOperandBufferWithUserTest,SortCanShare)2446 TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) {
2447   auto builder = HloComputation::Builder(TestName());
2448   module_ = CreateNewVerifiedModule();
2449 
2450   Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
2451   auto keys = builder.AddInstruction(
2452       HloInstruction::CreateParameter(0, keys_shape, "keys"));
2453   TF_ASSERT_OK_AND_ASSIGN(
2454       auto* sort, MakeSortHlo(keys_shape, {keys}, -1, /*is_stable=*/false,
2455                               &builder, module_.get()));
2456 
2457   computation_ = module_->AddEntryComputation(builder.Build());
2458   RunAnalysis();
2459 
2460   EXPECT_TRUE(
2461       dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {}));
2462 }
2463 
TEST_F(CanShareOperandBufferWithUserTest,SortCanShareWithTupleUser)2464 TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) {
2465   auto builder = HloComputation::Builder(TestName());
2466   module_ = CreateNewVerifiedModule();
2467 
2468   Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
2469   Shape values_shape = ShapeUtil::MakeShape(F32, {8});
2470   auto keys = builder.AddInstruction(
2471       HloInstruction::CreateParameter(0, keys_shape, "keys"));
2472   auto values = builder.AddInstruction(
2473       HloInstruction::CreateParameter(1, values_shape, "values"));
2474   TF_ASSERT_OK_AND_ASSIGN(
2475       auto* sort,
2476       MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}),
2477                   {keys, values}, 0, /*is_stable=*/false, &builder,
2478                   module_.get()));
2479 
2480   computation_ = module_->AddEntryComputation(builder.Build());
2481   RunAnalysis();
2482 
2483   // The buffer for the keys can be shared with the first tuple entry.
2484   EXPECT_TRUE(
2485       dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {0}));
2486   // The buffer for the values can be shared with the second tuple entry.
2487   EXPECT_TRUE(
2488       dataflow_analysis_->CanShareOperandBufferWithUser(values, {}, sort, {1}));
2489   // Verify that the buffers are not shared with the "wrong" tuple entry.
2490   EXPECT_FALSE(
2491       dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {1}));
2492   EXPECT_FALSE(
2493       dataflow_analysis_->CanShareOperandBufferWithUser(values, {}, sort, {0}));
2494 }
2495 
TEST_F(CanShareOperandBufferWithUserTest,FusedDotAdd)2496 TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
2497   auto builder = HloComputation::Builder(TestName());
2498   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2499 
2500   auto a = builder.AddInstruction(HloInstruction::CreateConstant(
2501       LiteralUtil::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
2502   auto b = builder.AddInstruction(HloInstruction::CreateConstant(
2503       LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
2504 
2505   DotDimensionNumbers dot_dnums;
2506   dot_dnums.add_lhs_contracting_dimensions(1);
2507   dot_dnums.add_rhs_contracting_dimensions(0);
2508   PrecisionConfig precision_config;
2509   precision_config.mutable_operand_precision()->Resize(
2510       2, PrecisionConfig::DEFAULT);
2511   auto dot = builder.AddInstruction(
2512       HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config));
2513 
2514   auto one = builder.AddInstruction(
2515       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2516   auto add_operand = builder.AddInstruction(
2517       HloInstruction::CreateBroadcast(data_shape, one, {1}));
2518 
2519   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
2520       data_shape, HloOpcode::kAdd, dot, add_operand));
2521 
2522   BuildModule(builder.Build());
2523   auto fusion = computation_->CreateFusionInstruction(
2524       {add, dot}, HloInstruction::FusionKind::kOutput);
2525   RunAnalysis();
2526 
2527   // Output fused dot add should be able to share buffer with 'add_operand'.
2528   EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(add_operand, {},
2529                                                                 fusion, {}));
2530 }
2531 
TEST_F(CanShareOperandBufferWithUserTest,OutputFusionCantAliasOperandBuffer)2532 TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
2533   auto builder = HloComputation::Builder(TestName());
2534   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2535 
2536   auto one = builder.AddInstruction(
2537       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2538   auto operand = builder.AddInstruction(
2539       HloInstruction::CreateBroadcast(data_shape, one, {1}));
2540 
2541   auto reverse = builder.AddInstruction(
2542       HloInstruction::CreateReverse(data_shape, operand, {0, 1}));
2543 
2544   auto two = builder.AddInstruction(HloInstruction::CreateConstant(
2545       LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
2546 
2547   auto add = builder.AddInstruction(
2548       HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two));
2549 
2550   BuildModule(builder.Build());
2551   auto fusion = computation_->CreateFusionInstruction(
2552       {add, two, reverse}, HloInstruction::FusionKind::kOutput);
2553   RunAnalysis();
2554 
2555   // Output fused operand->reverse->add cannot alias operand buffer 'operand'.
2556   EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {},
2557                                                                  fusion, {}));
2558 }
2559 
TEST_F(CanShareOperandBufferWithUserTest,FusionCanShareBufferCustomized)2560 TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) {
2561   auto builder = HloComputation::Builder(TestName());
2562   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2563 
2564   auto one = builder.AddInstruction(
2565       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2566   auto operand = builder.AddInstruction(
2567       HloInstruction::CreateBroadcast(data_shape, one, {1}));
2568   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
2569       data_shape, HloOpcode::kMultiply, operand, operand));
2570   auto two = builder.AddInstruction(HloInstruction::CreateConstant(
2571       LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
2572   auto add = builder.AddInstruction(
2573       HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, mul, two));
2574 
2575   BuildModule(builder.Build());
2576   auto fusion = computation_->CreateFusionInstruction(
2577       {add, two, mul}, HloInstruction::FusionKind::kInput);
2578   RunAnalysis(/*fusion_can_share_buffer=*/[](const HloInstruction* fusion,
2579                                              const HloInstruction*) {
2580     return fusion->fusion_kind() == HloInstruction::FusionKind::kLoop;
2581   });
2582 
2583   EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {},
2584                                                                  fusion, {}));
2585 }
2586 
TEST_F(CanShareOperandBufferWithUserTest,WhileCanShare)2587 TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
2588   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2589 
2590   auto make_cond = [&data_shape]() {
2591     auto builder = HloComputation::Builder(TestName() + ".Cond");
2592     auto data = builder.AddInstruction(
2593         HloInstruction::CreateParameter(0, data_shape, "data"));
2594     builder.AddInstruction(HloInstruction::CreateCompare(
2595         ShapeUtil::MakeShape(PRED, {}), data, data, ComparisonDirection::kEq));
2596     return builder.Build();
2597   };
2598 
2599   auto make_body = [&data_shape]() {
2600     auto builder = HloComputation::Builder(TestName() + ".Body");
2601     auto data = builder.AddInstruction(
2602         HloInstruction::CreateParameter(0, data_shape, "data"));
2603     builder.AddInstruction(
2604         HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data));
2605     return builder.Build();
2606   };
2607 
2608   module_ = CreateNewUnverifiedModule();
2609   HloComputation* cond_computation =
2610       module_->AddEmbeddedComputation(make_cond());
2611   HloComputation* body_computation =
2612       module_->AddEmbeddedComputation(make_body());
2613 
2614   auto builder = HloComputation::Builder(TestName());
2615   auto data = builder.AddInstruction(
2616       HloInstruction::CreateParameter(0, data_shape, "data"));
2617   auto whil = builder.AddInstruction(HloInstruction::CreateWhile(
2618       data_shape, cond_computation, body_computation, data));
2619   computation_ = module_->AddEntryComputation(builder.Build());
2620 
2621   RunAnalysis();
2622 
2623   // The While instruction can share with the data operand.
2624   EXPECT_TRUE(
2625       dataflow_analysis_->CanShareOperandBufferWithUser(data, {}, whil, {}));
2626 }
2627 
2628 // Tests that Call can alias operand buffer if the only use of the operand
2629 // in the called computation is an elementwise instruction.
TEST_F(CanShareOperandBufferWithUserTest,CallToComputationWithFusionRoot)2630 TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
2631   Shape shape = ShapeUtil::MakeShape(F32, {8});
2632   // Build sub-computation with fusion root.
2633   auto sub_builder = HloComputation::Builder(TestName() + "_sub");
2634   auto sub_param = sub_builder.AddInstruction(
2635       HloInstruction::CreateParameter(0, shape, "sub_param"));
2636   auto one = sub_builder.AddInstruction(
2637       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2638   auto ones = sub_builder.AddInstruction(
2639       HloInstruction::CreateBroadcast(shape, one, {1}));
2640   auto add = sub_builder.AddInstruction(
2641       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones));
2642 
2643   module_ = CreateNewUnverifiedModule();
2644   auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build());
2645   sub_computation->CreateFusionInstruction({add, ones},
2646                                            HloInstruction::FusionKind::kLoop);
2647 
2648   // Build entry-computation with kCall which calls 'sub_computation'.
2649   auto builder = HloComputation::Builder(TestName());
2650 
2651   auto param = builder.AddInstruction(
2652       HloInstruction::CreateParameter(0, shape, "param"));
2653   auto reverse =
2654       builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0}));
2655   auto call = builder.AddInstruction(
2656       HloInstruction::CreateCall(shape, {reverse}, sub_computation));
2657   computation_ = module_->AddEntryComputation(builder.Build());
2658 
2659   RunAnalysis();
2660 
2661   EXPECT_TRUE(
2662       dataflow_analysis_->CanShareOperandBufferWithUser(reverse, {}, call, {}));
2663 }
2664 
2665 }  // namespace
2666 }  // namespace xla
2667