1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/bfloat16_propagation.h"
17 
18 #include "absl/container/flat_hash_set.h"
19 #include "tensorflow/compiler/xla/literal.h"
20 #include "tensorflow/compiler/xla/map_util.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_dce.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_module.h"
25 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
26 #include "tensorflow/compiler/xla/shape_tree.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/core/lib/gtl/cleanup.h"
29 #include "tensorflow/core/platform/logging.h"
30 
31 namespace xla {
32 
BFloat16Propagation(const BFloat16Support * bfloat16_support)33 BFloat16Propagation::BFloat16Propagation(
34     const BFloat16Support* bfloat16_support)
35     : bfloat16_support_(bfloat16_support) {}
36 
DetermineFusionComputationPrecision(HloInstruction * fusion)37 void BFloat16Propagation::DetermineFusionComputationPrecision(
38     HloInstruction* fusion) {
39   CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
40   if (!bfloat16_support_->SupportsMixedPrecisions(*fusion)) {
41     return;
42   }
43 
44   // We are depending on the fusion node itself having already been analyzed
45   // for whether it can output BF16 and this has been adjusted in the output
46   // shape, and now we're looking to update the interior of the fusion node to
47   // match the new output shape, as well as recursively process the whole fusion
48   // node even if the output shape was not modified.
49   auto root = fusion->fused_instructions_computation()->root_instruction();
50 
51   // Adjust root's element types according to the fusion's output shape.
52   ShapeUtil::ForEachSubshape(
53       root->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
54         if (subshape.element_type() != F32) {
55           return;
56         }
57         if (OutputTypeAfterChange(fusion, index) == BF16) {
58           AddToOrRemoveFromBF16ChangeSet(root, index, BF16);
59           VLOG(2) << "Fused root " << root->ToString() << " at shape index "
60                   << index << " changed to BF16 precision for fusion "
61                   << fusion->ToString();
62         }
63       });
64 
65   // Propagate BF16 in the fusion computation.
66   auto insts =
67       fusion->fused_instructions_computation()->MakeInstructionPostOrder();
68   for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
69     DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
70   }
71   computations_visited_in_backward_pass_.insert(
72       fusion->fused_instructions_computation());
73 
74   RevertIfFusionInternalBF16Changes(fusion);
75 }
76 
RevertIfFusionInternalBF16Changes(HloInstruction * fusion)77 void BFloat16Propagation::RevertIfFusionInternalBF16Changes(
78     HloInstruction* fusion) {
79   auto has_changes = [this](HloInstruction* inst) {
80     auto it = changes_to_bf16_.find(inst);
81     return it != changes_to_bf16_.end() && !it->second.empty();
82   };
83 
84   auto root = fusion->fused_instructions_computation()->root_instruction();
85   absl::flat_hash_set<const HloValue*> changed_root_buffers;
86 
87   auto root_changes_it = changes_to_bf16_.find(root);
88   if (root_changes_it != changes_to_bf16_.end()) {
89     for (const auto& entry : root_changes_it->second) {
90       for (const HloValue* value :
91            dataflow_->GetValueSet(root, entry.second).values()) {
92         changed_root_buffers.insert(value);
93       }
94     }
95   }
96 
97   auto aliases_changed_root_buffer =
98       [this, &changed_root_buffers](const HloInstruction* inst) {
99         bool aliasing = false;
100         ShapeUtil::ForEachSubshape(
101             inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
102               if (aliasing) {
103                 // Skip if aliasing is already found.
104                 return;
105               }
106               // Only F32 buffers are considered for changing to BF16 in this
107               // pass.
108               if (subshape.element_type() != F32) {
109                 return;
110               }
111               for (const HloValue* value :
112                    dataflow_->GetValueSet(inst, index).values()) {
113                 if (ContainsKey(changed_root_buffers, value)) {
114                   aliasing = true;
115                   break;
116                 }
117               }
118             });
119         return aliasing;
120       };
121 
122   for (auto inst :
123        fusion->fused_instructions_computation()->MakeInstructionPostOrder()) {
124     if (inst->opcode() == HloOpcode::kParameter) {
125       continue;
126     }
127     if (aliases_changed_root_buffer(inst)) {
128       continue;
129     }
130     if (inst->opcode() == HloOpcode::kFusion) {
131       bool parameter_reverted = false;
132       for (int64 i = 0; i < inst->operand_count(); ++i) {
133         if (has_changes(inst->mutable_operand(i))) {
134           // Changes on the operand have not been reverted.
135           continue;
136         }
137         auto* fused_parameter = inst->fused_parameter(i);
138         if (has_changes(fused_parameter)) {
139           changes_to_bf16_.erase(fused_parameter);
140           parameter_reverted = true;
141         }
142       }
143       if (parameter_reverted) {
144         RevertIfFusionInternalBF16Changes(inst);
145       }
146     }
147     if (!has_changes(inst)) {
148       continue;
149     }
150     bool revert_changes = true;
151     for (auto operand : inst->operands()) {
152       if (has_changes(operand)) {
153         revert_changes = false;
154         break;
155       }
156     }
157     if (revert_changes) {
158       changes_to_bf16_.erase(inst);
159     }
160   }
161 }
162 
DetermineWhileComputationsPrecision(HloInstruction * while_hlo)163 void BFloat16Propagation::DetermineWhileComputationsPrecision(
164     HloInstruction* while_hlo) {
165   CHECK_EQ(while_hlo->opcode(), HloOpcode::kWhile);
166 
167   // We are depending on the while node itself having already been analyzed for
168   // whether it can output BF16 and this has been adjusted in the output shape,
169   // and now we're looking to update the body and condition computations to
170   // match the new output shape, as well as recursively process the whole while
171   // node even if the output shape was not modified.
172   HloComputation* body = while_hlo->while_body();
173   auto body_root = body->root_instruction();
174   HloComputation* condition = while_hlo->while_condition();
175 
176   ShapeUtil::ForEachSubshape(
177       body_root->shape(), [this, while_hlo, body_root](
178                               const Shape& subshape, const ShapeIndex& index) {
179         if (subshape.element_type() != F32) {
180           return;
181         }
182         if (OutputTypeAfterChange(while_hlo, index) == BF16) {
183           AddToOrRemoveFromBF16ChangeSet(body_root, index, BF16);
184           VLOG(2) << "While body root " << body_root->ToString()
185                   << " at shape index " << index
186                   << " changed to BF16 precision for while "
187                   << while_hlo->ToString();
188         }
189       });
190 
191   auto body_insts = body->MakeInstructionPostOrder();
192   for (auto inst_it = body_insts.rbegin(); inst_it != body_insts.rend();
193        ++inst_it) {
194     DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
195   }
196   computations_visited_in_backward_pass_.insert(body);
197 
198   auto condition_insts = condition->MakeInstructionPostOrder();
199   for (auto inst_it = condition_insts.rbegin();
200        inst_it != condition_insts.rend(); ++inst_it) {
201     DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
202   }
203   computations_visited_in_backward_pass_.insert(condition);
204 }
205 
AllUsersConsumeBF16(const HloInstruction & hlo,const ShapeIndex & index) const206 bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
207                                               const ShapeIndex& index) const {
208   // If the subshape isn't floating point then none of the users will be BF16.
209   const Shape& subshape = ShapeUtil::GetSubshape(hlo.shape(), index);
210   if (subshape.element_type() != BF16 && subshape.element_type() != F32) {
211     return false;
212   }
213 
214   auto& value_set = dataflow_->GetValueSet(&hlo, index);
215   for (const HloValue* value : value_set.values()) {
216     if (ContainsKey(values_that_must_be_kept_as_f32_, value)) {
217       return false;
218     }
219     // We use the original type for the value because we are going to examine
220     // the uses of it, instead of the value itself. If ValueTypeAfterChange()
221     // were used, it would cause problems when there are aliasing buffers, i.e.,
222     // ResolveInconsistencyOfAliasingBuffers() would fail to revert the
223     // tentative change to BF16 even if the uses require F32.
224     if (value->shape().element_type() == BF16) {
225       continue;
226     }
227     for (const HloUse& use : value->uses()) {
228       if (!ContainsKey(instructions_visited_in_backward_pass_,
229                        use.instruction)) {
230         // We don't know yet whether use.instruction will consume BF16 since it
231         // hasn't been visited. Although we visit instructions in reverse
232         // topological order, this is still possible because there may be
233         // unvisited instruction that alias the same buffer. In this case, we
234         // aggressively skip this use, and if this causes inconsistency (e.g.,
235         // one use is in BF16 but another use is in F32), it will be resolved at
236         // the end of the BFloat16Propagation pass.
237         continue;
238       }
239       if (use.instruction->HasSideEffectNoRecurse()) {
240         // Keep side-effecting instruction's operands unchanged.
241         return false;
242       }
243       // Any visited user that can accept BF16 has already been updated if
244       // necessary, e.g., the output has been changed to BF16 if it propagates
245       // precision, or a called computation's parameters have been changed to
246       // BF16 for fusions or whiles.
247       if (use.instruction->opcode() == HloOpcode::kFusion) {
248         auto* fused_parameter =
249             use.instruction->fused_parameter(use.operand_number);
250         if (OutputTypeAfterChange(fused_parameter, use.operand_index) != BF16) {
251           return false;
252         }
253         continue;
254       } else if (use.instruction->opcode() == HloOpcode::kWhile) {
255         auto* cond_parameter =
256             use.instruction->while_condition()->parameter_instruction(
257                 use.operand_number);
258         if (OutputTypeAfterChange(cond_parameter, use.operand_index) != BF16) {
259           return false;
260         }
261         auto* body_parameter =
262             use.instruction->while_body()->parameter_instruction(
263                 use.operand_number);
264         if (OutputTypeAfterChange(body_parameter, use.operand_index) != BF16) {
265           return false;
266         }
267         continue;
268       }
269       if (bfloat16_support_->EffectiveOperandPrecisionIsBF16(
270               *use.instruction, use.operand_number)) {
271         continue;
272       }
273       // If the op propagates precision and it outputs a BF16, then it's OK to
274       // supply BF16 also as the input. In the backward pass, the users shapes
275       // should have already been processed.
276       if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(
277               *use.instruction, use.operand_number)) {
278         if (use.instruction->opcode() == HloOpcode::kTuple ||
279             (use.instruction->opcode() == HloOpcode::kAllReduce &&
280              use.instruction->shape().IsTuple())) {
281           ShapeIndex use_output_index{use.operand_number};
282           for (int64 i : use.operand_index) {
283             use_output_index.push_back(i);
284           }
285           if (OutputTypeAfterChange(use.instruction, use_output_index) ==
286               BF16) {
287             continue;
288           }
289         } else if (use.instruction->opcode() == HloOpcode::kGetTupleElement) {
290           ShapeIndex use_output_index;
291           for (int64 i = 1; i < use.operand_index.size(); ++i) {
292             use_output_index.push_back(use.operand_index[i]);
293           }
294           if (OutputTypeAfterChange(use.instruction, use_output_index) ==
295               BF16) {
296             continue;
297           }
298         } else {
299           if (OutputTypeAfterChange(use.instruction, use.operand_index) ==
300               BF16) {
301             continue;
302           }
303         }
304       }
305       return false;
306     }
307   }
308   return true;
309 }
310 
DetermineInstructionPrecision(HloInstruction * hlo,bool skip_parameters)311 void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo,
312                                                         bool skip_parameters) {
313   // We handle any fusion computation or while body/condition after the
314   // instruction is handled, because we need to know the output shape of a
315   // fusion or while before propagating inside its  computations.
316   bool postpone_processing_called_computations = false;
317   auto cleaner = tensorflow::gtl::MakeCleanup(
318       [this, hlo, &postpone_processing_called_computations] {
319         if (!postpone_processing_called_computations) {
320           if (hlo->opcode() == HloOpcode::kFusion) {
321             DetermineFusionComputationPrecision(hlo);
322           } else if (hlo->opcode() == HloOpcode::kWhile) {
323             DetermineWhileComputationsPrecision(hlo);
324           }
325         }
326         instructions_visited_in_backward_pass_.insert(hlo);
327       });
328 
329   if (hlo->opcode() == HloOpcode::kWhile &&
330       (caller_counts_[hlo->while_condition()] > 1 ||
331        caller_counts_[hlo->while_body()] > 1)) {
332     postpone_processing_called_computations = true;
333     return;
334   }
335 
336   // Prevent root instructions from having their output modified by recording
337   // all F32 output values as needing to stay as F32.
338   CHECK(hlo->parent() != nullptr);
339   if (hlo == hlo->parent()->root_instruction()) {
340     if (!hlo->parent()->IsFusionComputation()) {
341       ShapeUtil::ForEachSubshape(hlo->shape(), [&](const Shape& /* subshape */,
342                                                    const ShapeIndex& index) {
343         if (OutputTypeAfterChange(hlo, index) != F32) {
344           return;
345         }
346         for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
347           // Since we use HloValues from the dataflow analysis, this can also
348           // affect HLO instructions beyond the root, e.g., if the root is a
349           // Tuple HLO, then its operands are also affected.
350           values_that_must_be_kept_as_f32_.insert(value);
351         }
352       });
353     }
354     return;
355   }
356 
357   // Do not change precision for instructions related to entry and exit of a
358   // computation, side-effecting instructions, and control flow, because this
359   // pass might break the interfaces or assumptions for them.
360   if (hlo->opcode() == HloOpcode::kCustomCall ||   //
361       hlo->opcode() == HloOpcode::kCall ||         //
362       hlo->opcode() == HloOpcode::kConditional ||  //
363       hlo->HasSideEffectNoRecurse() ||             //
364       (hlo->opcode() == HloOpcode::kParameter && skip_parameters)) {
365     return;
366   }
367 
368   if (!ContainsKey(consider_using_bfloat16_, hlo)) {
369     return;
370   }
371 
372   if (!bfloat16_support_->SupportsBF16Output(*hlo)) {
373     return;
374   }
375 
376   ShapeUtil::ForEachSubshape(
377       hlo->shape(),
378       [hlo, this](const Shape& /* subshape */, const ShapeIndex& index) {
379         if (OutputTypeAfterChange(hlo, index) == F32 &&
380             AllUsersConsumeBF16(*hlo, index)) {
381           AddToOrRemoveFromBF16ChangeSet(hlo, index, BF16);
382           VLOG(2) << "HloInstruction output at shape index " << index
383                   << " changed to BF16 precision: " << hlo->ToString();
384         }
385       });
386 }
387 
InstructionIsCandidateForBF16Output(HloInstruction * hlo)388 bool BFloat16Propagation::InstructionIsCandidateForBF16Output(
389     HloInstruction* hlo) {
390   if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) &&
391       hlo->opcode() != HloOpcode::kTuple &&
392       hlo->opcode() != HloOpcode::kGetTupleElement &&
393       hlo->opcode() != HloOpcode::kDomain &&
394       hlo->shape().element_type() != BF16) {
395     for (int64 i = 0; i < hlo->operand_count(); ++i) {
396       if (!bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo,
397                                                                          i) ||
398           !ContainsKey(consider_using_bfloat16_, hlo->operand(i))) {
399         return false;
400       }
401     }
402   }
403   return true;
404 }
405 
AdjustCalledComputationParameters(HloInstruction * hlo)406 void BFloat16Propagation::AdjustCalledComputationParameters(
407     HloInstruction* hlo) {
408   auto adjust_computation =
409       [this, hlo](HloComputation* computation,
410                   absl::Span<HloInstruction* const> operands) {
411         // Adjust parameters.
412         CHECK_EQ(operands.size(), computation->num_parameters());
413         for (int64 i = 0; i < operands.size(); ++i) {
414           auto parameter = computation->parameter_instruction(i);
415           ShapeUtil::ForEachSubshape(
416               parameter->shape(),
417               [this, i, hlo, &operands, parameter](const Shape& /* subshape */,
418                                                    const ShapeIndex& index) {
419                 if (!ShapeUtil::IsLeafIndex(parameter->shape(), index)) {
420                   return;
421                 }
422                 PrimitiveType operand_type =
423                     OutputTypeAfterChange(operands[i], index);
424                 if (OutputTypeAfterChange(parameter, index) == operand_type) {
425                   return;
426                 }
427                 AddToOrRemoveFromBF16ChangeSet(parameter, index, operand_type);
428                 VLOG(2) << "Called computation parameter "
429                         << parameter->ToString() << " at shape index " << index
430                         << " adjusted to "
431                         << (operand_type == BF16 ? "BF16" : "F32")
432                         << " to match operand in HLO " << hlo->ToString();
433               });
434         }
435       };
436 
437   switch (hlo->opcode()) {
438     case HloOpcode::kFusion:
439       adjust_computation(hlo->fused_instructions_computation(),
440                          hlo->operands());
441       break;
442     case HloOpcode::kWhile:
443       adjust_computation(hlo->while_condition(), hlo->operands());
444       adjust_computation(hlo->while_body(), hlo->operands());
445       break;
446     default:
447       break;
448   }
449 }
450 
AdjustCalledComputationRoot(HloInstruction * hlo)451 void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) {
452   auto adjust_computation = [this, hlo](HloComputation* computation,
453                                         HloInstruction* output) {
454     // Adjust root.
455     HloInstruction* root = computation->root_instruction();
456     ShapeUtil::ForEachSubshape(root->shape(), [this, hlo, root, output](
457                                                   const Shape& /* subshape */,
458                                                   const ShapeIndex& index) {
459       if (!ShapeUtil::IsLeafIndex(hlo->shape(), index)) {
460         return;
461       }
462       const PrimitiveType output_type = OutputTypeAfterChange(output, index);
463       if (OutputTypeAfterChange(root, index) == output_type) {
464         return;
465       }
466       AddToOrRemoveFromBF16ChangeSet(root, index, output_type);
467       // It's possible that output_type is F32, but the root instruction's
468       // type is BF16; e.g., a fusion node's output was changed to BF16
469       // initially but then adjusted back to F32, and the fusion computation
470       // is now being adjusted after the fusion node.
471       if (output_type == F32) {
472         for (const auto* value : dataflow_->GetValueSet(root, index).values()) {
473           // We rely on the fact that this adjustment works in reverse
474           // topological order so that called computation will be
475           // processed later. Adding the value to
476           // values_that_must_be_kept_as_f32_ will ensure the
477           // correctness of the adjustment for HLOs that will be
478           // processed later.
479           values_that_must_be_kept_as_f32_.insert(value);
480         }
481       }
482       VLOG(2) << "Called computation root " << root->ToString()
483               << " at shape index " << index << " adjusted to "
484               << (output_type == BF16 ? "BF16" : "F32")
485               << " to match output shape of " << hlo->ToString();
486     });
487   };
488 
489   switch (hlo->opcode()) {
490     case HloOpcode::kFusion:
491       adjust_computation(hlo->fused_instructions_computation(), hlo);
492       break;
493     case HloOpcode::kWhile:
494       adjust_computation(hlo->while_body(), hlo);
495       break;
496     default:
497       break;
498   }
499 }
500 
ResolveInconsistencyOfAliasingBuffersHelper(HloComputation * computation,absl::flat_hash_set<const HloComputation * > * visited_computations)501 bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
502     HloComputation* computation,
503     absl::flat_hash_set<const HloComputation*>* visited_computations) {
504   bool parameter_changed = false;
505   auto insts = computation->MakeInstructionPostOrder();
506   // Do the adjustment on each instruction in the computation in reverse
507   // topological order.
508   for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
509     auto hlo = *inst_it;
510     auto adjust_hlo_output = [this, hlo, &parameter_changed](
511                                  const Shape& /* subshape */,
512                                  const ShapeIndex& index) {
513       auto output_type = OutputTypeAfterChange(hlo, index);
514       if (output_type != F32 && output_type != BF16) {
515         return;
516       }
517       PrimitiveType type = BF16;
518       for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
519         auto value_type = ValueTypeAfterChange(value);
520         if (value_type == BF16) {
521           continue;
522         }
523         CHECK_EQ(value_type, F32);
524         type = F32;
525         break;
526       }
527       // It's possible that a user has been changed from BF16 to F32
528       // during this final adjustment pass, so we need to check
529       // AllUsersConsumeBF16() again.
530       if (type == BF16 && !AllUsersConsumeBF16(*hlo, index)) {
531         type = F32;
532       }
533       if (type == F32) {
534         for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
535           // We rely on the fact that this adjustment works in reverse
536           // topological order. Adding the value to
537           // values_that_must_be_kept_as_f32_ will ensure the correctness
538           // of the adjustment for HLOs that will be processed later.
539           values_that_must_be_kept_as_f32_.insert(value);
540         }
541       }
542       if (type != output_type) {
543         AddToOrRemoveFromBF16ChangeSet(hlo, index, type);
544         VLOG(2) << "HloInstruction output at shape index " << index
545                 << " adjusted to " << (type == BF16 ? "BF16" : "F32") << ": "
546                 << hlo->ToString();
547         if (hlo->opcode() == HloOpcode::kParameter) {
548           parameter_changed = true;
549         }
550       }
551     };
552     ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
553     AdjustCalledComputationRoot(hlo);
554     if (hlo->opcode() == HloOpcode::kWhile) {
555       // We need to run on the while body and condition repeatedly until a fixed
556       // point is reached, i.e., the parameters do not change any more. We may
557       // need more than one iteration because the while input and output alias
558       // each other, so changing one input parameter requires changing the
559       // corresponding output element and thus may transitively require changing
560       // another input parameter. A fixed point will be reached because the
561       // parameters can only be changed from BF16 to F32, not the other way
562       // around.
563       absl::flat_hash_set<const HloComputation*> visited_in_while;
564       while (ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_condition(),
565                                                          &visited_in_while) ||
566              ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(),
567                                                          &visited_in_while)) {
568         visited_in_while.clear();
569         ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
570         AdjustCalledComputationRoot(hlo);
571       }
572       visited_computations->insert(visited_in_while.begin(),
573                                    visited_in_while.end());
574     } else if (hlo->opcode() == HloOpcode::kFusion) {
575       ResolveInconsistencyOfAliasingBuffersHelper(
576           hlo->fused_instructions_computation(), visited_computations);
577     }
578   }
579   // Now adjust parameters of called computations.
580   for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
581     AdjustCalledComputationParameters(*inst_it);
582   }
583   return parameter_changed;
584 }
585 
ResolveInconsistencyOfAliasingBuffers(HloModule * module)586 void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
587     HloModule* module) {
588   const auto& computations_topological_order =
589       module->MakeComputationPostOrder();
590   absl::flat_hash_set<const HloComputation*> resolved;
591   for (auto comp_it = computations_topological_order.rbegin();
592        comp_it != computations_topological_order.rend(); ++comp_it) {
593     if (ContainsKey(resolved, *comp_it)) {
594       continue;
595     }
596     ResolveInconsistencyOfAliasingBuffersHelper(*comp_it, &resolved);
597   }
598 }
599 
ResolveInconsistentFusions(HloModule * module)600 Status BFloat16Propagation::ResolveInconsistentFusions(HloModule* module) {
601   // We could have changed a fusion computation's root shape to have a different
602   // precision than the fusion node's output, if the fusion root does not
603   // define a buffer (e.g., a tuple). Now we add conversions after such fusion
604   // roots to make them match the fusion output. If the fusion output is a
605   // (possibly nested) tuple, we first create get-tuple-elements, then convert
606   // the unmatching leaf nodes, and finally create a new tuple as the fusion
607   // computation's root. If tuples and get-tuple-elements are created, we will
608   // run tuple simplifier and dead code elimination at the end (dead code is not
609   // allowed in fusion computation). E.g.,
610   //
611   // (1)             (2)             (3)
612   // a  b            a  b            a  b
613   // |\ |            |\ |            |\ |
614   // \ add   ->      |add    ->      | add
615   //  \ |            \ |        convert |
616   //  tuple         tuple             \ |
617   //                 / \              tuple
618   //               gte gte
619   //                |   |
620   //           convert  |
621   //                 \  /
622   //                 tuple
623   // (1) a is F32 but tuple is BF16
624   // (2) after adding conversion
625   // (3) after tuple simplifier and DCE.
626   for (auto computation : module->MakeComputationPostOrder()) {
627     auto insts = computation->MakeInstructionPostOrder();
628     for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
629       auto hlo = *inst_it;
630       if (hlo->opcode() != HloOpcode::kFusion) {
631         continue;
632       }
633       auto fusion_computation = hlo->fused_instructions_computation();
634       auto fusion_root = fusion_computation->root_instruction();
635       if (ShapeUtil::Compatible(fusion_root->shape(), hlo->shape())) {
636         continue;
637       }
638       ShapeTree<HloInstruction*> converted_outputs(hlo->shape());
639       // Deep copy the fusion root, and convert a leaf node only if its shape
640       // does not match the fusion output.
641       TF_ASSIGN_OR_RETURN(
642           HloInstruction * copy,
643           fusion_computation->DeepCopyInstructionWithCustomCopier(
644               fusion_root,
645               [hlo](HloInstruction* leaf, const ShapeIndex& leaf_index,
646                     HloComputation* comp) {
647                 const Shape& hlo_subshape =
648                     ShapeUtil::GetSubshape(hlo->shape(), leaf_index);
649                 if (ShapeUtil::Compatible(leaf->shape(), hlo_subshape)) {
650                   return leaf;
651                 }
652                 return comp->AddInstruction(
653                     HloInstruction::CreateConvert(hlo_subshape, leaf));
654               }));
655       fusion_computation->set_root_instruction(copy);
656     }
657   }
658   return Status::OK();
659 }
660 
ResolveConvertedConstants(HloModule * module)661 Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) {
662   // We may have converted some constants from F32 to BF16, so adjust the
663   // constant literals in such cases. We do this here instead of when the
664   // constant node's is changed because 1) the HloInstruction interface does not
665   // allow resetting the literal so we have to create a new kConstant
666   // instruction to replace the old one, which invalidates dataflow analysis,
667   // and 2) it's possible that a kConstant's output gets changed to BF16 at the
668   // beginning but later on adjusted back to F32, so converting literals here
669   // can avoid repeated conversions.
670   //
671   // TODO(b/73833576): Consider resetting literal in HloInstruction.
672   for (auto computation : module->MakeComputationPostOrder()) {
673     for (auto hlo : computation->MakeInstructionPostOrder()) {
674       if (hlo->opcode() != HloOpcode::kConstant) {
675         continue;
676       }
677       if (!ShapeUtil::Equal(hlo->literal().shape(), hlo->shape())) {
678         TF_ASSIGN_OR_RETURN(auto converted_literal,
679                             hlo->literal().ConvertToShape(hlo->shape()));
680         auto new_constant = computation->AddInstruction(
681             HloInstruction::CreateConstant(std::move(converted_literal)));
682         TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant));
683       }
684     }
685   }
686   return Status::OK();
687 }
688 
SkipNoopConversions(HloModule * module)689 Status BFloat16Propagation::SkipNoopConversions(HloModule* module) {
690   for (auto computation : module->computations()) {
691     for (auto hlo : computation->MakeInstructionPostOrder()) {
692       if (hlo->opcode() != HloOpcode::kConvert) {
693         continue;
694       }
695       auto source = hlo->mutable_operand(0);
696       if (!ShapeUtil::Equal(source->shape(), hlo->shape())) {
697         continue;
698       }
699       const bool is_root = hlo == computation->root_instruction();
700       TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(source));
701       if (is_root) {
702         computation->set_root_instruction(source);
703       }
704     }
705   }
706   return Status::OK();
707 }
708 
709 // The algorithm first does a forward pass (parameters to root) to determine a
710 // set of instructions to consider using bfloat16, then does a backward pass to
711 // determine the precisions of those instructions according to the need of
712 // their users. During the backward pass, the potential changes are stored in
713 // changes_to_bf16_ which are subject to further adjustments then applied to the
714 // HLOs.
Run(HloModule * module)715 StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
716   consider_using_bfloat16_.clear();
717   instructions_visited_in_backward_pass_.clear();
718   computations_visited_in_backward_pass_.clear();
719   values_that_must_be_kept_as_f32_.clear();
720   caller_counts_.clear();
721   changes_to_bf16_.clear();
722   changed_ = false;
723 
724   auto computations_topological_order = module->MakeComputationPostOrder();
725 
726   // Before running the propagation pass, we insert copies (kConvert to the same
727   // type) of F32 inputs to while loops. This prevents other uses of the same
728   // input from aliasing the while loop input/output, so that there's greater
729   // chance to use BF16 inside the loop. If some of these added copies do not
730   // help, they will remain F32 after BF16 propagation and will be removed since
731   // they are no-ops.
732   for (auto computation : computations_topological_order) {
733     for (auto inst : computation->MakeInstructionPostOrder()) {
734       if (inst->opcode() != HloOpcode::kWhile) {
735         continue;
736       }
737 
738       auto operand = inst->mutable_operand(0);
739       TF_ASSIGN_OR_RETURN(
740           HloInstruction * copy,
741           computation->DeepCopyInstructionWithCustomCopier(
742               operand, [](HloInstruction* leaf, const ShapeIndex& leaf_index,
743                           HloComputation* comp) {
744                 if (leaf->shape().element_type() != F32) {
745                   return leaf;
746                 }
747                 return comp->AddInstruction(
748                     HloInstruction::CreateConvert(leaf->shape(), leaf));
749               }));
750       TF_RETURN_IF_ERROR(operand->ReplaceUseWith(inst, copy));
751     }
752   }
753 
754   TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module));
755 
756   // The first step is a forward pass (parameters to root), where we determine
757   // the potential candidate instructions to use bfloat16 in the outputs that
758   // are not likely to cause overhead from extra explicit conversions. This is
759   // done forwardly because we determine whether an HLO is a candidate partially
760   // based on whether its operands are candidates.
761   for (auto computation : computations_topological_order) {
762     for (auto inst : computation->MakeInstructionPostOrder()) {
763       if (InstructionIsCandidateForBF16Output(inst)) {
764         consider_using_bfloat16_.insert(inst);
765       }
766     }
767   }
768 
769   // The second step is a backward pass (root to parameters), where we modify
770   // the precisions of the instructions identified in the first step when
771   // feasible. This is done backwardly because we determine the precision of an
772   // HLO's output based on how it is later used.
773   //
774   // The precision of an instruction is determined by its users, so we do the
775   // propagation in reverse topological order.
776   for (auto comp_it = computations_topological_order.rbegin();
777        comp_it != computations_topological_order.rend(); ++comp_it) {
778     if (ContainsKey(computations_visited_in_backward_pass_, *comp_it)) {
779       continue;
780     }
781     auto insts = (*comp_it)->MakeInstructionPostOrder();
782     for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
783       DetermineInstructionPrecision(*inst_it,
784                                     /*skip_parameters=*/true);
785     }
786     computations_visited_in_backward_pass_.insert(*comp_it);
787   }
788 
789   // It's possible that an instruction does not define a buffer, but the
790   // defining instruction's shape has changed. So we need to adjust the output
791   // shapes of instructions according to the HLO values they refer to.
792   ResolveInconsistencyOfAliasingBuffers(module);
793 
794   // Apply the changes in changes_to_bf16_.
795   for (auto& change : changes_to_bf16_) {
796     for (const auto& entry : change.second) {
797       auto subshape = entry.first;
798       CHECK_EQ(subshape->element_type(), F32);
799       subshape->set_element_type(BF16);
800       changed_ = true;
801     }
802   }
803 
804   // Removes redundant HLOs added by this pass, either when inserting
805   // de-aliasing copies to while loop inputs, or later when converting output
806   // types.
807   auto clean_up = [this, module]() {
808     TF_RETURN_IF_ERROR(SkipNoopConversions(module));
809     TupleSimplifier tuple_simplifier;
810     TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
811     HloDCE dce;
812     TF_RETURN_IF_ERROR(dce.Run(module).status());
813     return Status::OK();
814   };
815 
816   if (!changed_) {
817     TF_RETURN_IF_ERROR(clean_up());
818     return false;
819   }
820 
821   TF_RETURN_IF_ERROR(ResolveInconsistentFusions(module));
822   TF_RETURN_IF_ERROR(ResolveConvertedConstants(module));
823 
824   TF_RETURN_IF_ERROR(clean_up());
825   return true;
826 }
827 
OutputTypeAfterChange(HloInstruction * hlo,const ShapeIndex & index) const828 PrimitiveType BFloat16Propagation::OutputTypeAfterChange(
829     HloInstruction* hlo, const ShapeIndex& index) const {
830   Shape* subshape = ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index);
831   const PrimitiveType type_on_hlo = subshape->element_type();
832   if (type_on_hlo != F32) {
833     return type_on_hlo;
834   }
835   auto it = changes_to_bf16_.find(hlo);
836   if (it == changes_to_bf16_.end()) {
837     return type_on_hlo;
838   }
839   return ContainsKey(it->second, subshape) ? BF16 : F32;
840 }
841 
ValueTypeAfterChange(const HloValue * value) const842 PrimitiveType BFloat16Propagation::ValueTypeAfterChange(
843     const HloValue* value) const {
844   auto hlo = value->defining_instruction();
845   const auto& position = value->defining_position();
846   return OutputTypeAfterChange(hlo, position.index);
847 }
848 
AddToOrRemoveFromBF16ChangeSet(HloInstruction * hlo,const ShapeIndex & index,PrimitiveType target_type)849 void BFloat16Propagation::AddToOrRemoveFromBF16ChangeSet(
850     HloInstruction* hlo, const ShapeIndex& index, PrimitiveType target_type) {
851   if (target_type == BF16) {
852     auto& entry = changes_to_bf16_[hlo];
853     entry.emplace(ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index),
854                   index);
855   } else {
856     CHECK_EQ(target_type, F32);
857     auto it = changes_to_bf16_.find(hlo);
858     if (it == changes_to_bf16_.end()) {
859       return;
860     }
861     it->second.erase(
862         ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index));
863   }
864 }
865 
866 }  // namespace xla
867