1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/bfloat16_propagation.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "tensorflow/compiler/xla/literal.h"
21 #include "tensorflow/compiler/xla/map_util.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_dce.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
28 #include "tensorflow/compiler/xla/shape_tree.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/core/lib/gtl/cleanup.h"
31 #include "tensorflow/core/platform/logging.h"
32 
33 namespace xla {
34 
BFloat16Propagation(const BFloat16Support * bfloat16_support)35 BFloat16Propagation::BFloat16Propagation(
36     const BFloat16Support* bfloat16_support)
37     : bfloat16_support_(bfloat16_support) {}
38 
DetermineFusionComputationPrecision(HloInstruction * fusion)39 void BFloat16Propagation::DetermineFusionComputationPrecision(
40     HloInstruction* fusion) {
41   CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
42   if (!bfloat16_support_->SupportsMixedPrecisions(*fusion)) {
43     return;
44   }
45 
46   // We are depending on the fusion node itself having already been analyzed
47   // for whether it can output BF16 and this has been adjusted in the output
48   // shape, and now we're looking to update the interior of the fusion node to
49   // match the new output shape, as well as recursively process the whole fusion
50   // node even if the output shape was not modified.
51   auto root = fusion->fused_instructions_computation()->root_instruction();
52 
53   // Adjust root's element types according to the fusion's output shape.
54   ShapeUtil::ForEachSubshape(
55       root->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
56         if (subshape.element_type() != F32) {
57           return;
58         }
59         if (OutputTypeAfterChange(fusion, index) == BF16) {
60           AddToOrRemoveFromBF16ChangeSet(root, index, BF16);
61           VLOG(2) << "Fused root " << root->ToString() << " at shape index "
62                   << index << " changed to BF16 precision for fusion "
63                   << fusion->ToString();
64         }
65       });
66 
67   // Propagate BF16 in the fusion computation.
68   auto insts =
69       fusion->fused_instructions_computation()->MakeInstructionPostOrder();
70   for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
71     DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
72   }
73   computations_visited_in_backward_pass_.insert(
74       fusion->fused_instructions_computation());
75 
76   RevertIfFusionInternalBF16Changes(fusion);
77 }
78 
RevertIfFusionInternalBF16Changes(HloInstruction * fusion)79 void BFloat16Propagation::RevertIfFusionInternalBF16Changes(
80     HloInstruction* fusion) {
81   auto has_changes = [this](HloInstruction* inst) {
82     auto it = changes_to_bf16_.find(inst);
83     return it != changes_to_bf16_.end() && !it->second.empty();
84   };
85 
86   auto root = fusion->fused_instructions_computation()->root_instruction();
87   absl::flat_hash_set<const HloValue*> changed_root_buffers;
88 
89   auto root_changes_it = changes_to_bf16_.find(root);
90   if (root_changes_it != changes_to_bf16_.end()) {
91     for (const auto& entry : root_changes_it->second) {
92       for (const HloValue* value :
93            dataflow_->GetValueSet(root, entry.second).values()) {
94         changed_root_buffers.insert(value);
95       }
96     }
97   }
98 
99   auto aliases_changed_root_buffer =
100       [this, &changed_root_buffers](const HloInstruction* inst) {
101         bool aliasing = false;
102         ShapeUtil::ForEachSubshape(
103             inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
104               if (aliasing) {
105                 // Skip if aliasing is already found.
106                 return;
107               }
108               // Only F32 buffers are considered for changing to BF16 in this
109               // pass.
110               if (subshape.element_type() != F32) {
111                 return;
112               }
113               for (const HloValue* value :
114                    dataflow_->GetValueSet(inst, index).values()) {
115                 if (ContainsKey(changed_root_buffers, value)) {
116                   aliasing = true;
117                   break;
118                 }
119               }
120             });
121         return aliasing;
122       };
123 
124   for (auto inst :
125        fusion->fused_instructions_computation()->MakeInstructionPostOrder()) {
126     if (inst->opcode() == HloOpcode::kParameter) {
127       continue;
128     }
129     if (aliases_changed_root_buffer(inst)) {
130       continue;
131     }
132     if (inst->opcode() == HloOpcode::kFusion) {
133       bool parameter_reverted = false;
134       for (int64 i = 0; i < inst->operand_count(); ++i) {
135         if (has_changes(inst->mutable_operand(i))) {
136           // Changes on the operand have not been reverted.
137           continue;
138         }
139         auto* fused_parameter = inst->fused_parameter(i);
140         if (has_changes(fused_parameter)) {
141           changes_to_bf16_.erase(fused_parameter);
142           parameter_reverted = true;
143         }
144       }
145       if (parameter_reverted) {
146         RevertIfFusionInternalBF16Changes(inst);
147       }
148     }
149     if (!has_changes(inst)) {
150       continue;
151     }
152     bool revert_changes = true;
153     for (auto operand : inst->operands()) {
154       if (has_changes(operand)) {
155         revert_changes = false;
156         break;
157       }
158     }
159     if (revert_changes) {
160       changes_to_bf16_.erase(inst);
161     }
162   }
163 }
164 
DetermineWhileComputationsPrecision(HloInstruction * while_hlo)165 void BFloat16Propagation::DetermineWhileComputationsPrecision(
166     HloInstruction* while_hlo) {
167   CHECK_EQ(while_hlo->opcode(), HloOpcode::kWhile);
168 
169   // We are depending on the while node itself having already been analyzed for
170   // whether it can output BF16 and this has been adjusted in the output shape,
171   // and now we're looking to update the body and condition computations to
172   // match the new output shape, as well as recursively process the whole while
173   // node even if the output shape was not modified.
174   HloComputation* body = while_hlo->while_body();
175   auto body_root = body->root_instruction();
176   HloComputation* condition = while_hlo->while_condition();
177 
178   ShapeUtil::ForEachSubshape(
179       body_root->shape(), [this, while_hlo, body_root](
180                               const Shape& subshape, const ShapeIndex& index) {
181         if (subshape.element_type() != F32) {
182           return;
183         }
184         if (OutputTypeAfterChange(while_hlo, index) == BF16) {
185           AddToOrRemoveFromBF16ChangeSet(body_root, index, BF16);
186           VLOG(2) << "While body root " << body_root->ToString()
187                   << " at shape index " << index
188                   << " changed to BF16 precision for while "
189                   << while_hlo->ToString();
190         }
191       });
192 
193   auto body_insts = body->MakeInstructionPostOrder();
194   for (auto inst_it = body_insts.rbegin(); inst_it != body_insts.rend();
195        ++inst_it) {
196     DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
197   }
198   computations_visited_in_backward_pass_.insert(body);
199 
200   auto condition_insts = condition->MakeInstructionPostOrder();
201   for (auto inst_it = condition_insts.rbegin();
202        inst_it != condition_insts.rend(); ++inst_it) {
203     DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
204   }
205   computations_visited_in_backward_pass_.insert(condition);
206 }
207 
DetermineConditionalComputationsPrecision(HloInstruction * cond)208 void BFloat16Propagation::DetermineConditionalComputationsPrecision(
209     HloInstruction* cond) {
210   CHECK_EQ(cond->opcode(), HloOpcode::kConditional);
211   for (int64 i = 0; i < cond->branch_count(); ++i) {
212     auto branch = cond->branch_computation(i);
213     auto root = branch->root_instruction();
214     ShapeUtil::ForEachSubshape(
215         root->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
216           if (subshape.element_type() != F32) {
217             return;
218           }
219           if (OutputTypeAfterChange(cond, index) == BF16) {
220             AddToOrRemoveFromBF16ChangeSet(root, index, BF16);
221             VLOG(2) << "Conditional branch " << i << " root "
222                     << root->ToString() << " at shape index " << index
223                     << " changed to BF16 precision for conditional "
224                     << cond->ToString();
225           }
226         });
227     auto insts = branch->MakeInstructionPostOrder();
228     for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
229       DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
230     }
231     computations_visited_in_backward_pass_.insert(branch);
232   }
233 }
234 
AllUsersConsumeBF16(const HloInstruction & hlo,const ShapeIndex & index) const235 bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
236                                               const ShapeIndex& index) const {
237   // If the subshape isn't floating point then none of the users will be BF16.
238   const Shape& subshape = ShapeUtil::GetSubshape(hlo.shape(), index);
239   if (subshape.element_type() != BF16 && subshape.element_type() != F32) {
240     return false;
241   }
242 
243   auto& value_set = dataflow_->GetValueSet(&hlo, index);
244   for (const HloValue* value : value_set.values()) {
245     if (ContainsKey(values_that_must_be_kept_as_f32_, value)) {
246       return false;
247     }
248     // We use the original type for the value because we are going to examine
249     // the uses of it, instead of the value itself. If ValueTypeAfterChange()
250     // were used, it would cause problems when there are aliasing buffers, i.e.,
251     // ResolveInconsistencyOfAliasingBuffers() would fail to revert the
252     // tentative change to BF16 even if the uses require F32.
253     if (value->shape().element_type() == BF16) {
254       continue;
255     }
256     for (const HloUse& use : value->uses()) {
257       if (!ContainsKey(instructions_visited_in_backward_pass_,
258                        use.instruction)) {
259         // We don't know yet whether use.instruction will consume BF16 since it
260         // hasn't been visited. Although we visit instructions in reverse
261         // topological order, this is still possible because there may be
262         // unvisited instruction that alias the same buffer. In this case, we
263         // aggressively skip this use, and if this causes inconsistency (e.g.,
264         // one use is in BF16 but another use is in F32), it will be resolved at
265         // the end of the BFloat16Propagation pass.
266         continue;
267       }
268       if (use.instruction->HasSideEffectNoRecurse()) {
269         // Keep side-effecting instruction's operands unchanged.
270         return false;
271       }
272       // Any visited user that can accept BF16 has already been updated if
273       // necessary, e.g., the output has been changed to BF16 if it propagates
274       // precision, or a called computation's parameters have been changed to
275       // BF16 for fusions or whiles.
276       if (use.instruction->opcode() == HloOpcode::kFusion) {
277         auto* fused_parameter =
278             use.instruction->fused_parameter(use.operand_number);
279         if (OutputTypeAfterChange(fused_parameter, use.operand_index) != BF16) {
280           return false;
281         }
282         continue;
283       } else if (use.instruction->opcode() == HloOpcode::kWhile) {
284         auto* cond_parameter =
285             use.instruction->while_condition()->parameter_instruction(
286                 use.operand_number);
287         if (OutputTypeAfterChange(cond_parameter, use.operand_index) != BF16) {
288           return false;
289         }
290         auto* body_parameter =
291             use.instruction->while_body()->parameter_instruction(
292                 use.operand_number);
293         if (OutputTypeAfterChange(body_parameter, use.operand_index) != BF16) {
294           return false;
295         }
296         continue;
297       } else if (use.instruction->opcode() == HloOpcode::kConditional) {
298         auto* cond_parameter =
299             use.instruction->branch_computation(use.operand_number - 1)
300                 ->parameter_instruction(0);
301         if (OutputTypeAfterChange(cond_parameter, use.operand_index) != BF16) {
302           return false;
303         }
304         continue;
305       }
306       if (bfloat16_support_->EffectiveOperandPrecisionIsBF16(
307               *use.instruction, use.operand_number)) {
308         continue;
309       }
310       // If the op propagates precision and it outputs a BF16, then it's OK to
311       // supply BF16 also as the input. In the backward pass, the users shapes
312       // should have already been processed.
313       if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(
314               *use.instruction, use.operand_number)) {
315         if (use.instruction->opcode() == HloOpcode::kTuple ||
316             (use.instruction->opcode() == HloOpcode::kAllReduce &&
317              use.instruction->shape().IsTuple())) {
318           ShapeIndex use_output_index{use.operand_number};
319           for (int64 i : use.operand_index) {
320             use_output_index.push_back(i);
321           }
322           if (OutputTypeAfterChange(use.instruction, use_output_index) ==
323               BF16) {
324             continue;
325           }
326         } else if (use.instruction->opcode() == HloOpcode::kGetTupleElement) {
327           ShapeIndex use_output_index;
328           for (int64 i = 1; i < use.operand_index.size(); ++i) {
329             use_output_index.push_back(use.operand_index[i]);
330           }
331           if (OutputTypeAfterChange(use.instruction, use_output_index) ==
332               BF16) {
333             continue;
334           }
335         } else {
336           if (OutputTypeAfterChange(use.instruction, use.operand_index) ==
337               BF16) {
338             continue;
339           }
340         }
341       }
342       return false;
343     }
344   }
345   return true;
346 }
347 
348 namespace {
349 
350 // Returns whether we should avoid changing the precision of inst regardless of
351 // the producers and users.
ShouldKeepPrecisionUnchanged(const HloInstruction * inst)352 bool ShouldKeepPrecisionUnchanged(const HloInstruction* inst) {
353   if (inst->opcode() == HloOpcode::kFusion &&
354       inst->fusion_kind() == HloInstruction::FusionKind::kCustom) {
355     return ShouldKeepPrecisionUnchanged(
356         inst->fused_instructions_computation()->root_instruction());
357   }
358   // Do not change precision for side-effecting instructions, control flow, and
359   // bitcast-convert, because this pass might break the interfaces or
360   // assumptions for them.
361   return inst->opcode() == HloOpcode::kCustomCall ||      //
362          inst->opcode() == HloOpcode::kCall ||            //
363          inst->opcode() == HloOpcode::kBitcastConvert ||  //
364          inst->HasSideEffectNoRecurse();
365 }
366 
367 }  // namespace
368 
DetermineInstructionPrecision(HloInstruction * hlo,bool skip_parameters)369 void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo,
370                                                         bool skip_parameters) {
371   // We handle any fusion computation, while body/condition or conditional
372   // branches after the instruction is handled, because we need to know the
373   // output shape of a fusion or while before propagating inside its
374   // computations.
375   bool postpone_processing_called_computations = false;
376   auto cleaner = tensorflow::gtl::MakeCleanup(
377       [this, hlo, &postpone_processing_called_computations] {
378         if (!postpone_processing_called_computations) {
379           if (hlo->opcode() == HloOpcode::kFusion) {
380             DetermineFusionComputationPrecision(hlo);
381           } else if (hlo->opcode() == HloOpcode::kWhile) {
382             DetermineWhileComputationsPrecision(hlo);
383           } else if (hlo->opcode() == HloOpcode::kConditional) {
384             DetermineConditionalComputationsPrecision(hlo);
385           }
386         }
387         instructions_visited_in_backward_pass_.insert(hlo);
388       });
389 
390   if (hlo->opcode() == HloOpcode::kWhile &&
391       (caller_counts_[hlo->while_condition()] > 1 ||
392        caller_counts_[hlo->while_body()] > 1)) {
393     postpone_processing_called_computations = true;
394     return;
395   }
396 
397   if (hlo->opcode() == HloOpcode::kConditional &&
398       absl::c_any_of(hlo->branch_computations(), [&](const HloComputation* c) {
399         return caller_counts_[c] > 1;
400       })) {
401     postpone_processing_called_computations = true;
402     return;
403   }
404 
405   // Prevent root instructions from having their output modified by recording
406   // all F32 output values as needing to stay as F32.
407   CHECK(hlo->parent() != nullptr);
408   if (hlo == hlo->parent()->root_instruction()) {
409     if (!hlo->parent()->IsFusionComputation()) {
410       ShapeUtil::ForEachSubshape(hlo->shape(), [&](const Shape& /* subshape */,
411                                                    const ShapeIndex& index) {
412         if (OutputTypeAfterChange(hlo, index) != F32) {
413           return;
414         }
415         for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
416           // Since we use HloValues from the dataflow analysis, this can also
417           // affect HLO instructions beyond the root, e.g., if the root is a
418           // Tuple HLO, then its operands are also affected.
419           values_that_must_be_kept_as_f32_.insert(value);
420         }
421       });
422     }
423     return;
424   }
425 
426   if (ShouldKeepPrecisionUnchanged(hlo) ||
427       (hlo->opcode() == HloOpcode::kParameter && skip_parameters)) {
428     return;
429   }
430 
431   if (!ContainsKey(consider_using_bfloat16_, hlo)) {
432     return;
433   }
434 
435   if (!bfloat16_support_->SupportsBF16Output(*hlo)) {
436     return;
437   }
438 
439   ShapeUtil::ForEachSubshape(
440       hlo->shape(),
441       [hlo, this](const Shape& /* subshape */, const ShapeIndex& index) {
442         if (OutputTypeAfterChange(hlo, index) == F32 &&
443             AllUsersConsumeBF16(*hlo, index)) {
444           AddToOrRemoveFromBF16ChangeSet(hlo, index, BF16);
445           VLOG(2) << "HloInstruction output at shape index " << index
446                   << " changed to BF16 precision: " << hlo->ToString();
447         }
448       });
449 }
450 
InstructionIsCandidateForBF16Output(HloInstruction * hlo)451 bool BFloat16Propagation::InstructionIsCandidateForBF16Output(
452     HloInstruction* hlo) {
453   if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) &&
454       hlo->opcode() != HloOpcode::kTuple &&
455       hlo->opcode() != HloOpcode::kGetTupleElement &&
456       hlo->opcode() != HloOpcode::kDomain &&
457       hlo->shape().element_type() != BF16) {
458     for (int64 i = 0; i < hlo->operand_count(); ++i) {
459       if (!bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo,
460                                                                          i) ||
461           !ContainsKey(consider_using_bfloat16_, hlo->operand(i))) {
462         return false;
463       }
464     }
465   }
466   return true;
467 }
468 
AdjustCalledComputationParameters(HloInstruction * hlo)469 void BFloat16Propagation::AdjustCalledComputationParameters(
470     HloInstruction* hlo) {
471   auto adjust_computation =
472       [this, hlo](HloComputation* computation,
473                   absl::Span<HloInstruction* const> operands) {
474         // Adjust parameters.
475         CHECK_EQ(operands.size(), computation->num_parameters());
476         for (int64 i = 0; i < operands.size(); ++i) {
477           auto parameter = computation->parameter_instruction(i);
478           ShapeUtil::ForEachSubshape(
479               parameter->shape(),
480               [this, i, hlo, &operands, parameter](const Shape& /* subshape */,
481                                                    const ShapeIndex& index) {
482                 if (!ShapeUtil::IsLeafIndex(parameter->shape(), index)) {
483                   return;
484                 }
485                 PrimitiveType operand_type =
486                     OutputTypeAfterChange(operands[i], index);
487                 if (OutputTypeAfterChange(parameter, index) == operand_type) {
488                   return;
489                 }
490                 AddToOrRemoveFromBF16ChangeSet(parameter, index, operand_type);
491                 VLOG(2) << "Called computation parameter "
492                         << parameter->ToString() << " at shape index " << index
493                         << " adjusted to "
494                         << (operand_type == BF16 ? "BF16" : "F32")
495                         << " to match operand in HLO " << hlo->ToString();
496               });
497         }
498       };
499 
500   switch (hlo->opcode()) {
501     case HloOpcode::kFusion:
502       adjust_computation(hlo->fused_instructions_computation(),
503                          hlo->operands());
504       break;
505     case HloOpcode::kWhile:
506       adjust_computation(hlo->while_condition(), hlo->operands());
507       adjust_computation(hlo->while_body(), hlo->operands());
508       break;
509     case HloOpcode::kConditional:
510       for (int64 i = 0; i < hlo->branch_count(); ++i) {
511         adjust_computation(hlo->branch_computation(i),
512                            {hlo->mutable_operand(i + 1)});
513       }
514       break;
515     default:
516       break;
517   }
518 }
519 
AdjustCalledComputationRoot(HloInstruction * hlo)520 void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) {
521   auto adjust_computation = [this, hlo](HloComputation* computation,
522                                         HloInstruction* output) {
523     // Adjust root.
524     HloInstruction* root = computation->root_instruction();
525     ShapeUtil::ForEachSubshape(root->shape(), [this, hlo, root, output](
526                                                   const Shape& /* subshape */,
527                                                   const ShapeIndex& index) {
528       if (!ShapeUtil::IsLeafIndex(hlo->shape(), index)) {
529         return;
530       }
531       const PrimitiveType output_type = OutputTypeAfterChange(output, index);
532       if (OutputTypeAfterChange(root, index) == output_type) {
533         return;
534       }
535       AddToOrRemoveFromBF16ChangeSet(root, index, output_type);
536       // It's possible that output_type is F32, but the root instruction's
537       // type is BF16; e.g., a fusion node's output was changed to BF16
538       // initially but then adjusted back to F32, and the fusion computation
539       // is now being adjusted after the fusion node.
540       if (output_type == F32) {
541         for (const auto* value : dataflow_->GetValueSet(root, index).values()) {
542           // We rely on the fact that this adjustment works in reverse
543           // topological order so that called computation will be
544           // processed later. Adding the value to
545           // values_that_must_be_kept_as_f32_ will ensure the
546           // correctness of the adjustment for HLOs that will be
547           // processed later.
548           values_that_must_be_kept_as_f32_.insert(value);
549         }
550       }
551       VLOG(2) << "Called computation root " << root->ToString()
552               << " at shape index " << index << " adjusted to "
553               << (output_type == BF16 ? "BF16" : "F32")
554               << " to match output shape of " << hlo->ToString();
555     });
556   };
557 
558   switch (hlo->opcode()) {
559     case HloOpcode::kFusion:
560       adjust_computation(hlo->fused_instructions_computation(), hlo);
561       break;
562     case HloOpcode::kWhile:
563       adjust_computation(hlo->while_body(), hlo);
564       break;
565     case HloOpcode::kConditional:
566       for (auto* branch : hlo->branch_computations()) {
567         adjust_computation(branch, hlo);
568       }
569       break;
570     default:
571       break;
572   }
573 }
574 
ResolveInconsistencyOfAliasingBuffersHelper(HloComputation * computation,absl::flat_hash_set<const HloComputation * > * visited_computations)575 bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
576     HloComputation* computation,
577     absl::flat_hash_set<const HloComputation*>* visited_computations) {
578   bool parameter_changed = false;
579   auto insts = computation->MakeInstructionPostOrder();
580   // Do the adjustment on each instruction in the computation in reverse
581   // topological order.
582   while (true) {
583     bool any_change = false;
584     for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
585       auto hlo = *inst_it;
586       auto adjust_hlo_output = [&](const Shape& /* subshape */,
587                                    const ShapeIndex& index) {
588         auto output_type = OutputTypeAfterChange(hlo, index);
589         VLOG(2) << "output_type is " << ((output_type == BF16) ? "BF16" : "F32")
590                 << " for :" << hlo->ToString() << "\n";
591         if (output_type != F32 && output_type != BF16) {
592           return;
593         }
594         PrimitiveType type = BF16;
595         for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
596           auto value_type = ValueTypeAfterChange(value);
597           if (value_type == BF16) {
598             continue;
599           }
600           VLOG(2) << "Adjust to F32 due to aliased dataflow value: "
601                   << value->ToString() << "\n";
602           CHECK_EQ(value_type, F32);
603           type = F32;
604           break;
605         }
606         // In order to find aliases due to in-place operations, use
607         // GetInPlaceInputOutputPairs. Ideally, we'd use HloAliasAnalysis here,
608         // but this code works with HloModules that aren't ready yet to use
609         // HloAliasAnalysis (e.g., their computation graphs may not have been
610         // flattened yet).
611         for (const auto& operand_and_output_index :
612              HloDataflowAnalysis::GetInPlaceInputOutputPairs(hlo)) {
613           if (operand_and_output_index.second == index) {
614             const HloUse& operand = operand_and_output_index.first;
615             for (const auto* value :
616                  dataflow_
617                      ->GetValueSet(hlo->operand(operand.operand_number),
618                                    operand.operand_index)
619                      .values()) {
620               auto value_type = ValueTypeAfterChange(value);
621               if (value_type == BF16) {
622                 continue;
623               }
624               VLOG(2) << "Adjust to F32 due to InputOutPair: "
625                       << value->ToString() << "\n";
626               CHECK_EQ(value_type, F32);
627               type = F32;
628               break;
629             }
630           }
631         }
632 
633         // It's possible that a user has been changed from BF16 to F32
634         // during this final adjustment pass, so we need to check
635         // AllUsersConsumeBF16() again.
636         if (type == BF16 && !AllUsersConsumeBF16(*hlo, index)) {
637           VLOG(2) << "Adjust to F32 due to All user consumeBF16 fail\n";
638           type = F32;
639         }
640         if (type == F32) {
641           for (const auto* value :
642                dataflow_->GetValueSet(hlo, index).values()) {
643             // We rely on the fact that this adjustment works in reverse
644             // topological order. Adding the value to
645             // values_that_must_be_kept_as_f32_ will ensure the correctness
646             // of the adjustment for HLOs that will be processed later.
647             values_that_must_be_kept_as_f32_.insert(value);
648           }
649         }
650         if (type != output_type) {
651           any_change = true;
652           AddToOrRemoveFromBF16ChangeSet(hlo, index, type);
653           VLOG(2) << "HloInstruction output at shape index " << index
654                   << " adjusted to " << (type == BF16 ? "BF16" : "F32") << ": "
655                   << hlo->ToString();
656           if (hlo->opcode() == HloOpcode::kParameter) {
657             parameter_changed = true;
658           }
659         }
660       };
661       ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
662       AdjustCalledComputationRoot(hlo);
663       if (hlo->opcode() == HloOpcode::kWhile) {
664         // We need to run on the while body and condition repeatedly until a
665         // fixed point is reached, i.e., the parameters do not change any more.
666         // We may need more than one iteration because the while input and
667         // output alias each other, so changing one input parameter requires
668         // changing the corresponding output element and thus may transitively
669         // require changing another input parameter. A fixed point will be
670         // reached because the parameters can only be changed from BF16 to F32,
671         // not the other way around.
672         absl::flat_hash_set<const HloComputation*> visited_in_while;
673         while (ResolveInconsistencyOfAliasingBuffersHelper(
674                    hlo->while_condition(), &visited_in_while) ||
675                ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(),
676                                                            &visited_in_while)) {
677           visited_in_while.clear();
678           ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
679           AdjustCalledComputationRoot(hlo);
680         }
681         visited_computations->insert(visited_in_while.begin(),
682                                      visited_in_while.end());
683       } else if (hlo->opcode() == HloOpcode::kFusion) {
684         ResolveInconsistencyOfAliasingBuffersHelper(
685             hlo->fused_instructions_computation(), visited_computations);
686       } else if (hlo->opcode() == HloOpcode::kConditional) {
687         for (auto* branch : hlo->branch_computations()) {
688           ResolveInconsistencyOfAliasingBuffersHelper(branch,
689                                                       visited_computations);
690         }
691       }
692     }
693     if (!any_change) {
694       break;
695     }
696   }
697   // Now adjust parameters of called computations.
698   for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
699     AdjustCalledComputationParameters(*inst_it);
700   }
701   return parameter_changed;
702 }
703 
ResolveInconsistencyOfAliasingBuffers(HloModule * module)704 void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
705     HloModule* module) {
706   const auto& computations_topological_order =
707       module->MakeComputationPostOrder();
708   absl::flat_hash_set<const HloComputation*> resolved;
709   for (auto comp_it = computations_topological_order.rbegin();
710        comp_it != computations_topological_order.rend(); ++comp_it) {
711     if (ContainsKey(resolved, *comp_it)) {
712       continue;
713     }
714     ResolveInconsistencyOfAliasingBuffersHelper(*comp_it, &resolved);
715   }
716 }
717 
ResolveInconsistentFusions(HloModule * module)718 Status BFloat16Propagation::ResolveInconsistentFusions(HloModule* module) {
719   // We could have changed a fusion computation's root shape to have a different
720   // precision than the fusion node's output, if the fusion root does not
721   // define a buffer (e.g., a tuple). Now we add conversions after such fusion
722   // roots to make them match the fusion output. If the fusion output is a
723   // (possibly nested) tuple, we first create get-tuple-elements, then convert
724   // the unmatching leaf nodes, and finally create a new tuple as the fusion
725   // computation's root. If tuples and get-tuple-elements are created, we will
726   // run tuple simplifier and dead code elimination at the end (dead code is not
727   // allowed in fusion computation). E.g.,
728   //
729   // (1)             (2)             (3)
730   // a  b            a  b            a  b
731   // |\ |            |\ |            |\ |
732   // \ add   ->      |add    ->      | add
733   //  \ |            \ |        convert |
734   //  tuple         tuple             \ |
735   //                 / \              tuple
736   //               gte gte
737   //                |   |
738   //           convert  |
739   //                 \  /
740   //                 tuple
741   // (1) a is F32 but tuple is BF16
742   // (2) after adding conversion
743   // (3) after tuple simplifier and DCE.
744   for (auto computation : module->MakeComputationPostOrder()) {
745     auto insts = computation->MakeInstructionPostOrder();
746     for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
747       auto hlo = *inst_it;
748       if (hlo->opcode() != HloOpcode::kFusion) {
749         continue;
750       }
751       auto fusion_computation = hlo->fused_instructions_computation();
752       auto fusion_root = fusion_computation->root_instruction();
753       if (ShapeUtil::Compatible(fusion_root->shape(), hlo->shape())) {
754         continue;
755       }
756       ShapeTree<HloInstruction*> converted_outputs(hlo->shape());
757       // Deep copy the fusion root, and convert a leaf node only if its shape
758       // does not match the fusion output.
759       TF_ASSIGN_OR_RETURN(
760           HloInstruction * copy,
761           fusion_computation->DeepCopyInstructionWithCustomCopier(
762               fusion_root,
763               [hlo](HloInstruction* leaf, const ShapeIndex& leaf_index,
764                     HloComputation* comp) {
765                 const Shape& hlo_subshape =
766                     ShapeUtil::GetSubshape(hlo->shape(), leaf_index);
767                 if (ShapeUtil::Compatible(leaf->shape(), hlo_subshape)) {
768                   return leaf;
769                 }
770                 return comp->AddInstruction(
771                     HloInstruction::CreateConvert(hlo_subshape, leaf));
772               }));
773       fusion_computation->set_root_instruction(copy);
774     }
775   }
776   return Status::OK();
777 }
778 
ResolveConvertedConstants(HloModule * module)779 Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) {
780   // We may have converted some constants from F32 to BF16, so adjust the
781   // constant literals in such cases. We do this here instead of when the
782   // constant node's is changed because 1) the HloInstruction interface does not
783   // allow resetting the literal so we have to create a new kConstant
784   // instruction to replace the old one, which invalidates dataflow analysis,
785   // and 2) it's possible that a kConstant's output gets changed to BF16 at the
786   // beginning but later on adjusted back to F32, so converting literals here
787   // can avoid repeated conversions.
788   //
789   // TODO(b/73833576): Consider resetting literal in HloInstruction.
790   for (auto computation : module->MakeComputationPostOrder()) {
791     for (auto hlo : computation->MakeInstructionPostOrder()) {
792       if (hlo->opcode() != HloOpcode::kConstant) {
793         continue;
794       }
795       if (!Shape::Equal().MinorToMajorOnlyInLayout()(hlo->literal().shape(),
796                                                      hlo->shape())) {
797         TF_ASSIGN_OR_RETURN(auto converted_literal,
798                             hlo->literal().ConvertToShape(hlo->shape()));
799         auto new_constant = computation->AddInstruction(
800             HloInstruction::CreateConstant(std::move(converted_literal)));
801         UpdateLayout(new_constant->mutable_shape());
802         TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant));
803       }
804     }
805   }
806   return Status::OK();
807 }
808 
SkipNoopConversions(HloModule * module)809 Status BFloat16Propagation::SkipNoopConversions(HloModule* module) {
810   for (auto computation : module->computations()) {
811     for (auto hlo : computation->MakeInstructionPostOrder()) {
812       if (hlo->opcode() != HloOpcode::kConvert) {
813         continue;
814       }
815       auto source = hlo->mutable_operand(0);
816       if (!ShapeUtil::Equal(source->shape(), hlo->shape())) {
817         continue;
818       }
819       const bool is_root = hlo == computation->root_instruction();
820       TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(source));
821       if (is_root) {
822         computation->set_root_instruction(source);
823       }
824     }
825   }
826   return Status::OK();
827 }
828 
829 // The algorithm first does a forward pass (parameters to root) to determine a
830 // set of instructions to consider using bfloat16, then does a backward pass to
831 // determine the precisions of those instructions according to the need of
832 // their users. During the backward pass, the potential changes are stored in
833 // changes_to_bf16_ which are subject to further adjustments then applied to the
834 // HLOs.
Run(HloModule * module)835 StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
836   consider_using_bfloat16_.clear();
837   instructions_visited_in_backward_pass_.clear();
838   computations_visited_in_backward_pass_.clear();
839   values_that_must_be_kept_as_f32_.clear();
840   caller_counts_.clear();
841   changes_to_bf16_.clear();
842   changed_ = false;
843 
844   auto computations_topological_order = module->MakeComputationPostOrder();
845 
846   // Before running the propagation pass, we insert copies (kConvert to the same
847   // type) of F32 inputs to while loops. This prevents other uses of the same
848   // input from aliasing the while loop input/output, so that there's greater
849   // chance to use BF16 inside the loop. If some of these added copies do not
850   // help, they will remain F32 after BF16 propagation and will be removed since
851   // they are no-ops.
852   for (auto computation : computations_topological_order) {
853     for (auto inst : computation->MakeInstructionPostOrder()) {
854       if (inst->opcode() != HloOpcode::kWhile) {
855         continue;
856       }
857 
858       auto operand = inst->mutable_operand(0);
859       TF_ASSIGN_OR_RETURN(
860           HloInstruction * copy,
861           computation->DeepCopyInstructionWithCustomCopier(
862               operand, [](HloInstruction* leaf, const ShapeIndex& leaf_index,
863                           HloComputation* comp) {
864                 if (leaf->shape().element_type() != F32) {
865                   return leaf;
866                 }
867                 return comp->AddInstruction(
868                     HloInstruction::CreateConvert(leaf->shape(), leaf));
869               }));
870       TF_RETURN_IF_ERROR(operand->ReplaceUseWith(inst, copy));
871     }
872   }
873 
874   TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module));
875 
876   // The first step is a forward pass (parameters to root), where we determine
877   // the potential candidate instructions to use bfloat16 in the outputs that
878   // are not likely to cause overhead from extra explicit conversions. This is
879   // done forwardly because we determine whether an HLO is a candidate partially
880   // based on whether its operands are candidates.
881   for (auto computation : computations_topological_order) {
882     for (auto inst : computation->MakeInstructionPostOrder()) {
883       if (InstructionIsCandidateForBF16Output(inst)) {
884         consider_using_bfloat16_.insert(inst);
885       }
886     }
887   }
888 
889   // The second step is a backward pass (root to parameters), where we modify
890   // the precisions of the instructions identified in the first step when
891   // feasible. This is done backwardly because we determine the precision of an
892   // HLO's output based on how it is later used.
893   //
894   // The precision of an instruction is determined by its users, so we do the
895   // propagation in reverse topological order.
896   for (auto comp_it = computations_topological_order.rbegin();
897        comp_it != computations_topological_order.rend(); ++comp_it) {
898     if (ContainsKey(computations_visited_in_backward_pass_, *comp_it)) {
899       continue;
900     }
901     auto insts = (*comp_it)->MakeInstructionPostOrder();
902     for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
903       DetermineInstructionPrecision(*inst_it,
904                                     /*skip_parameters=*/true);
905     }
906     computations_visited_in_backward_pass_.insert(*comp_it);
907   }
908 
909   // It's possible that an instruction does not define a buffer, but the
910   // defining instruction's shape has changed. So we need to adjust the output
911   // shapes of instructions according to the HLO values they refer to.
912   ResolveInconsistencyOfAliasingBuffers(module);
913 
914   // Apply the changes in changes_to_bf16_.
915   for (auto& change : changes_to_bf16_) {
916     auto inst = change.first;
917     // It is possible that we marked inst to change precision even if it is an
918     // unsupported change, when inst is the root of a fusion computation and it
919     // has to match the fusion node's output precision. We do a convert instead
920     // of in-place change for such cases.
921     if (ShouldKeepPrecisionUnchanged(inst)) {
922       auto users = inst->users();
923       bool is_root = inst == inst->parent()->root_instruction();
924       TF_ASSIGN_OR_RETURN(
925           HloInstruction * copy,
926           inst->parent()->DeepCopyInstructionWithCustomCopier(
927               inst, [&](HloInstruction* leaf, const ShapeIndex& leaf_index,
928                         HloComputation* comp) {
929                 if (!ContainsKey(change.second,
930                                  ShapeUtil::GetMutableSubshape(
931                                      inst->mutable_shape(), leaf_index))) {
932                   return leaf;
933                 }
934                 auto converted_shape =
935                     ShapeUtil::ChangeElementType(leaf->shape(), BF16);
936                 UpdateLayout(&converted_shape);
937                 return comp->AddInstruction(
938                     HloInstruction::CreateConvert(converted_shape, leaf));
939               }));
940       for (auto user : users) {
941         TF_RETURN_IF_ERROR(inst->ReplaceUseWithDifferentShape(user, copy));
942       }
943       if (is_root) {
944         inst->parent()->set_root_instruction(copy,
945                                              /*accept_different_shape=*/true);
946       }
947       continue;
948     }
949     for (const auto& entry : change.second) {
950       auto subshape = entry.first;
951       CHECK_EQ(subshape->element_type(), F32);
952       subshape->set_element_type(BF16);
953       UpdateLayout(subshape);
954       changed_ = true;
955     }
956   }
957 
958   // Removes redundant HLOs added by this pass, either when inserting
959   // de-aliasing copies to while loop inputs, or later when converting output
960   // types.
961   auto clean_up = [this, module]() {
962     TF_RETURN_IF_ERROR(SkipNoopConversions(module));
963     TupleSimplifier tuple_simplifier;
964     TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
965     HloDCE dce;
966     TF_RETURN_IF_ERROR(dce.Run(module).status());
967     return Status::OK();
968   };
969 
970   if (!changed_) {
971     TF_RETURN_IF_ERROR(clean_up());
972     return false;
973   }
974 
975   TF_RETURN_IF_ERROR(ResolveInconsistentFusions(module));
976   TF_RETURN_IF_ERROR(ResolveConvertedConstants(module));
977 
978   TF_RETURN_IF_ERROR(clean_up());
979   return true;
980 }
981 
OutputTypeAfterChange(HloInstruction * hlo,const ShapeIndex & index) const982 PrimitiveType BFloat16Propagation::OutputTypeAfterChange(
983     HloInstruction* hlo, const ShapeIndex& index) const {
984   Shape* subshape = ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index);
985   const PrimitiveType type_on_hlo = subshape->element_type();
986   if (type_on_hlo != F32) {
987     return type_on_hlo;
988   }
989   auto it = changes_to_bf16_.find(hlo);
990   if (it == changes_to_bf16_.end()) {
991     return type_on_hlo;
992   }
993   return ContainsKey(it->second, subshape) ? BF16 : F32;
994 }
995 
ValueTypeAfterChange(const HloValue * value) const996 PrimitiveType BFloat16Propagation::ValueTypeAfterChange(
997     const HloValue* value) const {
998   auto hlo = value->defining_instruction();
999   const auto& position = value->defining_position();
1000   return OutputTypeAfterChange(hlo, position.index);
1001 }
1002 
AddToOrRemoveFromBF16ChangeSet(HloInstruction * hlo,const ShapeIndex & index,PrimitiveType target_type)1003 void BFloat16Propagation::AddToOrRemoveFromBF16ChangeSet(
1004     HloInstruction* hlo, const ShapeIndex& index, PrimitiveType target_type) {
1005   if (target_type == BF16) {
1006     auto& entry = changes_to_bf16_[hlo];
1007     entry.emplace(ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index),
1008                   index);
1009   } else {
1010     CHECK_EQ(target_type, F32);
1011     auto it = changes_to_bf16_.find(hlo);
1012     if (it == changes_to_bf16_.end()) {
1013       return;
1014     }
1015     it->second.erase(
1016         ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index));
1017   }
1018 }
1019 
1020 }  // namespace xla
1021