1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/bfloat16_propagation.h"
17 #include "tensorflow/compiler/xla/service/bfloat16_support.h"
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_module.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/test_helpers.h"
25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
26 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 
29 namespace xla {
30 
31 // A class specifying the BF16 support used to test the propagation pass. It
32 // specifies that BF16 and mixed precision are supported in all HloInstructions,
33 // and that kDot reduces its operands precision to BF16.
34 class TestBFloat16Support : public BFloat16Support {
35  public:
TestBFloat16Support()36   TestBFloat16Support() {}
~TestBFloat16Support()37   ~TestBFloat16Support() override {}
38 
SupportsBF16Operand(const HloInstruction & hlo,int64 operand_index) const39   bool SupportsBF16Operand(const HloInstruction& hlo,
40                            int64 operand_index) const override {
41     return true;
42   }
43 
SupportsBF16Output(const HloInstruction & hlo) const44   bool SupportsBF16Output(const HloInstruction& hlo) const override {
45     return true;
46   }
47 
SupportsMixedPrecisions(const HloInstruction & hlo) const48   bool SupportsMixedPrecisions(const HloInstruction& hlo) const override {
49     return true;
50   }
51 
EffectiveOperandPrecisionIsBF16(const HloInstruction & hlo,int64 operand_index) const52   bool EffectiveOperandPrecisionIsBF16(const HloInstruction& hlo,
53                                        int64 operand_index) const override {
54     return hlo.opcode() == HloOpcode::kDot;
55   }
56 };
57 
58 class BFloat16PropagationTest : public HloTestBase {
59  protected:
BFloat16PropagationTest()60   BFloat16PropagationTest()
61       : HloTestBase(/*verifier_layout_sensitive=*/false,
62                     /*allow_mixed_precision_in_hlo_verifier=*/true) {}
63 
64   // Runs the propagation pass on the given module, and returns whether the
65   // module is changed after this pass.
PropagatePrecision(HloModule * module)66   bool PropagatePrecision(HloModule* module) {
67     TestBFloat16Support bfloat16_support;
68     BFloat16Propagation propagation(&bfloat16_support);
69     StatusOr<bool> result = propagation.Run(module);
70     EXPECT_IS_OK(result.status());
71     return result.ValueOrDie();
72   }
73 
74   // Returns whether the given HloInstruction's output element type is BF16 or
75   // the only use of it is converting to BF16.
OutputsBF16(const HloInstruction * inst)76   bool OutputsBF16(const HloInstruction* inst) {
77     if (inst->shape().element_type() == BF16) {
78       return true;
79     }
80     return inst->user_count() == 1 &&
81            inst->users()[0]->opcode() == HloOpcode::kConvert &&
82            inst->users()[0]->shape().element_type() == BF16;
83   }
84 
CreateDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs)85   std::unique_ptr<HloInstruction> CreateDot(const Shape& shape,
86                                             HloInstruction* lhs,
87                                             HloInstruction* rhs) {
88     DotDimensionNumbers dot_dnums;
89     dot_dnums.add_lhs_contracting_dimensions(1);
90     dot_dnums.add_rhs_contracting_dimensions(0);
91     return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums,
92                                      DefaultPrecisionConfig(2));
93   }
94 };
95 
96 // Tests that BF16 can propagate through select over non-tuple buffers, but not
97 // through add where reducing operand precision can affect the result.
TEST_F(BFloat16PropagationTest,PropagateThroughSelectButNotAdd)98 TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) {
99   auto builder = HloComputation::Builder(TestName());
100   Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
101 
102   HloInstruction* a =
103       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
104   HloInstruction* b =
105       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
106   HloInstruction* c =
107       builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "c"));
108   HloInstruction* add0 = builder.AddInstruction(
109       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
110   HloInstruction* add1 = builder.AddInstruction(
111       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, b));
112   HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateCompare(
113       ShapeUtil::MakeShape(PRED, {2, 4}), a, b, ComparisonDirection::kEq));
114   HloInstruction* sel = builder.AddInstruction(
115       HloInstruction::CreateTernary(shape, HloOpcode::kSelect, pred, c, add1));
116   HloInstruction* xpose =
117       builder.AddInstruction(HloInstruction::CreateTranspose(
118           ShapeUtil::MakeShape(F32, {4, 2}), sel, {1, 0}));
119   HloInstruction* dot = builder.AddInstruction(
120       CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, a));
121   HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary(
122       ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot));
123 
124   auto module = CreateNewVerifiedModule();
125   auto computation = module->AddEntryComputation(builder.Build());
126 
127   EXPECT_TRUE(PropagatePrecision(module.get()));
128 
129   EXPECT_EQ(computation->root_instruction(), root);
130   EXPECT_TRUE(OutputsBF16(xpose));
131   EXPECT_TRUE(OutputsBF16(sel));
132   EXPECT_TRUE(OutputsBF16(add1));
133   EXPECT_FALSE(OutputsBF16(add0));
134   EXPECT_FALSE(OutputsBF16(a));
135   EXPECT_FALSE(OutputsBF16(b));
136   EXPECT_FALSE(OutputsBF16(c));
137 }
138 
TEST_F(BFloat16PropagationTest,PropagateThroughMaxPoolReduceWindow)139 TEST_F(BFloat16PropagationTest, PropagateThroughMaxPoolReduceWindow) {
140   auto module = CreateNewVerifiedModule();
141 
142   auto sub_builder = HloComputation::Builder("max");
143   HloInstruction* p0 = sub_builder.AddInstruction(
144       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "a"));
145   HloInstruction* p1 = sub_builder.AddInstruction(
146       HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "b"));
147   sub_builder.AddInstruction(HloInstruction::CreateBinary(
148       ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, p0, p1));
149   auto max_computation = module->AddEmbeddedComputation(sub_builder.Build());
150 
151   auto builder = HloComputation::Builder(TestName());
152   Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
153 
154   HloInstruction* a =
155       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
156   HloInstruction* b =
157       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
158   HloInstruction* c =
159       builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "c"));
160   HloInstruction* add = builder.AddInstruction(
161       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
162   Window window;
163   WindowDimension dim;
164   dim.set_size(2);
165   dim.set_stride(1);
166   dim.set_padding_high(1);
167   dim.set_window_dilation(1);
168   dim.set_base_dilation(1);
169   *window.add_dimensions() = dim;
170   *window.add_dimensions() = dim;
171   HloInstruction* rw =
172       builder.AddInstruction(HloInstruction::CreateReduceWindow(
173           shape, add,
174           builder.AddInstruction(
175               HloInstruction::CreateConstant(LiteralUtil::Zero(F32))),
176           window, max_computation));
177   HloInstruction* xpose =
178       builder.AddInstruction(HloInstruction::CreateTranspose(
179           ShapeUtil::MakeShape(F32, {4, 2}), c, {1, 0}));
180   HloInstruction* dot = builder.AddInstruction(
181       CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, rw));
182   HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary(
183       ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot));
184 
185   auto computation = module->AddEntryComputation(builder.Build());
186 
187   EXPECT_TRUE(PropagatePrecision(module.get()));
188 
189   EXPECT_EQ(computation->root_instruction(), root);
190   EXPECT_TRUE(OutputsBF16(add));
191   EXPECT_TRUE(OutputsBF16(xpose));
192   EXPECT_TRUE(OutputsBF16(rw));
193 }
194 
195 // Tests that side-effecting all-reduce should not be changed.
TEST_F(BFloat16PropagationTest,DoNotChangeAllReduce)196 TEST_F(BFloat16PropagationTest, DoNotChangeAllReduce) {
197   auto module = CreateNewVerifiedModule();
198 
199   auto builder = HloComputation::Builder(TestName());
200   Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
201   HloInstruction* a =
202       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
203   HloInstruction* b =
204       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
205   auto rb = HloComputation::Builder(TestName());
206   rb.AddInstruction(HloInstruction::CreateBinary(
207       shape, HloOpcode::kAdd,
208       rb.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")),
209       rb.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"))));
210   auto reduction = module->AddEmbeddedComputation(rb.Build());
211   HloInstruction* all_reduce =
212       builder.AddInstruction(HloInstruction::CreateAllReduce(
213           ShapeUtil::MakeTupleShape({shape, shape}), {a, b}, reduction,
214           /*replica_groups=*/{}, /*barrier=*/"", /*all_reduce_id=*/1));
215   HloInstruction* gte0 = builder.AddInstruction(
216       HloInstruction::CreateGetTupleElement(shape, all_reduce, 0));
217   HloInstruction* gte1 = builder.AddInstruction(
218       HloInstruction::CreateGetTupleElement(shape, all_reduce, 1));
219   HloInstruction* dot = builder.AddInstruction(CreateDot(shape, gte0, gte1));
220   HloInstruction* root = builder.AddInstruction(
221       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
222 
223   auto computation = module->AddEntryComputation(builder.Build());
224 
225   EXPECT_FALSE(PropagatePrecision(module.get()));
226   EXPECT_EQ(computation->root_instruction(), root);
227 }
228 
229 // Tests that if a constant is converted to BF16 then its literal must also be
230 // converted.
TEST_F(BFloat16PropagationTest,ConvertConstantLiteral)231 TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
232   auto builder = HloComputation::Builder(TestName());
233   Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
234   Array2D<float> array_a(4, 4);
235   array_a.FillUnique(1.0f);
236   Array2D<float> array_b(4, 4);
237   array_b.FillUnique(10.0f);
238 
239   HloInstruction* a = builder.AddInstruction(
240       HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_a)));
241   HloInstruction* b = builder.AddInstruction(
242       HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_b)));
243   HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a, b));
244 
245   auto module = CreateNewVerifiedModule();
246   auto computation = module->AddEntryComputation(builder.Build());
247 
248   EXPECT_TRUE(PropagatePrecision(module.get()));
249 
250   EXPECT_EQ(computation->root_instruction(), dot);
251   EXPECT_TRUE(OutputsBF16(dot->operand(0)));
252   EXPECT_TRUE(OutputsBF16(dot->operand(1)));
253   EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant);
254   EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant);
255   EXPECT_TRUE(LiteralTestUtil::Equal(
256       LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_a)),
257       dot->operand(0)->literal()));
258   EXPECT_TRUE(LiteralTestUtil::Equal(
259       LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_b)),
260       dot->operand(1)->literal()));
261 }
262 
263 // Tests that BF16 can be propagated through nested tuples.
TEST_F(BFloat16PropagationTest,PropagateThroughTuples)264 TEST_F(BFloat16PropagationTest, PropagateThroughTuples) {
265   auto builder = HloComputation::Builder(TestName());
266   Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
267 
268   HloInstruction* a =
269       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
270   HloInstruction* b =
271       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
272   HloInstruction* add0 = builder.AddInstruction(
273       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
274   HloInstruction* add1 = builder.AddInstruction(
275       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, a));
276   HloInstruction* add2 = builder.AddInstruction(
277       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, b, b));
278   HloInstruction* xpose =
279       builder.AddInstruction(HloInstruction::CreateTranspose(
280           ShapeUtil::MakeShape(F32, {4, 2}), add1, {1, 0}));
281 
282   HloInstruction* tuple0 =
283       builder.AddInstruction(HloInstruction::CreateTuple({add0, add1, add2}));
284   HloInstruction* tuple1 =
285       builder.AddInstruction(HloInstruction::CreateTuple({tuple0, xpose}));
286 
287   HloInstruction* lhs = builder.AddInstruction(
288       HloInstruction::CreateGetTupleElement(xpose->shape(), tuple1, 1));
289   HloInstruction* rhs =
290       builder.AddInstruction(HloInstruction::CreateGetTupleElement(
291           add0->shape(),
292           builder.AddInstruction(HloInstruction::CreateGetTupleElement(
293               tuple0->shape(), tuple1, 0)),
294           0));
295   HloInstruction* dot = builder.AddInstruction(
296       CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs));
297 
298   HloInstruction* output_tuple =
299       builder.AddInstruction(HloInstruction::CreateTuple({dot, add2}));
300 
301   auto module = CreateNewVerifiedModule();
302   auto computation = module->AddEntryComputation(builder.Build());
303 
304   EXPECT_TRUE(PropagatePrecision(module.get()));
305 
306   EXPECT_EQ(computation->root_instruction(), output_tuple);
307   EXPECT_TRUE(OutputsBF16(xpose));
308   EXPECT_TRUE(OutputsBF16(add0));
309   EXPECT_TRUE(OutputsBF16(add1));
310   EXPECT_FALSE(OutputsBF16(add2));
311 }
312 
313 // Tests that even if an instruction does not define a buffer in its output, its
314 // shape must match the defining instruction.
TEST_F(BFloat16PropagationTest,SameValueReferencedTwice)315 TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) {
316   auto builder = HloComputation::Builder(TestName());
317   Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
318 
319   HloInstruction* a =
320       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
321   HloInstruction* b =
322       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
323   HloInstruction* add0 = builder.AddInstruction(
324       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
325   HloInstruction* add1 = builder.AddInstruction(
326       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, a));
327 
328   HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateTranspose(
329       ShapeUtil::MakeShape(F32, {4, 2}), add1, {1, 0}));
330 
331   HloInstruction* tuple =
332       builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
333   HloInstruction* rhs = builder.AddInstruction(
334       HloInstruction::CreateGetTupleElement(add1->shape(), tuple, 1));
335 
336   // lhs is the transpose of add1, and rhs is a get-tuple-element aliasing add1.
337   HloInstruction* dot = builder.AddInstruction(
338       CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs));
339 
340   auto module = CreateNewVerifiedModule();
341   auto computation = module->AddEntryComputation(builder.Build());
342 
343   EXPECT_TRUE(PropagatePrecision(module.get()));
344 
345   EXPECT_EQ(computation->root_instruction(), dot);
346   EXPECT_TRUE(OutputsBF16(add1));
347   EXPECT_TRUE(OutputsBF16(lhs));
348 
349   // add0 and rhs have been eliminated by simplification and DCE.
350 }
351 
352 // Tests that a non-fusion computation's root should not be changed.
TEST_F(BFloat16PropagationTest,DoNotChangeComputationRoot)353 TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) {
354   auto builder = HloComputation::Builder(TestName());
355   Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
356 
357   HloInstruction* a =
358       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
359   HloInstruction* b =
360       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
361   HloInstruction* add = builder.AddInstruction(
362       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
363 
364   HloInstruction* dot = builder.AddInstruction(CreateDot(shape, add, add));
365 
366   HloInstruction* tuple =
367       builder.AddInstruction(HloInstruction::CreateTuple({add, dot}));
368 
369   auto module = CreateNewVerifiedModule();
370   auto computation = module->AddEntryComputation(builder.Build());
371 
372   EXPECT_FALSE(PropagatePrecision(module.get()));
373 
374   EXPECT_EQ(computation->root_instruction(), tuple);
375   EXPECT_FALSE(OutputsBF16(add));
376 }
377 
378 // Tests that BF16 is propagated properly through fused computations.
TEST_F(BFloat16PropagationTest,PropagateThroughFusion)379 TEST_F(BFloat16PropagationTest, PropagateThroughFusion) {
380   auto module = CreateNewVerifiedModule();
381   auto builder = HloComputation::Builder(TestName());
382   Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
383 
384   HloInstruction* param = builder.AddInstruction(
385       HloInstruction::CreateParameter(0, shape, "param"));
386   HloInstruction* add = builder.AddInstruction(
387       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param));
388 
389   auto builder_f0 = HloComputation::Builder("fusion0");
390   HloInstruction* a_f0 =
391       builder_f0.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
392   HloInstruction* b_f0 =
393       builder_f0.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
394   HloInstruction* tuple_f0 =
395       builder_f0.AddInstruction(HloInstruction::CreateTuple({a_f0, b_f0}));
396   auto comp_f0 = module->AddEmbeddedComputation(builder_f0.Build());
397   auto fusion0 = builder.AddInstruction(HloInstruction::CreateFusion(
398       tuple_f0->shape(), HloInstruction::FusionKind::kCustom, {add, add},
399       comp_f0));
400 
401   auto builder_f1 = HloComputation::Builder("fusion1");
402   HloInstruction* p_f1 = builder_f1.AddInstruction(
403       HloInstruction::CreateParameter(0, tuple_f0->shape(), "param"));
404   HloInstruction* a_f1 = builder_f1.AddInstruction(
405       HloInstruction::CreateGetTupleElement(shape, p_f1, 0));
406   HloInstruction* b_f1 = builder_f1.AddInstruction(
407       HloInstruction::CreateGetTupleElement(shape, p_f1, 1));
408   HloInstruction* dot = builder_f1.AddInstruction(CreateDot(shape, a_f1, b_f1));
409   auto comp_f1 = module->AddEmbeddedComputation(builder_f1.Build());
410   auto fusion1 = builder.AddInstruction(HloInstruction::CreateFusion(
411       dot->shape(), HloInstruction::FusionKind::kCustom, {fusion0}, comp_f1));
412 
413   auto computation = module->AddEntryComputation(builder.Build());
414 
415   EXPECT_TRUE(PropagatePrecision(module.get()));
416 
417   EXPECT_EQ(computation->root_instruction(), fusion1);
418   EXPECT_TRUE(OutputsBF16(add));
419   EXPECT_TRUE(OutputsBF16(a_f0));
420   EXPECT_TRUE(OutputsBF16(b_f0));
421   EXPECT_TRUE(OutputsBF16(a_f1));
422   EXPECT_TRUE(OutputsBF16(b_f1));
423 }
424 
425 // Tests that changes to BF16 that cannot be propagated outside a fusion are
426 // discarded.
TEST_F(BFloat16PropagationTest,DiscardFusionInternalBF16Changes)427 TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) {
428   auto module = CreateNewVerifiedModule();
429   auto builder = HloComputation::Builder(TestName());
430   Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
431 
432   HloInstruction* param = builder.AddInstruction(
433       HloInstruction::CreateParameter(0, shape, "param"));
434   HloInstruction* add = builder.AddInstruction(
435       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param));
436 
437   auto builder_f = HloComputation::Builder("fusion");
438   HloInstruction* a_f =
439       builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
440   HloInstruction* b_f =
441       builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
442   HloInstruction* add_f = builder_f.AddInstruction(
443       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f));
444   HloInstruction* dot_f =
445       builder_f.AddInstruction(CreateDot(shape, add_f, add_f));
446   auto comp_f = module->AddEmbeddedComputation(builder_f.Build());
447   auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
448       dot_f->shape(), HloInstruction::FusionKind::kCustom, {add, add}, comp_f));
449 
450   auto computation = module->AddEntryComputation(builder.Build());
451 
452   EXPECT_FALSE(PropagatePrecision(module.get()));
453   EXPECT_EQ(computation->root_instruction(), fusion);
454 }
455 
456 // Tests that if 1) the root instruction of a fusion is a tuple, 2) the fusion
457 // outputs are only used by a dot, and 3) one element of the tuple is used by
458 // an add in the fusion computation, then the propagation pass should create a
459 // convert in the fusion computation to keep the add's operand in F32 but change
460 // the fusion output to BF16. E.g., the following fusion computation
461 //   (F32, F32) fusion_computation(F32 a, F32 b)
462 //     = tuple(F32 a, F32 add(F32 a, F32 b))
463 // will be changed to
464 //   (BF16, BF16) fusion_computation(F32 a, F32 b)
465 //     = tuple(BF16 convert(a), BF16 add(F32 a, F32 b))
TEST_F(BFloat16PropagationTest,ConvertTupleFusionElementIfUsedByAdd)466 TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) {
467   auto module = CreateNewVerifiedModule();
468   auto builder = HloComputation::Builder(TestName());
469   Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
470 
471   HloInstruction* param = builder.AddInstruction(
472       HloInstruction::CreateParameter(0, shape, "param"));
473   HloInstruction* add = builder.AddInstruction(
474       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param));
475 
476   auto builder_f = HloComputation::Builder("fusion0");
477   HloInstruction* a_f =
478       builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
479   HloInstruction* b_f =
480       builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
481   HloInstruction* add_f = builder_f.AddInstruction(
482       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f));
483   HloInstruction* tuple_f =
484       builder_f.AddInstruction(HloInstruction::CreateTuple({a_f, add_f}));
485   auto comp_f = module->AddEmbeddedComputation(builder_f.Build());
486   auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
487       tuple_f->shape(), HloInstruction::FusionKind::kCustom, {add, add},
488       comp_f));
489 
490   HloInstruction* gte0 = builder.AddInstruction(
491       HloInstruction::CreateGetTupleElement(shape, fusion, 0));
492   HloInstruction* gte1 = builder.AddInstruction(
493       HloInstruction::CreateGetTupleElement(shape, fusion, 1));
494   HloInstruction* dot = builder.AddInstruction(CreateDot(shape, gte0, gte1));
495 
496   auto computation = module->AddEntryComputation(builder.Build());
497 
498   EXPECT_TRUE(PropagatePrecision(module.get()));
499 
500   EXPECT_EQ(computation->root_instruction(), dot);
501   EXPECT_TRUE(OutputsBF16(gte0));
502   EXPECT_TRUE(OutputsBF16(gte1));
503   EXPECT_FALSE(OutputsBF16(a_f));
504   EXPECT_FALSE(OutputsBF16(b_f));
505   EXPECT_TRUE(OutputsBF16(add_f));
506   auto new_fusion_root = comp_f->root_instruction();
507   EXPECT_EQ(new_fusion_root->opcode(), HloOpcode::kTuple);
508   EXPECT_EQ(new_fusion_root->operand(1), add_f);
509   EXPECT_EQ(new_fusion_root->operand(0)->opcode(), HloOpcode::kConvert);
510   EXPECT_TRUE(OutputsBF16(new_fusion_root->operand(0)));
511 }
512 
513 // A select over tuples does not define the leaf buffers, so the types in
514 // on_true and on_false must match, so that as long as one of them is F32, the
515 // other must be F32 as well.
TEST_F(BFloat16PropagationTest,SelectOverTuples)516 TEST_F(BFloat16PropagationTest, SelectOverTuples) {
517   auto module = CreateNewVerifiedModule();
518   auto builder = HloComputation::Builder(TestName());
519   Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
520 
521   HloInstruction* param = builder.AddInstruction(
522       HloInstruction::CreateParameter(0, shape, "param"));
523   HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateParameter(
524       1, ShapeUtil::MakeShape(PRED, {}), "pred"));
525 
526   HloInstruction* add0 = builder.AddInstruction(
527       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param));
528   HloInstruction* add1 = builder.AddInstruction(
529       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, param));
530   HloInstruction* tuple0 =
531       builder.AddInstruction(HloInstruction::CreateTuple({param, add0}));
532   HloInstruction* tuple1 =
533       builder.AddInstruction(HloInstruction::CreateTuple({param, add1}));
534   HloInstruction* sel = builder.AddInstruction(HloInstruction::CreateTernary(
535       tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1));
536   HloInstruction* gte0 = builder.AddInstruction(
537       HloInstruction::CreateGetTupleElement(shape, sel, 0));
538   HloInstruction* gte1 = builder.AddInstruction(
539       HloInstruction::CreateGetTupleElement(shape, sel, 1));
540   HloInstruction* xpose =
541       builder.AddInstruction(HloInstruction::CreateTranspose(
542           ShapeUtil::MakeShape(F32, {4, 2}), gte0, {1, 0}));
543   HloInstruction* dot = builder.AddInstruction(
544       CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, gte1));
545 
546   auto computation = module->AddEntryComputation(builder.Build());
547 
548   EXPECT_TRUE(PropagatePrecision(module.get()));
549 
550   EXPECT_EQ(computation->root_instruction(), dot);
551   EXPECT_FALSE(OutputsBF16(add0));
552   EXPECT_FALSE(OutputsBF16(add1));
553   EXPECT_FALSE(OutputsBF16(gte0));
554   EXPECT_FALSE(OutputsBF16(gte1));
555   EXPECT_TRUE(OutputsBF16(xpose));
556 }
557 
558 // Tests that BF16 is propagated properly through a while computation with
559 // non-tuple input/output.
TEST_F(BFloat16PropagationTest,PropagateThroughSimpleWhile)560 TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) {
561   auto module = CreateNewVerifiedModule();
562   auto builder = HloComputation::Builder(TestName());
563   Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
564 
565   HloInstruction* param0 = builder.AddInstruction(
566       HloInstruction::CreateParameter(0, shape, "param0"));
567   HloInstruction* param1 = builder.AddInstruction(
568       HloInstruction::CreateParameter(1, shape, "param1"));
569   HloInstruction* add = builder.AddInstruction(
570       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
571 
572   auto builder_cond = HloComputation::Builder("cond");
573   auto cond_param = builder_cond.AddInstruction(
574       HloInstruction::CreateParameter(0, shape, "cond_param"));
575   auto cond_dot =
576       builder_cond.AddInstruction(CreateDot(shape, cond_param, cond_param));
577   auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateCompare(
578       ShapeUtil::MakeShape(PRED, {}),
579       builder_cond.AddInstruction(HloInstruction::CreateReshape(
580           ShapeUtil::MakeShape(F32, {}),
581           builder_cond.AddInstruction(
582               HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
583                                           cond_dot, {0, 0}, {1, 1}, {1, 1})))),
584       builder_cond.AddInstruction(HloInstruction::CreateReshape(
585           ShapeUtil::MakeShape(F32, {}),
586           builder_cond.AddInstruction(
587               HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
588                                           cond_dot, {1, 1}, {2, 2}, {1, 1})))),
589       ComparisonDirection::kGt));
590   auto cond = module->AddEmbeddedComputation(builder_cond.Build());
591 
592   auto builder_body = HloComputation::Builder("body");
593   auto body_param = builder_body.AddInstruction(
594       HloInstruction::CreateParameter(0, shape, "body_param"));
595   auto body_dot =
596       builder_body.AddInstruction(CreateDot(shape, body_param, body_param));
597   auto body = module->AddEmbeddedComputation(builder_body.Build());
598 
599   auto while_hlo = builder.AddInstruction(
600       HloInstruction::CreateWhile(shape, cond, body, add));
601 
602   auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo));
603   auto computation = module->AddEntryComputation(builder.Build());
604 
605   EXPECT_TRUE(PropagatePrecision(module.get()));
606 
607   EXPECT_EQ(computation->root_instruction(), dot);
608   EXPECT_TRUE(
609       ShapeUtil::Equal(cond_root->shape(), ShapeUtil::MakeShape(PRED, {})));
610   EXPECT_TRUE(OutputsBF16(add));
611   EXPECT_TRUE(OutputsBF16(body_dot));
612   EXPECT_TRUE(OutputsBF16(body_param));
613   EXPECT_TRUE(OutputsBF16(cond_param));
614   EXPECT_FALSE(OutputsBF16(dot));
615 }
616 
617 // Tests that if the while condition prevents using BF16, no changes should be
618 // made to the while body and thus the fusion node inside it.
TEST_F(BFloat16PropagationTest,ConditionPreventsPropagationForFusionInsideWhile)619 TEST_F(BFloat16PropagationTest,
620        ConditionPreventsPropagationForFusionInsideWhile) {
621   auto module = CreateNewVerifiedModule();
622   auto builder = HloComputation::Builder(TestName());
623   Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
624 
625   HloInstruction* param0 = builder.AddInstruction(
626       HloInstruction::CreateParameter(0, shape, "param0"));
627   HloInstruction* param1 = builder.AddInstruction(
628       HloInstruction::CreateParameter(1, shape, "param1"));
629   HloInstruction* add = builder.AddInstruction(
630       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
631 
632   auto builder_cond = HloComputation::Builder("cond");
633   auto cond_param = builder_cond.AddInstruction(
634       HloInstruction::CreateParameter(0, shape, "cond_param"));
635   builder_cond.AddInstruction(HloInstruction::CreateCompare(
636       ShapeUtil::MakeShape(PRED, {}),
637       builder_cond.AddInstruction(HloInstruction::CreateReshape(
638           ShapeUtil::MakeShape(F32, {}),
639           builder_cond.AddInstruction(HloInstruction::CreateSlice(
640               ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {0, 0}, {1, 1},
641               {1, 1})))),
642       builder_cond.AddInstruction(HloInstruction::CreateReshape(
643           ShapeUtil::MakeShape(F32, {}),
644           builder_cond.AddInstruction(HloInstruction::CreateSlice(
645               ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {1, 1}, {2, 2},
646               {1, 1})))),
647       ComparisonDirection::kGt));
648   auto cond = module->AddEmbeddedComputation(builder_cond.Build());
649 
650   auto builder_body = HloComputation::Builder("body");
651   auto body_param = builder_body.AddInstruction(
652       HloInstruction::CreateParameter(0, shape, "body_param"));
653   auto body_transpose = builder_body.AddInstruction(
654       HloInstruction::CreateTranspose(shape, body_param, {0, 1}));
655 
656   auto builder_f = HloComputation::Builder("fusion");
657   HloInstruction* a_f =
658       builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
659   builder_f.AddInstruction(HloInstruction::CreateTranspose(shape, a_f, {0, 1}));
660   auto comp_f = module->AddEmbeddedComputation(builder_f.Build());
661   auto body_fusion = builder_body.AddInstruction(HloInstruction::CreateFusion(
662       shape, HloInstruction::FusionKind::kCustom, {body_transpose}, comp_f));
663   auto body = module->AddEmbeddedComputation(builder_body.Build());
664 
665   auto while_hlo = builder.AddInstruction(
666       HloInstruction::CreateWhile(shape, cond, body, add));
667 
668   auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo));
669   auto computation = module->AddEntryComputation(builder.Build());
670 
671   EXPECT_FALSE(PropagatePrecision(module.get()));
672   EXPECT_EQ(computation->root_instruction(), dot);
673   EXPECT_FALSE(OutputsBF16(add));
674   EXPECT_FALSE(OutputsBF16(body_fusion));
675   EXPECT_FALSE(OutputsBF16(body_param));
676   EXPECT_FALSE(OutputsBF16(body_transpose));
677   EXPECT_FALSE(OutputsBF16(a_f));
678 }
679 
680 // Tests that BF16 is propagated properly through while computations with
681 // tuple-shaped input/output.
TEST_F(BFloat16PropagationTest,PropagateThroughTupleWhile)682 TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) {
683   auto module = CreateNewVerifiedModule();
684   auto builder = HloComputation::Builder(TestName());
685   Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
686 
687   HloInstruction* param0 = builder.AddInstruction(
688       HloInstruction::CreateParameter(0, shape, "param0"));
689   HloInstruction* param1 = builder.AddInstruction(
690       HloInstruction::CreateParameter(1, shape, "param1"));
691   HloInstruction* add0 = builder.AddInstruction(
692       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
693   HloInstruction* add1 = builder.AddInstruction(
694       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
695   HloInstruction* tuple =
696       builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
697 
698   auto builder_cond = HloComputation::Builder("cond");
699   auto cond_param = builder_cond.AddInstruction(
700       HloInstruction::CreateParameter(0, tuple->shape(), "cond_param"));
701   auto cond_lhs = builder_cond.AddInstruction(
702       HloInstruction::CreateGetTupleElement(shape, cond_param, 0));
703   auto cond_rhs = builder_cond.AddInstruction(
704       HloInstruction::CreateGetTupleElement(shape, cond_param, 1));
705   // This add should prevent RHS from using BF16
706   auto cond_add_rhs = builder_cond.AddInstruction(
707       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs));
708   auto cond_dot =
709       builder_cond.AddInstruction(CreateDot(shape, cond_lhs, cond_add_rhs));
710   builder_cond.AddInstruction(HloInstruction::CreateCompare(
711       ShapeUtil::MakeShape(PRED, {}),
712       builder_cond.AddInstruction(HloInstruction::CreateReshape(
713           ShapeUtil::MakeShape(F32, {}),
714           builder_cond.AddInstruction(
715               HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
716                                           cond_dot, {0, 0}, {1, 1}, {1, 1})))),
717       builder_cond.AddInstruction(HloInstruction::CreateReshape(
718           ShapeUtil::MakeShape(F32, {}),
719           builder_cond.AddInstruction(
720               HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
721                                           cond_dot, {1, 1}, {2, 2}, {1, 1})))),
722       ComparisonDirection::kGt));
723   auto cond = module->AddEmbeddedComputation(builder_cond.Build());
724 
725   auto builder_body = HloComputation::Builder("body");
726   auto body_param = builder_body.AddInstruction(
727       HloInstruction::CreateParameter(0, tuple->shape(), "body_param"));
728   auto body_lhs = builder_body.AddInstruction(
729       HloInstruction::CreateGetTupleElement(shape, body_param, 0));
730   auto body_rhs = builder_body.AddInstruction(
731       HloInstruction::CreateGetTupleElement(shape, body_param, 1));
732   auto body_dot1 =
733       builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs));
734   auto body_dot2 =
735       builder_body.AddInstruction(CreateDot(shape, body_rhs, body_lhs));
736   auto body_transpose = builder_body.AddInstruction(
737       HloInstruction::CreateTranspose(shape, body_dot2, {0, 1}));
738   builder_body.AddInstruction(
739       HloInstruction::CreateTuple({body_dot1, body_transpose}));
740   auto body = module->AddEmbeddedComputation(builder_body.Build());
741 
742   auto while_hlo = builder.AddInstruction(
743       HloInstruction::CreateWhile(tuple->shape(), cond, body, tuple));
744 
745   auto lhs = builder.AddInstruction(
746       HloInstruction::CreateGetTupleElement(shape, while_hlo, 0));
747   auto rhs = builder.AddInstruction(
748       HloInstruction::CreateGetTupleElement(shape, while_hlo, 1));
749   auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs));
750   auto computation = module->AddEntryComputation(builder.Build());
751 
752   EXPECT_TRUE(PropagatePrecision(module.get()));
753 
754   EXPECT_EQ(computation->root_instruction(), dot);
755   EXPECT_TRUE(OutputsBF16(lhs));
756   EXPECT_FALSE(OutputsBF16(rhs));
757   EXPECT_TRUE(OutputsBF16(body_dot1));
758   EXPECT_TRUE(OutputsBF16(body_lhs));
759   EXPECT_FALSE(OutputsBF16(body_rhs));
760   EXPECT_FALSE(OutputsBF16(body_dot2));
761   EXPECT_FALSE(OutputsBF16(body_transpose));
762   EXPECT_TRUE(OutputsBF16(cond_lhs));
763   EXPECT_FALSE(OutputsBF16(cond_rhs));
764   EXPECT_TRUE(OutputsBF16(add0));
765   EXPECT_FALSE(OutputsBF16(add1));
766 }
767 
768 // Tests that BF16 is not propagated through multiple whiles that invoke the
769 // same computation as long as one while prevents the propagation.
TEST_F(BFloat16PropagationTest,DoNotPropagateWhilesCallingSameComputation)770 TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
771   auto module = CreateNewVerifiedModule();
772   auto builder = HloComputation::Builder(TestName());
773   Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
774 
775   HloInstruction* param0 = builder.AddInstruction(
776       HloInstruction::CreateParameter(0, shape, "param0"));
777   HloInstruction* param1 = builder.AddInstruction(
778       HloInstruction::CreateParameter(1, shape, "param1"));
779   HloInstruction* add0 = builder.AddInstruction(
780       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
781   HloInstruction* add1 = builder.AddInstruction(
782       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
783   HloInstruction* add2 = builder.AddInstruction(
784       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
785   HloInstruction* add3 = builder.AddInstruction(
786       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
787   HloInstruction* tuple0 =
788       builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
789   HloInstruction* tuple1 =
790       builder.AddInstruction(HloInstruction::CreateTuple({add2, add3}));
791 
792   // Condition computation for the first while.
793   auto builder_cond0 = HloComputation::Builder("cond0");
794   auto cond0_param = builder_cond0.AddInstruction(
795       HloInstruction::CreateParameter(0, tuple0->shape(), "cond0_param"));
796   auto cond0_lhs = builder_cond0.AddInstruction(
797       HloInstruction::CreateGetTupleElement(shape, cond0_param, 0));
798   auto cond0_rhs = builder_cond0.AddInstruction(
799       HloInstruction::CreateGetTupleElement(shape, cond0_param, 1));
800   // This add should prevent RHS from using BF16
801   auto cond0_add_rhs =
802       builder_cond0.AddInstruction(HloInstruction::CreateBinary(
803           shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs));
804   auto cond0_dot =
805       builder_cond0.AddInstruction(CreateDot(shape, cond0_lhs, cond0_add_rhs));
806   builder_cond0.AddInstruction(HloInstruction::CreateCompare(
807       ShapeUtil::MakeShape(PRED, {}),
808       builder_cond0.AddInstruction(HloInstruction::CreateReshape(
809           ShapeUtil::MakeShape(F32, {}),
810           builder_cond0.AddInstruction(
811               HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
812                                           cond0_dot, {0, 0}, {1, 1}, {1, 1})))),
813       builder_cond0.AddInstruction(HloInstruction::CreateReshape(
814           ShapeUtil::MakeShape(F32, {}),
815           builder_cond0.AddInstruction(
816               HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
817                                           cond0_dot, {1, 1}, {2, 2}, {1, 1})))),
818       ComparisonDirection::kGt));
819   auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build());
820 
821   // Condition computation for the second while.
822   auto builder_cond1 = HloComputation::Builder("cond1");
823   auto cond1_param = builder_cond1.AddInstruction(
824       HloInstruction::CreateParameter(0, tuple1->shape(), "cond1_param"));
825   auto cond1_lhs = builder_cond1.AddInstruction(
826       HloInstruction::CreateGetTupleElement(shape, cond1_param, 0));
827   auto cond1_rhs = builder_cond1.AddInstruction(
828       HloInstruction::CreateGetTupleElement(shape, cond1_param, 1));
829   // This add should prevent LHS from using BF16
830   auto cond1_add_lhs =
831       builder_cond1.AddInstruction(HloInstruction::CreateBinary(
832           shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs));
833   auto cond1_dot =
834       builder_cond1.AddInstruction(CreateDot(shape, cond1_add_lhs, cond1_rhs));
835   builder_cond1.AddInstruction(HloInstruction::CreateCompare(
836       ShapeUtil::MakeShape(PRED, {}),
837       builder_cond1.AddInstruction(HloInstruction::CreateReshape(
838           ShapeUtil::MakeShape(F32, {}),
839           builder_cond1.AddInstruction(
840               HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
841                                           cond1_dot, {0, 0}, {1, 1}, {1, 1})))),
842       builder_cond1.AddInstruction(HloInstruction::CreateReshape(
843           ShapeUtil::MakeShape(F32, {}),
844           builder_cond1.AddInstruction(
845               HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
846                                           cond1_dot, {1, 1}, {2, 2}, {1, 1})))),
847       ComparisonDirection::kGt));
848   auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build());
849 
850   // Body computation shared by both whiles.
851   auto builder_body = HloComputation::Builder("body");
852   auto body_param = builder_body.AddInstruction(
853       HloInstruction::CreateParameter(0, tuple0->shape(), "body_param"));
854   auto body_lhs = builder_body.AddInstruction(
855       HloInstruction::CreateGetTupleElement(shape, body_param, 0));
856   auto body_rhs = builder_body.AddInstruction(
857       HloInstruction::CreateGetTupleElement(shape, body_param, 1));
858   auto body_dot =
859       builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs));
860   builder_body.AddInstruction(
861       HloInstruction::CreateTuple({body_dot, body_rhs}));
862   auto body = module->AddEmbeddedComputation(builder_body.Build());
863 
864   auto while0 = builder.AddInstruction(
865       HloInstruction::CreateWhile(tuple0->shape(), cond0, body, tuple0));
866   auto while1 = builder.AddInstruction(
867       HloInstruction::CreateWhile(tuple1->shape(), cond1, body, tuple1));
868 
869   auto lhs = builder.AddInstruction(
870       CreateDot(shape,
871                 builder.AddInstruction(
872                     HloInstruction::CreateGetTupleElement(shape, while0, 0)),
873                 builder.AddInstruction(
874                     HloInstruction::CreateGetTupleElement(shape, while0, 1))));
875   auto rhs = builder.AddInstruction(
876       CreateDot(shape,
877                 builder.AddInstruction(
878                     HloInstruction::CreateGetTupleElement(shape, while1, 0)),
879                 builder.AddInstruction(
880                     HloInstruction::CreateGetTupleElement(shape, while1, 1))));
881   auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs));
882   auto computation = module->AddEntryComputation(builder.Build());
883 
884   EXPECT_TRUE(PropagatePrecision(module.get()));
885   EXPECT_FALSE(OutputsBF16(body_dot));
886   EXPECT_FALSE(OutputsBF16(body_rhs));
887   EXPECT_FALSE(OutputsBF16(body_lhs));
888   EXPECT_FALSE(OutputsBF16(cond0_lhs));
889   EXPECT_FALSE(OutputsBF16(cond0_rhs));
890   EXPECT_FALSE(OutputsBF16(cond1_lhs));
891   EXPECT_FALSE(OutputsBF16(cond1_rhs));
892   EXPECT_TRUE(OutputsBF16(cond0_add_rhs));
893   EXPECT_TRUE(OutputsBF16(cond1_add_lhs));
894   EXPECT_EQ(computation->root_instruction(), dot);
895 }
896 
897 // Tests that if this pass turns an F32 -> BF16 conversion into a no-op (BF16 ->
898 // BF16 conversion), then it will remove that conversion.
TEST_F(BFloat16PropagationTest,NoopConversionRemoved)899 TEST_F(BFloat16PropagationTest, NoopConversionRemoved) {
900   auto builder = HloComputation::Builder(TestName());
901   Shape f32_shape = ShapeUtil::MakeShape(F32, {4, 4});
902   Shape bf16_shape = ShapeUtil::MakeShape(BF16, {4, 4});
903 
904   HloInstruction* param = builder.AddInstruction(
905       HloInstruction::CreateParameter(0, f32_shape, "param"));
906   HloInstruction* add0 = builder.AddInstruction(
907       HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, param, param));
908   HloInstruction* add1 = builder.AddInstruction(
909       HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, param, param));
910   HloInstruction* tuple =
911       builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
912   HloInstruction* gte0 = builder.AddInstruction(
913       HloInstruction::CreateGetTupleElement(f32_shape, tuple, 0));
914   HloInstruction* gte1 = builder.AddInstruction(
915       HloInstruction::CreateGetTupleElement(f32_shape, tuple, 1));
916   HloInstruction* convert0 =
917       builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte0));
918   HloInstruction* convert1 =
919       builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte1));
920   HloInstruction* add2 = builder.AddInstruction(HloInstruction::CreateBinary(
921       bf16_shape, HloOpcode::kAdd, convert0, convert1));
922 
923   auto module = CreateNewVerifiedModule();
924   auto computation = module->AddEntryComputation(builder.Build());
925 
926   EXPECT_TRUE(PropagatePrecision(module.get()));
927 
928   EXPECT_EQ(computation->root_instruction(), add2);
929   EXPECT_EQ(add2->operand(0), add0);
930   EXPECT_EQ(add2->operand(1), add1);
931   EXPECT_EQ(add0->shape().element_type(), BF16);
932   EXPECT_EQ(add1->shape().element_type(), BF16);
933 }
934 
TEST_F(BFloat16PropagationTest,TupleDomain)935 TEST_F(BFloat16PropagationTest, TupleDomain) {
936   auto builder = HloComputation::Builder(TestName());
937   Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
938 
939   HloInstruction* a =
940       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
941   HloInstruction* b =
942       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
943   HloInstruction* a_trans =
944       builder.AddInstruction(HloInstruction::CreateTranspose(shape, a, {0, 1}));
945   HloInstruction* b_trans =
946       builder.AddInstruction(HloInstruction::CreateTranspose(shape, b, {0, 1}));
947   HloInstruction* tuple =
948       builder.AddInstruction(HloInstruction::CreateTuple({a_trans, b_trans}));
949   HloInstruction* domain = builder.AddInstruction(
950       HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr));
951   HloInstruction* a_gte = builder.AddInstruction(
952       HloInstruction::CreateGetTupleElement(shape, domain, 0));
953   HloInstruction* b_gte = builder.AddInstruction(
954       HloInstruction::CreateGetTupleElement(shape, domain, 1));
955   HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a_gte, b_gte));
956   HloInstruction* root = builder.AddInstruction(
957       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
958 
959   auto module = CreateNewVerifiedModule();
960   auto computation = module->AddEntryComputation(builder.Build());
961 
962   EXPECT_TRUE(PropagatePrecision(module.get()));
963   EXPECT_EQ(computation->root_instruction(), root);
964 
965   // test BF16 propagated through domain
966   EXPECT_EQ(ShapeUtil::GetTupleElementShape(domain->shape(), 0).element_type(),
967             BF16);
968   EXPECT_EQ(ShapeUtil::GetTupleElementShape(domain->shape(), 1).element_type(),
969             BF16);
970 
971   EXPECT_TRUE(OutputsBF16(a_trans));
972   EXPECT_TRUE(OutputsBF16(b_trans));
973   EXPECT_TRUE(OutputsBF16(a_gte));
974   EXPECT_TRUE(OutputsBF16(b_gte));
975   EXPECT_FALSE(OutputsBF16(a));
976   EXPECT_FALSE(OutputsBF16(b));
977 }
978 
979 // Tests that bf16 is not propagated through a domain in case its input cannot
980 // be propagated. In the case below the input of the domain is the parameter
981 // tuple which cannot be propagated, so the domain instruction is not propagated
982 // either.
TEST_F(BFloat16PropagationTest,TupleDomainNoPropagation)983 TEST_F(BFloat16PropagationTest, TupleDomainNoPropagation) {
984   auto builder = HloComputation::Builder(TestName());
985   Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
986   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
987 
988   HloInstruction* param = builder.AddInstruction(
989       HloInstruction::CreateParameter(0, tuple_shape, "param"));
990   HloInstruction* domain = builder.AddInstruction(
991       HloInstruction::CreateDomain(param->shape(), param, nullptr, nullptr));
992   HloInstruction* a_gte = builder.AddInstruction(
993       HloInstruction::CreateGetTupleElement(shape, domain, 0));
994   HloInstruction* b_gte = builder.AddInstruction(
995       HloInstruction::CreateGetTupleElement(shape, domain, 1));
996   HloInstruction* a_trans = builder.AddInstruction(
997       HloInstruction::CreateTranspose(shape, a_gte, {0, 1}));
998   HloInstruction* b_trans = builder.AddInstruction(
999       HloInstruction::CreateTranspose(shape, b_gte, {0, 1}));
1000   HloInstruction* dot =
1001       builder.AddInstruction(CreateDot(shape, a_trans, b_trans));
1002   HloInstruction* root = builder.AddInstruction(
1003       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
1004 
1005   auto module = CreateNewVerifiedModule();
1006   auto computation = module->AddEntryComputation(builder.Build());
1007 
1008   EXPECT_TRUE(PropagatePrecision(module.get()));
1009 
1010   EXPECT_EQ(computation->root_instruction(), root);
1011   EXPECT_TRUE(OutputsBF16(a_trans));
1012   EXPECT_TRUE(OutputsBF16(b_trans));
1013   EXPECT_FALSE(OutputsBF16(a_gte));
1014   EXPECT_FALSE(OutputsBF16(b_gte));
1015   EXPECT_FALSE(OutputsBF16(domain));
1016   EXPECT_FALSE(OutputsBF16(param));
1017 }
1018 
1019 }  // namespace xla
1020