1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
17 
18 #include "absl/types/span.h"
19 #include "tensorflow/compiler/xla/service/hlo_computation.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/compiler/xla/xla_data.pb.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/types.h"
27 
28 namespace xla {
29 
30 class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
31  public:
BFloat16NormalizationVisitor(HloComputation * computation,const BFloat16Support * bfloat16_support)32   explicit BFloat16NormalizationVisitor(HloComputation* computation,
33                                         const BFloat16Support* bfloat16_support)
34       : computation_(computation), bfloat16_support_(bfloat16_support) {}
35 
36   Status DefaultAction(HloInstruction* hlo) override;
37 
Run(HloComputation * computation,const BFloat16Support * bfloat16_support)38   static bool Run(HloComputation* computation,
39                   const BFloat16Support* bfloat16_support) {
40     BFloat16NormalizationVisitor visitor(computation, bfloat16_support);
41     TF_CHECK_OK(computation->Accept(&visitor));
42     return visitor.changed_;
43   }
44 
45  private:
46   // Checks if the HLO uses BF16 in an unsupported way, and if so, inserts
47   // conversions between F32 and BF16 to make it supported.
48   Status HandleInstruction(HloInstruction* hlo);
49 
50   // Handle instructions with tuple outputs by examining each output
51   // independently.
52   Status HandleMultipleOutputs(HloInstruction* hlo);
53 
54   // Inserts a conversion HLO that changes the given HLO's output type.
55   Status InsertConvertAfterOutput(HloInstruction* hlo, PrimitiveType to,
56                                   HloComputation* computation);
57 
58   // Changes the output type to the specified type, then inserts a conversion
59   // to the original type.
60   Status ChangeOutputTypeThenInsertConvertBack(HloInstruction* hlo,
61                                                PrimitiveType to,
62                                                HloComputation* computation);
63 
64   // Inserts a conversion HLO that changes the given HLO's operand type.
65   Status InsertConvertBeforeOperand(HloInstruction* hlo, int64 operand_idx,
66                                     PrimitiveType to,
67                                     HloComputation* computation);
68 
69   // Inserts conversion HLOs to replace the called computations' BF16
70   // operands/outputs to F32.
71   Status ConvertCalledComputations(
72       HloInstruction* hlo, absl::Span<HloComputation* const> bf16_called_comps);
73 
74   HloComputation* computation_;
75   const BFloat16Support* bfloat16_support_;
76   bool changed_ = false;
77 };
78 
InsertConvertAfterOutput(HloInstruction * hlo,PrimitiveType to,HloComputation * computation)79 Status BFloat16NormalizationVisitor::InsertConvertAfterOutput(
80     HloInstruction* hlo, PrimitiveType to, HloComputation* computation) {
81   bool is_root = computation->root_instruction() == hlo;
82   std::vector<HloInstruction*> materialized_users = hlo->users();
83   // Use inst's shape temporarily, in order to pass checks in ReplaceUseWith.
84   auto convert = computation->AddInstruction(
85       HloInstruction::CreateConvert(hlo->shape(), hlo));
86   for (auto* user : materialized_users) {
87     if (user->opcode() == HloOpcode::kConvert &&
88         user->shape().element_type() == F32) {
89       TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo));
90     } else {
91       TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, convert));
92     }
93   }
94   if (is_root) {
95     computation->set_root_instruction(convert);
96   }
97   convert->mutable_shape()->set_element_type(to);
98   changed_ = true;
99   return Status::OK();
100 }
101 
ChangeOutputTypeThenInsertConvertBack(HloInstruction * hlo,PrimitiveType to,HloComputation * computation)102 Status BFloat16NormalizationVisitor::ChangeOutputTypeThenInsertConvertBack(
103     HloInstruction* hlo, PrimitiveType to, HloComputation* computation) {
104   auto original_type = hlo->shape().element_type();
105   hlo->mutable_shape()->set_element_type(to);
106   return InsertConvertAfterOutput(hlo, original_type, computation);
107 }
108 
InsertConvertBeforeOperand(HloInstruction * hlo,int64 operand_idx,PrimitiveType to,HloComputation * computation)109 Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand(
110     HloInstruction* hlo, int64 operand_idx, PrimitiveType to,
111     HloComputation* computation) {
112   auto operand = hlo->mutable_operand(operand_idx);
113   auto convert = computation->AddInstruction(HloInstruction::CreateConvert(
114       ShapeUtil::ChangeElementType(operand->shape(), to), operand));
115   TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(operand_idx, convert));
116   changed_ = true;
117   return Status::OK();
118 }
119 
ConvertCalledComputations(HloInstruction * hlo,absl::Span<HloComputation * const> bf16_called_comps)120 Status BFloat16NormalizationVisitor::ConvertCalledComputations(
121     HloInstruction* hlo, absl::Span<HloComputation* const> bf16_called_comps) {
122   std::map<HloComputation*, HloComputation*> cloned_computations;
123   for (auto& comp : bf16_called_comps) {
124     auto cloned = comp->parent()->AddEmbeddedComputation(comp->Clone());
125     cloned_computations[comp] = cloned;
126     changed_ = true;
127   }
128   hlo->ReplaceCalledComputations([&](HloComputation* comp) {
129     auto it = cloned_computations.find(comp);
130     if (it != cloned_computations.end()) {
131       return it->second;
132     }
133     return comp;
134   });
135   for (auto& comp_pair : cloned_computations) {
136     auto comp = comp_pair.second;
137     if (comp->root_instruction()->shape().element_type() == BF16) {
138       TF_RETURN_IF_ERROR(
139           InsertConvertAfterOutput(comp->root_instruction(), F32, comp));
140     }
141     for (auto* param : comp->parameter_instructions()) {
142       if (param->shape().element_type() == BF16) {
143         // This changes the parameter to F32 then inserts a convert after it.
144         TF_RETURN_IF_ERROR(
145             ChangeOutputTypeThenInsertConvertBack(param, F32, comp));
146       }
147     }
148   }
149   return Status::OK();
150 }
151 
HandleMultipleOutputs(HloInstruction * hlo)152 Status BFloat16NormalizationVisitor::HandleMultipleOutputs(
153     HloInstruction* hlo) {
154   std::vector<PrimitiveType> operand_types(hlo->operand_count());
155   std::vector<PrimitiveType> output_types(hlo->operand_count());
156   int64 f32_count = 0;
157   int64 bf16_count = 0;
158   bool has_unsupported_bf16_operand = false;
159   bool has_unsupported_bf16_output = false;
160   for (int64 i = 0; i < hlo->operand_count(); ++i) {
161     operand_types[i] = hlo->operand(i)->shape().element_type();
162     output_types[i] = ShapeUtil::GetSubshape(hlo->shape(), {i}).element_type();
163     if (operand_types[i] == F32) {
164       f32_count += 1;
165     } else if (operand_types[i] == BF16) {
166       bf16_count += 1;
167       if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
168         has_unsupported_bf16_operand = true;
169       }
170     }
171     if (output_types[i] == F32) {
172       f32_count += 1;
173     } else if (output_types[i] == BF16) {
174       bf16_count += 1;
175       if (!bfloat16_support_->SupportsBF16Output(*hlo)) {
176         has_unsupported_bf16_output = true;
177       }
178     }
179   }
180 
181   if (bf16_count == 0) {
182     return Status::OK();
183   }
184 
185   auto should_convert_operand = [&](int64 i) {
186     if (operand_types[i] != BF16) {
187       return false;
188     }
189     if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
190       return true;
191     }
192     if (bfloat16_support_->SupportsMixedPrecisions(*hlo)) {
193       return false;
194     }
195     return has_unsupported_bf16_operand || has_unsupported_bf16_output ||
196            f32_count > 0;
197   };
198 
199   for (int64 i = 0; i < hlo->operand_count(); ++i) {
200     if (should_convert_operand(i)) {
201       TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_));
202       f32_count += 1;
203       bf16_count -= 1;
204     }
205   }
206 
207   if (!has_unsupported_bf16_output &&
208       (bfloat16_support_->SupportsMixedPrecisions(*hlo) || f32_count == 0 ||
209        bf16_count == 0)) {
210     return Status::OK();
211   }
212 
213   std::vector<HloComputation*> bf16_called_comps;
214   for (auto* comp : hlo->called_computations()) {
215     bool comp_has_bf16 = false;
216     if (comp->root_instruction()->shape().element_type() == F32) {
217       f32_count += 1;
218     } else if (comp->root_instruction()->shape().element_type() == BF16) {
219       bf16_count += 1;
220       comp_has_bf16 = true;
221     }
222     for (auto* param : comp->parameter_instructions()) {
223       if (param->shape().element_type() == F32) {
224         f32_count += 1;
225       } else if (param->shape().element_type() == BF16) {
226         bf16_count += 1;
227         comp_has_bf16 = true;
228       }
229     }
230     if (comp_has_bf16) {
231       bf16_called_comps.push_back(comp);
232     }
233   }
234 
235   std::vector<HloInstruction*> materialized_users = hlo->users();
236   std::vector<HloInstruction*> output_elements(hlo->operand_count());
237   auto original_shape = hlo->shape();
238   for (int64 i = 0; i < hlo->operand_count(); ++i) {
239     auto subshape = ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), {i});
240     if (output_types[i] != BF16) {
241       output_elements[i] = computation_->AddInstruction(
242           HloInstruction::CreateGetTupleElement(*subshape, hlo, i));
243       continue;
244     }
245     subshape->set_element_type(F32);
246     auto gte = computation_->AddInstruction(
247         HloInstruction::CreateGetTupleElement(*subshape, hlo, i));
248     output_elements[i] =
249         computation_->AddInstruction(HloInstruction::CreateConvert(
250             ShapeUtil::ChangeElementType(*subshape, BF16), gte));
251   }
252   auto tuple = computation_->AddInstruction(
253       HloInstruction::CreateTuple(output_elements));
254 
255   // Use the hlo' shape temporarily, in order to pass checks in
256   // ReplaceUseWith.
257   *tuple->mutable_shape() = hlo->shape();
258   for (auto* user : materialized_users) {
259     TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, tuple));
260   }
261   bool is_root = computation_->root_instruction() == hlo;
262   if (is_root) {
263     computation_->set_root_instruction(tuple);
264   }
265   *tuple->mutable_shape() = original_shape;
266   return ConvertCalledComputations(hlo, bf16_called_comps);
267 }
268 
HandleInstruction(HloInstruction * hlo)269 Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) {
270   int f32_count = 0;
271   int bf16_count = 1;
272 
273   for (int64 i = 0; i < hlo->operand_count(); ++i) {
274     if (hlo->operand(i)->shape().element_type() == F32) {
275       f32_count += 1;
276     } else if (hlo->operand(i)->shape().element_type() == BF16) {
277       bf16_count += 1;
278     }
279   }
280 
281   if (hlo->shape().element_type() == F32) {
282     f32_count += 1;
283   } else if (hlo->shape().element_type() == BF16) {
284     bf16_count += 1;
285   }
286 
287   std::vector<HloComputation*> bf16_called_comps;
288   for (auto* comp : hlo->called_computations()) {
289     bool comp_has_bf16 = false;
290     if (comp->root_instruction()->shape().element_type() == F32) {
291       f32_count += 1;
292     } else if (comp->root_instruction()->shape().element_type() == BF16) {
293       bf16_count += 1;
294       comp_has_bf16 = true;
295     }
296     for (auto* param : comp->parameter_instructions()) {
297       if (param->shape().element_type() == F32) {
298         f32_count += 1;
299       } else if (param->shape().element_type() == BF16) {
300         bf16_count += 1;
301         comp_has_bf16 = true;
302       }
303     }
304     if (comp_has_bf16) {
305       bf16_called_comps.push_back(comp);
306     }
307   }
308 
309   // Resolve unsupported BF16 operands.
310   for (int i = 0; i < hlo->operand_count(); ++i) {
311     if (hlo->operand(i)->shape().element_type() == BF16 &&
312         !bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
313       TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_));
314       bf16_count -= 1;
315       f32_count += 1;
316     }
317   }
318 
319   // Resolve unsupported BF16 output.
320   if (hlo->shape().element_type() == BF16 &&
321       !bfloat16_support_->SupportsBF16Output(*hlo)) {
322     TF_RETURN_IF_ERROR(
323         ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_));
324     bf16_count -= 1;
325     f32_count += 1;
326   }
327 
328   // Resolve unsupported mixed precision after resolving unsupported BF16
329   // operands and output, because the numbers of BF16 operands/output and F32
330   // operands/output may have changed.
331   if (bfloat16_support_->SupportsMixedPrecisions(*hlo) || bf16_count == 0 ||
332       f32_count == 0) {
333     return Status::OK();
334   }
335   // See if we can change everything to BF16.
336   if (hlo->called_computations().empty() &&
337       hlo->shape().element_type() == BF16) {
338     bool can_use_bf16 = true;
339     for (int i = 0; i < hlo->operand_count(); ++i) {
340       if (hlo->operand(i)->shape().element_type() == BF16) {
341         continue;
342       }
343       if ((bfloat16_support_->EffectiveOperandPrecisionIsBF16(*hlo, i) ||
344            bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo,
345                                                                          i)) &&
346           bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
347         continue;
348       }
349       can_use_bf16 = false;
350       break;
351     }
352     if (can_use_bf16) {
353       for (int i = 0; i < hlo->operand_count(); ++i) {
354         if (hlo->operand(i)->shape().element_type() == F32) {
355           TF_RETURN_IF_ERROR(
356               InsertConvertBeforeOperand(hlo, i, BF16, computation_));
357         }
358       }
359       return Status::OK();
360     }
361   }
362   if (hlo->shape().element_type() == BF16) {
363     TF_RETURN_IF_ERROR(
364         ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_));
365   }
366   for (int i = 0; i < hlo->operand_count(); ++i) {
367     if (hlo->operand(i)->shape().element_type() == BF16) {
368       TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_));
369     }
370   }
371   return ConvertCalledComputations(hlo, bf16_called_comps);
372 }
373 
DefaultAction(HloInstruction * hlo)374 Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) {
375   // Do not change instructions related to entry and exit of a computation,
376   // tuples, fusion, convert, side-effecting instructions, and control flow.
377   if (hlo->opcode() == HloOpcode::kTuple ||            //
378       hlo->opcode() == HloOpcode::kGetTupleElement ||  //
379       hlo->opcode() == HloOpcode::kConstant ||         //
380       hlo->opcode() == HloOpcode::kParameter ||        //
381       hlo->opcode() == HloOpcode::kFusion ||           //
382       hlo->opcode() == HloOpcode::kConvert ||          //
383       hlo->opcode() == HloOpcode::kCall ||             //
384       hlo->opcode() == HloOpcode::kCustomCall ||       //
385       hlo->opcode() == HloOpcode::kWhile ||            //
386       hlo->opcode() == HloOpcode::kConditional ||      //
387       hlo->HasSideEffectNoRecurse()) {
388     return Status::OK();
389   }
390   // TODO(b/112040122): Correctly normalize variadic reduce.
391   if ((hlo->opcode() == HloOpcode::kSort ||
392        hlo->opcode() == HloOpcode::kAllReduce) &&
393       hlo->shape().IsTuple()) {
394     return HandleMultipleOutputs(hlo);
395   }
396   return HandleInstruction(hlo);
397 }
398 
Run(HloModule * module)399 StatusOr<bool> BFloat16Normalization::Run(HloModule* module) {
400   XLA_VLOG_LINES(
401       2, "BFloat16Normalization::Run(), before:\n" + module->ToString());
402   bool changed = false;
403   for (auto* comp : module->MakeComputationPostOrder()) {
404     if (BFloat16NormalizationVisitor::Run(comp, bfloat16_support_)) {
405       changed = true;
406     }
407   }
408   XLA_VLOG_LINES(2,
409                  "BFloat16Normalization::Run(), after:\n" + module->ToString());
410   return changed;
411 }
412 
413 }  // namespace xla
414