1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
17 
18 #include <map>
19 #include <memory>
20 
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
23 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
26 #include "tensorflow/compiler/xla/service/instruction_fusion.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/test.h"
29 #include "tensorflow/compiler/xla/test_helpers.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/test.h"
35 
36 namespace xla {
37 namespace {
38 
39 using ::testing::UnorderedElementsAre;
40 
41 class HloAliasAnalysisTest : public HloTestBase {
42  protected:
HloAliasAnalysisTest()43   HloAliasAnalysisTest() : HloTestBase() {
44     module_ = CreateNewVerifiedModule();
45   }
46 
47   // Run alias analysis on the member module. For convenience returns a
48   // reference to the generated analysis stored in analysis_.
RunAnalysis()49   HloAliasAnalysis& RunAnalysis() {
50     analysis_ = HloAliasAnalysis::Run(module_.get(),
51                                       /*fusion_can_share_buffer=*/nullptr)
52                     .ConsumeValueOrDie();
53     return *analysis_;
54   }
55 
56   // Return a vector of the buffers in the buffer set at the current position
57   // sorted by buffer id.
GetBuffersAt(const HloInstruction * instruction,const ShapeIndex & index={}) const58   std::vector<HloBuffer> GetBuffersAt(const HloInstruction* instruction,
59                                       const ShapeIndex& index = {}) const {
60     std::set<HloBuffer::Id> buffer_ids;
61     for (const HloValue* value : analysis_->dataflow_analysis()
62                                      .GetValueSet(instruction, index)
63                                      .values()) {
64       buffer_ids.insert(analysis_->GetBufferContainingValue(*value).id());
65     }
66 
67     std::vector<HloBuffer> buffers;
68     for (HloBuffer::Id id : buffer_ids) {
69       buffers.push_back(analysis_->GetBuffer(id));
70     }
71     return buffers;
72   }
73 
74   // Return a vector containing all of the HloValues in the given buffer.
GetValuesInBuffer(const HloBuffer & buffer)75   std::vector<HloValue> GetValuesInBuffer(const HloBuffer& buffer) {
76     std::vector<HloValue> values;
77     for (const HloValue* value : buffer.values()) {
78       values.push_back(*value);
79     }
80     return values;
81   }
82 
83   // Return the HloValue defined at the given position.
GetValueDefinedAt(const HloInstruction * instruction,const ShapeIndex & index={}) const84   const HloValue& GetValueDefinedAt(const HloInstruction* instruction,
85                                     const ShapeIndex& index = {}) const {
86     return analysis_->dataflow_analysis().GetValueDefinedAt(instruction, index);
87   }
88 
89   // Returns true if any values held in the same buffer interfere. Generally, in
90   // the compiler pipeline copy-insertion will guarantee that this interference
91   // never occurs, but HLO graphs with interference can be explicitly
92   // constructed.
AnyValuesInSameBufferInterfere()93   bool AnyValuesInSameBufferInterfere() {
94     DependencyHloOrdering ordering(module_.get());
95     for (const HloBuffer& buffer : analysis_->buffers()) {
96       for (const HloValue* value_a : buffer.values()) {
97         for (const HloValue* value_b : buffer.values()) {
98           if (*value_a != *value_b &&
99               ordering.MayInterfere(*value_a, *value_b,
100                                     analysis_->dataflow_analysis())) {
101             VLOG(1) << *value_a << " interferes with " << *value_b
102                     << " in buffer: " << buffer;
103             return true;
104           }
105         }
106       }
107     }
108     return false;
109   }
110 
111   std::unique_ptr<HloModule> module_;
112   std::unique_ptr<HloAliasAnalysis> analysis_;
113 
114   const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
115 };
116 
TEST_F(HloAliasAnalysisTest,BinaryOperation)117 TEST_F(HloAliasAnalysisTest, BinaryOperation) {
118   // Test the analysis on a single binary operation (Add).
119   auto builder = HloComputation::Builder(TestName());
120   auto constant1 = builder.AddInstruction(
121       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
122   auto constant2 = builder.AddInstruction(
123       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
124   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
125       scalar_shape_, HloOpcode::kAdd, constant1, constant2));
126   module_->AddEntryComputation(builder.Build());
127   SCOPED_TRACE(module_->ToString());
128 
129   const HloAliasAnalysis& analysis = RunAnalysis();
130 
131   EXPECT_EQ(analysis.buffers().size(), 3);
132 
133   // All of the buffer sets should trivially contain a single buffer containing
134   // a single value.
135   for (const HloInstruction* instruction : {constant1, constant2, add}) {
136     EXPECT_EQ(analysis.GetUniqueBufferAt(instruction).GetUniqueValue(),
137               GetValueDefinedAt(instruction));
138   }
139 
140   EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(add));
141   EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(add));
142 
143   EXPECT_FALSE(AnyValuesInSameBufferInterfere());
144 }
145 
TEST_F(HloAliasAnalysisTest,TupleAndGtes)146 TEST_F(HloAliasAnalysisTest, TupleAndGtes) {
147   // Verify the analysis for a Tuple and GetTupleElement instructions.
148   auto builder = HloComputation::Builder(TestName());
149   auto param0 = builder.AddInstruction(
150       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
151   auto param1 = builder.AddInstruction(
152       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
153   auto tuple =
154       builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
155   auto gte0 = builder.AddInstruction(
156       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 0));
157   auto gte1 = builder.AddInstruction(
158       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1));
159   builder.AddInstruction(
160       HloInstruction::CreateBinary(scalar_shape_, HloOpcode::kAdd, gte0, gte1));
161   module_->AddEntryComputation(builder.Build());
162   SCOPED_TRACE(module_->ToString());
163 
164   const HloAliasAnalysis& analysis = RunAnalysis();
165 
166   EXPECT_EQ(analysis.buffers().size(), 4);
167 
168   // Verify the expected aliasing of the tuple elements.
169   EXPECT_EQ(analysis.GetUniqueBufferAt(tuple, /*index=*/{}).GetUniqueValue(),
170             GetValueDefinedAt(tuple, /*index=*/{}));
171   EXPECT_EQ(analysis.GetUniqueBufferAt(tuple, /*index=*/{0}).GetUniqueValue(),
172             GetValueDefinedAt(param0));
173   EXPECT_EQ(analysis.GetUniqueBufferAt(tuple, /*index=*/{1}).GetUniqueValue(),
174             GetValueDefinedAt(param1));
175 
176   // The tuple operand, tuple element, and result of the GTE instruction should
177   // all be the same buffer.
178   EXPECT_EQ(analysis.GetUniqueBufferAt(param0),
179             analysis.GetUniqueBufferAt(tuple, /*index=*/{0}));
180   EXPECT_EQ(analysis.GetUniqueBufferAt(param0),
181             analysis.GetUniqueBufferAt(gte0));
182 
183   // Verify the positions of an aliased buffer.
184   EXPECT_THAT(
185       analysis.GetUniqueBufferAt(param0).ComputePositions(),
186       UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}},
187                            HloPosition{gte0, {}}));
188 
189   EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(tuple));
190   EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(tuple));
191 
192   EXPECT_FALSE(AnyValuesInSameBufferInterfere());
193 }
194 
TEST_F(HloAliasAnalysisTest,NondistinctTuple)195 TEST_F(HloAliasAnalysisTest, NondistinctTuple) {
196   // Test a expression with a non-distinct buffer set.
197   auto builder = HloComputation::Builder(TestName());
198   auto param0 = builder.AddInstruction(
199       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
200   auto param1 = builder.AddInstruction(
201       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
202   // param0 is included twice in the tuple.
203   auto tuple = builder.AddInstruction(
204       HloInstruction::CreateTuple({param0, param1, param0}));
205   module_->AddEntryComputation(builder.Build());
206   SCOPED_TRACE(module_->ToString());
207 
208   const HloAliasAnalysis& analysis = RunAnalysis();
209 
210   EXPECT_THAT(
211       analysis.GetUniqueBufferAt(param0).ComputePositions(),
212       UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}},
213                            HloPosition{tuple, {2}}));
214 
215   EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(tuple));
216   EXPECT_FALSE(analysis.InstructionBuffersAreDistinct(tuple));
217 
218   EXPECT_FALSE(AnyValuesInSameBufferInterfere());
219 }
220 
TEST_F(HloAliasAnalysisTest,ParametersWithAliasing)221 TEST_F(HloAliasAnalysisTest, ParametersWithAliasing) {
222   const Shape tuple_shape =
223       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
224 
225   auto builder = HloComputation::Builder(TestName());
226   auto param = builder.AddInstruction(
227       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
228   auto gte0 = builder.AddInstruction(
229       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
230   auto gte1 = builder.AddInstruction(
231       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
232 
233   auto negate0 = builder.AddInstruction(
234       HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
235   auto negate1 = builder.AddInstruction(
236       HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
237 
238   auto tuple =
239       builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1}));
240   module_->AddEntryComputation(builder.Build());
241   SCOPED_TRACE(module_->ToString());
242 
243   TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
244       /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
245       /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
246   TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
247       /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1},
248       /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
249 
250   // Cannot alias an output twice.
251   ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias(
252       /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0},
253       /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
254 
255   const HloAliasAnalysis& analysis = RunAnalysis();
256 
257   EXPECT_EQ(analysis.GetUniqueBufferAt(gte0),
258             analysis.GetUniqueBufferAt(tuple, /*index=*/{0}));
259 
260   EXPECT_EQ(analysis.GetUniqueBufferAt(gte1),
261             analysis.GetUniqueBufferAt(tuple, /*index=*/{1}));
262 }
263 
TEST_F(HloAliasAnalysisTest,ParametersWithCrossAliasing)264 TEST_F(HloAliasAnalysisTest, ParametersWithCrossAliasing) {
265   // parameter 0 aliased with output 1 and parameter 1 aliased with output 0.
266   //
267   //  (p0 ,  p1)
268   //     \   /
269   //      \ /
270   // alias X
271   //      / \
272   //     /   \
273   //  (p0  ,  p1)
274   const Shape tuple_shape =
275       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
276 
277   auto builder = HloComputation::Builder(TestName());
278   auto param = builder.AddInstruction(
279       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
280   auto gte0 = builder.AddInstruction(
281       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
282   auto gte1 = builder.AddInstruction(
283       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
284   auto tuple =
285       builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
286   module_->AddEntryComputation(builder.Build());
287   SCOPED_TRACE(module_->ToString());
288 
289   TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
290       /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1},
291       /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
292   TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
293       /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0},
294       /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
295 
296   // Cannot alias an output twice.
297   ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias(
298       /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1},
299       /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
300 
301   const HloAliasAnalysis& analysis = RunAnalysis();
302 
303   // Every Ops in this graph are aliased with each other.
304   EXPECT_EQ(analysis.GetUniqueBufferAt(gte0),
305             analysis.GetUniqueBufferAt(tuple, /*index=*/{0}));
306   EXPECT_EQ(analysis.GetUniqueBufferAt(gte0),
307             analysis.GetUniqueBufferAt(tuple, /*index=*/{1}));
308 
309   EXPECT_EQ(analysis.GetUniqueBufferAt(gte1),
310             analysis.GetUniqueBufferAt(tuple, /*index=*/{0}));
311   EXPECT_EQ(analysis.GetUniqueBufferAt(gte1),
312             analysis.GetUniqueBufferAt(tuple, /*index=*/{1}));
313 }
314 
TEST_F(HloAliasAnalysisTest,InputOutputAliasingWithWhile)315 TEST_F(HloAliasAnalysisTest, InputOutputAliasingWithWhile) {
316   // Test a simple single while instruction can be aliased with input and output
317   // of the computation.
318   //
319   // body((F32[], F32[]) %tuple_param):
320   //   %add = Add(%tuple_param{0}, %tuple_param{1})
321   //   return Tuple(%tuple_param{0}, %add)
322   //
323   // condition((F32[], F32[]) %tuple_param):
324   //   return Constant(false)
325   //
326   // entry:
327   //   %param1 = param1
328   //   %while = While(%param1, body, condition)
329   //   %while_1 = GTE(%while, 0)
330   //   %while_2 = GTE(%while, 1)
331   //   %negate_1 = Negate(%while_1)
332   //   %negate_2 = Negate(%while_2)
333   //   return Tuple(negate_1, negate_2)
334   //
335   const Shape tuple_shape =
336       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
337 
338   // Element 0 passes transparently through the body.
339   auto body_builder = HloComputation::Builder("body");
340   auto body_param = body_builder.AddInstruction(
341       HloInstruction::CreateParameter(0, tuple_shape, "param"));
342   auto body_element_0 = body_builder.AddInstruction(
343       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
344   auto body_element_1 = body_builder.AddInstruction(
345       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
346   auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
347       scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
348   auto body_tuple = body_builder.AddInstruction(
349       HloInstruction::CreateTuple({body_element_0, add}));
350   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
351 
352   // Condition computation trivially returns a constant "false".
353   auto cond_builder = HloComputation::Builder("condition");
354   auto cond_param = cond_builder.AddInstruction(
355       HloInstruction::CreateParameter(0, tuple_shape, "param"));
356   cond_builder.AddInstruction(
357       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
358   HloComputation* condition =
359       module_->AddEmbeddedComputation(cond_builder.Build());
360 
361   auto builder = HloComputation::Builder(TestName());
362   auto param = builder.AddInstruction(
363       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
364 
365   auto xla_while = builder.AddInstruction(
366       HloInstruction::CreateWhile(tuple_shape, condition, body, param));
367   auto while_element_1 = builder.AddInstruction(
368       HloInstruction::CreateGetTupleElement(scalar_shape_, xla_while, 0));
369   auto while_element_2 = builder.AddInstruction(
370       HloInstruction::CreateGetTupleElement(scalar_shape_, xla_while, 1));
371   auto negate_1 = builder.AddInstruction(HloInstruction::CreateUnary(
372       scalar_shape_, HloOpcode::kNegate, while_element_1));
373   auto negate_2 = builder.AddInstruction(HloInstruction::CreateUnary(
374       scalar_shape_, HloOpcode::kNegate, while_element_2));
375   auto tuple =
376       builder.AddInstruction(HloInstruction::CreateTuple({negate_1, negate_2}));
377   module_->AddEntryComputation(builder.Build());
378   SCOPED_TRACE(module_->ToString());
379 
380   TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
381       /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
382       /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
383   TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
384       /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1},
385       /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
386 
387   const HloAliasAnalysis& analysis = RunAnalysis();
388 
389   EXPECT_THAT(
390       GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1})),
391       UnorderedElementsAre(GetValueDefinedAt(param, {1}),
392                            GetValueDefinedAt(xla_while, /*index=*/{1}),
393                            GetValueDefinedAt(body_param, {1}),
394                            GetValueDefinedAt(cond_param, {1}),
395                            GetValueDefinedAt(add),
396                            GetValueDefinedAt(negate_2)));
397 
398   EXPECT_THAT(
399       analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).ComputePositions(),
400       UnorderedElementsAre(
401           HloPosition{param, {1}}, HloPosition{xla_while, {1}},
402           HloPosition{while_element_2, {}}, HloPosition{body_param, {1}},
403           HloPosition{body_element_1, {}}, HloPosition{add, {}},
404           HloPosition{body_tuple, {1}}, HloPosition{tuple, {1}},
405           HloPosition{cond_param, {1}}, HloPosition{negate_2, {}}));
406 
407   EXPECT_FALSE(AnyValuesInSameBufferInterfere());
408 }
409 
TEST_F(HloAliasAnalysisTest,SingleCall)410 TEST_F(HloAliasAnalysisTest, SingleCall) {
411   // Test a single call of a subcomputation. The subcomputation adds its two
412   // array-shaped parameters.
413   auto subbuilder = HloComputation::Builder("Subcomputation");
414   auto subparam0 = subbuilder.AddInstruction(
415       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
416   auto subparam1 = subbuilder.AddInstruction(
417       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
418   auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary(
419       scalar_shape_, HloOpcode::kAdd, subparam0, subparam1));
420   HloComputation* called_computation =
421       module_->AddEmbeddedComputation(subbuilder.Build());
422 
423   auto builder = HloComputation::Builder(TestName());
424   auto constant1 = builder.AddInstruction(
425       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
426   auto constant2 = builder.AddInstruction(
427       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
428   auto call = builder.AddInstruction(HloInstruction::CreateCall(
429       scalar_shape_, {constant1, constant2}, called_computation));
430   module_->AddEntryComputation(builder.Build());
431   SCOPED_TRACE(module_->ToString());
432 
433   const HloAliasAnalysis& analysis = RunAnalysis();
434 
435   // Verify aliasing of the kCall operands and the subcomputation parameters.
436   EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).ComputePositions(),
437               UnorderedElementsAre(HloPosition{constant1, {}},
438                                    HloPosition{subparam0, {}}));
439   EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).ComputePositions(),
440               UnorderedElementsAre(HloPosition{constant2, {}},
441                                    HloPosition{subparam1, {}}));
442 
443   // The subcomputation root and the kCall itself should alias.
444   EXPECT_THAT(
445       analysis.GetUniqueBufferAt(add).ComputePositions(),
446       UnorderedElementsAre(HloPosition{add, {}}, HloPosition{call, {}}));
447 
448   EXPECT_FALSE(AnyValuesInSameBufferInterfere());
449 }
450 
TEST_F(HloAliasAnalysisTest,ComputationCalledTwice)451 TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) {
452   // Test a subcomputation which is called twice with different argument values.
453   auto subbuilder = HloComputation::Builder("Subcomputation");
454   auto subparam0 = subbuilder.AddInstruction(
455       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
456   auto subparam1 = subbuilder.AddInstruction(
457       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
458   auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary(
459       scalar_shape_, HloOpcode::kAdd, subparam0, subparam1));
460   HloComputation* called_computation =
461       module_->AddEmbeddedComputation(subbuilder.Build());
462 
463   auto builder = HloComputation::Builder(TestName());
464   auto constant1 = builder.AddInstruction(
465       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
466   auto constant2 = builder.AddInstruction(
467       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
468   auto call1 = builder.AddInstruction(HloInstruction::CreateCall(
469       scalar_shape_, {constant1, constant2}, called_computation));
470   auto call2 = builder.AddInstruction(HloInstruction::CreateCall(
471       scalar_shape_, {call1, constant2}, called_computation));
472   module_->AddEntryComputation(builder.Build());
473   SCOPED_TRACE(module_->ToString());
474 
475   const HloAliasAnalysis& analysis = RunAnalysis();
476 
477   EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).ComputePositions(),
478               UnorderedElementsAre(HloPosition{constant1, {}},
479                                    HloPosition{subparam0, {}}));
480   EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).ComputePositions(),
481               UnorderedElementsAre(HloPosition{constant2, {}},
482                                    HloPosition{subparam1, {}}));
483 
484   // The 'add' (root of the subcomputation) aliases the two call instruction,
485   // and the first parameter of the subcomputation because 'call1' it is passed
486   // as an argument to the subcomputation in 'call2'.
487   EXPECT_THAT(
488       analysis.GetUniqueBufferAt(add).ComputePositions(),
489       UnorderedElementsAre(HloPosition{add, {}}, HloPosition{call1, {}},
490                            HloPosition{subparam0, {}}, HloPosition{call2, {}}));
491 
492   EXPECT_THAT(GetBuffersAt(subparam0),
493               UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1),
494                                    analysis.GetUniqueBufferAt(add)));
495   EXPECT_THAT(GetBuffersAt(subparam1),
496               UnorderedElementsAre(analysis.GetUniqueBufferAt(constant2)));
497 
498   EXPECT_TRUE(analysis.InstructionBuffersAreAmbiguous(subparam0));
499   EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(subparam1));
500   EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(subparam0));
501   EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(subparam1));
502 
503   EXPECT_FALSE(AnyValuesInSameBufferInterfere());
504 }
505 
TEST_F(HloAliasAnalysisTest,SingleWhile)506 TEST_F(HloAliasAnalysisTest, SingleWhile) {
507   // Test a simple single while instruction. The while body includes a
508   // pass-through value. HLO:
509   //
510   // body((F32[], F32[]) %tuple_param):
511   //   %add = Add(%tuple_param{0}, %tuple_param{1})
512   //   return Tuple(%tuple_param{0}, %add)
513   //
514   // condition((F32[], F32[]) %tuple_param):
515   //   return Constant(false)
516   //
517   // entry:
518   //   %constant1 = Constant(1.0)
519   //   %constant2 = Constant(2.0)
520   //   %tuple = Tuple(%constant1, %constant2)
521   //   return While(%tuple, body, condition)
522   //
523   const Shape tuple_shape =
524       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
525 
526   // Element 0 passes transparently through the body.
527   auto body_builder = HloComputation::Builder("body");
528   auto body_param = body_builder.AddInstruction(
529       HloInstruction::CreateParameter(0, tuple_shape, "param"));
530   auto body_element_0 = body_builder.AddInstruction(
531       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
532   auto body_element_1 = body_builder.AddInstruction(
533       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
534   auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
535       scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
536   auto body_tuple = body_builder.AddInstruction(
537       HloInstruction::CreateTuple({body_element_0, add}));
538   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
539 
540   // Condition computation trivially returns a constant "false".
541   auto cond_builder = HloComputation::Builder("condition");
542   auto cond_param = cond_builder.AddInstruction(
543       HloInstruction::CreateParameter(0, tuple_shape, "param"));
544   cond_builder.AddInstruction(
545       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
546   HloComputation* condition =
547       module_->AddEmbeddedComputation(cond_builder.Build());
548 
549   auto builder = HloComputation::Builder(TestName());
550   auto constant1 = builder.AddInstruction(
551       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
552   auto constant2 = builder.AddInstruction(
553       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
554   auto tuple = builder.AddInstruction(
555       HloInstruction::CreateTuple({constant1, constant2}));
556   auto xla_while = builder.AddInstruction(
557       HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
558   module_->AddEntryComputation(builder.Build());
559   SCOPED_TRACE(module_->ToString());
560 
561   const HloAliasAnalysis& analysis = RunAnalysis();
562 
563   // Verify the positions of the aliased while buffers.
564   EXPECT_THAT(
565       analysis.GetUniqueBufferAt(xla_while, /*index=*/{}).ComputePositions(),
566       UnorderedElementsAre(HloPosition{tuple, {}}, HloPosition{xla_while, {}},
567                            HloPosition{body_param, {}},
568                            HloPosition{body_tuple, {}},
569                            HloPosition{cond_param, {}}));
570   EXPECT_THAT(
571       analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}).ComputePositions(),
572       UnorderedElementsAre(
573           HloPosition{constant1, {}}, HloPosition{tuple, {0}},
574           HloPosition{xla_while, {0}}, HloPosition{body_param, {0}},
575           HloPosition{body_element_0, {}}, HloPosition{body_tuple, {0}},
576           HloPosition{cond_param, {0}}));
577   EXPECT_THAT(
578       analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).ComputePositions(),
579       UnorderedElementsAre(
580           HloPosition{constant2, {}}, HloPosition{tuple, {1}},
581           HloPosition{xla_while, {1}}, HloPosition{body_param, {1}},
582           HloPosition{body_element_1, {}}, HloPosition{add, {}},
583           HloPosition{body_tuple, {1}}, HloPosition{cond_param, {1}}));
584 
585   EXPECT_THAT(
586       GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0})),
587       UnorderedElementsAre(GetValueDefinedAt(constant1)));
588   EXPECT_THAT(
589       GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1})),
590       UnorderedElementsAre(GetValueDefinedAt(constant2),
591                            GetValueDefinedAt(xla_while, /*index=*/{1}),
592                            GetValueDefinedAt(body_param, {1}),
593                            GetValueDefinedAt(cond_param, {1}),
594                            GetValueDefinedAt(add)));
595 
596   EXPECT_FALSE(AnyValuesInSameBufferInterfere());
597 }
598 
TEST_F(HloAliasAnalysisTest,SequentialWhiles)599 TEST_F(HloAliasAnalysisTest, SequentialWhiles) {
600   // Test sequential while instructions. The while body includes a
601   // pass-through value. HLO:
602   //
603   // body((F32[], F32[]) %tuple_param):
604   //   %add = Add(%tuple_param{0}, %tuple_param{1})
605   //   return Tuple(%tuple_param{0}, %add)
606   //
607   // condition((F32[], F32[]) %tuple_param):
608   //   return Constant(false)
609   //
610   // entry:
611   //   %constant1 = Constant(1.0)
612   //   %constant2 = Constant(2.0)
613   //   %tuple = Tuple(%constant1, %constant2)
614   //   %while0 = While(%tuple, body, condition)
615   //   %while1 = While(%while0, body, condition)
616   //   return While(%while1, body, condition)
617   //
618   const Shape tuple_shape =
619       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
620 
621   // Element 0 passes transparently through the body.
622   auto body_builder = HloComputation::Builder("body");
623   auto body_param = body_builder.AddInstruction(
624       HloInstruction::CreateParameter(0, tuple_shape, "param"));
625   auto body_element_0 = body_builder.AddInstruction(
626       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
627   auto body_element_1 = body_builder.AddInstruction(
628       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
629   auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
630       scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
631   body_builder.AddInstruction(
632       HloInstruction::CreateTuple({body_element_0, add}));
633   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
634 
635   auto cond_builder = HloComputation::Builder("condition");
636   cond_builder.AddInstruction(
637       HloInstruction::CreateParameter(0, tuple_shape, "param"));
638   cond_builder.AddInstruction(
639       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
640   HloComputation* condition =
641       module_->AddEmbeddedComputation(cond_builder.Build());
642 
643   auto builder = HloComputation::Builder(TestName());
644   auto constant1 = builder.AddInstruction(
645       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
646   auto constant2 = builder.AddInstruction(
647       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
648   auto tuple = builder.AddInstruction(
649       HloInstruction::CreateTuple({constant1, constant2}));
650   auto xla_while0 = builder.AddInstruction(
651       HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
652   auto xla_while1 = builder.AddInstruction(
653       HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while0));
654   auto xla_while2 = builder.AddInstruction(
655       HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while1));
656   module_->AddEntryComputation(builder.Build());
657 
658   FlattenCallGraph flattener;
659   TF_ASSERT_OK(flattener.Run(module_.get()).status());
660   SCOPED_TRACE(module_->ToString());
661 
662   const HloAliasAnalysis& analysis = RunAnalysis();
663 
664   EXPECT_EQ(analysis.GetUniqueBufferAt(tuple, /*index=*/{}),
665             analysis.GetUniqueBufferAt(xla_while2, /*index=*/{}));
666   EXPECT_EQ(analysis.GetUniqueBufferAt(constant1),
667             analysis.GetUniqueBufferAt(xla_while2, /*index=*/{0}));
668   EXPECT_EQ(analysis.GetUniqueBufferAt(constant2),
669             analysis.GetUniqueBufferAt(xla_while2, /*index=*/{1}));
670 }
671 
TEST_F(HloAliasAnalysisTest,NestedWhiles)672 TEST_F(HloAliasAnalysisTest, NestedWhiles) {
673   // Test nested while instructions. The inner body passes through element 0 of
674   // its parameter, and the outer body passes through element 1.  HLO:
675   //
676   // inner_body((F32[], F32[]) %tuple_param):
677   //   %add = Add(%tuple_param{0}, %tuple_param{1})
678   //   return Tuple(%tuple_param{0}, %add)
679   //
680   // outer_body((F32[], F32[]) %tuple_param):
681   //   %negate = Negate(%tuple_param{0})
682   //   %tuple = Tuple(%negate, %tuple_param{1})
683   //   return While(%tuple, inner_body, condition)
684   //
685   // entry:
686   //   %constant1 = Constant(1.0)
687   //   %constant2 = Constant(2.0)
688   //   %tuple = Tuple(%constant1, %constant2)
689   //   return While(%tuple, outer_body, condition)
690   //
691   const Shape tuple_shape =
692       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
693 
694   auto build_cond_computation = [&tuple_shape]() {
695     auto cond_builder = HloComputation::Builder("condition");
696     cond_builder.AddInstruction(
697         HloInstruction::CreateParameter(0, tuple_shape, "param"));
698     cond_builder.AddInstruction(
699         HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
700     return cond_builder.Build();
701   };
702   // Build separate condition computations so the call graph is flat. The
703   // callgraph is always flattened in the compiler pipeline, and the flattened
704   // callgraph enables representative interference analysis.
705   HloComputation* condition1 =
706       module_->AddEmbeddedComputation(build_cond_computation());
707   HloComputation* condition2 =
708       module_->AddEmbeddedComputation(build_cond_computation());
709 
710   // Element 0 passes transparently through the body.
711   auto inner_builder = HloComputation::Builder("inner_body");
712   auto inner_param = inner_builder.AddInstruction(
713       HloInstruction::CreateParameter(0, tuple_shape, "param"));
714   auto inner_element_0 = inner_builder.AddInstruction(
715       HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 0));
716   auto inner_element_1 = inner_builder.AddInstruction(
717       HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 1));
718   auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary(
719       scalar_shape_, HloOpcode::kAdd, inner_element_0, inner_element_1));
720   inner_builder.AddInstruction(
721       HloInstruction::CreateTuple({inner_element_0, add}));
722   HloComputation* inner_body =
723       module_->AddEmbeddedComputation(inner_builder.Build());
724 
725   // Element 1 passes transparently through the body.
726   auto outer_builder = HloComputation::Builder("outer_body");
727   auto outer_param = outer_builder.AddInstruction(
728       HloInstruction::CreateParameter(0, tuple_shape, "param"));
729   auto outer_element_0 = outer_builder.AddInstruction(
730       HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 0));
731   auto negate = outer_builder.AddInstruction(HloInstruction::CreateUnary(
732       scalar_shape_, HloOpcode::kNegate, outer_element_0));
733   auto outer_element_1 = outer_builder.AddInstruction(
734       HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 1));
735   auto outer_tuple = outer_builder.AddInstruction(
736       HloInstruction::CreateTuple({negate, outer_element_1}));
737   auto nested_while = outer_builder.AddInstruction(HloInstruction::CreateWhile(
738       tuple_shape, condition1, inner_body, outer_tuple));
739   HloComputation* outer_body =
740       module_->AddEmbeddedComputation(outer_builder.Build());
741 
742   auto builder = HloComputation::Builder(TestName());
743   auto constant1 = builder.AddInstruction(
744       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
745   auto constant2 = builder.AddInstruction(
746       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
747   auto tuple = builder.AddInstruction(
748       HloInstruction::CreateTuple({constant1, constant2}));
749   auto entry_while = builder.AddInstruction(
750       HloInstruction::CreateWhile(tuple_shape, condition2, outer_body, tuple));
751   module_->AddEntryComputation(builder.Build());
752   SCOPED_TRACE(module_->ToString());
753 
754   const HloAliasAnalysis& analysis = RunAnalysis();
755 
756   EXPECT_EQ(analysis.GetUniqueBufferAt(constant1),
757             analysis.GetUniqueBufferAt(entry_while, /*index=*/{0}));
758   EXPECT_EQ(analysis.GetUniqueBufferAt(constant1),
759             analysis.GetUniqueBufferAt(nested_while, /*index=*/{0}));
760   EXPECT_EQ(analysis.GetUniqueBufferAt(constant1),
761             analysis.GetUniqueBufferAt(inner_element_0));
762 
763   EXPECT_EQ(analysis.GetUniqueBufferAt(constant2),
764             analysis.GetUniqueBufferAt(entry_while, /*index=*/{1}));
765   EXPECT_EQ(analysis.GetUniqueBufferAt(constant2),
766             analysis.GetUniqueBufferAt(nested_while, /*index=*/{1}));
767   EXPECT_EQ(analysis.GetUniqueBufferAt(constant2),
768             analysis.GetUniqueBufferAt(inner_element_1));
769 
770   EXPECT_FALSE(AnyValuesInSameBufferInterfere());
771 }
772 
TEST_F(HloAliasAnalysisTest,SwizzlingWhile)773 TEST_F(HloAliasAnalysisTest, SwizzlingWhile) {
774   // Test a while instruction with a body which permutes it's tuple parameter
775   // elements. HLO:
776   //
777   // body((F32[], F32[], F32[]) %tuple_param):
778   //   return Tuple(%tuple_param{1}, %tuple_param{2}, %tuple_param{0})
779   //
780   // condition((F32[], F32[]) %tuple_param):
781   //   return Constant(false)
782   //
783   // entry:
784   //   %constant1 = Constant(1.0)
785   //   %constant2 = Constant(2.0)
786   //   %constant3 = Constant(3.0)
787   //   %tuple = Tuple(%constant1, %constant2, %constant3)
788   //   return While(%tuple, body, condition)
789   //
790   const Shape tuple_shape =
791       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_, scalar_shape_});
792 
793   auto body_builder = HloComputation::Builder("body");
794   auto body_param = body_builder.AddInstruction(
795       HloInstruction::CreateParameter(0, tuple_shape, "param"));
796   auto body_element_0 = body_builder.AddInstruction(
797       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
798   auto body_element_1 = body_builder.AddInstruction(
799       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
800   auto body_element_2 = body_builder.AddInstruction(
801       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 2));
802   body_builder.AddInstruction(HloInstruction::CreateTuple(
803       {body_element_1, body_element_2, body_element_0}));
804   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
805 
806   auto cond_builder = HloComputation::Builder("condition");
807   cond_builder.AddInstruction(
808       HloInstruction::CreateParameter(0, tuple_shape, "param"));
809   auto cond_constant = cond_builder.AddInstruction(
810       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
811   HloComputation* condition =
812       module_->AddEmbeddedComputation(cond_builder.Build());
813 
814   auto builder = HloComputation::Builder(TestName());
815   auto constant1 = builder.AddInstruction(
816       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
817   auto constant2 = builder.AddInstruction(
818       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
819   auto constant3 = builder.AddInstruction(
820       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
821   auto tuple = builder.AddInstruction(
822       HloInstruction::CreateTuple({constant1, constant2, constant3}));
823   auto xla_while = builder.AddInstruction(
824       HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
825   module_->AddEntryComputation(builder.Build());
826   SCOPED_TRACE(module_->ToString());
827 
828   const HloAliasAnalysis& analysis = RunAnalysis();
829 
830   // The swizzling while makes most positions in the module alias leaving only 3
831   // HloBuffers.
832   EXPECT_THAT(
833       analysis.buffers(),
834       UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1),
835                            analysis.GetUniqueBufferAt(tuple, /*index=*/{}),
836                            analysis.GetUniqueBufferAt(cond_constant)));
837 
838   // The tuple elements of the while and the three constant inputs should all be
839   // smooshed into the same buffer.
840   EXPECT_EQ(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}),
841             analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}));
842   EXPECT_EQ(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}),
843             analysis.GetUniqueBufferAt(xla_while, /*index=*/{2}));
844   EXPECT_EQ(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}),
845             analysis.GetUniqueBufferAt(constant1));
846   EXPECT_EQ(analysis.GetUniqueBufferAt(constant1),
847             analysis.GetUniqueBufferAt(constant2));
848   EXPECT_EQ(analysis.GetUniqueBufferAt(constant1),
849             analysis.GetUniqueBufferAt(constant3));
850 
851   // All elements in of the loop state tuple are forced into the same buffer
852   // resulting liveness interference.
853   EXPECT_TRUE(AnyValuesInSameBufferInterfere());
854 }
855 
TEST_F(HloAliasAnalysisTest,TupleSelect)856 TEST_F(HloAliasAnalysisTest, TupleSelect) {
857   // Test a kTupleSelect. Non-top-level element flow through the instruction.
858   auto builder = HloComputation::Builder(TestName());
859   auto pred = builder.AddInstruction(
860       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
861   auto constant1 = builder.AddInstruction(
862       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
863   auto constant2 = builder.AddInstruction(
864       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
865   auto constant3 = builder.AddInstruction(
866       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
867   auto constant4 = builder.AddInstruction(
868       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4.0)));
869   auto tuple1 =
870       builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
871   auto tuple2 =
872       builder.AddInstruction(HloInstruction::CreateTuple({constant2}));
873   auto tuple3 =
874       builder.AddInstruction(HloInstruction::CreateTuple({constant3}));
875   auto tuple4 =
876       builder.AddInstruction(HloInstruction::CreateTuple({constant4}));
877   const Shape tuple_shape = tuple1->shape();
878   auto select11 = builder.AddInstruction(HloInstruction::CreateTernary(
879       tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple1));
880   auto select12 = builder.AddInstruction(HloInstruction::CreateTernary(
881       tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple2));
882   auto select34 = builder.AddInstruction(HloInstruction::CreateTernary(
883       tuple_shape, HloOpcode::kTupleSelect, pred, tuple3, tuple4));
884   auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary(
885       tuple_shape, HloOpcode::kTupleSelect, pred, select12, select34));
886 
887   module_->AddEntryComputation(builder.Build());
888   SCOPED_TRACE(module_->ToString());
889 
890   const HloAliasAnalysis& analysis = RunAnalysis();
891 
892   // Verify the buffer sets of each select.
893   EXPECT_THAT(GetBuffersAt(select11, /*index=*/{0}),
894               UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1)));
895   EXPECT_THAT(GetBuffersAt(select12, /*index=*/{0}),
896               UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1),
897                                    analysis.GetUniqueBufferAt(constant2)));
898   EXPECT_THAT(GetBuffersAt(select34, /*index=*/{0}),
899               UnorderedElementsAre(analysis.GetUniqueBufferAt(constant3),
900                                    analysis.GetUniqueBufferAt(constant4)));
901   EXPECT_THAT(GetBuffersAt(select1234, /*index=*/{0}),
902               UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1),
903                                    analysis.GetUniqueBufferAt(constant2),
904                                    analysis.GetUniqueBufferAt(constant3),
905                                    analysis.GetUniqueBufferAt(constant4)));
906 
907   EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(select11));
908   EXPECT_TRUE(analysis.InstructionBuffersAreAmbiguous(select12));
909   EXPECT_TRUE(analysis.InstructionBuffersAreAmbiguous(select34));
910   EXPECT_TRUE(analysis.InstructionBuffersAreAmbiguous(select1234));
911 
912   EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select11));
913   EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select12));
914   EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select34));
915   EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select1234));
916 
917   EXPECT_FALSE(AnyValuesInSameBufferInterfere());
918 }
919 
TEST_F(HloAliasAnalysisTest,TupleSelectToWhile)920 TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) {
921   // Test a tuple-shaped kTupleSelect feeding a kWhile instruction. HLO:
922   //
923   // body((F32[], F32[]) %tuple_param):
924   //   %negate = Negate(%tuple_param{0})
925   //   return Tuple(%negate)
926   //
927   // condition((F32[], F32[]) %tuple_param):
928   //   return Constant(false)
929   //
930   // entry:
931   //   %constant1 = Constant(1.0)
932   //   %constant2 = Constant(2.0)
933   //   %tuple1 = Tuple(%constant1)
934   //   %tuple2 = Tuple(%constant2)
935   //   %select = Select(%tuple1, %tuple2)
936   //   return While(%select, body, condition)
937   //
938   auto builder = HloComputation::Builder(TestName());
939 
940   const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_});
941 
942   // Element 0 passes transparently through the body.
943   auto body_builder = HloComputation::Builder("body");
944   auto body_param = body_builder.AddInstruction(
945       HloInstruction::CreateParameter(0, tuple_shape, "param"));
946   auto body_element = body_builder.AddInstruction(
947       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
948   auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
949       scalar_shape_, HloOpcode::kNegate, body_element));
950   body_builder.AddInstruction(HloInstruction::CreateTuple({negate}));
951   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
952 
953   auto cond_builder = HloComputation::Builder("condition");
954   auto cond_param = cond_builder.AddInstruction(
955       HloInstruction::CreateParameter(0, tuple_shape, "param"));
956   cond_builder.AddInstruction(
957       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
958   HloComputation* condition =
959       module_->AddEmbeddedComputation(cond_builder.Build());
960 
961   auto pred = builder.AddInstruction(
962       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
963   auto constant1 = builder.AddInstruction(
964       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
965   auto constant2 = builder.AddInstruction(
966       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
967   auto tuple1 =
968       builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
969   auto tuple2 =
970       builder.AddInstruction(HloInstruction::CreateTuple({constant2}));
971   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
972       tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple2));
973   auto xla_while = builder.AddInstruction(
974       HloInstruction::CreateWhile(tuple_shape, condition, body, select));
975 
976   module_->AddEntryComputation(builder.Build());
977   SCOPED_TRACE(module_->ToString());
978 
979   const HloAliasAnalysis& analysis = RunAnalysis();
980 
981   // The while should flatten the ambiguous select buffer set so that the buffer
982   // set contents (constant1 and constant2) becomes a single buffer.
983   EXPECT_EQ(analysis.GetUniqueBufferAt(constant1),
984             analysis.GetUniqueBufferAt(constant2));
985   EXPECT_EQ(analysis.GetUniqueBufferAt(constant1),
986             analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}));
987 
988   EXPECT_THAT(GetValuesInBuffer(analysis.GetUniqueBufferAt(constant1)),
989               UnorderedElementsAre(GetValueDefinedAt(constant1),
990                                    GetValueDefinedAt(constant2),
991                                    GetValueDefinedAt(xla_while, /*index=*/{0}),
992                                    GetValueDefinedAt(body_param, /*index=*/{0}),
993                                    GetValueDefinedAt(cond_param, /*index=*/{0}),
994                                    GetValueDefinedAt(negate)));
995   EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(select));
996   EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(xla_while));
997 
998   EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select));
999   EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(xla_while));
1000 
1001   // The two operands of the select get flattened into the same buffer resulting
1002   // in liveness interference.
1003   EXPECT_TRUE(AnyValuesInSameBufferInterfere());
1004 }
1005 
TEST_F(HloAliasAnalysisTest,Bitcast)1006 TEST_F(HloAliasAnalysisTest, Bitcast) {
1007   // Bitcasting a value should not produce a new buffer.
1008   auto builder = HloComputation::Builder(TestName());
1009   auto constant = builder.AddInstruction(
1010       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1011   auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
1012       scalar_shape_, HloOpcode::kBitcast, constant));
1013 
1014   module_->AddEntryComputation(builder.Build());
1015   SCOPED_TRACE(module_->ToString());
1016 
1017   const HloAliasAnalysis& analysis = RunAnalysis();
1018 
1019   EXPECT_EQ(analysis.buffers().size(), 1);
1020 
1021   EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
1022             analysis.GetUniqueBufferAt(bitcast));
1023 }
1024 
TEST_F(HloAliasAnalysisTest,BitcastInterference)1025 TEST_F(HloAliasAnalysisTest, BitcastInterference) {
1026   // A bitcast value simultaneously live with its operand should not cause
1027   // interference.
1028   auto builder = HloComputation::Builder(TestName());
1029   auto constant = builder.AddInstruction(
1030       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1031   auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
1032       scalar_shape_, HloOpcode::kBitcast, constant));
1033   builder.AddInstruction(HloInstruction::CreateTuple({constant, bitcast}));
1034 
1035   module_->AddEntryComputation(builder.Build());
1036   SCOPED_TRACE(module_->ToString());
1037 
1038   const HloAliasAnalysis& analysis = RunAnalysis();
1039 
1040   DependencyHloOrdering ordering(module_.get());
1041   EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering));
1042 }
1043 
TEST_F(HloAliasAnalysisTest,WhileInterference)1044 TEST_F(HloAliasAnalysisTest, WhileInterference) {
1045   // Build a while loop which has a parallel use of the init value. Depending on
1046   // ordering there may be interference between the update-in-place while and
1047   // the other use of the init.
1048   auto builder = HloComputation::Builder(TestName());
1049   auto init = builder.AddInstruction(
1050       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1051 
1052   auto cond_builder = HloComputation::Builder("condition");
1053   auto cond_param = cond_builder.AddInstruction(
1054       HloInstruction::CreateParameter(0, init->shape(), "param"));
1055   auto cond_root = cond_builder.AddInstruction(
1056       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1057   HloComputation* condition =
1058       module_->AddEmbeddedComputation(cond_builder.Build());
1059 
1060   auto body_builder = HloComputation::Builder("body");
1061   auto body_param = body_builder.AddInstruction(
1062       HloInstruction::CreateParameter(0, init->shape(), "param"));
1063   auto body_root = body_builder.AddInstruction(
1064       HloInstruction::CreateUnary(init->shape(), HloOpcode::kExp, body_param));
1065   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
1066 
1067   auto xla_while = builder.AddInstruction(
1068       HloInstruction::CreateWhile(init->shape(), condition, body, init));
1069 
1070   auto negate = builder.AddInstruction(
1071       HloInstruction::CreateUnary(init->shape(), HloOpcode::kNegate, init));
1072   auto entry_root =
1073       builder.AddInstruction(HloInstruction::CreateTuple({negate, xla_while}));
1074 
1075   HloComputation* entry = module_->AddEntryComputation(builder.Build());
1076   SCOPED_TRACE(module_->ToString());
1077 
1078   const HloAliasAnalysis& analysis = RunAnalysis();
1079 
1080   {
1081     // Dependency ordering should interfere because the negate and while are
1082     // unordered.
1083     DependencyHloOrdering ordering(module_.get());
1084     EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering));
1085   }
1086 
1087   // For a sequential order, if there is interference iff the negate is after
1088   // the while.
1089   HloSchedule schedule(module_.get());
1090   schedule.set_sequence(body, {body_param, body_root});
1091   schedule.set_sequence(condition, {cond_param, cond_root});
1092   {
1093     schedule.set_sequence(entry, {init, xla_while, negate, entry_root});
1094     TF_ASSERT_OK(schedule.Verify());
1095     SequentialHloOrdering ordering(schedule);
1096     EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering));
1097   }
1098 
1099   {
1100     schedule.set_sequence(entry, {init, negate, xla_while, entry_root});
1101     TF_ASSERT_OK(schedule.Verify());
1102     SequentialHloOrdering ordering(schedule);
1103     EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering));
1104   }
1105 }
1106 
1107 }  // namespace
1108 }  // namespace xla
1109