1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
17 #include "tensorflow/compiler/xla/service/bfloat16_support.h"
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/service/hlo_module.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/test.h"
26 #include "tensorflow/compiler/xla/test_helpers.h"
27 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29
30 namespace xla {
31
32 class TestBFloat16Support : public BFloat16Support {
33 public:
TestBFloat16Support()34 TestBFloat16Support() {}
~TestBFloat16Support()35 ~TestBFloat16Support() override {}
36
SupportsBF16Operand(const HloInstruction & hlo,int64 operand_index) const37 bool SupportsBF16Operand(const HloInstruction& hlo,
38 int64 operand_index) const override {
39 if (hlo.opcode() == HloOpcode::kAdd ||
40 hlo.opcode() == HloOpcode::kSubtract ||
41 hlo.opcode() == HloOpcode::kReduce ||
42 hlo.opcode() == HloOpcode::kTuple ||
43 hlo.opcode() == HloOpcode::kGetTupleElement) {
44 return true;
45 }
46 if (hlo.opcode() == HloOpcode::kDot) {
47 // Test that only the first operand of kDot supports BF16.
48 return operand_index == 0;
49 }
50 return false;
51 }
52
SupportsBF16Output(const HloInstruction & hlo) const53 bool SupportsBF16Output(const HloInstruction& hlo) const override {
54 if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kReduce ||
55 hlo.opcode() == HloOpcode::kSubtract ||
56 hlo.opcode() == HloOpcode::kDot || hlo.opcode() == HloOpcode::kTuple ||
57 hlo.opcode() == HloOpcode::kGetTupleElement) {
58 return true;
59 }
60 return false;
61 }
62
SupportsMixedPrecisions(const HloInstruction & hlo) const63 bool SupportsMixedPrecisions(const HloInstruction& hlo) const override {
64 if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple ||
65 hlo.opcode() == HloOpcode::kGetTupleElement) {
66 return true;
67 }
68 return false;
69 }
70 };
71
72 class BFloat16NormalizationTest : public HloTestBase {
73 protected:
BFloat16NormalizationTest()74 BFloat16NormalizationTest()
75 : HloTestBase(/*verifier_layout_sensitive=*/false,
76 /*allow_mixed_precision_in_hlo_verifier=*/true) {}
77
Normalize(HloModule * module)78 bool Normalize(HloModule* module) {
79 TestBFloat16Support bfloat16_support_;
80 BFloat16Normalization normalization(&bfloat16_support_);
81 StatusOr<bool> result = normalization.Run(module);
82 EXPECT_IS_OK(result.status());
83
84 HloVerifier verifier(/*layout_sensitive=*/false,
85 /*allow_mixed_precision=*/true);
86 EXPECT_IS_OK(verifier.Run(module).status());
87
88 return result.ValueOrDie();
89 }
90 };
91
TEST_F(BFloat16NormalizationTest,NoopIfSupported)92 TEST_F(BFloat16NormalizationTest, NoopIfSupported) {
93 auto builder = HloComputation::Builder(TestName());
94 Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
95 Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
96
97 HloInstruction* a = builder.AddInstruction(
98 HloInstruction::CreateParameter(0, f32_shape, "a"));
99 HloInstruction* b = builder.AddInstruction(
100 HloInstruction::CreateParameter(1, bf16_shape, "b"));
101 HloInstruction* c = builder.AddInstruction(
102 HloInstruction::CreateParameter(2, f32_shape, "c"));
103
104 HloInstruction* add0 = builder.AddInstruction(
105 HloInstruction::CreateBinary(bf16_shape, HloOpcode::kAdd, a, b));
106
107 HloInstruction* add1 = builder.AddInstruction(
108 HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, add0, c));
109
110 auto module = CreateNewVerifiedModule();
111 auto computation = module->AddEntryComputation(builder.Build());
112
113 EXPECT_FALSE(Normalize(module.get()));
114
115 EXPECT_EQ(computation->root_instruction(), add1);
116 EXPECT_EQ(add0->shape().element_type(), BF16);
117 EXPECT_EQ(add1->shape().element_type(), F32);
118 }
119
TEST_F(BFloat16NormalizationTest,ResolveIfUnsupportedBF16)120 TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) {
121 auto builder = HloComputation::Builder(TestName());
122 Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
123 Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
124
125 HloInstruction* a = builder.AddInstruction(
126 HloInstruction::CreateParameter(0, f32_shape, "a"));
127 HloInstruction* b = builder.AddInstruction(
128 HloInstruction::CreateParameter(1, bf16_shape, "b"));
129 HloInstruction* c = builder.AddInstruction(
130 HloInstruction::CreateParameter(2, f32_shape, "c"));
131
132 HloInstruction* mul0 = builder.AddInstruction(
133 HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, a, b));
134
135 HloInstruction* mul1 = builder.AddInstruction(
136 HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, mul0, c));
137
138 auto module = CreateNewVerifiedModule();
139 auto computation = module->AddEntryComputation(builder.Build());
140
141 EXPECT_TRUE(Normalize(module.get()));
142
143 EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
144 EXPECT_EQ(computation->root_instruction()->operand(0), mul1);
145 EXPECT_EQ(mul0->shape().element_type(), F32);
146 EXPECT_EQ(mul1->shape().element_type(), F32);
147 EXPECT_EQ(mul1->operand(0)->opcode(), HloOpcode::kConvert);
148 }
149
TEST_F(BFloat16NormalizationTest,ResolveUnsupportedMixedPrecisionSubtraction)150 TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) {
151 auto builder = HloComputation::Builder(TestName());
152 Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
153 Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
154
155 HloInstruction* a = builder.AddInstruction(
156 HloInstruction::CreateParameter(0, f32_shape, "a"));
157 HloInstruction* b = builder.AddInstruction(
158 HloInstruction::CreateParameter(1, bf16_shape, "b"));
159 HloInstruction* c = builder.AddInstruction(
160 HloInstruction::CreateParameter(2, f32_shape, "c"));
161
162 HloInstruction* sub0 = builder.AddInstruction(
163 HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, a, b));
164
165 HloInstruction* sub1 = builder.AddInstruction(
166 HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, sub0, c));
167
168 auto module = CreateNewVerifiedModule();
169 auto computation = module->AddEntryComputation(builder.Build());
170
171 EXPECT_TRUE(Normalize(module.get()));
172
173 EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
174 EXPECT_EQ(computation->root_instruction()->operand(0), sub1);
175 EXPECT_EQ(sub0->shape().element_type(), F32);
176 EXPECT_EQ(sub1->shape().element_type(), F32);
177 EXPECT_EQ(sub1->operand(0)->opcode(), HloOpcode::kConvert);
178 }
179
TEST_F(BFloat16NormalizationTest,ResolveUnsupportedMixedPrecisionReduce)180 TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) {
181 Shape f32_input_shape = ShapeUtil::MakeShape(F32, {2, 4});
182 Shape f32_output_shape = ShapeUtil::MakeShape(F32, {4});
183
184 Shape bf16_scalar_shape = ShapeUtil::MakeShape(BF16, {});
185
186 auto reduce_comp_builder = HloComputation::Builder("reduce_comp");
187 auto reduce_comp_param0 = reduce_comp_builder.AddInstruction(
188 HloInstruction::CreateParameter(0, bf16_scalar_shape, "param0"));
189 auto reduce_comp_param1 = reduce_comp_builder.AddInstruction(
190 HloInstruction::CreateParameter(1, bf16_scalar_shape, "param1"));
191 reduce_comp_builder.AddInstruction(
192 HloInstruction::CreateBinary(bf16_scalar_shape, HloOpcode::kAdd,
193 reduce_comp_param0, reduce_comp_param1));
194
195 auto module = CreateNewVerifiedModule();
196 auto reduce_computation =
197 module->AddEmbeddedComputation(reduce_comp_builder.Build());
198
199 auto builder = HloComputation::Builder(TestName());
200 HloInstruction* input = builder.AddInstruction(
201 HloInstruction::CreateParameter(0, f32_input_shape, "a"));
202 HloInstruction* init = builder.AddInstruction(
203 HloInstruction::CreateParameter(1, bf16_scalar_shape, "init"));
204 HloInstruction* reduce = builder.AddInstruction(HloInstruction::CreateReduce(
205 f32_output_shape, input, init, {0}, reduce_computation));
206
207 auto computation = module->AddEntryComputation(builder.Build());
208
209 EXPECT_TRUE(Normalize(module.get()));
210
211 EXPECT_EQ(computation->root_instruction(), reduce);
212 EXPECT_EQ(reduce->called_computations().size(), 1);
213 EXPECT_EQ(reduce->called_computations()[0]->num_parameters(), 2);
214 EXPECT_EQ(reduce->called_computations()[0]
215 ->parameter_instruction(0)
216 ->shape()
217 .element_type(),
218 F32);
219 EXPECT_EQ(reduce->called_computations()[0]
220 ->parameter_instruction(1)
221 ->shape()
222 .element_type(),
223 F32);
224 EXPECT_EQ(reduce->called_computations()[0]
225 ->root_instruction()
226 ->shape()
227 .element_type(),
228 F32);
229 EXPECT_EQ(reduce->shape().element_type(), F32);
230 EXPECT_EQ(reduce->operand(0), input);
231 EXPECT_EQ(input->shape().element_type(), F32);
232 EXPECT_EQ(reduce->operand(1)->opcode(), HloOpcode::kConvert);
233 EXPECT_EQ(reduce->operand(1)->shape().element_type(), F32);
234 }
235
TEST_F(BFloat16NormalizationTest,ResolveMixedPrecisionTupleAllReduce)236 TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleAllReduce) {
237 auto module = CreateNewVerifiedModule();
238 HloComputation::Builder sum_builder("sum");
239 auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter(
240 /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x"));
241 auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter(
242 /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y"));
243 sum_builder.AddInstruction(HloInstruction::CreateBinary(
244 ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, x, y));
245 HloComputation* reduction =
246 module->AddEmbeddedComputation(sum_builder.Build());
247
248 auto builder = HloComputation::Builder(TestName());
249 Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
250 Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
251
252 HloInstruction* a = builder.AddInstruction(
253 HloInstruction::CreateParameter(0, f32_shape, "a"));
254 HloInstruction* b = builder.AddInstruction(
255 HloInstruction::CreateParameter(1, bf16_shape, "b"));
256
257 HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce(
258 ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction,
259 /*replica_groups=*/{}, /*barrier=*/"",
260 /*all_reduce_id=*/absl::nullopt));
261 HloInstruction* gte = builder.AddInstruction(
262 HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1));
263
264 auto computation = module->AddEntryComputation(builder.Build());
265
266 EXPECT_TRUE(Normalize(module.get()));
267
268 EXPECT_EQ(computation->root_instruction(), gte);
269 EXPECT_EQ(gte->shape().element_type(), BF16);
270 EXPECT_EQ(crs->operand(1)->shape().element_type(), F32);
271 EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), F32);
272 }
273
TEST_F(BFloat16NormalizationTest,ResolveMixedPrecisionTupleSort)274 TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) {
275 auto module = CreateNewVerifiedModule();
276 auto builder = HloComputation::Builder(TestName());
277 Shape f32_shape = ShapeUtil::MakeShape(F32, {1024});
278 Shape bf16_shape = ShapeUtil::MakeShape(BF16, {1024});
279 Shape s32_shape = ShapeUtil::MakeShape(BF16, {1024});
280
281 HloInstruction* key = builder.AddInstruction(
282 HloInstruction::CreateParameter(0, f32_shape, "key"));
283 HloInstruction* value = builder.AddInstruction(
284 HloInstruction::CreateParameter(1, s32_shape, "value"));
285
286 TF_ASSERT_OK_AND_ASSIGN(
287 auto* sort,
288 MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}),
289 {key, value}, 0, /*is_stable=*/false, &builder,
290 module.get()));
291 HloInstruction* gte = builder.AddInstruction(
292 HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0));
293
294 auto computation = module->AddEntryComputation(builder.Build());
295
296 EXPECT_TRUE(Normalize(module.get()));
297
298 EXPECT_EQ(computation->root_instruction(), gte);
299 EXPECT_EQ(gte->shape().element_type(), BF16);
300 EXPECT_EQ(sort->operand(0)->shape().element_type(), F32);
301 EXPECT_EQ(ShapeUtil::GetSubshape(sort->shape(), {0}).element_type(), F32);
302 }
303
TEST_F(BFloat16NormalizationTest,ResolveMixedPrecisionTupleSortRoot)304 TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSortRoot) {
305 auto module = CreateNewVerifiedModule();
306 auto builder = HloComputation::Builder(TestName());
307 Shape f32_shape = ShapeUtil::MakeShape(F32, {1024});
308 Shape bf16_shape = ShapeUtil::MakeShape(BF16, {1024});
309
310 HloInstruction* key = builder.AddInstruction(
311 HloInstruction::CreateParameter(0, f32_shape, "key"));
312 HloInstruction* value = builder.AddInstruction(
313 HloInstruction::CreateParameter(1, bf16_shape, "value"));
314
315 TF_ASSERT_OK_AND_ASSIGN(
316 auto* sort,
317 MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, f32_shape}),
318 {key, value}, 0, /*is_stable=*/false, &builder,
319 module.get()));
320
321 auto computation = module->AddEntryComputation(builder.Build());
322
323 EXPECT_TRUE(Normalize(module.get()));
324
325 EXPECT_EQ(sort->operand(0)->shape().element_type(), F32);
326 EXPECT_EQ(ShapeUtil::GetSubshape(sort->shape(), {0}).element_type(), F32);
327 EXPECT_NE(computation->root_instruction(), sort);
328 EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kTuple);
329 EXPECT_EQ(sort->to_apply()->parameter_instruction(1)->shape().element_type(),
330 F32);
331 // Make sure that no convert to BF16 was added to the 'to_apply' comparison
332 // computation.
333 auto users = sort->to_apply()->parameter_instruction(1)->users();
334 for (auto user : users) {
335 EXPECT_NE(user->opcode(), HloOpcode::kConvert);
336 }
337 }
338
339 // Tests that the normalization should not cause unsupported mixed precision due
340 // to resolving unsupported BF16 operand.
TEST_F(BFloat16NormalizationTest,DoNotAddUnsupportedMixedPrecision)341 TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) {
342 auto builder = HloComputation::Builder(TestName());
343 Shape bf16_shape = ShapeUtil::MakeShape(BF16, {4, 4});
344
345 HloInstruction* a = builder.AddInstruction(
346 HloInstruction::CreateParameter(0, bf16_shape, "a"));
347 HloInstruction* b = builder.AddInstruction(
348 HloInstruction::CreateParameter(1, bf16_shape, "b"));
349
350 DotDimensionNumbers dot_dnums;
351 dot_dnums.add_lhs_contracting_dimensions(1);
352 dot_dnums.add_rhs_contracting_dimensions(0);
353 PrecisionConfig precision_config;
354 precision_config.mutable_operand_precision()->Resize(
355 2, PrecisionConfig::DEFAULT);
356 HloInstruction* dot = builder.AddInstruction(
357 HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums, precision_config));
358
359 auto module = CreateNewVerifiedModule();
360 auto computation = module->AddEntryComputation(builder.Build());
361
362 EXPECT_TRUE(Normalize(module.get()));
363
364 EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
365 EXPECT_EQ(dot->shape().element_type(), F32);
366 EXPECT_EQ(dot->operand(0)->shape().element_type(), F32);
367 EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConvert);
368 EXPECT_EQ(dot->operand(1)->shape().element_type(), F32);
369 EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConvert);
370 }
371
372 } // namespace xla
373