1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMISSION_UTILS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMISSION_UTILS_H_
18 
19 #include <utility>
20 
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/Value.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
25 
26 // TODO(jlebar): Move functions related to cublas/cudnn to a separate file; they
27 // don't belong in "ir_emission_utils".
28 
29 namespace xla {
30 namespace gpu {
31 
32 // Different types of convolutions supported by cudnn.
33 //
34 // A way to think about these is that a convolution is defined by three arrays
35 // -- the "input", the "filter", and the "output" -- and given any two of these,
36 // we can compute the third.  For example, a backward-input convolution takes as
37 // input a filter and an "output" and produces an "input" such that if one were
38 // to do a forward convolution of "input" using filter, the result would be
39 // something with the same shape as "output".
40 //
41 // This way of thinking is not correct if you look at the values produced. For
42 // example, a backward-input convolution is not actually the mathematical
43 // inverse of a forward convolution.  But it's right as far as the shapes and
44 // "connectivity" (i.e. which elements of the input affect which elements of
45 // the output) are concerned.
46 enum class CudnnConvKind {
47   kForward,            // input  + filter => output
48   kBackwardInput,      // filter + output => input
49   kBackwardFilter,     // input  + output => filter
50   kForwardActivation,  // activation(conv(input, filter) + broadcast(bias) +
51                        // (optionally) side_input) => output
52 };
53 
54 StatusOr<CudnnConvKind> GetCudnnConvKind(const HloCustomCallInstruction* instr);
55 
56 // Converts a CudnnConvKind value to a string.
57 string CudnnConvKindToString(CudnnConvKind kind);
58 
59 constexpr int64 kWarpSize = 32;
60 
61 // Returns true if `hlo` will be implemented as a call to BLAS gemm.
62 //
63 // Precondition: `hlo` is in an "unnested context", meaning, it lives within the
64 // entry computation, within the either of a while loop's subcomputations,
65 // within any of a conditional's subcomputations, etc., but *does not* live
66 // within a reduce subcomputation, a map subcomputation, a fusion
67 // subcomputation, etc.  It's OK if `hlo` *is* a fusion.
68 bool ImplementedAsGemm(const HloInstruction& hlo);
69 
70 // A call to cuDNN for batch normalization is represented as CustomCall HLO with
71 // a call target equal to one of these strings.
72 //
73 // The operands to and outputs of these calls are the same as those of the
74 // corresponding HLOs, except:
75 //
76 //  - epsilon and feature_index are proper operands, at the end of the operands
77 //    list.  They must be HLO constants.
78 //  - The cuDNN forward training call returns inv_stddev =
79 //    1/sqrt(variance + epsilon) in place of plain variance.
80 //  - Similarly, BatchNormGrad accepts inv_stddev in place of the variance
81 //    operand.
82 extern const char* const kCudnnBatchNormForwardInferenceCallTarget;
83 extern const char* const kCudnnBatchNormForwardTrainingCallTarget;
84 extern const char* const kCudnnBatchNormBackwardCallTarget;
85 
86 // Returns true if `hlo` will be implemented as a call to a cuDNN batch
87 // normalization routine.
88 //
89 // This returns true if `hlo` is a CustomCall HLO with a call target equal to
90 // one of the kCudnnBatchNormFoo constants above, but returns *false* for HLOs
91 // with one of the kBatchNorm opcodes, because these are lowered either to a
92 // sequence of generic HLOs or to a cuDNN CustomCall.
93 bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo);
94 
95 // A call to cuDNN for convolution (forward, backward filter, or backward input)
96 // is represented as a CustomCall HLO with a call target equal to one of these
97 // strings.
98 //
99 // These CustomCalls have window() and convolution_dimension_numbers() set like
100 // regular convolution ops.  They have the same LHS and RHS operands, plus two
101 // additional constant operands: an int64 operand for the cudnn algorithm and
102 // a bool operand for whether tensor_ops is enabled. A value of -1 for the cudnn
103 // algorithm means that the implementation is free to choose the best algorithm
104 // it can.
105 //
106 // These calls output a tuple (conv_result, scratch_memory), where conv_result
107 // is the actual result of the convolution, and scratch_memory is temporary
108 // memory used by cudnn.  Callers shouldn't inspect scratch_memory, as its value
109 // is not well-defined.
110 //
111 // CudnnConvRewriter lowers kConvolution HLOs to these custom calls.
112 // When it does so, it chooses algorithm -1 and 0 bytes of scratch space.  Later
113 // on in the pipeline, CudnnConvAlgorithmChooser chooses an explicit
114 // algorithm for each conv and sets the amount of scratch space needed.
115 //
116 // (Representing the scratch memory as an output may seem strange at first, but
117 // it's quite sensible, from a certain point of view.  The scratch buffer is a
118 // location in memory that the conv can write into, but which it can't legally
119 // read from, at least until it's written something first.  But that's exactly
120 // the definition of an output buffer.)
121 extern const char* const kCudnnConvForwardCallTarget;
122 extern const char* const kCudnnConvBackwardInputCallTarget;
123 extern const char* const kCudnnConvBackwardFilterCallTarget;
124 extern const char* const kCudnnConvBiasActivationForwardCallTarget;
125 
126 // Returns true if `hlo` will be implemented as a call to a cuDNN convolution
127 // routine.
128 //
129 // This returns true if `hlo` is a CustomCall HLO with a call target equal to
130 // one of the kCudnnConvFoo constants above, but returns *false* for HLOs with a
131 // kConvolution opcode.
132 bool IsCustomCallToDnnConvolution(const HloInstruction& hlo);
133 
134 // Returns true if `hlo` will be implemented as a call to a cuSolver routine.
135 //
136 // This returns true if `hlo` is a CustomCall HLO with a call target equal to
137 // one of the kCusolver... constants, but returns *false* for HLOs with
138 // say, a kCholesky opcode.
139 bool IsCustomCallToCusolver(const HloInstruction& hlo);
140 
141 // Cholesky decomposition. Takes a (batched) matrix as input, and returns a
142 // tuple of (result, workspace, info), where result is the result of the
143 // Cholesky decomposition, workspace is scratch space for cuSolver, and info
144 // is a success/failure code per batch element.
145 extern const char* const kCusolverCholeskyCallTarget;
146 
147 // Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm
148 // or cuDNN convolution.
149 bool ImplementedAsLibraryCall(const HloInstruction& hlo);
150 
151 bool IsReductionToVector(const HloInstruction& reduce);
152 
153 // Emits call to "vprintf" with given format and arguments.
154 llvm::Value* EmitPrintf(absl::string_view fmt,
155                         absl::Span<llvm::Value* const> arguments,
156                         llvm::IRBuilder<>* builder);
157 
158 // Emits code to shuffle data between threads of a warp. This has the same
159 // semantics as the PTX "shfl.sync.down" instruction but works for values that
160 // aren't 32 bits in size. The last operand of the emitted "shfl" is
161 // `kWarpSize - 1`.
162 //
163 // This function emits a "full-warp" shuffle, which all threads of a warp
164 // participate in.  *Do not use this function from a divergent context:* You
165 // can't correctly do so on both Volta and earlier GPUs.
166 //
167 // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync
168 llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
169                                      llvm::IRBuilder<>* builder);
170 
171 // Emits code that determines whether the current thread is thread 0 within
172 // block 0 of the kernel.
173 llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b);
174 
175 }  // namespace gpu
176 }  // namespace xla
177 
178 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMISSION_UTILS_H_
179