1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/buffer_liveness.h"
17 
18 #include <memory>
19 #include <string>
20 
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 
32 namespace xla {
33 namespace {
34 
35 class BufferLivenessTest : public HloTestBase {
36  protected:
37   // Returns the LogicalBuffer defined at the given instruction and
38   // index. CHECKs if no buffer is defined at that point.
GetBuffer(const BufferLiveness & liveness,const HloInstruction * instruction,const ShapeIndex & index)39   const LogicalBuffer& GetBuffer(const BufferLiveness& liveness,
40                                  const HloInstruction* instruction,
41                                  const ShapeIndex& index) {
42     const auto& pointed_to = liveness.points_to_analysis()
43                                  .GetPointsToSet(instruction)
44                                  .element(index);
45     CHECK_EQ(1, pointed_to.size());
46     CHECK_EQ(instruction, pointed_to[0]->instruction());
47     CHECK(index == pointed_to[0]->index());
48     return *pointed_to[0];
49   }
50 
51   // Returns true if the top-level buffers for instructions 'a' and 'b' may
52   // interfere. Precondition: 'a' and 'b' are array-shaped.
InstructionsMayInterfere(const BufferLiveness & liveness,HloInstruction * a,HloInstruction * b)53   bool InstructionsMayInterfere(const BufferLiveness& liveness,
54                                 HloInstruction* a, HloInstruction* b) {
55     EXPECT_FALSE(a->shape().IsTuple());
56     EXPECT_FALSE(b->shape().IsTuple());
57     return liveness.MayInterfere(
58         GetBuffer(liveness, /*instruction=*/a, /*index=*/{}),
59         GetBuffer(liveness, /*instruction=*/b, /*index=*/{}));
60   }
61 
62   // Returns true if the tuple elements at 'index' for instructions 'a' and 'b'
63   // may interfere. Precondition: 'a' and 'b' are tuple-shaped, with equal
64   // tuple element sub-shapes.
TupleElementsMayInterfere(const BufferLiveness & liveness,HloInstruction * a,HloInstruction * b,const ShapeIndex & index)65   bool TupleElementsMayInterfere(const BufferLiveness& liveness,
66                                  HloInstruction* a, HloInstruction* b,
67                                  const ShapeIndex& index) {
68     // Check that top-level shapes are tuple and tuple element shapes are equal.
69     EXPECT_TRUE(a->shape().IsTuple());
70     EXPECT_TRUE(b->shape().IsTuple());
71     EXPECT_TRUE(
72         ShapeUtil::Compatible(ShapeUtil::GetSubshape(a->shape(), index),
73                               ShapeUtil::GetSubshape(b->shape(), index)));
74     // Lookup PointsTo set for instructions 'a' and 'b'.
75     auto& points_to_analysis = liveness.points_to_analysis();
76     const auto& points_to_a =
77         points_to_analysis.GetPointsToSet(a).element(index);
78     const auto& points_to_b =
79         points_to_analysis.GetPointsToSet(b).element(index);
80     // Make sure PointsTo sets for 'a' and 'b' are unambiguous.
81     EXPECT_EQ(1, points_to_a.size());
82     EXPECT_EQ(points_to_a.size(), points_to_b.size());
83     // Check interference.
84     return liveness.MayInterfere(*points_to_a[0], *points_to_b[0]);
85   }
86 
87   // Returns true if the top-level buffers for the given instruction maybe
88   // liveout of the entry computation.
89   // Precondition: instruction is array-shaped.
InstructionMaybeLiveOut(const BufferLiveness & liveness,HloInstruction * instruction)90   bool InstructionMaybeLiveOut(const BufferLiveness& liveness,
91                                HloInstruction* instruction) {
92     return liveness.MaybeLiveOut(
93         GetBuffer(liveness, instruction, /*index=*/{}));
94   }
95 
BuildDummyComputation()96   std::unique_ptr<HloComputation> BuildDummyComputation() {
97     auto builder = HloComputation::Builder(TestName() + "_dummy");
98     builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
99     return builder.Build();
100   }
101 
102   const Shape vec_ = ShapeUtil::MakeShape(xla::F32, {42});
103 };
104 
TEST_F(BufferLivenessTest,ElementwiseChain)105 TEST_F(BufferLivenessTest, ElementwiseChain) {
106   // A simple chain of elementwise operations. No buffers should interfere.
107   //
108   // param --> negate -> exp -> log
109   //
110   auto builder = HloComputation::Builder(TestName());
111   auto param =
112       builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
113   auto negate = builder.AddInstruction(
114       HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param));
115   auto exp = builder.AddInstruction(
116       HloInstruction::CreateUnary(vec_, HloOpcode::kExp, negate));
117   auto log = builder.AddInstruction(
118       HloInstruction::CreateUnary(vec_, HloOpcode::kLog, exp));
119 
120   auto module = CreateNewVerifiedModule();
121   module->AddEntryComputation(builder.Build());
122 
123   auto liveness =
124       BufferLiveness::Run(
125           module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
126           .ConsumeValueOrDie();
127 
128   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate));
129   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp));
130   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, log));
131 
132   // No buffers should interfere.
133   EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, exp));
134   EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, log));
135   EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, negate));
136   EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, log));
137   EXPECT_FALSE(InstructionsMayInterfere(*liveness, log, negate));
138   EXPECT_FALSE(InstructionsMayInterfere(*liveness, log, exp));
139 
140   // Buffers should interfere with itself.
141   EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, exp));
142 
143   // Only log is live out.
144   EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, param));
145   EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, negate));
146   EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, exp));
147   EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, log));
148 }
149 
TEST_F(BufferLivenessTest,MultipleEntryParameters_Sequential)150 TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) {
151   // Two entry params, which interfere with each other.
152   //
153   // param0 --> negate ---------------\
154   //                   param1 --> exp --> add
155   auto builder = HloComputation::Builder(TestName());
156   auto param0 = builder.AddInstruction(
157       HloInstruction::CreateParameter(0, vec_, "param0"));
158   auto param1 = builder.AddInstruction(
159       HloInstruction::CreateParameter(1, vec_, "param1"));
160   auto negate = builder.AddInstruction(
161       HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param0));
162   auto exp = builder.AddInstruction(
163       HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param1));
164   auto add = builder.AddInstruction(
165       HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp));
166 
167   auto module = CreateNewVerifiedModule();
168   HloComputation* entry = module->AddEntryComputation(builder.Build());
169 
170   HloSchedule schedule(module.get());
171   schedule.set_sequence(entry, {param0, negate, param1, exp, add});
172   auto liveness =
173       BufferLiveness::Run(module.get(),
174                           absl::make_unique<SequentialHloOrdering>(schedule))
175           .ConsumeValueOrDie();
176 
177   // Entry parameters interfere as if they are defined simultaneously at
178   // the very beginning.
179   EXPECT_TRUE(InstructionsMayInterfere(*liveness, param0, param1));
180   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param0, negate));
181   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param0, exp));
182   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param0, add));
183   EXPECT_TRUE(InstructionsMayInterfere(*liveness, param1, param0));
184   EXPECT_TRUE(InstructionsMayInterfere(*liveness, param1, negate));
185   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param1, exp));
186   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param1, add));
187 
188   // Negate and exp still interfere.
189   EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp));
190   EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, negate));
191 
192   // But {negate, add} and {exp, add} don't interfere.
193   EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, add));
194   EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, negate));
195   EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, add));
196   EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp));
197 }
198 
TEST_F(BufferLivenessTest,NonElementwiseOperand)199 TEST_F(BufferLivenessTest, NonElementwiseOperand) {
200   // A chain of operations with two elementwise and one non-elementwise. The
201   // elementwise op should not interfere with its operand, while the
202   // non-elementwise op should interfere. Entry params always interfere.
203   //
204   // param --> exp -> negate -> reverse
205   //
206   auto builder = HloComputation::Builder(TestName());
207   auto param =
208       builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
209   auto exp = builder.AddInstruction(
210       HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param));
211   auto negate = builder.AddInstruction(
212       HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, exp));
213   auto reverse =
214       builder.AddInstruction(HloInstruction::CreateReverse(vec_, negate, {0}));
215 
216   auto module = CreateNewVerifiedModule();
217   module->AddEntryComputation(builder.Build());
218 
219   auto liveness =
220       BufferLiveness::Run(
221           module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
222           .ConsumeValueOrDie();
223 
224   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp));
225   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate));
226   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, reverse));
227 
228   // Negate is elementwise, so doesn't interfere with its operand.
229   // Reverse is non-elementwise, so does interfere with its operand.
230   EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, negate));
231   EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, reverse));
232 }
233 
TEST_F(BufferLivenessTest,OverlappedBuffers)234 TEST_F(BufferLivenessTest, OverlappedBuffers) {
235   // Verify simultaneously live buffers interfere (exp and negate).
236   //
237   // param --> negate -> add
238   //     \---> exp -----/
239   //
240   auto builder = HloComputation::Builder(TestName());
241   auto param =
242       builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
243   auto negate = builder.AddInstruction(
244       HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param));
245   auto exp = builder.AddInstruction(
246       HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param));
247   auto add = builder.AddInstruction(
248       HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp));
249 
250   auto module = CreateNewVerifiedModule();
251   module->AddEntryComputation(builder.Build());
252 
253   auto liveness =
254       BufferLiveness::Run(
255           module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
256           .ConsumeValueOrDie();
257 
258   EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate));
259   EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, exp));
260   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add));
261 
262   // Negate and exp interfere with each other, but not with add.
263   EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp));
264   EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, negate));
265   EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, add));
266   EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, negate));
267   EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, add));
268   EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp));
269 }
270 
TEST_F(BufferLivenessTest,OverlappedBuffersSequentialOrder)271 TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) {
272   // Identical to the test OverlappedBuffer but using a sequential ordering of
273   // HLO instructions.
274   //
275   // param --> negate -> add
276   //     \---> exp -----/
277   //
278   // Sequential order:
279   //  param, negate, exp, add
280   //
281   // Liveness is identical to the DependencyHloOrdering.
282   auto builder = HloComputation::Builder(TestName());
283   auto param =
284       builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
285   auto negate = builder.AddInstruction(
286       HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param));
287   auto exp = builder.AddInstruction(
288       HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param));
289   auto add = builder.AddInstruction(
290       HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp));
291 
292   auto module = CreateNewVerifiedModule();
293   auto computation = module->AddEntryComputation(builder.Build());
294 
295   HloSchedule schedule(module.get());
296   schedule.set_sequence(computation, {param, negate, exp, add});
297   auto liveness =
298       BufferLiveness::Run(module.get(),
299                           absl::make_unique<SequentialHloOrdering>(schedule))
300           .ConsumeValueOrDie();
301 
302   EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate));
303   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp));
304   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add));
305 
306   // Negate and exp interfere with each other, but not with add.
307   EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp));
308   EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, negate));
309   EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, add));
310   EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, negate));
311   EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, add));
312   EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp));
313 }
314 
TEST_F(BufferLivenessTest,RootInstructionIsNotLastInSequentialOrder)315 TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) {
316   // Tests that when the root instruction is not the last instruction in the
317   // schedule, the live range of its buffers interfere with the buffers of the
318   // later instructions.
319   //
320   // Two sets of independent instructions are executed in the computation.
321   // param --> add (root)
322   // recv --> recv-done --> send --> send-done
323   //
324   // Sequential order:
325   //  param, add (root), recv, recv-done, send, send-done
326   auto builder = HloComputation::Builder(TestName());
327   auto param =
328       builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
329   auto add = builder.AddInstruction(
330       HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, param, param));
331   auto token = builder.AddInstruction(HloInstruction::CreateToken());
332   auto recv = builder.AddInstruction(
333       HloInstruction::CreateRecv(vec_, token, /*channel_id=*/0));
334   auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
335   auto send = builder.AddInstruction(
336       HloInstruction::CreateSend(recv_done, token, /*channel_id=*/1));
337   auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
338 
339   auto module = CreateNewVerifiedModule();
340   auto computation = module->AddEntryComputation(builder.Build(add));
341 
342   HloSchedule schedule(module.get());
343   schedule.set_sequence(computation,
344                         {param, add, token, recv, recv_done, send, send_done});
345   TF_ASSERT_OK(schedule.Verify());
346   auto liveness =
347       BufferLiveness::Run(module.get(),
348                           absl::make_unique<SequentialHloOrdering>(schedule))
349           .ConsumeValueOrDie();
350 
351   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add));
352   // Check the root instruction (add) buffer interferes with the recv buffer.
353   EXPECT_TRUE(
354       liveness->MayInterfere(GetBuffer(*liveness, add, /*index=*/{}),
355                              GetBuffer(*liveness, recv, /*index=*/{0})));
356 }
357 
TEST_F(BufferLivenessTest,TupleLiveOut)358 TEST_F(BufferLivenessTest, TupleLiveOut) {
359   // Verify MaybeLiveOut with nested tuples. Result of computation looks like:
360   //
361   //   Tuple({Tuple({Negate(Param)}, Exp(Negate(Param)))})
362   //
363   // All values should be live out except Param.
364   auto builder = HloComputation::Builder(TestName());
365   auto param =
366       builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
367   auto negate = builder.AddInstruction(
368       HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param));
369   auto inner_tuple =
370       builder.AddInstruction(HloInstruction::CreateTuple({negate}));
371   auto exp = builder.AddInstruction(
372       HloInstruction::CreateUnary(vec_, HloOpcode::kExp, negate));
373   auto outer_tuple =
374       builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple, exp}));
375 
376   auto module = CreateNewVerifiedModule();
377   module->AddEntryComputation(builder.Build());
378 
379   auto liveness =
380       BufferLiveness::Run(
381           module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
382           .ConsumeValueOrDie();
383 
384   // All buffers should be live out except the param
385   EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, param));
386   EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, negate));
387   EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, inner_tuple));
388   EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, exp));
389   EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, outer_tuple));
390 }
391 
392 // bitcast liveout.
393 
TEST_F(BufferLivenessTest,EmbeddedComputation)394 TEST_F(BufferLivenessTest, EmbeddedComputation) {
395   // Test MaybeLiveOut and MayInterfere for embedded computation.
396   auto module = CreateNewVerifiedModule();
397 
398   auto embedded_builder = HloComputation::Builder(TestName() + "_embedded");
399   auto embedded_param = embedded_builder.AddInstruction(
400       HloInstruction::CreateParameter(0, vec_, "embedded_param"));
401   auto embedded_log = embedded_builder.AddInstruction(
402       HloInstruction::CreateUnary(vec_, HloOpcode::kLog, embedded_param));
403 
404   auto embedded_computation =
405       module->AddEmbeddedComputation(embedded_builder.Build());
406 
407   auto builder = HloComputation::Builder(TestName());
408   auto param =
409       builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
410   auto call = builder.AddInstruction(
411       HloInstruction::CreateCall(vec_, {param}, embedded_computation));
412 
413   module->AddEntryComputation(builder.Build());
414 
415   auto liveness =
416       BufferLiveness::Run(
417           module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
418           .ConsumeValueOrDie();
419 
420   // Buffers in different computations should always interfere.
421   EXPECT_TRUE(InstructionsMayInterfere(*liveness, embedded_log, call));
422   EXPECT_TRUE(InstructionsMayInterfere(*liveness, embedded_param, param));
423   EXPECT_FALSE(
424       InstructionsMayInterfere(*liveness, embedded_param, embedded_log));
425 
426   // The only buffers for which MaybeLiveOut == true are those live out
427   // of the entry computation. Buffers live out of embedded computations should
428   // return false for this method.
429   EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, embedded_log));
430   EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, call));
431 }
432 
TEST_F(BufferLivenessTest,TupleConstantLiveOut)433 TEST_F(BufferLivenessTest, TupleConstantLiveOut) {
434   // Verify non top-level elements of a nested tuple constant are properly
435   // marked as liveout. Computation:
436   //
437   //   GetTupleElement(0, TupleConstant({{0, 1}, {3}})
438   //
439   // Only the array buffers containing 0 and 1 are liveout of the
440   // computation. The buffer containing {0, 1} is copied by GetTupleElement, and
441   // the buffers containing {3} and 3 are dead.
442   auto builder = HloComputation::Builder(TestName());
443   Literal elements0[] = {LiteralUtil::CreateR0<int64>(0),
444                          LiteralUtil::CreateR0<int64>(1)};
445   auto inner_tuple0 = LiteralUtil::MakeTuple({&elements0[0], &elements0[1]});
446   Literal element1 = LiteralUtil::CreateR0<int64>(3);
447   auto inner_tuple1 = LiteralUtil::MakeTuple({&element1});
448   auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
449       LiteralUtil::MakeTuple({&inner_tuple0, &inner_tuple1})));
450   builder.AddInstruction(HloInstruction::CreateGetTupleElement(
451       inner_tuple0.shape(), tuple_constant, 0));
452 
453   auto module = CreateNewVerifiedModule();
454   module->AddEntryComputation(builder.Build());
455 
456   auto liveness =
457       BufferLiveness::Run(
458           module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
459           .ConsumeValueOrDie();
460 
461   // Only the element buffers of the tuple constant which are pointed to by
462   // the GetTupleElement instruction should be liveout.
463   EXPECT_FALSE(liveness->MaybeLiveOut(
464       GetBuffer(*liveness, tuple_constant, /*index=*/{})));
465   EXPECT_TRUE(liveness->MaybeLiveOut(
466       GetBuffer(*liveness, tuple_constant, /*index=*/{0})));
467   EXPECT_TRUE(liveness->MaybeLiveOut(
468       GetBuffer(*liveness, tuple_constant, /*index=*/{0, 0})));
469   EXPECT_TRUE(liveness->MaybeLiveOut(
470       GetBuffer(*liveness, tuple_constant, /*index=*/{0, 1})));
471   EXPECT_FALSE(liveness->MaybeLiveOut(
472       GetBuffer(*liveness, tuple_constant, /*index=*/{1})));
473   EXPECT_FALSE(liveness->MaybeLiveOut(
474       GetBuffer(*liveness, tuple_constant, /*index=*/{1, 0})));
475   EXPECT_FALSE(liveness->MaybeLiveOut(
476       GetBuffer(*liveness, tuple_constant, /*index=*/{1, 0})));
477 }
478 
TEST_F(BufferLivenessTest,IndependentTupleElements)479 TEST_F(BufferLivenessTest, IndependentTupleElements) {
480   auto builder = HloComputation::Builder(TestName());
481   // Create param0 Tuple.
482   auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
483       0,
484       ShapeUtil::MakeTupleShape(
485           {ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(S32, {4})}),
486       "param0"));
487   // Create independent computations for each tuple elememt.
488 
489   // Tuple element0 computation:
490   //   Add(GetTupleElement(tuple_param0, 0), const0)
491   auto tuple_element0_shape =
492       ShapeUtil::GetSubshape(tuple_param0->shape(), {0});
493   auto tuple_element0 =
494       builder.AddInstruction(HloInstruction::CreateGetTupleElement(
495           tuple_element0_shape, tuple_param0, 0));
496   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
497       LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
498   auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
499       tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0));
500 
501   // Tuple element1 computation:
502   //   Add(GetTupleElement(tuple_param0, 1), const1)
503   auto tuple_element1_shape =
504       ShapeUtil::GetSubshape(tuple_param0->shape(), {1});
505   auto tuple_element1 =
506       builder.AddInstruction(HloInstruction::CreateGetTupleElement(
507           tuple_element1_shape, tuple_param0, 1));
508   auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
509       LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f})));
510   auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
511       tuple_element1_shape, HloOpcode::kAdd, tuple_element1, const1));
512 
513   // Create output tuple.
514   auto tuple_root =
515       builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
516 
517   auto module = CreateNewUnverifiedModule();
518   module->AddEntryComputation(BuildDummyComputation());
519   module->AddEmbeddedComputation(builder.Build());
520 
521   auto liveness =
522       BufferLiveness::Run(
523           module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
524           .ConsumeValueOrDie();
525 
526   // We compare tuple element pairs that are input/output to the computation:
527   // 1) (input_tuple_element, output_tuple_element) = ('tuple_element0', 'add0')
528   // 2) (input_tuple_element, output_tuple_element) = ('tuple_element1', 'add1')
529 
530   // Tuple output element 'add0' does not depend on input 'tuple_element1'.
531   // Tuple output element 'add1' does not depend on input 'tuple_element0'.
532 
533   // Both element pair does not interfere, because there is no other dependency
534   // on the pairs tuple input element, and so liveness can compute that all
535   // users of the input tuple element execute before the associated output
536   // tuple element.
537   EXPECT_FALSE(
538       TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {0}));
539   EXPECT_FALSE(
540       TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}));
541 }
542 
TEST_F(BufferLivenessTest,DependentTupleElements)543 TEST_F(BufferLivenessTest, DependentTupleElements) {
544   auto builder = HloComputation::Builder(TestName());
545   // Create param0 Tuple.
546   auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
547       0,
548       ShapeUtil::MakeTupleShape(
549           {ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(F32, {8})}),
550       "param0"));
551   // Create dependent computations for each tuple elememt.
552 
553   // Tuple element0 computation:
554   //   Add(GetTupleElement(tuple_param0, 0), const0)
555   auto tuple_element0_shape =
556       ShapeUtil::GetSubshape(tuple_param0->shape(), {0});
557   auto tuple_element0 =
558       builder.AddInstruction(HloInstruction::CreateGetTupleElement(
559           tuple_element0_shape, tuple_param0, 0));
560   auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
561       LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
562   auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
563       tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0));
564 
565   // Tuple element1 computation:
566   //   Add(GetTupleElement(tuple_param0, 0), GetTupleElement(tuple_param0, 1))
567   auto tuple_element1_shape =
568       ShapeUtil::GetSubshape(tuple_param0->shape(), {1});
569   auto tuple_element1 =
570       builder.AddInstruction(HloInstruction::CreateGetTupleElement(
571           tuple_element1_shape, tuple_param0, 1));
572   auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
573       tuple_element1_shape, HloOpcode::kAdd, tuple_element0, tuple_element1));
574 
575   // Create output tuple.
576   auto tuple_root =
577       builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
578 
579   auto module = CreateNewVerifiedModule();
580   module->AddEntryComputation(BuildDummyComputation());
581   module->AddEmbeddedComputation(builder.Build());
582 
583   auto liveness =
584       BufferLiveness::Run(
585           module.get(), absl::make_unique<DependencyHloOrdering>(module.get()))
586           .ConsumeValueOrDie();
587 
588   // We compare tuple element pairs that are input/output to the computation:
589   // 1) (input_tuple_element, output_tuple_element) = ('tuple_element0', 'add0')
590   // 2) (input_tuple_element, output_tuple_element) = ('tuple_element1', 'add1')
591 
592   // The first tuple element pair output 'add0', has no dependency on second
593   // tuple element pairs input 'tuple_element1'.
594 
595   // The second tuple element pair output 'add1', has a dependency on first
596   // tuple element pairs input 'tuple_element0'.
597 
598   // The first tuple element pair does interfere, because liveness cannot
599   // compute that all references to 'tuple_element0' are executed before 'add0'
600   // (because of the depenency of 'add1' on 'tuple_element0').
601   EXPECT_TRUE(
602       TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {0}));
603 
604   // The second tuple element pair does not interfere, because there is no
605   // other dependency on 'tuple_element1', and so liveness can compute that
606   // all users execute before 'add1'.
607   EXPECT_FALSE(
608       TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}));
609 }
610 
611 class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
612  protected:
613   // Builds and runs a computation (see test case computation graphs below).
BuildModule(const bool update_uses_tuple_element1,const bool fuse_gte0)614   std::unique_ptr<VerifiedHloModule> BuildModule(
615       const bool update_uses_tuple_element1, const bool fuse_gte0) {
616     auto builder = HloComputation::Builder(TestName());
617     // Create param0 Tuple.
618     Shape data_shape = ShapeUtil::MakeShape(F32, {8});
619     Shape update_shape = ShapeUtil::MakeShape(F32, {3});
620     auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
621         0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "param0"));
622 
623     auto gte0 = builder.AddInstruction(
624         HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 0));
625 
626     auto gte1 = builder.AddInstruction(
627         HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1));
628 
629     auto update = builder.AddInstruction(HloInstruction::CreateConstant(
630         LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
631     HloInstruction* slice = nullptr;
632     if (update_uses_tuple_element1) {
633       // Create a slice instruction as an additional user of 'gte1'.
634       slice = builder.AddInstruction(
635           HloInstruction::CreateSlice(update_shape, gte1, {0}, {3}, {1}));
636       update = builder.AddInstruction(HloInstruction::CreateBinary(
637           update_shape, HloOpcode::kAdd, update, slice));
638     }
639     // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'.
640     auto starts = builder.AddInstruction(
641         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
642     auto dynamic_update_slice =
643         builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
644             data_shape, gte1, update, {starts}));
645     // Create output tuple.
646     builder.AddInstruction(
647         HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
648     // Build module and get reference to entry computation.
649     auto module = CreateNewVerifiedModule();
650     module->AddEntryComputation(builder.Build());
651     auto* computation = module->entry_computation();
652     // Create fusion instruction based on number of tuple element 1 users.
653     if (update_uses_tuple_element1) {
654       computation->CreateFusionInstruction(
655           {dynamic_update_slice, starts, update, CHECK_NOTNULL(slice), gte1},
656           HloInstruction::FusionKind::kLoop);
657     } else {
658       computation->CreateFusionInstruction(
659           {dynamic_update_slice, starts, update, gte1},
660           HloInstruction::FusionKind::kLoop);
661     }
662     // Create fusion instruction for tuple element 0 (if requested).
663     if (fuse_gte0) {
664       computation->CreateFusionInstruction({gte0},
665                                            HloInstruction::FusionKind::kLoop);
666     }
667     return module;
668   }
669 
670   // Returns whether buffer interference is detected between tuple-shaped
671   // parameter and root instructions at tuple element 1.
Run(const bool update_uses_tuple_element1,const bool fuse_gte0=false)672   bool Run(const bool update_uses_tuple_element1,
673            const bool fuse_gte0 = false) {
674     auto module = BuildModule(update_uses_tuple_element1, fuse_gte0);
675     // Run BufferLiveness on 'module'.
676     auto liveness = BufferLiveness::Run(
677                         module.get(),
678                         absl::make_unique<DependencyHloOrdering>(module.get()))
679                         .ConsumeValueOrDie();
680     // Return whether or not buffers interference is detected between
681     // 'tuple_param0' and 'tuple_root' at shape index '{1}'.
682     auto tuple_param0 = FindInstruction(module.get(), "param0");
683     auto tuple_root = module->entry_computation()->root_instruction();
684     return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1});
685   }
RunWithHloDataflowAnalysis(const bool update_uses_tuple_element1,const bool fuse_gte0=false)686   bool RunWithHloDataflowAnalysis(const bool update_uses_tuple_element1,
687                                   const bool fuse_gte0 = false) {
688     auto module = BuildModule(update_uses_tuple_element1, fuse_gte0);
689     // Run BufferLiveness on 'module'.
690     auto dataflow = HloDataflowAnalysis::Run(*module).ConsumeValueOrDie();
691     auto hlo_ordering = absl::make_unique<DependencyHloOrdering>(module.get());
692     // Return whether or not buffers interference is detected between
693     // 'tuple_param0' and 'tuple_root' at shape index '{1}'.
694     auto tuple_param0 = FindInstruction(module.get(), "param0");
695     auto tuple_root = module->entry_computation()->root_instruction();
696     return hlo_ordering->MayInterfere(
697         dataflow->GetUniqueValueAt(tuple_param0, {1}),
698         dataflow->GetUniqueValueAt(tuple_root, {1}), *dataflow);
699   }
700 };
701 
702 // Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion)
703 // do not overlap with the following computation:
704 //
705 //         Param0
706 //        /     \
707 //     GTE(0)  Fusion ----------->  FusionParam
708 //        |      |                      |
709 //        |      |                    GTE(1) Const Const
710 //        |      |                      \      |    /
711 //        |      |                    DynamicUpdateSlice  // fused root
712 //         \    /
713 //          Tuple  // computation root
714 //
TEST_F(FusedDynamicUpdateSliceLivenessTest,NoInterference)715 TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterference) {
716   EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false));
717   EXPECT_FALSE(
718       RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/false));
719 }
720 
721 // Tests that live ranges of buffers Param0[1] and Tuple[1] (which aliases
722 // 'fusion1') do not overlap in the presence of another fusion instruction
723 // (which is a user of 'param0' at a different tuple index).
724 // BufferLiveness should detect no uses of Param0 at index {1} in Fusion0
725 // (because Fusion0 only uses Param0 at index {0}).
726 //
727 //                               Param0
728 //                               /    \
729 //      FusionParam  <----- Fusion0  Fusion1 ------>  FusionParam
730 //         |                    |      |                 |
731 //        GTE(0)                |      |               GTE(1) Const Const
732 //                              |      |                  \      |    /
733 //                               \    /                DynamicUpdateSlice
734 //                               Tuple
735 //
TEST_F(FusedDynamicUpdateSliceLivenessTest,NoInterferenceWithUnrelatedFusion)736 TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterferenceWithUnrelatedFusion) {
737   EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false, /*fuse_gte0=*/true));
738   EXPECT_FALSE(RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/false,
739                                           /*fuse_gte0=*/true));
740 }
741 
742 // Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion)
743 // do overlap because GTE(1) has two users:
744 // 1) DynamicUpdateSlice at operand 0.
745 // 2) Slice at operand 0.
746 //
747 //         Param0
748 //        /     \   Const
749 //       /       \  /
750 //     GTE(0)  Fusion ----------->  FusionParam FusionParam
751 //        |      |                      |         |
752 //        |      |                    GTE(1)      /
753 //        |      |                      | \      /
754 //        |      |                      | Slice /
755 //        |      |                      |   \  /
756 //        |      |                      |   Add   Const
757 //        |      |                      |    |      |
758 //        |      |                    DynamicUpdateSlice  // fused root
759 //         \    /
760 //          Tuple  // computation root
761 //
TEST_F(FusedDynamicUpdateSliceLivenessTest,WithInterference)762 TEST_F(FusedDynamicUpdateSliceLivenessTest, WithInterference) {
763   EXPECT_TRUE(Run(/*update_uses_tuple_element1=*/true));
764   EXPECT_TRUE(RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/true));
765 }
766 
767 class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
768  protected:
769   // Builds and runs a computation (see test case computation graphs below).
770   // Runs BufferLiveness on this computation.
771   // Returns whether buffer interference is detected between tuple-shaped
772   // parameter and root instructions at tuple element 1.
Run(const bool tuple_element1_has_two_uses)773   bool Run(const bool tuple_element1_has_two_uses) {
774     auto builder = HloComputation::Builder(TestName());
775     // Create param0 Tuple.
776     Shape data_shape = ShapeUtil::MakeShape(F32, {8});
777     Shape update_shape = ShapeUtil::MakeShape(F32, {3});
778     auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
779         0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "param0"));
780 
781     auto gte0 = builder.AddInstruction(
782         HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 0));
783 
784     auto gte1 = builder.AddInstruction(
785         HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1));
786 
787     auto update = builder.AddInstruction(HloInstruction::CreateConstant(
788         LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
789 
790     if (tuple_element1_has_two_uses) {
791       // Add 'gte0' and 'gte1' to create another user of 'gte1'.
792       gte0 = builder.AddInstruction(HloInstruction::CreateBinary(
793           data_shape, HloOpcode::kAdd, gte0, gte1));
794     }
795     // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'.
796     auto starts = builder.AddInstruction(
797         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(2)));
798     auto dynamic_update_slice =
799         builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
800             data_shape, gte1, update, {starts}));
801     // Create output tuple.
802     auto tuple_root = builder.AddInstruction(
803         HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
804     // Build module and get reference to entry computation.
805     auto module = CreateNewVerifiedModule();
806     module->AddEntryComputation(BuildDummyComputation());
807     module->AddEmbeddedComputation(builder.Build());
808     // Run BufferLiveness on 'module'.
809     auto liveness = BufferLiveness::Run(
810                         module.get(),
811                         absl::make_unique<DependencyHloOrdering>(module.get()))
812                         .ConsumeValueOrDie();
813     // Return whether or not buffers interference is detected between
814     // 'tuple_param0' and 'tuple_root' at shape index '{1}'.
815     return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1});
816   }
817 };
818 
819 // Tests that live ranges of buffers Param0[1] and Tuple[1] do not overlap in
820 // the following computation (because DynamicUpdateSlice (at operand 0) is the
821 // unique user):
822 //
823 //     Parameter0
824 //      |      |
825 //    GTE(0) GTE(1) Const Const
826 //      |      \      |    /
827 //      |    DynamicUpdateSlice
828 //       \    /
829 //        Tuple
830 //
TEST_F(DynamicUpdateSliceLivenessTest,NoInterference)831 TEST_F(DynamicUpdateSliceLivenessTest, NoInterference) {
832   EXPECT_FALSE(Run(/*tuple_element1_has_two_uses=*/false));
833 }
834 
835 // Tests that live ranges of buffers Param0[1] and Tuple[1] do overlap because
836 // GTE(1) has two users:
837 // 1) DynamicUpdateSlice at operand 0.
838 // 2) Add at operand 1.
839 //
840 //     Parameter0
841 //      |      |
842 //    GTE(0) GTE(1)
843 //      |   /  |
844 //      |  /   |
845 //      Add    |     Const Const
846 //      |      |      |      |
847 //      |    DynamicUpdateSlice
848 //       \    /
849 //        Tuple
850 //
TEST_F(DynamicUpdateSliceLivenessTest,WithInterference)851 TEST_F(DynamicUpdateSliceLivenessTest, WithInterference) {
852   EXPECT_TRUE(Run(/*tuple_element1_has_two_uses=*/true));
853 }
854 
855 }  // namespace
856 
857 }  // namespace xla
858