19 #include <utility>
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/Value.h"
24 #include "mlir/IR/Operation.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
26 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
27 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
28 #include "tensorflow/compiler/xla/service/gpu/gpu_device_info.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
31 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
33 // TODO(jlebar): Move functions related to cublas/cudnn to a separate file; they
34 // don't belong in "ir_emission_utils".
36 namespace xla {
37 namespace gpu {
39 // Different types of convolutions supported by cudnn.
40 //
41 // A way to think about these is that a convolution is defined by three arrays
42 // -- the "input", the "filter", and the "output" -- and given any two of these,
43 // we can compute the third.  For example, a backward-input convolution takes as
44 // input a filter and an "output" and produces an "input" such that if one were
45 // to do a forward convolution of "input" using filter, the result would be
46 // something with the same shape as "output".
47 //
48 // This way of thinking is not correct if you look at the values produced. For
49 // example, a backward-input convolution is not actually the mathematical
50 // inverse of a forward convolution.  But it's right as far as the shapes and
51 // "connectivity" (i.e. which elements of the input affect which elements of
52 // the output) are concerned.
53 enum class CudnnConvKind {
54   kForward,            // input  + filter => output
55   kBackwardInput,      // filter + output => input
56   kBackwardFilter,     // input  + output => filter
57   kForwardActivation,  // activation(conv(input, filter) + broadcast(bias) +
58                        // (optionally) side_input) => output
59 };
61 StatusOr<CudnnConvKind> GetCudnnConvKind(const HloCustomCallInstruction* instr);
63 // Converts a CudnnConvKind value to a string.
64 string CudnnConvKindToString(CudnnConvKind kind);
66 // Matrix multiplication before the rewrite.
67 //
68 // This function should never return "true" on instructions after
69 // GemmRewriter pass has finished.
70 bool IsMatrixMultiplication(const HloInstruction& dot);
72 // Matrix multiplication rewritten into a GEMM custom call.
73 // All matrix multiplications should be rewritten as such custom calls
74 // after a GemmRewriter lowering pass.
75 bool IsCublasGemm(const HloInstruction& hlo);
77 constexpr int64 kWarpSize = 32;
79 // A call to cuBLAS general matrix multiplication API.
80 extern const char* const kGemmCallTarget;
82 // A call to cuDNN for batch normalization is represented as CustomCall HLO with
83 // a call target equal to one of these strings.
84 //
85 // The operands to and outputs of these calls are the same as those of the
86 // corresponding HLOs, except:
87 //
88 //  - epsilon and feature_index are proper operands, at the end of the operands
89 //    list.  They must be HLO constants.
90 //  - The cuDNN forward training call returns inv_stddev =
91 //    1/sqrt(variance + epsilon) in place of plain variance.
92 //  - Similarly, BatchNormGrad accepts inv_stddev in place of the variance
93 //    operand.
94 extern const char* const kCudnnBatchNormForwardInferenceCallTarget;
95 extern const char* const kCudnnBatchNormForwardTrainingCallTarget;
96 extern const char* const kCudnnBatchNormBackwardCallTarget;
98 // Returns true if `hlo` will be implemented as a call to a cuDNN batch
99 // normalization routine.
100 //
101 // This returns true if `hlo` is a CustomCall HLO with a call target equal to
102 // one of the kCudnnBatchNormFoo constants above, but returns *false* for HLOs
103 // with one of the kBatchNorm opcodes, because these are lowered either to a
104 // sequence of generic HLOs or to a cuDNN CustomCall.
105 bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo);
107 // A call to cuDNN for convolution (forward, backward filter, or backward input)
108 // is represented as a CustomCall HLO with a call target equal to one of these
109 // strings.
110 //
111 // These CustomCalls have window() and convolution_dimension_numbers() set like
112 // regular convolution ops.  They have the same LHS and RHS operands, plus two
113 // additional constant operands: an int64 operand for the cudnn algorithm and
114 // a bool operand for whether tensor_ops is enabled. A value of -1 for the cudnn
115 // algorithm means that the implementation is free to choose the best algorithm
116 // it can.
117 //
118 // These calls output a tuple (conv_result, scratch_memory), where conv_result
119 // is the actual result of the convolution, and scratch_memory is temporary
120 // memory used by cudnn.  Callers shouldn't inspect scratch_memory, as its value
121 // is not well-defined.
122 //
123 // GpuConvRewriter lowers kConvolution HLOs to these custom calls.
124 // When it does so, it chooses algorithm -1 and 0 bytes of scratch space.  Later
125 // on in the pipeline, CudnnConvAlgorithmChooser chooses an explicit
126 // algorithm for each conv and sets the amount of scratch space needed.
127 //
128 // (Representing the scratch memory as an output may seem strange at first, but
129 // it's quite sensible, from a certain point of view.  The scratch buffer is a
130 // location in memory that the conv can write into, but which it can't legally
131 // read from, at least until it's written something first.  But that's exactly
132 // the definition of an output buffer.)
133 extern const char* const kCudnnConvForwardCallTarget;
134 extern const char* const kCudnnConvBackwardInputCallTarget;
135 extern const char* const kCudnnConvBackwardFilterCallTarget;
136 extern const char* const kCudnnConvBiasActivationForwardCallTarget;
138 // Returns true if `hlo` will be implemented as a call to a cuDNN convolution
139 // routine.
140 //
141 // This returns true if `hlo` is a CustomCall HLO with a call target equal to
142 // one of the kCudnnConvFoo constants above, but returns *false* for HLOs with a
143 // kConvolution opcode.
144 bool IsCustomCallToDnnConvolution(const HloInstruction& hlo);
146 // Returns true if `hlo` will be implemented as a call to a cuSolver routine.
147 //
148 // This returns true if `hlo` is a CustomCall HLO with a call target equal to
149 // one of the kCusolver... constants, but returns *false* for HLOs with
150 // say, a kCholesky opcode.
151 bool IsCustomCallToCusolver(const HloInstruction& hlo);
153 // Cholesky decomposition. Takes a (batched) matrix as input, and returns a
154 // tuple of (result, workspace, info), where result is the result of the
155 // Cholesky decomposition, workspace is scratch space for cuSolver, and info
156 // is a success/failure code per batch element.
157 extern const char* const kCusolverCholeskyCallTarget;
159 // Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm
160 // or cuDNN convolution.
161 bool ImplementedAsLibraryCall(const HloInstruction& hlo);
163 // Returns true if either the dimensions being reduced or the dimensions being
164 // kept are contiguous in the input of the reduce instruction.
165 bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce);
166 bool IsReductionFromOrToContiguousDimensions(mlir::Operation* reduce);
168 // Returns whether unnested_hlo is an input fusion whose root is either a slice
169 // or a tuple of slices. If verify_no_strides is true, returns false unless all
170 // ROOT slices have no strides.
171 bool IsInputFusibleSlices(mlir::Operation* unnested_hlo,
172                           bool verify_no_strides);
174 struct ReductionDimensions {
175   // Indicates whether the reduction is a row reduction or a column reduction.
176   bool is_row_reduction;
178   // Contains the size of the three contiguous components for
179   // the reduction [depth, height, width] (major-to-minor ordering).
180   //
181   // For row reduction, we do: [D, H, W] -> [D, H].
182   // For column reduction, we do: [D, H, W] -> [D, W].
183   std::array<int64, 3> dimensions;
184 };
186 // Given the input shape and dimensions to reduce for a reduction, returns
187 // ReductionDimensions.
188 //
189 // Prerequisite: the reduction instruction passes the check
190 // IsReductionFromOrToContiguousDimensions, which guarantees either the
191 // dimensions to reduce or the dimensions to keep are consecutive.
192 ReductionDimensions GetReductionKindAndContiguousComponents(
193     const HloInstruction& reduce);
194 ReductionDimensions GetReductionKindAndContiguousComponents(
195     mlir::Operation* reduce);
197 // Get tiling per thread for the given reduction in dimensions [D, H, W] per
198 // thread.
199 // If the device isn't known pass null for device_description and you will get
200 // non-optimized value.
201 std::array<int64, 3> GetReductionTiling(
202     const ReductionDimensions& reduction_dimensions,
203     int smallest_input_dtype_bits,
204     absl::optional<CudaComputeCapability> cuda_compute_capability);
206 // Emits call to "vprintf" with given format and arguments.
207 llvm::Value* EmitPrintf(absl::string_view fmt,
208                         absl::Span<llvm::Value* const> arguments,
209                         llvm::IRBuilder<>* builder);
211 // Emits code to shuffle data between threads of a warp. This has the same
212 // semantics as the PTX "shfl.sync.down" instruction but works for values that
213 // aren't 32 bits in size. The last operand of the emitted "shfl" is
214 // `kWarpSize - 1`.
215 //
216 // This function emits a "full-warp" shuffle, which all threads of a warp
217 // participate in.  *Do not use this function from a divergent context:* You
218 // can't correctly do so on both Volta and earlier GPUs.
219 //
220 // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync
221 llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
222                                      llvm::IRBuilder<>* builder);
224 // Emits code that determines whether the current thread is thread 0 within
225 // block 0 of the kernel.
226 llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b);
228 // Returns whether the output of a fusion with reduction are consistent with
229 // `first_reduce`.
230 bool IsFusedReductionOutputConsistent(const HloInstruction* inst,
231                                       const HloInstruction* first_reduce);
232 bool IsFusedReductionOutputConsistent(mlir::mhlo::ReduceOp inst,
233                                       mlir::mhlo::ReduceOp first_reduce);
AreFusedReductionOutputsConsistent(absl::Span<const HloInstruction * const> output_instructions,const HloInstruction * first_reduce)235 inline bool AreFusedReductionOutputsConsistent(
236     absl::Span<const HloInstruction* const> output_instructions,
237     const HloInstruction* first_reduce) {
238   return absl::c_all_of(output_instructions, [=](const HloInstruction* inst) {
239     return IsFusedReductionOutputConsistent(inst, first_reduce);
240   });
241 }
MlirToString(mlir::Operation * op)243 inline std::string MlirToString(mlir::Operation* op) {
244   std::string s;
245   {
246     llvm::raw_string_ostream os(s);
247     op->print(os);
248   }
249   return s;
250 }
252 int PartitionLmhloOperandsAndOutputs(mlir::Operation* op);
253 std::vector<mlir::Value> GetHloOperands(mlir::Operation* op);
254 std::vector<mlir::Value> GetHloOutputs(mlir::Operation* op);
256 bool WritesMlirBuffer(mlir::Operation* op, mlir::Value operand);
258 template <typename T>
ToStdVector(const llvm::SmallVectorImpl<T> & v)259 std::vector<T> ToStdVector(const llvm::SmallVectorImpl<T>& v) {
260   return std::vector<T>(v.begin(), v.end());
261 }
263 StatusOr<BufferAllocation::Slice> GetAllocationSliceForMlir(
264     mlir::Value v, absl::Span<const BufferAllocation> allocations);
266 bool CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
267     mlir::lmhlo::FusionOp fusion,
268     absl::Span<const BufferAllocation> allocations);
270 }  // namespace gpu
271 }  // namespace xla