1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
18 
19 #include "absl/container/inlined_vector.h"
20 #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
21 #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
22 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
23 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
24 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h"
25 
26 namespace xla {
27 namespace gpu {
28 
29 // Emits LLVM IR for an "unnested computation".
30 //
31 // An unnested computation is an HloComputation which you run by executing one
32 // or more kernels for each HloInstruction it contains.  Examples of unnested
33 // computations:
34 //
35 //  - An HloModule's root computation,
36 //  - The body of an HLO while loop,
37 //  - The true/false computation of an HLO conditional.
38 //
39 // Note the opportunity for confusion -- the while loop's computation is nested
40 // within the root computation, but it's emitted using IrEmitterUnnested!  Don't
41 // think about it too hard.
42 //
43 // Examples of things that are not unnested computations:
44 //
45 //  - The reducer of a kReduce HLO.  This is emitted using IrEmitterNested.
46 //  - The body of a fusion node.  IrEmitterUnenested emits the relevant code
47 //    within a kernel function using FusedIrEmitter.  (FusedIrEmitter is not
48 //    really an IrEmitter, but is more an "IR generator generator".)
49 //
50 class IrEmitterUnnested : public IrEmitter {
51  public:
52   // Parameter block_contains_multi_tiles indicates whether a tile block
53   // consists of multiple tiles or not. If the tile block contains only one
54   // tile, there is no need to use atomic operation to accumulate a local result
55   // to a global result to implement reduction.
56   using TileGenerator =
57       std::function<void(const llvm_ir::IrArray::Index& output_tile_origin,
58                          absl::Span<llvm::Value* const> output_tile_bounds,
59                          bool block_contains_multi_tiles)>;
60   // KernelCodegenInfo records the common information to support the code
61   // generation for a kernel to process tensor elements by blocks. A block of
62   // tensor elements may contain one or multiple tiles. The code generators that
63   // generate code for tile elements or block prologue/epilogue refer to this
64   // class in their prototypes. If the implementations of such code generators
65   // require other information that are specific to the HLO instructions, the
66   // implementations need to define and use derived classes of this class.
67   class KernelCodegenInfo {
68    public:
KernelCodegenInfo(llvm_ir::KernelMappingScheme * mapping_scheme)69     explicit KernelCodegenInfo(llvm_ir::KernelMappingScheme* mapping_scheme)
70         : mapping_scheme_(mapping_scheme),
71           tiled_param_info_(nullptr),
72           lane_id_(nullptr),
73           index_ty_(nullptr) {}
~KernelCodegenInfo()74     virtual ~KernelCodegenInfo() {}
75 
SetLaneId(llvm::Value * v)76     void SetLaneId(llvm::Value* v) { lane_id_ = v; }
SetIndexType(llvm::Type * t)77     void SetIndexType(llvm::Type* t) { index_ty_ = t; }
SetTiledParamInfo(llvm_ir::TiledParameterInfo * tiled_param_info)78     void SetTiledParamInfo(llvm_ir::TiledParameterInfo* tiled_param_info) {
79       tiled_param_info_ = tiled_param_info;
80     }
81 
GetLaneId()82     llvm::Value* GetLaneId() const { return lane_id_; }
GetKernelMappingScheme()83     llvm_ir::KernelMappingScheme* GetKernelMappingScheme() const {
84       return mapping_scheme_;
85     }
GetTiledParameterInfo()86     llvm_ir::TiledParameterInfo* GetTiledParameterInfo() const {
87       return tiled_param_info_;
88     }
GetIndexType()89     llvm::Type* GetIndexType() const { return index_ty_; }
90 
91    protected:
92     llvm_ir::KernelMappingScheme* mapping_scheme_;
93     llvm_ir::TiledParameterInfo* tiled_param_info_;
94     llvm::Value* lane_id_;
95     llvm::Type* index_ty_;
96   };
97 
98   // A function object to prepare for the code generation for a tile block.
99   using BlockPrologueGenerator =
100       std::function<void(HloInstruction* hlo, KernelCodegenInfo* kernel_info)>;
101   // A function object to finalize the code generation for a tile block.
102   using BlockEpilogueGenerator =
103       std::function<void(HloInstruction* hlo, KernelCodegenInfo* kernel_info)>;
104   // A function object to generate code to process one element in a tile.
105   //
106   // hlo: the instruction for which the code is generated for.
107   // index: the index for the first output element of the current thread.
108   // y_loc: The y coordinate within a tile.
109   // x_loc: The x coordinate within a tile.
110   // kernel_info: Other information to support the kernel code generation.
111   // x_iter_num: When a thread process N elements in the X dimension, x_iter_num
112   //             has a value of 0..N-1 to identify the element being process.
113   using TileElementGenerator = std::function<void(
114       HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
115       const KernelCodegenInfo* kernel_info, llvm::Value* y_loc,
116       llvm::Value* x_loc, int64 x_iter_num)>;
117 
118   // KernelCodeGenerator records the code generator objects that generate code
119   // for tile elements or tile block prologue/epilogue.
120   class KernelCodeGenerator {
121    public:
122     explicit KernelCodeGenerator(
123         TileElementGenerator tile_element_generator,
124         BlockPrologueGenerator block_prologue_generator = {},
125         BlockEpilogueGenerator block_epilogue_generator = {})
tile_element_generator_(std::move (tile_element_generator))126         : tile_element_generator_(std::move(tile_element_generator)),
127           block_prologue_generator_(std::move(block_prologue_generator)),
128           block_epilogue_generator_(std::move(block_epilogue_generator)) {}
129 
GetTileElementGenerator()130     const TileElementGenerator& GetTileElementGenerator() const {
131       return tile_element_generator_;
132     }
GetBlockPrologueGenerator()133     const BlockPrologueGenerator& GetBlockPrologueGenerator() const {
134       return block_prologue_generator_;
135     }
GetBlockEpilogueGenerator()136     const BlockEpilogueGenerator& GetBlockEpilogueGenerator() const {
137       return block_epilogue_generator_;
138     }
139 
140    private:
141     TileElementGenerator tile_element_generator_;
142     BlockPrologueGenerator block_prologue_generator_;
143     BlockEpilogueGenerator block_epilogue_generator_;
144   };
145 
146   IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
147                     const HloComputation* hlo_computation,
148                     IrEmitterContext* ir_emitter_context);
149   IrEmitterUnnested(const IrEmitterUnnested&) = delete;
150   IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete;
151 
152   // Transfers the ownship of thunk_sequence_ out.
ConsumeThunkSequence()153   std::unique_ptr<ThunkSequence> ConsumeThunkSequence() {
154     return std::move(thunk_sequence_);
155   }
156 
157   Status DefaultAction(HloInstruction* hlo) override;
158 
159   // IrEmitterUnnested handles the following instructions differently from
160   // IrEmitter.
161   Status HandleCopy(HloInstruction* copy) override;
162   Status HandleConditional(HloInstruction* conditional) override;
163   Status HandleConvolution(HloInstruction* convolution) override;
164   Status HandleCustomCall(HloInstruction* custom_call) override;
165   Status HandleDot(HloInstruction* dot) override;
166   Status HandleFft(HloInstruction* fft) override;
167   Status HandleFusion(HloInstruction* fusion) override;
168   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
169   Status HandleReduce(HloInstruction* reduce) override;
170   Status HandleSelectAndScatter(HloInstruction* instruction) override;
171   Status HandleTuple(HloInstruction* tuple) override;
172   Status HandleWhile(HloInstruction* xla_while) override;
173   Status HandleInfeed(HloInstruction* xla_infeed) override;
174   Status HandleOutfeed(HloInstruction* outfeed) override;
175   Status HandleRng(HloInstruction* random) override;
176   Status HandleScatter(HloInstruction* scatter) override;
177   Status HandleSelect(HloInstruction* select) override;
178   Status HandleSort(HloInstruction* sort) override;
179   Status HandleTriangularSolve(HloInstruction* hlo) override;
180   Status HandleTupleSelect(HloInstruction* tuple_select) override;
181   Status HandleAllReduce(HloInstruction* crs) override;
182   Status HandleAfterAll(HloInstruction* after_all) override;
183 
184   Status EmitTargetElementLoop(
185       const HloInstruction& hlo,
186       const llvm_ir::ElementGenerator& body_emitter) override;
187 
188   // Same as `EmitTargetElementLoop`, but in given `thunk` rather than
189   // `LastThunk()`.
190   Status EmitTargetElementLoopInThunk(
191       const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter,
192       KernelThunk* thunk);
193 
194   // Emits LLVM global variables corresponding to constant instructions.
195   Status EmitConstantGlobals();
196 
197  private:
198   // Add a owning Thunk object to the thunk sequence.
AddThunkToThunkSequence(std::unique_ptr<Thunk> thunk)199   void AddThunkToThunkSequence(std::unique_ptr<Thunk> thunk) {
200     thunk_sequence_->emplace_back(std::move(thunk));
201   }
202 
203   // Builds the prototype of the IR kernel for `inst` and adds it to the module.
204   // This kernel takes as arguments pointers to the given buffer allocations.
205   llvm::Function* BuildKernelPrototype(
206       const HloInstruction& inst,
207       absl::Span<const BufferAllocation* const> args);
208 
209   // Helper for writing extra outputs from inside a reduce kernel.
210   Status EmitExtraOutputsForReduce(
211       const HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index,
212       absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
213           extra_output_gens);
214 
215   // Generates code for reduction to contiguous dimensions.
216   //
217   // Prerequisite: `IsReductionToVector(*unnested_hlo)`
218   Status EmitReductionToVector(HloInstruction* unnested_hlo);
219 
220   // Computes the KernelMappingScheme for the reduce HLO and indicates whether
221   // the reduction is a row reduction. For an un-fused reduce op, unnested_hlo
222   // and first_reduce are the same instruction. For a kInput fusion,
223   // unnested_hlo is the fusion instruction while first_reduce is the first
224   // reduce op.
225   std::tuple<llvm_ir::KernelMappingScheme, bool>
226   ComputeMappingSchemeAndReductionKind(const HloInstruction* unnested_hlo,
227                                        const HloInstruction* first_reduce);
228 
229   // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in
230   // the process. `scatter` may be fused, scatter indices are taken from
231   // `scatter_indices_gen`, updates from`updates_gen`. The output buffer is
232   // expected to have the operand values in it already.
233   Status EmitScatter(Thunk* thunk, HloInstruction* scatter,
234                      const llvm_ir::ElementGenerator& scatter_indices_gen,
235                      const llvm_ir::ElementGenerator& updates_gen);
236 
237   // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel
238   // for the hlo instruction.
239   bool CheckAndEmitHloWithTile021(HloInstruction* hlo);
240   // Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and
241   // returns the launch dimensions for the kernel. This is a helper to support
242   // the implementation of CheckAndEmitHloWithTile021.
243   LaunchDimensions EmitHlo021Tile(HloInstruction* hlo,
244                                   absl::Span<const int64> reduced_output_dims,
245                                   absl::Span<const int64> tiled_param_ids);
246   // Emits a kernel for an unnested HLO instruction.
247   LaunchDimensions EmitKernel(HloInstruction* unnested_hlo,
248                               absl::Span<const int64> param_ids,
249                               const KernelCodeGenerator& kernel_generator,
250                               KernelCodegenInfo* kernel_info);
251   void EmitBlock(const TileGenerator& emit_one_tile,
252                  KernelCodegenInfo* kernel_info, KernelSupportLibrary* ksl,
253                  llvm::Type* index_ty);
254   // Emits code to process a tensor element in a tile for the given kCopy HLO
255   // that performs a 0-2-1 transpose.
256   void EmitTileElementForCopy(HloInstruction* hlo,
257                               const llvm_ir::IrArray::Index& index,
258                               const KernelCodegenInfo* kernel_info,
259                               llvm::Value* y_loc, llvm::Value* x_loc,
260                               int64 x_iter_num);
261   // Emits code to process a tensor element in a tile for the given kLoop fusion
262   // HLO containing parameters that are 0-2-1 transpose of its outputs.
263   void EmitTileElementForFusion(HloInstruction* hlo,
264                                 const llvm_ir::IrArray::Index& index,
265                                 const KernelCodegenInfo* kernel_info,
266                                 llvm::Value* y_loc, llvm::Value* x_loc,
267                                 int64 x_iter_num);
268   // Emits code to process a tensor element in a tile for the given input hlo
269   // that is either a unnested kReduce or a kInput fusion.
270   void EmitTileElementForReduction(HloInstruction* unnested_hlo,
271                                    const llvm_ir::IrArray::Index& index,
272                                    const KernelCodegenInfo* kernel_info,
273                                    llvm::Value* y_loc, llvm::Value* x_loc,
274                                    int64 x_iter_num);
275   // Prepares for the code generation for a tile block of a reduction kernel.
276   void EmitPrologueForReduction(HloInstruction* unnested_hlo,
277                                 KernelCodegenInfo* kernel_info);
278   void EmitPrologueForOneReduction(HloInstruction* unnested_hlo,
279                                    HloInstruction* reduce_inst, int reduce_idx,
280                                    KernelCodegenInfo* kernel_info,
281                                    GpuElementalIrEmitter* elemental_emitter,
282                                    ShapeIndex output_shape_index);
283   // Wraps up the code generation for a tile block of a reduction kernel.
284   void EmitEpilogueForReduction(HloInstruction* unnested_hlo,
285                                 KernelCodegenInfo* kernel_info);
286   // For each reducer, emits the shuffle-down loop to accumulate the partial
287   // result to the global result.
288   void EmitFullWarpShuffleDownLoopForAllReduces(
289       absl::Span<HloComputation* const> reducers,
290       absl::Span<llvm::AllocaInst* const> partial_result_addresses);
291 
292   // Generates the IrArray for each input of an hlo and returns a vector that
293   // constains such IrArrays.
294   std::vector<llvm_ir::IrArray> ConstructIrArrayForInputs(
295       const HloInstruction& hlo);
296 
297   // For each input of the `hlo` instruction, checks its value in
298   // `param_buffers` to find out whether the input has a reduced shape. If the
299   // input has a reduced shape, constructs the reduced shape for the input and
300   // casts the original input IrArray in `param_arrays` to the reduced shape.
301   // Return the total number of inputs.
302   int ConstructInputReducedShapeAndCastInputIrArrayToShape(
303       const HloInstruction& hlo,
304       const std::vector<llvm_ir::IrArray>& param_arrays,
305       const std::vector<llvm::Value*>& param_buffers,
306       absl::Span<const int64> reduced_output_dims,
307       std::vector<Shape>* param_reduced_shapes,
308       std::vector<llvm_ir::IrArray>* param_in_reduced_shape_arrays);
309 
310   // Returns a KernelThunk that invokes the kernel emitted for `inst`. The
311   // caller needs to make sure `inst` outlives the lifetime of the returned
312   // Thunk object. The kernel implementation will be unrolled if unroll_factor
313   // is greater than one. 'implements_whole_instruction' specifies whether this
314   // KernelThunk implements the whole 'inst' HloInstruction. In some cases
315   // 'inst' will be implemented by a sequence of Thunks.
316   std::unique_ptr<KernelThunk> BuildKernelThunk(
317       const HloInstruction* inst, bool implements_whole_instruction,
318       int unroll_factor = 1);
319 
320   // Returns a FftThunk that calls cuFFT to implement `inst`.
321   std::unique_ptr<Thunk> BuildFftThunk(const HloInstruction* inst);
322 
323   // Returns a CholeskyThunk that calls cuSolver to implement `inst`.
324   std::unique_ptr<Thunk> BuildCholeskyThunk(const HloInstruction* inst);
325 
326   // Returns a TriangularSolveThunk that calls cuBlas to implement `inst`.
327   std::unique_ptr<Thunk> BuildTriangularSolveThunk(const HloInstruction* inst);
328 
329   // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs
330   // to make sure `inst` outlives the lifetime of the returned Thunk object.
331   std::unique_ptr<Thunk> BuildGemmThunk(const HloInstruction* inst);
332 
333   // Returns a thunk that, given a reduce or select-and-scatter op, initializes
334   // its memory to the appropriate initial value.
335   StatusOr<std::unique_ptr<Thunk>> BuildInitializerThunk(
336       HloInstruction* hlo, const ShapeIndex& index = {});
337 
338   // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`.
339   std::unique_ptr<Thunk> BuildHostToDeviceCopyThunk(const HloInstruction* inst);
340 
341   // Returns a thunk that calls device-to-device cuMemcpy to implement `inst`.
342   std::unique_ptr<Thunk> BuildDeviceToDeviceCopyThunk(
343       const HloInstruction* inst);
344 
345   // Returns an InfeedThunk that performs a host-to-device memcpy to implement
346   // `inst`.
347   std::unique_ptr<Thunk> BuildInfeedThunk(const HloInstruction* inst);
348 
349   // Returns an OutfeedThunk that performs a device-to-host memcpy to implement
350   // `inst`.
351   std::unique_ptr<Thunk> BuildOutfeedThunk(const HloInstruction* inst);
352 
353   // Returns a WhileThunk that invokes thunk sequences for 'condition' and
354   // 'body' sub-computations of while instruction 'hlo'.
355   std::unique_ptr<Thunk> BuildWhileThunk(const HloInstruction* hlo);
356 
357   // Returns a ForThunk which executes 'loop_limit' invocations of a thunk
358   // sequence from the 'body' sub-computation of the while instruction 'hlo'.
359   std::unique_ptr<Thunk> BuildForThunk(const HloInstruction* hlo,
360                                        const int64 loop_limit);
361 
362   // Returns a ConditionalThunk which executes the thunk sequence for the
363   // 'branch_computation' corresponding to the predicate/branch_index of the
364   // given conditional instruction.
365   std::unique_ptr<Thunk> BuildConditionalThunk(const HloInstruction* hlo);
366 
367   Status Postprocess(HloInstruction* hlo) override;
368 
369   // Returns the last generated thunk.
LastThunk()370   Thunk* LastThunk() const { return thunk_sequence_->back().get(); }
371 
372   // The thunk sequence this IrEmitter generates for the input computation.
373   std::unique_ptr<ThunkSequence> thunk_sequence_;
374 
375   // The HloComputation that this IrEmitter emits code for.
376   const HloComputation* hlo_computation_;
377 };
378 
379 }  // namespace gpu
380 }  // namespace xla
381 
382 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
383