1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
17 #include "tensorflow/compiler/xla/service/bfloat16_support.h"
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/service/hlo_module.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/test.h"
26 #include "tensorflow/compiler/xla/test_helpers.h"
27 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 
30 namespace xla {
31 
32 class TestBFloat16Support : public BFloat16Support {
33  public:
TestBFloat16Support()34   TestBFloat16Support() {}
~TestBFloat16Support()35   ~TestBFloat16Support() override {}
36 
SupportsBF16Operand(const HloInstruction & hlo,int64 operand_index) const37   bool SupportsBF16Operand(const HloInstruction& hlo,
38                            int64 operand_index) const override {
39     if (hlo.opcode() == HloOpcode::kAdd ||
40         hlo.opcode() == HloOpcode::kSubtract ||
41         hlo.opcode() == HloOpcode::kReduce ||
42         hlo.opcode() == HloOpcode::kTuple ||
43         hlo.opcode() == HloOpcode::kGetTupleElement) {
44       return true;
45     }
46     if (hlo.opcode() == HloOpcode::kDot) {
47       // Test that only the first operand of kDot supports BF16.
48       return operand_index == 0;
49     }
50     return false;
51   }
52 
SupportsBF16Output(const HloInstruction & hlo) const53   bool SupportsBF16Output(const HloInstruction& hlo) const override {
54     if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kReduce ||
55         hlo.opcode() == HloOpcode::kSubtract ||
56         hlo.opcode() == HloOpcode::kDot || hlo.opcode() == HloOpcode::kTuple ||
57         hlo.opcode() == HloOpcode::kGetTupleElement) {
58       return true;
59     }
60     return false;
61   }
62 
SupportsMixedPrecisions(const HloInstruction & hlo) const63   bool SupportsMixedPrecisions(const HloInstruction& hlo) const override {
64     if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple ||
65         hlo.opcode() == HloOpcode::kGetTupleElement) {
66       return true;
67     }
68     return false;
69   }
70 };
71 
72 class BFloat16NormalizationTest : public HloTestBase {
73  protected:
BFloat16NormalizationTest()74   BFloat16NormalizationTest()
75       : HloTestBase(/*verifier_layout_sensitive=*/false,
76                     /*allow_mixed_precision_in_hlo_verifier=*/true) {}
77 
Normalize(HloModule * module)78   bool Normalize(HloModule* module) {
79     TestBFloat16Support bfloat16_support_;
80     BFloat16Normalization normalization(&bfloat16_support_);
81     StatusOr<bool> result = normalization.Run(module);
82     EXPECT_IS_OK(result.status());
83 
84     HloVerifier verifier(/*layout_sensitive=*/false,
85                          /*allow_mixed_precision=*/true);
86     EXPECT_IS_OK(verifier.Run(module).status());
87 
88     return result.ValueOrDie();
89   }
90 };
91 
TEST_F(BFloat16NormalizationTest,NoopIfSupported)92 TEST_F(BFloat16NormalizationTest, NoopIfSupported) {
93   auto builder = HloComputation::Builder(TestName());
94   Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
95   Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
96 
97   HloInstruction* a = builder.AddInstruction(
98       HloInstruction::CreateParameter(0, f32_shape, "a"));
99   HloInstruction* b = builder.AddInstruction(
100       HloInstruction::CreateParameter(1, bf16_shape, "b"));
101   HloInstruction* c = builder.AddInstruction(
102       HloInstruction::CreateParameter(2, f32_shape, "c"));
103 
104   HloInstruction* add0 = builder.AddInstruction(
105       HloInstruction::CreateBinary(bf16_shape, HloOpcode::kAdd, a, b));
106 
107   HloInstruction* add1 = builder.AddInstruction(
108       HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, add0, c));
109 
110   auto module = CreateNewVerifiedModule();
111   auto computation = module->AddEntryComputation(builder.Build());
112 
113   EXPECT_FALSE(Normalize(module.get()));
114 
115   EXPECT_EQ(computation->root_instruction(), add1);
116   EXPECT_EQ(add0->shape().element_type(), BF16);
117   EXPECT_EQ(add1->shape().element_type(), F32);
118 }
119 
TEST_F(BFloat16NormalizationTest,ResolveIfUnsupportedBF16)120 TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) {
121   auto builder = HloComputation::Builder(TestName());
122   Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
123   Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
124 
125   HloInstruction* a = builder.AddInstruction(
126       HloInstruction::CreateParameter(0, f32_shape, "a"));
127   HloInstruction* b = builder.AddInstruction(
128       HloInstruction::CreateParameter(1, bf16_shape, "b"));
129   HloInstruction* c = builder.AddInstruction(
130       HloInstruction::CreateParameter(2, f32_shape, "c"));
131 
132   HloInstruction* mul0 = builder.AddInstruction(
133       HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, a, b));
134 
135   HloInstruction* mul1 = builder.AddInstruction(
136       HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, mul0, c));
137 
138   auto module = CreateNewVerifiedModule();
139   auto computation = module->AddEntryComputation(builder.Build());
140 
141   EXPECT_TRUE(Normalize(module.get()));
142 
143   EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
144   EXPECT_EQ(computation->root_instruction()->operand(0), mul1);
145   EXPECT_EQ(mul0->shape().element_type(), F32);
146   EXPECT_EQ(mul1->shape().element_type(), F32);
147   EXPECT_EQ(mul1->operand(0)->opcode(), HloOpcode::kConvert);
148 }
149 
TEST_F(BFloat16NormalizationTest,ResolveUnsupportedMixedPrecisionSubtraction)150 TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) {
151   auto builder = HloComputation::Builder(TestName());
152   Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
153   Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
154 
155   HloInstruction* a = builder.AddInstruction(
156       HloInstruction::CreateParameter(0, f32_shape, "a"));
157   HloInstruction* b = builder.AddInstruction(
158       HloInstruction::CreateParameter(1, bf16_shape, "b"));
159   HloInstruction* c = builder.AddInstruction(
160       HloInstruction::CreateParameter(2, f32_shape, "c"));
161 
162   HloInstruction* sub0 = builder.AddInstruction(
163       HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, a, b));
164 
165   HloInstruction* sub1 = builder.AddInstruction(
166       HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, sub0, c));
167 
168   auto module = CreateNewVerifiedModule();
169   auto computation = module->AddEntryComputation(builder.Build());
170 
171   EXPECT_TRUE(Normalize(module.get()));
172 
173   EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
174   EXPECT_EQ(computation->root_instruction()->operand(0), sub1);
175   EXPECT_EQ(sub0->shape().element_type(), F32);
176   EXPECT_EQ(sub1->shape().element_type(), F32);
177   EXPECT_EQ(sub1->operand(0)->opcode(), HloOpcode::kConvert);
178 }
179 
TEST_F(BFloat16NormalizationTest,ResolveUnsupportedMixedPrecisionReduce)180 TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) {
181   Shape f32_input_shape = ShapeUtil::MakeShape(F32, {2, 4});
182   Shape f32_output_shape = ShapeUtil::MakeShape(F32, {4});
183 
184   Shape bf16_scalar_shape = ShapeUtil::MakeShape(BF16, {});
185 
186   auto reduce_comp_builder = HloComputation::Builder("reduce_comp");
187   auto reduce_comp_param0 = reduce_comp_builder.AddInstruction(
188       HloInstruction::CreateParameter(0, bf16_scalar_shape, "param0"));
189   auto reduce_comp_param1 = reduce_comp_builder.AddInstruction(
190       HloInstruction::CreateParameter(1, bf16_scalar_shape, "param1"));
191   reduce_comp_builder.AddInstruction(
192       HloInstruction::CreateBinary(bf16_scalar_shape, HloOpcode::kAdd,
193                                    reduce_comp_param0, reduce_comp_param1));
194 
195   auto module = CreateNewVerifiedModule();
196   auto reduce_computation =
197       module->AddEmbeddedComputation(reduce_comp_builder.Build());
198 
199   auto builder = HloComputation::Builder(TestName());
200   HloInstruction* input = builder.AddInstruction(
201       HloInstruction::CreateParameter(0, f32_input_shape, "a"));
202   HloInstruction* init = builder.AddInstruction(
203       HloInstruction::CreateParameter(1, bf16_scalar_shape, "init"));
204   HloInstruction* reduce = builder.AddInstruction(HloInstruction::CreateReduce(
205       f32_output_shape, input, init, {0}, reduce_computation));
206 
207   auto computation = module->AddEntryComputation(builder.Build());
208 
209   EXPECT_TRUE(Normalize(module.get()));
210 
211   EXPECT_EQ(computation->root_instruction(), reduce);
212   EXPECT_EQ(reduce->called_computations().size(), 1);
213   EXPECT_EQ(reduce->called_computations()[0]->num_parameters(), 2);
214   EXPECT_EQ(reduce->called_computations()[0]
215                 ->parameter_instruction(0)
216                 ->shape()
217                 .element_type(),
218             F32);
219   EXPECT_EQ(reduce->called_computations()[0]
220                 ->parameter_instruction(1)
221                 ->shape()
222                 .element_type(),
223             F32);
224   EXPECT_EQ(reduce->called_computations()[0]
225                 ->root_instruction()
226                 ->shape()
227                 .element_type(),
228             F32);
229   EXPECT_EQ(reduce->shape().element_type(), F32);
230   EXPECT_EQ(reduce->operand(0), input);
231   EXPECT_EQ(input->shape().element_type(), F32);
232   EXPECT_EQ(reduce->operand(1)->opcode(), HloOpcode::kConvert);
233   EXPECT_EQ(reduce->operand(1)->shape().element_type(), F32);
234 }
235 
TEST_F(BFloat16NormalizationTest,ResolveMixedPrecisionTupleAllReduce)236 TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllReduce) {
237   auto module = CreateNewVerifiedModule();
238   HloComputation::Builder sum_builder("sum");
239   auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter(
240       /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x"));
241   auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter(
242       /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y"));
243   sum_builder.AddInstruction(HloInstruction::CreateBinary(
244       ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, x, y));
245   HloComputation* reduction =
246       module->AddEmbeddedComputation(sum_builder.Build());
247 
248   auto builder = HloComputation::Builder(TestName());
249   Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
250   Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
251 
252   HloInstruction* a = builder.AddInstruction(
253       HloInstruction::CreateParameter(0, f32_shape, "a"));
254   HloInstruction* b = builder.AddInstruction(
255       HloInstruction::CreateParameter(1, bf16_shape, "b"));
256 
257   HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce(
258       ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction,
259       /*replica_groups=*/{}, /*barrier=*/"",
260       /*all_reduce_id=*/absl::nullopt));
261   HloInstruction* gte = builder.AddInstruction(
262       HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1));
263 
264   auto computation = module->AddEntryComputation(builder.Build());
265 
266   EXPECT_TRUE(Normalize(module.get()));
267 
268   EXPECT_EQ(computation->root_instruction(), gte);
269   EXPECT_EQ(gte->shape().element_type(), BF16);
270   EXPECT_EQ(crs->operand(1)->shape().element_type(), F32);
271   EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), F32);
272 }
273 
TEST_F(BFloat16NormalizationTest,ResolveMixedPrecisionTupleSort)274 TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) {
275   auto module = CreateNewVerifiedModule();
276   auto builder = HloComputation::Builder(TestName());
277   Shape f32_shape = ShapeUtil::MakeShape(F32, {1024});
278   Shape bf16_shape = ShapeUtil::MakeShape(BF16, {1024});
279   Shape s32_shape = ShapeUtil::MakeShape(BF16, {1024});
280 
281   HloInstruction* key = builder.AddInstruction(
282       HloInstruction::CreateParameter(0, f32_shape, "key"));
283   HloInstruction* value = builder.AddInstruction(
284       HloInstruction::CreateParameter(1, s32_shape, "value"));
285 
286   TF_ASSERT_OK_AND_ASSIGN(
287       auto* sort,
288       MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}),
289                   {key, value}, 0, /*is_stable=*/false, &builder,
290                   module.get()));
291   HloInstruction* gte = builder.AddInstruction(
292       HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0));
293 
294   auto computation = module->AddEntryComputation(builder.Build());
295 
296   EXPECT_TRUE(Normalize(module.get()));
297 
298   EXPECT_EQ(computation->root_instruction(), gte);
299   EXPECT_EQ(gte->shape().element_type(), BF16);
300   EXPECT_EQ(sort->operand(0)->shape().element_type(), F32);
301   EXPECT_EQ(ShapeUtil::GetSubshape(sort->shape(), {0}).element_type(), F32);
302 }
303 
TEST_F(BFloat16NormalizationTest,ResolveMixedPrecisionTupleSortRoot)304 TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSortRoot) {
305   auto module = CreateNewVerifiedModule();
306   auto builder = HloComputation::Builder(TestName());
307   Shape f32_shape = ShapeUtil::MakeShape(F32, {1024});
308   Shape bf16_shape = ShapeUtil::MakeShape(BF16, {1024});
309 
310   HloInstruction* key = builder.AddInstruction(
311       HloInstruction::CreateParameter(0, f32_shape, "key"));
312   HloInstruction* value = builder.AddInstruction(
313       HloInstruction::CreateParameter(1, bf16_shape, "value"));
314 
315   TF_ASSERT_OK_AND_ASSIGN(
316       auto* sort,
317       MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, f32_shape}),
318                   {key, value}, 0, /*is_stable=*/false, &builder,
319                   module.get()));
320 
321   auto computation = module->AddEntryComputation(builder.Build());
322 
323   EXPECT_TRUE(Normalize(module.get()));
324 
325   EXPECT_EQ(sort->operand(0)->shape().element_type(), F32);
326   EXPECT_EQ(ShapeUtil::GetSubshape(sort->shape(), {0}).element_type(), F32);
327   EXPECT_NE(computation->root_instruction(), sort);
328   EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kTuple);
329   EXPECT_EQ(sort->to_apply()->parameter_instruction(1)->shape().element_type(),
330             F32);
331   // Make sure that no convert to BF16 was added to the 'to_apply' comparison
332   // computation.
333   auto users = sort->to_apply()->parameter_instruction(1)->users();
334   for (auto user : users) {
335     EXPECT_NE(user->opcode(), HloOpcode::kConvert);
336   }
337 }
338 
339 // Tests that the normalization should not cause unsupported mixed precision due
340 // to resolving unsupported BF16 operand.
TEST_F(BFloat16NormalizationTest,DoNotAddUnsupportedMixedPrecision)341 TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) {
342   auto builder = HloComputation::Builder(TestName());
343   Shape bf16_shape = ShapeUtil::MakeShape(BF16, {4, 4});
344 
345   HloInstruction* a = builder.AddInstruction(
346       HloInstruction::CreateParameter(0, bf16_shape, "a"));
347   HloInstruction* b = builder.AddInstruction(
348       HloInstruction::CreateParameter(1, bf16_shape, "b"));
349 
350   DotDimensionNumbers dot_dnums;
351   dot_dnums.add_lhs_contracting_dimensions(1);
352   dot_dnums.add_rhs_contracting_dimensions(0);
353   PrecisionConfig precision_config;
354   precision_config.mutable_operand_precision()->Resize(
355       2, PrecisionConfig::DEFAULT);
356   HloInstruction* dot = builder.AddInstruction(
357       HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums, precision_config));
358 
359   auto module = CreateNewVerifiedModule();
360   auto computation = module->AddEntryComputation(builder.Build());
361 
362   EXPECT_TRUE(Normalize(module.get()));
363 
364   EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
365   EXPECT_EQ(dot->shape().element_type(), F32);
366   EXPECT_EQ(dot->operand(0)->shape().element_type(), F32);
367   EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConvert);
368   EXPECT_EQ(dot->operand(1)->shape().element_type(), F32);
369   EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConvert);
370 }
371 
372 }  // namespace xla
373