1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h"
17
18 #include "absl/types/span.h"
19 #include "tensorflow/compiler/xla/service/hlo_computation.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/status_macros.h"
23 #include "tensorflow/compiler/xla/xla_data.pb.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/types.h"
26
27 namespace xla {
28
29 class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault {
30 public:
BFloat16ConversionFoldingVisitor(HloComputation * computation,const BFloat16Support * bfloat16_support)31 explicit BFloat16ConversionFoldingVisitor(
32 HloComputation* computation, const BFloat16Support* bfloat16_support)
33 : computation_(computation), bfloat16_support_(bfloat16_support) {}
34
35 Status DefaultAction(HloInstruction* hlo) override;
36
37 // Special handling for all-reduce which can have a tuple output.
38 Status HandleAllReduce(HloInstruction* crs) override;
39
Run(HloComputation * computation,const BFloat16Support * bfloat16_support)40 static bool Run(HloComputation* computation,
41 const BFloat16Support* bfloat16_support) {
42 BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support);
43 TF_CHECK_OK(computation->Accept(&visitor));
44 return visitor.changed_;
45 }
46
47 private:
48 // Checks if the HLO has a BF16 -> F32 conversion as input, or a F32 -> BF16
49 // conversion as output, and folds them to the HLO itself if feasible.
50 Status TryFoldBF16Conversions(HloInstruction* hlo);
51
52 // Folds the F32 -> BF16 conversions from the HLO's output.
53 //
54 // Precondition: all of the HLO's users are F32 -> BF16 conversions.
55 Status FoldOutputConversions(HloInstruction* hlo);
56
57 // Folds the BF16 -> F32 conversion operand to the HLO.
58 //
59 // Precondition: the operand is a F32 -> BF16 conversion.
60 Status FoldOperandConversion(HloInstruction* hlo, int64 operand_index);
61
62 HloComputation* computation_;
63 const BFloat16Support* bfloat16_support_;
64 bool changed_ = false;
65 };
66
FoldOutputConversions(HloInstruction * hlo)67 Status BFloat16ConversionFoldingVisitor::FoldOutputConversions(
68 HloInstruction* hlo) {
69 std::vector<HloInstruction*> materialized_users = hlo->users();
70 hlo->mutable_shape()->set_element_type(BF16);
71 for (auto user : materialized_users) {
72 CHECK_EQ(user->opcode(), HloOpcode::kConvert);
73 TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo));
74 changed_ = true;
75 }
76 return Status::OK();
77 }
78
FoldOperandConversion(HloInstruction * hlo,int64 operand_index)79 Status BFloat16ConversionFoldingVisitor::FoldOperandConversion(
80 HloInstruction* hlo, int64 operand_index) {
81 // The operand is a convert from BF16 to F32.
82 auto operand = hlo->mutable_operand(operand_index);
83 CHECK_EQ(operand->opcode(), HloOpcode::kConvert);
84 TF_RETURN_IF_ERROR(
85 hlo->ReplaceOperandWith(operand_index, operand->mutable_operand(0)));
86 changed_ = true;
87 return Status::OK();
88 }
89
90 namespace {
91
92 // Returns whether hlo has users and all users are conversions from F32 to BF16.
AllUsersAreF32ToBF16Converts(const HloInstruction * hlo)93 bool AllUsersAreF32ToBF16Converts(const HloInstruction* hlo) {
94 if (hlo->user_count() == 0 || hlo->shape().element_type() != F32) {
95 return false;
96 }
97 for (const auto user : hlo->users()) {
98 if (user->opcode() == HloOpcode::kConvert &&
99 user->shape().element_type() == BF16) {
100 continue;
101 }
102 return false;
103 }
104 return true;
105 }
106
107 } // namespace
108
TryFoldBF16Conversions(HloInstruction * hlo)109 Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions(
110 HloInstruction* hlo) {
111 std::vector<int64> bf16_to_f32_operands;
112 bool has_other_f32_operands = false;
113 for (int64 i = 0; i < hlo->operands().size(); ++i) {
114 auto operand = hlo->operand(i);
115 if (operand->shape().element_type() == F32) {
116 if (operand->opcode() == HloOpcode::kConvert &&
117 operand->operand(0)->shape().element_type() == BF16 &&
118 bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
119 // Operand is a convert from BF16 to F32 and we support BF16 input
120 // directly in the current HLO at the operand index.
121 bf16_to_f32_operands.push_back(i);
122 } else {
123 has_other_f32_operands = true;
124 }
125 continue;
126 }
127 }
128
129 const bool fold_output_conversion =
130 AllUsersAreF32ToBF16Converts(hlo) &&
131 bfloat16_support_->SupportsBF16Output(*hlo);
132
133 if (!bfloat16_support_->SupportsMixedPrecisions(*hlo)) {
134 if (has_other_f32_operands ||
135 (!fold_output_conversion && hlo->shape().element_type() == F32)) {
136 // Some of the operands/output will remain F32, but we cannot use mixed
137 // precisions, so we cannot do anything here.
138 return Status::OK();
139 }
140 }
141
142 if (fold_output_conversion) {
143 TF_RETURN_IF_ERROR(FoldOutputConversions(hlo));
144 }
145
146 for (int64 i : bf16_to_f32_operands) {
147 TF_RETURN_IF_ERROR(FoldOperandConversion(hlo, i));
148 }
149 return Status::OK();
150 }
151
DefaultAction(HloInstruction * hlo)152 Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) {
153 // Do not fold BF16 conversions for instructions related to tuples, entry and
154 // exit of a computation, fusion, convert, side-effecting instructions and
155 // control flow.
156 if (hlo->opcode() == HloOpcode::kTuple || //
157 hlo->opcode() == HloOpcode::kGetTupleElement || //
158 hlo->opcode() == HloOpcode::kConstant || //
159 hlo->opcode() == HloOpcode::kParameter || //
160 hlo->opcode() == HloOpcode::kFusion || //
161 hlo->opcode() == HloOpcode::kConvert || //
162 hlo->opcode() == HloOpcode::kCall || //
163 hlo->opcode() == HloOpcode::kCustomCall || //
164 hlo->opcode() == HloOpcode::kWhile || //
165 hlo->opcode() == HloOpcode::kConditional || //
166 hlo->HasSideEffectNoRecurse()) {
167 return Status::OK();
168 }
169 if (hlo == computation_->root_instruction() &&
170 !bfloat16_support_->SupportsMixedPrecisions(*hlo)) {
171 // If hlo is the root instruction, we cannot change its output, so folding
172 // can only happen when it supports mixed precision so that we can change
173 // its operands.
174 return Status::OK();
175 }
176 return TryFoldBF16Conversions(hlo);
177 }
178
HandleAllReduce(HloInstruction * crs)179 Status BFloat16ConversionFoldingVisitor::HandleAllReduce(HloInstruction* crs) {
180 if (crs->IsCrossModuleAllReduce()) {
181 // Cross-module all-reduce has side effect.
182 return Status::OK();
183 }
184 // First use DefaultAction() to handle the operands. It can't handle
185 // tuple-shaped output.
186 TF_RETURN_IF_ERROR(DefaultAction(crs));
187
188 if (!bfloat16_support_->SupportsMixedPrecisions(*crs)) {
189 return Status::OK();
190 }
191
192 // If the output is not a tuple, we don't need special handling.
193 if (!crs->shape().IsTuple()) {
194 return Status::OK();
195 }
196
197 // If crs is the root instruction, we should keep its original output type.
198 // The root instruction implicitly has a use from being the result of the
199 // computation, and the code below does not take this use into account.
200 if (crs == computation_->root_instruction()) {
201 return Status::OK();
202 }
203
204 // Then do per-tuple-element handling on the output.
205 std::vector<std::vector<HloInstruction*>> per_tuple_element_gtes(
206 crs->operand_count());
207 for (auto user : crs->users()) {
208 if (user->opcode() != HloOpcode::kGetTupleElement) {
209 return Status::OK();
210 }
211 per_tuple_element_gtes[user->tuple_index()].push_back(user);
212 }
213
214 for (int64 i = 0; i < crs->operand_count(); ++i) {
215 // Fold conversions only when all the get-tuple-elements' users are
216 // conversions from F32 to BF16.
217 auto all_gte_users_are_bf16_convert = [&per_tuple_element_gtes, i]() {
218 for (auto gte : per_tuple_element_gtes[i]) {
219 if (!AllUsersAreF32ToBF16Converts(gte)) {
220 return false;
221 }
222 }
223 return true;
224 };
225 if (!all_gte_users_are_bf16_convert()) {
226 continue;
227 }
228
229 ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i})
230 ->set_element_type(BF16);
231 for (auto gte : per_tuple_element_gtes[i]) {
232 TF_RETURN_IF_ERROR(FoldOutputConversions(gte));
233 }
234 }
235
236 return Status::OK();
237 }
238
Run(HloModule * module)239 StatusOr<bool> BFloat16ConversionFolding::Run(HloModule* module) {
240 XLA_VLOG_LINES(
241 2, "BFloat16ConversionFolding::Run(), before:\n" + module->ToString());
242 bool changed = false;
243 for (auto* comp : module->MakeNonfusionComputations()) {
244 if (BFloat16ConversionFoldingVisitor::Run(comp, bfloat16_support_)) {
245 changed = true;
246 }
247 }
248 XLA_VLOG_LINES(
249 2, "BFloat16ConversionFolding::Run(), after:\n" + module->ToString());
250 return changed;
251 }
252
253 } // namespace xla
254