1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
17
18 #include <algorithm>
19 #include <vector>
20
21 #include "llvm/IR/Module.h"
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/compiler/xla/window_util.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/protobuf.h"
34
35 namespace xla {
36 namespace gpu {
37
38 namespace {
39
40 // Return whether the given shape is rank 2 excluding the batch dimensions.
IsRank2(const Shape & shape,int64 batch_dimensions_size)41 bool IsRank2(const Shape& shape, int64 batch_dimensions_size) {
42 return shape.rank() == batch_dimensions_size + 2;
43 }
44
45 // In a gemm operation where output = lhs * rhs, check whether the given shapes
46 // are valid for the operation.
AreValidGemmShapes(const Shape & lhs_shape,const Shape & rhs_shape,const Shape & output_shape,int64 batch_dimensions_size)47 bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
48 const Shape& output_shape,
49 int64 batch_dimensions_size) {
50 // The inputs and the output must
51 // 1) be matrices with no padding and a non-zero number of elements,
52 // 2) have an allowed element type.
53 PrimitiveType output_primitive_type = output_shape.element_type();
54 bool type_is_allowed =
55 (output_primitive_type == F16 || output_primitive_type == F32 ||
56 output_primitive_type == F64 || output_primitive_type == C64 ||
57 output_primitive_type == C128);
58 return type_is_allowed && IsRank2(lhs_shape, batch_dimensions_size) &&
59 IsRank2(rhs_shape, batch_dimensions_size) &&
60 IsRank2(output_shape, batch_dimensions_size) &&
61 !ShapeUtil::IsZeroElementArray(lhs_shape) &&
62 !ShapeUtil::IsZeroElementArray(rhs_shape);
63 }
64
DotImplementedAsGemm(const HloInstruction & dot)65 bool DotImplementedAsGemm(const HloInstruction& dot) {
66 CHECK_EQ(dot.opcode(), HloOpcode::kDot);
67 const Shape& lhs_shape = dot.operand(0)->shape();
68 const Shape& rhs_shape = dot.operand(1)->shape();
69 const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers();
70
71 // If gemm can accept the operand shapes, use it rather than a custom
72 // kernel.
73 if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape(),
74 dim_numbers.lhs_batch_dimensions_size())) {
75 // The size of the reduction dimension should match. The shape inference
76 // guarantees this invariant, so the check here is for programming
77 // errors.
78 CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
79 rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
80 return true;
81 }
82 return false;
83 }
84 } // namespace
85
ImplementedAsGemm(const HloInstruction & hlo)86 bool ImplementedAsGemm(const HloInstruction& hlo) {
87 // For certain types of Dot, we can call pre-canned BLAS gemm.
88 if (hlo.opcode() == HloOpcode::kDot) {
89 return DotImplementedAsGemm(hlo);
90 }
91
92 if (hlo.opcode() == HloOpcode::kFusion &&
93 hlo.fusion_kind() == HloInstruction::FusionKind::kOutput &&
94 (hlo.fused_expression_root()->opcode() == HloOpcode::kMultiply ||
95 hlo.fused_expression_root()->opcode() == HloOpcode::kAdd)) {
96 // Try to find the dot inside the output fusion node.
97 const HloInstruction* dot = hlo.fused_expression_root()->operand(0);
98 if (dot->opcode() != HloOpcode::kDot) {
99 dot = hlo.fused_expression_root()->operand(1);
100 }
101 if (dot->opcode() == HloOpcode::kDot) {
102 return DotImplementedAsGemm(*dot);
103 }
104 }
105
106 return false;
107 }
108
109 const char* const kCudnnBatchNormForwardInferenceCallTarget =
110 "__cudnn$batchNormalizationForwardInference";
111 const char* const kCudnnBatchNormForwardTrainingCallTarget =
112 "__cudnn$batchNormalizationForwardTraining";
113 const char* const kCudnnBatchNormBackwardCallTarget =
114 "__cudnn$batchNormalizationBackward";
115
IsCustomCallToDnnBatchNorm(const HloInstruction & hlo)116 bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo) {
117 if (hlo.opcode() != HloOpcode::kCustomCall) {
118 return false;
119 }
120 const auto& target = hlo.custom_call_target();
121 return target == kCudnnBatchNormForwardInferenceCallTarget ||
122 target == kCudnnBatchNormForwardTrainingCallTarget ||
123 target == kCudnnBatchNormBackwardCallTarget;
124 }
125
126 const char* const kCudnnConvForwardCallTarget = "__cudnn$convForward";
127 const char* const kCudnnConvBackwardInputCallTarget =
128 "__cudnn$convBackwardInput";
129 const char* const kCudnnConvBackwardFilterCallTarget =
130 "__cudnn$convBackwardFilter";
131 const char* const kCudnnConvBiasActivationForwardCallTarget =
132 "__cudnn$convBiasActivationForward";
133
IsCustomCallToDnnConvolution(const HloInstruction & hlo)134 bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) {
135 if (hlo.opcode() != HloOpcode::kCustomCall) {
136 return false;
137 }
138 const auto& target = hlo.custom_call_target();
139 return target == kCudnnConvForwardCallTarget ||
140 target == kCudnnConvBackwardInputCallTarget ||
141 target == kCudnnConvBackwardFilterCallTarget ||
142 target == kCudnnConvBiasActivationForwardCallTarget;
143 }
144
145 const char* const kCusolverCholeskyCallTarget = "__cusolver$cholesky";
146
IsCustomCallToCusolver(const HloInstruction & hlo)147 bool IsCustomCallToCusolver(const HloInstruction& hlo) {
148 if (hlo.opcode() != HloOpcode::kCustomCall) {
149 return false;
150 }
151 const auto& target = hlo.custom_call_target();
152 return target == kCusolverCholeskyCallTarget;
153 }
154
ImplementedAsLibraryCall(const HloInstruction & hlo)155 bool ImplementedAsLibraryCall(const HloInstruction& hlo) {
156 return ImplementedAsGemm(hlo) || IsCustomCallToDnnBatchNorm(hlo) ||
157 IsCustomCallToDnnConvolution(hlo);
158 }
159
IsReductionToVector(const HloInstruction & reduce)160 bool IsReductionToVector(const HloInstruction& reduce) {
161 if (HloOpcode::kReduce != reduce.opcode()) {
162 return false;
163 }
164 const HloInstruction* input = reduce.operand(0);
165 std::vector<int64> dims_to_keep;
166 for (int64 dim = 0; dim < input->shape().dimensions().size(); ++dim) {
167 if (!absl::c_linear_search(reduce.dimensions(), dim)) {
168 dims_to_keep.push_back(dim);
169 }
170 }
171 return LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
172 dims_to_keep) &&
173 ShapeUtil::Equal(
174 reduce.shape(),
175 ShapeUtil::FilterDimensions(
176 [&](int64 dim) { return absl::c_count(dims_to_keep, dim); },
177 input->shape()));
178 }
179
180 // This emits a device-side call to
181 // "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see
182 // http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls
EmitPrintf(absl::string_view fmt,absl::Span<llvm::Value * const> arguments,llvm::IRBuilder<> * builder)183 llvm::Value* EmitPrintf(absl::string_view fmt,
184 absl::Span<llvm::Value* const> arguments,
185 llvm::IRBuilder<>* builder) {
186 std::vector<llvm::Type*> argument_types;
187 for (auto argument : arguments) {
188 argument_types.push_back(argument->getType());
189 }
190 auto* arguments_type = llvm::StructType::create(argument_types);
191 llvm::Value* arguments_ptr = builder->CreateAlloca(arguments_type);
192 for (size_t i = 0; i < arguments.size(); ++i) {
193 builder->CreateStore(
194 arguments[i],
195 builder->CreateGEP(arguments_ptr,
196 {builder->getInt64(0), builder->getInt32(i)}));
197 }
198 return builder->CreateCall(
199 builder->GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
200 "vprintf",
201 llvm::FunctionType::get(builder->getInt32Ty(),
202 {builder->getInt8Ty()->getPointerTo(),
203 arguments_type->getPointerTo()},
204 /*isVarArg=*/false)),
205 {builder->CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)),
206 arguments_ptr});
207 }
208
EmitFullWarpShuffleDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * builder)209 llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
210 llvm::IRBuilder<>* builder) {
211 int bit_width = value->getType()->getPrimitiveSizeInBits();
212 llvm::Value* all_warps_mask = builder->getInt32(-1);
213
214 // Special case for efficiency
215 if (value->getType()->isFloatTy() && bit_width == 32) {
216 return llvm_ir::EmitCallToIntrinsic(
217 llvm::Intrinsic::nvvm_shfl_sync_down_f32,
218 {all_warps_mask, value, offset, builder->getInt32(kWarpSize - 1)}, {},
219 builder);
220 }
221
222 // We must split values wider than 32 bits as the "shfl" instruction operates
223 // on 32-bit values.
224 int num_segments = CeilOfRatio(bit_width, 32);
225 llvm::Value* x = builder->CreateBitCast(
226 builder->CreateZExt(
227 builder->CreateBitCast(value, builder->getIntNTy(bit_width)),
228 builder->getIntNTy(32 * num_segments)),
229 llvm::VectorType::get(builder->getInt32Ty(), num_segments));
230 for (int i = 0; i < num_segments; ++i) {
231 x = builder->CreateInsertElement(
232 x,
233 llvm_ir::EmitCallToIntrinsic(
234 llvm::Intrinsic::nvvm_shfl_sync_down_i32,
235 {all_warps_mask, builder->CreateExtractElement(x, i), offset,
236 builder->getInt32(kWarpSize - 1)},
237 {}, builder),
238 i);
239 }
240 return builder->CreateBitCast(
241 builder->CreateTrunc(
242 builder->CreateBitCast(x, builder->getIntNTy(32 * num_segments)),
243 builder->getIntNTy(bit_width)),
244 value->getType());
245 }
246
GetCudnnConvKind(const HloCustomCallInstruction * instr)247 StatusOr<CudnnConvKind> GetCudnnConvKind(
248 const HloCustomCallInstruction* instr) {
249 absl::string_view target = instr->custom_call_target();
250 if (target == kCudnnConvForwardCallTarget) {
251 return CudnnConvKind::kForward;
252 }
253 if (target == kCudnnConvBackwardInputCallTarget) {
254 return CudnnConvKind::kBackwardInput;
255 }
256 if (target == kCudnnConvBackwardFilterCallTarget) {
257 return CudnnConvKind::kBackwardFilter;
258 }
259 if (target == kCudnnConvBiasActivationForwardCallTarget) {
260 return CudnnConvKind::kForwardActivation;
261 }
262 return InternalError("Unexpected call target: %s", target);
263 }
264
CudnnConvKindToString(CudnnConvKind kind)265 string CudnnConvKindToString(CudnnConvKind kind) {
266 switch (kind) {
267 case CudnnConvKind::kForward:
268 return "forward";
269 case CudnnConvKind::kBackwardFilter:
270 return "backward_filter";
271 case CudnnConvKind::kBackwardInput:
272 return "backward_input";
273 case CudnnConvKind::kForwardActivation:
274 return "forward with activation";
275 }
276 }
277
IsBlock0Thread0(llvm::IRBuilder<> * b)278 llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) {
279 return b->CreateAnd(
280 b->CreateICmpEQ(
281 b->getInt32(0),
282 llvm_ir::EmitCallToIntrinsic(
283 llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b)),
284 b->CreateICmpEQ(
285 b->getInt32(0),
286 llvm_ir::EmitCallToIntrinsic(
287 llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b)));
288 }
289
290 } // namespace gpu
291 } // namespace xla
292