1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h"
17 #include "tensorflow/compiler/xla/service/bfloat16_support.h"
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_module.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/test_helpers.h"
25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27
28 namespace xla {
29
30 class TestBFloat16Support : public BFloat16Support {
31 public:
TestBFloat16Support()32 TestBFloat16Support() {}
~TestBFloat16Support()33 ~TestBFloat16Support() override {}
34
SupportsBF16Operand(const HloInstruction & hlo,int64 operand_index) const35 bool SupportsBF16Operand(const HloInstruction& hlo,
36 int64 operand_index) const override {
37 if (hlo.opcode() == HloOpcode::kAdd ||
38 hlo.opcode() == HloOpcode::kSubtract ||
39 hlo.opcode() == HloOpcode::kTuple ||
40 hlo.opcode() == HloOpcode::kGetTupleElement ||
41 hlo.opcode() == HloOpcode::kAllReduce) {
42 return true;
43 }
44 return false;
45 }
46
SupportsBF16Output(const HloInstruction & hlo) const47 bool SupportsBF16Output(const HloInstruction& hlo) const override {
48 if (hlo.opcode() == HloOpcode::kAdd ||
49 hlo.opcode() == HloOpcode::kSubtract ||
50 hlo.opcode() == HloOpcode::kTuple ||
51 hlo.opcode() == HloOpcode::kGetTupleElement ||
52 hlo.opcode() == HloOpcode::kAllReduce) {
53 return true;
54 }
55 return false;
56 }
57
SupportsMixedPrecisions(const HloInstruction & hlo) const58 bool SupportsMixedPrecisions(const HloInstruction& hlo) const override {
59 if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple ||
60 hlo.opcode() == HloOpcode::kGetTupleElement ||
61 hlo.opcode() == HloOpcode::kAllReduce) {
62 return true;
63 }
64 return false;
65 }
66 };
67
68 class BFloat16ConversionFoldingTest : public HloTestBase {
69 protected:
BFloat16ConversionFoldingTest()70 BFloat16ConversionFoldingTest()
71 : HloTestBase(/*verifier_layout_sensitive=*/false,
72 /*allow_mixed_precision_in_hlo_verifier=*/true) {}
73
FoldConversions(HloModule * module)74 bool FoldConversions(HloModule* module) {
75 TestBFloat16Support bfloat16_support_;
76 BFloat16ConversionFolding fold(&bfloat16_support_);
77 StatusOr<bool> result = fold.Run(module);
78 EXPECT_IS_OK(result.status());
79 return result.ValueOrDie();
80 }
81 };
82
TEST_F(BFloat16ConversionFoldingTest,FoldIfSupported)83 TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) {
84 auto builder = HloComputation::Builder(TestName());
85 Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
86 Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
87
88 HloInstruction* a = builder.AddInstruction(
89 HloInstruction::CreateParameter(0, f32_shape, "a"));
90 HloInstruction* b = builder.AddInstruction(
91 HloInstruction::CreateParameter(1, f32_shape, "b"));
92 HloInstruction* c = builder.AddInstruction(
93 HloInstruction::CreateParameter(2, f32_shape, "c"));
94
95 HloInstruction* add0 = builder.AddInstruction(
96 HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, a, b));
97 HloInstruction* convert0 =
98 builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add0));
99 HloInstruction* convert1 = builder.AddInstruction(
100 HloInstruction::CreateConvert(f32_shape, convert0));
101
102 HloInstruction* add1 = builder.AddInstruction(
103 HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, convert1, c));
104 builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add1));
105
106 auto module = CreateNewVerifiedModule();
107 auto computation = module->AddEntryComputation(builder.Build());
108
109 EXPECT_TRUE(FoldConversions(module.get()));
110
111 EXPECT_EQ(computation->root_instruction(), add1);
112 EXPECT_EQ(add0->shape().element_type(), BF16);
113 EXPECT_EQ(add1->shape().element_type(), BF16);
114 EXPECT_EQ(add1->operand(0), add0);
115 }
116
TEST_F(BFloat16ConversionFoldingTest,DoNotFoldIfUnsupported)117 TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) {
118 auto builder = HloComputation::Builder(TestName());
119 Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
120 Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
121
122 HloInstruction* a = builder.AddInstruction(
123 HloInstruction::CreateParameter(0, f32_shape, "a"));
124 HloInstruction* b = builder.AddInstruction(
125 HloInstruction::CreateParameter(1, f32_shape, "b"));
126 HloInstruction* c = builder.AddInstruction(
127 HloInstruction::CreateParameter(2, f32_shape, "c"));
128
129 HloInstruction* mul0 = builder.AddInstruction(
130 HloInstruction::CreateBinary(f32_shape, HloOpcode::kMultiply, a, b));
131 HloInstruction* convert0 =
132 builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul0));
133 HloInstruction* convert1 = builder.AddInstruction(
134 HloInstruction::CreateConvert(f32_shape, convert0));
135
136 HloInstruction* mul1 = builder.AddInstruction(HloInstruction::CreateBinary(
137 f32_shape, HloOpcode::kMultiply, convert1, c));
138 HloInstruction* convert2 =
139 builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul1));
140
141 auto module = CreateNewVerifiedModule();
142 auto computation = module->AddEntryComputation(builder.Build());
143
144 EXPECT_FALSE(FoldConversions(module.get()));
145
146 EXPECT_EQ(computation->root_instruction(), convert2);
147 EXPECT_EQ(mul0->shape().element_type(), F32);
148 EXPECT_EQ(mul1->shape().element_type(), F32);
149 EXPECT_EQ(mul1->operand(0), convert1);
150 }
151
TEST_F(BFloat16ConversionFoldingTest,DoNotFoldUnsupportedMixedPrecision)152 TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) {
153 auto builder = HloComputation::Builder(TestName());
154 Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
155 Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
156
157 HloInstruction* a = builder.AddInstruction(
158 HloInstruction::CreateParameter(0, f32_shape, "a"));
159 HloInstruction* b = builder.AddInstruction(
160 HloInstruction::CreateParameter(1, f32_shape, "b"));
161 HloInstruction* c = builder.AddInstruction(
162 HloInstruction::CreateParameter(2, f32_shape, "c"));
163
164 HloInstruction* sub0 = builder.AddInstruction(
165 HloInstruction::CreateBinary(f32_shape, HloOpcode::kSubtract, a, b));
166 HloInstruction* convert0 =
167 builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub0));
168 HloInstruction* convert1 = builder.AddInstruction(
169 HloInstruction::CreateConvert(f32_shape, convert0));
170
171 HloInstruction* sub1 = builder.AddInstruction(HloInstruction::CreateBinary(
172 f32_shape, HloOpcode::kSubtract, convert1, c));
173 HloInstruction* convert2 =
174 builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub1));
175
176 auto module = CreateNewVerifiedModule();
177 auto computation = module->AddEntryComputation(builder.Build());
178
179 EXPECT_FALSE(FoldConversions(module.get()));
180
181 EXPECT_EQ(computation->root_instruction(), convert2);
182 EXPECT_EQ(sub0->shape().element_type(), F32);
183 EXPECT_EQ(sub1->shape().element_type(), F32);
184 EXPECT_EQ(sub1->operand(0), convert1);
185 }
186
TEST_F(BFloat16ConversionFoldingTest,DoNotFoldTuple)187 TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) {
188 auto builder = HloComputation::Builder(TestName());
189 Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
190 Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
191
192 HloInstruction* a = builder.AddInstruction(
193 HloInstruction::CreateParameter(0, f32_shape, "a"));
194 HloInstruction* b = builder.AddInstruction(
195 HloInstruction::CreateParameter(1, bf16_shape, "b"));
196 HloInstruction* convert0 =
197 builder.AddInstruction(HloInstruction::CreateConvert(f32_shape, b));
198
199 HloInstruction* tuple =
200 builder.AddInstruction(HloInstruction::CreateTuple({a, convert0}));
201 HloInstruction* gte = builder.AddInstruction(
202 HloInstruction::CreateGetTupleElement(f32_shape, tuple, 0));
203 HloInstruction* convert1 =
204 builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte));
205
206 auto module = CreateNewVerifiedModule();
207 auto computation = module->AddEntryComputation(builder.Build());
208
209 EXPECT_FALSE(FoldConversions(module.get()));
210
211 EXPECT_EQ(computation->root_instruction(), convert1);
212 EXPECT_EQ(gte->shape().element_type(), F32);
213 EXPECT_EQ(tuple->operand(1), convert0);
214 }
215
TEST_F(BFloat16ConversionFoldingTest,FoldAllReduceTupleOutput)216 TEST_F(BFloat16ConversionFoldingTest, FoldAllReduceTupleOutput) {
217 auto builder = HloComputation::Builder(TestName());
218
219 auto module = CreateNewVerifiedModule();
220 HloComputation::Builder sum_builder("add");
221 auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter(
222 /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x"));
223 auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter(
224 /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y"));
225 sum_builder.AddInstruction(HloInstruction::CreateBinary(
226 ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, x, y));
227 HloComputation* sum = module->AddEmbeddedComputation(sum_builder.Build());
228
229 Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
230 Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
231
232 HloInstruction* a = builder.AddInstruction(
233 HloInstruction::CreateParameter(0, bf16_shape, "a"));
234 HloInstruction* convert_a =
235 builder.AddInstruction(HloInstruction::CreateConvert(f32_shape, a));
236 HloInstruction* b = builder.AddInstruction(
237 HloInstruction::CreateParameter(1, f32_shape, "b"));
238
239 HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce(
240 ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, sum,
241 /*replica_groups=*/{}, /*barrier=*/"",
242 /*all_reduce_id=*/absl::nullopt));
243 HloInstruction* gte_a = builder.AddInstruction(
244 HloInstruction::CreateGetTupleElement(f32_shape, crs, 0));
245 HloInstruction* gte_b = builder.AddInstruction(
246 HloInstruction::CreateGetTupleElement(f32_shape, crs, 1));
247 HloInstruction* convert_gte_b =
248 builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte_b));
249 HloInstruction* tuple = builder.AddInstruction(
250 HloInstruction::CreateTuple({gte_a, convert_gte_b}));
251
252 auto computation = module->AddEntryComputation(builder.Build());
253
254 EXPECT_TRUE(FoldConversions(module.get()));
255
256 EXPECT_EQ(computation->root_instruction(), tuple);
257 EXPECT_EQ(tuple->operand(0), gte_a);
258 EXPECT_EQ(tuple->operand(1), gte_b);
259 EXPECT_EQ(gte_a->shape().element_type(), F32);
260 EXPECT_EQ(gte_b->shape().element_type(), BF16);
261 EXPECT_EQ(crs->operand(0), a);
262 EXPECT_EQ(crs->operand(1), b);
263 EXPECT_EQ(a->shape().element_type(), BF16);
264 EXPECT_EQ(b->shape().element_type(), F32);
265 EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {0}).element_type(), F32);
266 EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), BF16);
267 }
268
269 } // namespace xla
270