1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_TILING_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_TILING_H_
18 
19 #include "llvm/IR/Value.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
22 
23 namespace xla {
24 namespace llvm_ir {
25 
26 // About 0-2-1 transpose:
27 //
28 // If a shape can be viewed as three logical components 0-1-2 in the order of
29 // major to minor, a 0-2-1-transpose changes the order of such logical
30 // components to 0-2-1. We call the shape being transposed the input shape and
31 // the transposed shape the output shape. The logical view of the input/output
32 // shapes for the transpose are called the 0-1-2/0-2-1 shapes or the normalized
33 // shapes. The original input/output shapes are called unnormalized shapes.
34 //
35 // If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the
36 // normalized shape of `b` or the 0-2-1 shape.
37 absl::optional<std::vector<int64> > FindTranspose021(const Shape& a,
38                                                      const Shape& b);
39 
40 // A tile is a spatial subdivision of a tensor. We group tensor elements into
41 // tiles so that we can launch kernels to process the tensor elements in blocks
42 // of tiles.
43 //
44 // A kernel mapping scheme describes a method to partition the tensors accessed
45 // by an unnested HLO instruction into tiles and blocks of tiles, and the
46 // associated information to use hardware threads to process the tensor elements
47 // in blocks of tiles.
48 //
49 // Currently, there are two main use cases for a tiling scheme. First, we
50 // implement kernels with 0-2-1 memory transpose using shared memory to improve
51 // memory access pattern. Second, we implement reduction to contiguous
52 // dimensions in layout, with or without memory tranpsose, to achieve better
53 // memory access pattern as well as to reduce the need numbers of executed
54 // expensive instructions, such as thread synchronization related instructions
55 // and atomic operations. For both use cases, we can apply a normalization to
56 // the original tensors, to collapse contiguous dimensions for the same purpose
57 // and produce normlized three dimensional tensors. For this reason, the tiling
58 // scheme class only needs to handle normalized three dimensional tensors and
59 // two dimensional tiles.
60 //
61 // The current implementation of the class is somewhat NVIDIA GPU oriented. This
62 // situation can be improved when there is a need though. The idea of 0-2-1
63 // transpose using shared memory can be found in the following CUDA algorithm in
64 // TensorFlow: https://goo.gl/MStRV6.
65 //
66 // We use a thread block to process a tile because we want to use the HW thread
67 // block synchronization primitives to synchronize the processing of all the
68 // elements in the same tile. A thread block can be viewed as a two dimensional
69 // array of threads, described by the number of threads for the Y and X
70 // dimensions. A thread block (num_threads_y, num_threads_x) processes a tile of
71 // (tile_size_y, tile_size_x) as follows: each thread in the thread block
72 // processes one element in the tile so that all the threads in the thread block
73 // together process a subdivision of the tile that has the same dimension as the
74 // thread block array. Then the thread block moves on to process the next
75 // subdivision of the tile until the whole tile is processed. Therefore, each
76 // thread in the thread block processes
77 // tile_size_x/num_threads_x * tile_size_y/num_threads_y elements in a tile.
78 //
79 // There are situations where we want a thread block to process multiple
80 // tiles. We can't group those tiles into a bigger tiles because we limit a tile
81 // to a two dimensional spatial subdivision of a tensor. For example, when we
82 // use tiling to implement reduction with tranpose, we want the partial sum
83 // produced by each thread to accumulate values for more elements before using
84 // shlf_down and atomic_add instructions for further reduction, to amortize the
85 // cost of such expensive instructions. The concept of tile block is introduced
86 // for this purpose. A tile block is a three dimensional array of tiles, of
87 // which some dimensions may be degenerated to only one tile.
88 class KernelMappingScheme {
89  public:
90   enum { DimZ = 0, DimY, DimX, DimTot };
91 
92  public:
KernelMappingScheme()93   KernelMappingScheme() {}
94   // dims_in_elems: the normalized tensor dimensions.
95   // req_block_sizes: the requested block size in number of tiles for each
96   //   dimension. The actual block size is set to min(req_block_size,
97   //   dims_in_number_of_blocks).
98   KernelMappingScheme(absl::Span<const int64> dims_in_elems, int64 tile_size_y,
99                       int64 tile_size_x,
100                       absl::Span<const int64> req_block_sizes,
101                       int64 num_threads_y, int64 num_threads_x,
102                       llvm::IRBuilder<>* b);
103 
GetDimensionsInElements()104   absl::Span<const int64> GetDimensionsInElements() const {
105     return dims_in_elems_;
106   }
GetDimensionsInTiles()107   absl::Span<const int64> GetDimensionsInTiles() const {
108     return dims_in_tiles_;
109   }
GetDimensionsInBlocks()110   absl::Span<const int64> GetDimensionsInBlocks() const {
111     return dims_in_blocks_;
112   }
113 
GetNumberOfTilesInTotal()114   int64 GetNumberOfTilesInTotal() const {
115     return absl::c_accumulate(dims_in_tiles_, 1LL, std::multiplies<int64>());
116   }
GetNumberOfTilesInOneBlock()117   int64 GetNumberOfTilesInOneBlock() const {
118     return absl::c_accumulate(block_sizes_, 1, std::multiplies<int64>());
119   }
GetNumberOfTilesInOneBlockForDimension(int d)120   int64 GetNumberOfTilesInOneBlockForDimension(int d) const {
121     DCHECK(d >= DimZ && d <= DimX);
122     return block_sizes_[d];
123   }
GetNumberOfBlocks()124   int64 GetNumberOfBlocks() const {
125     return absl::c_accumulate(dims_in_blocks_, 1, std::multiplies<int64>());
126   }
127 
GetTileSizeForDimension(int d)128   int64 GetTileSizeForDimension(int d) const {
129     DCHECK(d >= DimZ && d <= DimX);
130     return tile_sizes_[d];
131   }
GetTileSizeForDimensionX()132   int64 GetTileSizeForDimensionX() const {
133     return GetTileSizeForDimension(DimX);
134   }
GetTileSizeForDimensionY()135   int64 GetTileSizeForDimensionY() const {
136     return GetTileSizeForDimension(DimY);
137   }
138 
GetBlockSizes()139   absl::Span<const int64> GetBlockSizes() const { return block_sizes_; }
GetTileBlockSizeForDimension(int d)140   int64 GetTileBlockSizeForDimension(int d) const {
141     DCHECK(d >= DimZ && d <= DimX);
142     return dims_in_blocks_[d];
143   }
144 
GetNumberOfThreadsForDimensionX()145   int64 GetNumberOfThreadsForDimensionX() const { return num_threads_x_; }
GetNumberOfThreadsForDimensionY()146   int64 GetNumberOfThreadsForDimensionY() const { return num_threads_y_; }
147 
GetThreadsPerBlock()148   int64 GetThreadsPerBlock() const {
149     return GetNumberOfThreadsForDimensionX() *
150            GetNumberOfThreadsForDimensionY();
151   }
152 
DilatedX()153   bool DilatedX() const { return dilated_x_; }
SetDilatedX(bool v)154   void SetDilatedX(bool v) {
155     dilated_x_ = v;
156     if (!dilated_x_) {
157       // dilated_x_=false is for the purpose of vectorization, which requires
158       // GetTileSizeForDimension(DimX) to be a multiplier of num_threads_x_.
159       CHECK_EQ(GetTileSizeForDimension(DimX) % num_threads_x_, 0);
160     }
161   }
162 
163   IrArray::Index EmitBlockIndex(llvm::Type* index_ty);
164   // Returns the index for the first tile in the block with the given block
165   // index.
166   IrArray::Index GetTileIndexForBlockOrigin(const IrArray::Index& block_index);
167   // Returns the index for the first element in the tile with the given tile
168   // index.
169   IrArray::Index GetElementIndexForTileOrigin(const IrArray::Index& tile_index);
170 
171   std::tuple<llvm::Value*, llvm::Value*> EmitThreadYXCoordinate(
172       llvm::Type* index_ty);
173 
174   IrArray::Index GetUnnormalizedIndex(
175       const IrArray::Index& normalized_shape_index,
176       const Shape& unnormalized_shape);
177 
178   llvm::GlobalVariable* GetSharedMemoryBufferForElementType(
179       llvm::Type* elem_ty, absl::string_view buffer_name);
180 
181  private:
182   llvm::IRBuilder<>* b_;
183   // The number of elements in each dimension.
184   std::vector<int64> dims_in_elems_;
185 
186   // The number of elements for each dimension of a tile.
187   std::vector<int64> tile_sizes_;
188   // The number of tiles in each dimension. It is computed from dims_in_elem_
189   // and tile_sizes_.
190   std::vector<int64> dims_in_tiles_;
191 
192   // The number of tiles for each dimension of a tile block.
193   std::vector<int64> block_sizes_;
194   // The number of blocks in each dimension of a tile block. It is computed from
195   // dims_in_tile_ and block_sizes_.
196   std::vector<int64> dims_in_blocks_;
197 
198   // Number of threads used to process elements in the X direction of a tile.
199   int64 num_threads_x_;
200   // Number of threads used to process elements in the Y direction of a tile.
201   int64 num_threads_y_;
202 
203   // When num_threads_x threads process a total of tile_size_x elements in the
204   // X dimension of a tile, each threads process n=tile_size_x/num_threads_x
205   // elements. When dilated_x=false, the n elements processed by a thread are
206   // contiguous. On the other hand, when dilated_x=true the n elements are
207   // dilated by a factor of num_threads_x.
208   bool dilated_x_;
209 };
210 
211 // A class to represent information for tiled parameters to support IR emission
212 // for 021 transpose.
213 class TiledParameterInfo {
214  public:
TiledParameterInfo(absl::Span<llvm::Value * const> param_buffers,llvm::Value * y,llvm::Value * x)215   TiledParameterInfo(absl::Span<llvm::Value* const> param_buffers,
216                      llvm::Value* y, llvm::Value* x)
217       : param_buffers_(param_buffers), y_(y), x_(x) {}
218 
x()219   llvm::Value* x() const { return x_; }
y()220   llvm::Value* y() const { return y_; }
221 
set_x(llvm::Value * x)222   void set_x(llvm::Value* x) { x_ = x; }
set_y(llvm::Value * y)223   void set_y(llvm::Value* y) { y_ = y; }
224 
GetBufferForParameter(int64 index)225   llvm::Value* GetBufferForParameter(int64 index) const {
226     return param_buffers_[index];
227   }
228 
229  private:
230   // Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr
231   // if the parameter is not tiled.
232   absl::Span<llvm::Value* const> param_buffers_;
233   // The y coordinate within a tile.
234   llvm::Value* y_;
235   // The x coordinate within a tile.
236   llvm::Value* x_;
237 };
238 
239 }  // namespace llvm_ir
240 }  // namespace xla
241 
242 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_TILING_H_
243