1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
17
18 #include <map>
19 #include <memory>
20
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
23 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
26 #include "tensorflow/compiler/xla/service/instruction_fusion.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/test.h"
29 #include "tensorflow/compiler/xla/test_helpers.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/test.h"
35
36 namespace xla {
37 namespace {
38
39 using ::testing::UnorderedElementsAre;
40
41 class HloAliasAnalysisTest : public HloTestBase {
42 protected:
HloAliasAnalysisTest()43 HloAliasAnalysisTest() : HloTestBase() {
44 module_ = CreateNewVerifiedModule();
45 }
46
47 // Run alias analysis on the member module. For convenience returns a
48 // reference to the generated analysis stored in analysis_.
RunAnalysis()49 HloAliasAnalysis& RunAnalysis() {
50 analysis_ = HloAliasAnalysis::Run(module_.get(),
51 /*fusion_can_share_buffer=*/nullptr)
52 .ConsumeValueOrDie();
53 return *analysis_;
54 }
55
56 // Return a vector of the buffers in the buffer set at the current position
57 // sorted by buffer id.
GetBuffersAt(const HloInstruction * instruction,const ShapeIndex & index={}) const58 std::vector<HloBuffer> GetBuffersAt(const HloInstruction* instruction,
59 const ShapeIndex& index = {}) const {
60 std::set<HloBuffer::Id> buffer_ids;
61 for (const HloValue* value : analysis_->dataflow_analysis()
62 .GetValueSet(instruction, index)
63 .values()) {
64 buffer_ids.insert(analysis_->GetBufferContainingValue(*value).id());
65 }
66
67 std::vector<HloBuffer> buffers;
68 for (HloBuffer::Id id : buffer_ids) {
69 buffers.push_back(analysis_->GetBuffer(id));
70 }
71 return buffers;
72 }
73
74 // Return a vector containing all of the HloValues in the given buffer.
GetValuesInBuffer(const HloBuffer & buffer)75 std::vector<HloValue> GetValuesInBuffer(const HloBuffer& buffer) {
76 std::vector<HloValue> values;
77 for (const HloValue* value : buffer.values()) {
78 values.push_back(*value);
79 }
80 return values;
81 }
82
83 // Return the HloValue defined at the given position.
GetValueDefinedAt(const HloInstruction * instruction,const ShapeIndex & index={}) const84 const HloValue& GetValueDefinedAt(const HloInstruction* instruction,
85 const ShapeIndex& index = {}) const {
86 return analysis_->dataflow_analysis().GetValueDefinedAt(instruction, index);
87 }
88
89 // Returns true if any values held in the same buffer interfere. Generally, in
90 // the compiler pipeline copy-insertion will guarantee that this interference
91 // never occurs, but HLO graphs with interference can be explicitly
92 // constructed.
AnyValuesInSameBufferInterfere()93 bool AnyValuesInSameBufferInterfere() {
94 DependencyHloOrdering ordering(module_.get());
95 for (const HloBuffer& buffer : analysis_->buffers()) {
96 for (const HloValue* value_a : buffer.values()) {
97 for (const HloValue* value_b : buffer.values()) {
98 if (*value_a != *value_b &&
99 ordering.MayInterfere(*value_a, *value_b,
100 analysis_->dataflow_analysis())) {
101 VLOG(1) << *value_a << " interferes with " << *value_b
102 << " in buffer: " << buffer;
103 return true;
104 }
105 }
106 }
107 }
108 return false;
109 }
110
111 std::unique_ptr<HloModule> module_;
112 std::unique_ptr<HloAliasAnalysis> analysis_;
113
114 const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
115 };
116
TEST_F(HloAliasAnalysisTest,BinaryOperation)117 TEST_F(HloAliasAnalysisTest, BinaryOperation) {
118 // Test the analysis on a single binary operation (Add).
119 auto builder = HloComputation::Builder(TestName());
120 auto constant1 = builder.AddInstruction(
121 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
122 auto constant2 = builder.AddInstruction(
123 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
124 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
125 scalar_shape_, HloOpcode::kAdd, constant1, constant2));
126 module_->AddEntryComputation(builder.Build());
127 SCOPED_TRACE(module_->ToString());
128
129 const HloAliasAnalysis& analysis = RunAnalysis();
130
131 EXPECT_EQ(analysis.buffers().size(), 3);
132
133 // All of the buffer sets should trivially contain a single buffer containing
134 // a single value.
135 for (const HloInstruction* instruction : {constant1, constant2, add}) {
136 EXPECT_EQ(analysis.GetUniqueBufferAt(instruction).GetUniqueValue(),
137 GetValueDefinedAt(instruction));
138 }
139
140 EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(add));
141 EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(add));
142
143 EXPECT_FALSE(AnyValuesInSameBufferInterfere());
144 }
145
TEST_F(HloAliasAnalysisTest,TupleAndGtes)146 TEST_F(HloAliasAnalysisTest, TupleAndGtes) {
147 // Verify the analysis for a Tuple and GetTupleElement instructions.
148 auto builder = HloComputation::Builder(TestName());
149 auto param0 = builder.AddInstruction(
150 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
151 auto param1 = builder.AddInstruction(
152 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
153 auto tuple =
154 builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
155 auto gte0 = builder.AddInstruction(
156 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 0));
157 auto gte1 = builder.AddInstruction(
158 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1));
159 builder.AddInstruction(
160 HloInstruction::CreateBinary(scalar_shape_, HloOpcode::kAdd, gte0, gte1));
161 module_->AddEntryComputation(builder.Build());
162 SCOPED_TRACE(module_->ToString());
163
164 const HloAliasAnalysis& analysis = RunAnalysis();
165
166 EXPECT_EQ(analysis.buffers().size(), 4);
167
168 // Verify the expected aliasing of the tuple elements.
169 EXPECT_EQ(analysis.GetUniqueBufferAt(tuple, /*index=*/{}).GetUniqueValue(),
170 GetValueDefinedAt(tuple, /*index=*/{}));
171 EXPECT_EQ(analysis.GetUniqueBufferAt(tuple, /*index=*/{0}).GetUniqueValue(),
172 GetValueDefinedAt(param0));
173 EXPECT_EQ(analysis.GetUniqueBufferAt(tuple, /*index=*/{1}).GetUniqueValue(),
174 GetValueDefinedAt(param1));
175
176 // The tuple operand, tuple element, and result of the GTE instruction should
177 // all be the same buffer.
178 EXPECT_EQ(analysis.GetUniqueBufferAt(param0),
179 analysis.GetUniqueBufferAt(tuple, /*index=*/{0}));
180 EXPECT_EQ(analysis.GetUniqueBufferAt(param0),
181 analysis.GetUniqueBufferAt(gte0));
182
183 // Verify the positions of an aliased buffer.
184 EXPECT_THAT(
185 analysis.GetUniqueBufferAt(param0).ComputePositions(),
186 UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}},
187 HloPosition{gte0, {}}));
188
189 EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(tuple));
190 EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(tuple));
191
192 EXPECT_FALSE(AnyValuesInSameBufferInterfere());
193 }
194
TEST_F(HloAliasAnalysisTest,NondistinctTuple)195 TEST_F(HloAliasAnalysisTest, NondistinctTuple) {
196 // Test a expression with a non-distinct buffer set.
197 auto builder = HloComputation::Builder(TestName());
198 auto param0 = builder.AddInstruction(
199 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
200 auto param1 = builder.AddInstruction(
201 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
202 // param0 is included twice in the tuple.
203 auto tuple = builder.AddInstruction(
204 HloInstruction::CreateTuple({param0, param1, param0}));
205 module_->AddEntryComputation(builder.Build());
206 SCOPED_TRACE(module_->ToString());
207
208 const HloAliasAnalysis& analysis = RunAnalysis();
209
210 EXPECT_THAT(
211 analysis.GetUniqueBufferAt(param0).ComputePositions(),
212 UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}},
213 HloPosition{tuple, {2}}));
214
215 EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(tuple));
216 EXPECT_FALSE(analysis.InstructionBuffersAreDistinct(tuple));
217
218 EXPECT_FALSE(AnyValuesInSameBufferInterfere());
219 }
220
TEST_F(HloAliasAnalysisTest,ParametersWithAliasing)221 TEST_F(HloAliasAnalysisTest, ParametersWithAliasing) {
222 const Shape tuple_shape =
223 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
224
225 auto builder = HloComputation::Builder(TestName());
226 auto param = builder.AddInstruction(
227 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
228 auto gte0 = builder.AddInstruction(
229 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
230 auto gte1 = builder.AddInstruction(
231 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
232
233 auto negate0 = builder.AddInstruction(
234 HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
235 auto negate1 = builder.AddInstruction(
236 HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
237
238 auto tuple =
239 builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1}));
240 module_->AddEntryComputation(builder.Build());
241 SCOPED_TRACE(module_->ToString());
242
243 TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
244 /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
245 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
246 TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
247 /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1},
248 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
249
250 // Cannot alias an output twice.
251 ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias(
252 /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0},
253 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
254
255 const HloAliasAnalysis& analysis = RunAnalysis();
256
257 EXPECT_EQ(analysis.GetUniqueBufferAt(gte0),
258 analysis.GetUniqueBufferAt(tuple, /*index=*/{0}));
259
260 EXPECT_EQ(analysis.GetUniqueBufferAt(gte1),
261 analysis.GetUniqueBufferAt(tuple, /*index=*/{1}));
262 }
263
TEST_F(HloAliasAnalysisTest,ParametersWithCrossAliasing)264 TEST_F(HloAliasAnalysisTest, ParametersWithCrossAliasing) {
265 // parameter 0 aliased with output 1 and parameter 1 aliased with output 0.
266 //
267 // (p0 , p1)
268 // \ /
269 // \ /
270 // alias X
271 // / \
272 // / \
273 // (p0 , p1)
274 const Shape tuple_shape =
275 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
276
277 auto builder = HloComputation::Builder(TestName());
278 auto param = builder.AddInstruction(
279 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
280 auto gte0 = builder.AddInstruction(
281 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
282 auto gte1 = builder.AddInstruction(
283 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
284 auto tuple =
285 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
286 module_->AddEntryComputation(builder.Build());
287 SCOPED_TRACE(module_->ToString());
288
289 TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
290 /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1},
291 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
292 TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
293 /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0},
294 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
295
296 // Cannot alias an output twice.
297 ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias(
298 /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1},
299 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
300
301 const HloAliasAnalysis& analysis = RunAnalysis();
302
303 // Every Ops in this graph are aliased with each other.
304 EXPECT_EQ(analysis.GetUniqueBufferAt(gte0),
305 analysis.GetUniqueBufferAt(tuple, /*index=*/{0}));
306 EXPECT_EQ(analysis.GetUniqueBufferAt(gte0),
307 analysis.GetUniqueBufferAt(tuple, /*index=*/{1}));
308
309 EXPECT_EQ(analysis.GetUniqueBufferAt(gte1),
310 analysis.GetUniqueBufferAt(tuple, /*index=*/{0}));
311 EXPECT_EQ(analysis.GetUniqueBufferAt(gte1),
312 analysis.GetUniqueBufferAt(tuple, /*index=*/{1}));
313 }
314
TEST_F(HloAliasAnalysisTest,InputOutputAliasingWithWhile)315 TEST_F(HloAliasAnalysisTest, InputOutputAliasingWithWhile) {
316 // Test a simple single while instruction can be aliased with input and output
317 // of the computation.
318 //
319 // body((F32[], F32[]) %tuple_param):
320 // %add = Add(%tuple_param{0}, %tuple_param{1})
321 // return Tuple(%tuple_param{0}, %add)
322 //
323 // condition((F32[], F32[]) %tuple_param):
324 // return Constant(false)
325 //
326 // entry:
327 // %param1 = param1
328 // %while = While(%param1, body, condition)
329 // %while_1 = GTE(%while, 0)
330 // %while_2 = GTE(%while, 1)
331 // %negate_1 = Negate(%while_1)
332 // %negate_2 = Negate(%while_2)
333 // return Tuple(negate_1, negate_2)
334 //
335 const Shape tuple_shape =
336 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
337
338 // Element 0 passes transparently through the body.
339 auto body_builder = HloComputation::Builder("body");
340 auto body_param = body_builder.AddInstruction(
341 HloInstruction::CreateParameter(0, tuple_shape, "param"));
342 auto body_element_0 = body_builder.AddInstruction(
343 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
344 auto body_element_1 = body_builder.AddInstruction(
345 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
346 auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
347 scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
348 auto body_tuple = body_builder.AddInstruction(
349 HloInstruction::CreateTuple({body_element_0, add}));
350 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
351
352 // Condition computation trivially returns a constant "false".
353 auto cond_builder = HloComputation::Builder("condition");
354 auto cond_param = cond_builder.AddInstruction(
355 HloInstruction::CreateParameter(0, tuple_shape, "param"));
356 cond_builder.AddInstruction(
357 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
358 HloComputation* condition =
359 module_->AddEmbeddedComputation(cond_builder.Build());
360
361 auto builder = HloComputation::Builder(TestName());
362 auto param = builder.AddInstruction(
363 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
364
365 auto xla_while = builder.AddInstruction(
366 HloInstruction::CreateWhile(tuple_shape, condition, body, param));
367 auto while_element_1 = builder.AddInstruction(
368 HloInstruction::CreateGetTupleElement(scalar_shape_, xla_while, 0));
369 auto while_element_2 = builder.AddInstruction(
370 HloInstruction::CreateGetTupleElement(scalar_shape_, xla_while, 1));
371 auto negate_1 = builder.AddInstruction(HloInstruction::CreateUnary(
372 scalar_shape_, HloOpcode::kNegate, while_element_1));
373 auto negate_2 = builder.AddInstruction(HloInstruction::CreateUnary(
374 scalar_shape_, HloOpcode::kNegate, while_element_2));
375 auto tuple =
376 builder.AddInstruction(HloInstruction::CreateTuple({negate_1, negate_2}));
377 module_->AddEntryComputation(builder.Build());
378 SCOPED_TRACE(module_->ToString());
379
380 TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
381 /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
382 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
383 TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
384 /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1},
385 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
386
387 const HloAliasAnalysis& analysis = RunAnalysis();
388
389 EXPECT_THAT(
390 GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1})),
391 UnorderedElementsAre(GetValueDefinedAt(param, {1}),
392 GetValueDefinedAt(xla_while, /*index=*/{1}),
393 GetValueDefinedAt(body_param, {1}),
394 GetValueDefinedAt(cond_param, {1}),
395 GetValueDefinedAt(add),
396 GetValueDefinedAt(negate_2)));
397
398 EXPECT_THAT(
399 analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).ComputePositions(),
400 UnorderedElementsAre(
401 HloPosition{param, {1}}, HloPosition{xla_while, {1}},
402 HloPosition{while_element_2, {}}, HloPosition{body_param, {1}},
403 HloPosition{body_element_1, {}}, HloPosition{add, {}},
404 HloPosition{body_tuple, {1}}, HloPosition{tuple, {1}},
405 HloPosition{cond_param, {1}}, HloPosition{negate_2, {}}));
406
407 EXPECT_FALSE(AnyValuesInSameBufferInterfere());
408 }
409
TEST_F(HloAliasAnalysisTest,SingleCall)410 TEST_F(HloAliasAnalysisTest, SingleCall) {
411 // Test a single call of a subcomputation. The subcomputation adds its two
412 // array-shaped parameters.
413 auto subbuilder = HloComputation::Builder("Subcomputation");
414 auto subparam0 = subbuilder.AddInstruction(
415 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
416 auto subparam1 = subbuilder.AddInstruction(
417 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
418 auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary(
419 scalar_shape_, HloOpcode::kAdd, subparam0, subparam1));
420 HloComputation* called_computation =
421 module_->AddEmbeddedComputation(subbuilder.Build());
422
423 auto builder = HloComputation::Builder(TestName());
424 auto constant1 = builder.AddInstruction(
425 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
426 auto constant2 = builder.AddInstruction(
427 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
428 auto call = builder.AddInstruction(HloInstruction::CreateCall(
429 scalar_shape_, {constant1, constant2}, called_computation));
430 module_->AddEntryComputation(builder.Build());
431 SCOPED_TRACE(module_->ToString());
432
433 const HloAliasAnalysis& analysis = RunAnalysis();
434
435 // Verify aliasing of the kCall operands and the subcomputation parameters.
436 EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).ComputePositions(),
437 UnorderedElementsAre(HloPosition{constant1, {}},
438 HloPosition{subparam0, {}}));
439 EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).ComputePositions(),
440 UnorderedElementsAre(HloPosition{constant2, {}},
441 HloPosition{subparam1, {}}));
442
443 // The subcomputation root and the kCall itself should alias.
444 EXPECT_THAT(
445 analysis.GetUniqueBufferAt(add).ComputePositions(),
446 UnorderedElementsAre(HloPosition{add, {}}, HloPosition{call, {}}));
447
448 EXPECT_FALSE(AnyValuesInSameBufferInterfere());
449 }
450
TEST_F(HloAliasAnalysisTest,ComputationCalledTwice)451 TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) {
452 // Test a subcomputation which is called twice with different argument values.
453 auto subbuilder = HloComputation::Builder("Subcomputation");
454 auto subparam0 = subbuilder.AddInstruction(
455 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
456 auto subparam1 = subbuilder.AddInstruction(
457 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
458 auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary(
459 scalar_shape_, HloOpcode::kAdd, subparam0, subparam1));
460 HloComputation* called_computation =
461 module_->AddEmbeddedComputation(subbuilder.Build());
462
463 auto builder = HloComputation::Builder(TestName());
464 auto constant1 = builder.AddInstruction(
465 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
466 auto constant2 = builder.AddInstruction(
467 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
468 auto call1 = builder.AddInstruction(HloInstruction::CreateCall(
469 scalar_shape_, {constant1, constant2}, called_computation));
470 auto call2 = builder.AddInstruction(HloInstruction::CreateCall(
471 scalar_shape_, {call1, constant2}, called_computation));
472 module_->AddEntryComputation(builder.Build());
473 SCOPED_TRACE(module_->ToString());
474
475 const HloAliasAnalysis& analysis = RunAnalysis();
476
477 EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).ComputePositions(),
478 UnorderedElementsAre(HloPosition{constant1, {}},
479 HloPosition{subparam0, {}}));
480 EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).ComputePositions(),
481 UnorderedElementsAre(HloPosition{constant2, {}},
482 HloPosition{subparam1, {}}));
483
484 // The 'add' (root of the subcomputation) aliases the two call instruction,
485 // and the first parameter of the subcomputation because 'call1' it is passed
486 // as an argument to the subcomputation in 'call2'.
487 EXPECT_THAT(
488 analysis.GetUniqueBufferAt(add).ComputePositions(),
489 UnorderedElementsAre(HloPosition{add, {}}, HloPosition{call1, {}},
490 HloPosition{subparam0, {}}, HloPosition{call2, {}}));
491
492 EXPECT_THAT(GetBuffersAt(subparam0),
493 UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1),
494 analysis.GetUniqueBufferAt(add)));
495 EXPECT_THAT(GetBuffersAt(subparam1),
496 UnorderedElementsAre(analysis.GetUniqueBufferAt(constant2)));
497
498 EXPECT_TRUE(analysis.InstructionBuffersAreAmbiguous(subparam0));
499 EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(subparam1));
500 EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(subparam0));
501 EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(subparam1));
502
503 EXPECT_FALSE(AnyValuesInSameBufferInterfere());
504 }
505
TEST_F(HloAliasAnalysisTest,SingleWhile)506 TEST_F(HloAliasAnalysisTest, SingleWhile) {
507 // Test a simple single while instruction. The while body includes a
508 // pass-through value. HLO:
509 //
510 // body((F32[], F32[]) %tuple_param):
511 // %add = Add(%tuple_param{0}, %tuple_param{1})
512 // return Tuple(%tuple_param{0}, %add)
513 //
514 // condition((F32[], F32[]) %tuple_param):
515 // return Constant(false)
516 //
517 // entry:
518 // %constant1 = Constant(1.0)
519 // %constant2 = Constant(2.0)
520 // %tuple = Tuple(%constant1, %constant2)
521 // return While(%tuple, body, condition)
522 //
523 const Shape tuple_shape =
524 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
525
526 // Element 0 passes transparently through the body.
527 auto body_builder = HloComputation::Builder("body");
528 auto body_param = body_builder.AddInstruction(
529 HloInstruction::CreateParameter(0, tuple_shape, "param"));
530 auto body_element_0 = body_builder.AddInstruction(
531 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
532 auto body_element_1 = body_builder.AddInstruction(
533 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
534 auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
535 scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
536 auto body_tuple = body_builder.AddInstruction(
537 HloInstruction::CreateTuple({body_element_0, add}));
538 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
539
540 // Condition computation trivially returns a constant "false".
541 auto cond_builder = HloComputation::Builder("condition");
542 auto cond_param = cond_builder.AddInstruction(
543 HloInstruction::CreateParameter(0, tuple_shape, "param"));
544 cond_builder.AddInstruction(
545 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
546 HloComputation* condition =
547 module_->AddEmbeddedComputation(cond_builder.Build());
548
549 auto builder = HloComputation::Builder(TestName());
550 auto constant1 = builder.AddInstruction(
551 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
552 auto constant2 = builder.AddInstruction(
553 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
554 auto tuple = builder.AddInstruction(
555 HloInstruction::CreateTuple({constant1, constant2}));
556 auto xla_while = builder.AddInstruction(
557 HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
558 module_->AddEntryComputation(builder.Build());
559 SCOPED_TRACE(module_->ToString());
560
561 const HloAliasAnalysis& analysis = RunAnalysis();
562
563 // Verify the positions of the aliased while buffers.
564 EXPECT_THAT(
565 analysis.GetUniqueBufferAt(xla_while, /*index=*/{}).ComputePositions(),
566 UnorderedElementsAre(HloPosition{tuple, {}}, HloPosition{xla_while, {}},
567 HloPosition{body_param, {}},
568 HloPosition{body_tuple, {}},
569 HloPosition{cond_param, {}}));
570 EXPECT_THAT(
571 analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}).ComputePositions(),
572 UnorderedElementsAre(
573 HloPosition{constant1, {}}, HloPosition{tuple, {0}},
574 HloPosition{xla_while, {0}}, HloPosition{body_param, {0}},
575 HloPosition{body_element_0, {}}, HloPosition{body_tuple, {0}},
576 HloPosition{cond_param, {0}}));
577 EXPECT_THAT(
578 analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).ComputePositions(),
579 UnorderedElementsAre(
580 HloPosition{constant2, {}}, HloPosition{tuple, {1}},
581 HloPosition{xla_while, {1}}, HloPosition{body_param, {1}},
582 HloPosition{body_element_1, {}}, HloPosition{add, {}},
583 HloPosition{body_tuple, {1}}, HloPosition{cond_param, {1}}));
584
585 EXPECT_THAT(
586 GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0})),
587 UnorderedElementsAre(GetValueDefinedAt(constant1)));
588 EXPECT_THAT(
589 GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1})),
590 UnorderedElementsAre(GetValueDefinedAt(constant2),
591 GetValueDefinedAt(xla_while, /*index=*/{1}),
592 GetValueDefinedAt(body_param, {1}),
593 GetValueDefinedAt(cond_param, {1}),
594 GetValueDefinedAt(add)));
595
596 EXPECT_FALSE(AnyValuesInSameBufferInterfere());
597 }
598
TEST_F(HloAliasAnalysisTest,SequentialWhiles)599 TEST_F(HloAliasAnalysisTest, SequentialWhiles) {
600 // Test sequential while instructions. The while body includes a
601 // pass-through value. HLO:
602 //
603 // body((F32[], F32[]) %tuple_param):
604 // %add = Add(%tuple_param{0}, %tuple_param{1})
605 // return Tuple(%tuple_param{0}, %add)
606 //
607 // condition((F32[], F32[]) %tuple_param):
608 // return Constant(false)
609 //
610 // entry:
611 // %constant1 = Constant(1.0)
612 // %constant2 = Constant(2.0)
613 // %tuple = Tuple(%constant1, %constant2)
614 // %while0 = While(%tuple, body, condition)
615 // %while1 = While(%while0, body, condition)
616 // return While(%while1, body, condition)
617 //
618 const Shape tuple_shape =
619 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
620
621 // Element 0 passes transparently through the body.
622 auto body_builder = HloComputation::Builder("body");
623 auto body_param = body_builder.AddInstruction(
624 HloInstruction::CreateParameter(0, tuple_shape, "param"));
625 auto body_element_0 = body_builder.AddInstruction(
626 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
627 auto body_element_1 = body_builder.AddInstruction(
628 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
629 auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
630 scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
631 body_builder.AddInstruction(
632 HloInstruction::CreateTuple({body_element_0, add}));
633 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
634
635 auto cond_builder = HloComputation::Builder("condition");
636 cond_builder.AddInstruction(
637 HloInstruction::CreateParameter(0, tuple_shape, "param"));
638 cond_builder.AddInstruction(
639 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
640 HloComputation* condition =
641 module_->AddEmbeddedComputation(cond_builder.Build());
642
643 auto builder = HloComputation::Builder(TestName());
644 auto constant1 = builder.AddInstruction(
645 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
646 auto constant2 = builder.AddInstruction(
647 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
648 auto tuple = builder.AddInstruction(
649 HloInstruction::CreateTuple({constant1, constant2}));
650 auto xla_while0 = builder.AddInstruction(
651 HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
652 auto xla_while1 = builder.AddInstruction(
653 HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while0));
654 auto xla_while2 = builder.AddInstruction(
655 HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while1));
656 module_->AddEntryComputation(builder.Build());
657
658 FlattenCallGraph flattener;
659 TF_ASSERT_OK(flattener.Run(module_.get()).status());
660 SCOPED_TRACE(module_->ToString());
661
662 const HloAliasAnalysis& analysis = RunAnalysis();
663
664 EXPECT_EQ(analysis.GetUniqueBufferAt(tuple, /*index=*/{}),
665 analysis.GetUniqueBufferAt(xla_while2, /*index=*/{}));
666 EXPECT_EQ(analysis.GetUniqueBufferAt(constant1),
667 analysis.GetUniqueBufferAt(xla_while2, /*index=*/{0}));
668 EXPECT_EQ(analysis.GetUniqueBufferAt(constant2),
669 analysis.GetUniqueBufferAt(xla_while2, /*index=*/{1}));
670 }
671
TEST_F(HloAliasAnalysisTest,NestedWhiles)672 TEST_F(HloAliasAnalysisTest, NestedWhiles) {
673 // Test nested while instructions. The inner body passes through element 0 of
674 // its parameter, and the outer body passes through element 1. HLO:
675 //
676 // inner_body((F32[], F32[]) %tuple_param):
677 // %add = Add(%tuple_param{0}, %tuple_param{1})
678 // return Tuple(%tuple_param{0}, %add)
679 //
680 // outer_body((F32[], F32[]) %tuple_param):
681 // %negate = Negate(%tuple_param{0})
682 // %tuple = Tuple(%negate, %tuple_param{1})
683 // return While(%tuple, inner_body, condition)
684 //
685 // entry:
686 // %constant1 = Constant(1.0)
687 // %constant2 = Constant(2.0)
688 // %tuple = Tuple(%constant1, %constant2)
689 // return While(%tuple, outer_body, condition)
690 //
691 const Shape tuple_shape =
692 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
693
694 auto build_cond_computation = [&tuple_shape]() {
695 auto cond_builder = HloComputation::Builder("condition");
696 cond_builder.AddInstruction(
697 HloInstruction::CreateParameter(0, tuple_shape, "param"));
698 cond_builder.AddInstruction(
699 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
700 return cond_builder.Build();
701 };
702 // Build separate condition computations so the call graph is flat. The
703 // callgraph is always flattened in the compiler pipeline, and the flattened
704 // callgraph enables representative interference analysis.
705 HloComputation* condition1 =
706 module_->AddEmbeddedComputation(build_cond_computation());
707 HloComputation* condition2 =
708 module_->AddEmbeddedComputation(build_cond_computation());
709
710 // Element 0 passes transparently through the body.
711 auto inner_builder = HloComputation::Builder("inner_body");
712 auto inner_param = inner_builder.AddInstruction(
713 HloInstruction::CreateParameter(0, tuple_shape, "param"));
714 auto inner_element_0 = inner_builder.AddInstruction(
715 HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 0));
716 auto inner_element_1 = inner_builder.AddInstruction(
717 HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 1));
718 auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary(
719 scalar_shape_, HloOpcode::kAdd, inner_element_0, inner_element_1));
720 inner_builder.AddInstruction(
721 HloInstruction::CreateTuple({inner_element_0, add}));
722 HloComputation* inner_body =
723 module_->AddEmbeddedComputation(inner_builder.Build());
724
725 // Element 1 passes transparently through the body.
726 auto outer_builder = HloComputation::Builder("outer_body");
727 auto outer_param = outer_builder.AddInstruction(
728 HloInstruction::CreateParameter(0, tuple_shape, "param"));
729 auto outer_element_0 = outer_builder.AddInstruction(
730 HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 0));
731 auto negate = outer_builder.AddInstruction(HloInstruction::CreateUnary(
732 scalar_shape_, HloOpcode::kNegate, outer_element_0));
733 auto outer_element_1 = outer_builder.AddInstruction(
734 HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 1));
735 auto outer_tuple = outer_builder.AddInstruction(
736 HloInstruction::CreateTuple({negate, outer_element_1}));
737 auto nested_while = outer_builder.AddInstruction(HloInstruction::CreateWhile(
738 tuple_shape, condition1, inner_body, outer_tuple));
739 HloComputation* outer_body =
740 module_->AddEmbeddedComputation(outer_builder.Build());
741
742 auto builder = HloComputation::Builder(TestName());
743 auto constant1 = builder.AddInstruction(
744 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
745 auto constant2 = builder.AddInstruction(
746 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
747 auto tuple = builder.AddInstruction(
748 HloInstruction::CreateTuple({constant1, constant2}));
749 auto entry_while = builder.AddInstruction(
750 HloInstruction::CreateWhile(tuple_shape, condition2, outer_body, tuple));
751 module_->AddEntryComputation(builder.Build());
752 SCOPED_TRACE(module_->ToString());
753
754 const HloAliasAnalysis& analysis = RunAnalysis();
755
756 EXPECT_EQ(analysis.GetUniqueBufferAt(constant1),
757 analysis.GetUniqueBufferAt(entry_while, /*index=*/{0}));
758 EXPECT_EQ(analysis.GetUniqueBufferAt(constant1),
759 analysis.GetUniqueBufferAt(nested_while, /*index=*/{0}));
760 EXPECT_EQ(analysis.GetUniqueBufferAt(constant1),
761 analysis.GetUniqueBufferAt(inner_element_0));
762
763 EXPECT_EQ(analysis.GetUniqueBufferAt(constant2),
764 analysis.GetUniqueBufferAt(entry_while, /*index=*/{1}));
765 EXPECT_EQ(analysis.GetUniqueBufferAt(constant2),
766 analysis.GetUniqueBufferAt(nested_while, /*index=*/{1}));
767 EXPECT_EQ(analysis.GetUniqueBufferAt(constant2),
768 analysis.GetUniqueBufferAt(inner_element_1));
769
770 EXPECT_FALSE(AnyValuesInSameBufferInterfere());
771 }
772
TEST_F(HloAliasAnalysisTest,SwizzlingWhile)773 TEST_F(HloAliasAnalysisTest, SwizzlingWhile) {
774 // Test a while instruction with a body which permutes it's tuple parameter
775 // elements. HLO:
776 //
777 // body((F32[], F32[], F32[]) %tuple_param):
778 // return Tuple(%tuple_param{1}, %tuple_param{2}, %tuple_param{0})
779 //
780 // condition((F32[], F32[]) %tuple_param):
781 // return Constant(false)
782 //
783 // entry:
784 // %constant1 = Constant(1.0)
785 // %constant2 = Constant(2.0)
786 // %constant3 = Constant(3.0)
787 // %tuple = Tuple(%constant1, %constant2, %constant3)
788 // return While(%tuple, body, condition)
789 //
790 const Shape tuple_shape =
791 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_, scalar_shape_});
792
793 auto body_builder = HloComputation::Builder("body");
794 auto body_param = body_builder.AddInstruction(
795 HloInstruction::CreateParameter(0, tuple_shape, "param"));
796 auto body_element_0 = body_builder.AddInstruction(
797 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
798 auto body_element_1 = body_builder.AddInstruction(
799 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
800 auto body_element_2 = body_builder.AddInstruction(
801 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 2));
802 body_builder.AddInstruction(HloInstruction::CreateTuple(
803 {body_element_1, body_element_2, body_element_0}));
804 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
805
806 auto cond_builder = HloComputation::Builder("condition");
807 cond_builder.AddInstruction(
808 HloInstruction::CreateParameter(0, tuple_shape, "param"));
809 auto cond_constant = cond_builder.AddInstruction(
810 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
811 HloComputation* condition =
812 module_->AddEmbeddedComputation(cond_builder.Build());
813
814 auto builder = HloComputation::Builder(TestName());
815 auto constant1 = builder.AddInstruction(
816 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
817 auto constant2 = builder.AddInstruction(
818 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
819 auto constant3 = builder.AddInstruction(
820 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
821 auto tuple = builder.AddInstruction(
822 HloInstruction::CreateTuple({constant1, constant2, constant3}));
823 auto xla_while = builder.AddInstruction(
824 HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
825 module_->AddEntryComputation(builder.Build());
826 SCOPED_TRACE(module_->ToString());
827
828 const HloAliasAnalysis& analysis = RunAnalysis();
829
830 // The swizzling while makes most positions in the module alias leaving only 3
831 // HloBuffers.
832 EXPECT_THAT(
833 analysis.buffers(),
834 UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1),
835 analysis.GetUniqueBufferAt(tuple, /*index=*/{}),
836 analysis.GetUniqueBufferAt(cond_constant)));
837
838 // The tuple elements of the while and the three constant inputs should all be
839 // smooshed into the same buffer.
840 EXPECT_EQ(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}),
841 analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}));
842 EXPECT_EQ(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}),
843 analysis.GetUniqueBufferAt(xla_while, /*index=*/{2}));
844 EXPECT_EQ(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}),
845 analysis.GetUniqueBufferAt(constant1));
846 EXPECT_EQ(analysis.GetUniqueBufferAt(constant1),
847 analysis.GetUniqueBufferAt(constant2));
848 EXPECT_EQ(analysis.GetUniqueBufferAt(constant1),
849 analysis.GetUniqueBufferAt(constant3));
850
851 // All elements in of the loop state tuple are forced into the same buffer
852 // resulting liveness interference.
853 EXPECT_TRUE(AnyValuesInSameBufferInterfere());
854 }
855
TEST_F(HloAliasAnalysisTest,TupleSelect)856 TEST_F(HloAliasAnalysisTest, TupleSelect) {
857 // Test a kTupleSelect. Non-top-level element flow through the instruction.
858 auto builder = HloComputation::Builder(TestName());
859 auto pred = builder.AddInstruction(
860 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
861 auto constant1 = builder.AddInstruction(
862 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
863 auto constant2 = builder.AddInstruction(
864 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
865 auto constant3 = builder.AddInstruction(
866 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
867 auto constant4 = builder.AddInstruction(
868 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4.0)));
869 auto tuple1 =
870 builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
871 auto tuple2 =
872 builder.AddInstruction(HloInstruction::CreateTuple({constant2}));
873 auto tuple3 =
874 builder.AddInstruction(HloInstruction::CreateTuple({constant3}));
875 auto tuple4 =
876 builder.AddInstruction(HloInstruction::CreateTuple({constant4}));
877 const Shape tuple_shape = tuple1->shape();
878 auto select11 = builder.AddInstruction(HloInstruction::CreateTernary(
879 tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple1));
880 auto select12 = builder.AddInstruction(HloInstruction::CreateTernary(
881 tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple2));
882 auto select34 = builder.AddInstruction(HloInstruction::CreateTernary(
883 tuple_shape, HloOpcode::kTupleSelect, pred, tuple3, tuple4));
884 auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary(
885 tuple_shape, HloOpcode::kTupleSelect, pred, select12, select34));
886
887 module_->AddEntryComputation(builder.Build());
888 SCOPED_TRACE(module_->ToString());
889
890 const HloAliasAnalysis& analysis = RunAnalysis();
891
892 // Verify the buffer sets of each select.
893 EXPECT_THAT(GetBuffersAt(select11, /*index=*/{0}),
894 UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1)));
895 EXPECT_THAT(GetBuffersAt(select12, /*index=*/{0}),
896 UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1),
897 analysis.GetUniqueBufferAt(constant2)));
898 EXPECT_THAT(GetBuffersAt(select34, /*index=*/{0}),
899 UnorderedElementsAre(analysis.GetUniqueBufferAt(constant3),
900 analysis.GetUniqueBufferAt(constant4)));
901 EXPECT_THAT(GetBuffersAt(select1234, /*index=*/{0}),
902 UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1),
903 analysis.GetUniqueBufferAt(constant2),
904 analysis.GetUniqueBufferAt(constant3),
905 analysis.GetUniqueBufferAt(constant4)));
906
907 EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(select11));
908 EXPECT_TRUE(analysis.InstructionBuffersAreAmbiguous(select12));
909 EXPECT_TRUE(analysis.InstructionBuffersAreAmbiguous(select34));
910 EXPECT_TRUE(analysis.InstructionBuffersAreAmbiguous(select1234));
911
912 EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select11));
913 EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select12));
914 EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select34));
915 EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select1234));
916
917 EXPECT_FALSE(AnyValuesInSameBufferInterfere());
918 }
919
TEST_F(HloAliasAnalysisTest,TupleSelectToWhile)920 TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) {
921 // Test a tuple-shaped kTupleSelect feeding a kWhile instruction. HLO:
922 //
923 // body((F32[], F32[]) %tuple_param):
924 // %negate = Negate(%tuple_param{0})
925 // return Tuple(%negate)
926 //
927 // condition((F32[], F32[]) %tuple_param):
928 // return Constant(false)
929 //
930 // entry:
931 // %constant1 = Constant(1.0)
932 // %constant2 = Constant(2.0)
933 // %tuple1 = Tuple(%constant1)
934 // %tuple2 = Tuple(%constant2)
935 // %select = Select(%tuple1, %tuple2)
936 // return While(%select, body, condition)
937 //
938 auto builder = HloComputation::Builder(TestName());
939
940 const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_});
941
942 // Element 0 passes transparently through the body.
943 auto body_builder = HloComputation::Builder("body");
944 auto body_param = body_builder.AddInstruction(
945 HloInstruction::CreateParameter(0, tuple_shape, "param"));
946 auto body_element = body_builder.AddInstruction(
947 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
948 auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
949 scalar_shape_, HloOpcode::kNegate, body_element));
950 body_builder.AddInstruction(HloInstruction::CreateTuple({negate}));
951 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
952
953 auto cond_builder = HloComputation::Builder("condition");
954 auto cond_param = cond_builder.AddInstruction(
955 HloInstruction::CreateParameter(0, tuple_shape, "param"));
956 cond_builder.AddInstruction(
957 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
958 HloComputation* condition =
959 module_->AddEmbeddedComputation(cond_builder.Build());
960
961 auto pred = builder.AddInstruction(
962 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
963 auto constant1 = builder.AddInstruction(
964 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
965 auto constant2 = builder.AddInstruction(
966 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
967 auto tuple1 =
968 builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
969 auto tuple2 =
970 builder.AddInstruction(HloInstruction::CreateTuple({constant2}));
971 auto select = builder.AddInstruction(HloInstruction::CreateTernary(
972 tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple2));
973 auto xla_while = builder.AddInstruction(
974 HloInstruction::CreateWhile(tuple_shape, condition, body, select));
975
976 module_->AddEntryComputation(builder.Build());
977 SCOPED_TRACE(module_->ToString());
978
979 const HloAliasAnalysis& analysis = RunAnalysis();
980
981 // The while should flatten the ambiguous select buffer set so that the buffer
982 // set contents (constant1 and constant2) becomes a single buffer.
983 EXPECT_EQ(analysis.GetUniqueBufferAt(constant1),
984 analysis.GetUniqueBufferAt(constant2));
985 EXPECT_EQ(analysis.GetUniqueBufferAt(constant1),
986 analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}));
987
988 EXPECT_THAT(GetValuesInBuffer(analysis.GetUniqueBufferAt(constant1)),
989 UnorderedElementsAre(GetValueDefinedAt(constant1),
990 GetValueDefinedAt(constant2),
991 GetValueDefinedAt(xla_while, /*index=*/{0}),
992 GetValueDefinedAt(body_param, /*index=*/{0}),
993 GetValueDefinedAt(cond_param, /*index=*/{0}),
994 GetValueDefinedAt(negate)));
995 EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(select));
996 EXPECT_FALSE(analysis.InstructionBuffersAreAmbiguous(xla_while));
997
998 EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(select));
999 EXPECT_TRUE(analysis.InstructionBuffersAreDistinct(xla_while));
1000
1001 // The two operands of the select get flattened into the same buffer resulting
1002 // in liveness interference.
1003 EXPECT_TRUE(AnyValuesInSameBufferInterfere());
1004 }
1005
TEST_F(HloAliasAnalysisTest,Bitcast)1006 TEST_F(HloAliasAnalysisTest, Bitcast) {
1007 // Bitcasting a value should not produce a new buffer.
1008 auto builder = HloComputation::Builder(TestName());
1009 auto constant = builder.AddInstruction(
1010 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1011 auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
1012 scalar_shape_, HloOpcode::kBitcast, constant));
1013
1014 module_->AddEntryComputation(builder.Build());
1015 SCOPED_TRACE(module_->ToString());
1016
1017 const HloAliasAnalysis& analysis = RunAnalysis();
1018
1019 EXPECT_EQ(analysis.buffers().size(), 1);
1020
1021 EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
1022 analysis.GetUniqueBufferAt(bitcast));
1023 }
1024
TEST_F(HloAliasAnalysisTest,BitcastInterference)1025 TEST_F(HloAliasAnalysisTest, BitcastInterference) {
1026 // A bitcast value simultaneously live with its operand should not cause
1027 // interference.
1028 auto builder = HloComputation::Builder(TestName());
1029 auto constant = builder.AddInstruction(
1030 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1031 auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
1032 scalar_shape_, HloOpcode::kBitcast, constant));
1033 builder.AddInstruction(HloInstruction::CreateTuple({constant, bitcast}));
1034
1035 module_->AddEntryComputation(builder.Build());
1036 SCOPED_TRACE(module_->ToString());
1037
1038 const HloAliasAnalysis& analysis = RunAnalysis();
1039
1040 DependencyHloOrdering ordering(module_.get());
1041 EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering));
1042 }
1043
TEST_F(HloAliasAnalysisTest,WhileInterference)1044 TEST_F(HloAliasAnalysisTest, WhileInterference) {
1045 // Build a while loop which has a parallel use of the init value. Depending on
1046 // ordering there may be interference between the update-in-place while and
1047 // the other use of the init.
1048 auto builder = HloComputation::Builder(TestName());
1049 auto init = builder.AddInstruction(
1050 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1051
1052 auto cond_builder = HloComputation::Builder("condition");
1053 auto cond_param = cond_builder.AddInstruction(
1054 HloInstruction::CreateParameter(0, init->shape(), "param"));
1055 auto cond_root = cond_builder.AddInstruction(
1056 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1057 HloComputation* condition =
1058 module_->AddEmbeddedComputation(cond_builder.Build());
1059
1060 auto body_builder = HloComputation::Builder("body");
1061 auto body_param = body_builder.AddInstruction(
1062 HloInstruction::CreateParameter(0, init->shape(), "param"));
1063 auto body_root = body_builder.AddInstruction(
1064 HloInstruction::CreateUnary(init->shape(), HloOpcode::kExp, body_param));
1065 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
1066
1067 auto xla_while = builder.AddInstruction(
1068 HloInstruction::CreateWhile(init->shape(), condition, body, init));
1069
1070 auto negate = builder.AddInstruction(
1071 HloInstruction::CreateUnary(init->shape(), HloOpcode::kNegate, init));
1072 auto entry_root =
1073 builder.AddInstruction(HloInstruction::CreateTuple({negate, xla_while}));
1074
1075 HloComputation* entry = module_->AddEntryComputation(builder.Build());
1076 SCOPED_TRACE(module_->ToString());
1077
1078 const HloAliasAnalysis& analysis = RunAnalysis();
1079
1080 {
1081 // Dependency ordering should interfere because the negate and while are
1082 // unordered.
1083 DependencyHloOrdering ordering(module_.get());
1084 EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering));
1085 }
1086
1087 // For a sequential order, if there is interference iff the negate is after
1088 // the while.
1089 HloSchedule schedule(module_.get());
1090 schedule.set_sequence(body, {body_param, body_root});
1091 schedule.set_sequence(condition, {cond_param, cond_root});
1092 {
1093 schedule.set_sequence(entry, {init, xla_while, negate, entry_root});
1094 TF_ASSERT_OK(schedule.Verify());
1095 SequentialHloOrdering ordering(schedule);
1096 EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering));
1097 }
1098
1099 {
1100 schedule.set_sequence(entry, {init, negate, xla_while, entry_root});
1101 TF_ASSERT_OK(schedule.Verify());
1102 SequentialHloOrdering ordering(schedule);
1103 EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering));
1104 }
1105 }
1106
1107 } // namespace
1108 } // namespace xla
1109