1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/bfloat16_propagation.h"
17
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "tensorflow/compiler/xla/literal.h"
21 #include "tensorflow/compiler/xla/map_util.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_dce.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
28 #include "tensorflow/compiler/xla/shape_tree.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/core/lib/gtl/cleanup.h"
31 #include "tensorflow/core/platform/logging.h"
32
33 namespace xla {
34
BFloat16Propagation(const BFloat16Support * bfloat16_support)35 BFloat16Propagation::BFloat16Propagation(
36 const BFloat16Support* bfloat16_support)
37 : bfloat16_support_(bfloat16_support) {}
38
DetermineFusionComputationPrecision(HloInstruction * fusion)39 void BFloat16Propagation::DetermineFusionComputationPrecision(
40 HloInstruction* fusion) {
41 CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
42 if (!bfloat16_support_->SupportsMixedPrecisions(*fusion)) {
43 return;
44 }
45
46 // We are depending on the fusion node itself having already been analyzed
47 // for whether it can output BF16 and this has been adjusted in the output
48 // shape, and now we're looking to update the interior of the fusion node to
49 // match the new output shape, as well as recursively process the whole fusion
50 // node even if the output shape was not modified.
51 auto root = fusion->fused_instructions_computation()->root_instruction();
52
53 // Adjust root's element types according to the fusion's output shape.
54 ShapeUtil::ForEachSubshape(
55 root->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
56 if (subshape.element_type() != F32) {
57 return;
58 }
59 if (OutputTypeAfterChange(fusion, index) == BF16) {
60 AddToOrRemoveFromBF16ChangeSet(root, index, BF16);
61 VLOG(2) << "Fused root " << root->ToString() << " at shape index "
62 << index << " changed to BF16 precision for fusion "
63 << fusion->ToString();
64 }
65 });
66
67 // Propagate BF16 in the fusion computation.
68 auto insts =
69 fusion->fused_instructions_computation()->MakeInstructionPostOrder();
70 for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
71 DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
72 }
73 computations_visited_in_backward_pass_.insert(
74 fusion->fused_instructions_computation());
75
76 RevertIfFusionInternalBF16Changes(fusion);
77 }
78
RevertIfFusionInternalBF16Changes(HloInstruction * fusion)79 void BFloat16Propagation::RevertIfFusionInternalBF16Changes(
80 HloInstruction* fusion) {
81 auto has_changes = [this](HloInstruction* inst) {
82 auto it = changes_to_bf16_.find(inst);
83 return it != changes_to_bf16_.end() && !it->second.empty();
84 };
85
86 auto root = fusion->fused_instructions_computation()->root_instruction();
87 absl::flat_hash_set<const HloValue*> changed_root_buffers;
88
89 auto root_changes_it = changes_to_bf16_.find(root);
90 if (root_changes_it != changes_to_bf16_.end()) {
91 for (const auto& entry : root_changes_it->second) {
92 for (const HloValue* value :
93 dataflow_->GetValueSet(root, entry.second).values()) {
94 changed_root_buffers.insert(value);
95 }
96 }
97 }
98
99 auto aliases_changed_root_buffer =
100 [this, &changed_root_buffers](const HloInstruction* inst) {
101 bool aliasing = false;
102 ShapeUtil::ForEachSubshape(
103 inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
104 if (aliasing) {
105 // Skip if aliasing is already found.
106 return;
107 }
108 // Only F32 buffers are considered for changing to BF16 in this
109 // pass.
110 if (subshape.element_type() != F32) {
111 return;
112 }
113 for (const HloValue* value :
114 dataflow_->GetValueSet(inst, index).values()) {
115 if (ContainsKey(changed_root_buffers, value)) {
116 aliasing = true;
117 break;
118 }
119 }
120 });
121 return aliasing;
122 };
123
124 for (auto inst :
125 fusion->fused_instructions_computation()->MakeInstructionPostOrder()) {
126 if (inst->opcode() == HloOpcode::kParameter) {
127 continue;
128 }
129 if (aliases_changed_root_buffer(inst)) {
130 continue;
131 }
132 if (inst->opcode() == HloOpcode::kFusion) {
133 bool parameter_reverted = false;
134 for (int64 i = 0; i < inst->operand_count(); ++i) {
135 if (has_changes(inst->mutable_operand(i))) {
136 // Changes on the operand have not been reverted.
137 continue;
138 }
139 auto* fused_parameter = inst->fused_parameter(i);
140 if (has_changes(fused_parameter)) {
141 changes_to_bf16_.erase(fused_parameter);
142 parameter_reverted = true;
143 }
144 }
145 if (parameter_reverted) {
146 RevertIfFusionInternalBF16Changes(inst);
147 }
148 }
149 if (!has_changes(inst)) {
150 continue;
151 }
152 bool revert_changes = true;
153 for (auto operand : inst->operands()) {
154 if (has_changes(operand)) {
155 revert_changes = false;
156 break;
157 }
158 }
159 if (revert_changes) {
160 changes_to_bf16_.erase(inst);
161 }
162 }
163 }
164
DetermineWhileComputationsPrecision(HloInstruction * while_hlo)165 void BFloat16Propagation::DetermineWhileComputationsPrecision(
166 HloInstruction* while_hlo) {
167 CHECK_EQ(while_hlo->opcode(), HloOpcode::kWhile);
168
169 // We are depending on the while node itself having already been analyzed for
170 // whether it can output BF16 and this has been adjusted in the output shape,
171 // and now we're looking to update the body and condition computations to
172 // match the new output shape, as well as recursively process the whole while
173 // node even if the output shape was not modified.
174 HloComputation* body = while_hlo->while_body();
175 auto body_root = body->root_instruction();
176 HloComputation* condition = while_hlo->while_condition();
177
178 ShapeUtil::ForEachSubshape(
179 body_root->shape(), [this, while_hlo, body_root](
180 const Shape& subshape, const ShapeIndex& index) {
181 if (subshape.element_type() != F32) {
182 return;
183 }
184 if (OutputTypeAfterChange(while_hlo, index) == BF16) {
185 AddToOrRemoveFromBF16ChangeSet(body_root, index, BF16);
186 VLOG(2) << "While body root " << body_root->ToString()
187 << " at shape index " << index
188 << " changed to BF16 precision for while "
189 << while_hlo->ToString();
190 }
191 });
192
193 auto body_insts = body->MakeInstructionPostOrder();
194 for (auto inst_it = body_insts.rbegin(); inst_it != body_insts.rend();
195 ++inst_it) {
196 DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
197 }
198 computations_visited_in_backward_pass_.insert(body);
199
200 auto condition_insts = condition->MakeInstructionPostOrder();
201 for (auto inst_it = condition_insts.rbegin();
202 inst_it != condition_insts.rend(); ++inst_it) {
203 DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
204 }
205 computations_visited_in_backward_pass_.insert(condition);
206 }
207
DetermineConditionalComputationsPrecision(HloInstruction * cond)208 void BFloat16Propagation::DetermineConditionalComputationsPrecision(
209 HloInstruction* cond) {
210 CHECK_EQ(cond->opcode(), HloOpcode::kConditional);
211 for (int64 i = 0; i < cond->branch_count(); ++i) {
212 auto branch = cond->branch_computation(i);
213 auto root = branch->root_instruction();
214 ShapeUtil::ForEachSubshape(
215 root->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
216 if (subshape.element_type() != F32) {
217 return;
218 }
219 if (OutputTypeAfterChange(cond, index) == BF16) {
220 AddToOrRemoveFromBF16ChangeSet(root, index, BF16);
221 VLOG(2) << "Conditional branch " << i << " root "
222 << root->ToString() << " at shape index " << index
223 << " changed to BF16 precision for conditional "
224 << cond->ToString();
225 }
226 });
227 auto insts = branch->MakeInstructionPostOrder();
228 for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
229 DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
230 }
231 computations_visited_in_backward_pass_.insert(branch);
232 }
233 }
234
AllUsersConsumeBF16(const HloInstruction & hlo,const ShapeIndex & index) const235 bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
236 const ShapeIndex& index) const {
237 // If the subshape isn't floating point then none of the users will be BF16.
238 const Shape& subshape = ShapeUtil::GetSubshape(hlo.shape(), index);
239 if (subshape.element_type() != BF16 && subshape.element_type() != F32) {
240 return false;
241 }
242
243 auto& value_set = dataflow_->GetValueSet(&hlo, index);
244 for (const HloValue* value : value_set.values()) {
245 if (ContainsKey(values_that_must_be_kept_as_f32_, value)) {
246 return false;
247 }
248 // We use the original type for the value because we are going to examine
249 // the uses of it, instead of the value itself. If ValueTypeAfterChange()
250 // were used, it would cause problems when there are aliasing buffers, i.e.,
251 // ResolveInconsistencyOfAliasingBuffers() would fail to revert the
252 // tentative change to BF16 even if the uses require F32.
253 if (value->shape().element_type() == BF16) {
254 continue;
255 }
256 for (const HloUse& use : value->uses()) {
257 if (!ContainsKey(instructions_visited_in_backward_pass_,
258 use.instruction)) {
259 // We don't know yet whether use.instruction will consume BF16 since it
260 // hasn't been visited. Although we visit instructions in reverse
261 // topological order, this is still possible because there may be
262 // unvisited instruction that alias the same buffer. In this case, we
263 // aggressively skip this use, and if this causes inconsistency (e.g.,
264 // one use is in BF16 but another use is in F32), it will be resolved at
265 // the end of the BFloat16Propagation pass.
266 continue;
267 }
268 if (use.instruction->HasSideEffectNoRecurse()) {
269 // Keep side-effecting instruction's operands unchanged.
270 return false;
271 }
272 // Any visited user that can accept BF16 has already been updated if
273 // necessary, e.g., the output has been changed to BF16 if it propagates
274 // precision, or a called computation's parameters have been changed to
275 // BF16 for fusions or whiles.
276 if (use.instruction->opcode() == HloOpcode::kFusion) {
277 auto* fused_parameter =
278 use.instruction->fused_parameter(use.operand_number);
279 if (OutputTypeAfterChange(fused_parameter, use.operand_index) != BF16) {
280 return false;
281 }
282 continue;
283 } else if (use.instruction->opcode() == HloOpcode::kWhile) {
284 auto* cond_parameter =
285 use.instruction->while_condition()->parameter_instruction(
286 use.operand_number);
287 if (OutputTypeAfterChange(cond_parameter, use.operand_index) != BF16) {
288 return false;
289 }
290 auto* body_parameter =
291 use.instruction->while_body()->parameter_instruction(
292 use.operand_number);
293 if (OutputTypeAfterChange(body_parameter, use.operand_index) != BF16) {
294 return false;
295 }
296 continue;
297 } else if (use.instruction->opcode() == HloOpcode::kConditional) {
298 auto* cond_parameter =
299 use.instruction->branch_computation(use.operand_number - 1)
300 ->parameter_instruction(0);
301 if (OutputTypeAfterChange(cond_parameter, use.operand_index) != BF16) {
302 return false;
303 }
304 continue;
305 }
306 if (bfloat16_support_->EffectiveOperandPrecisionIsBF16(
307 *use.instruction, use.operand_number)) {
308 continue;
309 }
310 // If the op propagates precision and it outputs a BF16, then it's OK to
311 // supply BF16 also as the input. In the backward pass, the users shapes
312 // should have already been processed.
313 if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(
314 *use.instruction, use.operand_number)) {
315 if (use.instruction->opcode() == HloOpcode::kTuple ||
316 (use.instruction->opcode() == HloOpcode::kAllReduce &&
317 use.instruction->shape().IsTuple())) {
318 ShapeIndex use_output_index{use.operand_number};
319 for (int64 i : use.operand_index) {
320 use_output_index.push_back(i);
321 }
322 if (OutputTypeAfterChange(use.instruction, use_output_index) ==
323 BF16) {
324 continue;
325 }
326 } else if (use.instruction->opcode() == HloOpcode::kGetTupleElement) {
327 ShapeIndex use_output_index;
328 for (int64 i = 1; i < use.operand_index.size(); ++i) {
329 use_output_index.push_back(use.operand_index[i]);
330 }
331 if (OutputTypeAfterChange(use.instruction, use_output_index) ==
332 BF16) {
333 continue;
334 }
335 } else {
336 if (OutputTypeAfterChange(use.instruction, use.operand_index) ==
337 BF16) {
338 continue;
339 }
340 }
341 }
342 return false;
343 }
344 }
345 return true;
346 }
347
348 namespace {
349
350 // Returns whether we should avoid changing the precision of inst regardless of
351 // the producers and users.
ShouldKeepPrecisionUnchanged(const HloInstruction * inst)352 bool ShouldKeepPrecisionUnchanged(const HloInstruction* inst) {
353 if (inst->opcode() == HloOpcode::kFusion &&
354 inst->fusion_kind() == HloInstruction::FusionKind::kCustom) {
355 return ShouldKeepPrecisionUnchanged(
356 inst->fused_instructions_computation()->root_instruction());
357 }
358 // Do not change precision for side-effecting instructions, control flow, and
359 // bitcast-convert, because this pass might break the interfaces or
360 // assumptions for them.
361 return inst->opcode() == HloOpcode::kCustomCall || //
362 inst->opcode() == HloOpcode::kCall || //
363 inst->opcode() == HloOpcode::kBitcastConvert || //
364 inst->HasSideEffectNoRecurse();
365 }
366
367 } // namespace
368
DetermineInstructionPrecision(HloInstruction * hlo,bool skip_parameters)369 void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo,
370 bool skip_parameters) {
371 // We handle any fusion computation, while body/condition or conditional
372 // branches after the instruction is handled, because we need to know the
373 // output shape of a fusion or while before propagating inside its
374 // computations.
375 bool postpone_processing_called_computations = false;
376 auto cleaner = tensorflow::gtl::MakeCleanup(
377 [this, hlo, &postpone_processing_called_computations] {
378 if (!postpone_processing_called_computations) {
379 if (hlo->opcode() == HloOpcode::kFusion) {
380 DetermineFusionComputationPrecision(hlo);
381 } else if (hlo->opcode() == HloOpcode::kWhile) {
382 DetermineWhileComputationsPrecision(hlo);
383 } else if (hlo->opcode() == HloOpcode::kConditional) {
384 DetermineConditionalComputationsPrecision(hlo);
385 }
386 }
387 instructions_visited_in_backward_pass_.insert(hlo);
388 });
389
390 if (hlo->opcode() == HloOpcode::kWhile &&
391 (caller_counts_[hlo->while_condition()] > 1 ||
392 caller_counts_[hlo->while_body()] > 1)) {
393 postpone_processing_called_computations = true;
394 return;
395 }
396
397 if (hlo->opcode() == HloOpcode::kConditional &&
398 absl::c_any_of(hlo->branch_computations(), [&](const HloComputation* c) {
399 return caller_counts_[c] > 1;
400 })) {
401 postpone_processing_called_computations = true;
402 return;
403 }
404
405 // Prevent root instructions from having their output modified by recording
406 // all F32 output values as needing to stay as F32.
407 CHECK(hlo->parent() != nullptr);
408 if (hlo == hlo->parent()->root_instruction()) {
409 if (!hlo->parent()->IsFusionComputation()) {
410 ShapeUtil::ForEachSubshape(hlo->shape(), [&](const Shape& /* subshape */,
411 const ShapeIndex& index) {
412 if (OutputTypeAfterChange(hlo, index) != F32) {
413 return;
414 }
415 for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
416 // Since we use HloValues from the dataflow analysis, this can also
417 // affect HLO instructions beyond the root, e.g., if the root is a
418 // Tuple HLO, then its operands are also affected.
419 values_that_must_be_kept_as_f32_.insert(value);
420 }
421 });
422 }
423 return;
424 }
425
426 if (ShouldKeepPrecisionUnchanged(hlo) ||
427 (hlo->opcode() == HloOpcode::kParameter && skip_parameters)) {
428 return;
429 }
430
431 if (!ContainsKey(consider_using_bfloat16_, hlo)) {
432 return;
433 }
434
435 if (!bfloat16_support_->SupportsBF16Output(*hlo)) {
436 return;
437 }
438
439 ShapeUtil::ForEachSubshape(
440 hlo->shape(),
441 [hlo, this](const Shape& /* subshape */, const ShapeIndex& index) {
442 if (OutputTypeAfterChange(hlo, index) == F32 &&
443 AllUsersConsumeBF16(*hlo, index)) {
444 AddToOrRemoveFromBF16ChangeSet(hlo, index, BF16);
445 VLOG(2) << "HloInstruction output at shape index " << index
446 << " changed to BF16 precision: " << hlo->ToString();
447 }
448 });
449 }
450
InstructionIsCandidateForBF16Output(HloInstruction * hlo)451 bool BFloat16Propagation::InstructionIsCandidateForBF16Output(
452 HloInstruction* hlo) {
453 if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) &&
454 hlo->opcode() != HloOpcode::kTuple &&
455 hlo->opcode() != HloOpcode::kGetTupleElement &&
456 hlo->opcode() != HloOpcode::kDomain &&
457 hlo->shape().element_type() != BF16) {
458 for (int64 i = 0; i < hlo->operand_count(); ++i) {
459 if (!bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo,
460 i) ||
461 !ContainsKey(consider_using_bfloat16_, hlo->operand(i))) {
462 return false;
463 }
464 }
465 }
466 return true;
467 }
468
AdjustCalledComputationParameters(HloInstruction * hlo)469 void BFloat16Propagation::AdjustCalledComputationParameters(
470 HloInstruction* hlo) {
471 auto adjust_computation =
472 [this, hlo](HloComputation* computation,
473 absl::Span<HloInstruction* const> operands) {
474 // Adjust parameters.
475 CHECK_EQ(operands.size(), computation->num_parameters());
476 for (int64 i = 0; i < operands.size(); ++i) {
477 auto parameter = computation->parameter_instruction(i);
478 ShapeUtil::ForEachSubshape(
479 parameter->shape(),
480 [this, i, hlo, &operands, parameter](const Shape& /* subshape */,
481 const ShapeIndex& index) {
482 if (!ShapeUtil::IsLeafIndex(parameter->shape(), index)) {
483 return;
484 }
485 PrimitiveType operand_type =
486 OutputTypeAfterChange(operands[i], index);
487 if (OutputTypeAfterChange(parameter, index) == operand_type) {
488 return;
489 }
490 AddToOrRemoveFromBF16ChangeSet(parameter, index, operand_type);
491 VLOG(2) << "Called computation parameter "
492 << parameter->ToString() << " at shape index " << index
493 << " adjusted to "
494 << (operand_type == BF16 ? "BF16" : "F32")
495 << " to match operand in HLO " << hlo->ToString();
496 });
497 }
498 };
499
500 switch (hlo->opcode()) {
501 case HloOpcode::kFusion:
502 adjust_computation(hlo->fused_instructions_computation(),
503 hlo->operands());
504 break;
505 case HloOpcode::kWhile:
506 adjust_computation(hlo->while_condition(), hlo->operands());
507 adjust_computation(hlo->while_body(), hlo->operands());
508 break;
509 case HloOpcode::kConditional:
510 for (int64 i = 0; i < hlo->branch_count(); ++i) {
511 adjust_computation(hlo->branch_computation(i),
512 {hlo->mutable_operand(i + 1)});
513 }
514 break;
515 default:
516 break;
517 }
518 }
519
AdjustCalledComputationRoot(HloInstruction * hlo)520 void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) {
521 auto adjust_computation = [this, hlo](HloComputation* computation,
522 HloInstruction* output) {
523 // Adjust root.
524 HloInstruction* root = computation->root_instruction();
525 ShapeUtil::ForEachSubshape(root->shape(), [this, hlo, root, output](
526 const Shape& /* subshape */,
527 const ShapeIndex& index) {
528 if (!ShapeUtil::IsLeafIndex(hlo->shape(), index)) {
529 return;
530 }
531 const PrimitiveType output_type = OutputTypeAfterChange(output, index);
532 if (OutputTypeAfterChange(root, index) == output_type) {
533 return;
534 }
535 AddToOrRemoveFromBF16ChangeSet(root, index, output_type);
536 // It's possible that output_type is F32, but the root instruction's
537 // type is BF16; e.g., a fusion node's output was changed to BF16
538 // initially but then adjusted back to F32, and the fusion computation
539 // is now being adjusted after the fusion node.
540 if (output_type == F32) {
541 for (const auto* value : dataflow_->GetValueSet(root, index).values()) {
542 // We rely on the fact that this adjustment works in reverse
543 // topological order so that called computation will be
544 // processed later. Adding the value to
545 // values_that_must_be_kept_as_f32_ will ensure the
546 // correctness of the adjustment for HLOs that will be
547 // processed later.
548 values_that_must_be_kept_as_f32_.insert(value);
549 }
550 }
551 VLOG(2) << "Called computation root " << root->ToString()
552 << " at shape index " << index << " adjusted to "
553 << (output_type == BF16 ? "BF16" : "F32")
554 << " to match output shape of " << hlo->ToString();
555 });
556 };
557
558 switch (hlo->opcode()) {
559 case HloOpcode::kFusion:
560 adjust_computation(hlo->fused_instructions_computation(), hlo);
561 break;
562 case HloOpcode::kWhile:
563 adjust_computation(hlo->while_body(), hlo);
564 break;
565 case HloOpcode::kConditional:
566 for (auto* branch : hlo->branch_computations()) {
567 adjust_computation(branch, hlo);
568 }
569 break;
570 default:
571 break;
572 }
573 }
574
ResolveInconsistencyOfAliasingBuffersHelper(HloComputation * computation,absl::flat_hash_set<const HloComputation * > * visited_computations)575 bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
576 HloComputation* computation,
577 absl::flat_hash_set<const HloComputation*>* visited_computations) {
578 bool parameter_changed = false;
579 auto insts = computation->MakeInstructionPostOrder();
580 // Do the adjustment on each instruction in the computation in reverse
581 // topological order.
582 while (true) {
583 bool any_change = false;
584 for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
585 auto hlo = *inst_it;
586 auto adjust_hlo_output = [&](const Shape& /* subshape */,
587 const ShapeIndex& index) {
588 auto output_type = OutputTypeAfterChange(hlo, index);
589 VLOG(2) << "output_type is " << ((output_type == BF16) ? "BF16" : "F32")
590 << " for :" << hlo->ToString() << "\n";
591 if (output_type != F32 && output_type != BF16) {
592 return;
593 }
594 PrimitiveType type = BF16;
595 for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
596 auto value_type = ValueTypeAfterChange(value);
597 if (value_type == BF16) {
598 continue;
599 }
600 VLOG(2) << "Adjust to F32 due to aliased dataflow value: "
601 << value->ToString() << "\n";
602 CHECK_EQ(value_type, F32);
603 type = F32;
604 break;
605 }
606 // In order to find aliases due to in-place operations, use
607 // GetInPlaceInputOutputPairs. Ideally, we'd use HloAliasAnalysis here,
608 // but this code works with HloModules that aren't ready yet to use
609 // HloAliasAnalysis (e.g., their computation graphs may not have been
610 // flattened yet).
611 for (const auto& operand_and_output_index :
612 HloDataflowAnalysis::GetInPlaceInputOutputPairs(hlo)) {
613 if (operand_and_output_index.second == index) {
614 const HloUse& operand = operand_and_output_index.first;
615 for (const auto* value :
616 dataflow_
617 ->GetValueSet(hlo->operand(operand.operand_number),
618 operand.operand_index)
619 .values()) {
620 auto value_type = ValueTypeAfterChange(value);
621 if (value_type == BF16) {
622 continue;
623 }
624 VLOG(2) << "Adjust to F32 due to InputOutPair: "
625 << value->ToString() << "\n";
626 CHECK_EQ(value_type, F32);
627 type = F32;
628 break;
629 }
630 }
631 }
632
633 // It's possible that a user has been changed from BF16 to F32
634 // during this final adjustment pass, so we need to check
635 // AllUsersConsumeBF16() again.
636 if (type == BF16 && !AllUsersConsumeBF16(*hlo, index)) {
637 VLOG(2) << "Adjust to F32 due to All user consumeBF16 fail\n";
638 type = F32;
639 }
640 if (type == F32) {
641 for (const auto* value :
642 dataflow_->GetValueSet(hlo, index).values()) {
643 // We rely on the fact that this adjustment works in reverse
644 // topological order. Adding the value to
645 // values_that_must_be_kept_as_f32_ will ensure the correctness
646 // of the adjustment for HLOs that will be processed later.
647 values_that_must_be_kept_as_f32_.insert(value);
648 }
649 }
650 if (type != output_type) {
651 any_change = true;
652 AddToOrRemoveFromBF16ChangeSet(hlo, index, type);
653 VLOG(2) << "HloInstruction output at shape index " << index
654 << " adjusted to " << (type == BF16 ? "BF16" : "F32") << ": "
655 << hlo->ToString();
656 if (hlo->opcode() == HloOpcode::kParameter) {
657 parameter_changed = true;
658 }
659 }
660 };
661 ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
662 AdjustCalledComputationRoot(hlo);
663 if (hlo->opcode() == HloOpcode::kWhile) {
664 // We need to run on the while body and condition repeatedly until a
665 // fixed point is reached, i.e., the parameters do not change any more.
666 // We may need more than one iteration because the while input and
667 // output alias each other, so changing one input parameter requires
668 // changing the corresponding output element and thus may transitively
669 // require changing another input parameter. A fixed point will be
670 // reached because the parameters can only be changed from BF16 to F32,
671 // not the other way around.
672 absl::flat_hash_set<const HloComputation*> visited_in_while;
673 while (ResolveInconsistencyOfAliasingBuffersHelper(
674 hlo->while_condition(), &visited_in_while) ||
675 ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(),
676 &visited_in_while)) {
677 visited_in_while.clear();
678 ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
679 AdjustCalledComputationRoot(hlo);
680 }
681 visited_computations->insert(visited_in_while.begin(),
682 visited_in_while.end());
683 } else if (hlo->opcode() == HloOpcode::kFusion) {
684 ResolveInconsistencyOfAliasingBuffersHelper(
685 hlo->fused_instructions_computation(), visited_computations);
686 } else if (hlo->opcode() == HloOpcode::kConditional) {
687 for (auto* branch : hlo->branch_computations()) {
688 ResolveInconsistencyOfAliasingBuffersHelper(branch,
689 visited_computations);
690 }
691 }
692 }
693 if (!any_change) {
694 break;
695 }
696 }
697 // Now adjust parameters of called computations.
698 for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
699 AdjustCalledComputationParameters(*inst_it);
700 }
701 return parameter_changed;
702 }
703
ResolveInconsistencyOfAliasingBuffers(HloModule * module)704 void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
705 HloModule* module) {
706 const auto& computations_topological_order =
707 module->MakeComputationPostOrder();
708 absl::flat_hash_set<const HloComputation*> resolved;
709 for (auto comp_it = computations_topological_order.rbegin();
710 comp_it != computations_topological_order.rend(); ++comp_it) {
711 if (ContainsKey(resolved, *comp_it)) {
712 continue;
713 }
714 ResolveInconsistencyOfAliasingBuffersHelper(*comp_it, &resolved);
715 }
716 }
717
ResolveInconsistentFusions(HloModule * module)718 Status BFloat16Propagation::ResolveInconsistentFusions(HloModule* module) {
719 // We could have changed a fusion computation's root shape to have a different
720 // precision than the fusion node's output, if the fusion root does not
721 // define a buffer (e.g., a tuple). Now we add conversions after such fusion
722 // roots to make them match the fusion output. If the fusion output is a
723 // (possibly nested) tuple, we first create get-tuple-elements, then convert
724 // the unmatching leaf nodes, and finally create a new tuple as the fusion
725 // computation's root. If tuples and get-tuple-elements are created, we will
726 // run tuple simplifier and dead code elimination at the end (dead code is not
727 // allowed in fusion computation). E.g.,
728 //
729 // (1) (2) (3)
730 // a b a b a b
731 // |\ | |\ | |\ |
732 // \ add -> |add -> | add
733 // \ | \ | convert |
734 // tuple tuple \ |
735 // / \ tuple
736 // gte gte
737 // | |
738 // convert |
739 // \ /
740 // tuple
741 // (1) a is F32 but tuple is BF16
742 // (2) after adding conversion
743 // (3) after tuple simplifier and DCE.
744 for (auto computation : module->MakeComputationPostOrder()) {
745 auto insts = computation->MakeInstructionPostOrder();
746 for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
747 auto hlo = *inst_it;
748 if (hlo->opcode() != HloOpcode::kFusion) {
749 continue;
750 }
751 auto fusion_computation = hlo->fused_instructions_computation();
752 auto fusion_root = fusion_computation->root_instruction();
753 if (ShapeUtil::Compatible(fusion_root->shape(), hlo->shape())) {
754 continue;
755 }
756 ShapeTree<HloInstruction*> converted_outputs(hlo->shape());
757 // Deep copy the fusion root, and convert a leaf node only if its shape
758 // does not match the fusion output.
759 TF_ASSIGN_OR_RETURN(
760 HloInstruction * copy,
761 fusion_computation->DeepCopyInstructionWithCustomCopier(
762 fusion_root,
763 [hlo](HloInstruction* leaf, const ShapeIndex& leaf_index,
764 HloComputation* comp) {
765 const Shape& hlo_subshape =
766 ShapeUtil::GetSubshape(hlo->shape(), leaf_index);
767 if (ShapeUtil::Compatible(leaf->shape(), hlo_subshape)) {
768 return leaf;
769 }
770 return comp->AddInstruction(
771 HloInstruction::CreateConvert(hlo_subshape, leaf));
772 }));
773 fusion_computation->set_root_instruction(copy);
774 }
775 }
776 return Status::OK();
777 }
778
ResolveConvertedConstants(HloModule * module)779 Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) {
780 // We may have converted some constants from F32 to BF16, so adjust the
781 // constant literals in such cases. We do this here instead of when the
782 // constant node's is changed because 1) the HloInstruction interface does not
783 // allow resetting the literal so we have to create a new kConstant
784 // instruction to replace the old one, which invalidates dataflow analysis,
785 // and 2) it's possible that a kConstant's output gets changed to BF16 at the
786 // beginning but later on adjusted back to F32, so converting literals here
787 // can avoid repeated conversions.
788 //
789 // TODO(b/73833576): Consider resetting literal in HloInstruction.
790 for (auto computation : module->MakeComputationPostOrder()) {
791 for (auto hlo : computation->MakeInstructionPostOrder()) {
792 if (hlo->opcode() != HloOpcode::kConstant) {
793 continue;
794 }
795 if (!Shape::Equal().MinorToMajorOnlyInLayout()(hlo->literal().shape(),
796 hlo->shape())) {
797 TF_ASSIGN_OR_RETURN(auto converted_literal,
798 hlo->literal().ConvertToShape(hlo->shape()));
799 auto new_constant = computation->AddInstruction(
800 HloInstruction::CreateConstant(std::move(converted_literal)));
801 UpdateLayout(new_constant->mutable_shape());
802 TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant));
803 }
804 }
805 }
806 return Status::OK();
807 }
808
SkipNoopConversions(HloModule * module)809 Status BFloat16Propagation::SkipNoopConversions(HloModule* module) {
810 for (auto computation : module->computations()) {
811 for (auto hlo : computation->MakeInstructionPostOrder()) {
812 if (hlo->opcode() != HloOpcode::kConvert) {
813 continue;
814 }
815 auto source = hlo->mutable_operand(0);
816 if (!ShapeUtil::Equal(source->shape(), hlo->shape())) {
817 continue;
818 }
819 const bool is_root = hlo == computation->root_instruction();
820 TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(source));
821 if (is_root) {
822 computation->set_root_instruction(source);
823 }
824 }
825 }
826 return Status::OK();
827 }
828
829 // The algorithm first does a forward pass (parameters to root) to determine a
830 // set of instructions to consider using bfloat16, then does a backward pass to
831 // determine the precisions of those instructions according to the need of
832 // their users. During the backward pass, the potential changes are stored in
833 // changes_to_bf16_ which are subject to further adjustments then applied to the
834 // HLOs.
Run(HloModule * module)835 StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
836 consider_using_bfloat16_.clear();
837 instructions_visited_in_backward_pass_.clear();
838 computations_visited_in_backward_pass_.clear();
839 values_that_must_be_kept_as_f32_.clear();
840 caller_counts_.clear();
841 changes_to_bf16_.clear();
842 changed_ = false;
843
844 auto computations_topological_order = module->MakeComputationPostOrder();
845
846 // Before running the propagation pass, we insert copies (kConvert to the same
847 // type) of F32 inputs to while loops. This prevents other uses of the same
848 // input from aliasing the while loop input/output, so that there's greater
849 // chance to use BF16 inside the loop. If some of these added copies do not
850 // help, they will remain F32 after BF16 propagation and will be removed since
851 // they are no-ops.
852 for (auto computation : computations_topological_order) {
853 for (auto inst : computation->MakeInstructionPostOrder()) {
854 if (inst->opcode() != HloOpcode::kWhile) {
855 continue;
856 }
857
858 auto operand = inst->mutable_operand(0);
859 TF_ASSIGN_OR_RETURN(
860 HloInstruction * copy,
861 computation->DeepCopyInstructionWithCustomCopier(
862 operand, [](HloInstruction* leaf, const ShapeIndex& leaf_index,
863 HloComputation* comp) {
864 if (leaf->shape().element_type() != F32) {
865 return leaf;
866 }
867 return comp->AddInstruction(
868 HloInstruction::CreateConvert(leaf->shape(), leaf));
869 }));
870 TF_RETURN_IF_ERROR(operand->ReplaceUseWith(inst, copy));
871 }
872 }
873
874 TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module));
875
876 // The first step is a forward pass (parameters to root), where we determine
877 // the potential candidate instructions to use bfloat16 in the outputs that
878 // are not likely to cause overhead from extra explicit conversions. This is
879 // done forwardly because we determine whether an HLO is a candidate partially
880 // based on whether its operands are candidates.
881 for (auto computation : computations_topological_order) {
882 for (auto inst : computation->MakeInstructionPostOrder()) {
883 if (InstructionIsCandidateForBF16Output(inst)) {
884 consider_using_bfloat16_.insert(inst);
885 }
886 }
887 }
888
889 // The second step is a backward pass (root to parameters), where we modify
890 // the precisions of the instructions identified in the first step when
891 // feasible. This is done backwardly because we determine the precision of an
892 // HLO's output based on how it is later used.
893 //
894 // The precision of an instruction is determined by its users, so we do the
895 // propagation in reverse topological order.
896 for (auto comp_it = computations_topological_order.rbegin();
897 comp_it != computations_topological_order.rend(); ++comp_it) {
898 if (ContainsKey(computations_visited_in_backward_pass_, *comp_it)) {
899 continue;
900 }
901 auto insts = (*comp_it)->MakeInstructionPostOrder();
902 for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
903 DetermineInstructionPrecision(*inst_it,
904 /*skip_parameters=*/true);
905 }
906 computations_visited_in_backward_pass_.insert(*comp_it);
907 }
908
909 // It's possible that an instruction does not define a buffer, but the
910 // defining instruction's shape has changed. So we need to adjust the output
911 // shapes of instructions according to the HLO values they refer to.
912 ResolveInconsistencyOfAliasingBuffers(module);
913
914 // Apply the changes in changes_to_bf16_.
915 for (auto& change : changes_to_bf16_) {
916 auto inst = change.first;
917 // It is possible that we marked inst to change precision even if it is an
918 // unsupported change, when inst is the root of a fusion computation and it
919 // has to match the fusion node's output precision. We do a convert instead
920 // of in-place change for such cases.
921 if (ShouldKeepPrecisionUnchanged(inst)) {
922 auto users = inst->users();
923 bool is_root = inst == inst->parent()->root_instruction();
924 TF_ASSIGN_OR_RETURN(
925 HloInstruction * copy,
926 inst->parent()->DeepCopyInstructionWithCustomCopier(
927 inst, [&](HloInstruction* leaf, const ShapeIndex& leaf_index,
928 HloComputation* comp) {
929 if (!ContainsKey(change.second,
930 ShapeUtil::GetMutableSubshape(
931 inst->mutable_shape(), leaf_index))) {
932 return leaf;
933 }
934 auto converted_shape =
935 ShapeUtil::ChangeElementType(leaf->shape(), BF16);
936 UpdateLayout(&converted_shape);
937 return comp->AddInstruction(
938 HloInstruction::CreateConvert(converted_shape, leaf));
939 }));
940 for (auto user : users) {
941 TF_RETURN_IF_ERROR(inst->ReplaceUseWithDifferentShape(user, copy));
942 }
943 if (is_root) {
944 inst->parent()->set_root_instruction(copy,
945 /*accept_different_shape=*/true);
946 }
947 continue;
948 }
949 for (const auto& entry : change.second) {
950 auto subshape = entry.first;
951 CHECK_EQ(subshape->element_type(), F32);
952 subshape->set_element_type(BF16);
953 UpdateLayout(subshape);
954 changed_ = true;
955 }
956 }
957
958 // Removes redundant HLOs added by this pass, either when inserting
959 // de-aliasing copies to while loop inputs, or later when converting output
960 // types.
961 auto clean_up = [this, module]() {
962 TF_RETURN_IF_ERROR(SkipNoopConversions(module));
963 TupleSimplifier tuple_simplifier;
964 TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
965 HloDCE dce;
966 TF_RETURN_IF_ERROR(dce.Run(module).status());
967 return Status::OK();
968 };
969
970 if (!changed_) {
971 TF_RETURN_IF_ERROR(clean_up());
972 return false;
973 }
974
975 TF_RETURN_IF_ERROR(ResolveInconsistentFusions(module));
976 TF_RETURN_IF_ERROR(ResolveConvertedConstants(module));
977
978 TF_RETURN_IF_ERROR(clean_up());
979 return true;
980 }
981
OutputTypeAfterChange(HloInstruction * hlo,const ShapeIndex & index) const982 PrimitiveType BFloat16Propagation::OutputTypeAfterChange(
983 HloInstruction* hlo, const ShapeIndex& index) const {
984 Shape* subshape = ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index);
985 const PrimitiveType type_on_hlo = subshape->element_type();
986 if (type_on_hlo != F32) {
987 return type_on_hlo;
988 }
989 auto it = changes_to_bf16_.find(hlo);
990 if (it == changes_to_bf16_.end()) {
991 return type_on_hlo;
992 }
993 return ContainsKey(it->second, subshape) ? BF16 : F32;
994 }
995
ValueTypeAfterChange(const HloValue * value) const996 PrimitiveType BFloat16Propagation::ValueTypeAfterChange(
997 const HloValue* value) const {
998 auto hlo = value->defining_instruction();
999 const auto& position = value->defining_position();
1000 return OutputTypeAfterChange(hlo, position.index);
1001 }
1002
AddToOrRemoveFromBF16ChangeSet(HloInstruction * hlo,const ShapeIndex & index,PrimitiveType target_type)1003 void BFloat16Propagation::AddToOrRemoveFromBF16ChangeSet(
1004 HloInstruction* hlo, const ShapeIndex& index, PrimitiveType target_type) {
1005 if (target_type == BF16) {
1006 auto& entry = changes_to_bf16_[hlo];
1007 entry.emplace(ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index),
1008 index);
1009 } else {
1010 CHECK_EQ(target_type, F32);
1011 auto it = changes_to_bf16_.find(hlo);
1012 if (it == changes_to_bf16_.end()) {
1013 return;
1014 }
1015 it->second.erase(
1016 ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index));
1017 }
1018 }
1019
1020 } // namespace xla
1021