1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
17
18 #include "absl/types/span.h"
19 #include "tensorflow/compiler/xla/service/hlo_computation.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/compiler/xla/xla_data.pb.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/types.h"
27
28 namespace xla {
29
30 class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
31 public:
BFloat16NormalizationVisitor(HloComputation * computation,const BFloat16Support * bfloat16_support)32 explicit BFloat16NormalizationVisitor(HloComputation* computation,
33 const BFloat16Support* bfloat16_support)
34 : computation_(computation), bfloat16_support_(bfloat16_support) {}
35
36 Status DefaultAction(HloInstruction* hlo) override;
37
Run(HloComputation * computation,const BFloat16Support * bfloat16_support)38 static bool Run(HloComputation* computation,
39 const BFloat16Support* bfloat16_support) {
40 BFloat16NormalizationVisitor visitor(computation, bfloat16_support);
41 TF_CHECK_OK(computation->Accept(&visitor));
42 return visitor.changed_;
43 }
44
45 private:
46 // Checks if the HLO uses BF16 in an unsupported way, and if so, inserts
47 // conversions between F32 and BF16 to make it supported.
48 Status HandleInstruction(HloInstruction* hlo);
49
50 // Handle instructions with tuple outputs by examining each output
51 // independently.
52 Status HandleMultipleOutputs(HloInstruction* hlo);
53
54 // Inserts a conversion HLO that changes the given HLO's output type.
55 Status InsertConvertAfterOutput(HloInstruction* hlo, PrimitiveType to,
56 HloComputation* computation);
57
58 // Changes the output type to the specified type, then inserts a conversion
59 // to the original type.
60 Status ChangeOutputTypeThenInsertConvertBack(HloInstruction* hlo,
61 PrimitiveType to,
62 HloComputation* computation);
63
64 // Inserts a conversion HLO that changes the given HLO's operand type.
65 Status InsertConvertBeforeOperand(HloInstruction* hlo, int64 operand_idx,
66 PrimitiveType to,
67 HloComputation* computation);
68
69 // Inserts conversion HLOs to replace the called computations' BF16
70 // operands/outputs to F32.
71 Status ConvertCalledComputations(
72 HloInstruction* hlo, absl::Span<HloComputation* const> bf16_called_comps);
73
74 HloComputation* computation_;
75 const BFloat16Support* bfloat16_support_;
76 bool changed_ = false;
77 };
78
InsertConvertAfterOutput(HloInstruction * hlo,PrimitiveType to,HloComputation * computation)79 Status BFloat16NormalizationVisitor::InsertConvertAfterOutput(
80 HloInstruction* hlo, PrimitiveType to, HloComputation* computation) {
81 bool is_root = computation->root_instruction() == hlo;
82 std::vector<HloInstruction*> materialized_users = hlo->users();
83 // Use inst's shape temporarily, in order to pass checks in ReplaceUseWith.
84 auto convert = computation->AddInstruction(
85 HloInstruction::CreateConvert(hlo->shape(), hlo));
86 for (auto* user : materialized_users) {
87 if (user->opcode() == HloOpcode::kConvert &&
88 user->shape().element_type() == F32) {
89 TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo));
90 } else {
91 TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, convert));
92 }
93 }
94 if (is_root) {
95 computation->set_root_instruction(convert);
96 }
97 convert->mutable_shape()->set_element_type(to);
98 changed_ = true;
99 return Status::OK();
100 }
101
ChangeOutputTypeThenInsertConvertBack(HloInstruction * hlo,PrimitiveType to,HloComputation * computation)102 Status BFloat16NormalizationVisitor::ChangeOutputTypeThenInsertConvertBack(
103 HloInstruction* hlo, PrimitiveType to, HloComputation* computation) {
104 auto original_type = hlo->shape().element_type();
105 hlo->mutable_shape()->set_element_type(to);
106 return InsertConvertAfterOutput(hlo, original_type, computation);
107 }
108
InsertConvertBeforeOperand(HloInstruction * hlo,int64 operand_idx,PrimitiveType to,HloComputation * computation)109 Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand(
110 HloInstruction* hlo, int64 operand_idx, PrimitiveType to,
111 HloComputation* computation) {
112 auto operand = hlo->mutable_operand(operand_idx);
113 auto convert = computation->AddInstruction(HloInstruction::CreateConvert(
114 ShapeUtil::ChangeElementType(operand->shape(), to), operand));
115 TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(operand_idx, convert));
116 changed_ = true;
117 return Status::OK();
118 }
119
ConvertCalledComputations(HloInstruction * hlo,absl::Span<HloComputation * const> bf16_called_comps)120 Status BFloat16NormalizationVisitor::ConvertCalledComputations(
121 HloInstruction* hlo, absl::Span<HloComputation* const> bf16_called_comps) {
122 std::map<HloComputation*, HloComputation*> cloned_computations;
123 for (auto& comp : bf16_called_comps) {
124 auto cloned = comp->parent()->AddEmbeddedComputation(comp->Clone());
125 cloned_computations[comp] = cloned;
126 changed_ = true;
127 }
128 hlo->ReplaceCalledComputations([&](HloComputation* comp) {
129 auto it = cloned_computations.find(comp);
130 if (it != cloned_computations.end()) {
131 return it->second;
132 }
133 return comp;
134 });
135 for (auto& comp_pair : cloned_computations) {
136 auto comp = comp_pair.second;
137 if (comp->root_instruction()->shape().element_type() == BF16) {
138 TF_RETURN_IF_ERROR(
139 InsertConvertAfterOutput(comp->root_instruction(), F32, comp));
140 }
141 for (auto* param : comp->parameter_instructions()) {
142 if (param->shape().element_type() == BF16) {
143 // This changes the parameter to F32 then inserts a convert after it.
144 TF_RETURN_IF_ERROR(
145 ChangeOutputTypeThenInsertConvertBack(param, F32, comp));
146 }
147 }
148 }
149 return Status::OK();
150 }
151
HandleMultipleOutputs(HloInstruction * hlo)152 Status BFloat16NormalizationVisitor::HandleMultipleOutputs(
153 HloInstruction* hlo) {
154 std::vector<PrimitiveType> operand_types(hlo->operand_count());
155 std::vector<PrimitiveType> output_types(hlo->operand_count());
156 int64 f32_count = 0;
157 int64 bf16_count = 0;
158 bool has_unsupported_bf16_operand = false;
159 bool has_unsupported_bf16_output = false;
160 for (int64 i = 0; i < hlo->operand_count(); ++i) {
161 operand_types[i] = hlo->operand(i)->shape().element_type();
162 output_types[i] = ShapeUtil::GetSubshape(hlo->shape(), {i}).element_type();
163 if (operand_types[i] == F32) {
164 f32_count += 1;
165 } else if (operand_types[i] == BF16) {
166 bf16_count += 1;
167 if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
168 has_unsupported_bf16_operand = true;
169 }
170 }
171 if (output_types[i] == F32) {
172 f32_count += 1;
173 } else if (output_types[i] == BF16) {
174 bf16_count += 1;
175 if (!bfloat16_support_->SupportsBF16Output(*hlo)) {
176 has_unsupported_bf16_output = true;
177 }
178 }
179 }
180
181 if (bf16_count == 0) {
182 return Status::OK();
183 }
184
185 auto should_convert_operand = [&](int64 i) {
186 if (operand_types[i] != BF16) {
187 return false;
188 }
189 if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
190 return true;
191 }
192 if (bfloat16_support_->SupportsMixedPrecisions(*hlo)) {
193 return false;
194 }
195 return has_unsupported_bf16_operand || has_unsupported_bf16_output ||
196 f32_count > 0;
197 };
198
199 for (int64 i = 0; i < hlo->operand_count(); ++i) {
200 if (should_convert_operand(i)) {
201 TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_));
202 f32_count += 1;
203 bf16_count -= 1;
204 }
205 }
206
207 if (!has_unsupported_bf16_output &&
208 (bfloat16_support_->SupportsMixedPrecisions(*hlo) || f32_count == 0 ||
209 bf16_count == 0)) {
210 return Status::OK();
211 }
212
213 std::vector<HloComputation*> bf16_called_comps;
214 for (auto* comp : hlo->called_computations()) {
215 bool comp_has_bf16 = false;
216 if (comp->root_instruction()->shape().element_type() == F32) {
217 f32_count += 1;
218 } else if (comp->root_instruction()->shape().element_type() == BF16) {
219 bf16_count += 1;
220 comp_has_bf16 = true;
221 }
222 for (auto* param : comp->parameter_instructions()) {
223 if (param->shape().element_type() == F32) {
224 f32_count += 1;
225 } else if (param->shape().element_type() == BF16) {
226 bf16_count += 1;
227 comp_has_bf16 = true;
228 }
229 }
230 if (comp_has_bf16) {
231 bf16_called_comps.push_back(comp);
232 }
233 }
234
235 std::vector<HloInstruction*> materialized_users = hlo->users();
236 std::vector<HloInstruction*> output_elements(hlo->operand_count());
237 auto original_shape = hlo->shape();
238 for (int64 i = 0; i < hlo->operand_count(); ++i) {
239 auto subshape = ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), {i});
240 if (output_types[i] != BF16) {
241 output_elements[i] = computation_->AddInstruction(
242 HloInstruction::CreateGetTupleElement(*subshape, hlo, i));
243 continue;
244 }
245 subshape->set_element_type(F32);
246 auto gte = computation_->AddInstruction(
247 HloInstruction::CreateGetTupleElement(*subshape, hlo, i));
248 output_elements[i] =
249 computation_->AddInstruction(HloInstruction::CreateConvert(
250 ShapeUtil::ChangeElementType(*subshape, BF16), gte));
251 }
252 auto tuple = computation_->AddInstruction(
253 HloInstruction::CreateTuple(output_elements));
254
255 // Use the hlo' shape temporarily, in order to pass checks in
256 // ReplaceUseWith.
257 *tuple->mutable_shape() = hlo->shape();
258 for (auto* user : materialized_users) {
259 TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, tuple));
260 }
261 bool is_root = computation_->root_instruction() == hlo;
262 if (is_root) {
263 computation_->set_root_instruction(tuple);
264 }
265 *tuple->mutable_shape() = original_shape;
266 return ConvertCalledComputations(hlo, bf16_called_comps);
267 }
268
HandleInstruction(HloInstruction * hlo)269 Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) {
270 int f32_count = 0;
271 int bf16_count = 1;
272
273 for (int64 i = 0; i < hlo->operand_count(); ++i) {
274 if (hlo->operand(i)->shape().element_type() == F32) {
275 f32_count += 1;
276 } else if (hlo->operand(i)->shape().element_type() == BF16) {
277 bf16_count += 1;
278 }
279 }
280
281 if (hlo->shape().element_type() == F32) {
282 f32_count += 1;
283 } else if (hlo->shape().element_type() == BF16) {
284 bf16_count += 1;
285 }
286
287 std::vector<HloComputation*> bf16_called_comps;
288 for (auto* comp : hlo->called_computations()) {
289 bool comp_has_bf16 = false;
290 if (comp->root_instruction()->shape().element_type() == F32) {
291 f32_count += 1;
292 } else if (comp->root_instruction()->shape().element_type() == BF16) {
293 bf16_count += 1;
294 comp_has_bf16 = true;
295 }
296 for (auto* param : comp->parameter_instructions()) {
297 if (param->shape().element_type() == F32) {
298 f32_count += 1;
299 } else if (param->shape().element_type() == BF16) {
300 bf16_count += 1;
301 comp_has_bf16 = true;
302 }
303 }
304 if (comp_has_bf16) {
305 bf16_called_comps.push_back(comp);
306 }
307 }
308
309 // Resolve unsupported BF16 operands.
310 for (int i = 0; i < hlo->operand_count(); ++i) {
311 if (hlo->operand(i)->shape().element_type() == BF16 &&
312 !bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
313 TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_));
314 bf16_count -= 1;
315 f32_count += 1;
316 }
317 }
318
319 // Resolve unsupported BF16 output.
320 if (hlo->shape().element_type() == BF16 &&
321 !bfloat16_support_->SupportsBF16Output(*hlo)) {
322 TF_RETURN_IF_ERROR(
323 ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_));
324 bf16_count -= 1;
325 f32_count += 1;
326 }
327
328 // Resolve unsupported mixed precision after resolving unsupported BF16
329 // operands and output, because the numbers of BF16 operands/output and F32
330 // operands/output may have changed.
331 if (bfloat16_support_->SupportsMixedPrecisions(*hlo) || bf16_count == 0 ||
332 f32_count == 0) {
333 return Status::OK();
334 }
335 // See if we can change everything to BF16.
336 if (hlo->called_computations().empty() &&
337 hlo->shape().element_type() == BF16) {
338 bool can_use_bf16 = true;
339 for (int i = 0; i < hlo->operand_count(); ++i) {
340 if (hlo->operand(i)->shape().element_type() == BF16) {
341 continue;
342 }
343 if ((bfloat16_support_->EffectiveOperandPrecisionIsBF16(*hlo, i) ||
344 bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo,
345 i)) &&
346 bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
347 continue;
348 }
349 can_use_bf16 = false;
350 break;
351 }
352 if (can_use_bf16) {
353 for (int i = 0; i < hlo->operand_count(); ++i) {
354 if (hlo->operand(i)->shape().element_type() == F32) {
355 TF_RETURN_IF_ERROR(
356 InsertConvertBeforeOperand(hlo, i, BF16, computation_));
357 }
358 }
359 return Status::OK();
360 }
361 }
362 if (hlo->shape().element_type() == BF16) {
363 TF_RETURN_IF_ERROR(
364 ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_));
365 }
366 for (int i = 0; i < hlo->operand_count(); ++i) {
367 if (hlo->operand(i)->shape().element_type() == BF16) {
368 TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_));
369 }
370 }
371 return ConvertCalledComputations(hlo, bf16_called_comps);
372 }
373
DefaultAction(HloInstruction * hlo)374 Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) {
375 // Do not change instructions related to entry and exit of a computation,
376 // tuples, fusion, convert, side-effecting instructions, and control flow.
377 if (hlo->opcode() == HloOpcode::kTuple || //
378 hlo->opcode() == HloOpcode::kGetTupleElement || //
379 hlo->opcode() == HloOpcode::kConstant || //
380 hlo->opcode() == HloOpcode::kParameter || //
381 hlo->opcode() == HloOpcode::kFusion || //
382 hlo->opcode() == HloOpcode::kConvert || //
383 hlo->opcode() == HloOpcode::kCall || //
384 hlo->opcode() == HloOpcode::kCustomCall || //
385 hlo->opcode() == HloOpcode::kWhile || //
386 hlo->opcode() == HloOpcode::kConditional || //
387 hlo->HasSideEffectNoRecurse()) {
388 return Status::OK();
389 }
390 // TODO(b/112040122): Correctly normalize variadic reduce.
391 if ((hlo->opcode() == HloOpcode::kSort ||
392 hlo->opcode() == HloOpcode::kAllReduce) &&
393 hlo->shape().IsTuple()) {
394 return HandleMultipleOutputs(hlo);
395 }
396 return HandleInstruction(hlo);
397 }
398
Run(HloModule * module)399 StatusOr<bool> BFloat16Normalization::Run(HloModule* module) {
400 XLA_VLOG_LINES(
401 2, "BFloat16Normalization::Run(), before:\n" + module->ToString());
402 bool changed = false;
403 for (auto* comp : module->MakeComputationPostOrder()) {
404 if (BFloat16NormalizationVisitor::Run(comp, bfloat16_support_)) {
405 changed = true;
406 }
407 }
408 XLA_VLOG_LINES(2,
409 "BFloat16Normalization::Run(), after:\n" + module->ToString());
410 return changed;
411 }
412
413 } // namespace xla
414