1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/copy_insertion.h"
17
18 #include <set>
19
20 #include "tensorflow/compiler/xla/debug_options_flags.h"
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/hlo_runner.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/platform/test_benchmark.h"
34
35 namespace op = xla::testing::opcode_matchers;
36
37 namespace xla {
38 namespace {
39
40 using ::testing::UnorderedElementsAre;
41
CountCopies(const HloComputation & computation)42 int64 CountCopies(const HloComputation& computation) {
43 int64 count = 0;
44 for (const auto& instruction : computation.instructions()) {
45 if (instruction->opcode() == HloOpcode::kCopy) {
46 count++;
47 }
48 }
49 return count;
50 }
51
CountCopies(const HloModule & module)52 int64 CountCopies(const HloModule& module) {
53 int64 count = 0;
54 for (const auto& computation : module.computations()) {
55 count += CountCopies(*computation);
56 }
57 return count;
58 }
59
CountControlEdges(const HloComputation & computation)60 int64 CountControlEdges(const HloComputation& computation) {
61 int64 count = 0;
62 for (const auto& instruction : computation.instructions()) {
63 count += instruction->control_successors().size();
64 }
65 return count;
66 }
67
CountControlEdges(const HloModule & module)68 int64 CountControlEdges(const HloModule& module) {
69 int64 count = 0;
70 for (const auto& computation : module.computations()) {
71 count += CountControlEdges(*computation);
72 }
73 return count;
74 }
75
76 class CopyInsertionTest : public HloTestBase {
77 protected:
InsertCopies(HloModule * module)78 void InsertCopies(HloModule* module) {
79 CopyInsertion copy_insertion;
80 ASSERT_IS_OK(copy_insertion.Run(module).status());
81 }
82
83 const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
84 };
85
TEST_F(CopyInsertionTest,SingleParameter)86 TEST_F(CopyInsertionTest, SingleParameter) {
87 // Computation is a single parameter passed into a tuple. The parameter should
88 // be copied before entering the tuple.
89 auto builder = HloComputation::Builder(TestName());
90 HloInstruction* x = builder.AddInstruction(
91 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x"));
92 HloInstruction* tuple =
93 builder.AddInstruction(HloInstruction::CreateTuple({x}));
94
95 EXPECT_THAT(x->users(), UnorderedElementsAre(tuple));
96
97 auto module = CreateNewVerifiedModule();
98 module->AddEntryComputation(builder.Build());
99
100 InsertCopies(module.get());
101
102 EXPECT_THAT(module->entry_computation()->root_instruction(),
103 op::Tuple(op::Copy(x)));
104 }
105
TEST_F(CopyInsertionTest,SingleConstant)106 TEST_F(CopyInsertionTest, SingleConstant) {
107 // Computation is a single constant passed into a tuple. The parameter should
108 // be copied before entering the tuple.
109 auto builder = HloComputation::Builder(TestName());
110 HloInstruction* constant = builder.AddInstruction(
111 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
112 HloInstruction* tuple =
113 builder.AddInstruction(HloInstruction::CreateTuple({constant}));
114
115 EXPECT_THAT(constant->users(), UnorderedElementsAre(tuple));
116
117 auto module = CreateNewVerifiedModule();
118 module->AddEntryComputation(builder.Build());
119
120 InsertCopies(module.get());
121 EXPECT_EQ(CountCopies(*module), 1);
122
123 EXPECT_THAT(module->entry_computation()->root_instruction(),
124 op::Tuple(op::Copy(constant)));
125 }
126
TEST_F(CopyInsertionTest,ExistingCopiesNotRemoved)127 TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) {
128 // Verify that kCopy instructions which change layout and exist before
129 // copy-insertion remain in the graph after copy-insertion.
130 auto module = CreateNewVerifiedModule();
131
132 auto builder = HloComputation::Builder(TestName());
133 HloInstruction* constant =
134 builder.AddInstruction(HloInstruction::CreateConstant(
135 LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}})));
136 auto minor_to_major = LayoutUtil::MinorToMajor(constant->shape());
137 Layout reversed_layout =
138 LayoutUtil::MakeLayoutFromMajorToMinor(minor_to_major);
139 Shape copy_shape = constant->shape();
140 *copy_shape.mutable_layout() = reversed_layout;
141 HloInstruction* copy_1 = builder.AddInstruction(
142 HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant));
143 HloInstruction* copy_2 = builder.AddInstruction(
144 HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant));
145 HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
146 constant->shape(), HloOpcode::kAdd, copy_1, copy_2));
147 builder.AddInstruction(
148 HloInstruction::CreateUnary(add->shape(), HloOpcode::kCopy, add));
149
150 module->AddEntryComputation(builder.Build());
151
152 EXPECT_EQ(CountCopies(*module), 3);
153
154 InsertCopies(module.get());
155
156 EXPECT_EQ(CountCopies(*module), 2);
157
158 EXPECT_EQ(module->entry_computation()->root_instruction(), add);
159 EXPECT_THAT(module->entry_computation()->root_instruction(),
160 op::Add(op::Copy(op::Constant()), op::Copy(op::Constant())));
161 }
162
TEST_F(CopyInsertionTest,MultipleConstantsAndParameters)163 TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) {
164 // Create a computation with more than one constant and parameter. Only one of
165 // each constant/parameter is pointed to by the output tuple. Only these
166 // instructions should be copied.
167 auto builder = HloComputation::Builder(TestName());
168
169 HloInstruction* constant1 = builder.AddInstruction(
170 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
171 HloInstruction* constant2 = builder.AddInstruction(
172 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
173
174 HloInstruction* x = builder.AddInstruction(
175 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x"));
176 HloInstruction* y = builder.AddInstruction(
177 HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "y"));
178
179 HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
180 ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, constant1, y));
181
182 builder.AddInstruction(HloInstruction::CreateTuple({constant2, x, add}));
183
184 auto module = CreateNewVerifiedModule();
185 module->AddEntryComputation(builder.Build());
186
187 InsertCopies(module.get());
188 EXPECT_EQ(CountCopies(*module), 2);
189
190 EXPECT_THAT(
191 module->entry_computation()->root_instruction(),
192 op::Tuple(op::Copy(constant2), op::Copy(x), op::Add(constant1, y)));
193 }
194
TEST_F(CopyInsertionTest,AmbiguousPointsToSet)195 TEST_F(CopyInsertionTest, AmbiguousPointsToSet) {
196 // Create a computation using select which has an ambiguous points-to set for
197 // the computation result. Verify that copies are added properly.
198 auto builder = HloComputation::Builder(TestName());
199 HloInstruction* constant1 = builder.AddInstruction(
200 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
201 HloInstruction* constant2 = builder.AddInstruction(
202 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
203 HloInstruction* constant3 = builder.AddInstruction(
204 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
205
206 HloInstruction* tuple1 = builder.AddInstruction(
207 HloInstruction::CreateTuple({constant1, constant2}));
208 HloInstruction* tuple2 = builder.AddInstruction(
209 HloInstruction::CreateTuple({constant3, constant2}));
210
211 HloInstruction* pred = builder.AddInstruction(
212 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
213 builder.AddInstruction(HloInstruction::CreateTernary(
214 tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
215
216 EXPECT_THAT(constant1->users(), UnorderedElementsAre(tuple1));
217 EXPECT_THAT(constant2->users(), UnorderedElementsAre(tuple1, tuple2));
218 EXPECT_THAT(constant3->users(), UnorderedElementsAre(tuple2));
219
220 auto module = CreateNewVerifiedModule();
221 module->AddEntryComputation(builder.Build());
222
223 HloInstruction* old_root = module->entry_computation()->root_instruction();
224 InsertCopies(module.get());
225 EXPECT_EQ(CountCopies(*module), 2);
226
227 EXPECT_THAT(module->entry_computation()->root_instruction(),
228 op::Tuple(op::Copy(op::GetTupleElement(old_root)),
229 op::Copy(op::GetTupleElement(old_root))));
230 }
231
TEST_F(CopyInsertionTest,BitcastParameter)232 TEST_F(CopyInsertionTest, BitcastParameter) {
233 // The output of a bitcast is its operand (same buffer), so a bitcast
234 // parameter feeding the result must have a copy added.
235 auto builder = HloComputation::Builder(TestName());
236 HloInstruction* x = builder.AddInstruction(
237 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x"));
238 HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
239 ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x));
240
241 auto module = CreateNewVerifiedModule();
242 module->AddEntryComputation(builder.Build());
243
244 EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast));
245
246 HloInstruction* old_root = module->entry_computation()->root_instruction();
247 InsertCopies(module.get());
248 EXPECT_EQ(CountCopies(*module), 1);
249
250 EXPECT_THAT(module->entry_computation()->root_instruction(),
251 op::Copy(old_root));
252 }
253
TEST_F(CopyInsertionTest,BitcastConstant)254 TEST_F(CopyInsertionTest, BitcastConstant) {
255 // The output of a bitcast is its operand (same buffer), so a bitcast
256 // constant feeding the result must have a copy added.
257 auto builder = HloComputation::Builder(TestName());
258 HloInstruction* constant =
259 builder.AddInstruction(HloInstruction::CreateConstant(
260 LiteralUtil::CreateR1<float>({1.0, 42.0})));
261 HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
262 ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant));
263
264 auto module = CreateNewVerifiedModule();
265 module->AddEntryComputation(builder.Build());
266
267 EXPECT_THAT(constant->users(), UnorderedElementsAre(bitcast));
268
269 HloInstruction* old_root = module->entry_computation()->root_instruction();
270 InsertCopies(module.get());
271 EXPECT_EQ(CountCopies(*module), 1);
272
273 EXPECT_THAT(module->entry_computation()->root_instruction(),
274 op::Copy(old_root));
275 }
276
TEST_F(CopyInsertionTest,BitcastTupleElementParameter)277 TEST_F(CopyInsertionTest, BitcastTupleElementParameter) {
278 // Same as BitcastParameter, but the bitcast is wrapped in a tuple.
279 auto builder = HloComputation::Builder(TestName());
280 HloInstruction* x = builder.AddInstruction(
281 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x"));
282 HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
283 ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x));
284 builder.AddInstruction(HloInstruction::CreateTuple({bitcast}));
285
286 auto module = CreateNewVerifiedModule();
287 module->AddEntryComputation(builder.Build());
288
289 EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast));
290
291 InsertCopies(module.get());
292 EXPECT_EQ(CountCopies(*module), 1);
293
294 EXPECT_THAT(module->entry_computation()->root_instruction(),
295 op::Tuple(op::Copy(bitcast)));
296 }
297
TEST_F(CopyInsertionTest,NestedTupleParameter)298 TEST_F(CopyInsertionTest, NestedTupleParameter) {
299 // Construct a trivial computation where the root of the computation is a
300 // nested tuple-shaped parameter. The parameter should be deep copied and the
301 // copy should be the root of the computation.
302 auto builder = HloComputation::Builder(TestName());
303
304 // Param shape is: ((F32[], S32[1,2,3]), F32[42])
305 builder.AddInstruction(HloInstruction::CreateParameter(
306 0,
307 ShapeUtil::MakeTupleShape(
308 {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}),
309 ShapeUtil::MakeShape(S32, {1, 2, 3})}),
310 ShapeUtil::MakeShape(F32, {42})}),
311 "param0"));
312
313 auto module = CreateNewVerifiedModule();
314 module->AddEntryComputation(builder.Build());
315
316 EXPECT_EQ(HloOpcode::kParameter,
317 module->entry_computation()->root_instruction()->opcode());
318
319 HloInstruction* old_root = module->entry_computation()->root_instruction();
320 InsertCopies(module.get());
321 EXPECT_EQ(CountCopies(*module), 3);
322
323 HloInstruction* new_root = module->entry_computation()->root_instruction();
324 EXPECT_NE(old_root, new_root);
325
326 EXPECT_THAT(
327 new_root,
328 op::Tuple(
329 op::Tuple(
330 op::Copy(op::GetTupleElement(op::GetTupleElement(old_root))),
331 op::Copy(op::GetTupleElement(op::GetTupleElement(old_root)))),
332 op::Copy(op::GetTupleElement(old_root))));
333 }
334
TEST_F(CopyInsertionTest,ElementOfNestedTupleParameter)335 TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) {
336 // Construct a computation where the root of the computation is a tuple
337 // element of a nested tuple-shaped parameter.
338 auto builder = HloComputation::Builder(TestName());
339
340 // Param shape is: ((F32[], S32[1,2,3]), F32[42])
341 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
342 0,
343 ShapeUtil::MakeTupleShape(
344 {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}),
345 ShapeUtil::MakeShape(S32, {1, 2, 3})}),
346 ShapeUtil::MakeShape(F32, {42})}),
347 "param0"));
348
349 // The return value of the computation is the zero-th element of the nested
350 // tuple. This element is itself a tuple.
351 auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
352 ShapeUtil::GetSubshape(param->shape(), {0}), param, 0));
353
354 auto module = CreateNewVerifiedModule();
355 module->AddEntryComputation(builder.Build());
356
357 EXPECT_EQ(gte, module->entry_computation()->root_instruction());
358
359 InsertCopies(module.get());
360 EXPECT_EQ(CountCopies(*module), 2);
361
362 EXPECT_THAT(
363 module->entry_computation()->root_instruction(),
364 op::Tuple(op::Copy(op::GetTupleElement(op::GetTupleElement(param))),
365 op::Copy(op::GetTupleElement(op::GetTupleElement(param)))));
366 }
367
TEST_F(CopyInsertionTest,AmbiguousTopLevelRoot)368 TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) {
369 // Create a computation using select which has an ambiguous points-to set for
370 // the top-level buffer of the root of the computation. Verify that a shallow
371 // copy is added.
372 auto builder = HloComputation::Builder(TestName());
373 HloInstruction* constant1 = builder.AddInstruction(
374 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
375 HloInstruction* constant2 = builder.AddInstruction(
376 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
377
378 HloInstruction* tuple1 = builder.AddInstruction(
379 HloInstruction::CreateTuple({constant1, constant2}));
380 HloInstruction* tuple2 = builder.AddInstruction(
381 HloInstruction::CreateTuple({constant2, constant1}));
382
383 HloInstruction* pred = builder.AddInstruction(
384 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
385 HloInstruction* select = builder.AddInstruction(HloInstruction::CreateTernary(
386 tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
387 HloInstruction* gte =
388 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
389 ShapeUtil::GetSubshape(select->shape(), {0}), select, 0));
390
391 auto module = CreateNewVerifiedModule();
392 module->AddEntryComputation(builder.Build());
393
394 EXPECT_EQ(gte, module->entry_computation()->root_instruction());
395
396 HloInstruction* old_root = module->entry_computation()->root_instruction();
397 InsertCopies(module.get());
398 EXPECT_EQ(CountCopies(*module), 1);
399
400 EXPECT_THAT(module->entry_computation()->root_instruction(),
401 op::Copy(old_root));
402 }
403
404 class WhileCopyInsertionTest : public CopyInsertionTest {
405 protected:
WhileCopyInsertionTest()406 WhileCopyInsertionTest() : module_(CreateNewUnverifiedModule()) {}
407
408 // Builds a While condition computation which reads the induction variable
409 // from the tuple parameter, and returns a predicate indicating whether this
410 // value is less than the constant '10'.
411 // The parameter 'nested' specifies the loop state shape from which to
412 // read the induction variable.
BuildConditionComputation(const Shape & loop_state_shape)413 std::unique_ptr<HloComputation> BuildConditionComputation(
414 const Shape& loop_state_shape) {
415 auto builder = HloComputation::Builder(TestName() + ".Condition");
416 auto limit_const = builder.AddInstruction(
417 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(10)));
418 auto loop_state = builder.AddInstruction(
419 HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
420 auto induction_variable =
421 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
422 limit_const->shape(), loop_state, 0));
423 builder.AddInstruction(HloInstruction::CreateCompare(
424 condition_result_shape_, induction_variable, limit_const,
425 ComparisonDirection::kLt));
426 return builder.Build();
427 }
428
429 // Builds a While body computation with one output tuple element dependent on
430 // both input tuple elements.
431 // EX:
432 // Body({in0, in1})
433 // out0 = Add(in0, 1)
434 // out1 = Add(BCast(in0), in1)
435 // Tuple(out0, out1)
BuildDependentBodyComputation()436 std::unique_ptr<HloComputation> BuildDependentBodyComputation() {
437 auto builder = HloComputation::Builder(TestName() + ".Body");
438 // Create param instruction to access loop state.
439 auto loop_state = builder.AddInstruction(
440 HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
441 // Update the induction variable GTE(0).
442 auto induction_variable =
443 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
444 induction_variable_shape_, loop_state, 0));
445 auto inc = builder.AddInstruction(
446 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
447 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
448 induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
449 // Update data GTE(1).
450 auto data = builder.AddInstruction(
451 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
452 // Use 'induction_variable' in computation with no path to output tuple.
453 auto update = builder.AddInstruction(
454 HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8}));
455 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
456 data_shape_, HloOpcode::kAdd, data, update));
457 // Create output Tuple.
458 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
459 return builder.Build();
460 }
461
462 // Builds a While body computation with two output tuple elements dependent on
463 // both input tuple elements.
464 //
465 // EX: Body({in0, in1, in2})
466 // out0 = Add(in0, 1)
467 // out1 = in1
468 // out2 = in2
469 // Tuple(out0, out1, out2)
BuildDependentBodyComputation2()470 std::unique_ptr<HloComputation> BuildDependentBodyComputation2() {
471 auto builder = HloComputation::Builder(TestName() + ".Body");
472
473 const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
474 {induction_variable_shape_, data_shape_, data_shape_});
475
476 auto loop_state = builder.AddInstruction(
477 HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
478
479 // Update the induction variable GTE(0).
480 auto induction_variable =
481 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
482 induction_variable_shape_, loop_state, 0));
483 auto inc = builder.AddInstruction(
484 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
485
486 // add0 = Add(in0, 1)
487 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
488 induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
489 // data1 = GTE(1).
490 HloInstruction* data1 = builder.AddInstruction(
491 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
492
493 // data2 = GTE(2).
494 HloInstruction* data2 = builder.AddInstruction(
495 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 2));
496
497 // Create output Tuple.
498 builder.AddInstruction(HloInstruction::CreateTuple({add0, data1, data2}));
499
500 return builder.Build();
501 }
502
503 // Builds a While body computation with read-only tuple element 0.
504 // EX:
505 // Body({in0, in1})
506 // out0 = in0
507 // out1 = Add(BCast(in0), in1)
508 // Tuple(out0, out1)
BuildDependentBodyOneReadOnlyComputation()509 std::unique_ptr<HloComputation> BuildDependentBodyOneReadOnlyComputation() {
510 auto builder = HloComputation::Builder(TestName() + ".Body");
511 // Create param instruction to access loop state.
512 auto loop_state = builder.AddInstruction(
513 HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
514 // Update the induction variable GTE(0).
515 auto induction_variable =
516 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
517 induction_variable_shape_, loop_state, 0));
518 // Update data GTE(1).
519 auto data = builder.AddInstruction(
520 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
521
522 // Use 'induction_variable' in computation with no path to output tuple.
523 auto update = builder.AddInstruction(
524 HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8}));
525 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
526 data_shape_, HloOpcode::kAdd, data, update));
527 // Create output Tuple.
528 builder.AddInstruction(
529 HloInstruction::CreateTuple({induction_variable, add1}));
530 return builder.Build();
531 }
532
533 // Builds a While body computation with independent outputs.
534 // EX:
535 // Body({in0, in1})
536 // out0 = Add(in0, 1)
537 // out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
538 // Tuple(out0, out1)
BuildIndependentBodyComputation(bool nested=false)539 std::unique_ptr<HloComputation> BuildIndependentBodyComputation(
540 bool nested = false) {
541 auto builder = HloComputation::Builder(TestName() + ".Body");
542 // Create param instruction to access loop state.
543 const Shape& loop_state_shape =
544 nested ? nested_loop_state_shape_ : loop_state_shape_;
545
546 auto loop_state = builder.AddInstruction(
547 HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
548 // Update the induction variable GTE(0).
549 auto induction_variable =
550 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
551 induction_variable_shape_, loop_state, 0));
552 auto inc = builder.AddInstruction(
553 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
554 // add0 = Add(in0, 1)
555 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
556 induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
557 // Update data GTE(1).
558 HloInstruction* data = nullptr;
559 if (nested) {
560 data = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
561 nested_tuple_shape_, loop_state, 1));
562 data = builder.AddInstruction(
563 HloInstruction::CreateGetTupleElement(data_shape_, data, 0));
564 } else {
565 data = builder.AddInstruction(
566 HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
567 }
568 auto update = builder.AddInstruction(
569 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
570 {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
571 // add1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
572 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
573 data_shape_, HloOpcode::kAdd, data, update));
574 // Create output Tuple.
575 if (nested) {
576 auto nested_tuple =
577 builder.AddInstruction(HloInstruction::CreateTuple({add1, add1}));
578 builder.AddInstruction(HloInstruction::CreateTuple({add0, nested_tuple}));
579 } else {
580 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
581 }
582 return builder.Build();
583 }
584
585 // Builds a While body computation with the following nested tuple
586 // sub-computation:
587 // |
588 // GTE(loop_state, 1)
589 // / \
590 // GTE(GTE(loop_state, 1), 0) GTE(GTE(loop_state, 1), 1)
591 // | |
592 // Add Reverse
593 // | |
BuildNestedBodyComputation()594 std::unique_ptr<HloComputation> BuildNestedBodyComputation() {
595 auto builder = HloComputation::Builder(TestName() + ".Body");
596 // Create param instruction to access loop state.
597 auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter(
598 0, nested_loop_state_shape_, "loop_state"));
599 // Update GTE(0).
600 auto gte0 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
601 induction_variable_shape_, loop_state, 0));
602 auto inc = builder.AddInstruction(
603 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
604 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
605 gte0->shape(), HloOpcode::kAdd, gte0, inc));
606
607 // GTE(loop_state, 1)
608 auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
609 nested_tuple_shape_, loop_state, 1));
610 // GTE(GTE(loop_state, 1), 0) -> Add
611 auto gte10 = builder.AddInstruction(
612 HloInstruction::CreateGetTupleElement(data_shape_, gte1, 0));
613 auto update10 = builder.AddInstruction(
614 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
615 {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
616 auto add10 = builder.AddInstruction(HloInstruction::CreateBinary(
617 data_shape_, HloOpcode::kAdd, gte10, update10));
618
619 // GTE(GTE(loop_state, 1), 1) -> Reverse
620 auto gte11 = builder.AddInstruction(
621 HloInstruction::CreateGetTupleElement(data_shape_, gte1, 1));
622 auto rev11 = builder.AddInstruction(
623 HloInstruction::CreateReverse(data_shape_, gte11, {0}));
624
625 // Create output Tuple.
626 auto inner_tuple =
627 builder.AddInstruction(HloInstruction::CreateTuple({add10, rev11}));
628 builder.AddInstruction(HloInstruction::CreateTuple({add0, inner_tuple}));
629 return builder.Build();
630 }
631
632 // Builds a While instruction using 'condition' and 'body' sub-computations.
633 // Init operand is initialized to zeros of appropriate shape.
BuildWhileInstruction(HloComputation * condition,HloComputation * body,bool nested=false)634 HloInstruction* BuildWhileInstruction(HloComputation* condition,
635 HloComputation* body,
636 bool nested = false) {
637 auto builder = HloComputation::Builder(TestName() + ".While");
638 auto induction_var_init = builder.AddInstruction(
639 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
640
641 auto data_init = builder.AddInstruction(
642 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
643 {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
644
645 if (nested) {
646 auto inner_init = builder.AddInstruction(
647 HloInstruction::CreateTuple({data_init, data_init}));
648 auto loop_state_init = builder.AddInstruction(
649 HloInstruction::CreateTuple({induction_var_init, inner_init}));
650 auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile(
651 loop_state_init->shape(), condition, body, loop_state_init));
652 module_->AddEntryComputation(builder.Build());
653 return while_hlo;
654 }
655
656 auto loop_state_init = builder.AddInstruction(
657 HloInstruction::CreateTuple({induction_var_init, data_init}));
658 auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile(
659 loop_state_shape_, condition, body, loop_state_init));
660 module_->AddEntryComputation(builder.Build());
661 return while_hlo;
662 }
663
BuildWhileInstruction_InitPointsToConstant()664 HloInstruction* BuildWhileInstruction_InitPointsToConstant() {
665 auto builder = HloComputation::Builder(TestName() + ".While");
666 auto data_init = builder.AddInstruction(
667 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
668 {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
669 return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init,
670 &builder);
671 }
672
BuildWhileInstruction_InitPointsToParameter()673 HloInstruction* BuildWhileInstruction_InitPointsToParameter() {
674 auto builder = HloComputation::Builder(TestName() + ".While");
675 auto data_init = builder.AddInstruction(
676 HloInstruction::CreateParameter(0, data_shape_, "data_init"));
677 return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init,
678 &builder);
679 }
680
BuildWhileInstruction_InitPointsToAmbiguous()681 HloInstruction* BuildWhileInstruction_InitPointsToAmbiguous() {
682 auto builder = HloComputation::Builder(TestName() + ".While");
683
684 auto one = builder.AddInstruction(
685 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
686 auto v1 = builder.AddInstruction(
687 HloInstruction::CreateBroadcast(data_shape_, one, {1}));
688 auto zero = builder.AddInstruction(
689 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
690 auto v2 = builder.AddInstruction(
691 HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
692
693 auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({v1, v2}));
694 auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({v2, v1}));
695
696 auto pred = builder.AddInstruction(
697 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
698 auto data_init = builder.AddInstruction(HloInstruction::CreateTernary(
699 nested_tuple_shape_, HloOpcode::kTupleSelect, pred, tuple1, tuple2));
700
701 return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_,
702 data_init, &builder);
703 }
704
BuildWhileInstruction_InitPointsToNonDistinct()705 HloInstruction* BuildWhileInstruction_InitPointsToNonDistinct() {
706 auto builder = HloComputation::Builder(TestName() + ".While");
707
708 auto one = builder.AddInstruction(
709 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
710 auto one_vec = builder.AddInstruction(
711 HloInstruction::CreateBroadcast(data_shape_, one, {1}));
712 auto data_init =
713 builder.AddInstruction(HloInstruction::CreateTuple({one_vec, one_vec}));
714
715 return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_,
716 data_init, &builder);
717 }
718
BuildWhileInstruction_InitPointsToInterfering()719 HloInstruction* BuildWhileInstruction_InitPointsToInterfering() {
720 auto builder = HloComputation::Builder(TestName() + ".While");
721 auto one = builder.AddInstruction(
722 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
723 auto data_init = builder.AddInstruction(
724 HloInstruction::CreateBroadcast(data_shape_, one, {1}));
725 auto one_vec = builder.AddInstruction(
726 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
727 {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
728 // Take a reference to 'data_init' to make it interfere with while result.
729 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
730 data_shape_, HloOpcode::kAdd, data_init, one_vec));
731
732 auto xla_while = BuildWhileInstructionWithCustomInit(loop_state_shape_,
733 data_init, &builder);
734
735 // Add an additional binary operation operating on the while and the
736 // interfering add so that neither operation is dead.
737 auto gte = xla_while->parent()->AddInstruction(
738 HloInstruction::CreateGetTupleElement(
739 ShapeUtil::GetSubshape(xla_while->shape(), {1}), xla_while, 1));
740 auto sub = xla_while->parent()->AddInstruction(HloInstruction::CreateBinary(
741 data_shape_, HloOpcode::kSubtract, add, gte));
742 auto gte0 = xla_while->parent()->AddInstruction(
743 HloInstruction::CreateGetTupleElement(
744 ShapeUtil::GetSubshape(xla_while->shape(), {0}), xla_while, 0));
745 auto tuple = xla_while->parent()->AddInstruction(
746 HloInstruction::CreateTuple({gte0, sub}));
747
748 xla_while->parent()->set_root_instruction(tuple);
749
750 return xla_while;
751 }
752
BuildWhileInstructionWithCustomInit(const Shape & loop_state_shape,HloInstruction * data_init,HloComputation::Builder * builder)753 HloInstruction* BuildWhileInstructionWithCustomInit(
754 const Shape& loop_state_shape, HloInstruction* data_init,
755 HloComputation::Builder* builder) {
756 const bool nested =
757 ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_);
758 auto induction_var_init = builder->AddInstruction(
759 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
760 auto condition = module_->AddEmbeddedComputation(
761 BuildConditionComputation(loop_state_shape));
762 auto body = module_->AddEmbeddedComputation(
763 BuildIndependentBodyComputation(nested));
764 auto loop_state_init = builder->AddInstruction(
765 HloInstruction::CreateTuple({induction_var_init, data_init}));
766 auto while_hlo = builder->AddInstruction(HloInstruction::CreateWhile(
767 loop_state_shape, condition, body, loop_state_init));
768 module_->AddEntryComputation(builder->Build());
769 return while_hlo;
770 }
771
772 std::unique_ptr<HloModule> module_;
773 Shape induction_variable_shape_ = ShapeUtil::MakeShape(S32, {});
774 Shape data_shape_ = ShapeUtil::MakeShape(F32, {8});
775 Shape loop_state_shape_ =
776 ShapeUtil::MakeTupleShape({induction_variable_shape_, data_shape_});
777 Shape nested_tuple_shape_ =
778 ShapeUtil::MakeTupleShape({data_shape_, data_shape_});
779 Shape nested_loop_state_shape_ = ShapeUtil::MakeTupleShape(
780 {induction_variable_shape_, nested_tuple_shape_});
781 Shape condition_result_shape_ = ShapeUtil::MakeShape(PRED, {});
782 };
783
784 // Tests while body computation with independent tuple elements:
785 //
786 // While.Body({in0, in1})
787 // out0 = Add(in0, 1)
788 // out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
789 // Tuple(out0, out1)
790 //
791 // CopyInsertion pass should not generate any copies.
792 //
TEST_F(WhileCopyInsertionTest,IndependentTupleElements)793 TEST_F(WhileCopyInsertionTest, IndependentTupleElements) {
794 auto condition = module_->AddEmbeddedComputation(
795 BuildConditionComputation(loop_state_shape_));
796 auto body =
797 module_->AddEmbeddedComputation(BuildIndependentBodyComputation());
798 auto while_hlo = BuildWhileInstruction(condition, body);
799
800 InsertCopies(module_.get());
801
802 // Body should have no copies as the adds can be done inplace.
803 EXPECT_EQ(CountCopies(*body), 0);
804 EXPECT_EQ(CountControlEdges(*module_), 0);
805
806 // Both init indices need copies as they are constants.
807 EXPECT_THAT(while_hlo->operand(0),
808 op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant())));
809 }
810
811 // Tests while body computation with dependent tuple elements:
812 //
813 // While.Body({in0, in1})
814 // out0 = Add(in0, 1)
815 // out1 = Add(BCast(in0), in1)
816 // Tuple(out0, out1)
817 //
818 // CopyInsertion pass should convert the root instruction to:
819 //
820 // Tuple(Copy(out0), out1)
821 //
TEST_F(WhileCopyInsertionTest,DependentTupleElements)822 TEST_F(WhileCopyInsertionTest, DependentTupleElements) {
823 auto condition = module_->AddEmbeddedComputation(
824 BuildConditionComputation(loop_state_shape_));
825 auto body = module_->AddEmbeddedComputation(BuildDependentBodyComputation());
826 auto while_hlo = BuildWhileInstruction(condition, body);
827
828 InsertCopies(module_.get());
829
830 EXPECT_EQ(CountCopies(*body), 1);
831 EXPECT_EQ(CountControlEdges(*body), 0);
832
833 EXPECT_THAT(
834 body->root_instruction(),
835 op::Tuple(op::Add(), op::Add(op::GetTupleElement(), op::Broadcast())));
836
837 auto add = body->root_instruction()->operand(0);
838 auto bcast = body->root_instruction()->operand(1)->operand(1);
839 ASSERT_EQ(add->opcode(), HloOpcode::kAdd);
840 ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
841
842 EXPECT_THAT(
843 while_hlo->while_body()->root_instruction(),
844 op::Tuple(op::Add(op::Copy(), op::Constant()),
845 op::Add(op::GetTupleElement(), op::Broadcast(op::Copy()))));
846
847 // Both init indices need copies as they are constants.
848 EXPECT_THAT(while_hlo->operand(0),
849 op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant())));
850 }
851
852 // Tests while body computation with read-only tuple element 0:
853 //
854 // PARAMETER
855 // / \
856 // GTE(0) GTE(1)
857 // | \ |
858 // | BCAST |
859 // | \ |
860 // | ADD
861 // | |
862 // \ /
863 // TUPLE (root)
864 //
865 // CopyInsertion pass should not generate any copies for the while body.
TEST_F(WhileCopyInsertionTest,DependentTupleElements_OneReadOnly)866 TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly) {
867 auto condition = module_->AddEmbeddedComputation(
868 BuildConditionComputation(loop_state_shape_));
869 auto body = module_->AddEmbeddedComputation(
870 BuildDependentBodyOneReadOnlyComputation());
871 BuildWhileInstruction(condition, body);
872
873 InsertCopies(module_.get());
874
875 // No copies or control edges should be inserted. The body is legal as is.
876 EXPECT_EQ(CountCopies(*body), 0);
877 EXPECT_EQ(CountControlEdges(*body), 0);
878 }
879
880 // Same as above, but with two while loops, sharing entry parameters.
TEST_F(WhileCopyInsertionTest,DependentTupleElements_OneReadOnly_TwoLoops_EntryParams)881 TEST_F(WhileCopyInsertionTest,
882 DependentTupleElements_OneReadOnly_TwoLoops_EntryParams) {
883 auto condition1 = module_->AddEmbeddedComputation(
884 BuildConditionComputation(loop_state_shape_));
885 auto condition2 = module_->AddEmbeddedComputation(
886 BuildConditionComputation(loop_state_shape_));
887 auto body1 = module_->AddEmbeddedComputation(
888 BuildDependentBodyOneReadOnlyComputation());
889 auto body2 = module_->AddEmbeddedComputation(
890 BuildDependentBodyOneReadOnlyComputation());
891
892 auto builder = HloComputation::Builder(TestName() + ".While");
893 auto iter_param = builder.AddInstruction(
894 HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
895 auto data_param = builder.AddInstruction(
896 HloInstruction::CreateParameter(1, data_shape_, "data"));
897 auto loop_init = builder.AddInstruction(
898 HloInstruction::CreateTuple({iter_param, data_param}));
899
900 auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
901 loop_state_shape_, condition1, body1, loop_init));
902 auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
903 loop_state_shape_, condition2, body2, loop_init));
904
905 // Add a couple elements from each of the while so both whiles are live.
906 auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
907 ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0));
908 auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
909 ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0));
910 builder.AddInstruction(
911 HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2));
912
913 auto entry = module_->AddEntryComputation(builder.Build());
914
915 InsertCopies(module_.get());
916
917 // Neither body should have any copies or control edges in them.
918 EXPECT_EQ(CountCopies(*body1), 0);
919 EXPECT_EQ(CountCopies(*body2), 0);
920 EXPECT_EQ(CountControlEdges(*body1), 0);
921 EXPECT_EQ(CountControlEdges(*body2), 0);
922
923 // Only two copies should be necessary. Each of the whiles should have
924 // a copy of tuple element 1 (init value is a parameter, and the element is
925 // not non-read-only) so each of the while bodies gets its own buffer to write
926 // element 1 into.
927 EXPECT_EQ(CountCopies(*entry), 2);
928
929 EXPECT_EQ(while_hlo1->operand(0)->operand(1)->opcode(), HloOpcode::kCopy);
930 EXPECT_EQ(while_hlo2->operand(0)->operand(1)->opcode(), HloOpcode::kCopy);
931
932 // The two copies of element 1 should be different.
933 EXPECT_NE(while_hlo1->operand(0)->operand(1),
934 while_hlo2->operand(0)->operand(1));
935 }
936
937 // Same as above, but with two while loops, sharing non-parameters.
TEST_F(WhileCopyInsertionTest,DependentTupleElements_OneReadOnly_TwoLoops_NonParams)938 TEST_F(WhileCopyInsertionTest,
939 DependentTupleElements_OneReadOnly_TwoLoops_NonParams) {
940 auto condition1 = module_->AddEmbeddedComputation(
941 BuildConditionComputation(loop_state_shape_));
942 auto condition2 = module_->AddEmbeddedComputation(
943 BuildConditionComputation(loop_state_shape_));
944 auto body1 = module_->AddEmbeddedComputation(
945 BuildDependentBodyOneReadOnlyComputation());
946 auto body2 = module_->AddEmbeddedComputation(
947 BuildDependentBodyOneReadOnlyComputation());
948
949 auto builder = HloComputation::Builder(TestName() + ".While");
950 auto iter_param = builder.AddInstruction(
951 HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
952 auto data_param = builder.AddInstruction(
953 HloInstruction::CreateParameter(1, data_shape_, "data"));
954 // Add dummy ops to ensure loop_init elements aren't entry parameters.
955 auto iter_value = builder.AddInstruction(HloInstruction::CreateUnary(
956 iter_param->shape(), HloOpcode::kExp, iter_param));
957 auto data_value = builder.AddInstruction(HloInstruction::CreateUnary(
958 data_param->shape(), HloOpcode::kExp, data_param));
959 auto loop_init = builder.AddInstruction(
960 HloInstruction::CreateTuple({iter_value, data_value}));
961
962 auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
963 loop_state_shape_, condition1, body1, loop_init));
964 auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
965 loop_state_shape_, condition2, body2, loop_init));
966
967 // Add a couple elements from each of the while so both whiles are not dead.
968 auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
969 ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0));
970 auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
971 ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0));
972 builder.AddInstruction(
973 HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2));
974 auto entry = module_->AddEntryComputation(builder.Build());
975
976 InsertCopies(module_.get());
977
978 // Ideally only one copy should be necessary. One of the whiles should
979 // have a copy of tuple element 1 (the non-read-only element) so each of the
980 // while bodies gets its own buffer to write element 1 into. However, the
981 // analysis isn't perfect and adds an additional copy of element 0.
982 EXPECT_EQ(CountCopies(*entry), 2);
983
984 EXPECT_THAT(while_hlo1->operand(0),
985 op::Tuple(op::Exp(), op::Copy(op::Exp())));
986 EXPECT_THAT(while_hlo2->operand(0),
987 op::Tuple(op::Exp(), op::Copy(op::Exp())));
988 }
989
990 // Tests while body computation with nested tuple elements:
991 //
992 // |
993 // GTE(loop_state, 1)
994 // / \
995 // GTE(GTE(loop_state, 1), 0) GTE(GTE(loop_state, 1), 1)
996 // | |
997 // Add Reverse
998 // | |
999 //
1000 // CopyInsertion pass will conceptually generate the following, but with the
1001 // actual GTE and Tuple instructions optimized away:
1002 //
1003 // Tuple // old root
1004 // / \
1005 // / \
1006 // GTE(0) GTE(1)
1007 // | / \
1008 // | / \
1009 // | GTE(0) GTE(1)
1010 // | | |
1011 // | | Copy
1012 // | | |
1013 // \ | /
1014 // \ Tuple // "inner" tuple.
1015 // \ /
1016 // \ /
1017 // Tuple // new root
1018 //
TEST_F(WhileCopyInsertionTest,NestedTupleElements)1019 TEST_F(WhileCopyInsertionTest, NestedTupleElements) {
1020 auto condition = module_->AddEmbeddedComputation(
1021 BuildConditionComputation(nested_loop_state_shape_));
1022 auto body = module_->AddEmbeddedComputation(BuildNestedBodyComputation());
1023 BuildWhileInstruction(condition, body, true);
1024
1025 // HloInstruction* old_root = body->root_instruction();
1026 InsertCopies(module_.get());
1027
1028 // The only copy necessary is for the kReverse as it cannot be done
1029 // in-place (instruction can share buffer with operand). The other elements of
1030 // the loop state are kAdd instructions which can be done in-place.
1031 EXPECT_EQ(CountCopies(*body), 1);
1032
1033 // Each element of the init needs a copy as all are constants.
1034 EXPECT_EQ(CountCopies(*module_), 4);
1035
1036 // Either the kReverse itself must be copied or the operand of the kReverse
1037 // must be copied.
1038 if (body->root_instruction()->operand(1)->operand(1)->opcode() ==
1039 HloOpcode::kCopy) {
1040 EXPECT_THAT(
1041 body->root_instruction(),
1042 op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Reverse()))));
1043 } else {
1044 EXPECT_THAT(
1045 body->root_instruction(),
1046 op::Tuple(op::Add(), op::Tuple(op::Add(), op::Reverse(op::Copy()))));
1047 }
1048 }
1049
1050 // Tests while init instruction which points-to a constant.
1051 //
1052 // init = Tuple(Constant(S32, {}), Constant(F32, {8}))
1053 //
1054 // CopyInsertion pass should add copies for both constants.
1055 //
TEST_F(WhileCopyInsertionTest,InitPointsToConstant)1056 TEST_F(WhileCopyInsertionTest, InitPointsToConstant) {
1057 auto while_hlo = BuildWhileInstruction_InitPointsToConstant();
1058
1059 InsertCopies(module_.get());
1060 EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0);
1061 EXPECT_EQ(CountCopies(*module_), 2);
1062
1063 EXPECT_THAT(while_hlo->operand(0),
1064 op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant())));
1065 }
1066
1067 // Tests while init instruction which points-to a parameter.
1068 //
1069 // init = Tuple(Constant(S32, {}), Parameter(F32, {8}))
1070 //
1071 // CopyInsertion pass should add copies for both the constant and parameter.
1072 //
TEST_F(WhileCopyInsertionTest,InitPointsToParameter)1073 TEST_F(WhileCopyInsertionTest, InitPointsToParameter) {
1074 auto while_hlo = BuildWhileInstruction_InitPointsToParameter();
1075
1076 InsertCopies(module_.get());
1077 EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0);
1078 EXPECT_EQ(CountCopies(*module_), 2);
1079
1080 EXPECT_THAT(while_hlo->operand(0),
1081 op::Tuple(op::Copy(op::Constant()), op::Copy(op::Parameter())));
1082 }
1083
1084 // Tests while init instruction which has an ambiguous points-to set.
1085 //
1086 // select = Select(pred, tuple1, tuple2)
1087 // init = Tuple(Constant(S32, {}), Parameter(F32, {8}))
1088 //
1089 // CopyInsertion pass will conceptually generate the following, but with some of
1090 // the actual GTE and Tuple instructions optimized away:
1091 //
1092 // Tuple // old init
1093 // / \
1094 // / \
1095 // GTE(0) GTE(1)
1096 // | / \
1097 // | / \
1098 // | GTE(0) GTE(1)
1099 // | | |
1100 // Copy Copy Copy
1101 // | | |
1102 // \ | /
1103 // \ Tuple
1104 // \ /
1105 // \ /
1106 // Tuple // new init
1107 //
TEST_F(WhileCopyInsertionTest,InitPointsToAmbiguous)1108 TEST_F(WhileCopyInsertionTest, InitPointsToAmbiguous) {
1109 auto while_hlo = BuildWhileInstruction_InitPointsToAmbiguous();
1110
1111 InsertCopies(module_.get());
1112 EXPECT_EQ(CountCopies(*module_), 4);
1113 // The entry computation requires three copies to resolve the ambiguity of two
1114 // init elements and the constant passed in as one of the init elements.
1115 EXPECT_EQ(CountCopies(*module_->entry_computation()), 3);
1116 EXPECT_THAT(while_hlo->operand(0),
1117 op::Tuple(op::Copy(op::Constant()),
1118 op::Tuple(op::Copy(op::GetTupleElement()),
1119 op::Copy(op::GetTupleElement()))));
1120
1121 // The body requires one copy because the buffer set is not distinct: the
1122 // result of one of the adds is written into two elements of the output of the
1123 // loop body. Either element might be copied.
1124 EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1);
1125 if (while_hlo->while_body()
1126 ->root_instruction()
1127 ->operand(1)
1128 ->operand(0)
1129 ->opcode() == HloOpcode::kCopy) {
1130 EXPECT_THAT(
1131 while_hlo->while_body()->root_instruction(),
1132 op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add())));
1133 } else {
1134 EXPECT_THAT(
1135 while_hlo->while_body()->root_instruction(),
1136 op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add()))));
1137 }
1138 }
1139
1140 // Tests while init instruction which has a non-distinct points-to set.
1141 //
1142 // init = Tuple(Constant(S32, {}), Tuple({vec_one, vec_one}))
1143 //
1144 // CopyInsertion pass will conceptually generate the following, but with some of
1145 // the actual GTE and Tuple instructions optimized away:
1146 //
1147 // Tuple // old init
1148 // / \
1149 // / \
1150 // GTE(0) GTE(1)
1151 // | / \
1152 // | / \
1153 // | GTE(0) GTE(1)
1154 // | | |
1155 // Copy Copy Copy
1156 // | | |
1157 // \ | /
1158 // \ Tuple
1159 // \ /
1160 // \ /
1161 // Tuple // new init
1162 //
TEST_F(WhileCopyInsertionTest,InitPointsToNonDistinct)1163 TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) {
1164 auto while_hlo = BuildWhileInstruction_InitPointsToNonDistinct();
1165
1166 InsertCopies(module_.get());
1167
1168 // The entry computation requires two copies to resolve the non-disinctness of
1169 // two init elements and the constant passed in as one of the init
1170 // elements. Either element can be copied for the distinctness issue.
1171 EXPECT_EQ(CountCopies(*module_->entry_computation()), 2);
1172 if (while_hlo->operand(0)->operand(1)->operand(0)->opcode() ==
1173 HloOpcode::kCopy) {
1174 EXPECT_THAT(
1175 while_hlo->operand(0),
1176 op::Tuple(op::Copy(op::Constant()),
1177 op::Tuple(op::Copy(op::Broadcast()), op::Broadcast())));
1178 } else {
1179 EXPECT_THAT(
1180 while_hlo->operand(0),
1181 op::Tuple(op::Copy(op::Constant()),
1182 op::Tuple(op::Broadcast(), op::Copy(op::Broadcast()))));
1183 }
1184
1185 // The body requires one copy because the buffer set is not distinct: the
1186 // result of one of the adds is written into two elements of the output of the
1187 // loop body. Either element might be copied.
1188 EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1);
1189 if (while_hlo->while_body()
1190 ->root_instruction()
1191 ->operand(1)
1192 ->operand(0)
1193 ->opcode() == HloOpcode::kCopy) {
1194 EXPECT_THAT(
1195 while_hlo->while_body()->root_instruction(),
1196 op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add())));
1197 } else {
1198 EXPECT_THAT(
1199 while_hlo->while_body()->root_instruction(),
1200 op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add()))));
1201 }
1202 }
1203
1204 // Tests while init instruction buffer which interferes with while result
1205 // buffer.
1206 //
1207 // init_data = Broadcast(...)
1208 // add_unrelated = Add(init_data) // takes a reference to cause interference
1209 // init = Tuple(Constant(S32, {}), init_data))
1210 //
1211 // CopyInsertion pass should copy both operands.
1212 //
TEST_F(WhileCopyInsertionTest,InitPointsToInterfering)1213 TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) {
1214 auto while_hlo = BuildWhileInstruction_InitPointsToInterfering();
1215
1216 InsertCopies(module_.get());
1217 EXPECT_EQ(CountCopies(*module_), 2);
1218 EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0);
1219
1220 EXPECT_THAT(while_hlo->operand(0),
1221 op::Tuple(op::Copy(op::Constant()), op::Copy(op::Broadcast())));
1222 }
1223
1224 // Tests while init instruction buffer which has a non-distinct points-to set:
1225 //
1226 // init = Tuple(Parameter(S32, {}), Parameter(F32, {8},
1227 // Parameter(F32, {8})))
1228 //
1229 // where the second and third parameters are identical *and* the tuple shared
1230 // by another while instruction.
1231 //
1232 // Verifies that the resulting point-to set is distinct in the resulting Tuple
1233 // (non-identical Copys). In other words, verifies that copy sharing does not
1234 // insert identical copies to the resulting tuple.
TEST_F(WhileCopyInsertionTest,InitPointsToNonDistinctUsedByTwoWhileLoops)1235 TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) {
1236 // Loop body that outputs tuple comprises two elements dependent on the init
1237 // tuple.
1238 const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
1239 {induction_variable_shape_, data_shape_, data_shape_});
1240
1241 auto condition1 = module_->AddEmbeddedComputation(
1242 BuildConditionComputation(loop_state_shape));
1243 auto condition2 = module_->AddEmbeddedComputation(
1244 BuildConditionComputation(loop_state_shape));
1245 auto body1 =
1246 module_->AddEmbeddedComputation(BuildDependentBodyComputation2());
1247 auto body2 =
1248 module_->AddEmbeddedComputation(BuildDependentBodyComputation2());
1249
1250 auto builder = HloComputation::Builder(TestName() + ".While");
1251
1252 auto iter_param = builder.AddInstruction(
1253 HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
1254 auto data_param = builder.AddInstruction(
1255 HloInstruction::CreateParameter(1, data_shape_, "data"));
1256
1257 // Loop init tuple contains two identical parameter buffers.
1258 auto loop_init = builder.AddInstruction(
1259 HloInstruction::CreateTuple({iter_param, data_param, data_param}));
1260
1261 // Two while loops shares the same loop init tuple.
1262 auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
1263 loop_state_shape, condition1, body1, loop_init));
1264 auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
1265 loop_state_shape, condition2, body2, loop_init));
1266
1267 // Add add instruction so neither while is dead.
1268 auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1269 ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0));
1270 auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
1271 ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo2, 0));
1272 builder.AddInstruction(
1273 HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2));
1274
1275 module_->AddEntryComputation(builder.Build());
1276
1277 InsertCopies(module_.get());
1278
1279 // None of the bodies should have copies or control flow edges.
1280 EXPECT_EQ(CountCopies(*body1), 0);
1281 EXPECT_EQ(CountCopies(*body2), 0);
1282
1283 // The loop bodies pass through elements 1 and 2 in the init tuple, so ideally
1284 // these should not need to be copied before either while. However, copy
1285 // insertion is not able to reason about the transparency of elements through
1286 // while bodies in all circumstances so extra copies are added (b/xxx).
1287 EXPECT_EQ(CountCopies(*module_->entry_computation()), 2);
1288
1289 EXPECT_THAT(while_hlo1->operand(0),
1290 op::Tuple(op::Copy(), op::Parameter(), op::Parameter()));
1291 EXPECT_THAT(while_hlo2->operand(0),
1292 op::Tuple(op::Copy(), op::Parameter(), op::Parameter()));
1293 }
1294
TEST_F(CopyInsertionTest,SwizzlingWhile)1295 TEST_F(CopyInsertionTest, SwizzlingWhile) {
1296 // Test a while instruction with a body which permutes its tuple parameter
1297 // elements.
1298 auto module = CreateNewVerifiedModule();
1299 const Shape loop_state_shape =
1300 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1301
1302 // Body simply interchanges the two tuple elements in the loop state.
1303 auto body_builder = HloComputation::Builder("body");
1304 auto body_param = body_builder.AddInstruction(
1305 HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1306 auto body_element_0 = body_builder.AddInstruction(
1307 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
1308 auto body_element_1 = body_builder.AddInstruction(
1309 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
1310 body_builder.AddInstruction(
1311 HloInstruction::CreateTuple({body_element_1, body_element_0}));
1312 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1313
1314 auto cond_builder = HloComputation::Builder("condition");
1315 cond_builder.AddInstruction(
1316 HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1317 auto cond_constant = cond_builder.AddInstruction(
1318 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1319 cond_builder.AddInstruction(HloInstruction::CreateUnary(
1320 cond_constant->shape(), HloOpcode::kNot, cond_constant));
1321 HloComputation* condition =
1322 module->AddEmbeddedComputation(cond_builder.Build());
1323
1324 auto builder = HloComputation::Builder(TestName());
1325 auto constant1 = builder.AddInstruction(
1326 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1327 auto constant2 = builder.AddInstruction(
1328 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
1329 auto tuple = builder.AddInstruction(
1330 HloInstruction::CreateTuple({constant1, constant2}));
1331 auto xla_while = builder.AddInstruction(
1332 HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple));
1333 module->AddEntryComputation(builder.Build());
1334
1335 InsertCopies(module.get());
1336
1337 EXPECT_EQ(CountCopies(*module), 6);
1338
1339 // The loop state elements should be copied at the parameter and at the root
1340 // with a control edge in between (see DeepCopyAndAddControlEdges). This is
1341 // technically one more copy than is strictly necessary, but in order to have
1342 // only three copies the copies of different loop state elements must be
1343 // ordered with a control edge.
1344 EXPECT_EQ(CountCopies(*body), 4);
1345 EXPECT_EQ(CountControlEdges(*body), 2);
1346
1347 EXPECT_THAT(body->root_instruction(),
1348 op::Tuple(op::Copy(op::Copy()), op::Copy(op::Copy())));
1349
1350 EXPECT_EQ(CountCopies(*module->entry_computation()), 2);
1351 EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy()));
1352 }
1353
TEST_F(CopyInsertionTest,CrossingParameters)1354 TEST_F(CopyInsertionTest, CrossingParameters) {
1355 // Test a case where two parameters' dataflow cross with each other while
1356 // input and output are aliased with same index:
1357 //
1358 // (p0 , p1)
1359 // | \ /|
1360 // | \ / |
1361 // alias X alias
1362 // | / \ |
1363 // | / \|
1364 // (p1 , p0)
1365 auto module = CreateNewVerifiedModule();
1366 const Shape tuple_shape =
1367 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1368
1369 auto builder = HloComputation::Builder(TestName());
1370 auto param = builder.AddInstruction(
1371 HloInstruction::CreateParameter(0, tuple_shape, "0"));
1372 auto gte0 = builder.AddInstruction(
1373 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1374 auto gte1 = builder.AddInstruction(
1375 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1376 builder.AddInstruction(HloInstruction::CreateTuple({gte1, gte0}));
1377 module->AddEntryComputation(builder.Build());
1378 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1379 /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
1380 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
1381 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1382 /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1},
1383 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
1384 InsertCopies(module.get());
1385
1386 EXPECT_EQ(CountCopies(*module), 4);
1387 }
1388
TEST_F(CopyInsertionTest,ParametersAliasing)1389 TEST_F(CopyInsertionTest, ParametersAliasing) {
1390 // Test a case where two parameters' dataflow don't interfere with each other
1391 // while aliased.
1392 //
1393 // (p0 , p1)
1394 // | |
1395 // | |
1396 // alias alias
1397 // | |
1398 // | |
1399 // (p0 , p1)
1400 auto module = CreateNewVerifiedModule();
1401 const Shape tuple_shape =
1402 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1403
1404 auto builder = HloComputation::Builder(TestName());
1405 auto param = builder.AddInstruction(
1406 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1407 auto gte0 = builder.AddInstruction(
1408 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1409 auto gte1 = builder.AddInstruction(
1410 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1411 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
1412 module->AddEntryComputation(builder.Build());
1413 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1414 /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
1415 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
1416 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1417 /*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1},
1418 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
1419 InsertCopies(module.get());
1420
1421 EXPECT_EQ(CountCopies(*module), 0);
1422 }
1423
TEST_F(CopyInsertionTest,ParameterWithNoAliasing)1424 TEST_F(CopyInsertionTest, ParameterWithNoAliasing) {
1425 // Test a case where no parameter is aliased with result. In this case, copy
1426 // should be added
1427 //
1428 // (p0 , p1)
1429 // | |
1430 // | |
1431 // | |
1432 // | |
1433 // | |
1434 // (p0 , p1)
1435 auto module = CreateNewVerifiedModule();
1436 const Shape tuple_shape =
1437 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1438
1439 auto builder = HloComputation::Builder(TestName());
1440 auto param = builder.AddInstruction(
1441 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1442 auto gte0 = builder.AddInstruction(
1443 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1444 auto gte1 = builder.AddInstruction(
1445 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1446 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
1447 module->AddEntryComputation(builder.Build());
1448 InsertCopies(module.get());
1449
1450 EXPECT_THAT(module->entry_computation()->root_instruction(),
1451 op::Tuple(op::Copy(op::GetTupleElement(param, 0)),
1452 op::Copy(op::GetTupleElement(param, 1))));
1453
1454 EXPECT_EQ(CountCopies(*module), 2);
1455 }
1456
TEST_F(CopyInsertionTest,ParameterWithPartialAliasing)1457 TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) {
1458 // Test a case where one parameter is aliased with result while another one
1459 // isn't.
1460 //
1461 // (p0 , p1)
1462 // | |
1463 // | |
1464 // alias |
1465 // | |
1466 // | |
1467 // (p0 , p1)
1468 auto module = CreateNewVerifiedModule();
1469 const Shape tuple_shape =
1470 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1471
1472 auto builder = HloComputation::Builder(TestName());
1473 auto param = builder.AddInstruction(
1474 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1475 auto gte0 = builder.AddInstruction(
1476 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1477 auto gte1 = builder.AddInstruction(
1478 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1479 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
1480 module->AddEntryComputation(builder.Build());
1481 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1482 /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
1483 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
1484 InsertCopies(module.get());
1485
1486 EXPECT_THAT(module->entry_computation()->root_instruction(),
1487 op::Tuple(op::GetTupleElement(param, 0),
1488 op::Copy(op::GetTupleElement(param, 1))));
1489
1490 EXPECT_EQ(CountCopies(*module), 1);
1491 }
1492
TEST_F(CopyInsertionTest,ParameterAndParallelOpsWithPartialAliasing)1493 TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) {
1494 // Test a case where one parameter is aliased with result while another one
1495 // isn't.
1496 //
1497 // +-- (p0 , p1)
1498 // | | |
1499 // | | |
1500 // alias Negate Negate
1501 // | | |
1502 // | | |
1503 // +-- (p0 , p1)
1504 auto module = CreateNewVerifiedModule();
1505 const Shape tuple_shape =
1506 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1507
1508 auto builder = HloComputation::Builder(TestName());
1509 auto param = builder.AddInstruction(
1510 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1511 auto gte0 = builder.AddInstruction(
1512 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1513 auto gte1 = builder.AddInstruction(
1514 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1515
1516 auto negate0 = builder.AddInstruction(
1517 HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
1518
1519 auto negate1 = builder.AddInstruction(
1520 HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
1521 builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1}));
1522 module->AddEntryComputation(builder.Build());
1523 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1524 /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
1525 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
1526 InsertCopies(module.get());
1527
1528 EXPECT_EQ(CountCopies(*module), 0);
1529 }
1530
TEST_F(CopyInsertionTest,ParameterAndOpsWithPartialAliasing)1531 TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) {
1532 // Test a case where one parameter is aliased with result while another one
1533 // isn't.
1534 //
1535 // +-- (p0 , p1)
1536 // | | |
1537 // | | |
1538 // alias Negate Negate
1539 // | | |
1540 // | Add----+
1541 // | | |
1542 // +-- (p0 , p1)
1543 auto module = CreateNewVerifiedModule();
1544 const Shape tuple_shape =
1545 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1546
1547 auto builder = HloComputation::Builder(TestName());
1548 auto param = builder.AddInstruction(
1549 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
1550 auto gte0 = builder.AddInstruction(
1551 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
1552 auto gte1 = builder.AddInstruction(
1553 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
1554
1555 auto negate0 = builder.AddInstruction(
1556 HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
1557
1558 auto negate1 = builder.AddInstruction(
1559 HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
1560
1561 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1562 scalar_shape_, HloOpcode::kAdd, negate0, negate1));
1563 builder.AddInstruction(HloInstruction::CreateTuple({add, negate1}));
1564 module->AddEntryComputation(builder.Build());
1565 ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
1566 /*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
1567 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
1568 InsertCopies(module.get());
1569
1570 EXPECT_EQ(CountCopies(*module), 0);
1571 }
1572
TEST_F(CopyInsertionTest,SwizzlingWhileWithOneOp)1573 TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) {
1574 // Test a while instruction with a body which permutes its tuple parameter
1575 // elements and applies one operation to one of the elements. The addition of
1576 // the operation (instruction) on the element makes the live range of the
1577 // respective input and output elements different than if the instruction were
1578 // not there (as in the SwizzlingWhile test above).
1579 auto module = CreateNewVerifiedModule();
1580 const Shape loop_state_shape =
1581 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1582
1583 // Body interchanges the two tuple elements in the loop state and negates one
1584 // of them.
1585 auto body_builder = HloComputation::Builder("body");
1586 auto body_param = body_builder.AddInstruction(
1587 HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1588 auto body_element_0 = body_builder.AddInstruction(
1589 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
1590 auto body_element_1 = body_builder.AddInstruction(
1591 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
1592 auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
1593 scalar_shape_, HloOpcode::kNegate, body_element_1));
1594 body_builder.AddInstruction(
1595 HloInstruction::CreateTuple({negate, body_element_0}));
1596 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1597
1598 auto cond_builder = HloComputation::Builder("condition");
1599 cond_builder.AddInstruction(
1600 HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1601 auto cond_constant = cond_builder.AddInstruction(
1602 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1603 cond_builder.AddInstruction(HloInstruction::CreateUnary(
1604 cond_constant->shape(), HloOpcode::kNot, cond_constant));
1605 HloComputation* condition =
1606 module->AddEmbeddedComputation(cond_builder.Build());
1607
1608 auto builder = HloComputation::Builder(TestName());
1609 auto constant1 = builder.AddInstruction(
1610 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1611 auto constant2 = builder.AddInstruction(
1612 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
1613 auto tuple = builder.AddInstruction(
1614 HloInstruction::CreateTuple({constant1, constant2}));
1615 auto xla_while = builder.AddInstruction(
1616 HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple));
1617 module->AddEntryComputation(builder.Build());
1618
1619 InsertCopies(module.get());
1620
1621 EXPECT_EQ(CountCopies(*module), 6);
1622
1623 // The loop state elements should be copied at the parameter and at the root
1624 // with a control edge in between (see DeepCopyAndAddControlEdges).
1625 EXPECT_EQ(CountCopies(*body), 4);
1626 EXPECT_EQ(CountControlEdges(*body), 2);
1627
1628 EXPECT_THAT(
1629 body->root_instruction(),
1630 op::Tuple(op::Copy(op::Negate(op::Copy())), op::Copy(op::Copy())));
1631
1632 EXPECT_EQ(CountCopies(*module->entry_computation()), 2);
1633 EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy()));
1634 }
1635
TEST_F(CopyInsertionTest,SwizzlingWhileSharedInput)1636 TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) {
1637 // Test a while instruction with a body which permutes it's tuple parameter
1638 // elements similar to SwizzlinWhile above. However, in this test the input to
1639 // the while body is a single constant (both loop state elements are the same
1640 // constant). This means no copies are necessary because both loop state
1641 // elements are the same so interchanging them is a no-op.
1642 auto module = CreateNewVerifiedModule();
1643 const Shape loop_state_shape =
1644 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
1645
1646 // Body simply interchanges the two tuple elements in the loop state.
1647 auto body_builder = HloComputation::Builder("body");
1648 auto body_param = body_builder.AddInstruction(
1649 HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1650 auto body_element_0 = body_builder.AddInstruction(
1651 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
1652 auto body_element_1 = body_builder.AddInstruction(
1653 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
1654 body_builder.AddInstruction(
1655 HloInstruction::CreateTuple({body_element_1, body_element_0}));
1656 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1657
1658 auto cond_builder = HloComputation::Builder("condition");
1659 cond_builder.AddInstruction(
1660 HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1661 auto cond_constant = cond_builder.AddInstruction(
1662 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1663 cond_builder.AddInstruction(HloInstruction::CreateUnary(
1664 cond_constant->shape(), HloOpcode::kNot, cond_constant));
1665 HloComputation* condition =
1666 module->AddEmbeddedComputation(cond_builder.Build());
1667
1668 auto builder = HloComputation::Builder(TestName());
1669 auto constant = builder.AddInstruction(
1670 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1671 auto tuple =
1672 builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
1673 builder.AddInstruction(
1674 HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple));
1675 module->AddEntryComputation(builder.Build());
1676
1677 InsertCopies(module.get());
1678
1679 EXPECT_EQ(CountCopies(*module), 2);
1680 EXPECT_EQ(CountCopies(*body), 0);
1681
1682 EXPECT_EQ(CountCopies(*module->entry_computation()), 2);
1683 EXPECT_THAT(module->entry_computation()->root_instruction(),
1684 op::Tuple(op::Copy(), op::Copy()));
1685 }
1686
TEST_F(CopyInsertionTest,SequentialWhiles)1687 TEST_F(CopyInsertionTest, SequentialWhiles) {
1688 // Construct a computation with a series of sequential while instructions
1689 // containing four loop state elements:
1690 //
1691 // element 0 is passed to each while directly from an entry parameter.
1692 //
1693 // element 1 is passed transparently in series through all the while bodies.
1694 //
1695 // element 2 is negated in each while body. (in-place possible)
1696 //
1697 // element 3 is reversed in each while body. (in-place not possible)
1698 //
1699 const Shape element_shape = ShapeUtil::MakeShape(F32, {42});
1700 const Shape loop_state_shape = ShapeUtil::MakeTupleShape(
1701 {element_shape, element_shape, element_shape, element_shape});
1702
1703 auto module = CreateNewVerifiedModule();
1704 auto builder = HloComputation::Builder(TestName());
1705 auto param_0 = builder.AddInstruction(
1706 HloInstruction::CreateParameter(0, element_shape, "param_0"));
1707 auto param_1 = builder.AddInstruction(
1708 HloInstruction::CreateParameter(1, element_shape, "param_1"));
1709 auto param_2 = builder.AddInstruction(
1710 HloInstruction::CreateParameter(2, element_shape, "param_2"));
1711 auto param_3 = builder.AddInstruction(
1712 HloInstruction::CreateParameter(3, element_shape, "param_3"));
1713
1714 // The number of sequential kWhile instructions.
1715 const int kNumWhiles = 3;
1716
1717 HloInstruction* prev_element_1 = param_1;
1718 HloInstruction* prev_element_2 = param_2;
1719 HloInstruction* prev_element_3 = param_3;
1720
1721 // Vector containing all of the while instructions.
1722 std::vector<const HloInstruction*> whiles;
1723 for (int i = 0; i < kNumWhiles; ++i) {
1724 auto body_builder = HloComputation::Builder("body");
1725 auto body_param = body_builder.AddInstruction(
1726 HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1727 auto body_element_0 = body_builder.AddInstruction(
1728 HloInstruction::CreateGetTupleElement(element_shape, body_param, 0));
1729 auto body_element_1 = body_builder.AddInstruction(
1730 HloInstruction::CreateGetTupleElement(element_shape, body_param, 1));
1731 auto body_element_2 = body_builder.AddInstruction(
1732 HloInstruction::CreateGetTupleElement(element_shape, body_param, 2));
1733 auto body_element_3 = body_builder.AddInstruction(
1734 HloInstruction::CreateGetTupleElement(element_shape, body_param, 3));
1735 auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
1736 element_shape, HloOpcode::kNegate, body_element_2));
1737 auto reverse = body_builder.AddInstruction(
1738 HloInstruction::CreateReverse(element_shape, body_element_3, {0}));
1739 body_builder.AddInstruction(HloInstruction::CreateTuple(
1740 {body_element_0, body_element_1, negate, reverse}));
1741 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1742
1743 auto cond_builder = HloComputation::Builder("condition");
1744 cond_builder.AddInstruction(
1745 HloInstruction::CreateParameter(0, loop_state_shape, "param"));
1746 auto cond_constant = cond_builder.AddInstruction(
1747 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1748 cond_builder.AddInstruction(HloInstruction::CreateUnary(
1749 cond_constant->shape(), HloOpcode::kNot, cond_constant));
1750 HloComputation* condition =
1751 module->AddEmbeddedComputation(cond_builder.Build());
1752
1753 auto while_init = builder.AddInstruction(HloInstruction::CreateTuple(
1754 {param_0, prev_element_1, prev_element_2, prev_element_3}));
1755
1756 auto xla_while = builder.AddInstruction(HloInstruction::CreateWhile(
1757 loop_state_shape, condition, body, while_init));
1758 whiles.push_back(xla_while);
1759 if (i != kNumWhiles - 1) {
1760 prev_element_1 = builder.AddInstruction(
1761 HloInstruction::CreateGetTupleElement(element_shape, xla_while, 1));
1762 prev_element_2 = builder.AddInstruction(
1763 HloInstruction::CreateGetTupleElement(element_shape, xla_while, 2));
1764 prev_element_3 = builder.AddInstruction(
1765 HloInstruction::CreateGetTupleElement(element_shape, xla_while, 3));
1766 }
1767 }
1768
1769 module->AddEntryComputation(builder.Build());
1770
1771 InsertCopies(module.get());
1772
1773 // Each while body has one copy. And each loop state element is copied once in
1774 // the entry computation.
1775 EXPECT_EQ(CountCopies(*module), 4 + kNumWhiles);
1776
1777 // Each while body should have exactly one copy for element three which is an
1778 // op (kReverse) which cannot be done in place.
1779 for (const HloInstruction* xla_while : whiles) {
1780 EXPECT_EQ(CountCopies(*xla_while->while_body()), 1);
1781 }
1782
1783 EXPECT_THAT(whiles[0]->operand(0), op::Tuple(op::Parameter(), op::Parameter(),
1784 op::Copy(), op::Copy()));
1785 EXPECT_THAT(module->entry_computation()->root_instruction(),
1786 op::Tuple(op::Copy(), op::Copy(), op::GetTupleElement(),
1787 op::GetTupleElement()));
1788 }
1789
TEST_F(CopyInsertionTest,WhileBodyWithConstantRoot)1790 TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) {
1791 // Test a while body and condition which are each simply a constant (root of
1792 // computation is a constant). The body constant should be copied.
1793 auto module = CreateNewVerifiedModule();
1794 auto builder = HloComputation::Builder(TestName());
1795 auto param_0 = builder.AddInstruction(
1796 HloInstruction::CreateParameter(0, scalar_shape_, "param_0"));
1797
1798 auto body_builder = HloComputation::Builder("body");
1799 body_builder.AddInstruction(
1800 HloInstruction::CreateParameter(0, scalar_shape_, "param"));
1801 body_builder.AddInstruction(
1802 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0)));
1803 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
1804
1805 auto cond_builder = HloComputation::Builder("condition");
1806 cond_builder.AddInstruction(
1807 HloInstruction::CreateParameter(0, scalar_shape_, "param"));
1808 cond_builder.AddInstruction(
1809 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1810 HloComputation* condition =
1811 module->AddEmbeddedComputation(cond_builder.Build());
1812
1813 auto xla_while = builder.AddInstruction(
1814 HloInstruction::CreateWhile(scalar_shape_, condition, body, param_0));
1815
1816 module->AddEntryComputation(builder.Build());
1817
1818 InsertCopies(module.get());
1819
1820 EXPECT_EQ(CountCopies(*module), 2);
1821
1822 EXPECT_THAT(xla_while->operand(0), op::Copy(op::Parameter()));
1823 EXPECT_THAT(body->root_instruction(), op::Copy(op::Constant()));
1824 EXPECT_THAT(condition->root_instruction(), op::Constant());
1825 }
1826
TEST_F(CopyInsertionTest,TokensShouldNotBeCopied)1827 TEST_F(CopyInsertionTest, TokensShouldNotBeCopied) {
1828 string module_string = R"(
1829 HloModule TokensShouldNotBeCopied
1830
1831 %Body (param.1: (s32[], token[])) -> (s32[], token[]) {
1832 %param.1 = (s32[], token[]) parameter(0)
1833 %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
1834 %constant.1 = s32[] constant(1)
1835 %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
1836 %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
1837 %after-all = token[] after-all(token[] %get-tuple-element.2)
1838 ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
1839 }
1840
1841 %Cond (param: (s32[], token[])) -> pred[] {
1842 %param = (s32[], token[]) parameter(0)
1843 %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
1844 %constant = s32[] constant(42)
1845 ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
1846 }
1847
1848 ENTRY %TokensShouldNotBeCopied () -> s32[] {
1849 %one = s32[] constant(1)
1850 %negative_one = s32[] negate(%one)
1851 %init_token = token[] after-all()
1852 %init_tuple = (s32[], token[]) tuple(s32[] %negative_one, token[] %init_token)
1853 %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
1854 ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
1855 }
1856 )";
1857 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1858 ParseAndReturnVerifiedModule(module_string));
1859 InsertCopies(module.get());
1860
1861 // There should be no copies added because tokens should not be copied.
1862 EXPECT_EQ(CountCopies(*module), 0);
1863 }
1864
MakeTrivialCondition(const Shape & shape)1865 std::unique_ptr<HloComputation> MakeTrivialCondition(const Shape& shape) {
1866 auto builder = HloComputation::Builder("trivial_condition");
1867 builder.AddInstruction(
1868 HloInstruction::CreateParameter(0, shape, "loop_state"));
1869 auto constant = builder.AddInstruction(
1870 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1871 builder.AddInstruction(HloInstruction::CreateUnary(
1872 constant->shape(), HloOpcode::kNot, constant));
1873 return builder.Build();
1874 }
1875
MakeBenchmarkWhileBody()1876 std::unique_ptr<HloComputation> MakeBenchmarkWhileBody() {
1877 auto builder = HloComputation::Builder("benchmark_loop_body");
1878 const Shape element_shape = ShapeUtil::MakeShape(F32, {42});
1879 const Shape loop_state_shape =
1880 ShapeUtil::MakeTupleShape({element_shape, element_shape, element_shape});
1881 HloInstruction* param = builder.AddInstruction(
1882 HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
1883 HloInstruction* element_0 = builder.AddInstruction(
1884 HloInstruction::CreateGetTupleElement(element_shape, param, 0));
1885 HloInstruction* element_1 = builder.AddInstruction(
1886 HloInstruction::CreateGetTupleElement(element_shape, param, 1));
1887 HloInstruction* element_2 = builder.AddInstruction(
1888 HloInstruction::CreateGetTupleElement(element_shape, param, 2));
1889
1890 HloInstruction* rev_1 = builder.AddInstruction(
1891 HloInstruction::CreateReverse(element_shape, element_1, {0}));
1892 HloInstruction* add_1_2 = builder.AddInstruction(HloInstruction::CreateBinary(
1893 element_shape, HloOpcode::kAdd, element_1, element_2));
1894
1895 builder.AddInstruction(
1896 HloInstruction::CreateTuple({element_0, rev_1, add_1_2}));
1897 return builder.Build();
1898 }
1899
BM_SequentialWhiles(int num_iters,int num_whiles)1900 void BM_SequentialWhiles(int num_iters, int num_whiles) {
1901 // This benchmark constructs a chain of sequential while instructions.
1902 tensorflow::testing::StopTiming();
1903 for (int i = 0; i < num_iters; ++i) {
1904 HloModuleConfig config;
1905 config.set_debug_options(GetDebugOptionsFromFlags());
1906 HloModule module("BM_SequentialWhiles", config);
1907
1908 auto builder = HloComputation::Builder("BM_SequentialWhiles");
1909 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
1910 0, ShapeUtil::MakeShape(F32, {42}), "x"));
1911 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
1912 1, ShapeUtil::MakeShape(F32, {42}), "y"));
1913 HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
1914 2, ShapeUtil::MakeShape(F32, {42}), "z"));
1915 HloInstruction* init =
1916 builder.AddInstruction(HloInstruction::CreateTuple({x, y, z}));
1917
1918 HloInstruction* prev_loop_state = init;
1919 for (int w = 0; w < num_whiles; ++w) {
1920 HloComputation* condition =
1921 module.AddEmbeddedComputation(MakeTrivialCondition(init->shape()));
1922 HloComputation* body =
1923 module.AddEmbeddedComputation(MakeBenchmarkWhileBody());
1924 prev_loop_state = builder.AddInstruction(HloInstruction::CreateWhile(
1925 init->shape(), condition, body, prev_loop_state));
1926 }
1927 module.AddEntryComputation(builder.Build());
1928
1929 CopyInsertion copy_insertion;
1930
1931 tensorflow::testing::StartTiming();
1932 ASSERT_IS_OK(copy_insertion.Run(&module).status());
1933 tensorflow::testing::StopTiming();
1934
1935 // The entry computation should have three copies, and each body has one.
1936 ASSERT_EQ(CountCopies(module), 3 + num_whiles);
1937 }
1938 }
1939
BM_ParallelWhiles(int num_iters,int num_whiles)1940 void BM_ParallelWhiles(int num_iters, int num_whiles) {
1941 // This benchmark constructs a fan-out of parallel while instructions.
1942 tensorflow::testing::StopTiming();
1943 for (int i = 0; i < num_iters; ++i) {
1944 HloModuleConfig config;
1945 config.set_debug_options(GetDebugOptionsFromFlags());
1946 HloModule module("BM_SequentialWhiles", config);
1947
1948 auto builder = HloComputation::Builder("BM_ParallelWhiles");
1949 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
1950 0, ShapeUtil::MakeShape(F32, {42}), "x"));
1951 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
1952 1, ShapeUtil::MakeShape(F32, {42}), "y"));
1953 HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
1954 2, ShapeUtil::MakeShape(F32, {42}), "z"));
1955 HloInstruction* init =
1956 builder.AddInstruction(HloInstruction::CreateTuple({x, y, z}));
1957
1958 HloInstruction* sum = nullptr;
1959 for (int w = 0; w < num_whiles; ++w) {
1960 HloComputation* condition =
1961 module.AddEmbeddedComputation(MakeTrivialCondition(init->shape()));
1962 HloComputation* body =
1963 module.AddEmbeddedComputation(MakeBenchmarkWhileBody());
1964
1965 HloInstruction* xla_while = builder.AddInstruction(
1966 HloInstruction::CreateWhile(init->shape(), condition, body, init));
1967
1968 if (sum == nullptr) {
1969 sum = builder.AddInstruction(
1970 HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0));
1971 } else {
1972 HloInstruction* element_0 = builder.AddInstruction(
1973 HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0));
1974 sum = builder.AddInstruction(HloInstruction::CreateBinary(
1975 x->shape(), HloOpcode::kAdd, sum, element_0));
1976 }
1977 }
1978 module.AddEntryComputation(builder.Build());
1979
1980 CopyInsertion copy_insertion;
1981
1982 tensorflow::testing::StartTiming();
1983 ASSERT_IS_OK(copy_insertion.Run(&module).status());
1984 tensorflow::testing::StopTiming();
1985
1986 // Each body receives of copy of two of the parameters (the corresponding
1987 // elements in the body are modifed), and there is one copy in each body.
1988 ASSERT_EQ(CountCopies(module), 3 * num_whiles);
1989 }
1990 }
1991
MakeBenchmarkWhileBody(const int num_tuple_inputs)1992 std::unique_ptr<HloComputation> MakeBenchmarkWhileBody(
1993 const int num_tuple_inputs) {
1994 auto builder = HloComputation::Builder("benchmark_loop_body");
1995 const Shape element_shape = ShapeUtil::MakeShape(F32, {});
1996 std::vector<Shape> input_shape(num_tuple_inputs, element_shape);
1997 const Shape loop_state_shape = ShapeUtil::MakeTupleShape(input_shape);
1998 HloInstruction* param = builder.AddInstruction(
1999 HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
2000 std::vector<HloInstruction*> gte_nodes(num_tuple_inputs);
2001 for (int i = 0; i < num_tuple_inputs; ++i) {
2002 gte_nodes[i] = builder.AddInstruction(
2003 HloInstruction::CreateGetTupleElement(element_shape, param, i));
2004 }
2005 builder.AddInstruction(HloInstruction::CreateTuple(gte_nodes));
2006 return builder.Build();
2007 }
2008
BM_ManyElementTuple(int num_iters,const int num_tuple_inputs)2009 void BM_ManyElementTuple(int num_iters, const int num_tuple_inputs) {
2010 tensorflow::testing::StopTiming();
2011 HloModuleConfig config;
2012 config.set_debug_options(GetDebugOptionsFromFlags());
2013 CopyInsertion copy_insertion;
2014 const Shape element_shape = ShapeUtil::MakeShape(F32, {});
2015 std::vector<HloInstruction*> tuple_params(num_tuple_inputs);
2016 for (int i = 0; i < num_iters; ++i) {
2017 auto builder = HloComputation::Builder("BM_ParallelWhiles");
2018 HloModule module("BM_ManyElementTuple", config);
2019 for (int j = 0; j < num_tuple_inputs; ++j) {
2020 tuple_params[j] = builder.AddInstruction(
2021 HloInstruction::CreateParameter(j, element_shape, ""));
2022 }
2023 HloInstruction* init =
2024 builder.AddInstruction(HloInstruction::CreateTuple(tuple_params));
2025 HloComputation* condition =
2026 module.AddEmbeddedComputation(MakeTrivialCondition(init->shape()));
2027 HloComputation* body =
2028 module.AddEmbeddedComputation(MakeBenchmarkWhileBody(num_tuple_inputs));
2029 HloInstruction* xla_while = builder.AddInstruction(
2030 HloInstruction::CreateWhile(init->shape(), condition, body, init));
2031 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
2032 ShapeUtil::MakeShape(F32, {}), xla_while, 0));
2033 module.AddEntryComputation(builder.Build());
2034 tensorflow::testing::StartTiming();
2035 ASSERT_IS_OK(copy_insertion.Run(&module).status());
2036 tensorflow::testing::StopTiming();
2037 }
2038 }
2039
2040 BENCHMARK(BM_SequentialWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096);
2041 BENCHMARK(BM_ParallelWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096);
2042 BENCHMARK(BM_ManyElementTuple)->Arg(1024)->Arg(12288);
2043
TEST_F(CopyInsertionTest,SimpleControlFlowTest)2044 TEST_F(CopyInsertionTest, SimpleControlFlowTest) {
2045 const string& hlo_string = R"(
2046 HloModule TestModule
2047
2048 if-body.v5 {
2049 constant.3 = s32[] constant(-1)
2050 p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2051 get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1
2052 get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0
2053 get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1
2054 add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66)
2055 tuple.33 = (s32[]) tuple(add.3)
2056 ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33)
2057 }
2058
2059 if-condition.v4 {
2060 p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2061 get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0
2062 constant.4 = s32[] constant(0)
2063 ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ
2064 }
2065
2066 _functionalize_body_1__.v28 {
2067 arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0)
2068 get-tuple-element.68 = s32[] get-tuple-element(arg_tuple.4), index=0
2069 constant.7 = s32[] constant(1)
2070 add.4 = s32[] add(get-tuple-element.68, constant.7)
2071 get-tuple-element.69 = s32[] get-tuple-element(arg_tuple.4), index=1
2072 get-tuple-element.70 = s32[] get-tuple-element(arg_tuple.4), index=2
2073 less-than-or-equal-to = pred[] compare(get-tuple-element.69, get-tuple-element.70), direction=LE
2074 constant.8 = s32[] constant(0)
2075 select = s32[] select(less-than-or-equal-to, constant.8, constant.7)
2076 get-tuple-element.71 = s32[] get-tuple-element(arg_tuple.4), index=3
2077 tuple.35 = (s32[], s32[], s32[]) tuple(get-tuple-element.69, get-tuple-element.71, get-tuple-element.70)
2078 tuple.36 = (s32[]) tuple(constant.8)
2079 tuple.37 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.35, tuple.36)
2080 while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.37), condition=if-condition.v4, body=if-body.v5
2081 get-tuple-element.72 = (s32[]) get-tuple-element(while), index=2
2082 get-tuple-element.73 = s32[] get-tuple-element(get-tuple-element.72), index=0
2083 ROOT tuple.38 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.69, get-tuple-element.70, get-tuple-element.73)
2084 }
2085
2086 cond_wrapper.v3.1 {
2087 inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0)
2088 get-tuple-element.75 = s32[] get-tuple-element(inputs.1), index=0
2089 constant.11 = s32[] constant(7)
2090 ROOT less-than.2 = pred[] compare(get-tuple-element.75, constant.11), direction=LT
2091 }
2092
2093 _functionalize_body_2__.v25 {
2094 arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2095 get-tuple-element.76 = s32[] get-tuple-element(arg_tuple.5), index=0
2096 get-tuple-element.77 = s32[] get-tuple-element(arg_tuple.5), index=2
2097 get-tuple-element.78 = s32[] get-tuple-element(arg_tuple.5), index=3
2098 get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=4
2099 tuple.39 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.76, get-tuple-element.77, get-tuple-element.78, get-tuple-element.79)
2100 while.2 = (s32[], s32[], s32[], s32[]) while(tuple.39), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28
2101 get-tuple-element.80 = s32[] get-tuple-element(while.2), index=0
2102 get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=1
2103 constant.12 = s32[] constant(1)
2104 add.5 = s32[] add(get-tuple-element.81, constant.12)
2105 get-tuple-element.82 = s32[] get-tuple-element(while.2), index=3
2106 ROOT tuple.40 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.80, add.5, get-tuple-element.77, get-tuple-element.78, get-tuple-element.82)
2107 }
2108
2109 cond_wrapper.v3.2 {
2110 inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2111 get-tuple-element.83 = s32[] get-tuple-element(inputs.2), index=1
2112 constant.13 = s32[] constant(5)
2113 ROOT less-than.3 = pred[] compare(get-tuple-element.83, constant.13), direction=LT
2114 }
2115
2116 ENTRY TestComputation {
2117 arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2118 ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25
2119 }
2120 )";
2121 auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
2122 auto module = module_or_status.ConsumeValueOrDie();
2123 InsertCopies(module.get());
2124 }
2125
TEST_F(CopyInsertionTest,ControlFlowTest)2126 TEST_F(CopyInsertionTest, ControlFlowTest) {
2127 const string& hlo_string = R"(
2128 HloModule TestModule
2129
2130 if-body.v5 {
2131 constant.3 = s32[] constant(-1)
2132 p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2133 get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1
2134 get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0
2135 get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1
2136 add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66)
2137 tuple.33 = (s32[]) tuple(add.3)
2138 ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33)
2139 }
2140
2141 if-condition.v4 {
2142 p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2143 get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0
2144 constant.4 = s32[] constant(0)
2145 ROOT equal-to = pred[] compare(get-tuple-element.67, constant.4), direction=EQ
2146 }
2147
2148 if-body.v5.1 {
2149 constant.5 = s32[] constant(-1)
2150 p.3 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2151 get-tuple-element.68 = (s32[], s32[], s32[]) get-tuple-element(p.3), index=1
2152 get-tuple-element.70 = s32[] get-tuple-element(get-tuple-element.68), index=2
2153 multiply.1 = s32[] multiply(get-tuple-element.70, get-tuple-element.70)
2154 tuple.35 = (s32[]) tuple(multiply.1)
2155 ROOT tuple.36 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.5, get-tuple-element.68, tuple.35)
2156 }
2157
2158 if-condition.v4.1 {
2159 p.4 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0)
2160 get-tuple-element.71 = s32[] get-tuple-element(p.4), index=0
2161 constant.6 = s32[] constant(1)
2162 ROOT equal-to.1 = pred[] compare(get-tuple-element.71, constant.6), direction=EQ
2163 }
2164
2165 _functionalize_body_1__.v28 {
2166 arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0)
2167 get-tuple-element.72 = s32[] get-tuple-element(arg_tuple.4), index=0
2168 constant.7 = s32[] constant(1)
2169 add.4 = s32[] add(get-tuple-element.72, constant.7)
2170 get-tuple-element.73 = s32[] get-tuple-element(arg_tuple.4), index=1
2171 get-tuple-element.74 = s32[] get-tuple-element(arg_tuple.4), index=2
2172 less-than-or-equal-to = pred[] compare(get-tuple-element.73, get-tuple-element.74), direction=LE
2173 constant.8 = s32[] constant(0)
2174 select = s32[] select(less-than-or-equal-to, constant.8, constant.7)
2175 get-tuple-element.75 = s32[] get-tuple-element(arg_tuple.4), index=3
2176 tuple.37 = (s32[], s32[], s32[]) tuple(get-tuple-element.73, get-tuple-element.75, get-tuple-element.74)
2177 tuple.38 = (s32[]) tuple(constant.8)
2178 tuple.39 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.37, tuple.38)
2179 while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.39), condition=if-condition.v4, body=if-body.v5
2180 while.1 = (s32[], (s32[], s32[], s32[]), (s32[])) while(while), condition=if-condition.v4.1, body=if-body.v5.1
2181 get-tuple-element.76 = (s32[]) get-tuple-element(while.1), index=2
2182 get-tuple-element.77 = s32[] get-tuple-element(get-tuple-element.76), index=0
2183 ROOT tuple.40 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.73, get-tuple-element.74, get-tuple-element.77)
2184 }
2185
2186 cond_wrapper.v3.1 {
2187 inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0)
2188 get-tuple-element.78 = s32[] get-tuple-element(inputs.1), index=0
2189 constant.11 = s32[] constant(7)
2190 ROOT less-than.2 = pred[] compare(get-tuple-element.78, constant.11), direction=LT
2191 }
2192
2193 _functionalize_body_2__.v25 {
2194 arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2195 get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=0
2196 get-tuple-element.80 = s32[] get-tuple-element(arg_tuple.5), index=2
2197 get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=3
2198 get-tuple-element.82 = s32[] get-tuple-element(arg_tuple.5), index=4
2199 tuple.41 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.79, get-tuple-element.80, get-tuple-element.81, get-tuple-element.82)
2200 while.2 = (s32[], s32[], s32[], s32[]) while(tuple.41), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28
2201 get-tuple-element.83 = s32[] get-tuple-element(while.2), index=0
2202 get-tuple-element.84 = s32[] get-tuple-element(arg_tuple.5), index=1
2203 constant.12 = s32[] constant(1)
2204 add.5 = s32[] add(get-tuple-element.84, constant.12)
2205 get-tuple-element.85 = s32[] get-tuple-element(while.2), index=3
2206 ROOT tuple.42 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.83, add.5, get-tuple-element.80, get-tuple-element.81, get-tuple-element.85)
2207 }
2208
2209 cond_wrapper.v3.2 {
2210 inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2211 get-tuple-element.86 = s32[] get-tuple-element(inputs.2), index=1
2212 constant.13 = s32[] constant(5)
2213 ROOT less-than.3 = pred[] compare(get-tuple-element.86, constant.13), direction=LT
2214 }
2215
2216 ENTRY TestComputation {
2217 arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0)
2218 ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25
2219 }
2220 )";
2221 auto module_or_status = ParseAndReturnVerifiedModule(hlo_string);
2222 auto module = module_or_status.ConsumeValueOrDie();
2223 InsertCopies(module.get());
2224 }
2225
TEST_F(CopyInsertionTest,NestedWhiles)2226 TEST_F(CopyInsertionTest, NestedWhiles) {
2227 // Verify that only no unnecessary copies remain after copy insertion for
2228 // trivial nested whiles (b/112472605).
2229 const string& hlo_string = R"(
2230 HloModule TestModule
2231
2232 cond.inner {
2233 ROOT param.cond.inner = pred[] parameter(0)
2234 }
2235
2236 body.inner {
2237 param.body.inner = pred[] parameter(0)
2238 ROOT not = pred[] not(param.body.inner)
2239 }
2240
2241 cond.outer {
2242 ROOT param.cond.outer = pred[] parameter(0)
2243 }
2244
2245 body.outer {
2246 param.cond.outer = pred[] parameter(0)
2247 ROOT while = pred[] while(param.cond.outer), condition=cond.inner, body=body.inner
2248 }
2249
2250 ENTRY TestComputation {
2251 entry_param = pred[] parameter(0)
2252 ROOT while = pred[] while(entry_param), condition=cond.outer, body=body.outer
2253 }
2254 )";
2255 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2256 ParseAndReturnVerifiedModule(hlo_string));
2257 InsertCopies(module.get());
2258
2259 // There should only be a single copy inserted, and it's in the entry
2260 // computation.
2261 EXPECT_EQ(CountCopies(*module), 1);
2262 EXPECT_THAT(module->entry_computation()->root_instruction(),
2263 op::While(op::Copy(op::Parameter())));
2264 }
2265
2266 } // namespace
2267 } // namespace xla
2268