1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ 18 19 #include "absl/container/inlined_vector.h" 20 #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" 21 #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" 22 #include "tensorflow/compiler/xla/service/gpu/thunk.h" 23 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" 24 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h" 25 26 namespace xla { 27 namespace gpu { 28 29 // Emits LLVM IR for an "unnested computation". 30 // 31 // An unnested computation is an HloComputation which you run by executing one 32 // or more kernels for each HloInstruction it contains. Examples of unnested 33 // computations: 34 // 35 // - An HloModule's root computation, 36 // - The body of an HLO while loop, 37 // - The true/false computation of an HLO conditional. 38 // 39 // Note the opportunity for confusion -- the while loop's computation is nested 40 // within the root computation, but it's emitted using IrEmitterUnnested! Don't 41 // think about it too hard. 42 // 43 // Examples of things that are not unnested computations: 44 // 45 // - The reducer of a kReduce HLO. This is emitted using IrEmitterNested. 46 // - The body of a fusion node. IrEmitterUnenested emits the relevant code 47 // within a kernel function using FusedIrEmitter. (FusedIrEmitter is not 48 // really an IrEmitter, but is more an "IR generator generator".) 49 // 50 class IrEmitterUnnested : public IrEmitter { 51 public: 52 // Parameter block_contains_multi_tiles indicates whether a tile block 53 // consists of multiple tiles or not. If the tile block contains only one 54 // tile, there is no need to use atomic operation to accumulate a local result 55 // to a global result to implement reduction. 56 using TileGenerator = 57 std::function<void(const llvm_ir::IrArray::Index& output_tile_origin, 58 absl::Span<llvm::Value* const> output_tile_bounds, 59 bool block_contains_multi_tiles)>; 60 // KernelCodegenInfo records the common information to support the code 61 // generation for a kernel to process tensor elements by blocks. A block of 62 // tensor elements may contain one or multiple tiles. The code generators that 63 // generate code for tile elements or block prologue/epilogue refer to this 64 // class in their prototypes. If the implementations of such code generators 65 // require other information that are specific to the HLO instructions, the 66 // implementations need to define and use derived classes of this class. 67 class KernelCodegenInfo { 68 public: KernelCodegenInfo(llvm_ir::KernelMappingScheme * mapping_scheme)69 explicit KernelCodegenInfo(llvm_ir::KernelMappingScheme* mapping_scheme) 70 : mapping_scheme_(mapping_scheme), 71 tiled_param_info_(nullptr), 72 lane_id_(nullptr), 73 index_ty_(nullptr) {} ~KernelCodegenInfo()74 virtual ~KernelCodegenInfo() {} 75 SetLaneId(llvm::Value * v)76 void SetLaneId(llvm::Value* v) { lane_id_ = v; } SetIndexType(llvm::Type * t)77 void SetIndexType(llvm::Type* t) { index_ty_ = t; } SetTiledParamInfo(llvm_ir::TiledParameterInfo * tiled_param_info)78 void SetTiledParamInfo(llvm_ir::TiledParameterInfo* tiled_param_info) { 79 tiled_param_info_ = tiled_param_info; 80 } 81 GetLaneId()82 llvm::Value* GetLaneId() const { return lane_id_; } GetKernelMappingScheme()83 llvm_ir::KernelMappingScheme* GetKernelMappingScheme() const { 84 return mapping_scheme_; 85 } GetTiledParameterInfo()86 llvm_ir::TiledParameterInfo* GetTiledParameterInfo() const { 87 return tiled_param_info_; 88 } GetIndexType()89 llvm::Type* GetIndexType() const { return index_ty_; } 90 91 protected: 92 llvm_ir::KernelMappingScheme* mapping_scheme_; 93 llvm_ir::TiledParameterInfo* tiled_param_info_; 94 llvm::Value* lane_id_; 95 llvm::Type* index_ty_; 96 }; 97 98 // A function object to prepare for the code generation for a tile block. 99 using BlockPrologueGenerator = 100 std::function<void(HloInstruction* hlo, KernelCodegenInfo* kernel_info)>; 101 // A function object to finalize the code generation for a tile block. 102 using BlockEpilogueGenerator = 103 std::function<void(HloInstruction* hlo, KernelCodegenInfo* kernel_info)>; 104 // A function object to generate code to process one element in a tile. 105 // 106 // hlo: the instruction for which the code is generated for. 107 // index: the index for the first output element of the current thread. 108 // y_loc: The y coordinate within a tile. 109 // x_loc: The x coordinate within a tile. 110 // kernel_info: Other information to support the kernel code generation. 111 // x_iter_num: When a thread process N elements in the X dimension, x_iter_num 112 // has a value of 0..N-1 to identify the element being process. 113 using TileElementGenerator = std::function<void( 114 HloInstruction* hlo, const llvm_ir::IrArray::Index& index, 115 const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, 116 llvm::Value* x_loc, int64 x_iter_num)>; 117 118 // KernelCodeGenerator records the code generator objects that generate code 119 // for tile elements or tile block prologue/epilogue. 120 class KernelCodeGenerator { 121 public: 122 explicit KernelCodeGenerator( 123 TileElementGenerator tile_element_generator, 124 BlockPrologueGenerator block_prologue_generator = {}, 125 BlockEpilogueGenerator block_epilogue_generator = {}) tile_element_generator_(std::move (tile_element_generator))126 : tile_element_generator_(std::move(tile_element_generator)), 127 block_prologue_generator_(std::move(block_prologue_generator)), 128 block_epilogue_generator_(std::move(block_epilogue_generator)) {} 129 GetTileElementGenerator()130 const TileElementGenerator& GetTileElementGenerator() const { 131 return tile_element_generator_; 132 } GetBlockPrologueGenerator()133 const BlockPrologueGenerator& GetBlockPrologueGenerator() const { 134 return block_prologue_generator_; 135 } GetBlockEpilogueGenerator()136 const BlockEpilogueGenerator& GetBlockEpilogueGenerator() const { 137 return block_epilogue_generator_; 138 } 139 140 private: 141 TileElementGenerator tile_element_generator_; 142 BlockPrologueGenerator block_prologue_generator_; 143 BlockEpilogueGenerator block_epilogue_generator_; 144 }; 145 146 IrEmitterUnnested(const HloModuleConfig& hlo_module_config, 147 const HloComputation* hlo_computation, 148 IrEmitterContext* ir_emitter_context); 149 IrEmitterUnnested(const IrEmitterUnnested&) = delete; 150 IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete; 151 152 // Transfers the ownship of thunk_sequence_ out. ConsumeThunkSequence()153 std::unique_ptr<ThunkSequence> ConsumeThunkSequence() { 154 return std::move(thunk_sequence_); 155 } 156 157 Status DefaultAction(HloInstruction* hlo) override; 158 159 // IrEmitterUnnested handles the following instructions differently from 160 // IrEmitter. 161 Status HandleCopy(HloInstruction* copy) override; 162 Status HandleConditional(HloInstruction* conditional) override; 163 Status HandleConvolution(HloInstruction* convolution) override; 164 Status HandleCustomCall(HloInstruction* custom_call) override; 165 Status HandleDot(HloInstruction* dot) override; 166 Status HandleFft(HloInstruction* fft) override; 167 Status HandleFusion(HloInstruction* fusion) override; 168 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; 169 Status HandleReduce(HloInstruction* reduce) override; 170 Status HandleSelectAndScatter(HloInstruction* instruction) override; 171 Status HandleTuple(HloInstruction* tuple) override; 172 Status HandleWhile(HloInstruction* xla_while) override; 173 Status HandleInfeed(HloInstruction* xla_infeed) override; 174 Status HandleOutfeed(HloInstruction* outfeed) override; 175 Status HandleRng(HloInstruction* random) override; 176 Status HandleScatter(HloInstruction* scatter) override; 177 Status HandleSelect(HloInstruction* select) override; 178 Status HandleSort(HloInstruction* sort) override; 179 Status HandleTriangularSolve(HloInstruction* hlo) override; 180 Status HandleTupleSelect(HloInstruction* tuple_select) override; 181 Status HandleAllReduce(HloInstruction* crs) override; 182 Status HandleAfterAll(HloInstruction* after_all) override; 183 184 Status EmitTargetElementLoop( 185 const HloInstruction& hlo, 186 const llvm_ir::ElementGenerator& body_emitter) override; 187 188 // Same as `EmitTargetElementLoop`, but in given `thunk` rather than 189 // `LastThunk()`. 190 Status EmitTargetElementLoopInThunk( 191 const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter, 192 KernelThunk* thunk); 193 194 // Emits LLVM global variables corresponding to constant instructions. 195 Status EmitConstantGlobals(); 196 197 private: 198 // Add a owning Thunk object to the thunk sequence. AddThunkToThunkSequence(std::unique_ptr<Thunk> thunk)199 void AddThunkToThunkSequence(std::unique_ptr<Thunk> thunk) { 200 thunk_sequence_->emplace_back(std::move(thunk)); 201 } 202 203 // Builds the prototype of the IR kernel for `inst` and adds it to the module. 204 // This kernel takes as arguments pointers to the given buffer allocations. 205 llvm::Function* BuildKernelPrototype( 206 const HloInstruction& inst, 207 absl::Span<const BufferAllocation* const> args); 208 209 // Helper for writing extra outputs from inside a reduce kernel. 210 Status EmitExtraOutputsForReduce( 211 const HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index, 212 absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>> 213 extra_output_gens); 214 215 // Generates code for reduction to contiguous dimensions. 216 // 217 // Prerequisite: `IsReductionToVector(*unnested_hlo)` 218 Status EmitReductionToVector(HloInstruction* unnested_hlo); 219 220 // Computes the KernelMappingScheme for the reduce HLO and indicates whether 221 // the reduction is a row reduction. For an un-fused reduce op, unnested_hlo 222 // and first_reduce are the same instruction. For a kInput fusion, 223 // unnested_hlo is the fusion instruction while first_reduce is the first 224 // reduce op. 225 std::tuple<llvm_ir::KernelMappingScheme, bool> 226 ComputeMappingSchemeAndReductionKind(const HloInstruction* unnested_hlo, 227 const HloInstruction* first_reduce); 228 229 // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in 230 // the process. `scatter` may be fused, scatter indices are taken from 231 // `scatter_indices_gen`, updates from`updates_gen`. The output buffer is 232 // expected to have the operand values in it already. 233 Status EmitScatter(Thunk* thunk, HloInstruction* scatter, 234 const llvm_ir::ElementGenerator& scatter_indices_gen, 235 const llvm_ir::ElementGenerator& updates_gen); 236 237 // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel 238 // for the hlo instruction. 239 bool CheckAndEmitHloWithTile021(HloInstruction* hlo); 240 // Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and 241 // returns the launch dimensions for the kernel. This is a helper to support 242 // the implementation of CheckAndEmitHloWithTile021. 243 LaunchDimensions EmitHlo021Tile(HloInstruction* hlo, 244 absl::Span<const int64> reduced_output_dims, 245 absl::Span<const int64> tiled_param_ids); 246 // Emits a kernel for an unnested HLO instruction. 247 LaunchDimensions EmitKernel(HloInstruction* unnested_hlo, 248 absl::Span<const int64> param_ids, 249 const KernelCodeGenerator& kernel_generator, 250 KernelCodegenInfo* kernel_info); 251 void EmitBlock(const TileGenerator& emit_one_tile, 252 KernelCodegenInfo* kernel_info, KernelSupportLibrary* ksl, 253 llvm::Type* index_ty); 254 // Emits code to process a tensor element in a tile for the given kCopy HLO 255 // that performs a 0-2-1 transpose. 256 void EmitTileElementForCopy(HloInstruction* hlo, 257 const llvm_ir::IrArray::Index& index, 258 const KernelCodegenInfo* kernel_info, 259 llvm::Value* y_loc, llvm::Value* x_loc, 260 int64 x_iter_num); 261 // Emits code to process a tensor element in a tile for the given kLoop fusion 262 // HLO containing parameters that are 0-2-1 transpose of its outputs. 263 void EmitTileElementForFusion(HloInstruction* hlo, 264 const llvm_ir::IrArray::Index& index, 265 const KernelCodegenInfo* kernel_info, 266 llvm::Value* y_loc, llvm::Value* x_loc, 267 int64 x_iter_num); 268 // Emits code to process a tensor element in a tile for the given input hlo 269 // that is either a unnested kReduce or a kInput fusion. 270 void EmitTileElementForReduction(HloInstruction* unnested_hlo, 271 const llvm_ir::IrArray::Index& index, 272 const KernelCodegenInfo* kernel_info, 273 llvm::Value* y_loc, llvm::Value* x_loc, 274 int64 x_iter_num); 275 // Prepares for the code generation for a tile block of a reduction kernel. 276 void EmitPrologueForReduction(HloInstruction* unnested_hlo, 277 KernelCodegenInfo* kernel_info); 278 void EmitPrologueForOneReduction(HloInstruction* unnested_hlo, 279 HloInstruction* reduce_inst, int reduce_idx, 280 KernelCodegenInfo* kernel_info, 281 GpuElementalIrEmitter* elemental_emitter, 282 ShapeIndex output_shape_index); 283 // Wraps up the code generation for a tile block of a reduction kernel. 284 void EmitEpilogueForReduction(HloInstruction* unnested_hlo, 285 KernelCodegenInfo* kernel_info); 286 // For each reducer, emits the shuffle-down loop to accumulate the partial 287 // result to the global result. 288 void EmitFullWarpShuffleDownLoopForAllReduces( 289 absl::Span<HloComputation* const> reducers, 290 absl::Span<llvm::AllocaInst* const> partial_result_addresses); 291 292 // Generates the IrArray for each input of an hlo and returns a vector that 293 // constains such IrArrays. 294 std::vector<llvm_ir::IrArray> ConstructIrArrayForInputs( 295 const HloInstruction& hlo); 296 297 // For each input of the `hlo` instruction, checks its value in 298 // `param_buffers` to find out whether the input has a reduced shape. If the 299 // input has a reduced shape, constructs the reduced shape for the input and 300 // casts the original input IrArray in `param_arrays` to the reduced shape. 301 // Return the total number of inputs. 302 int ConstructInputReducedShapeAndCastInputIrArrayToShape( 303 const HloInstruction& hlo, 304 const std::vector<llvm_ir::IrArray>& param_arrays, 305 const std::vector<llvm::Value*>& param_buffers, 306 absl::Span<const int64> reduced_output_dims, 307 std::vector<Shape>* param_reduced_shapes, 308 std::vector<llvm_ir::IrArray>* param_in_reduced_shape_arrays); 309 310 // Returns a KernelThunk that invokes the kernel emitted for `inst`. The 311 // caller needs to make sure `inst` outlives the lifetime of the returned 312 // Thunk object. The kernel implementation will be unrolled if unroll_factor 313 // is greater than one. 'implements_whole_instruction' specifies whether this 314 // KernelThunk implements the whole 'inst' HloInstruction. In some cases 315 // 'inst' will be implemented by a sequence of Thunks. 316 std::unique_ptr<KernelThunk> BuildKernelThunk( 317 const HloInstruction* inst, bool implements_whole_instruction, 318 int unroll_factor = 1); 319 320 // Returns a FftThunk that calls cuFFT to implement `inst`. 321 std::unique_ptr<Thunk> BuildFftThunk(const HloInstruction* inst); 322 323 // Returns a CholeskyThunk that calls cuSolver to implement `inst`. 324 std::unique_ptr<Thunk> BuildCholeskyThunk(const HloInstruction* inst); 325 326 // Returns a TriangularSolveThunk that calls cuBlas to implement `inst`. 327 std::unique_ptr<Thunk> BuildTriangularSolveThunk(const HloInstruction* inst); 328 329 // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs 330 // to make sure `inst` outlives the lifetime of the returned Thunk object. 331 std::unique_ptr<Thunk> BuildGemmThunk(const HloInstruction* inst); 332 333 // Returns a thunk that, given a reduce or select-and-scatter op, initializes 334 // its memory to the appropriate initial value. 335 StatusOr<std::unique_ptr<Thunk>> BuildInitializerThunk( 336 HloInstruction* hlo, const ShapeIndex& index = {}); 337 338 // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. 339 std::unique_ptr<Thunk> BuildHostToDeviceCopyThunk(const HloInstruction* inst); 340 341 // Returns a thunk that calls device-to-device cuMemcpy to implement `inst`. 342 std::unique_ptr<Thunk> BuildDeviceToDeviceCopyThunk( 343 const HloInstruction* inst); 344 345 // Returns an InfeedThunk that performs a host-to-device memcpy to implement 346 // `inst`. 347 std::unique_ptr<Thunk> BuildInfeedThunk(const HloInstruction* inst); 348 349 // Returns an OutfeedThunk that performs a device-to-host memcpy to implement 350 // `inst`. 351 std::unique_ptr<Thunk> BuildOutfeedThunk(const HloInstruction* inst); 352 353 // Returns a WhileThunk that invokes thunk sequences for 'condition' and 354 // 'body' sub-computations of while instruction 'hlo'. 355 std::unique_ptr<Thunk> BuildWhileThunk(const HloInstruction* hlo); 356 357 // Returns a ForThunk which executes 'loop_limit' invocations of a thunk 358 // sequence from the 'body' sub-computation of the while instruction 'hlo'. 359 std::unique_ptr<Thunk> BuildForThunk(const HloInstruction* hlo, 360 const int64 loop_limit); 361 362 // Returns a ConditionalThunk which executes the thunk sequence for the 363 // 'branch_computation' corresponding to the predicate/branch_index of the 364 // given conditional instruction. 365 std::unique_ptr<Thunk> BuildConditionalThunk(const HloInstruction* hlo); 366 367 Status Postprocess(HloInstruction* hlo) override; 368 369 // Returns the last generated thunk. LastThunk()370 Thunk* LastThunk() const { return thunk_sequence_->back().get(); } 371 372 // The thunk sequence this IrEmitter generates for the input computation. 373 std::unique_ptr<ThunkSequence> thunk_sequence_; 374 375 // The HloComputation that this IrEmitter emits code for. 376 const HloComputation* hlo_computation_; 377 }; 378 379 } // namespace gpu 380 } // namespace xla 381 382 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ 383