1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
17 
18 #include <algorithm>
19 #include <vector>
20 
21 #include "llvm/IR/Module.h"
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/compiler/xla/window_util.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/protobuf.h"
34 
35 namespace xla {
36 namespace gpu {
37 
38 namespace {
39 
40 // Return whether the given shape is rank 2 excluding the batch dimensions.
IsRank2(const Shape & shape,int64 batch_dimensions_size)41 bool IsRank2(const Shape& shape, int64 batch_dimensions_size) {
42   return shape.rank() == batch_dimensions_size + 2;
43 }
44 
45 // In a gemm operation where output = lhs * rhs, check whether the given shapes
46 // are valid for the operation.
AreValidGemmShapes(const Shape & lhs_shape,const Shape & rhs_shape,const Shape & output_shape,int64 batch_dimensions_size)47 bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
48                         const Shape& output_shape,
49                         int64 batch_dimensions_size) {
50   // The inputs and the output must
51   // 1) be matrices with no padding and a non-zero number of elements,
52   // 2) have an allowed element type.
53   PrimitiveType output_primitive_type = output_shape.element_type();
54   bool type_is_allowed =
55       (output_primitive_type == F16 || output_primitive_type == F32 ||
56        output_primitive_type == F64 || output_primitive_type == C64 ||
57        output_primitive_type == C128);
58   return type_is_allowed && IsRank2(lhs_shape, batch_dimensions_size) &&
59          IsRank2(rhs_shape, batch_dimensions_size) &&
60          IsRank2(output_shape, batch_dimensions_size) &&
61          !ShapeUtil::IsZeroElementArray(lhs_shape) &&
62          !ShapeUtil::IsZeroElementArray(rhs_shape);
63 }
64 
DotImplementedAsGemm(const HloInstruction & dot)65 bool DotImplementedAsGemm(const HloInstruction& dot) {
66   CHECK_EQ(dot.opcode(), HloOpcode::kDot);
67   const Shape& lhs_shape = dot.operand(0)->shape();
68   const Shape& rhs_shape = dot.operand(1)->shape();
69   const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers();
70 
71   // If gemm can accept the operand shapes, use it rather than a custom
72   // kernel.
73   if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape(),
74                          dim_numbers.lhs_batch_dimensions_size())) {
75     // The size of the reduction dimension should match. The shape inference
76     // guarantees this invariant, so the check here is for programming
77     // errors.
78     CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
79              rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
80     return true;
81   }
82   return false;
83 }
84 }  // namespace
85 
ImplementedAsGemm(const HloInstruction & hlo)86 bool ImplementedAsGemm(const HloInstruction& hlo) {
87   // For certain types of Dot, we can call pre-canned BLAS gemm.
88   if (hlo.opcode() == HloOpcode::kDot) {
89     return DotImplementedAsGemm(hlo);
90   }
91 
92   if (hlo.opcode() == HloOpcode::kFusion &&
93       hlo.fusion_kind() == HloInstruction::FusionKind::kOutput &&
94       (hlo.fused_expression_root()->opcode() == HloOpcode::kMultiply ||
95        hlo.fused_expression_root()->opcode() == HloOpcode::kAdd)) {
96     // Try to find the dot inside the output fusion node.
97     const HloInstruction* dot = hlo.fused_expression_root()->operand(0);
98     if (dot->opcode() != HloOpcode::kDot) {
99       dot = hlo.fused_expression_root()->operand(1);
100     }
101     if (dot->opcode() == HloOpcode::kDot) {
102       return DotImplementedAsGemm(*dot);
103     }
104   }
105 
106   return false;
107 }
108 
109 const char* const kCudnnBatchNormForwardInferenceCallTarget =
110     "__cudnn$batchNormalizationForwardInference";
111 const char* const kCudnnBatchNormForwardTrainingCallTarget =
112     "__cudnn$batchNormalizationForwardTraining";
113 const char* const kCudnnBatchNormBackwardCallTarget =
114     "__cudnn$batchNormalizationBackward";
115 
IsCustomCallToDnnBatchNorm(const HloInstruction & hlo)116 bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo) {
117   if (hlo.opcode() != HloOpcode::kCustomCall) {
118     return false;
119   }
120   const auto& target = hlo.custom_call_target();
121   return target == kCudnnBatchNormForwardInferenceCallTarget ||
122          target == kCudnnBatchNormForwardTrainingCallTarget ||
123          target == kCudnnBatchNormBackwardCallTarget;
124 }
125 
126 const char* const kCudnnConvForwardCallTarget = "__cudnn$convForward";
127 const char* const kCudnnConvBackwardInputCallTarget =
128     "__cudnn$convBackwardInput";
129 const char* const kCudnnConvBackwardFilterCallTarget =
130     "__cudnn$convBackwardFilter";
131 const char* const kCudnnConvBiasActivationForwardCallTarget =
132     "__cudnn$convBiasActivationForward";
133 
IsCustomCallToDnnConvolution(const HloInstruction & hlo)134 bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) {
135   if (hlo.opcode() != HloOpcode::kCustomCall) {
136     return false;
137   }
138   const auto& target = hlo.custom_call_target();
139   return target == kCudnnConvForwardCallTarget ||
140          target == kCudnnConvBackwardInputCallTarget ||
141          target == kCudnnConvBackwardFilterCallTarget ||
142          target == kCudnnConvBiasActivationForwardCallTarget;
143 }
144 
145 const char* const kCusolverCholeskyCallTarget = "__cusolver$cholesky";
146 
IsCustomCallToCusolver(const HloInstruction & hlo)147 bool IsCustomCallToCusolver(const HloInstruction& hlo) {
148   if (hlo.opcode() != HloOpcode::kCustomCall) {
149     return false;
150   }
151   const auto& target = hlo.custom_call_target();
152   return target == kCusolverCholeskyCallTarget;
153 }
154 
ImplementedAsLibraryCall(const HloInstruction & hlo)155 bool ImplementedAsLibraryCall(const HloInstruction& hlo) {
156   return ImplementedAsGemm(hlo) || IsCustomCallToDnnBatchNorm(hlo) ||
157          IsCustomCallToDnnConvolution(hlo);
158 }
159 
IsReductionToVector(const HloInstruction & reduce)160 bool IsReductionToVector(const HloInstruction& reduce) {
161   if (HloOpcode::kReduce != reduce.opcode()) {
162     return false;
163   }
164   const HloInstruction* input = reduce.operand(0);
165   std::vector<int64> dims_to_keep;
166   for (int64 dim = 0; dim < input->shape().dimensions().size(); ++dim) {
167     if (!absl::c_linear_search(reduce.dimensions(), dim)) {
168       dims_to_keep.push_back(dim);
169     }
170   }
171   return LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
172                                               dims_to_keep) &&
173          ShapeUtil::Equal(
174              reduce.shape(),
175              ShapeUtil::FilterDimensions(
176                  [&](int64 dim) { return absl::c_count(dims_to_keep, dim); },
177                  input->shape()));
178 }
179 
180 // This emits a device-side call to
181 // "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see
182 // http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls
EmitPrintf(absl::string_view fmt,absl::Span<llvm::Value * const> arguments,llvm::IRBuilder<> * builder)183 llvm::Value* EmitPrintf(absl::string_view fmt,
184                         absl::Span<llvm::Value* const> arguments,
185                         llvm::IRBuilder<>* builder) {
186   std::vector<llvm::Type*> argument_types;
187   for (auto argument : arguments) {
188     argument_types.push_back(argument->getType());
189   }
190   auto* arguments_type = llvm::StructType::create(argument_types);
191   llvm::Value* arguments_ptr = builder->CreateAlloca(arguments_type);
192   for (size_t i = 0; i < arguments.size(); ++i) {
193     builder->CreateStore(
194         arguments[i],
195         builder->CreateGEP(arguments_ptr,
196                            {builder->getInt64(0), builder->getInt32(i)}));
197   }
198   return builder->CreateCall(
199       builder->GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
200           "vprintf",
201           llvm::FunctionType::get(builder->getInt32Ty(),
202                                   {builder->getInt8Ty()->getPointerTo(),
203                                    arguments_type->getPointerTo()},
204                                   /*isVarArg=*/false)),
205       {builder->CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)),
206        arguments_ptr});
207 }
208 
EmitFullWarpShuffleDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * builder)209 llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
210                                      llvm::IRBuilder<>* builder) {
211   int bit_width = value->getType()->getPrimitiveSizeInBits();
212   llvm::Value* all_warps_mask = builder->getInt32(-1);
213 
214   // Special case for efficiency
215   if (value->getType()->isFloatTy() && bit_width == 32) {
216     return llvm_ir::EmitCallToIntrinsic(
217         llvm::Intrinsic::nvvm_shfl_sync_down_f32,
218         {all_warps_mask, value, offset, builder->getInt32(kWarpSize - 1)}, {},
219         builder);
220   }
221 
222   // We must split values wider than 32 bits as the "shfl" instruction operates
223   // on 32-bit values.
224   int num_segments = CeilOfRatio(bit_width, 32);
225   llvm::Value* x = builder->CreateBitCast(
226       builder->CreateZExt(
227           builder->CreateBitCast(value, builder->getIntNTy(bit_width)),
228           builder->getIntNTy(32 * num_segments)),
229       llvm::VectorType::get(builder->getInt32Ty(), num_segments));
230   for (int i = 0; i < num_segments; ++i) {
231     x = builder->CreateInsertElement(
232         x,
233         llvm_ir::EmitCallToIntrinsic(
234             llvm::Intrinsic::nvvm_shfl_sync_down_i32,
235             {all_warps_mask, builder->CreateExtractElement(x, i), offset,
236              builder->getInt32(kWarpSize - 1)},
237             {}, builder),
238         i);
239   }
240   return builder->CreateBitCast(
241       builder->CreateTrunc(
242           builder->CreateBitCast(x, builder->getIntNTy(32 * num_segments)),
243           builder->getIntNTy(bit_width)),
244       value->getType());
245 }
246 
GetCudnnConvKind(const HloCustomCallInstruction * instr)247 StatusOr<CudnnConvKind> GetCudnnConvKind(
248     const HloCustomCallInstruction* instr) {
249   absl::string_view target = instr->custom_call_target();
250   if (target == kCudnnConvForwardCallTarget) {
251     return CudnnConvKind::kForward;
252   }
253   if (target == kCudnnConvBackwardInputCallTarget) {
254     return CudnnConvKind::kBackwardInput;
255   }
256   if (target == kCudnnConvBackwardFilterCallTarget) {
257     return CudnnConvKind::kBackwardFilter;
258   }
259   if (target == kCudnnConvBiasActivationForwardCallTarget) {
260     return CudnnConvKind::kForwardActivation;
261   }
262   return InternalError("Unexpected call target: %s", target);
263 }
264 
CudnnConvKindToString(CudnnConvKind kind)265 string CudnnConvKindToString(CudnnConvKind kind) {
266   switch (kind) {
267     case CudnnConvKind::kForward:
268       return "forward";
269     case CudnnConvKind::kBackwardFilter:
270       return "backward_filter";
271     case CudnnConvKind::kBackwardInput:
272       return "backward_input";
273     case CudnnConvKind::kForwardActivation:
274       return "forward with activation";
275   }
276 }
277 
IsBlock0Thread0(llvm::IRBuilder<> * b)278 llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) {
279   return b->CreateAnd(
280       b->CreateICmpEQ(
281           b->getInt32(0),
282           llvm_ir::EmitCallToIntrinsic(
283               llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b)),
284       b->CreateICmpEQ(
285           b->getInt32(0),
286           llvm_ir::EmitCallToIntrinsic(
287               llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b)));
288 }
289 
290 }  // namespace gpu
291 }  // namespace xla
292