1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_ 18 19 #include <memory> 20 #include <unordered_map> 21 #include <unordered_set> 22 #include <vector> 23 24 #include "absl/container/flat_hash_map.h" 25 #include "absl/container/flat_hash_set.h" 26 #include "tensorflow/compiler/xla/service/bfloat16_support.h" 27 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" 28 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 29 #include "tensorflow/compiler/xla/service/hlo_module.h" 30 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 31 #include "tensorflow/core/lib/hash/hash.h" 32 33 namespace xla { 34 35 // HLO pass which reduces the precision of some HLO instructions to BF16 36 // according to the backend-specific BFloat16Support rule provided by the 37 // caller. 38 // 39 // This pass can be used to reduce instruction precision without affecting the 40 // numerical accuracy of the module, i.e., the final output of the module would 41 // be bitwise identical to that without this pass; this is possible if the 42 // backend already reduces precision to BF16 on some HLO instructions. 43 // 44 // This pass will not modify the signature of a computation, unless it is a 45 // fusion computation or its only caller is a while. 46 // 47 // !!! WARNING !!! This pass can introduce mixed precision in individual HLOs, 48 // which has two issues: 49 // 50 // 1) It does not guarantee to respect the passed-in BFloat16Support 51 // specification in terms of mixed precision, so the backend may not support an 52 // HLO that has mixed precision produced by this pass. To address this issue, 53 // run BFloat16Normalization with the same BFloat16Support after this pass. 54 // 55 // 2) In general, mixed precision may break the assumptions of some other HLO 56 // passes even if the specific backend supports the individual HLOs. Such 57 // assumptions include that there are no HLOs using mixed precision, or that the 58 // precision of an HLO's output is determined by its inputs. It should be used 59 // at the end of the HLO optimization pipeline but before 60 // BFloat16ConversionFolding. If other passes are needed after this pass, run 61 // BFloat16MixedPrecisionRemoval first to undo some of the changes made by this 62 // pass. 63 class BFloat16Propagation : public HloModulePass { 64 public: 65 explicit BFloat16Propagation(const BFloat16Support* bfloat16_support); 66 67 ~BFloat16Propagation() override = default; 68 name()69 absl::string_view name() const override { return "bfloat16-propagation"; } 70 71 // Runs the pass on the given module. Returns whether the module was changed 72 // (precision reductions were added). 73 StatusOr<bool> Run(HloModule* module) override; 74 75 private: 76 // *************************** 77 // Function called and state produced by the forward analysis pass (from 78 // parameters to root) that determines the candidate HLOs to use BF16 outputs. 79 80 // Determines whether we should consider changing the precision of the given 81 // instruction in the forward pass. 82 bool InstructionIsCandidateForBF16Output(HloInstruction* hlo); 83 84 // The set of instructions to consider using bfloat16, computed in the forward 85 // pass. 86 absl::flat_hash_set<const HloInstruction*> consider_using_bfloat16_; 87 88 // *************************** 89 // Functions called and state produced by the backward pass (from root to 90 // parameters) that finds opportunities to use BF16. 91 92 // Determines the precision for the given instruction in the 93 // opportunity-finding pass. 94 void DetermineInstructionPrecision(HloInstruction* hlo, bool skip_parameters); 95 96 // Special handling in the opportunity-finding pass for fusion computations. 97 // 98 // Precondition: hlo->opcode() == kFusion 99 void DetermineFusionComputationPrecision(HloInstruction* fusion); 100 101 // Reverts changes to BF16 that will not propagate outside a fusion 102 // computation. This avoids BF16 casts overhead inside a fusion which won't 103 // save memory bandwidth. 104 // 105 // Precondition: hlo->opcode() == kFusion 106 void RevertIfFusionInternalBF16Changes(HloInstruction* fusion); 107 108 // Special handling in the opportunity-finding pass for while computations. 109 // 110 // Precondition: hlo->opcode() == kWhile 111 void DetermineWhileComputationsPrecision(HloInstruction* while_hlo); 112 113 // Special handling in the opportunity-finding pass for conditional branches. 114 // 115 // Precondition: hlo->opcode() == kConditional 116 void DetermineConditionalComputationsPrecision(HloInstruction* cond); 117 118 // The set of HloInstructions that have been visited in the 119 // opportunity-finding pass. 120 absl::flat_hash_set<const HloInstruction*> 121 instructions_visited_in_backward_pass_; 122 123 // The set of HloComputations that have been visited in the 124 // opportunity-finding pass. 125 absl::flat_hash_set<const HloComputation*> 126 computations_visited_in_backward_pass_; 127 128 // *************************** 129 // Functions called by the final inconsistency resolving pass. 130 131 // Adjusts the output shapes of HloInstructions such that if two 132 // HloInstructions have aliasing buffers in their outputs, they must have the 133 // same precision. 134 void ResolveInconsistencyOfAliasingBuffers(HloModule* module); 135 136 // Resolves inconsistency of aliasing buffers for the given computation, and 137 // recursively runs on a while instruction's condition and body until a fixed 138 // point is reached. 139 bool ResolveInconsistencyOfAliasingBuffersHelper( 140 HloComputation* computation, 141 absl::flat_hash_set<const HloComputation*>* visited_computations); 142 143 // Makes the parameters of called computations match how they are called by 144 // the given HLO. 145 void AdjustCalledComputationParameters(HloInstruction* hlo); 146 147 // Makes the root instructions of called computations match how they are used 148 // by the given HLO. 149 void AdjustCalledComputationRoot(HloInstruction* hlo); 150 151 // *************************** 152 // Functions called after changes in changes_to_bf16_ are applied. 153 154 // Resolves inconsistencies introduced by this pass for fusions with 155 // tuple-type output. 156 Status ResolveInconsistentFusions(HloModule* module); 157 158 // Converts the literals in kConstant HLOs which have their types changed to 159 // BF16 by this pass. 160 Status ResolveConvertedConstants(HloModule* module); 161 162 // Skips no-op conversions (same source and target shapes) that can be 163 // produced this pass, i.e., replaces them in their uses with their operands. 164 Status SkipNoopConversions(HloModule* module); 165 166 // *************************** 167 // Functions called and state used by two or more passes. 168 169 // Returns whether all uses of the given HloInstruction can consume BF16 170 // input. 171 bool AllUsersConsumeBF16(const HloInstruction& hlo, 172 const ShapeIndex& index) const; 173 174 // The output element type of the HLO at the given shape index after changes 175 // in changes_to_bf16_ are applied. 176 PrimitiveType OutputTypeAfterChange(HloInstruction* hlo, 177 const ShapeIndex& index) const; 178 179 // The element type of the HLO value after changes in changes_to_bf16_ are 180 // applied. 181 PrimitiveType ValueTypeAfterChange(const HloValue* value) const; 182 183 // If target_type == BF16, adds the HLO at the given index to 184 // changes_to_bf16_; otherwise, target_type must be F32 and this function 185 // removes the HLO at the given index from changes_to_bf16_ if it was earlier 186 // added. 187 void AddToOrRemoveFromBF16ChangeSet(HloInstruction* hlo, 188 const ShapeIndex& index, 189 PrimitiveType target_type); 190 191 // The set of F32 HLO values that must be kept in F32. 192 absl::flat_hash_set<const HloValue*> values_that_must_be_kept_as_f32_; 193 194 // Mapping from each HloComputation to the number of callers to it in the 195 // module. Populated at the beginning of this pass. 196 absl::flat_hash_map<const HloComputation*, int64> caller_counts_; 197 198 // We first store the potential F32-to-BF16 changes to changes_to_bf16_, which 199 // are subject to further adjustment, then finally applied to the HLOs. This 200 // avoids setting changed_ to true but all changes are reverted during 201 // adjustment. 202 // 203 // For each HloInstruction, changes_to_bf16_ stores the affected buffers in 204 // the output as a map from in-place pointers to subshapes to shape indices. 205 absl::flat_hash_map<HloInstruction*, absl::flat_hash_map<Shape*, ShapeIndex>> 206 changes_to_bf16_; 207 208 // Whether the last processed HLO module has been changed by this pass. 209 bool changed_ = false; 210 211 const BFloat16Support* bfloat16_support_; 212 std::unique_ptr<HloDataflowAnalysis> dataflow_; 213 }; 214 215 } // namespace xla 216 217 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_ 218