1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/buffer_liveness.h"
17
18 #include <memory>
19 #include <string>
20
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31
32 namespace xla {
33 namespace {
34
35 class BufferLivenessTest : public HloTestBase {
36 protected:
37 // Returns the LogicalBuffer defined at the given instruction and
38 // index. CHECKs if no buffer is defined at that point.
GetBuffer(const BufferLiveness & liveness,const HloInstruction * instruction,const ShapeIndex & index)39 const LogicalBuffer& GetBuffer(const BufferLiveness& liveness,
40 const HloInstruction* instruction,
41 const ShapeIndex& index) {
42 const auto& pointed_to = liveness.points_to_analysis()
43 .GetPointsToSet(instruction)
44 .element(index);
45 CHECK_EQ(1, pointed_to.size());
46 CHECK_EQ(instruction, pointed_to[0]->instruction());
47 CHECK(index == pointed_to[0]->index());
48 return *pointed_to[0];
49 }
50
51 // Returns true if the top-level buffers for instructions 'a' and 'b' may
52 // interfere. Precondition: 'a' and 'b' are array-shaped.
InstructionsMayInterfere(const BufferLiveness & liveness,HloInstruction * a,HloInstruction * b)53 bool InstructionsMayInterfere(const BufferLiveness& liveness,
54 HloInstruction* a, HloInstruction* b) {
55 EXPECT_FALSE(a->shape().IsTuple());
56 EXPECT_FALSE(b->shape().IsTuple());
57 return liveness.MayInterfere(
58 GetBuffer(liveness, /*instruction=*/a, /*index=*/{}),
59 GetBuffer(liveness, /*instruction=*/b, /*index=*/{}));
60 }
61
62 // Returns true if the tuple elements at 'index' for instructions 'a' and 'b'
63 // may interfere. Precondition: 'a' and 'b' are tuple-shaped, with equal
64 // tuple element sub-shapes.
TupleElementsMayInterfere(const BufferLiveness & liveness,HloInstruction * a,HloInstruction * b,const ShapeIndex & index)65 bool TupleElementsMayInterfere(const BufferLiveness& liveness,
66 HloInstruction* a, HloInstruction* b,
67 const ShapeIndex& index) {
68 // Check that top-level shapes are tuple and tuple element shapes are equal.
69 EXPECT_TRUE(a->shape().IsTuple());
70 EXPECT_TRUE(b->shape().IsTuple());
71 EXPECT_TRUE(
72 ShapeUtil::Compatible(ShapeUtil::GetSubshape(a->shape(), index),
73 ShapeUtil::GetSubshape(b->shape(), index)));
74 // Lookup PointsTo set for instructions 'a' and 'b'.
75 auto& points_to_analysis = liveness.points_to_analysis();
76 const auto& points_to_a =
77 points_to_analysis.GetPointsToSet(a).element(index);
78 const auto& points_to_b =
79 points_to_analysis.GetPointsToSet(b).element(index);
80 // Make sure PointsTo sets for 'a' and 'b' are unambiguous.
81 EXPECT_EQ(1, points_to_a.size());
82 EXPECT_EQ(points_to_a.size(), points_to_b.size());
83 // Check interference.
84 return liveness.MayInterfere(*points_to_a[0], *points_to_b[0]);
85 }
86
87 // Returns true if the top-level buffers for the given instruction maybe
88 // liveout of the entry computation.
89 // Precondition: instruction is array-shaped.
InstructionMaybeLiveOut(const BufferLiveness & liveness,HloInstruction * instruction)90 bool InstructionMaybeLiveOut(const BufferLiveness& liveness,
91 HloInstruction* instruction) {
92 return liveness.MaybeLiveOut(
93 GetBuffer(liveness, instruction, /*index=*/{}));
94 }
95
BuildDummyComputation()96 std::unique_ptr<HloComputation> BuildDummyComputation() {
97 auto builder = HloComputation::Builder(TestName() + "_dummy");
98 builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
99 return builder.Build();
100 }
101
102 const Shape vec_ = ShapeUtil::MakeShape(xla::F32, {42});
103 };
104
TEST_F(BufferLivenessTest,ElementwiseChain)105 TEST_F(BufferLivenessTest, ElementwiseChain) {
106 // A simple chain of elementwise operations. No buffers should interfere.
107 //
108 // param --> negate -> exp -> log
109 //
110 auto builder = HloComputation::Builder(TestName());
111 auto param =
112 builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
113 auto negate = builder.AddInstruction(
114 HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param));
115 auto exp = builder.AddInstruction(
116 HloInstruction::CreateUnary(vec_, HloOpcode::kExp, negate));
117 auto log = builder.AddInstruction(
118 HloInstruction::CreateUnary(vec_, HloOpcode::kLog, exp));
119
120 auto module = CreateNewVerifiedModule();
121 module->AddEntryComputation(builder.Build());
122
123 auto liveness =
124 BufferLiveness::Run(
125 module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
126 .ConsumeValueOrDie();
127
128 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate));
129 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp));
130 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, log));
131
132 // No buffers should interfere.
133 EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, exp));
134 EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, log));
135 EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, negate));
136 EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, log));
137 EXPECT_FALSE(InstructionsMayInterfere(*liveness, log, negate));
138 EXPECT_FALSE(InstructionsMayInterfere(*liveness, log, exp));
139
140 // Buffers should interfere with itself.
141 EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, exp));
142
143 // Only log is live out.
144 EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, param));
145 EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, negate));
146 EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, exp));
147 EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, log));
148 }
149
TEST_F(BufferLivenessTest,MultipleEntryParameters_Sequential)150 TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) {
151 // Two entry params, which interfere with each other.
152 //
153 // param0 --> negate ---------------\
154 // param1 --> exp --> add
155 auto builder = HloComputation::Builder(TestName());
156 auto param0 = builder.AddInstruction(
157 HloInstruction::CreateParameter(0, vec_, "param0"));
158 auto param1 = builder.AddInstruction(
159 HloInstruction::CreateParameter(1, vec_, "param1"));
160 auto negate = builder.AddInstruction(
161 HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param0));
162 auto exp = builder.AddInstruction(
163 HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param1));
164 auto add = builder.AddInstruction(
165 HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp));
166
167 auto module = CreateNewVerifiedModule();
168 HloComputation* entry = module->AddEntryComputation(builder.Build());
169
170 HloSchedule schedule(module.get());
171 schedule.set_sequence(entry, {param0, negate, param1, exp, add});
172 auto liveness =
173 BufferLiveness::Run(module.get(),
174 absl::make_unique<SequentialHloOrdering>(schedule))
175 .ConsumeValueOrDie();
176
177 // Entry parameters interfere as if they are defined simultaneously at
178 // the very beginning.
179 EXPECT_TRUE(InstructionsMayInterfere(*liveness, param0, param1));
180 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param0, negate));
181 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param0, exp));
182 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param0, add));
183 EXPECT_TRUE(InstructionsMayInterfere(*liveness, param1, param0));
184 EXPECT_TRUE(InstructionsMayInterfere(*liveness, param1, negate));
185 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param1, exp));
186 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param1, add));
187
188 // Negate and exp still interfere.
189 EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp));
190 EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, negate));
191
192 // But {negate, add} and {exp, add} don't interfere.
193 EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, add));
194 EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, negate));
195 EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, add));
196 EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp));
197 }
198
TEST_F(BufferLivenessTest,NonElementwiseOperand)199 TEST_F(BufferLivenessTest, NonElementwiseOperand) {
200 // A chain of operations with two elementwise and one non-elementwise. The
201 // elementwise op should not interfere with its operand, while the
202 // non-elementwise op should interfere. Entry params always interfere.
203 //
204 // param --> exp -> negate -> reverse
205 //
206 auto builder = HloComputation::Builder(TestName());
207 auto param =
208 builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
209 auto exp = builder.AddInstruction(
210 HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param));
211 auto negate = builder.AddInstruction(
212 HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, exp));
213 auto reverse =
214 builder.AddInstruction(HloInstruction::CreateReverse(vec_, negate, {0}));
215
216 auto module = CreateNewVerifiedModule();
217 module->AddEntryComputation(builder.Build());
218
219 auto liveness =
220 BufferLiveness::Run(
221 module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
222 .ConsumeValueOrDie();
223
224 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp));
225 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate));
226 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, reverse));
227
228 // Negate is elementwise, so doesn't interfere with its operand.
229 // Reverse is non-elementwise, so does interfere with its operand.
230 EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, negate));
231 EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, reverse));
232 }
233
TEST_F(BufferLivenessTest,OverlappedBuffers)234 TEST_F(BufferLivenessTest, OverlappedBuffers) {
235 // Verify simultaneously live buffers interfere (exp and negate).
236 //
237 // param --> negate -> add
238 // \---> exp -----/
239 //
240 auto builder = HloComputation::Builder(TestName());
241 auto param =
242 builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
243 auto negate = builder.AddInstruction(
244 HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param));
245 auto exp = builder.AddInstruction(
246 HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param));
247 auto add = builder.AddInstruction(
248 HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp));
249
250 auto module = CreateNewVerifiedModule();
251 module->AddEntryComputation(builder.Build());
252
253 auto liveness =
254 BufferLiveness::Run(
255 module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
256 .ConsumeValueOrDie();
257
258 EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate));
259 EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, exp));
260 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add));
261
262 // Negate and exp interfere with each other, but not with add.
263 EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp));
264 EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, negate));
265 EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, add));
266 EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, negate));
267 EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, add));
268 EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp));
269 }
270
TEST_F(BufferLivenessTest,OverlappedBuffersSequentialOrder)271 TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) {
272 // Identical to the test OverlappedBuffer but using a sequential ordering of
273 // HLO instructions.
274 //
275 // param --> negate -> add
276 // \---> exp -----/
277 //
278 // Sequential order:
279 // param, negate, exp, add
280 //
281 // Liveness is identical to the DependencyHloOrdering.
282 auto builder = HloComputation::Builder(TestName());
283 auto param =
284 builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
285 auto negate = builder.AddInstruction(
286 HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param));
287 auto exp = builder.AddInstruction(
288 HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param));
289 auto add = builder.AddInstruction(
290 HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp));
291
292 auto module = CreateNewVerifiedModule();
293 auto computation = module->AddEntryComputation(builder.Build());
294
295 HloSchedule schedule(module.get());
296 schedule.set_sequence(computation, {param, negate, exp, add});
297 auto liveness =
298 BufferLiveness::Run(module.get(),
299 absl::make_unique<SequentialHloOrdering>(schedule))
300 .ConsumeValueOrDie();
301
302 EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate));
303 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp));
304 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add));
305
306 // Negate and exp interfere with each other, but not with add.
307 EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp));
308 EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, negate));
309 EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, add));
310 EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, negate));
311 EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, add));
312 EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp));
313 }
314
TEST_F(BufferLivenessTest,RootInstructionIsNotLastInSequentialOrder)315 TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) {
316 // Tests that when the root instruction is not the last instruction in the
317 // schedule, the live range of its buffers interfere with the buffers of the
318 // later instructions.
319 //
320 // Two sets of independent instructions are executed in the computation.
321 // param --> add (root)
322 // recv --> recv-done --> send --> send-done
323 //
324 // Sequential order:
325 // param, add (root), recv, recv-done, send, send-done
326 auto builder = HloComputation::Builder(TestName());
327 auto param =
328 builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
329 auto add = builder.AddInstruction(
330 HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, param, param));
331 auto token = builder.AddInstruction(HloInstruction::CreateToken());
332 auto recv = builder.AddInstruction(
333 HloInstruction::CreateRecv(vec_, token, /*channel_id=*/0));
334 auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
335 auto send = builder.AddInstruction(
336 HloInstruction::CreateSend(recv_done, token, /*channel_id=*/1));
337 auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
338
339 auto module = CreateNewVerifiedModule();
340 auto computation = module->AddEntryComputation(builder.Build(add));
341
342 HloSchedule schedule(module.get());
343 schedule.set_sequence(computation,
344 {param, add, token, recv, recv_done, send, send_done});
345 TF_ASSERT_OK(schedule.Verify());
346 auto liveness =
347 BufferLiveness::Run(module.get(),
348 absl::make_unique<SequentialHloOrdering>(schedule))
349 .ConsumeValueOrDie();
350
351 EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add));
352 // Check the root instruction (add) buffer interferes with the recv buffer.
353 EXPECT_TRUE(
354 liveness->MayInterfere(GetBuffer(*liveness, add, /*index=*/{}),
355 GetBuffer(*liveness, recv, /*index=*/{0})));
356 }
357
TEST_F(BufferLivenessTest,TupleLiveOut)358 TEST_F(BufferLivenessTest, TupleLiveOut) {
359 // Verify MaybeLiveOut with nested tuples. Result of computation looks like:
360 //
361 // Tuple({Tuple({Negate(Param)}, Exp(Negate(Param)))})
362 //
363 // All values should be live out except Param.
364 auto builder = HloComputation::Builder(TestName());
365 auto param =
366 builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
367 auto negate = builder.AddInstruction(
368 HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param));
369 auto inner_tuple =
370 builder.AddInstruction(HloInstruction::CreateTuple({negate}));
371 auto exp = builder.AddInstruction(
372 HloInstruction::CreateUnary(vec_, HloOpcode::kExp, negate));
373 auto outer_tuple =
374 builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple, exp}));
375
376 auto module = CreateNewVerifiedModule();
377 module->AddEntryComputation(builder.Build());
378
379 auto liveness =
380 BufferLiveness::Run(
381 module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
382 .ConsumeValueOrDie();
383
384 // All buffers should be live out except the param
385 EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, param));
386 EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, negate));
387 EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, inner_tuple));
388 EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, exp));
389 EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, outer_tuple));
390 }
391
392 // bitcast liveout.
393
TEST_F(BufferLivenessTest,EmbeddedComputation)394 TEST_F(BufferLivenessTest, EmbeddedComputation) {
395 // Test MaybeLiveOut and MayInterfere for embedded computation.
396 auto module = CreateNewVerifiedModule();
397
398 auto embedded_builder = HloComputation::Builder(TestName() + "_embedded");
399 auto embedded_param = embedded_builder.AddInstruction(
400 HloInstruction::CreateParameter(0, vec_, "embedded_param"));
401 auto embedded_log = embedded_builder.AddInstruction(
402 HloInstruction::CreateUnary(vec_, HloOpcode::kLog, embedded_param));
403
404 auto embedded_computation =
405 module->AddEmbeddedComputation(embedded_builder.Build());
406
407 auto builder = HloComputation::Builder(TestName());
408 auto param =
409 builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
410 auto call = builder.AddInstruction(
411 HloInstruction::CreateCall(vec_, {param}, embedded_computation));
412
413 module->AddEntryComputation(builder.Build());
414
415 auto liveness =
416 BufferLiveness::Run(
417 module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
418 .ConsumeValueOrDie();
419
420 // Buffers in different computations should always interfere.
421 EXPECT_TRUE(InstructionsMayInterfere(*liveness, embedded_log, call));
422 EXPECT_TRUE(InstructionsMayInterfere(*liveness, embedded_param, param));
423 EXPECT_FALSE(
424 InstructionsMayInterfere(*liveness, embedded_param, embedded_log));
425
426 // The only buffers for which MaybeLiveOut == true are those live out
427 // of the entry computation. Buffers live out of embedded computations should
428 // return false for this method.
429 EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, embedded_log));
430 EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, call));
431 }
432
TEST_F(BufferLivenessTest,TupleConstantLiveOut)433 TEST_F(BufferLivenessTest, TupleConstantLiveOut) {
434 // Verify non top-level elements of a nested tuple constant are properly
435 // marked as liveout. Computation:
436 //
437 // GetTupleElement(0, TupleConstant({{0, 1}, {3}})
438 //
439 // Only the array buffers containing 0 and 1 are liveout of the
440 // computation. The buffer containing {0, 1} is copied by GetTupleElement, and
441 // the buffers containing {3} and 3 are dead.
442 auto builder = HloComputation::Builder(TestName());
443 Literal elements0[] = {LiteralUtil::CreateR0<int64>(0),
444 LiteralUtil::CreateR0<int64>(1)};
445 auto inner_tuple0 = LiteralUtil::MakeTuple({&elements0[0], &elements0[1]});
446 Literal element1 = LiteralUtil::CreateR0<int64>(3);
447 auto inner_tuple1 = LiteralUtil::MakeTuple({&element1});
448 auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
449 LiteralUtil::MakeTuple({&inner_tuple0, &inner_tuple1})));
450 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
451 inner_tuple0.shape(), tuple_constant, 0));
452
453 auto module = CreateNewVerifiedModule();
454 module->AddEntryComputation(builder.Build());
455
456 auto liveness =
457 BufferLiveness::Run(
458 module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
459 .ConsumeValueOrDie();
460
461 // Only the element buffers of the tuple constant which are pointed to by
462 // the GetTupleElement instruction should be liveout.
463 EXPECT_FALSE(liveness->MaybeLiveOut(
464 GetBuffer(*liveness, tuple_constant, /*index=*/{})));
465 EXPECT_TRUE(liveness->MaybeLiveOut(
466 GetBuffer(*liveness, tuple_constant, /*index=*/{0})));
467 EXPECT_TRUE(liveness->MaybeLiveOut(
468 GetBuffer(*liveness, tuple_constant, /*index=*/{0, 0})));
469 EXPECT_TRUE(liveness->MaybeLiveOut(
470 GetBuffer(*liveness, tuple_constant, /*index=*/{0, 1})));
471 EXPECT_FALSE(liveness->MaybeLiveOut(
472 GetBuffer(*liveness, tuple_constant, /*index=*/{1})));
473 EXPECT_FALSE(liveness->MaybeLiveOut(
474 GetBuffer(*liveness, tuple_constant, /*index=*/{1, 0})));
475 EXPECT_FALSE(liveness->MaybeLiveOut(
476 GetBuffer(*liveness, tuple_constant, /*index=*/{1, 0})));
477 }
478
TEST_F(BufferLivenessTest,IndependentTupleElements)479 TEST_F(BufferLivenessTest, IndependentTupleElements) {
480 auto builder = HloComputation::Builder(TestName());
481 // Create param0 Tuple.
482 auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
483 0,
484 ShapeUtil::MakeTupleShape(
485 {ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(S32, {4})}),
486 "param0"));
487 // Create independent computations for each tuple elememt.
488
489 // Tuple element0 computation:
490 // Add(GetTupleElement(tuple_param0, 0), const0)
491 auto tuple_element0_shape =
492 ShapeUtil::GetSubshape(tuple_param0->shape(), {0});
493 auto tuple_element0 =
494 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
495 tuple_element0_shape, tuple_param0, 0));
496 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
497 LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
498 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
499 tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0));
500
501 // Tuple element1 computation:
502 // Add(GetTupleElement(tuple_param0, 1), const1)
503 auto tuple_element1_shape =
504 ShapeUtil::GetSubshape(tuple_param0->shape(), {1});
505 auto tuple_element1 =
506 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
507 tuple_element1_shape, tuple_param0, 1));
508 auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
509 LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f})));
510 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
511 tuple_element1_shape, HloOpcode::kAdd, tuple_element1, const1));
512
513 // Create output tuple.
514 auto tuple_root =
515 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
516
517 auto module = CreateNewUnverifiedModule();
518 module->AddEntryComputation(BuildDummyComputation());
519 module->AddEmbeddedComputation(builder.Build());
520
521 auto liveness =
522 BufferLiveness::Run(
523 module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
524 .ConsumeValueOrDie();
525
526 // We compare tuple element pairs that are input/output to the computation:
527 // 1) (input_tuple_element, output_tuple_element) = ('tuple_element0', 'add0')
528 // 2) (input_tuple_element, output_tuple_element) = ('tuple_element1', 'add1')
529
530 // Tuple output element 'add0' does not depend on input 'tuple_element1'.
531 // Tuple output element 'add1' does not depend on input 'tuple_element0'.
532
533 // Both element pair does not interfere, because there is no other dependency
534 // on the pairs tuple input element, and so liveness can compute that all
535 // users of the input tuple element execute before the associated output
536 // tuple element.
537 EXPECT_FALSE(
538 TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {0}));
539 EXPECT_FALSE(
540 TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}));
541 }
542
TEST_F(BufferLivenessTest,DependentTupleElements)543 TEST_F(BufferLivenessTest, DependentTupleElements) {
544 auto builder = HloComputation::Builder(TestName());
545 // Create param0 Tuple.
546 auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
547 0,
548 ShapeUtil::MakeTupleShape(
549 {ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(F32, {8})}),
550 "param0"));
551 // Create dependent computations for each tuple elememt.
552
553 // Tuple element0 computation:
554 // Add(GetTupleElement(tuple_param0, 0), const0)
555 auto tuple_element0_shape =
556 ShapeUtil::GetSubshape(tuple_param0->shape(), {0});
557 auto tuple_element0 =
558 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
559 tuple_element0_shape, tuple_param0, 0));
560 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
561 LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
562 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
563 tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0));
564
565 // Tuple element1 computation:
566 // Add(GetTupleElement(tuple_param0, 0), GetTupleElement(tuple_param0, 1))
567 auto tuple_element1_shape =
568 ShapeUtil::GetSubshape(tuple_param0->shape(), {1});
569 auto tuple_element1 =
570 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
571 tuple_element1_shape, tuple_param0, 1));
572 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
573 tuple_element1_shape, HloOpcode::kAdd, tuple_element0, tuple_element1));
574
575 // Create output tuple.
576 auto tuple_root =
577 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
578
579 auto module = CreateNewVerifiedModule();
580 module->AddEntryComputation(BuildDummyComputation());
581 module->AddEmbeddedComputation(builder.Build());
582
583 auto liveness =
584 BufferLiveness::Run(
585 module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
586 .ConsumeValueOrDie();
587
588 // We compare tuple element pairs that are input/output to the computation:
589 // 1) (input_tuple_element, output_tuple_element) = ('tuple_element0', 'add0')
590 // 2) (input_tuple_element, output_tuple_element) = ('tuple_element1', 'add1')
591
592 // The first tuple element pair output 'add0', has no dependency on second
593 // tuple element pairs input 'tuple_element1'.
594
595 // The second tuple element pair output 'add1', has a dependency on first
596 // tuple element pairs input 'tuple_element0'.
597
598 // The first tuple element pair does interfere, because liveness cannot
599 // compute that all references to 'tuple_element0' are executed before 'add0'
600 // (because of the depenency of 'add1' on 'tuple_element0').
601 EXPECT_TRUE(
602 TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {0}));
603
604 // The second tuple element pair does not interfere, because there is no
605 // other dependency on 'tuple_element1', and so liveness can compute that
606 // all users execute before 'add1'.
607 EXPECT_FALSE(
608 TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}));
609 }
610
611 class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
612 protected:
613 // Builds and runs a computation (see test case computation graphs below).
BuildModule(const bool update_uses_tuple_element1,const bool fuse_gte0)614 std::unique_ptr<VerifiedHloModule> BuildModule(
615 const bool update_uses_tuple_element1, const bool fuse_gte0) {
616 auto builder = HloComputation::Builder(TestName());
617 // Create param0 Tuple.
618 Shape data_shape = ShapeUtil::MakeShape(F32, {8});
619 Shape update_shape = ShapeUtil::MakeShape(F32, {3});
620 auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
621 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "param0"));
622
623 auto gte0 = builder.AddInstruction(
624 HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 0));
625
626 auto gte1 = builder.AddInstruction(
627 HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1));
628
629 auto update = builder.AddInstruction(HloInstruction::CreateConstant(
630 LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
631 HloInstruction* slice = nullptr;
632 if (update_uses_tuple_element1) {
633 // Create a slice instruction as an additional user of 'gte1'.
634 slice = builder.AddInstruction(
635 HloInstruction::CreateSlice(update_shape, gte1, {0}, {3}, {1}));
636 update = builder.AddInstruction(HloInstruction::CreateBinary(
637 update_shape, HloOpcode::kAdd, update, slice));
638 }
639 // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'.
640 auto starts = builder.AddInstruction(
641 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
642 auto dynamic_update_slice =
643 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
644 data_shape, gte1, update, {starts}));
645 // Create output tuple.
646 builder.AddInstruction(
647 HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
648 // Build module and get reference to entry computation.
649 auto module = CreateNewVerifiedModule();
650 module->AddEntryComputation(builder.Build());
651 auto* computation = module->entry_computation();
652 // Create fusion instruction based on number of tuple element 1 users.
653 if (update_uses_tuple_element1) {
654 computation->CreateFusionInstruction(
655 {dynamic_update_slice, starts, update, CHECK_NOTNULL(slice), gte1},
656 HloInstruction::FusionKind::kLoop);
657 } else {
658 computation->CreateFusionInstruction(
659 {dynamic_update_slice, starts, update, gte1},
660 HloInstruction::FusionKind::kLoop);
661 }
662 // Create fusion instruction for tuple element 0 (if requested).
663 if (fuse_gte0) {
664 computation->CreateFusionInstruction({gte0},
665 HloInstruction::FusionKind::kLoop);
666 }
667 return module;
668 }
669
670 // Returns whether buffer interference is detected between tuple-shaped
671 // parameter and root instructions at tuple element 1.
Run(const bool update_uses_tuple_element1,const bool fuse_gte0=false)672 bool Run(const bool update_uses_tuple_element1,
673 const bool fuse_gte0 = false) {
674 auto module = BuildModule(update_uses_tuple_element1, fuse_gte0);
675 // Run BufferLiveness on 'module'.
676 auto liveness = BufferLiveness::Run(
677 module.get(),
678 absl::make_unique<DependencyHloOrdering>(module.get()))
679 .ConsumeValueOrDie();
680 // Return whether or not buffers interference is detected between
681 // 'tuple_param0' and 'tuple_root' at shape index '{1}'.
682 auto tuple_param0 = FindInstruction(module.get(), "param0");
683 auto tuple_root = module->entry_computation()->root_instruction();
684 return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1});
685 }
RunWithHloDataflowAnalysis(const bool update_uses_tuple_element1,const bool fuse_gte0=false)686 bool RunWithHloDataflowAnalysis(const bool update_uses_tuple_element1,
687 const bool fuse_gte0 = false) {
688 auto module = BuildModule(update_uses_tuple_element1, fuse_gte0);
689 // Run BufferLiveness on 'module'.
690 auto dataflow = HloDataflowAnalysis::Run(*module).ConsumeValueOrDie();
691 auto hlo_ordering = absl::make_unique<DependencyHloOrdering>(module.get());
692 // Return whether or not buffers interference is detected between
693 // 'tuple_param0' and 'tuple_root' at shape index '{1}'.
694 auto tuple_param0 = FindInstruction(module.get(), "param0");
695 auto tuple_root = module->entry_computation()->root_instruction();
696 return hlo_ordering->MayInterfere(
697 dataflow->GetUniqueValueAt(tuple_param0, {1}),
698 dataflow->GetUniqueValueAt(tuple_root, {1}), *dataflow);
699 }
700 };
701
702 // Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion)
703 // do not overlap with the following computation:
704 //
705 // Param0
706 // / \
707 // GTE(0) Fusion -----------> FusionParam
708 // | | |
709 // | | GTE(1) Const Const
710 // | | \ | /
711 // | | DynamicUpdateSlice // fused root
712 // \ /
713 // Tuple // computation root
714 //
TEST_F(FusedDynamicUpdateSliceLivenessTest,NoInterference)715 TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterference) {
716 EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false));
717 EXPECT_FALSE(
718 RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/false));
719 }
720
721 // Tests that live ranges of buffers Param0[1] and Tuple[1] (which aliases
722 // 'fusion1') do not overlap in the presence of another fusion instruction
723 // (which is a user of 'param0' at a different tuple index).
724 // BufferLiveness should detect no uses of Param0 at index {1} in Fusion0
725 // (because Fusion0 only uses Param0 at index {0}).
726 //
727 // Param0
728 // / \
729 // FusionParam <----- Fusion0 Fusion1 ------> FusionParam
730 // | | | |
731 // GTE(0) | | GTE(1) Const Const
732 // | | \ | /
733 // \ / DynamicUpdateSlice
734 // Tuple
735 //
TEST_F(FusedDynamicUpdateSliceLivenessTest,NoInterferenceWithUnrelatedFusion)736 TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterferenceWithUnrelatedFusion) {
737 EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false, /*fuse_gte0=*/true));
738 EXPECT_FALSE(RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/false,
739 /*fuse_gte0=*/true));
740 }
741
742 // Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion)
743 // do overlap because GTE(1) has two users:
744 // 1) DynamicUpdateSlice at operand 0.
745 // 2) Slice at operand 0.
746 //
747 // Param0
748 // / \ Const
749 // / \ /
750 // GTE(0) Fusion -----------> FusionParam FusionParam
751 // | | | |
752 // | | GTE(1) /
753 // | | | \ /
754 // | | | Slice /
755 // | | | \ /
756 // | | | Add Const
757 // | | | | |
758 // | | DynamicUpdateSlice // fused root
759 // \ /
760 // Tuple // computation root
761 //
TEST_F(FusedDynamicUpdateSliceLivenessTest,WithInterference)762 TEST_F(FusedDynamicUpdateSliceLivenessTest, WithInterference) {
763 EXPECT_TRUE(Run(/*update_uses_tuple_element1=*/true));
764 EXPECT_TRUE(RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/true));
765 }
766
767 class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
768 protected:
769 // Builds and runs a computation (see test case computation graphs below).
770 // Runs BufferLiveness on this computation.
771 // Returns whether buffer interference is detected between tuple-shaped
772 // parameter and root instructions at tuple element 1.
Run(const bool tuple_element1_has_two_uses)773 bool Run(const bool tuple_element1_has_two_uses) {
774 auto builder = HloComputation::Builder(TestName());
775 // Create param0 Tuple.
776 Shape data_shape = ShapeUtil::MakeShape(F32, {8});
777 Shape update_shape = ShapeUtil::MakeShape(F32, {3});
778 auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
779 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "param0"));
780
781 auto gte0 = builder.AddInstruction(
782 HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 0));
783
784 auto gte1 = builder.AddInstruction(
785 HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1));
786
787 auto update = builder.AddInstruction(HloInstruction::CreateConstant(
788 LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
789
790 if (tuple_element1_has_two_uses) {
791 // Add 'gte0' and 'gte1' to create another user of 'gte1'.
792 gte0 = builder.AddInstruction(HloInstruction::CreateBinary(
793 data_shape, HloOpcode::kAdd, gte0, gte1));
794 }
795 // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'.
796 auto starts = builder.AddInstruction(
797 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
798 auto dynamic_update_slice =
799 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
800 data_shape, gte1, update, {starts}));
801 // Create output tuple.
802 auto tuple_root = builder.AddInstruction(
803 HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
804 // Build module and get reference to entry computation.
805 auto module = CreateNewVerifiedModule();
806 module->AddEntryComputation(BuildDummyComputation());
807 module->AddEmbeddedComputation(builder.Build());
808 // Run BufferLiveness on 'module'.
809 auto liveness = BufferLiveness::Run(
810 module.get(),
811 absl::make_unique<DependencyHloOrdering>(module.get()))
812 .ConsumeValueOrDie();
813 // Return whether or not buffers interference is detected between
814 // 'tuple_param0' and 'tuple_root' at shape index '{1}'.
815 return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1});
816 }
817 };
818
819 // Tests that live ranges of buffers Param0[1] and Tuple[1] do not overlap in
820 // the following computation (because DynamicUpdateSlice (at operand 0) is the
821 // unique user):
822 //
823 // Parameter0
824 // | |
825 // GTE(0) GTE(1) Const Const
826 // | \ | /
827 // | DynamicUpdateSlice
828 // \ /
829 // Tuple
830 //
TEST_F(DynamicUpdateSliceLivenessTest,NoInterference)831 TEST_F(DynamicUpdateSliceLivenessTest, NoInterference) {
832 EXPECT_FALSE(Run(/*tuple_element1_has_two_uses=*/false));
833 }
834
835 // Tests that live ranges of buffers Param0[1] and Tuple[1] do overlap because
836 // GTE(1) has two users:
837 // 1) DynamicUpdateSlice at operand 0.
838 // 2) Add at operand 1.
839 //
840 // Parameter0
841 // | |
842 // GTE(0) GTE(1)
843 // | / |
844 // | / |
845 // Add | Const Const
846 // | | | |
847 // | DynamicUpdateSlice
848 // \ /
849 // Tuple
850 //
TEST_F(DynamicUpdateSliceLivenessTest,WithInterference)851 TEST_F(DynamicUpdateSliceLivenessTest, WithInterference) {
852 EXPECT_TRUE(Run(/*tuple_element1_has_two_uses=*/true));
853 }
854
855 } // namespace
856
857 } // namespace xla
858