1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/copy_insertion.h"
17 
18 #include <set>
19 
20 #include "tensorflow/compiler/xla/debug_options_flags.h"
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/hlo_runner.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/platform/test_benchmark.h"
34 
35 namespace op = xla::testing::opcode_matchers;
36 
37 namespace xla {
38 namespace {
39 
40 using ::testing::UnorderedElementsAre;
41 
CountCopies(const HloComputation & computation)42 int64 CountCopies(const HloComputation& computation) {
43   int64 count = 0;
44   for (const auto& instruction : computation.instructions()) {
45     if (instruction->opcode() == HloOpcode::kCopy) {
46       count++;
47     }
48   }
49   return count;
50 }
51 
CountCopies(const HloModule & module)52 int64 CountCopies(const HloModule& module) {
53   int64 count = 0;
54   for (const auto& computation : module.computations()) {
55     count += CountCopies(*computation);
56   }
57   return count;
58 }
59 
CountControlEdges(const HloComputation & computation)60 int64 CountControlEdges(const HloComputation& computation) {
61   int64 count = 0;
62   for (const auto& instruction : computation.instructions()) {
63     count += instruction->control_successors().size();
64   }
65   return count;
66 }
67 
CountControlEdges(const HloModule & module)68 int64 CountControlEdges(const HloModule& module) {
69   int64 count = 0;
70   for (const auto& computation : module.computations()) {
71     count += CountControlEdges(*computation);
72   }
73   return count;
74 }
75 
76 class CopyInsertionTest : public HloTestBase {
77  protected:
InsertCopies(HloModule * module)78   void InsertCopies(HloModule* module) {
79     CopyInsertion copy_insertion;
80     ASSERT_IS_OK(copy_insertion.Run(module).status());
81   }
82 
83   const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
84 };
85 
TEST_F(CopyInsertionTest,SingleParameter)86 TEST_F(CopyInsertionTest, SingleParameter) {
87   // Computation is a single parameter passed into a tuple. The parameter should
88   // be copied before entering the tuple.
89   auto builder = HloComputation::Builder(TestName());
90   HloInstruction* x = builder.AddInstruction(
91       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x"));
92   HloInstruction* tuple =
93       builder.AddInstruction(HloInstruction::CreateTuple({x}));
94 
95   EXPECT_THAT(x->users(), UnorderedElementsAre(tuple));
96 
97   auto module = CreateNewVerifiedModule();
98   module->AddEntryComputation(builder.Build());
99 
100   InsertCopies(module.get());
101 
102   EXPECT_THAT(module->entry_computation()->root_instruction(),
103               op::Tuple(op::Copy(x)));
104 }
105 
TEST_F(CopyInsertionTest,SingleConstant)106 TEST_F(CopyInsertionTest, SingleConstant) {
107   // Computation is a single constant passed into a tuple. The parameter should
108   // be copied before entering the tuple.
109   auto builder = HloComputation::Builder(TestName());
110   HloInstruction* constant = builder.AddInstruction(
111       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
112   HloInstruction* tuple =
113       builder.AddInstruction(HloInstruction::CreateTuple({constant}));
114 
115   EXPECT_THAT(constant->users(), UnorderedElementsAre(tuple));
116 
117   auto module = CreateNewVerifiedModule();
118   module->AddEntryComputation(builder.Build());
119 
120   InsertCopies(module.get());
121   EXPECT_EQ(CountCopies(*module), 1);
122 
123   EXPECT_THAT(module->entry_computation()->root_instruction(),
124               op::Tuple(op::Copy(constant)));
125 }
126 
TEST_F(CopyInsertionTest,ExistingCopiesNotRemoved)127 TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) {
128   // Verify that kCopy instructions which change layout and exist before
129   // copy-insertion remain in the graph after copy-insertion.
130   auto module = CreateNewVerifiedModule();
131 
132   auto builder = HloComputation::Builder(TestName());
133   HloInstruction* constant =
134       builder.AddInstruction(HloInstruction::CreateConstant(
135           LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}})));
136   auto minor_to_major = LayoutUtil::MinorToMajor(constant->shape());
137   Layout reversed_layout =
138       LayoutUtil::MakeLayoutFromMajorToMinor(minor_to_major);
139   Shape copy_shape = constant->shape();
140   *copy_shape.mutable_layout() = reversed_layout;
141   HloInstruction* copy_1 = builder.AddInstruction(
142       HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant));
143   HloInstruction* copy_2 = builder.AddInstruction(
144       HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant));
145   HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
146       constant->shape(), HloOpcode::kAdd, copy_1, copy_2));
147   builder.AddInstruction(
148       HloInstruction::CreateUnary(add->shape(), HloOpcode::kCopy, add));
149 
150   module->AddEntryComputation(builder.Build());
151 
152   EXPECT_EQ(CountCopies(*module), 3);
153 
154   InsertCopies(module.get());
155 
156   EXPECT_EQ(CountCopies(*module), 2);
157 
158   EXPECT_EQ(module->entry_computation()->root_instruction(), add);
159   EXPECT_THAT(module->entry_computation()->root_instruction(),
160               op::Add(op::Copy(op::Constant()), op::Copy(op::Constant())));
161 }
162 
TEST_F(CopyInsertionTest,MultipleConstantsAndParameters)163 TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) {
164   // Create a computation with more than one constant and parameter. Only one of
165   // each constant/parameter is pointed to by the output tuple. Only these
166   // instructions should be copied.
167   auto builder = HloComputation::Builder(TestName());
168 
169   HloInstruction* constant1 = builder.AddInstruction(
170       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
171   HloInstruction* constant2 = builder.AddInstruction(
172       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
173 
174   HloInstruction* x = builder.AddInstruction(
175       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x"));
176   HloInstruction* y = builder.AddInstruction(
177       HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "y"));
178 
179   HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
180       ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, constant1, y));
181 
182   builder.AddInstruction(HloInstruction::CreateTuple({constant2, x, add}));
183 
184   auto module = CreateNewVerifiedModule();
185   module->AddEntryComputation(builder.Build());
186 
187   InsertCopies(module.get());
188   EXPECT_EQ(CountCopies(*module), 2);
189 
190   EXPECT_THAT(
191       module->entry_computation()->root_instruction(),
192       op::Tuple(op::Copy(constant2), op::Copy(x), op::Add(constant1, y)));
193 }
194 
TEST_F(CopyInsertionTest,AmbiguousPointsToSet)195 TEST_F(CopyInsertionTest, AmbiguousPointsToSet) {
196   // Create a computation using select which has an ambiguous points-to set for
197   // the computation result. Verify that copies are added properly.
198   auto builder = HloComputation::Builder(TestName());
199   HloInstruction* constant1 = builder.AddInstruction(
200       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
201   HloInstruction* constant2 = builder.AddInstruction(
202       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
203   HloInstruction* constant3 = builder.AddInstruction(
204       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
205 
206   HloInstruction* tuple1 = builder.AddInstruction(
207       HloInstruction::CreateTuple({constant1, constant2}));
208   HloInstruction* tuple2 = builder.AddInstruction(
209       HloInstruction::CreateTuple({constant3, constant2}));
210 
211   HloInstruction* pred = builder.AddInstruction(
212       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
213   builder.AddInstruction(HloInstruction::CreateTernary(
214       tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
215 
216   EXPECT_THAT(constant1->users(), UnorderedElementsAre(tuple1));
217   EXPECT_THAT(constant2->users(), UnorderedElementsAre(tuple1, tuple2));
218   EXPECT_THAT(constant3->users(), UnorderedElementsAre(tuple2));
219 
220   auto module = CreateNewVerifiedModule();
221   module->AddEntryComputation(builder.Build());
222 
223   HloInstruction* old_root = module->entry_computation()->root_instruction();
224   InsertCopies(module.get());
225   EXPECT_EQ(CountCopies(*module), 2);
226 
227   EXPECT_THAT(module->entry_computation()->root_instruction(),
228               op::Tuple(op::Copy(op::GetTupleElement(old_root)),
229                         op::Copy(op::GetTupleElement(old_root))));
230 }
231 
TEST_F(CopyInsertionTest,BitcastParameter)232 TEST_F(CopyInsertionTest, BitcastParameter) {
233   // The output of a bitcast is its operand (same buffer), so a bitcast
234   // parameter feeding the result must have a copy added.
235   auto builder = HloComputation::Builder(TestName());
236   HloInstruction* x = builder.AddInstruction(
237       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x"));
238   HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
239       ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x));
240 
241   auto module = CreateNewVerifiedModule();
242   module->AddEntryComputation(builder.Build());
243 
244   EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast));
245 
246   HloInstruction* old_root = module->entry_computation()->root_instruction();
247   InsertCopies(module.get());
248   EXPECT_EQ(CountCopies(*module), 1);
249 
250   EXPECT_THAT(module->entry_computation()->root_instruction(),
251               op::Copy(old_root));
252 }
253 
TEST_F(CopyInsertionTest,BitcastConstant)254 TEST_F(CopyInsertionTest, BitcastConstant) {
255   // The output of a bitcast is its operand (same buffer), so a bitcast
256   // constant feeding the result must have a copy added.
257   auto builder = HloComputation::Builder(TestName());
258   HloInstruction* constant =
259       builder.AddInstruction(HloInstruction::CreateConstant(
260           LiteralUtil::CreateR1<float>({1.0, 42.0})));
261   HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
262       ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant));
263 
264   auto module = CreateNewVerifiedModule();
265   module->AddEntryComputation(builder.Build());
266 
267   EXPECT_THAT(constant->users(), UnorderedElementsAre(bitcast));
268 
269   HloInstruction* old_root = module->entry_computation()->root_instruction();
270   InsertCopies(module.get());
271   EXPECT_EQ(CountCopies(*module), 1);
272 
273   EXPECT_THAT(module->entry_computation()->root_instruction(),
274               op::Copy(old_root));
275 }
276 
TEST_F(CopyInsertionTest,BitcastTupleElementParameter)277 TEST_F(CopyInsertionTest, BitcastTupleElementParameter) {
278   // Same as BitcastParameter, but the bitcast is wrapped in a tuple.
279   auto builder = HloComputation::Builder(TestName());
280   HloInstruction* x = builder.AddInstruction(
281       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x"));
282   HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
283       ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x));
284   builder.AddInstruction(HloInstruction::CreateTuple({bitcast}));
285 
286   auto module = CreateNewVerifiedModule();
287   module->AddEntryComputation(builder.Build());
288 
289   EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast));
290 
291   InsertCopies(module.get());
292   EXPECT_EQ(CountCopies(*module), 1);
293 
294   EXPECT_THAT(module->entry_computation()->root_instruction(),
295               op::Tuple(op::Copy(bitcast)));
296 }
297 
TEST_F(CopyInsertionTest,NestedTupleParameter)298 TEST_F(CopyInsertionTest, NestedTupleParameter) {
299   // Construct a trivial computation where the root of the computation is a
300   // nested tuple-shaped parameter. The parameter should be deep copied and the
301   // copy should be the root of the computation.
302   auto builder = HloComputation::Builder(TestName());
303 
304   // Param shape is: ((F32[], S32[1,2,3]), F32[42])
305   builder.AddInstruction(HloInstruction::CreateParameter(
306       0,
307       ShapeUtil::MakeTupleShape(
308           {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}),
309                                       ShapeUtil::MakeShape(S32, {1, 2, 3})}),
310            ShapeUtil::MakeShape(F32, {42})}),
311       "param0"));
312 
313   auto module = CreateNewVerifiedModule();
314   module->AddEntryComputation(builder.Build());
315 
316   EXPECT_EQ(HloOpcode::kParameter,
317             module->entry_computation()->root_instruction()->opcode());
318 
319   HloInstruction* old_root = module->entry_computation()->root_instruction();
320   InsertCopies(module.get());
321   EXPECT_EQ(CountCopies(*module), 3);
322 
323   HloInstruction* new_root = module->entry_computation()->root_instruction();
324   EXPECT_NE(old_root, new_root);
325 
326   EXPECT_THAT(
327       new_root,
328       op::Tuple(
329           op::Tuple(
330               op::Copy(op::GetTupleElement(op::GetTupleElement(old_root))),
331               op::Copy(op::GetTupleElement(op::GetTupleElement(old_root)))),
332           op::Copy(op::GetTupleElement(old_root))));
333 }
334 
TEST_F(CopyInsertionTest,ElementOfNestedTupleParameter)335 TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) {
336   // Construct a computation where the root of the computation is a tuple
337   // element of a nested tuple-shaped parameter.
338   auto builder = HloComputation::Builder(TestName());
339 
340   // Param shape is: ((F32[], S32[1,2,3]), F32[42])
341   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
342       0,
343       ShapeUtil::MakeTupleShape(
344           {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}),
345                                       ShapeUtil::MakeShape(S32, {1, 2, 3})}),
346            ShapeUtil::MakeShape(F32, {42})}),
347       "param0"));
348 
349   // The return value of the computation is the zero-th element of the nested
350   // tuple. This element is itself a tuple.
351   auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
352       ShapeUtil::GetSubshape(param->shape(), {0}), param, 0));
353 
354   auto module = CreateNewVerifiedModule();
355   module->AddEntryComputation(builder.Build());
356 
357   EXPECT_EQ(gte, module->entry_computation()->root_instruction());
358 
359   InsertCopies(module.get());
360   EXPECT_EQ(CountCopies(*module), 2);
361 
362   EXPECT_THAT(
363       module->entry_computation()->root_instruction(),
364       op::Tuple(op::Copy(op::GetTupleElement(op::GetTupleElement(param))),
365                 op::Copy(op::GetTupleElement(op::GetTupleElement(param)))));
366 }
367 
TEST_F(CopyInsertionTest,AmbiguousTopLevelRoot)368 TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) {
369   // Create a computation using select which has an ambiguous points-to set for
370   // the top-level buffer of the root of the computation. Verify that a shallow
371   // copy is added.
372   auto builder = HloComputation::Builder(TestName());
373   HloInstruction* constant1 = builder.AddInstruction(
374       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
375   HloInstruction* constant2 = builder.AddInstruction(
376       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
377 
378   HloInstruction* tuple1 = builder.AddInstruction(
379       HloInstruction::CreateTuple({constant1, constant2}));
380   HloInstruction* tuple2 = builder.AddInstruction(
381       HloInstruction::CreateTuple({constant2, constant1}));
382 
383   HloInstruction* pred = builder.AddInstruction(
384       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
385   HloInstruction* select = builder.AddInstruction(HloInstruction::CreateTernary(
386       tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
387   HloInstruction* gte =
388       builder.AddInstruction(HloInstruction::CreateGetTupleElement(
389           ShapeUtil::GetSubshape(select->shape(), {0}), select, 0));
390 
391   auto module = CreateNewVerifiedModule();
392   module->AddEntryComputation(builder.Build());
393 
394   EXPECT_EQ(gte, module->entry_computation()->root_instruction());
395 
396   HloInstruction* old_root = module->entry_computation()->root_instruction();
397   InsertCopies(module.get());
398   EXPECT_EQ(CountCopies(*module), 1);
399 
400   EXPECT_THAT(module->entry_computation()->root_instruction(),
401               op::Copy(old_root));
402 }
403 
404 class WhileCopyInsertionTest : public CopyInsertionTest {
405  protected:
WhileCopyInsertionTest()406   WhileCopyInsertionTest() : module_(CreateNewUnverifiedModule()) {}
407 
408   // Builds a While condition computation which reads the induction variable
409   // from the tuple parameter, and returns a predicate indicating whether this
410   // value is less than the constant '10'.
411   // The parameter 'nested' specifies the loop state shape from which to
412   // read the induction variable.
BuildConditionComputation(const Shape & loop_state_shape)413   std::unique_ptr<HloComputation> BuildConditionComputation(
414       const Shape& loop_state_shape) {
415     auto builder = HloComputation::Builder(TestName() + ".Condition");
416     auto limit_const = builder.AddInstruction(
417         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(10)));
418     auto loop_state = builder.AddInstruction(
419         HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
420     auto induction_variable =
421         builder.AddInstruction(HloInstruction::CreateGetTupleElement(
422             limit_const->shape(), loop_state, 0));
423     builder.AddInstruction(HloInstruction::CreateCompare(
424         condition_result_shape_, induction_variable, limit_const,
425         ComparisonDirection::kLt));
426     return builder.Build();
427   }
428 
429   // Builds a While body computation with one output tuple element dependent on
430   // both input tuple elements.
431   // EX:
432   // Body({in0, in1})
433   //   out0 = Add(in0, 1)
434   //   out1 = Add(BCast(in0), in1)
435   //   Tuple(out0, out1)
BuildDependentBodyComputation()436   std::unique_ptr<HloComputation> BuildDependentBodyComputation() {
437     auto builder = HloComputation::Builder(TestName() + ".Body");
438     // Create param instruction to access loop state.
439     auto loop_state = builder.AddInstruction(
440         HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
441     // Update the induction variable GTE(0).
442     auto induction_variable =
443         builder.AddInstruction(HloInstruction::CreateGetTupleElement(
444             induction_variable_shape_, loop_state, 0));
445     auto inc = builder.AddInstruction(
446         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
447     auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
448         induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
449     // Update data GTE(1).
450     auto data = builder.AddInstruction(
451         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
452     // Use 'induction_variable' in computation with no path to output tuple.
453     auto update = builder.AddInstruction(
454         HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8}));
455     auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
456         data_shape_, HloOpcode::kAdd, data, update));
457     // Create output Tuple.
458     builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
459     return builder.Build();
460   }
461 
462   // Builds a While body computation with two output tuple elements dependent on
463   // both input tuple elements.
464   //
465   // EX: Body({in0, in1, in2})
466   //   out0 = Add(in0, 1)
467   //   out1 = in1
468   //   out2 = in2
469   //   Tuple(out0, out1, out2)
BuildDependentBodyComputation2()470   std::unique_ptr<HloComputation> BuildDependentBodyComputation2() {
471     auto builder = HloComputation::Builder(TestName() + ".Body");
472 
473     const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
474         {induction_variable_shape_, data_shape_, data_shape_});
475 
476     auto loop_state = builder.AddInstruction(
477         HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
478 
479     // Update the induction variable GTE(0).
480     auto induction_variable =
481         builder.AddInstruction(HloInstruction::CreateGetTupleElement(
482             induction_variable_shape_, loop_state, 0));
483     auto inc = builder.AddInstruction(
484         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
485 
486     // add0 = Add(in0, 1)
487     auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
488         induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
489     // data1 = GTE(1).
490     HloInstruction* data1 = builder.AddInstruction(
491         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
492 
493     // data2 = GTE(2).
494     HloInstruction* data2 = builder.AddInstruction(
495         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 2));
496 
497     // Create output Tuple.
498     builder.AddInstruction(HloInstruction::CreateTuple({add0, data1, data2}));
499 
500     return builder.Build();
501   }
502 
503   // Builds a While body computation with read-only tuple element 0.
504   // EX:
505   // Body({in0, in1})
506   //   out0 = in0
507   //   out1 = Add(BCast(in0), in1)
508   //   Tuple(out0, out1)
BuildDependentBodyOneReadOnlyComputation()509   std::unique_ptr<HloComputation> BuildDependentBodyOneReadOnlyComputation() {
510     auto builder = HloComputation::Builder(TestName() + ".Body");
511     // Create param instruction to access loop state.
512     auto loop_state = builder.AddInstruction(
513         HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
514     // Update the induction variable GTE(0).
515     auto induction_variable =
516         builder.AddInstruction(HloInstruction::CreateGetTupleElement(
517             induction_variable_shape_, loop_state, 0));
518     // Update data GTE(1).
519     auto data = builder.AddInstruction(
520         HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
521 
522     // Use 'induction_variable' in computation with no path to output tuple.
523     auto update = builder.AddInstruction(
524         HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8}));
525     auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
526         data_shape_, HloOpcode::kAdd, data, update));
527     // Create output Tuple.
528     builder.AddInstruction(
529         HloInstruction::CreateTuple({induction_variable, add1}));
530     return builder.Build();
531   }
532 
533   // Builds a While body computation with independent outputs.
534   // EX:
535   // Body({in0, in1})
536   //   out0 = Add(in0, 1)
537   //   out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
538   //   Tuple(out0, out1)
BuildIndependentBodyComputation(bool nested=false)539   std::unique_ptr<HloComputation> BuildIndependentBodyComputation(
540       bool nested = false) {
541     auto builder = HloComputation::Builder(TestName() + ".Body");
542     // Create param instruction to access loop state.
543     const Shape& loop_state_shape =
544         nested ? nested_loop_state_shape_ : loop_state_shape_;
545 
546     auto loop_state = builder.AddInstruction(
547         HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
548     // Update the induction variable GTE(0).
549     auto induction_variable =
550         builder.AddInstruction(HloInstruction::CreateGetTupleElement(
551             induction_variable_shape_, loop_state, 0));
552     auto inc = builder.AddInstruction(
553         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
554     // add0 = Add(in0, 1)
555     auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
556         induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
557     // Update data GTE(1).
558     HloInstruction* data = nullptr;
559     if (nested) {
560       data = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
561           nested_tuple_shape_, loop_state, 1));
562       data = builder.AddInstruction(
563           HloInstruction::CreateGetTupleElement(data_shape_, data, 0));
564     } else {
565       data = builder.AddInstruction(
566           HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
567     }
568     auto update = builder.AddInstruction(
569         HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
570             {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
571     // add1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
572     auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
573         data_shape_, HloOpcode::kAdd, data, update));
574     // Create output Tuple.
575     if (nested) {
576       auto nested_tuple =
577           builder.AddInstruction(HloInstruction::CreateTuple({add1, add1}));
578       builder.AddInstruction(HloInstruction::CreateTuple({add0, nested_tuple}));
579     } else {
580       builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
581     }
582     return builder.Build();
583   }
584 
585   // Builds a While body computation with the following nested tuple
586   // sub-computation:
587   //                            |
588   //                    GTE(loop_state, 1)
589   //                       /           \
590   // GTE(GTE(loop_state, 1), 0)     GTE(GTE(loop_state, 1), 1)
591   //           |                              |
592   //          Add                           Reverse
593   //           |                              |
BuildNestedBodyComputation()594   std::unique_ptr<HloComputation> BuildNestedBodyComputation() {
595     auto builder = HloComputation::Builder(TestName() + ".Body");
596     // Create param instruction to access loop state.
597     auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter(
598         0, nested_loop_state_shape_, "loop_state"));
599     // Update GTE(0).
600     auto gte0 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
601         induction_variable_shape_, loop_state, 0));
602     auto inc = builder.AddInstruction(
603         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
604     auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
605         gte0->shape(), HloOpcode::kAdd, gte0, inc));
606 
607     // GTE(loop_state, 1)
608     auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
609         nested_tuple_shape_, loop_state, 1));
610     // GTE(GTE(loop_state, 1), 0) -> Add
611     auto gte10 = builder.AddInstruction(
612         HloInstruction::CreateGetTupleElement(data_shape_, gte1, 0));
613     auto update10 = builder.AddInstruction(
614         HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
615             {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
616     auto add10 = builder.AddInstruction(HloInstruction::CreateBinary(
617         data_shape_, HloOpcode::kAdd, gte10, update10));
618 
619     // GTE(GTE(loop_state, 1), 1) -> Reverse
620     auto gte11 = builder.AddInstruction(
621         HloInstruction::CreateGetTupleElement(data_shape_, gte1, 1));
622     auto rev11 = builder.AddInstruction(
623         HloInstruction::CreateReverse(data_shape_, gte11, {0}));
624 
625     // Create output Tuple.
626     auto inner_tuple =
627         builder.AddInstruction(HloInstruction::CreateTuple({add10, rev11}));
628     builder.AddInstruction(HloInstruction::CreateTuple({add0, inner_tuple}));
629     return builder.Build();
630   }
631 
632   // Builds a While instruction using 'condition' and 'body' sub-computations.
633   // Init operand is initialized to zeros of appropriate shape.
BuildWhileInstruction(HloComputation * condition,HloComputation * body,bool nested=false)634   HloInstruction* BuildWhileInstruction(HloComputation* condition,
635                                         HloComputation* body,
636                                         bool nested = false) {
637     auto builder = HloComputation::Builder(TestName() + ".While");
638     auto induction_var_init = builder.AddInstruction(
639         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
640 
641     auto data_init = builder.AddInstruction(
642         HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
643             {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
644 
645     if (nested) {
646       auto inner_init = builder.AddInstruction(
647           HloInstruction::CreateTuple({data_init, data_init}));
648       auto loop_state_init = builder.AddInstruction(
649           HloInstruction::CreateTuple({induction_var_init, inner_init}));
650       auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile(
651           loop_state_init->shape(), condition, body, loop_state_init));
652       module_->AddEntryComputation(builder.Build());
653       return while_hlo;
654     }
655 
656     auto loop_state_init = builder.AddInstruction(
657         HloInstruction::CreateTuple({induction_var_init, data_init}));
658     auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile(
659         loop_state_shape_, condition, body, loop_state_init));
660     module_->AddEntryComputation(builder.Build());
661     return while_hlo;
662   }
663 
BuildWhileInstruction_InitPointsToConstant()664   HloInstruction* BuildWhileInstruction_InitPointsToConstant() {
665     auto builder = HloComputation::Builder(TestName() + ".While");
666     auto data_init = builder.AddInstruction(
667         HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
668             {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
669     return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init,
670                                                &builder);
671   }
672 
BuildWhileInstruction_InitPointsToParameter()673   HloInstruction* BuildWhileInstruction_InitPointsToParameter() {
674     auto builder = HloComputation::Builder(TestName() + ".While");
675     auto data_init = builder.AddInstruction(
676         HloInstruction::CreateParameter(0, data_shape_, "data_init"));
677     return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init,
678                                                &builder);
679   }
680 
BuildWhileInstruction_InitPointsToAmbiguous()681   HloInstruction* BuildWhileInstruction_InitPointsToAmbiguous() {
682     auto builder = HloComputation::Builder(TestName() + ".While");
683 
684     auto one = builder.AddInstruction(
685         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
686     auto v1 = builder.AddInstruction(
687         HloInstruction::CreateBroadcast(data_shape_, one, {1}));
688     auto zero = builder.AddInstruction(
689         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
690     auto v2 = builder.AddInstruction(
691         HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
692 
693     auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({v1, v2}));
694     auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({v2, v1}));
695 
696     auto pred = builder.AddInstruction(
697         HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
698     auto data_init = builder.AddInstruction(HloInstruction::CreateTernary(
699         nested_tuple_shape_, HloOpcode::kTupleSelect, pred, tuple1, tuple2));
700 
701     return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_,
702                                                data_init, &builder);
703   }
704 
BuildWhileInstruction_InitPointsToNonDistinct()705   HloInstruction* BuildWhileInstruction_InitPointsToNonDistinct() {
706     auto builder = HloComputation::Builder(TestName() + ".While");
707 
708     auto one = builder.AddInstruction(
709         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
710     auto one_vec = builder.AddInstruction(
711         HloInstruction::CreateBroadcast(data_shape_, one, {1}));
712     auto data_init =
713         builder.AddInstruction(HloInstruction::CreateTuple({one_vec, one_vec}));
714 
715     return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_,
716                                                data_init, &builder);
717   }
718 
BuildWhileInstruction_InitPointsToInterfering()719   HloInstruction* BuildWhileInstruction_InitPointsToInterfering() {
720     auto builder = HloComputation::Builder(TestName() + ".While");
721     auto one = builder.AddInstruction(
722         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
723     auto data_init = builder.AddInstruction(
724         HloInstruction::CreateBroadcast(data_shape_, one, {1}));
725     auto one_vec = builder.AddInstruction(
726         HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
727             {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
728     // Take a reference to 'data_init' to make it interfere with while result.
729     auto add = builder.AddInstruction(HloInstruction::CreateBinary(
730         data_shape_, HloOpcode::kAdd, data_init, one_vec));
731 
732     auto xla_while = BuildWhileInstructionWithCustomInit(loop_state_shape_,
733                                                          data_init, &builder);
734 
735     // Add an additional binary operation operating on the while and the
736     // interfering add so that neither operation is dead.
737     auto gte = xla_while->parent()->AddInstruction(
738         HloInstruction::CreateGetTupleElement(
739             ShapeUtil::GetSubshape(xla_while->shape(), {1}), xla_while, 1));
740     auto sub = xla_while->parent()->AddInstruction(HloInstruction::CreateBinary(
741         data_shape_, HloOpcode::kSubtract, add, gte));
742     auto gte0 = xla_while->parent()->AddInstruction(
743         HloInstruction::CreateGetTupleElement(
744             ShapeUtil::GetSubshape(xla_while->shape(), {0}), xla_while, 0));
745     auto tuple = xla_while->parent()->AddInstruction(
746         HloInstruction::CreateTuple({gte0, sub}));
747 
748     xla_while->parent()->set_root_instruction(tuple);
749 
750     return xla_while;
751   }
752 
BuildWhileInstructionWithCustomInit(const Shape & loop_state_shape,HloInstruction * data_init,HloComputation::Builder * builder)753   HloInstruction* BuildWhileInstructionWithCustomInit(
754       const Shape& loop_state_shape, HloInstruction* data_init,
755       HloComputation::Builder* builder) {
756     const bool nested =
757         ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_);
758     auto induction_var_init = builder->AddInstruction(
759         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
760     auto condition = module_->AddEmbeddedComputation(
761         BuildConditionComputation(loop_state_shape));
762     auto body = module_->AddEmbeddedComputation(
763         BuildIndependentBodyComputation(nested));
764     auto loop_state_init = builder->AddInstruction(
765         HloInstruction::CreateTuple({induction_var_init, data_init}));
766     auto while_hlo = builder->AddInstruction(HloInstruction::CreateWhile(
767         loop_state_shape, condition, body, loop_state_init));
768     module_->AddEntryComputation(builder->Build());
769     return while_hlo;
770   }
771 
772   std::unique_ptr<HloModule> module_;
773   Shape induction_variable_shape_ = ShapeUtil::MakeShape(S32, {});
774   Shape data_shape_ = ShapeUtil::MakeShape(F32, {8});
775   Shape loop_state_shape_ =
776       ShapeUtil::MakeTupleShape({induction_variable_shape_, data_shape_});
777   Shape nested_tuple_shape_ =
778       ShapeUtil::MakeTupleShape({data_shape_, data_shape_});
779   Shape nested_loop_state_shape_ = ShapeUtil::MakeTupleShape(
780       {induction_variable_shape_, nested_tuple_shape_});
781   Shape condition_result_shape_ = ShapeUtil::MakeShape(PRED, {});
782 };
783 
784 // Tests while body computation with independent tuple elements:
785 //
786 //   While.Body({in0, in1})
787 //     out0 = Add(in0, 1)
788 //     out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
789 //     Tuple(out0, out1)
790 //
791 // CopyInsertion pass should not generate any copies.
792 //
TEST_F(WhileCopyInsertionTest,IndependentTupleElements)793 TEST_F(WhileCopyInsertionTest, IndependentTupleElements) {
794   auto condition = module_->AddEmbeddedComputation(
795       BuildConditionComputation(loop_state_shape_));
796   auto body =
797       module_->AddEmbeddedComputation(BuildIndependentBodyComputation());
798   auto while_hlo = BuildWhileInstruction(condition, body);
799 
800   InsertCopies(module_.get());
801 
802   // Body should have no copies as the adds can be done inplace.
803   EXPECT_EQ(CountCopies(*body), 0);
804   EXPECT_EQ(CountControlEdges(*module_), 0);
805 
806   // Both init indices need copies as they are constants.
807   EXPECT_THAT(while_hlo->operand(0),
808               op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant())));
809 }
810 
811 // Tests while body computation with dependent tuple elements:
812 //
813 //   While.Body({in0, in1})
814 //     out0 = Add(in0, 1)
815 //     out1 = Add(BCast(in0), in1)
816 //     Tuple(out0, out1)
817 //
818 // CopyInsertion pass should convert the root instruction to:
819 //
820 //     Tuple(Copy(out0), out1)
821 //
TEST_F(WhileCopyInsertionTest,DependentTupleElements)822 TEST_F(WhileCopyInsertionTest, DependentTupleElements) {
823   auto condition = module_->AddEmbeddedComputation(
824       BuildConditionComputation(loop_state_shape_));
825   auto body = module_->AddEmbeddedComputation(BuildDependentBodyComputation());
826   auto while_hlo = BuildWhileInstruction(condition, body);
827 
828   InsertCopies(module_.get());
829 
830   EXPECT_EQ(CountCopies(*body), 1);
831   EXPECT_EQ(CountControlEdges(*body), 0);
832 
833   EXPECT_THAT(
834       body->root_instruction(),
835       op::Tuple(op::Add(), op::Add(op::GetTupleElement(), op::Broadcast())));
836 
837   auto add = body->root_instruction()->operand(0);
838   auto bcast = body->root_instruction()->operand(1)->operand(1);
839   ASSERT_EQ(add->opcode(), HloOpcode::kAdd);
840   ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
841 
842   EXPECT_THAT(
843       while_hlo->while_body()->root_instruction(),
844       op::Tuple(op::Add(op::Copy(), op::Constant()),
845                 op::Add(op::GetTupleElement(), op::Broadcast(op::Copy()))));
846 
847   // Both init indices need copies as they are constants.
848   EXPECT_THAT(while_hlo->operand(0),
849               op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant())));
850 }
851 
852 // Tests while body computation with read-only tuple element 0:
853 //
854 //                         PARAMETER
855 //                         /       \
856 //                      GTE(0)     GTE(1)
857 //                        |  \      |
858 //                        |   BCAST |
859 //                        |      \  |
860 //                        |       ADD
861 //                        |        |
862 //                         \      /
863 //                           TUPLE (root)
864 //
865 // CopyInsertion pass should not generate any copies for the while body.
TEST_F(WhileCopyInsertionTest,DependentTupleElements_OneReadOnly)866 TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly) {
867   auto condition = module_->AddEmbeddedComputation(
868       BuildConditionComputation(loop_state_shape_));
869   auto body = module_->AddEmbeddedComputation(
870       BuildDependentBodyOneReadOnlyComputation());
871   BuildWhileInstruction(condition, body);
872 
873   InsertCopies(module_.get());
874 
875   // No copies or control edges should be inserted. The body is legal as is.
876   EXPECT_EQ(CountCopies(*body), 0);
877   EXPECT_EQ(CountControlEdges(*body), 0);
878 }
879 
880 // Same as above, but with two while loops, sharing entry parameters.
TEST_F(WhileCopyInsertionTest,DependentTupleElements_OneReadOnly_TwoLoops_EntryParams)881 TEST_F(WhileCopyInsertionTest,
882        DependentTupleElements_OneReadOnly_TwoLoops_EntryParams) {
883   auto condition1 = module_->AddEmbeddedComputation(
884       BuildConditionComputation(loop_state_shape_));
885   auto condition2 = module_->AddEmbeddedComputation(
886       BuildConditionComputation(loop_state_shape_));
887   auto body1 = module_->AddEmbeddedComputation(
888       BuildDependentBodyOneReadOnlyComputation());
889   auto body2 = module_->AddEmbeddedComputation(
890       BuildDependentBodyOneReadOnlyComputation());
891 
892   auto builder = HloComputation::Builder(TestName() + ".While");
893   auto iter_param = builder.AddInstruction(
894       HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
895   auto data_param = builder.AddInstruction(
896       HloInstruction::CreateParameter(1, data_shape_, "data"));
897   auto loop_init = builder.AddInstruction(
898       HloInstruction::CreateTuple({iter_param, data_param}));
899 
900   auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
901       loop_state_shape_, condition1, body1, loop_init));
902   auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
903       loop_state_shape_, condition2, body2, loop_init));
904 
905   // Add a couple elements from each of the while so both whiles are live.
906   auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
907       ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0));
908   auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
909       ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0));
910   builder.AddInstruction(
911       HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2));
912 
913   auto entry = module_->AddEntryComputation(builder.Build());
914 
915   InsertCopies(module_.get());
916 
917   // Neither body should have any copies or control edges in them.
918   EXPECT_EQ(CountCopies(*body1), 0);
919   EXPECT_EQ(CountCopies(*body2), 0);
920   EXPECT_EQ(CountControlEdges(*body1), 0);
921   EXPECT_EQ(CountControlEdges(*body2), 0);
922 
923   // Only two copies should be necessary. Each of the whiles should have
924   // a copy of tuple element 1 (init value is a parameter, and the element is
925   // not non-read-only) so each of the while bodies gets its own buffer to write
926   // element 1 into.
927   EXPECT_EQ(CountCopies(*entry), 2);
928 
929   EXPECT_EQ(while_hlo1->operand(0)->operand(1)->opcode(), HloOpcode::kCopy);
930   EXPECT_EQ(while_hlo2->operand(0)->operand(1)->opcode(), HloOpcode::kCopy);
931 
932   // The two copies of element 1 should be different.
933   EXPECT_NE(while_hlo1->operand(0)->operand(1),
934             while_hlo2->operand(0)->operand(1));
935 }
936 
937 // Same as above, but with two while loops, sharing non-parameters.
TEST_F(WhileCopyInsertionTest,DependentTupleElements_OneReadOnly_TwoLoops_NonParams)938 TEST_F(WhileCopyInsertionTest,
939        DependentTupleElements_OneReadOnly_TwoLoops_NonParams) {
940   auto condition1 = module_->AddEmbeddedComputation(
941       BuildConditionComputation(loop_state_shape_));
942   auto condition2 = module_->AddEmbeddedComputation(
943       BuildConditionComputation(loop_state_shape_));
944   auto body1 = module_->AddEmbeddedComputation(
945       BuildDependentBodyOneReadOnlyComputation());
946   auto body2 = module_->AddEmbeddedComputation(
947       BuildDependentBodyOneReadOnlyComputation());
948 
949   auto builder = HloComputation::Builder(TestName() + ".While");
950   auto iter_param = builder.AddInstruction(
951       HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
952   auto data_param = builder.AddInstruction(
953       HloInstruction::CreateParameter(1, data_shape_, "data"));
954   // Add dummy ops to ensure loop_init elements aren't entry parameters.
955   auto iter_value = builder.AddInstruction(HloInstruction::CreateUnary(
956       iter_param->shape(), HloOpcode::kExp, iter_param));
957   auto data_value = builder.AddInstruction(HloInstruction::CreateUnary(
958       data_param->shape(), HloOpcode::kExp, data_param));
959   auto loop_init = builder.AddInstruction(
960       HloInstruction::CreateTuple({iter_value, data_value}));
961 
962   auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
963       loop_state_shape_, condition1, body1, loop_init));
964   auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
965       loop_state_shape_, condition2, body2, loop_init));
966 
967   // Add a couple elements from each of the while so both whiles are not dead.
968   auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
969       ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0));
970   auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
971       ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0));
972   builder.AddInstruction(
973       HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2));
974   auto entry = module_->AddEntryComputation(builder.Build());
975 
976   InsertCopies(module_.get());
977 
978   // Ideally only one copy should be necessary. One of the whiles should
979   // have a copy of tuple element 1 (the non-read-only element) so each of the
980   // while bodies gets its own buffer to write element 1 into. However, the
981   // analysis isn't perfect and adds an additional copy of element 0.
982   EXPECT_EQ(CountCopies(*entry), 2);
983 
984   EXPECT_THAT(while_hlo1->operand(0),
985               op::Tuple(op::Exp(), op::Copy(op::Exp())));
986   EXPECT_THAT(while_hlo2->operand(0),
987               op::Tuple(op::Exp(), op::Copy(op::Exp())));
988 }
989 
990 // Tests while body computation with nested tuple elements:
991 //
992 //                            |
993 //                    GTE(loop_state, 1)
994 //                       /          \
995 // GTE(GTE(loop_state, 1), 0)     GTE(GTE(loop_state, 1), 1)
996 //           |                              |
997 //          Add                           Reverse
998 //           |                              |
999 //
1000 // CopyInsertion pass will conceptually generate the following, but with the
1001 // actual GTE and Tuple instructions optimized away:
1002 //
1003 //                    Tuple  // old root
1004 //                   /     \
1005 //                  /       \
1006 //                GTE(0)   GTE(1)
1007 //                  |       /  \
1008 //                  |      /    \
1009 //                  |    GTE(0) GTE(1)
1010 //                  |       |    |
1011 //                  |       |   Copy
1012 //                  |       |    |
1013 //                   \      |   /
1014 //                    \    Tuple  // "inner" tuple.
1015 //                     \    /
1016 //                      \  /
1017 //                     Tuple  // new root
1018 //
TEST_F(WhileCopyInsertionTest,NestedTupleElements)1019 TEST_F(WhileCopyInsertionTest, NestedTupleElements) {
1020   auto condition = module_->AddEmbeddedComputation(
1021       BuildConditionComputation(nested_loop_state_shape_));
1022   auto body = module_->AddEmbeddedComputation(BuildNestedBodyComputation());
1023   BuildWhileInstruction(condition, body, true);
1024 
1025   //  HloInstruction* old_root = body->root_instruction();
1026   InsertCopies(module_.get());
1027 
1028   // The only copy necessary is for the kReverse as it cannot be done
1029   // in-place (instruction can share buffer with operand). The other elements of
1030   // the loop state are kAdd instructions which can be done in-place.
1031   EXPECT_EQ(CountCopies(*body), 1);
1032 
1033   // Each element of the init needs a copy as all are constants.
1034   EXPECT_EQ(CountCopies(*module_), 4);
1035 
1036   // Either the kReverse itself must be copied or the operand of the kReverse
1037   // must be copied.
1038   if (body->root_instruction()->operand(1)->operand(1)->opcode() ==
1039       HloOpcode::kCopy) {
1040     EXPECT_THAT(
1041         body->root_instruction(),
1042         op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Reverse()))));
1043   } else {
1044     EXPECT_THAT(
1045         body->root_instruction(),
1046         op::Tuple(op::Add(), op::Tuple(op::Add(), op::Reverse(op::Copy()))));
1047   }
1048 }
1049 
1050 // Tests while init instruction which points-to a constant.
1051 //
1052 //     init = Tuple(Constant(S32, {}), Constant(F32, {8}))
1053 //
1054 // CopyInsertion pass should add copies for both constants.
1055 //
TEST_F(WhileCopyInsertionTest,InitPointsToConstant)1056 TEST_F(WhileCopyInsertionTest, InitPointsToConstant) {
1057   auto while_hlo = BuildWhileInstruction_InitPointsToConstant();
1058 
1059   InsertCopies(module_.get());
1060   EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0);
1061   EXPECT_EQ(CountCopies(*module_), 2);
1062 
1063   EXPECT_THAT(while_hlo->operand(0),
1064               op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant())));
1065 }
1066 
1067 // Tests while init instruction which points-to a parameter.
1068 //
1069 //     init = Tuple(Constant(S32, {}), Parameter(F32, {8}))
1070 //
1071 // CopyInsertion pass should add copies for both the constant and parameter.
1072 //
TEST_F(WhileCopyInsertionTest,InitPointsToParameter)1073 TEST_F(WhileCopyInsertionTest, InitPointsToParameter) {
1074   auto while_hlo = BuildWhileInstruction_InitPointsToParameter();
1075 
1076   InsertCopies(module_.get());
1077   EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0);
1078   EXPECT_EQ(CountCopies(*module_), 2);
1079 
1080   EXPECT_THAT(while_hlo->operand(0),
1081               op::Tuple(op::Copy(op::Constant()), op::Copy(op::Parameter())));
1082 }
1083 
1084 // Tests while init instruction which has an ambiguous points-to set.
1085 //
1086 //     select = Select(pred, tuple1, tuple2)
1087 //     init = Tuple(Constant(S32, {}), Parameter(F32, {8}))
1088 //
1089 // CopyInsertion pass will conceptually generate the following, but with some of
1090 // the actual GTE and Tuple instructions optimized away:
1091 //
1092 //                    Tuple  // old init
1093 //                   /     \
1094 //                  /       \
1095 //                GTE(0)   GTE(1)
1096 //                  |       /  \
1097 //                  |      /    \
1098 //                  |    GTE(0) GTE(1)
1099 //                  |       |    |
1100 //                Copy   Copy   Copy
1101 //                  |       |    |
1102 //                   \      |   /
1103 //                    \    Tuple
1104 //                     \    /
1105 //                      \  /
1106 //                     Tuple  // new init
1107 //
TEST_F(WhileCopyInsertionTest,InitPointsToAmbiguous)1108 TEST_F(WhileCopyInsertionTest, InitPointsToAmbiguous) {
1109   auto while_hlo = BuildWhileInstruction_InitPointsToAmbiguous();
1110 
1111   InsertCopies(module_.get());
1112   EXPECT_EQ(CountCopies(*module_), 4);
1113   // The entry computation requires three copies to resolve the ambiguity of two
1114   // init elements and the constant passed in as one of the init elements.
1115   EXPECT_EQ(CountCopies(*module_->entry_computation()), 3);
1116   EXPECT_THAT(while_hlo->operand(0),
1117               op::Tuple(op::Copy(op::Constant()),
1118                         op::Tuple(op::Copy(op::GetTupleElement()),
1119                                   op::Copy(op::GetTupleElement()))));
1120 
1121   // The body requires one copy because the buffer set is not distinct: the
1122   // result of one of the adds is written into two elements of the output of the
1123   // loop body. Either element might be copied.
1124   EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1);
1125   if (while_hlo->while_body()
1126           ->root_instruction()
1127           ->operand(1)
1128           ->operand(0)
1129           ->opcode() == HloOpcode::kCopy) {
1130     EXPECT_THAT(
1131         while_hlo->while_body()->root_instruction(),
1132         op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add())));
1133   } else {
1134     EXPECT_THAT(
1135         while_hlo->while_body()->root_instruction(),
1136         op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add()))));
1137   }
1138 }
1139 
1140 // Tests while init instruction which has a non-distinct points-to set.
1141 //
1142 //     init = Tuple(Constant(S32, {}), Tuple({vec_one, vec_one}))
1143 //
1144 // CopyInsertion pass will conceptually generate the following, but with some of
1145 // the actual GTE and Tuple instructions optimized away:
1146 //
1147 //                    Tuple  // old init
1148 //                   /     \
1149 //                  /       \
1150 //                GTE(0)   GTE(1)
1151 //                  |       /  \
1152 //                  |      /    \
1153 //                  |    GTE(0) GTE(1)
1154 //                  |       |    |
1155 //                Copy   Copy   Copy
1156 //                  |       |    |
1157 //                   \      |   /
1158 //                    \    Tuple
1159 //                     \    /
1160 //                      \  /
1161 //                     Tuple  // new init
1162 //
TEST_F(WhileCopyInsertionTest,InitPointsToNonDistinct)1163 TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) {
1164   auto while_hlo = BuildWhileInstruction_InitPointsToNonDistinct();
1165 
1166   InsertCopies(module_.get());
1167 
1168   // The entry computation requires two copies to resolve the non-disinctness of
1169   // two init elements and the constant passed in as one of the init
1170   // elements. Either element can be copied for the distinctness issue.
1171   EXPECT_EQ(CountCopies(*module_->entry_computation()), 2);
1172   if (while_hlo->operand(0)->operand(1)->operand(0)->opcode() ==
1173       HloOpcode::kCopy) {
1174     EXPECT_THAT(
1175         while_hlo->operand(0),
1176         op::Tuple(op::Copy(op::Constant()),
1177                   op::Tuple(op::Copy(op::Broadcast()), op::Broadcast())));
1178   } else {
1179     EXPECT_THAT(
1180         while_hlo->operand(0),
1181         op::Tuple(op::Copy(op::Constant()),
1182                   op::Tuple(op::Broadcast(), op::Copy(op::Broadcast()))));
1183   }
1184 
1185   // The body requires one copy because the buffer set is not distinct: the
1186   // result of one of the adds is written into two elements of the output of the
1187   // loop body. Either element might be copied.
1188   EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1);
1189   if (while_hlo->while_body()
1190           ->root_instruction()
1191           ->operand(1)
1192           ->operand(0)
1193           ->opcode() == HloOpcode::kCopy) {
1194     EXPECT_THAT(
1195         while_hlo->while_body()->root_instruction(),
1196         op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add())));
1197   } else {
1198     EXPECT_THAT(
1199         while_hlo->while_body()->root_instruction(),
1200         op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add()))));
1201   }
1202 }
1203 
1204 // Tests while init instruction buffer which interferes with while result
1205 // buffer.
1206 //
1207 //     init_data = Broadcast(...)
1208 //     add_unrelated = Add(init_data) // takes a reference to cause interference
1209 //     init = Tuple(Constant(S32, {}), init_data))
1210 //
1211 // CopyInsertion pass should copy both operands.
1212 //
TEST_F(WhileCopyInsertionTest,InitPointsToInterfering)1213 TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) {
1214   auto while_hlo = BuildWhileInstruction_InitPointsToInterfering();
1215 
1216   InsertCopies(module_.get());
1217   EXPECT_EQ(CountCopies(*module_), 2);
1218   EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0);
1219 
1220   EXPECT_THAT(while_hlo->operand(0),
1221               op::Tuple(op::Copy(op::Constant()), op::Copy(op::Broadcast())));
1222 }
1223 
1224 // Tests while init instruction buffer which has a non-distinct points-to set:
1225 //
1226 //     init = Tuple(Parameter(S32, {}), Parameter(F32, {8},
1227 //                  Parameter(F32, {8})))
1228 //
1229 // where the second and third parameters are identical *and* the tuple shared
1230 // by another while instruction.
1231 //
1232 // Verifies that the resulting point-to set is distinct in the resulting Tuple
1233 // (non-identical Copys). In other words, verifies that copy sharing does not
1234 // insert identical copies to the resulting tuple.
TEST_F(WhileCopyInsertionTest,InitPointsToNonDistinctUsedByTwoWhileLoops)1235 TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) {
1236   // Loop body that outputs tuple comprises two elements dependent on the init
1237   // tuple.
1238   const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
1239       {induction_variable_shape_, data_shape_, data_shape_});
1240 
1241   auto condition1 = module_->AddEmbeddedComputation(
1242       BuildConditionComputation(loop_state_shape));
1243   auto condition2 = module_->AddEmbeddedComputation(
1244       BuildConditionComputation(loop_state_shape));
1245   auto body1 =
1246       module_->AddEmbeddedComputation(BuildDependentBodyComputation2());
1247   auto body2 =
1248       module_->AddEmbeddedComputation(BuildDependentBodyComputation2());
1249 
1250   auto builder = HloComputation::Builder(TestName() + ".While");
1251 
1252   auto iter_param = builder.AddInstruction(
1253       HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
1254   auto data_param = builder.AddInstruction(
1255       HloInstruction::CreateParameter(1, data_shape_, "data"));
1256 
1257   // Loop init tuple contains two identical parameter buffers.
1258   auto loop_init = builder.AddInstruction(
1259       HloInstruction::CreateTuple({iter_param, data_param, data_param}));
1260 
1261   // Two while loops shares the same loop init tuple.
1262   auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
1263       loop_state_shape, condition1, body1, loop_init));
1264   auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
1265       loop_state_shape, condition2, body2, loop_init));
1266 
1267   // Add add instruction so neither while is dead.
1268   auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1269       ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0));
1270   auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1271       ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo2, 0));
1272   builder.AddInstruction(
1273       HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2));
1274 
1275   module_->AddEntryComputation(builder.Build());
1276 
1277   InsertCopies(module_.get());
1278 
1279   // None of the bodies should have copies or control flow edges.
1280   EXPECT_EQ(CountCopies(*body1), 0);
1281   EXPECT_EQ(CountCopies(*body2), 0);
1282 
1283   // The loop bodies pass through elements 1 and 2 in the init tuple, so ideally
1284   // these should not need to be copied before either while. However, copy
1285   // insertion is not able to reason about the transparency of elements through
1286   // while bodies in all circumstances so extra copies are added (b/xxx).
1287   EXPECT_EQ(CountCopies(*module_->entry_computation()), 2);
1288 
1289   EXPECT_THAT(while_hlo1->operand(0),
1290               op::Tuple(op::Copy(), op::Parameter(), op::Parameter()));
1291   EXPECT_THAT(while_hlo2->operand(0),
1292               op::Tuple(op::Copy(), op::Parameter(), op::Parameter()));
1293 }
1294 
TEST_F(CopyInsertionTest,SwizzlingWhile)1295 TEST_F(CopyInsertionTest, SwizzlingWhile) {
1296   // Test a while instruction with a body which permutes its tuple parameter
1297   // elements.
1298   auto module = CreateNewVerifiedModule();
1299   const Shape loop_state_shape =
1300       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1301 
1302   // Body simply interchanges the two tuple elements in the loop state.
1303   auto body_builder = HloComputation::Builder("body");
1304   auto body_param = body_builder.AddInstruction(
1305       HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1306   auto body_element_0 = body_builder.AddInstruction(
1307       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
1308   auto body_element_1 = body_builder.AddInstruction(
1309       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
1310   body_builder.AddInstruction(
1311       HloInstruction::CreateTuple({body_element_1, body_element_0}));
1312   HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1313 
1314   auto cond_builder = HloComputation::Builder("condition");
1315   cond_builder.AddInstruction(
1316       HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1317   auto cond_constant = cond_builder.AddInstruction(
1318       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1319   cond_builder.AddInstruction(HloInstruction::CreateUnary(
1320       cond_constant->shape(), HloOpcode::kNot, cond_constant));
1321   HloComputation* condition =
1322       module->AddEmbeddedComputation(cond_builder.Build());
1323 
1324   auto builder = HloComputation::Builder(TestName());
1325   auto constant1 = builder.AddInstruction(
1326       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1327   auto constant2 = builder.AddInstruction(
1328       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
1329   auto tuple = builder.AddInstruction(
1330       HloInstruction::CreateTuple({constant1, constant2}));
1331   auto xla_while = builder.AddInstruction(
1332       HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple));
1333   module->AddEntryComputation(builder.Build());
1334 
1335   InsertCopies(module.get());
1336 
1337   EXPECT_EQ(CountCopies(*module), 6);
1338 
1339   // The loop state elements should be copied at the parameter and at the root
1340   // with a control edge in between (see DeepCopyAndAddControlEdges). This is
1341   // technically one more copy than is strictly necessary, but in order to have
1342   // only three copies the copies of different loop state elements must be
1343   // ordered with a control edge.
1344   EXPECT_EQ(CountCopies(*body), 4);
1345   EXPECT_EQ(CountControlEdges(*body), 2);
1346 
1347   EXPECT_THAT(body->root_instruction(),
1348               op::Tuple(op::Copy(op::Copy()), op::Copy(op::Copy())));
1349 
1350   EXPECT_EQ(CountCopies(*module->entry_computation()), 2);
1351   EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy()));
1352 }
1353 
TEST_F(CopyInsertionTest,CrossingParameters)1354 TEST_F(CopyInsertionTest, CrossingParameters) {
1355   // Test a case where two parameters' dataflow cross with each other while
1356   // input and output are aliased with same index:
1357   //
1358   //  (p0 ,  p1)
1359   //   | \   /|
1360   //   |  \ / |
1361   // alias X  alias
1362   //   |  / \ |
1363   //   | /   \|
1364   //  (p1  ,  p0)
1365   auto module = CreateNewVerifiedModule();
1366   const Shape tuple_shape =
1367       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1368 
1369   auto builder = HloComputation::Builder(TestName());
1370   auto param = builder.AddInstruction(
1371       HloInstruction::CreateParameter(0, tuple_shape, "0"));
1372   auto gte0 = builder.AddInstruction(
1373       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1374   auto gte1 = builder.AddInstruction(
1375       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1376   builder.AddInstruction(HloInstruction::CreateTuple({gte1, gte0}));
1377   module->AddEntryComputation(builder.Build());
1378   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1379       /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
1380       /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
1381   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1382       /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1},
1383       /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
1384   InsertCopies(module.get());
1385 
1386   EXPECT_EQ(CountCopies(*module), 4);
1387 }
1388 
TEST_F(CopyInsertionTest,ParametersAliasing)1389 TEST_F(CopyInsertionTest, ParametersAliasing) {
1390   // Test a case where two parameters' dataflow don't interfere with each other
1391   // while aliased.
1392   //
1393   //  (p0 ,  p1)
1394   //   |      |
1395   //   |      |
1396   // alias   alias
1397   //   |      |
1398   //   |      |
1399   //  (p0 ,  p1)
1400   auto module = CreateNewVerifiedModule();
1401   const Shape tuple_shape =
1402       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1403 
1404   auto builder = HloComputation::Builder(TestName());
1405   auto param = builder.AddInstruction(
1406       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1407   auto gte0 = builder.AddInstruction(
1408       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1409   auto gte1 = builder.AddInstruction(
1410       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1411   builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
1412   module->AddEntryComputation(builder.Build());
1413   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1414       /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
1415       /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
1416   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1417       /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1},
1418       /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
1419   InsertCopies(module.get());
1420 
1421   EXPECT_EQ(CountCopies(*module), 0);
1422 }
1423 
TEST_F(CopyInsertionTest,ParameterWithNoAliasing)1424 TEST_F(CopyInsertionTest, ParameterWithNoAliasing) {
1425   // Test a case where no parameter is aliased with result. In this case, copy
1426   // should be added
1427   //
1428   //  (p0 ,  p1)
1429   //   |      |
1430   //   |      |
1431   //   |      |
1432   //   |      |
1433   //   |      |
1434   //  (p0 ,  p1)
1435   auto module = CreateNewVerifiedModule();
1436   const Shape tuple_shape =
1437       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1438 
1439   auto builder = HloComputation::Builder(TestName());
1440   auto param = builder.AddInstruction(
1441       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1442   auto gte0 = builder.AddInstruction(
1443       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1444   auto gte1 = builder.AddInstruction(
1445       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1446   builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
1447   module->AddEntryComputation(builder.Build());
1448   InsertCopies(module.get());
1449 
1450   EXPECT_THAT(module->entry_computation()->root_instruction(),
1451               op::Tuple(op::Copy(op::GetTupleElement(param, 0)),
1452                         op::Copy(op::GetTupleElement(param, 1))));
1453 
1454   EXPECT_EQ(CountCopies(*module), 2);
1455 }
1456 
TEST_F(CopyInsertionTest,ParameterWithPartialAliasing)1457 TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) {
1458   // Test a case where one parameter is aliased with result while another one
1459   // isn't.
1460   //
1461   //  (p0 ,  p1)
1462   //   |      |
1463   //   |      |
1464   // alias    |
1465   //   |      |
1466   //   |      |
1467   //  (p0 ,  p1)
1468   auto module = CreateNewVerifiedModule();
1469   const Shape tuple_shape =
1470       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1471 
1472   auto builder = HloComputation::Builder(TestName());
1473   auto param = builder.AddInstruction(
1474       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1475   auto gte0 = builder.AddInstruction(
1476       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1477   auto gte1 = builder.AddInstruction(
1478       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1479   builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
1480   module->AddEntryComputation(builder.Build());
1481   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1482       /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
1483       /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
1484   InsertCopies(module.get());
1485 
1486   EXPECT_THAT(module->entry_computation()->root_instruction(),
1487               op::Tuple(op::GetTupleElement(param, 0),
1488                         op::Copy(op::GetTupleElement(param, 1))));
1489 
1490   EXPECT_EQ(CountCopies(*module), 1);
1491 }
1492 
TEST_F(CopyInsertionTest,ParameterAndParallelOpsWithPartialAliasing)1493 TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) {
1494   // Test a case where one parameter is aliased with result while another one
1495   // isn't.
1496   //
1497   //   +-- (p0 ,  p1)
1498   //   |    |      |
1499   //   |    |      |
1500   // alias Negate  Negate
1501   //   |    |      |
1502   //   |    |      |
1503   //   +-- (p0 ,  p1)
1504   auto module = CreateNewVerifiedModule();
1505   const Shape tuple_shape =
1506       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1507 
1508   auto builder = HloComputation::Builder(TestName());
1509   auto param = builder.AddInstruction(
1510       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1511   auto gte0 = builder.AddInstruction(
1512       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1513   auto gte1 = builder.AddInstruction(
1514       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1515 
1516   auto negate0 = builder.AddInstruction(
1517       HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
1518 
1519   auto negate1 = builder.AddInstruction(
1520       HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
1521   builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1}));
1522   module->AddEntryComputation(builder.Build());
1523   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1524       /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
1525       /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
1526   InsertCopies(module.get());
1527 
1528   EXPECT_EQ(CountCopies(*module), 0);
1529 }
1530 
TEST_F(CopyInsertionTest,ParameterAndOpsWithPartialAliasing)1531 TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) {
1532   // Test a case where one parameter is aliased with result while another one
1533   // isn't.
1534   //
1535   //   +-- (p0 ,  p1)
1536   //   |    |      |
1537   //   |    |      |
1538   // alias Negate  Negate
1539   //   |    |      |
1540   //   |    Add----+
1541   //   |    |      |
1542   //   +-- (p0 ,  p1)
1543   auto module = CreateNewVerifiedModule();
1544   const Shape tuple_shape =
1545       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1546 
1547   auto builder = HloComputation::Builder(TestName());
1548   auto param = builder.AddInstruction(
1549       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1550   auto gte0 = builder.AddInstruction(
1551       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1552   auto gte1 = builder.AddInstruction(
1553       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1554 
1555   auto negate0 = builder.AddInstruction(
1556       HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
1557 
1558   auto negate1 = builder.AddInstruction(
1559       HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
1560 
1561   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1562       scalar_shape_, HloOpcode::kAdd, negate0, negate1));
1563   builder.AddInstruction(HloInstruction::CreateTuple({add, negate1}));
1564   module->AddEntryComputation(builder.Build());
1565   ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1566       /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
1567       /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
1568   InsertCopies(module.get());
1569 
1570   EXPECT_EQ(CountCopies(*module), 0);
1571 }
1572 
TEST_F(CopyInsertionTest,SwizzlingWhileWithOneOp)1573 TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) {
1574   // Test a while instruction with a body which permutes its tuple parameter
1575   // elements and applies one operation to one of the elements. The addition of
1576   // the operation (instruction) on the element makes the live range of the
1577   // respective input and output elements different than if the instruction were
1578   // not there (as in the SwizzlingWhile test above).
1579   auto module = CreateNewVerifiedModule();
1580   const Shape loop_state_shape =
1581       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1582 
1583   // Body interchanges the two tuple elements in the loop state and negates one
1584   // of them.
1585   auto body_builder = HloComputation::Builder("body");
1586   auto body_param = body_builder.AddInstruction(
1587       HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1588   auto body_element_0 = body_builder.AddInstruction(
1589       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
1590   auto body_element_1 = body_builder.AddInstruction(
1591       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
1592   auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
1593       scalar_shape_, HloOpcode::kNegate, body_element_1));
1594   body_builder.AddInstruction(
1595       HloInstruction::CreateTuple({negate, body_element_0}));
1596   HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1597 
1598   auto cond_builder = HloComputation::Builder("condition");
1599   cond_builder.AddInstruction(
1600       HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1601   auto cond_constant = cond_builder.AddInstruction(
1602       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1603   cond_builder.AddInstruction(HloInstruction::CreateUnary(
1604       cond_constant->shape(), HloOpcode::kNot, cond_constant));
1605   HloComputation* condition =
1606       module->AddEmbeddedComputation(cond_builder.Build());
1607 
1608   auto builder = HloComputation::Builder(TestName());
1609   auto constant1 = builder.AddInstruction(
1610       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1611   auto constant2 = builder.AddInstruction(
1612       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
1613   auto tuple = builder.AddInstruction(
1614       HloInstruction::CreateTuple({constant1, constant2}));
1615   auto xla_while = builder.AddInstruction(
1616       HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple));
1617   module->AddEntryComputation(builder.Build());
1618 
1619   InsertCopies(module.get());
1620 
1621   EXPECT_EQ(CountCopies(*module), 6);
1622 
1623   // The loop state elements should be copied at the parameter and at the root
1624   // with a control edge in between (see DeepCopyAndAddControlEdges).
1625   EXPECT_EQ(CountCopies(*body), 4);
1626   EXPECT_EQ(CountControlEdges(*body), 2);
1627 
1628   EXPECT_THAT(
1629       body->root_instruction(),
1630       op::Tuple(op::Copy(op::Negate(op::Copy())), op::Copy(op::Copy())));
1631 
1632   EXPECT_EQ(CountCopies(*module->entry_computation()), 2);
1633   EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy()));
1634 }
1635 
TEST_F(CopyInsertionTest,SwizzlingWhileSharedInput)1636 TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) {
1637   // Test a while instruction with a body which permutes it's tuple parameter
1638   // elements similar to SwizzlinWhile above. However, in this test the input to
1639   // the while body is a single constant (both loop state elements are the same
1640   // constant). This means no copies are necessary because both loop state
1641   // elements are the same so interchanging them is a no-op.
1642   auto module = CreateNewVerifiedModule();
1643   const Shape loop_state_shape =
1644       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1645 
1646   // Body simply interchanges the two tuple elements in the loop state.
1647   auto body_builder = HloComputation::Builder("body");
1648   auto body_param = body_builder.AddInstruction(
1649       HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1650   auto body_element_0 = body_builder.AddInstruction(
1651       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
1652   auto body_element_1 = body_builder.AddInstruction(
1653       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
1654   body_builder.AddInstruction(
1655       HloInstruction::CreateTuple({body_element_1, body_element_0}));
1656   HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1657 
1658   auto cond_builder = HloComputation::Builder("condition");
1659   cond_builder.AddInstruction(
1660       HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1661   auto cond_constant = cond_builder.AddInstruction(
1662       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1663   cond_builder.AddInstruction(HloInstruction::CreateUnary(
1664       cond_constant->shape(), HloOpcode::kNot, cond_constant));
1665   HloComputation* condition =
1666       module->AddEmbeddedComputation(cond_builder.Build());
1667 
1668   auto builder = HloComputation::Builder(TestName());
1669   auto constant = builder.AddInstruction(
1670       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1671   auto tuple =
1672       builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
1673   builder.AddInstruction(
1674       HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple));
1675   module->AddEntryComputation(builder.Build());
1676 
1677   InsertCopies(module.get());
1678 
1679   EXPECT_EQ(CountCopies(*module), 2);
1680   EXPECT_EQ(CountCopies(*body), 0);
1681 
1682   EXPECT_EQ(CountCopies(*module->entry_computation()), 2);
1683   EXPECT_THAT(module->entry_computation()->root_instruction(),
1684               op::Tuple(op::Copy(), op::Copy()));
1685 }
1686 
TEST_F(CopyInsertionTest,SequentialWhiles)1687 TEST_F(CopyInsertionTest, SequentialWhiles) {
1688   // Construct a computation with a series of sequential while instructions
1689   // containing four loop state elements:
1690   //
1691   //   element 0 is passed to each while directly from an entry parameter.
1692   //
1693   //   element 1 is passed transparently in series through all the while bodies.
1694   //
1695   //   element 2 is negated in each while body. (in-place possible)
1696   //
1697   //   element 3 is reversed in each while body. (in-place not possible)
1698   //
1699   const Shape element_shape = ShapeUtil::MakeShape(F32, {42});
1700   const Shape loop_state_shape = ShapeUtil::MakeTupleShape(
1701       {element_shape, element_shape, element_shape, element_shape});
1702 
1703   auto module = CreateNewVerifiedModule();
1704   auto builder = HloComputation::Builder(TestName());
1705   auto param_0 = builder.AddInstruction(
1706       HloInstruction::CreateParameter(0, element_shape, "param_0"));
1707   auto param_1 = builder.AddInstruction(
1708       HloInstruction::CreateParameter(1, element_shape, "param_1"));
1709   auto param_2 = builder.AddInstruction(
1710       HloInstruction::CreateParameter(2, element_shape, "param_2"));
1711   auto param_3 = builder.AddInstruction(
1712       HloInstruction::CreateParameter(3, element_shape, "param_3"));
1713 
1714   // The number of sequential kWhile instructions.
1715   const int kNumWhiles = 3;
1716 
1717   HloInstruction* prev_element_1 = param_1;
1718   HloInstruction* prev_element_2 = param_2;
1719   HloInstruction* prev_element_3 = param_3;
1720 
1721   // Vector containing all of the while instructions.
1722   std::vector<const HloInstruction*> whiles;
1723   for (int i = 0; i < kNumWhiles; ++i) {
1724     auto body_builder = HloComputation::Builder("body");
1725     auto body_param = body_builder.AddInstruction(
1726         HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1727     auto body_element_0 = body_builder.AddInstruction(
1728         HloInstruction::CreateGetTupleElement(element_shape, body_param, 0));
1729     auto body_element_1 = body_builder.AddInstruction(
1730         HloInstruction::CreateGetTupleElement(element_shape, body_param, 1));
1731     auto body_element_2 = body_builder.AddInstruction(
1732         HloInstruction::CreateGetTupleElement(element_shape, body_param, 2));
1733     auto body_element_3 = body_builder.AddInstruction(
1734         HloInstruction::CreateGetTupleElement(element_shape, body_param, 3));
1735     auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
1736         element_shape, HloOpcode::kNegate, body_element_2));
1737     auto reverse = body_builder.AddInstruction(
1738         HloInstruction::CreateReverse(element_shape, body_element_3, {0}));
1739     body_builder.AddInstruction(HloInstruction::CreateTuple(
1740         {body_element_0, body_element_1, negate, reverse}));
1741     HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1742 
1743     auto cond_builder = HloComputation::Builder("condition");
1744     cond_builder.AddInstruction(
1745         HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1746     auto cond_constant = cond_builder.AddInstruction(
1747         HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1748     cond_builder.AddInstruction(HloInstruction::CreateUnary(
1749         cond_constant->shape(), HloOpcode::kNot, cond_constant));
1750     HloComputation* condition =
1751         module->AddEmbeddedComputation(cond_builder.Build());
1752 
1753     auto while_init = builder.AddInstruction(HloInstruction::CreateTuple(
1754         {param_0, prev_element_1, prev_element_2, prev_element_3}));
1755 
1756     auto xla_while = builder.AddInstruction(HloInstruction::CreateWhile(
1757         loop_state_shape, condition, body, while_init));
1758     whiles.push_back(xla_while);
1759     if (i != kNumWhiles - 1) {
1760       prev_element_1 = builder.AddInstruction(
1761           HloInstruction::CreateGetTupleElement(element_shape, xla_while, 1));
1762       prev_element_2 = builder.AddInstruction(
1763           HloInstruction::CreateGetTupleElement(element_shape, xla_while, 2));
1764       prev_element_3 = builder.AddInstruction(
1765           HloInstruction::CreateGetTupleElement(element_shape, xla_while, 3));
1766     }
1767   }
1768 
1769   module->AddEntryComputation(builder.Build());
1770 
1771   InsertCopies(module.get());
1772 
1773   // Each while body has one copy. And each loop state element is copied once in
1774   // the entry computation.
1775   EXPECT_EQ(CountCopies(*module), 4 + kNumWhiles);
1776 
1777   // Each while body should have exactly one copy for element three which is an
1778   // op (kReverse) which cannot be done in place.
1779   for (const HloInstruction* xla_while : whiles) {
1780     EXPECT_EQ(CountCopies(*xla_while->while_body()), 1);
1781   }
1782 
1783   EXPECT_THAT(whiles[0]->operand(0), op::Tuple(op::Parameter(), op::Parameter(),
1784                                                op::Copy(), op::Copy()));
1785   EXPECT_THAT(module->entry_computation()->root_instruction(),
1786               op::Tuple(op::Copy(), op::Copy(), op::GetTupleElement(),
1787                         op::GetTupleElement()));
1788 }
1789 
TEST_F(CopyInsertionTest,WhileBodyWithConstantRoot)1790 TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) {
1791   // Test a while body and condition which are each simply a constant (root of
1792   // computation is a constant). The body constant should be copied.
1793   auto module = CreateNewVerifiedModule();
1794   auto builder = HloComputation::Builder(TestName());
1795   auto param_0 = builder.AddInstruction(
1796       HloInstruction::CreateParameter(0, scalar_shape_, "param_0"));
1797 
1798   auto body_builder = HloComputation::Builder("body");
1799   body_builder.AddInstruction(
1800       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
1801   body_builder.AddInstruction(
1802       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0)));
1803   HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1804 
1805   auto cond_builder = HloComputation::Builder("condition");
1806   cond_builder.AddInstruction(
1807       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
1808   cond_builder.AddInstruction(
1809       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1810   HloComputation* condition =
1811       module->AddEmbeddedComputation(cond_builder.Build());
1812 
1813   auto xla_while = builder.AddInstruction(
1814       HloInstruction::CreateWhile(scalar_shape_, condition, body, param_0));
1815 
1816   module->AddEntryComputation(builder.Build());
1817 
1818   InsertCopies(module.get());
1819 
1820   EXPECT_EQ(CountCopies(*module), 2);
1821 
1822   EXPECT_THAT(xla_while->operand(0), op::Copy(op::Parameter()));
1823   EXPECT_THAT(body->root_instruction(), op::Copy(op::Constant()));
1824   EXPECT_THAT(condition->root_instruction(), op::Constant());
1825 }
1826 
TEST_F(CopyInsertionTest,TokensShouldNotBeCopied)1827 TEST_F(CopyInsertionTest, TokensShouldNotBeCopied) {
1828   string module_string = R"(
1829 HloModule TokensShouldNotBeCopied
1830 
1831 %Body (param.1: (s32[], token[])) -> (s32[], token[]) {
1832   %param.1 = (s32[], token[]) parameter(0)
1833   %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
1834   %constant.1 = s32[] constant(1)
1835   %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
1836   %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
1837   %after-all = token[] after-all(token[] %get-tuple-element.2)
1838   ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
1839 }
1840 
1841 %Cond (param: (s32[], token[])) -> pred[] {
1842   %param = (s32[], token[]) parameter(0)
1843   %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
1844   %constant = s32[] constant(42)
1845   ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
1846 }
1847 
1848 ENTRY %TokensShouldNotBeCopied () -> s32[] {
1849   %one = s32[] constant(1)
1850   %negative_one = s32[] negate(%one)
1851   %init_token = token[] after-all()
1852   %init_tuple = (s32[], token[]) tuple(s32[] %negative_one, token[] %init_token)
1853   %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
1854   ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
1855 }
1856 )";
1857   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1858                           ParseAndReturnVerifiedModule(module_string));
1859   InsertCopies(module.get());
1860 
1861   // There should be no copies added because tokens should not be copied.
1862   EXPECT_EQ(CountCopies(*module), 0);
1863 }
1864 
MakeTrivialCondition(const Shape & shape)1865 std::unique_ptr<HloComputation> MakeTrivialCondition(const Shape& shape) {
1866   auto builder = HloComputation::Builder("trivial_condition");
1867   builder.AddInstruction(
1868       HloInstruction::CreateParameter(0, shape, "loop_state"));
1869   auto constant = builder.AddInstruction(
1870       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1871   builder.AddInstruction(HloInstruction::CreateUnary(
1872       constant->shape(), HloOpcode::kNot, constant));
1873   return builder.Build();
1874 }
1875 
MakeBenchmarkWhileBody()1876 std::unique_ptr<HloComputation> MakeBenchmarkWhileBody() {
1877   auto builder = HloComputation::Builder("benchmark_loop_body");
1878   const Shape element_shape = ShapeUtil::MakeShape(F32, {42});
1879   const Shape loop_state_shape =
1880       ShapeUtil::MakeTupleShape({element_shape, element_shape, element_shape});
1881   HloInstruction* param = builder.AddInstruction(
1882       HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
1883   HloInstruction* element_0 = builder.AddInstruction(
1884       HloInstruction::CreateGetTupleElement(element_shape, param, 0));
1885   HloInstruction* element_1 = builder.AddInstruction(
1886       HloInstruction::CreateGetTupleElement(element_shape, param, 1));
1887   HloInstruction* element_2 = builder.AddInstruction(
1888       HloInstruction::CreateGetTupleElement(element_shape, param, 2));
1889 
1890   HloInstruction* rev_1 = builder.AddInstruction(
1891       HloInstruction::CreateReverse(element_shape, element_1, {0}));
1892   HloInstruction* add_1_2 = builder.AddInstruction(HloInstruction::CreateBinary(
1893       element_shape, HloOpcode::kAdd, element_1, element_2));
1894 
1895   builder.AddInstruction(
1896       HloInstruction::CreateTuple({element_0, rev_1, add_1_2}));
1897   return builder.Build();
1898 }
1899 
BM_SequentialWhiles(int num_iters,int num_whiles)1900 void BM_SequentialWhiles(int num_iters, int num_whiles) {
1901   // This benchmark constructs a chain of sequential while instructions.
1902   tensorflow::testing::StopTiming();
1903   for (int i = 0; i < num_iters; ++i) {
1904     HloModuleConfig config;
1905     config.set_debug_options(GetDebugOptionsFromFlags());
1906     HloModule module("BM_SequentialWhiles", config);
1907 
1908     auto builder = HloComputation::Builder("BM_SequentialWhiles");
1909     HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
1910         0, ShapeUtil::MakeShape(F32, {42}), "x"));
1911     HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
1912         1, ShapeUtil::MakeShape(F32, {42}), "y"));
1913     HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
1914         2, ShapeUtil::MakeShape(F32, {42}), "z"));
1915     HloInstruction* init =
1916         builder.AddInstruction(HloInstruction::CreateTuple({x, y, z}));
1917 
1918     HloInstruction* prev_loop_state = init;
1919     for (int w = 0; w < num_whiles; ++w) {
1920       HloComputation* condition =
1921           module.AddEmbeddedComputation(MakeTrivialCondition(init->shape()));
1922       HloComputation* body =
1923           module.AddEmbeddedComputation(MakeBenchmarkWhileBody());
1924       prev_loop_state = builder.AddInstruction(HloInstruction::CreateWhile(
1925           init->shape(), condition, body, prev_loop_state));
1926     }
1927     module.AddEntryComputation(builder.Build());
1928 
1929     CopyInsertion copy_insertion;
1930 
1931     tensorflow::testing::StartTiming();
1932     ASSERT_IS_OK(copy_insertion.Run(&module).status());
1933     tensorflow::testing::StopTiming();
1934 
1935     // The entry computation should have three copies, and each body has one.
1936     ASSERT_EQ(CountCopies(module), 3 + num_whiles);
1937   }
1938 }
1939 
BM_ParallelWhiles(int num_iters,int num_whiles)1940 void BM_ParallelWhiles(int num_iters, int num_whiles) {
1941   // This benchmark constructs a fan-out of parallel while instructions.
1942   tensorflow::testing::StopTiming();
1943   for (int i = 0; i < num_iters; ++i) {
1944     HloModuleConfig config;
1945     config.set_debug_options(GetDebugOptionsFromFlags());
1946     HloModule module("BM_SequentialWhiles", config);
1947 
1948     auto builder = HloComputation::Builder("BM_ParallelWhiles");
1949     HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
1950         0, ShapeUtil::MakeShape(F32, {42}), "x"));
1951     HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
1952         1, ShapeUtil::MakeShape(F32, {42}), "y"));
1953     HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
1954         2, ShapeUtil::MakeShape(F32, {42}), "z"));
1955     HloInstruction* init =
1956         builder.AddInstruction(HloInstruction::CreateTuple({x, y, z}));
1957 
1958     HloInstruction* sum = nullptr;
1959     for (int w = 0; w < num_whiles; ++w) {
1960       HloComputation* condition =
1961           module.AddEmbeddedComputation(MakeTrivialCondition(init->shape()));
1962       HloComputation* body =
1963           module.AddEmbeddedComputation(MakeBenchmarkWhileBody());
1964 
1965       HloInstruction* xla_while = builder.AddInstruction(
1966           HloInstruction::CreateWhile(init->shape(), condition, body, init));
1967 
1968       if (sum == nullptr) {
1969         sum = builder.AddInstruction(
1970             HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0));
1971       } else {
1972         HloInstruction* element_0 = builder.AddInstruction(
1973             HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0));
1974         sum = builder.AddInstruction(HloInstruction::CreateBinary(
1975             x->shape(), HloOpcode::kAdd, sum, element_0));
1976       }
1977     }
1978     module.AddEntryComputation(builder.Build());
1979 
1980     CopyInsertion copy_insertion;
1981 
1982     tensorflow::testing::StartTiming();
1983     ASSERT_IS_OK(copy_insertion.Run(&module).status());
1984     tensorflow::testing::StopTiming();
1985 
1986     // Each body receives of copy of two of the parameters (the corresponding
1987     // elements in the body are modifed), and there is one copy in each body.
1988     ASSERT_EQ(CountCopies(module), 3 * num_whiles);
1989   }
1990 }
1991 
MakeBenchmarkWhileBody(const int num_tuple_inputs)1992 std::unique_ptr<HloComputation> MakeBenchmarkWhileBody(
1993     const int num_tuple_inputs) {
1994   auto builder = HloComputation::Builder("benchmark_loop_body");
1995   const Shape element_shape = ShapeUtil::MakeShape(F32, {});
1996   std::vector<Shape> input_shape(num_tuple_inputs, element_shape);
1997   const Shape loop_state_shape = ShapeUtil::MakeTupleShape(input_shape);
1998   HloInstruction* param = builder.AddInstruction(
1999       HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
2000   std::vector<HloInstruction*> gte_nodes(num_tuple_inputs);
2001   for (int i = 0; i < num_tuple_inputs; ++i) {
2002     gte_nodes[i] = builder.AddInstruction(
2003         HloInstruction::CreateGetTupleElement(element_shape, param, i));
2004   }
2005   builder.AddInstruction(HloInstruction::CreateTuple(gte_nodes));
2006   return builder.Build();
2007 }
2008 
BM_ManyElementTuple(int num_iters,const int num_tuple_inputs)2009 void BM_ManyElementTuple(int num_iters, const int num_tuple_inputs) {
2010   tensorflow::testing::StopTiming();
2011   HloModuleConfig config;
2012   config.set_debug_options(GetDebugOptionsFromFlags());
2013   CopyInsertion copy_insertion;
2014   const Shape element_shape = ShapeUtil::MakeShape(F32, {});
2015   std::vector<HloInstruction*> tuple_params(num_tuple_inputs);
2016   for (int i = 0; i < num_iters; ++i) {
2017     auto builder = HloComputation::Builder("BM_ParallelWhiles");
2018     HloModule module("BM_ManyElementTuple", config);
2019     for (int j = 0; j < num_tuple_inputs; ++j) {
2020       tuple_params[j] = builder.AddInstruction(
2021           HloInstruction::CreateParameter(j, element_shape, ""));
2022     }
2023     HloInstruction* init =
2024         builder.AddInstruction(HloInstruction::CreateTuple(tuple_params));
2025     HloComputation* condition =
2026         module.AddEmbeddedComputation(MakeTrivialCondition(init->shape()));
2027     HloComputation* body =
2028         module.AddEmbeddedComputation(MakeBenchmarkWhileBody(num_tuple_inputs));
2029     HloInstruction* xla_while = builder.AddInstruction(
2030         HloInstruction::CreateWhile(init->shape(), condition, body, init));
2031     builder.AddInstruction(HloInstruction::CreateGetTupleElement(
2032         ShapeUtil::MakeShape(F32, {}), xla_while, 0));
2033     module.AddEntryComputation(builder.Build());
2034     tensorflow::testing::StartTiming();
2035     ASSERT_IS_OK(copy_insertion.Run(&module).status());
2036     tensorflow::testing::StopTiming();
2037   }
2038 }
2039 
2040 BENCHMARK(BM_SequentialWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096);
2041 BENCHMARK(BM_ParallelWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096);
2042 BENCHMARK(BM_ManyElementTuple)->Arg(1024)->Arg(12288);
2043 
TEST_F(CopyInsertionTest,SimpleControlFlowTest)2044 TEST_F(CopyInsertionTest, SimpleControlFlowTest) {
2045   const string& hlo_string = R"(
2046 HloModule TestModule
2047 
2048 if-body.v5 {
2049   constant.3 = s32[] constant(-1)
2050   p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2051   get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1
2052   get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0
2053   get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1
2054   add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66)
2055   tuple.33 = (s32[]) tuple(add.3)
2056   ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33)
2057 }
2058 
2059 if-condition.v4 {
2060   p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2061   get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0
2062   constant.4 = s32[] constant(0)
2063   ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ
2064 }
2065 
2066 _functionalize_body_1__.v28 {
2067   arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0)
2068   get-tuple-element.68 = s32[] get-tuple-element(arg_tuple.4), index=0
2069   constant.7 = s32[] constant(1)
2070   add.4 = s32[] add(get-tuple-element.68, constant.7)
2071   get-tuple-element.69 = s32[] get-tuple-element(arg_tuple.4), index=1
2072   get-tuple-element.70 = s32[] get-tuple-element(arg_tuple.4), index=2
2073   less-than-or-equal-to = pred[] compare(get-tuple-element.69, get-tuple-element.70), direction=LE
2074   constant.8 = s32[] constant(0)
2075   select = s32[] select(less-than-or-equal-to, constant.8, constant.7)
2076   get-tuple-element.71 = s32[] get-tuple-element(arg_tuple.4), index=3
2077   tuple.35 = (s32[], s32[], s32[]) tuple(get-tuple-element.69, get-tuple-element.71, get-tuple-element.70)
2078   tuple.36 = (s32[]) tuple(constant.8)
2079   tuple.37 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.35, tuple.36)
2080   while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.37), condition=if-condition.v4, body=if-body.v5
2081   get-tuple-element.72 = (s32[]) get-tuple-element(while), index=2
2082   get-tuple-element.73 = s32[] get-tuple-element(get-tuple-element.72), index=0
2083   ROOT tuple.38 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.69, get-tuple-element.70, get-tuple-element.73)
2084 }
2085 
2086 cond_wrapper.v3.1 {
2087   inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0)
2088   get-tuple-element.75 = s32[] get-tuple-element(inputs.1), index=0
2089   constant.11 = s32[] constant(7)
2090   ROOT less-than.2 = pred[] compare(get-tuple-element.75, constant.11), direction=LT
2091 }
2092 
2093 _functionalize_body_2__.v25 {
2094   arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2095   get-tuple-element.76 = s32[] get-tuple-element(arg_tuple.5), index=0
2096   get-tuple-element.77 = s32[] get-tuple-element(arg_tuple.5), index=2
2097   get-tuple-element.78 = s32[] get-tuple-element(arg_tuple.5), index=3
2098   get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=4
2099   tuple.39 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.76, get-tuple-element.77, get-tuple-element.78, get-tuple-element.79)
2100   while.2 = (s32[], s32[], s32[], s32[]) while(tuple.39), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28
2101   get-tuple-element.80 = s32[] get-tuple-element(while.2), index=0
2102   get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=1
2103   constant.12 = s32[] constant(1)
2104   add.5 = s32[] add(get-tuple-element.81, constant.12)
2105   get-tuple-element.82 = s32[] get-tuple-element(while.2), index=3
2106   ROOT tuple.40 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.80, add.5, get-tuple-element.77, get-tuple-element.78, get-tuple-element.82)
2107 }
2108 
2109 cond_wrapper.v3.2 {
2110   inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2111   get-tuple-element.83 = s32[] get-tuple-element(inputs.2), index=1
2112   constant.13 = s32[] constant(5)
2113   ROOT less-than.3 = pred[] compare(get-tuple-element.83, constant.13), direction=LT
2114 }
2115 
2116 ENTRY TestComputation {
2117   arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2118   ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25
2119 }
2120 )";
2121   auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
2122   auto module = module_or_status.ConsumeValueOrDie();
2123   InsertCopies(module.get());
2124 }
2125 
TEST_F(CopyInsertionTest,ControlFlowTest)2126 TEST_F(CopyInsertionTest, ControlFlowTest) {
2127   const string& hlo_string = R"(
2128 HloModule TestModule
2129 
2130 if-body.v5 {
2131   constant.3 = s32[] constant(-1)
2132   p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2133   get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1
2134   get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0
2135   get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1
2136   add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66)
2137   tuple.33 = (s32[]) tuple(add.3)
2138   ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33)
2139 }
2140 
2141 if-condition.v4 {
2142   p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2143   get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0
2144   constant.4 = s32[] constant(0)
2145   ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ
2146 }
2147 
2148 if-body.v5.1 {
2149   constant.5 = s32[] constant(-1)
2150   p.3 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2151   get-tuple-element.68 = (s32[], s32[], s32[]) get-tuple-element(p.3), index=1
2152   get-tuple-element.70 = s32[] get-tuple-element(get-tuple-element.68), index=2
2153   multiply.1 = s32[] multiply(get-tuple-element.70, get-tuple-element.70)
2154   tuple.35 = (s32[]) tuple(multiply.1)
2155   ROOT tuple.36 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.5, get-tuple-element.68, tuple.35)
2156 }
2157 
2158 if-condition.v4.1 {
2159   p.4 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2160   get-tuple-element.71 = s32[] get-tuple-element(p.4), index=0
2161   constant.6 = s32[] constant(1)
2162   ROOT equal-to.1 = pred[] compare(get-tuple-element.71, constant.6), direction=EQ
2163 }
2164 
2165 _functionalize_body_1__.v28 {
2166   arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0)
2167   get-tuple-element.72 = s32[] get-tuple-element(arg_tuple.4), index=0
2168   constant.7 = s32[] constant(1)
2169   add.4 = s32[] add(get-tuple-element.72, constant.7)
2170   get-tuple-element.73 = s32[] get-tuple-element(arg_tuple.4), index=1
2171   get-tuple-element.74 = s32[] get-tuple-element(arg_tuple.4), index=2
2172   less-than-or-equal-to = pred[] compare(get-tuple-element.73, get-tuple-element.74), direction=LE
2173   constant.8 = s32[] constant(0)
2174   select = s32[] select(less-than-or-equal-to, constant.8, constant.7)
2175   get-tuple-element.75 = s32[] get-tuple-element(arg_tuple.4), index=3
2176   tuple.37 = (s32[], s32[], s32[]) tuple(get-tuple-element.73, get-tuple-element.75, get-tuple-element.74)
2177   tuple.38 = (s32[]) tuple(constant.8)
2178   tuple.39 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.37, tuple.38)
2179   while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.39), condition=if-condition.v4, body=if-body.v5
2180   while.1 = (s32[], (s32[], s32[], s32[]), (s32[])) while(while), condition=if-condition.v4.1, body=if-body.v5.1
2181   get-tuple-element.76 = (s32[]) get-tuple-element(while.1), index=2
2182   get-tuple-element.77 = s32[] get-tuple-element(get-tuple-element.76), index=0
2183   ROOT tuple.40 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.73, get-tuple-element.74, get-tuple-element.77)
2184 }
2185 
2186 cond_wrapper.v3.1 {
2187   inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0)
2188   get-tuple-element.78 = s32[] get-tuple-element(inputs.1), index=0
2189   constant.11 = s32[] constant(7)
2190   ROOT less-than.2 = pred[] compare(get-tuple-element.78, constant.11), direction=LT
2191 }
2192 
2193 _functionalize_body_2__.v25 {
2194   arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2195   get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=0
2196   get-tuple-element.80 = s32[] get-tuple-element(arg_tuple.5), index=2
2197   get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=3
2198   get-tuple-element.82 = s32[] get-tuple-element(arg_tuple.5), index=4
2199   tuple.41 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.79, get-tuple-element.80, get-tuple-element.81, get-tuple-element.82)
2200   while.2 = (s32[], s32[], s32[], s32[]) while(tuple.41), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28
2201   get-tuple-element.83 = s32[] get-tuple-element(while.2), index=0
2202   get-tuple-element.84 = s32[] get-tuple-element(arg_tuple.5), index=1
2203   constant.12 = s32[] constant(1)
2204   add.5 = s32[] add(get-tuple-element.84, constant.12)
2205   get-tuple-element.85 = s32[] get-tuple-element(while.2), index=3
2206   ROOT tuple.42 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.83, add.5, get-tuple-element.80, get-tuple-element.81, get-tuple-element.85)
2207 }
2208 
2209 cond_wrapper.v3.2 {
2210   inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2211   get-tuple-element.86 = s32[] get-tuple-element(inputs.2), index=1
2212   constant.13 = s32[] constant(5)
2213   ROOT less-than.3 = pred[] compare(get-tuple-element.86, constant.13), direction=LT
2214 }
2215 
2216 ENTRY TestComputation {
2217   arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2218   ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25
2219 }
2220 )";
2221   auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
2222   auto module = module_or_status.ConsumeValueOrDie();
2223   InsertCopies(module.get());
2224 }
2225 
TEST_F(CopyInsertionTest,NestedWhiles)2226 TEST_F(CopyInsertionTest, NestedWhiles) {
2227   // Verify that only no unnecessary copies remain after copy insertion for
2228   // trivial nested whiles (b/112472605).
2229   const string& hlo_string = R"(
2230 HloModule TestModule
2231 
2232 cond.inner {
2233   ROOT param.cond.inner = pred[] parameter(0)
2234 }
2235 
2236 body.inner {
2237   param.body.inner = pred[] parameter(0)
2238   ROOT not = pred[] not(param.body.inner)
2239 }
2240 
2241 cond.outer {
2242   ROOT param.cond.outer = pred[] parameter(0)
2243 }
2244 
2245 body.outer {
2246   param.cond.outer = pred[] parameter(0)
2247   ROOT while = pred[] while(param.cond.outer), condition=cond.inner, body=body.inner
2248 }
2249 
2250 ENTRY TestComputation {
2251   entry_param = pred[] parameter(0)
2252   ROOT while = pred[] while(entry_param), condition=cond.outer, body=body.outer
2253 }
2254 )";
2255   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2256                           ParseAndReturnVerifiedModule(hlo_string));
2257   InsertCopies(module.get());
2258 
2259   // There should only be a single copy inserted, and it's in the entry
2260   // computation.
2261   EXPECT_EQ(CountCopies(*module), 1);
2262   EXPECT_THAT(module->entry_computation()->root_instruction(),
2263               op::While(op::Copy(op::Parameter())));
2264 }
2265 
2266 }  // namespace
2267 }  // namespace xla
2268