1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h"
17 #include "tensorflow/compiler/xla/service/bfloat16_support.h"
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_module.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/test_helpers.h"
25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 
28 namespace xla {
29 
30 class TestBFloat16Support : public BFloat16Support {
31  public:
TestBFloat16Support()32   TestBFloat16Support() {}
~TestBFloat16Support()33   ~TestBFloat16Support() override {}
34 
SupportsBF16Operand(const HloInstruction & hlo,int64 operand_index) const35   bool SupportsBF16Operand(const HloInstruction& hlo,
36                            int64 operand_index) const override {
37     if (hlo.opcode() == HloOpcode::kAdd ||
38         hlo.opcode() == HloOpcode::kSubtract ||
39         hlo.opcode() == HloOpcode::kTuple ||
40         hlo.opcode() == HloOpcode::kGetTupleElement ||
41         hlo.opcode() == HloOpcode::kAllReduce) {
42       return true;
43     }
44     return false;
45   }
46 
SupportsBF16Output(const HloInstruction & hlo) const47   bool SupportsBF16Output(const HloInstruction& hlo) const override {
48     if (hlo.opcode() == HloOpcode::kAdd ||
49         hlo.opcode() == HloOpcode::kSubtract ||
50         hlo.opcode() == HloOpcode::kTuple ||
51         hlo.opcode() == HloOpcode::kGetTupleElement ||
52         hlo.opcode() == HloOpcode::kAllReduce) {
53       return true;
54     }
55     return false;
56   }
57 
SupportsMixedPrecisions(const HloInstruction & hlo) const58   bool SupportsMixedPrecisions(const HloInstruction& hlo) const override {
59     if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple ||
60         hlo.opcode() == HloOpcode::kGetTupleElement ||
61         hlo.opcode() == HloOpcode::kAllReduce) {
62       return true;
63     }
64     return false;
65   }
66 };
67 
68 class BFloat16ConversionFoldingTest : public HloTestBase {
69  protected:
BFloat16ConversionFoldingTest()70   BFloat16ConversionFoldingTest()
71       : HloTestBase(/*verifier_layout_sensitive=*/false,
72                     /*allow_mixed_precision_in_hlo_verifier=*/true) {}
73 
FoldConversions(HloModule * module)74   bool FoldConversions(HloModule* module) {
75     TestBFloat16Support bfloat16_support_;
76     BFloat16ConversionFolding fold(&bfloat16_support_);
77     StatusOr<bool> result = fold.Run(module);
78     EXPECT_IS_OK(result.status());
79     return result.ValueOrDie();
80   }
81 };
82 
TEST_F(BFloat16ConversionFoldingTest,FoldIfSupported)83 TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) {
84   auto builder = HloComputation::Builder(TestName());
85   Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
86   Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
87 
88   HloInstruction* a = builder.AddInstruction(
89       HloInstruction::CreateParameter(0, f32_shape, "a"));
90   HloInstruction* b = builder.AddInstruction(
91       HloInstruction::CreateParameter(1, f32_shape, "b"));
92   HloInstruction* c = builder.AddInstruction(
93       HloInstruction::CreateParameter(2, f32_shape, "c"));
94 
95   HloInstruction* add0 = builder.AddInstruction(
96       HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, a, b));
97   HloInstruction* convert0 =
98       builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add0));
99   HloInstruction* convert1 = builder.AddInstruction(
100       HloInstruction::CreateConvert(f32_shape, convert0));
101 
102   HloInstruction* add1 = builder.AddInstruction(
103       HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, convert1, c));
104   builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add1));
105 
106   auto module = CreateNewVerifiedModule();
107   auto computation = module->AddEntryComputation(builder.Build());
108 
109   EXPECT_TRUE(FoldConversions(module.get()));
110 
111   EXPECT_EQ(computation->root_instruction(), add1);
112   EXPECT_EQ(add0->shape().element_type(), BF16);
113   EXPECT_EQ(add1->shape().element_type(), BF16);
114   EXPECT_EQ(add1->operand(0), add0);
115 }
116 
TEST_F(BFloat16ConversionFoldingTest,DoNotFoldIfUnsupported)117 TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) {
118   auto builder = HloComputation::Builder(TestName());
119   Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
120   Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
121 
122   HloInstruction* a = builder.AddInstruction(
123       HloInstruction::CreateParameter(0, f32_shape, "a"));
124   HloInstruction* b = builder.AddInstruction(
125       HloInstruction::CreateParameter(1, f32_shape, "b"));
126   HloInstruction* c = builder.AddInstruction(
127       HloInstruction::CreateParameter(2, f32_shape, "c"));
128 
129   HloInstruction* mul0 = builder.AddInstruction(
130       HloInstruction::CreateBinary(f32_shape, HloOpcode::kMultiply, a, b));
131   HloInstruction* convert0 =
132       builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul0));
133   HloInstruction* convert1 = builder.AddInstruction(
134       HloInstruction::CreateConvert(f32_shape, convert0));
135 
136   HloInstruction* mul1 = builder.AddInstruction(HloInstruction::CreateBinary(
137       f32_shape, HloOpcode::kMultiply, convert1, c));
138   HloInstruction* convert2 =
139       builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul1));
140 
141   auto module = CreateNewVerifiedModule();
142   auto computation = module->AddEntryComputation(builder.Build());
143 
144   EXPECT_FALSE(FoldConversions(module.get()));
145 
146   EXPECT_EQ(computation->root_instruction(), convert2);
147   EXPECT_EQ(mul0->shape().element_type(), F32);
148   EXPECT_EQ(mul1->shape().element_type(), F32);
149   EXPECT_EQ(mul1->operand(0), convert1);
150 }
151 
TEST_F(BFloat16ConversionFoldingTest,DoNotFoldUnsupportedMixedPrecision)152 TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) {
153   auto builder = HloComputation::Builder(TestName());
154   Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
155   Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
156 
157   HloInstruction* a = builder.AddInstruction(
158       HloInstruction::CreateParameter(0, f32_shape, "a"));
159   HloInstruction* b = builder.AddInstruction(
160       HloInstruction::CreateParameter(1, f32_shape, "b"));
161   HloInstruction* c = builder.AddInstruction(
162       HloInstruction::CreateParameter(2, f32_shape, "c"));
163 
164   HloInstruction* sub0 = builder.AddInstruction(
165       HloInstruction::CreateBinary(f32_shape, HloOpcode::kSubtract, a, b));
166   HloInstruction* convert0 =
167       builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub0));
168   HloInstruction* convert1 = builder.AddInstruction(
169       HloInstruction::CreateConvert(f32_shape, convert0));
170 
171   HloInstruction* sub1 = builder.AddInstruction(HloInstruction::CreateBinary(
172       f32_shape, HloOpcode::kSubtract, convert1, c));
173   HloInstruction* convert2 =
174       builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub1));
175 
176   auto module = CreateNewVerifiedModule();
177   auto computation = module->AddEntryComputation(builder.Build());
178 
179   EXPECT_FALSE(FoldConversions(module.get()));
180 
181   EXPECT_EQ(computation->root_instruction(), convert2);
182   EXPECT_EQ(sub0->shape().element_type(), F32);
183   EXPECT_EQ(sub1->shape().element_type(), F32);
184   EXPECT_EQ(sub1->operand(0), convert1);
185 }
186 
TEST_F(BFloat16ConversionFoldingTest,DoNotFoldTuple)187 TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) {
188   auto builder = HloComputation::Builder(TestName());
189   Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
190   Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
191 
192   HloInstruction* a = builder.AddInstruction(
193       HloInstruction::CreateParameter(0, f32_shape, "a"));
194   HloInstruction* b = builder.AddInstruction(
195       HloInstruction::CreateParameter(1, bf16_shape, "b"));
196   HloInstruction* convert0 =
197       builder.AddInstruction(HloInstruction::CreateConvert(f32_shape, b));
198 
199   HloInstruction* tuple =
200       builder.AddInstruction(HloInstruction::CreateTuple({a, convert0}));
201   HloInstruction* gte = builder.AddInstruction(
202       HloInstruction::CreateGetTupleElement(f32_shape, tuple, 0));
203   HloInstruction* convert1 =
204       builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte));
205 
206   auto module = CreateNewVerifiedModule();
207   auto computation = module->AddEntryComputation(builder.Build());
208 
209   EXPECT_FALSE(FoldConversions(module.get()));
210 
211   EXPECT_EQ(computation->root_instruction(), convert1);
212   EXPECT_EQ(gte->shape().element_type(), F32);
213   EXPECT_EQ(tuple->operand(1), convert0);
214 }
215 
TEST_F(BFloat16ConversionFoldingTest,FoldAllReduceTupleOutput)216 TEST_F(BFloat16ConversionFoldingTest, FoldAllReduceTupleOutput) {
217   auto builder = HloComputation::Builder(TestName());
218 
219   auto module = CreateNewVerifiedModule();
220   HloComputation::Builder sum_builder("add");
221   auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter(
222       /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x"));
223   auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter(
224       /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y"));
225   sum_builder.AddInstruction(HloInstruction::CreateBinary(
226       ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, x, y));
227   HloComputation* sum = module->AddEmbeddedComputation(sum_builder.Build());
228 
229   Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
230   Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
231 
232   HloInstruction* a = builder.AddInstruction(
233       HloInstruction::CreateParameter(0, bf16_shape, "a"));
234   HloInstruction* convert_a =
235       builder.AddInstruction(HloInstruction::CreateConvert(f32_shape, a));
236   HloInstruction* b = builder.AddInstruction(
237       HloInstruction::CreateParameter(1, f32_shape, "b"));
238 
239   HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce(
240       ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, sum,
241       /*replica_groups=*/{}, /*barrier=*/"",
242       /*all_reduce_id=*/absl::nullopt));
243   HloInstruction* gte_a = builder.AddInstruction(
244       HloInstruction::CreateGetTupleElement(f32_shape, crs, 0));
245   HloInstruction* gte_b = builder.AddInstruction(
246       HloInstruction::CreateGetTupleElement(f32_shape, crs, 1));
247   HloInstruction* convert_gte_b =
248       builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte_b));
249   HloInstruction* tuple = builder.AddInstruction(
250       HloInstruction::CreateTuple({gte_a, convert_gte_b}));
251 
252   auto computation = module->AddEntryComputation(builder.Build());
253 
254   EXPECT_TRUE(FoldConversions(module.get()));
255 
256   EXPECT_EQ(computation->root_instruction(), tuple);
257   EXPECT_EQ(tuple->operand(0), gte_a);
258   EXPECT_EQ(tuple->operand(1), gte_b);
259   EXPECT_EQ(gte_a->shape().element_type(), F32);
260   EXPECT_EQ(gte_b->shape().element_type(), BF16);
261   EXPECT_EQ(crs->operand(0), a);
262   EXPECT_EQ(crs->operand(1), b);
263   EXPECT_EQ(a->shape().element_type(), BF16);
264   EXPECT_EQ(b->shape().element_type(), F32);
265   EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {0}).element_type(), F32);
266   EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), BF16);
267 }
268 
269 }  // namespace xla
270