1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/bfloat16_propagation.h"
17 #include "tensorflow/compiler/xla/service/bfloat16_support.h"
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_module.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/test_helpers.h"
25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
26 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28
29 namespace xla {
30
31 // A class specifying the BF16 support used to test the propagation pass. It
32 // specifies that BF16 and mixed precision are supported in all HloInstructions,
33 // and that kDot reduces its operands precision to BF16.
34 class TestBFloat16Support : public BFloat16Support {
35 public:
TestBFloat16Support()36 TestBFloat16Support() {}
~TestBFloat16Support()37 ~TestBFloat16Support() override {}
38
SupportsBF16Operand(const HloInstruction & hlo,int64 operand_index) const39 bool SupportsBF16Operand(const HloInstruction& hlo,
40 int64 operand_index) const override {
41 return true;
42 }
43
SupportsBF16Output(const HloInstruction & hlo) const44 bool SupportsBF16Output(const HloInstruction& hlo) const override {
45 return true;
46 }
47
SupportsMixedPrecisions(const HloInstruction & hlo) const48 bool SupportsMixedPrecisions(const HloInstruction& hlo) const override {
49 return true;
50 }
51
EffectiveOperandPrecisionIsBF16(const HloInstruction & hlo,int64 operand_index) const52 bool EffectiveOperandPrecisionIsBF16(const HloInstruction& hlo,
53 int64 operand_index) const override {
54 return hlo.opcode() == HloOpcode::kDot;
55 }
56 };
57
58 class BFloat16PropagationTest : public HloTestBase {
59 protected:
BFloat16PropagationTest()60 BFloat16PropagationTest()
61 : HloTestBase(/*verifier_layout_sensitive=*/false,
62 /*allow_mixed_precision_in_hlo_verifier=*/true) {}
63
64 // Runs the propagation pass on the given module, and returns whether the
65 // module is changed after this pass.
PropagatePrecision(HloModule * module)66 bool PropagatePrecision(HloModule* module) {
67 TestBFloat16Support bfloat16_support;
68 BFloat16Propagation propagation(&bfloat16_support);
69 StatusOr<bool> result = propagation.Run(module);
70 EXPECT_IS_OK(result.status());
71 return result.ValueOrDie();
72 }
73
74 // Returns whether the given HloInstruction's output element type is BF16 or
75 // the only use of it is converting to BF16.
OutputsBF16(const HloInstruction * inst)76 bool OutputsBF16(const HloInstruction* inst) {
77 if (inst->shape().element_type() == BF16) {
78 return true;
79 }
80 return inst->user_count() == 1 &&
81 inst->users()[0]->opcode() == HloOpcode::kConvert &&
82 inst->users()[0]->shape().element_type() == BF16;
83 }
84
CreateDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs)85 std::unique_ptr<HloInstruction> CreateDot(const Shape& shape,
86 HloInstruction* lhs,
87 HloInstruction* rhs) {
88 DotDimensionNumbers dot_dnums;
89 dot_dnums.add_lhs_contracting_dimensions(1);
90 dot_dnums.add_rhs_contracting_dimensions(0);
91 return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums,
92 DefaultPrecisionConfig(2));
93 }
94 };
95
96 // Tests that BF16 can propagate through select over non-tuple buffers, but not
97 // through add where reducing operand precision can affect the result.
TEST_F(BFloat16PropagationTest,PropagateThroughSelectButNotAdd)98 TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) {
99 auto builder = HloComputation::Builder(TestName());
100 Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
101
102 HloInstruction* a =
103 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
104 HloInstruction* b =
105 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
106 HloInstruction* c =
107 builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "c"));
108 HloInstruction* add0 = builder.AddInstruction(
109 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
110 HloInstruction* add1 = builder.AddInstruction(
111 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, b));
112 HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateCompare(
113 ShapeUtil::MakeShape(PRED, {2, 4}), a, b, ComparisonDirection::kEq));
114 HloInstruction* sel = builder.AddInstruction(
115 HloInstruction::CreateTernary(shape, HloOpcode::kSelect, pred, c, add1));
116 HloInstruction* xpose =
117 builder.AddInstruction(HloInstruction::CreateTranspose(
118 ShapeUtil::MakeShape(F32, {4, 2}), sel, {1, 0}));
119 HloInstruction* dot = builder.AddInstruction(
120 CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, a));
121 HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary(
122 ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot));
123
124 auto module = CreateNewVerifiedModule();
125 auto computation = module->AddEntryComputation(builder.Build());
126
127 EXPECT_TRUE(PropagatePrecision(module.get()));
128
129 EXPECT_EQ(computation->root_instruction(), root);
130 EXPECT_TRUE(OutputsBF16(xpose));
131 EXPECT_TRUE(OutputsBF16(sel));
132 EXPECT_TRUE(OutputsBF16(add1));
133 EXPECT_FALSE(OutputsBF16(add0));
134 EXPECT_FALSE(OutputsBF16(a));
135 EXPECT_FALSE(OutputsBF16(b));
136 EXPECT_FALSE(OutputsBF16(c));
137 }
138
TEST_F(BFloat16PropagationTest,PropagateThroughMaxPoolReduceWindow)139 TEST_F(BFloat16PropagationTest, PropagateThroughMaxPoolReduceWindow) {
140 auto module = CreateNewVerifiedModule();
141
142 auto sub_builder = HloComputation::Builder("max");
143 HloInstruction* p0 = sub_builder.AddInstruction(
144 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "a"));
145 HloInstruction* p1 = sub_builder.AddInstruction(
146 HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "b"));
147 sub_builder.AddInstruction(HloInstruction::CreateBinary(
148 ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, p0, p1));
149 auto max_computation = module->AddEmbeddedComputation(sub_builder.Build());
150
151 auto builder = HloComputation::Builder(TestName());
152 Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
153
154 HloInstruction* a =
155 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
156 HloInstruction* b =
157 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
158 HloInstruction* c =
159 builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "c"));
160 HloInstruction* add = builder.AddInstruction(
161 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
162 Window window;
163 WindowDimension dim;
164 dim.set_size(2);
165 dim.set_stride(1);
166 dim.set_padding_high(1);
167 dim.set_window_dilation(1);
168 dim.set_base_dilation(1);
169 *window.add_dimensions() = dim;
170 *window.add_dimensions() = dim;
171 HloInstruction* rw =
172 builder.AddInstruction(HloInstruction::CreateReduceWindow(
173 shape, add,
174 builder.AddInstruction(
175 HloInstruction::CreateConstant(LiteralUtil::Zero(F32))),
176 window, max_computation));
177 HloInstruction* xpose =
178 builder.AddInstruction(HloInstruction::CreateTranspose(
179 ShapeUtil::MakeShape(F32, {4, 2}), c, {1, 0}));
180 HloInstruction* dot = builder.AddInstruction(
181 CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, rw));
182 HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary(
183 ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot));
184
185 auto computation = module->AddEntryComputation(builder.Build());
186
187 EXPECT_TRUE(PropagatePrecision(module.get()));
188
189 EXPECT_EQ(computation->root_instruction(), root);
190 EXPECT_TRUE(OutputsBF16(add));
191 EXPECT_TRUE(OutputsBF16(xpose));
192 EXPECT_TRUE(OutputsBF16(rw));
193 }
194
195 // Tests that side-effecting all-reduce should not be changed.
TEST_F(BFloat16PropagationTest,DoNotChangeAllReduce)196 TEST_F(BFloat16PropagationTest, DoNotChangeAllReduce) {
197 auto module = CreateNewVerifiedModule();
198
199 auto builder = HloComputation::Builder(TestName());
200 Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
201 HloInstruction* a =
202 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
203 HloInstruction* b =
204 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
205 auto rb = HloComputation::Builder(TestName());
206 rb.AddInstruction(HloInstruction::CreateBinary(
207 shape, HloOpcode::kAdd,
208 rb.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")),
209 rb.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"))));
210 auto reduction = module->AddEmbeddedComputation(rb.Build());
211 HloInstruction* all_reduce =
212 builder.AddInstruction(HloInstruction::CreateAllReduce(
213 ShapeUtil::MakeTupleShape({shape, shape}), {a, b}, reduction,
214 /*replica_groups=*/{}, /*barrier=*/"", /*all_reduce_id=*/1));
215 HloInstruction* gte0 = builder.AddInstruction(
216 HloInstruction::CreateGetTupleElement(shape, all_reduce, 0));
217 HloInstruction* gte1 = builder.AddInstruction(
218 HloInstruction::CreateGetTupleElement(shape, all_reduce, 1));
219 HloInstruction* dot = builder.AddInstruction(CreateDot(shape, gte0, gte1));
220 HloInstruction* root = builder.AddInstruction(
221 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
222
223 auto computation = module->AddEntryComputation(builder.Build());
224
225 EXPECT_FALSE(PropagatePrecision(module.get()));
226 EXPECT_EQ(computation->root_instruction(), root);
227 }
228
229 // Tests that if a constant is converted to BF16 then its literal must also be
230 // converted.
TEST_F(BFloat16PropagationTest,ConvertConstantLiteral)231 TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
232 auto builder = HloComputation::Builder(TestName());
233 Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
234 Array2D<float> array_a(4, 4);
235 array_a.FillUnique(1.0f);
236 Array2D<float> array_b(4, 4);
237 array_b.FillUnique(10.0f);
238
239 HloInstruction* a = builder.AddInstruction(
240 HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_a)));
241 HloInstruction* b = builder.AddInstruction(
242 HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_b)));
243 HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a, b));
244
245 auto module = CreateNewVerifiedModule();
246 auto computation = module->AddEntryComputation(builder.Build());
247
248 EXPECT_TRUE(PropagatePrecision(module.get()));
249
250 EXPECT_EQ(computation->root_instruction(), dot);
251 EXPECT_TRUE(OutputsBF16(dot->operand(0)));
252 EXPECT_TRUE(OutputsBF16(dot->operand(1)));
253 EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant);
254 EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant);
255 EXPECT_TRUE(LiteralTestUtil::Equal(
256 LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_a)),
257 dot->operand(0)->literal()));
258 EXPECT_TRUE(LiteralTestUtil::Equal(
259 LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_b)),
260 dot->operand(1)->literal()));
261 }
262
263 // Tests that BF16 can be propagated through nested tuples.
TEST_F(BFloat16PropagationTest,PropagateThroughTuples)264 TEST_F(BFloat16PropagationTest, PropagateThroughTuples) {
265 auto builder = HloComputation::Builder(TestName());
266 Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
267
268 HloInstruction* a =
269 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
270 HloInstruction* b =
271 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
272 HloInstruction* add0 = builder.AddInstruction(
273 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
274 HloInstruction* add1 = builder.AddInstruction(
275 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, a));
276 HloInstruction* add2 = builder.AddInstruction(
277 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, b, b));
278 HloInstruction* xpose =
279 builder.AddInstruction(HloInstruction::CreateTranspose(
280 ShapeUtil::MakeShape(F32, {4, 2}), add1, {1, 0}));
281
282 HloInstruction* tuple0 =
283 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1, add2}));
284 HloInstruction* tuple1 =
285 builder.AddInstruction(HloInstruction::CreateTuple({tuple0, xpose}));
286
287 HloInstruction* lhs = builder.AddInstruction(
288 HloInstruction::CreateGetTupleElement(xpose->shape(), tuple1, 1));
289 HloInstruction* rhs =
290 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
291 add0->shape(),
292 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
293 tuple0->shape(), tuple1, 0)),
294 0));
295 HloInstruction* dot = builder.AddInstruction(
296 CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs));
297
298 HloInstruction* output_tuple =
299 builder.AddInstruction(HloInstruction::CreateTuple({dot, add2}));
300
301 auto module = CreateNewVerifiedModule();
302 auto computation = module->AddEntryComputation(builder.Build());
303
304 EXPECT_TRUE(PropagatePrecision(module.get()));
305
306 EXPECT_EQ(computation->root_instruction(), output_tuple);
307 EXPECT_TRUE(OutputsBF16(xpose));
308 EXPECT_TRUE(OutputsBF16(add0));
309 EXPECT_TRUE(OutputsBF16(add1));
310 EXPECT_FALSE(OutputsBF16(add2));
311 }
312
313 // Tests that even if an instruction does not define a buffer in its output, its
314 // shape must match the defining instruction.
TEST_F(BFloat16PropagationTest,SameValueReferencedTwice)315 TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) {
316 auto builder = HloComputation::Builder(TestName());
317 Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
318
319 HloInstruction* a =
320 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
321 HloInstruction* b =
322 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
323 HloInstruction* add0 = builder.AddInstruction(
324 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
325 HloInstruction* add1 = builder.AddInstruction(
326 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, a));
327
328 HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateTranspose(
329 ShapeUtil::MakeShape(F32, {4, 2}), add1, {1, 0}));
330
331 HloInstruction* tuple =
332 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
333 HloInstruction* rhs = builder.AddInstruction(
334 HloInstruction::CreateGetTupleElement(add1->shape(), tuple, 1));
335
336 // lhs is the transpose of add1, and rhs is a get-tuple-element aliasing add1.
337 HloInstruction* dot = builder.AddInstruction(
338 CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs));
339
340 auto module = CreateNewVerifiedModule();
341 auto computation = module->AddEntryComputation(builder.Build());
342
343 EXPECT_TRUE(PropagatePrecision(module.get()));
344
345 EXPECT_EQ(computation->root_instruction(), dot);
346 EXPECT_TRUE(OutputsBF16(add1));
347 EXPECT_TRUE(OutputsBF16(lhs));
348
349 // add0 and rhs have been eliminated by simplification and DCE.
350 }
351
352 // Tests that a non-fusion computation's root should not be changed.
TEST_F(BFloat16PropagationTest,DoNotChangeComputationRoot)353 TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) {
354 auto builder = HloComputation::Builder(TestName());
355 Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
356
357 HloInstruction* a =
358 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
359 HloInstruction* b =
360 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
361 HloInstruction* add = builder.AddInstruction(
362 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
363
364 HloInstruction* dot = builder.AddInstruction(CreateDot(shape, add, add));
365
366 HloInstruction* tuple =
367 builder.AddInstruction(HloInstruction::CreateTuple({add, dot}));
368
369 auto module = CreateNewVerifiedModule();
370 auto computation = module->AddEntryComputation(builder.Build());
371
372 EXPECT_FALSE(PropagatePrecision(module.get()));
373
374 EXPECT_EQ(computation->root_instruction(), tuple);
375 EXPECT_FALSE(OutputsBF16(add));
376 }
377
378 // Tests that BF16 is propagated properly through fused computations.
TEST_F(BFloat16PropagationTest,PropagateThroughFusion)379 TEST_F(BFloat16PropagationTest, PropagateThroughFusion) {
380 auto module = CreateNewVerifiedModule();
381 auto builder = HloComputation::Builder(TestName());
382 Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
383
384 HloInstruction* param = builder.AddInstruction(
385 HloInstruction::CreateParameter(0, shape, "param"));
386 HloInstruction* add = builder.AddInstruction(
387 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param));
388
389 auto builder_f0 = HloComputation::Builder("fusion0");
390 HloInstruction* a_f0 =
391 builder_f0.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
392 HloInstruction* b_f0 =
393 builder_f0.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
394 HloInstruction* tuple_f0 =
395 builder_f0.AddInstruction(HloInstruction::CreateTuple({a_f0, b_f0}));
396 auto comp_f0 = module->AddEmbeddedComputation(builder_f0.Build());
397 auto fusion0 = builder.AddInstruction(HloInstruction::CreateFusion(
398 tuple_f0->shape(), HloInstruction::FusionKind::kCustom, {add, add},
399 comp_f0));
400
401 auto builder_f1 = HloComputation::Builder("fusion1");
402 HloInstruction* p_f1 = builder_f1.AddInstruction(
403 HloInstruction::CreateParameter(0, tuple_f0->shape(), "param"));
404 HloInstruction* a_f1 = builder_f1.AddInstruction(
405 HloInstruction::CreateGetTupleElement(shape, p_f1, 0));
406 HloInstruction* b_f1 = builder_f1.AddInstruction(
407 HloInstruction::CreateGetTupleElement(shape, p_f1, 1));
408 HloInstruction* dot = builder_f1.AddInstruction(CreateDot(shape, a_f1, b_f1));
409 auto comp_f1 = module->AddEmbeddedComputation(builder_f1.Build());
410 auto fusion1 = builder.AddInstruction(HloInstruction::CreateFusion(
411 dot->shape(), HloInstruction::FusionKind::kCustom, {fusion0}, comp_f1));
412
413 auto computation = module->AddEntryComputation(builder.Build());
414
415 EXPECT_TRUE(PropagatePrecision(module.get()));
416
417 EXPECT_EQ(computation->root_instruction(), fusion1);
418 EXPECT_TRUE(OutputsBF16(add));
419 EXPECT_TRUE(OutputsBF16(a_f0));
420 EXPECT_TRUE(OutputsBF16(b_f0));
421 EXPECT_TRUE(OutputsBF16(a_f1));
422 EXPECT_TRUE(OutputsBF16(b_f1));
423 }
424
425 // Tests that changes to BF16 that cannot be propagated outside a fusion are
426 // discarded.
TEST_F(BFloat16PropagationTest,DiscardFusionInternalBF16Changes)427 TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) {
428 auto module = CreateNewVerifiedModule();
429 auto builder = HloComputation::Builder(TestName());
430 Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
431
432 HloInstruction* param = builder.AddInstruction(
433 HloInstruction::CreateParameter(0, shape, "param"));
434 HloInstruction* add = builder.AddInstruction(
435 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param));
436
437 auto builder_f = HloComputation::Builder("fusion");
438 HloInstruction* a_f =
439 builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
440 HloInstruction* b_f =
441 builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
442 HloInstruction* add_f = builder_f.AddInstruction(
443 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f));
444 HloInstruction* dot_f =
445 builder_f.AddInstruction(CreateDot(shape, add_f, add_f));
446 auto comp_f = module->AddEmbeddedComputation(builder_f.Build());
447 auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
448 dot_f->shape(), HloInstruction::FusionKind::kCustom, {add, add}, comp_f));
449
450 auto computation = module->AddEntryComputation(builder.Build());
451
452 EXPECT_FALSE(PropagatePrecision(module.get()));
453 EXPECT_EQ(computation->root_instruction(), fusion);
454 }
455
456 // Tests that if 1) the root instruction of a fusion is a tuple, 2) the fusion
457 // outputs are only used by a dot, and 3) one element of the tuple is used by
458 // an add in the fusion computation, then the propagation pass should create a
459 // convert in the fusion computation to keep the add's operand in F32 but change
460 // the fusion output to BF16. E.g., the following fusion computation
461 // (F32, F32) fusion_computation(F32 a, F32 b)
462 // = tuple(F32 a, F32 add(F32 a, F32 b))
463 // will be changed to
464 // (BF16, BF16) fusion_computation(F32 a, F32 b)
465 // = tuple(BF16 convert(a), BF16 add(F32 a, F32 b))
TEST_F(BFloat16PropagationTest,ConvertTupleFusionElementIfUsedByAdd)466 TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) {
467 auto module = CreateNewVerifiedModule();
468 auto builder = HloComputation::Builder(TestName());
469 Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
470
471 HloInstruction* param = builder.AddInstruction(
472 HloInstruction::CreateParameter(0, shape, "param"));
473 HloInstruction* add = builder.AddInstruction(
474 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param));
475
476 auto builder_f = HloComputation::Builder("fusion0");
477 HloInstruction* a_f =
478 builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
479 HloInstruction* b_f =
480 builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
481 HloInstruction* add_f = builder_f.AddInstruction(
482 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f));
483 HloInstruction* tuple_f =
484 builder_f.AddInstruction(HloInstruction::CreateTuple({a_f, add_f}));
485 auto comp_f = module->AddEmbeddedComputation(builder_f.Build());
486 auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
487 tuple_f->shape(), HloInstruction::FusionKind::kCustom, {add, add},
488 comp_f));
489
490 HloInstruction* gte0 = builder.AddInstruction(
491 HloInstruction::CreateGetTupleElement(shape, fusion, 0));
492 HloInstruction* gte1 = builder.AddInstruction(
493 HloInstruction::CreateGetTupleElement(shape, fusion, 1));
494 HloInstruction* dot = builder.AddInstruction(CreateDot(shape, gte0, gte1));
495
496 auto computation = module->AddEntryComputation(builder.Build());
497
498 EXPECT_TRUE(PropagatePrecision(module.get()));
499
500 EXPECT_EQ(computation->root_instruction(), dot);
501 EXPECT_TRUE(OutputsBF16(gte0));
502 EXPECT_TRUE(OutputsBF16(gte1));
503 EXPECT_FALSE(OutputsBF16(a_f));
504 EXPECT_FALSE(OutputsBF16(b_f));
505 EXPECT_TRUE(OutputsBF16(add_f));
506 auto new_fusion_root = comp_f->root_instruction();
507 EXPECT_EQ(new_fusion_root->opcode(), HloOpcode::kTuple);
508 EXPECT_EQ(new_fusion_root->operand(1), add_f);
509 EXPECT_EQ(new_fusion_root->operand(0)->opcode(), HloOpcode::kConvert);
510 EXPECT_TRUE(OutputsBF16(new_fusion_root->operand(0)));
511 }
512
513 // A select over tuples does not define the leaf buffers, so the types in
514 // on_true and on_false must match, so that as long as one of them is F32, the
515 // other must be F32 as well.
TEST_F(BFloat16PropagationTest,SelectOverTuples)516 TEST_F(BFloat16PropagationTest, SelectOverTuples) {
517 auto module = CreateNewVerifiedModule();
518 auto builder = HloComputation::Builder(TestName());
519 Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
520
521 HloInstruction* param = builder.AddInstruction(
522 HloInstruction::CreateParameter(0, shape, "param"));
523 HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateParameter(
524 1, ShapeUtil::MakeShape(PRED, {}), "pred"));
525
526 HloInstruction* add0 = builder.AddInstruction(
527 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param));
528 HloInstruction* add1 = builder.AddInstruction(
529 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, param));
530 HloInstruction* tuple0 =
531 builder.AddInstruction(HloInstruction::CreateTuple({param, add0}));
532 HloInstruction* tuple1 =
533 builder.AddInstruction(HloInstruction::CreateTuple({param, add1}));
534 HloInstruction* sel = builder.AddInstruction(HloInstruction::CreateTernary(
535 tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1));
536 HloInstruction* gte0 = builder.AddInstruction(
537 HloInstruction::CreateGetTupleElement(shape, sel, 0));
538 HloInstruction* gte1 = builder.AddInstruction(
539 HloInstruction::CreateGetTupleElement(shape, sel, 1));
540 HloInstruction* xpose =
541 builder.AddInstruction(HloInstruction::CreateTranspose(
542 ShapeUtil::MakeShape(F32, {4, 2}), gte0, {1, 0}));
543 HloInstruction* dot = builder.AddInstruction(
544 CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, gte1));
545
546 auto computation = module->AddEntryComputation(builder.Build());
547
548 EXPECT_TRUE(PropagatePrecision(module.get()));
549
550 EXPECT_EQ(computation->root_instruction(), dot);
551 EXPECT_FALSE(OutputsBF16(add0));
552 EXPECT_FALSE(OutputsBF16(add1));
553 EXPECT_FALSE(OutputsBF16(gte0));
554 EXPECT_FALSE(OutputsBF16(gte1));
555 EXPECT_TRUE(OutputsBF16(xpose));
556 }
557
558 // Tests that BF16 is propagated properly through a while computation with
559 // non-tuple input/output.
TEST_F(BFloat16PropagationTest,PropagateThroughSimpleWhile)560 TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) {
561 auto module = CreateNewVerifiedModule();
562 auto builder = HloComputation::Builder(TestName());
563 Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
564
565 HloInstruction* param0 = builder.AddInstruction(
566 HloInstruction::CreateParameter(0, shape, "param0"));
567 HloInstruction* param1 = builder.AddInstruction(
568 HloInstruction::CreateParameter(1, shape, "param1"));
569 HloInstruction* add = builder.AddInstruction(
570 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
571
572 auto builder_cond = HloComputation::Builder("cond");
573 auto cond_param = builder_cond.AddInstruction(
574 HloInstruction::CreateParameter(0, shape, "cond_param"));
575 auto cond_dot =
576 builder_cond.AddInstruction(CreateDot(shape, cond_param, cond_param));
577 auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateCompare(
578 ShapeUtil::MakeShape(PRED, {}),
579 builder_cond.AddInstruction(HloInstruction::CreateReshape(
580 ShapeUtil::MakeShape(F32, {}),
581 builder_cond.AddInstruction(
582 HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
583 cond_dot, {0, 0}, {1, 1}, {1, 1})))),
584 builder_cond.AddInstruction(HloInstruction::CreateReshape(
585 ShapeUtil::MakeShape(F32, {}),
586 builder_cond.AddInstruction(
587 HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
588 cond_dot, {1, 1}, {2, 2}, {1, 1})))),
589 ComparisonDirection::kGt));
590 auto cond = module->AddEmbeddedComputation(builder_cond.Build());
591
592 auto builder_body = HloComputation::Builder("body");
593 auto body_param = builder_body.AddInstruction(
594 HloInstruction::CreateParameter(0, shape, "body_param"));
595 auto body_dot =
596 builder_body.AddInstruction(CreateDot(shape, body_param, body_param));
597 auto body = module->AddEmbeddedComputation(builder_body.Build());
598
599 auto while_hlo = builder.AddInstruction(
600 HloInstruction::CreateWhile(shape, cond, body, add));
601
602 auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo));
603 auto computation = module->AddEntryComputation(builder.Build());
604
605 EXPECT_TRUE(PropagatePrecision(module.get()));
606
607 EXPECT_EQ(computation->root_instruction(), dot);
608 EXPECT_TRUE(
609 ShapeUtil::Equal(cond_root->shape(), ShapeUtil::MakeShape(PRED, {})));
610 EXPECT_TRUE(OutputsBF16(add));
611 EXPECT_TRUE(OutputsBF16(body_dot));
612 EXPECT_TRUE(OutputsBF16(body_param));
613 EXPECT_TRUE(OutputsBF16(cond_param));
614 EXPECT_FALSE(OutputsBF16(dot));
615 }
616
617 // Tests that if the while condition prevents using BF16, no changes should be
618 // made to the while body and thus the fusion node inside it.
TEST_F(BFloat16PropagationTest,ConditionPreventsPropagationForFusionInsideWhile)619 TEST_F(BFloat16PropagationTest,
620 ConditionPreventsPropagationForFusionInsideWhile) {
621 auto module = CreateNewVerifiedModule();
622 auto builder = HloComputation::Builder(TestName());
623 Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
624
625 HloInstruction* param0 = builder.AddInstruction(
626 HloInstruction::CreateParameter(0, shape, "param0"));
627 HloInstruction* param1 = builder.AddInstruction(
628 HloInstruction::CreateParameter(1, shape, "param1"));
629 HloInstruction* add = builder.AddInstruction(
630 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
631
632 auto builder_cond = HloComputation::Builder("cond");
633 auto cond_param = builder_cond.AddInstruction(
634 HloInstruction::CreateParameter(0, shape, "cond_param"));
635 builder_cond.AddInstruction(HloInstruction::CreateCompare(
636 ShapeUtil::MakeShape(PRED, {}),
637 builder_cond.AddInstruction(HloInstruction::CreateReshape(
638 ShapeUtil::MakeShape(F32, {}),
639 builder_cond.AddInstruction(HloInstruction::CreateSlice(
640 ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {0, 0}, {1, 1},
641 {1, 1})))),
642 builder_cond.AddInstruction(HloInstruction::CreateReshape(
643 ShapeUtil::MakeShape(F32, {}),
644 builder_cond.AddInstruction(HloInstruction::CreateSlice(
645 ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {1, 1}, {2, 2},
646 {1, 1})))),
647 ComparisonDirection::kGt));
648 auto cond = module->AddEmbeddedComputation(builder_cond.Build());
649
650 auto builder_body = HloComputation::Builder("body");
651 auto body_param = builder_body.AddInstruction(
652 HloInstruction::CreateParameter(0, shape, "body_param"));
653 auto body_transpose = builder_body.AddInstruction(
654 HloInstruction::CreateTranspose(shape, body_param, {0, 1}));
655
656 auto builder_f = HloComputation::Builder("fusion");
657 HloInstruction* a_f =
658 builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
659 builder_f.AddInstruction(HloInstruction::CreateTranspose(shape, a_f, {0, 1}));
660 auto comp_f = module->AddEmbeddedComputation(builder_f.Build());
661 auto body_fusion = builder_body.AddInstruction(HloInstruction::CreateFusion(
662 shape, HloInstruction::FusionKind::kCustom, {body_transpose}, comp_f));
663 auto body = module->AddEmbeddedComputation(builder_body.Build());
664
665 auto while_hlo = builder.AddInstruction(
666 HloInstruction::CreateWhile(shape, cond, body, add));
667
668 auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo));
669 auto computation = module->AddEntryComputation(builder.Build());
670
671 EXPECT_FALSE(PropagatePrecision(module.get()));
672 EXPECT_EQ(computation->root_instruction(), dot);
673 EXPECT_FALSE(OutputsBF16(add));
674 EXPECT_FALSE(OutputsBF16(body_fusion));
675 EXPECT_FALSE(OutputsBF16(body_param));
676 EXPECT_FALSE(OutputsBF16(body_transpose));
677 EXPECT_FALSE(OutputsBF16(a_f));
678 }
679
680 // Tests that BF16 is propagated properly through while computations with
681 // tuple-shaped input/output.
TEST_F(BFloat16PropagationTest,PropagateThroughTupleWhile)682 TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) {
683 auto module = CreateNewVerifiedModule();
684 auto builder = HloComputation::Builder(TestName());
685 Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
686
687 HloInstruction* param0 = builder.AddInstruction(
688 HloInstruction::CreateParameter(0, shape, "param0"));
689 HloInstruction* param1 = builder.AddInstruction(
690 HloInstruction::CreateParameter(1, shape, "param1"));
691 HloInstruction* add0 = builder.AddInstruction(
692 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
693 HloInstruction* add1 = builder.AddInstruction(
694 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
695 HloInstruction* tuple =
696 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
697
698 auto builder_cond = HloComputation::Builder("cond");
699 auto cond_param = builder_cond.AddInstruction(
700 HloInstruction::CreateParameter(0, tuple->shape(), "cond_param"));
701 auto cond_lhs = builder_cond.AddInstruction(
702 HloInstruction::CreateGetTupleElement(shape, cond_param, 0));
703 auto cond_rhs = builder_cond.AddInstruction(
704 HloInstruction::CreateGetTupleElement(shape, cond_param, 1));
705 // This add should prevent RHS from using BF16
706 auto cond_add_rhs = builder_cond.AddInstruction(
707 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs));
708 auto cond_dot =
709 builder_cond.AddInstruction(CreateDot(shape, cond_lhs, cond_add_rhs));
710 builder_cond.AddInstruction(HloInstruction::CreateCompare(
711 ShapeUtil::MakeShape(PRED, {}),
712 builder_cond.AddInstruction(HloInstruction::CreateReshape(
713 ShapeUtil::MakeShape(F32, {}),
714 builder_cond.AddInstruction(
715 HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
716 cond_dot, {0, 0}, {1, 1}, {1, 1})))),
717 builder_cond.AddInstruction(HloInstruction::CreateReshape(
718 ShapeUtil::MakeShape(F32, {}),
719 builder_cond.AddInstruction(
720 HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
721 cond_dot, {1, 1}, {2, 2}, {1, 1})))),
722 ComparisonDirection::kGt));
723 auto cond = module->AddEmbeddedComputation(builder_cond.Build());
724
725 auto builder_body = HloComputation::Builder("body");
726 auto body_param = builder_body.AddInstruction(
727 HloInstruction::CreateParameter(0, tuple->shape(), "body_param"));
728 auto body_lhs = builder_body.AddInstruction(
729 HloInstruction::CreateGetTupleElement(shape, body_param, 0));
730 auto body_rhs = builder_body.AddInstruction(
731 HloInstruction::CreateGetTupleElement(shape, body_param, 1));
732 auto body_dot1 =
733 builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs));
734 auto body_dot2 =
735 builder_body.AddInstruction(CreateDot(shape, body_rhs, body_lhs));
736 auto body_transpose = builder_body.AddInstruction(
737 HloInstruction::CreateTranspose(shape, body_dot2, {0, 1}));
738 builder_body.AddInstruction(
739 HloInstruction::CreateTuple({body_dot1, body_transpose}));
740 auto body = module->AddEmbeddedComputation(builder_body.Build());
741
742 auto while_hlo = builder.AddInstruction(
743 HloInstruction::CreateWhile(tuple->shape(), cond, body, tuple));
744
745 auto lhs = builder.AddInstruction(
746 HloInstruction::CreateGetTupleElement(shape, while_hlo, 0));
747 auto rhs = builder.AddInstruction(
748 HloInstruction::CreateGetTupleElement(shape, while_hlo, 1));
749 auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs));
750 auto computation = module->AddEntryComputation(builder.Build());
751
752 EXPECT_TRUE(PropagatePrecision(module.get()));
753
754 EXPECT_EQ(computation->root_instruction(), dot);
755 EXPECT_TRUE(OutputsBF16(lhs));
756 EXPECT_FALSE(OutputsBF16(rhs));
757 EXPECT_TRUE(OutputsBF16(body_dot1));
758 EXPECT_TRUE(OutputsBF16(body_lhs));
759 EXPECT_FALSE(OutputsBF16(body_rhs));
760 EXPECT_FALSE(OutputsBF16(body_dot2));
761 EXPECT_FALSE(OutputsBF16(body_transpose));
762 EXPECT_TRUE(OutputsBF16(cond_lhs));
763 EXPECT_FALSE(OutputsBF16(cond_rhs));
764 EXPECT_TRUE(OutputsBF16(add0));
765 EXPECT_FALSE(OutputsBF16(add1));
766 }
767
768 // Tests that BF16 is not propagated through multiple whiles that invoke the
769 // same computation as long as one while prevents the propagation.
TEST_F(BFloat16PropagationTest,DoNotPropagateWhilesCallingSameComputation)770 TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
771 auto module = CreateNewVerifiedModule();
772 auto builder = HloComputation::Builder(TestName());
773 Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
774
775 HloInstruction* param0 = builder.AddInstruction(
776 HloInstruction::CreateParameter(0, shape, "param0"));
777 HloInstruction* param1 = builder.AddInstruction(
778 HloInstruction::CreateParameter(1, shape, "param1"));
779 HloInstruction* add0 = builder.AddInstruction(
780 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
781 HloInstruction* add1 = builder.AddInstruction(
782 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
783 HloInstruction* add2 = builder.AddInstruction(
784 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
785 HloInstruction* add3 = builder.AddInstruction(
786 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
787 HloInstruction* tuple0 =
788 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
789 HloInstruction* tuple1 =
790 builder.AddInstruction(HloInstruction::CreateTuple({add2, add3}));
791
792 // Condition computation for the first while.
793 auto builder_cond0 = HloComputation::Builder("cond0");
794 auto cond0_param = builder_cond0.AddInstruction(
795 HloInstruction::CreateParameter(0, tuple0->shape(), "cond0_param"));
796 auto cond0_lhs = builder_cond0.AddInstruction(
797 HloInstruction::CreateGetTupleElement(shape, cond0_param, 0));
798 auto cond0_rhs = builder_cond0.AddInstruction(
799 HloInstruction::CreateGetTupleElement(shape, cond0_param, 1));
800 // This add should prevent RHS from using BF16
801 auto cond0_add_rhs =
802 builder_cond0.AddInstruction(HloInstruction::CreateBinary(
803 shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs));
804 auto cond0_dot =
805 builder_cond0.AddInstruction(CreateDot(shape, cond0_lhs, cond0_add_rhs));
806 builder_cond0.AddInstruction(HloInstruction::CreateCompare(
807 ShapeUtil::MakeShape(PRED, {}),
808 builder_cond0.AddInstruction(HloInstruction::CreateReshape(
809 ShapeUtil::MakeShape(F32, {}),
810 builder_cond0.AddInstruction(
811 HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
812 cond0_dot, {0, 0}, {1, 1}, {1, 1})))),
813 builder_cond0.AddInstruction(HloInstruction::CreateReshape(
814 ShapeUtil::MakeShape(F32, {}),
815 builder_cond0.AddInstruction(
816 HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
817 cond0_dot, {1, 1}, {2, 2}, {1, 1})))),
818 ComparisonDirection::kGt));
819 auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build());
820
821 // Condition computation for the second while.
822 auto builder_cond1 = HloComputation::Builder("cond1");
823 auto cond1_param = builder_cond1.AddInstruction(
824 HloInstruction::CreateParameter(0, tuple1->shape(), "cond1_param"));
825 auto cond1_lhs = builder_cond1.AddInstruction(
826 HloInstruction::CreateGetTupleElement(shape, cond1_param, 0));
827 auto cond1_rhs = builder_cond1.AddInstruction(
828 HloInstruction::CreateGetTupleElement(shape, cond1_param, 1));
829 // This add should prevent LHS from using BF16
830 auto cond1_add_lhs =
831 builder_cond1.AddInstruction(HloInstruction::CreateBinary(
832 shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs));
833 auto cond1_dot =
834 builder_cond1.AddInstruction(CreateDot(shape, cond1_add_lhs, cond1_rhs));
835 builder_cond1.AddInstruction(HloInstruction::CreateCompare(
836 ShapeUtil::MakeShape(PRED, {}),
837 builder_cond1.AddInstruction(HloInstruction::CreateReshape(
838 ShapeUtil::MakeShape(F32, {}),
839 builder_cond1.AddInstruction(
840 HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
841 cond1_dot, {0, 0}, {1, 1}, {1, 1})))),
842 builder_cond1.AddInstruction(HloInstruction::CreateReshape(
843 ShapeUtil::MakeShape(F32, {}),
844 builder_cond1.AddInstruction(
845 HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
846 cond1_dot, {1, 1}, {2, 2}, {1, 1})))),
847 ComparisonDirection::kGt));
848 auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build());
849
850 // Body computation shared by both whiles.
851 auto builder_body = HloComputation::Builder("body");
852 auto body_param = builder_body.AddInstruction(
853 HloInstruction::CreateParameter(0, tuple0->shape(), "body_param"));
854 auto body_lhs = builder_body.AddInstruction(
855 HloInstruction::CreateGetTupleElement(shape, body_param, 0));
856 auto body_rhs = builder_body.AddInstruction(
857 HloInstruction::CreateGetTupleElement(shape, body_param, 1));
858 auto body_dot =
859 builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs));
860 builder_body.AddInstruction(
861 HloInstruction::CreateTuple({body_dot, body_rhs}));
862 auto body = module->AddEmbeddedComputation(builder_body.Build());
863
864 auto while0 = builder.AddInstruction(
865 HloInstruction::CreateWhile(tuple0->shape(), cond0, body, tuple0));
866 auto while1 = builder.AddInstruction(
867 HloInstruction::CreateWhile(tuple1->shape(), cond1, body, tuple1));
868
869 auto lhs = builder.AddInstruction(
870 CreateDot(shape,
871 builder.AddInstruction(
872 HloInstruction::CreateGetTupleElement(shape, while0, 0)),
873 builder.AddInstruction(
874 HloInstruction::CreateGetTupleElement(shape, while0, 1))));
875 auto rhs = builder.AddInstruction(
876 CreateDot(shape,
877 builder.AddInstruction(
878 HloInstruction::CreateGetTupleElement(shape, while1, 0)),
879 builder.AddInstruction(
880 HloInstruction::CreateGetTupleElement(shape, while1, 1))));
881 auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs));
882 auto computation = module->AddEntryComputation(builder.Build());
883
884 EXPECT_TRUE(PropagatePrecision(module.get()));
885 EXPECT_FALSE(OutputsBF16(body_dot));
886 EXPECT_FALSE(OutputsBF16(body_rhs));
887 EXPECT_FALSE(OutputsBF16(body_lhs));
888 EXPECT_FALSE(OutputsBF16(cond0_lhs));
889 EXPECT_FALSE(OutputsBF16(cond0_rhs));
890 EXPECT_FALSE(OutputsBF16(cond1_lhs));
891 EXPECT_FALSE(OutputsBF16(cond1_rhs));
892 EXPECT_TRUE(OutputsBF16(cond0_add_rhs));
893 EXPECT_TRUE(OutputsBF16(cond1_add_lhs));
894 EXPECT_EQ(computation->root_instruction(), dot);
895 }
896
897 // Tests that if this pass turns an F32 -> BF16 conversion into a no-op (BF16 ->
898 // BF16 conversion), then it will remove that conversion.
TEST_F(BFloat16PropagationTest,NoopConversionRemoved)899 TEST_F(BFloat16PropagationTest, NoopConversionRemoved) {
900 auto builder = HloComputation::Builder(TestName());
901 Shape f32_shape = ShapeUtil::MakeShape(F32, {4, 4});
902 Shape bf16_shape = ShapeUtil::MakeShape(BF16, {4, 4});
903
904 HloInstruction* param = builder.AddInstruction(
905 HloInstruction::CreateParameter(0, f32_shape, "param"));
906 HloInstruction* add0 = builder.AddInstruction(
907 HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, param, param));
908 HloInstruction* add1 = builder.AddInstruction(
909 HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, param, param));
910 HloInstruction* tuple =
911 builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
912 HloInstruction* gte0 = builder.AddInstruction(
913 HloInstruction::CreateGetTupleElement(f32_shape, tuple, 0));
914 HloInstruction* gte1 = builder.AddInstruction(
915 HloInstruction::CreateGetTupleElement(f32_shape, tuple, 1));
916 HloInstruction* convert0 =
917 builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte0));
918 HloInstruction* convert1 =
919 builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte1));
920 HloInstruction* add2 = builder.AddInstruction(HloInstruction::CreateBinary(
921 bf16_shape, HloOpcode::kAdd, convert0, convert1));
922
923 auto module = CreateNewVerifiedModule();
924 auto computation = module->AddEntryComputation(builder.Build());
925
926 EXPECT_TRUE(PropagatePrecision(module.get()));
927
928 EXPECT_EQ(computation->root_instruction(), add2);
929 EXPECT_EQ(add2->operand(0), add0);
930 EXPECT_EQ(add2->operand(1), add1);
931 EXPECT_EQ(add0->shape().element_type(), BF16);
932 EXPECT_EQ(add1->shape().element_type(), BF16);
933 }
934
TEST_F(BFloat16PropagationTest,TupleDomain)935 TEST_F(BFloat16PropagationTest, TupleDomain) {
936 auto builder = HloComputation::Builder(TestName());
937 Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
938
939 HloInstruction* a =
940 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
941 HloInstruction* b =
942 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
943 HloInstruction* a_trans =
944 builder.AddInstruction(HloInstruction::CreateTranspose(shape, a, {0, 1}));
945 HloInstruction* b_trans =
946 builder.AddInstruction(HloInstruction::CreateTranspose(shape, b, {0, 1}));
947 HloInstruction* tuple =
948 builder.AddInstruction(HloInstruction::CreateTuple({a_trans, b_trans}));
949 HloInstruction* domain = builder.AddInstruction(
950 HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr));
951 HloInstruction* a_gte = builder.AddInstruction(
952 HloInstruction::CreateGetTupleElement(shape, domain, 0));
953 HloInstruction* b_gte = builder.AddInstruction(
954 HloInstruction::CreateGetTupleElement(shape, domain, 1));
955 HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a_gte, b_gte));
956 HloInstruction* root = builder.AddInstruction(
957 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
958
959 auto module = CreateNewVerifiedModule();
960 auto computation = module->AddEntryComputation(builder.Build());
961
962 EXPECT_TRUE(PropagatePrecision(module.get()));
963 EXPECT_EQ(computation->root_instruction(), root);
964
965 // test BF16 propagated through domain
966 EXPECT_EQ(ShapeUtil::GetTupleElementShape(domain->shape(), 0).element_type(),
967 BF16);
968 EXPECT_EQ(ShapeUtil::GetTupleElementShape(domain->shape(), 1).element_type(),
969 BF16);
970
971 EXPECT_TRUE(OutputsBF16(a_trans));
972 EXPECT_TRUE(OutputsBF16(b_trans));
973 EXPECT_TRUE(OutputsBF16(a_gte));
974 EXPECT_TRUE(OutputsBF16(b_gte));
975 EXPECT_FALSE(OutputsBF16(a));
976 EXPECT_FALSE(OutputsBF16(b));
977 }
978
979 // Tests that bf16 is not propagated through a domain in case its input cannot
980 // be propagated. In the case below the input of the domain is the parameter
981 // tuple which cannot be propagated, so the domain instruction is not propagated
982 // either.
TEST_F(BFloat16PropagationTest,TupleDomainNoPropagation)983 TEST_F(BFloat16PropagationTest, TupleDomainNoPropagation) {
984 auto builder = HloComputation::Builder(TestName());
985 Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
986 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
987
988 HloInstruction* param = builder.AddInstruction(
989 HloInstruction::CreateParameter(0, tuple_shape, "param"));
990 HloInstruction* domain = builder.AddInstruction(
991 HloInstruction::CreateDomain(param->shape(), param, nullptr, nullptr));
992 HloInstruction* a_gte = builder.AddInstruction(
993 HloInstruction::CreateGetTupleElement(shape, domain, 0));
994 HloInstruction* b_gte = builder.AddInstruction(
995 HloInstruction::CreateGetTupleElement(shape, domain, 1));
996 HloInstruction* a_trans = builder.AddInstruction(
997 HloInstruction::CreateTranspose(shape, a_gte, {0, 1}));
998 HloInstruction* b_trans = builder.AddInstruction(
999 HloInstruction::CreateTranspose(shape, b_gte, {0, 1}));
1000 HloInstruction* dot =
1001 builder.AddInstruction(CreateDot(shape, a_trans, b_trans));
1002 HloInstruction* root = builder.AddInstruction(
1003 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
1004
1005 auto module = CreateNewVerifiedModule();
1006 auto computation = module->AddEntryComputation(builder.Build());
1007
1008 EXPECT_TRUE(PropagatePrecision(module.get()));
1009
1010 EXPECT_EQ(computation->root_instruction(), root);
1011 EXPECT_TRUE(OutputsBF16(a_trans));
1012 EXPECT_TRUE(OutputsBF16(b_trans));
1013 EXPECT_FALSE(OutputsBF16(a_gte));
1014 EXPECT_FALSE(OutputsBF16(b_gte));
1015 EXPECT_FALSE(OutputsBF16(domain));
1016 EXPECT_FALSE(OutputsBF16(param));
1017 }
1018
1019 } // namespace xla
1020