1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h"
17 
18 #include "absl/types/span.h"
19 #include "tensorflow/compiler/xla/service/hlo_computation.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/status_macros.h"
23 #include "tensorflow/compiler/xla/xla_data.pb.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/types.h"
26 
27 namespace xla {
28 
29 class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault {
30  public:
BFloat16ConversionFoldingVisitor(HloComputation * computation,const BFloat16Support * bfloat16_support)31   explicit BFloat16ConversionFoldingVisitor(
32       HloComputation* computation, const BFloat16Support* bfloat16_support)
33       : computation_(computation), bfloat16_support_(bfloat16_support) {}
34 
35   Status DefaultAction(HloInstruction* hlo) override;
36 
37   // Special handling for all-reduce which can have a tuple output.
38   Status HandleAllReduce(HloInstruction* crs) override;
39 
Run(HloComputation * computation,const BFloat16Support * bfloat16_support)40   static bool Run(HloComputation* computation,
41                   const BFloat16Support* bfloat16_support) {
42     BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support);
43     TF_CHECK_OK(computation->Accept(&visitor));
44     return visitor.changed_;
45   }
46 
47  private:
48   // Checks if the HLO has a BF16 -> F32 conversion as input, or a F32 -> BF16
49   // conversion as output, and folds them to the HLO itself if feasible.
50   Status TryFoldBF16Conversions(HloInstruction* hlo);
51 
52   // Folds the F32 -> BF16 conversions from the HLO's output.
53   //
54   // Precondition: all of the HLO's users are F32 -> BF16 conversions.
55   Status FoldOutputConversions(HloInstruction* hlo);
56 
57   // Folds the BF16 -> F32 conversion operand to the HLO.
58   //
59   // Precondition: the operand is a F32 -> BF16 conversion.
60   Status FoldOperandConversion(HloInstruction* hlo, int64 operand_index);
61 
62   HloComputation* computation_;
63   const BFloat16Support* bfloat16_support_;
64   bool changed_ = false;
65 };
66 
FoldOutputConversions(HloInstruction * hlo)67 Status BFloat16ConversionFoldingVisitor::FoldOutputConversions(
68     HloInstruction* hlo) {
69   std::vector<HloInstruction*> materialized_users = hlo->users();
70   hlo->mutable_shape()->set_element_type(BF16);
71   for (auto user : materialized_users) {
72     CHECK_EQ(user->opcode(), HloOpcode::kConvert);
73     TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo));
74     changed_ = true;
75   }
76   return Status::OK();
77 }
78 
FoldOperandConversion(HloInstruction * hlo,int64 operand_index)79 Status BFloat16ConversionFoldingVisitor::FoldOperandConversion(
80     HloInstruction* hlo, int64 operand_index) {
81   // The operand is a convert from BF16 to F32.
82   auto operand = hlo->mutable_operand(operand_index);
83   CHECK_EQ(operand->opcode(), HloOpcode::kConvert);
84   TF_RETURN_IF_ERROR(
85       hlo->ReplaceOperandWith(operand_index, operand->mutable_operand(0)));
86   changed_ = true;
87   return Status::OK();
88 }
89 
90 namespace {
91 
92 // Returns whether hlo has users and all users are conversions from F32 to BF16.
AllUsersAreF32ToBF16Converts(const HloInstruction * hlo)93 bool AllUsersAreF32ToBF16Converts(const HloInstruction* hlo) {
94   if (hlo->user_count() == 0 || hlo->shape().element_type() != F32) {
95     return false;
96   }
97   for (const auto user : hlo->users()) {
98     if (user->opcode() == HloOpcode::kConvert &&
99         user->shape().element_type() == BF16) {
100       continue;
101     }
102     return false;
103   }
104   return true;
105 }
106 
107 }  // namespace
108 
TryFoldBF16Conversions(HloInstruction * hlo)109 Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions(
110     HloInstruction* hlo) {
111   std::vector<int64> bf16_to_f32_operands;
112   bool has_other_f32_operands = false;
113   for (int64 i = 0; i < hlo->operands().size(); ++i) {
114     auto operand = hlo->operand(i);
115     if (operand->shape().element_type() == F32) {
116       if (operand->opcode() == HloOpcode::kConvert &&
117           operand->operand(0)->shape().element_type() == BF16 &&
118           bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
119         // Operand is a convert from BF16 to F32 and we support BF16 input
120         // directly in the current HLO at the operand index.
121         bf16_to_f32_operands.push_back(i);
122       } else {
123         has_other_f32_operands = true;
124       }
125       continue;
126     }
127   }
128 
129   const bool fold_output_conversion =
130       AllUsersAreF32ToBF16Converts(hlo) &&
131       bfloat16_support_->SupportsBF16Output(*hlo);
132 
133   if (!bfloat16_support_->SupportsMixedPrecisions(*hlo)) {
134     if (has_other_f32_operands ||
135         (!fold_output_conversion && hlo->shape().element_type() == F32)) {
136       // Some of the operands/output will remain F32, but we cannot use mixed
137       // precisions, so we cannot do anything here.
138       return Status::OK();
139     }
140   }
141 
142   if (fold_output_conversion) {
143     TF_RETURN_IF_ERROR(FoldOutputConversions(hlo));
144   }
145 
146   for (int64 i : bf16_to_f32_operands) {
147     TF_RETURN_IF_ERROR(FoldOperandConversion(hlo, i));
148   }
149   return Status::OK();
150 }
151 
DefaultAction(HloInstruction * hlo)152 Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) {
153   // Do not fold BF16 conversions for instructions related to tuples, entry and
154   // exit of a computation, fusion, convert, side-effecting instructions and
155   // control flow.
156   if (hlo->opcode() == HloOpcode::kTuple ||            //
157       hlo->opcode() == HloOpcode::kGetTupleElement ||  //
158       hlo->opcode() == HloOpcode::kConstant ||         //
159       hlo->opcode() == HloOpcode::kParameter ||        //
160       hlo->opcode() == HloOpcode::kFusion ||           //
161       hlo->opcode() == HloOpcode::kConvert ||          //
162       hlo->opcode() == HloOpcode::kCall ||             //
163       hlo->opcode() == HloOpcode::kCustomCall ||       //
164       hlo->opcode() == HloOpcode::kWhile ||            //
165       hlo->opcode() == HloOpcode::kConditional ||      //
166       hlo->HasSideEffectNoRecurse()) {
167     return Status::OK();
168   }
169   if (hlo == computation_->root_instruction() &&
170       !bfloat16_support_->SupportsMixedPrecisions(*hlo)) {
171     // If hlo is the root instruction, we cannot change its output, so folding
172     // can only happen when it supports mixed precision so that we can change
173     // its operands.
174     return Status::OK();
175   }
176   return TryFoldBF16Conversions(hlo);
177 }
178 
HandleAllReduce(HloInstruction * crs)179 Status BFloat16ConversionFoldingVisitor::HandleAllReduce(HloInstruction* crs) {
180   if (crs->IsCrossModuleAllReduce()) {
181     // Cross-module all-reduce has side effect.
182     return Status::OK();
183   }
184   // First use DefaultAction() to handle the operands. It can't handle
185   // tuple-shaped output.
186   TF_RETURN_IF_ERROR(DefaultAction(crs));
187 
188   if (!bfloat16_support_->SupportsMixedPrecisions(*crs)) {
189     return Status::OK();
190   }
191 
192   // If the output is not a tuple, we don't need special handling.
193   if (!crs->shape().IsTuple()) {
194     return Status::OK();
195   }
196 
197   // If crs is the root instruction, we should keep its original output type.
198   // The root instruction implicitly has a use from being the result of the
199   // computation, and the code below does not take this use into account.
200   if (crs == computation_->root_instruction()) {
201     return Status::OK();
202   }
203 
204   // Then do per-tuple-element handling on the output.
205   std::vector<std::vector<HloInstruction*>> per_tuple_element_gtes(
206       crs->operand_count());
207   for (auto user : crs->users()) {
208     if (user->opcode() != HloOpcode::kGetTupleElement) {
209       return Status::OK();
210     }
211     per_tuple_element_gtes[user->tuple_index()].push_back(user);
212   }
213 
214   for (int64 i = 0; i < crs->operand_count(); ++i) {
215     // Fold conversions only when all the get-tuple-elements' users are
216     // conversions from F32 to BF16.
217     auto all_gte_users_are_bf16_convert = [&per_tuple_element_gtes, i]() {
218       for (auto gte : per_tuple_element_gtes[i]) {
219         if (!AllUsersAreF32ToBF16Converts(gte)) {
220           return false;
221         }
222       }
223       return true;
224     };
225     if (!all_gte_users_are_bf16_convert()) {
226       continue;
227     }
228 
229     ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i})
230         ->set_element_type(BF16);
231     for (auto gte : per_tuple_element_gtes[i]) {
232       TF_RETURN_IF_ERROR(FoldOutputConversions(gte));
233     }
234   }
235 
236   return Status::OK();
237 }
238 
Run(HloModule * module)239 StatusOr<bool> BFloat16ConversionFolding::Run(HloModule* module) {
240   XLA_VLOG_LINES(
241       2, "BFloat16ConversionFolding::Run(), before:\n" + module->ToString());
242   bool changed = false;
243   for (auto* comp : module->MakeNonfusionComputations()) {
244     if (BFloat16ConversionFoldingVisitor::Run(comp, bfloat16_support_)) {
245       changed = true;
246     }
247   }
248   XLA_VLOG_LINES(
249       2, "BFloat16ConversionFolding::Run(), after:\n" + module->ToString());
250   return changed;
251 }
252 
253 }  // namespace xla
254