1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
17 
18 #include <algorithm>
19 #include <array>
20 #include <vector>
21 
22 #include "llvm/IR/IntrinsicsNVPTX.h"
23 #include "llvm/IR/Module.h"
24 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
26 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
27 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
28 #include "tensorflow/compiler/xla/layout_util.h"
29 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_module.h"
33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
34 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
35 #include "tensorflow/compiler/xla/shape_util.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/window_util.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 #include "tensorflow/core/platform/logging.h"
40 #include "tensorflow/core/platform/protobuf.h"
41 #include "tensorflow/stream_executor/device_description.h"
42 
43 namespace xla {
44 namespace gpu {
45 
46 namespace {
47 
48 // Return whether the given shape is rank 2 excluding the batch dimensions.
IsRank2(const Shape & shape,int64 batch_dimensions_size)49 bool IsRank2(const Shape& shape, int64 batch_dimensions_size) {
50   return shape.rank() == batch_dimensions_size + 2;
51 }
52 
53 // In a gemm operation where output = lhs * rhs, check whether the given shapes
54 // are valid for the operation.
AreValidGemmShapes(const Shape & lhs_shape,const Shape & rhs_shape,const Shape & output_shape,int64 batch_dimensions_size)55 bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
56                         const Shape& output_shape,
57                         int64 batch_dimensions_size) {
58   // The inputs and the output must
59   // 1) be matrices with no padding and a non-zero number of elements,
60   // 2) have an allowed element type.
61   PrimitiveType output_primitive_type = output_shape.element_type();
62   bool type_is_allowed =
63       (output_primitive_type == F16 || output_primitive_type == F32 ||
64        output_primitive_type == F64 || output_primitive_type == C64 ||
65        output_primitive_type == C128);
66   return type_is_allowed && IsRank2(lhs_shape, batch_dimensions_size) &&
67          IsRank2(rhs_shape, batch_dimensions_size) &&
68          IsRank2(output_shape, batch_dimensions_size) &&
69          !ShapeUtil::IsZeroElementArray(lhs_shape) &&
70          !ShapeUtil::IsZeroElementArray(rhs_shape);
71 }
72 
73 // Given a shape and a group of contiguous dimensions in the shape, returns
74 // a tuple of three values (major, middle, minor), where major is the size of
75 // the dimensions more major then the given dimensions, minor is the size of
76 // dimensions more minor then the given dimensions, and middle is the size of
77 // the given dimensions.
PartitionShapeByMiddleDimensions(const Shape & shape,absl::Span<const int64> dims_middle)78 std::array<int64, 3> PartitionShapeByMiddleDimensions(
79     const Shape& shape, absl::Span<const int64> dims_middle) {
80   CHECK(LayoutUtil::AreDimensionsConsecutive(shape.layout(), dims_middle));
81   std::array<int64, 3> values = {1, 1, 1};
82   enum Segment { kMajor = 0, kMiddle = 1, kMinor = 2 };
83   Segment cur_segment = kMinor;
84 
85   for (int64 cur_dim : LayoutUtil::MinorToMajor(shape)) {
86     if (cur_segment != kMajor) {
87       // Handle change of segments.
88       bool cur_dim_in_middle = absl::c_linear_search(dims_middle, cur_dim);
89       if (cur_segment == kMinor) {
90         if (cur_dim_in_middle) {
91           cur_segment = kMiddle;
92         }
93       } else if (cur_segment == kMiddle) {
94         if (!cur_dim_in_middle) {
95           cur_segment = kMajor;
96         }
97       }
98     }
99     values[cur_segment] *= shape.dimensions(cur_dim);
100   }
101   return values;
102 }
103 
104 }  // namespace
105 
IsMatrixMultiplication(const HloInstruction & dot)106 bool IsMatrixMultiplication(const HloInstruction& dot) {
107   if (dot.opcode() != HloOpcode::kDot) {
108     return false;
109   }
110   const Shape& lhs_shape = dot.operand(0)->shape();
111   const Shape& rhs_shape = dot.operand(1)->shape();
112   const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers();
113 
114   // If gemm can accept the operand shapes, use it rather than a custom
115   // kernel.
116   if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape(),
117                          dim_numbers.lhs_batch_dimensions_size())) {
118     // The size of the reduction dimension should match. The shape inference
119     // guarantees this invariant, so the check here is for programming
120     // errors.
121     CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
122              rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
123     return true;
124   }
125   return false;
126 }
127 
IsCublasGemm(const HloInstruction & hlo)128 bool IsCublasGemm(const HloInstruction& hlo) {
129   return hlo.opcode() == HloOpcode::kCustomCall &&
130          hlo.custom_call_target() == kGemmCallTarget;
131 }
132 
GetReductionTiling(const ReductionDimensions & reduction_dimensions,int smallest_input_dtype_bits,absl::optional<CudaComputeCapability> cuda_compute_capability)133 std::array<int64, 3> GetReductionTiling(
134     const ReductionDimensions& reduction_dimensions,
135     int smallest_input_dtype_bits,
136     absl::optional<CudaComputeCapability> cuda_compute_capability) {
137   if (reduction_dimensions.is_row_reduction) {
138     int64 tile_z = std::min(reduction_dimensions.dimensions[0], int64{8});
139     if (reduction_dimensions.dimensions[1] == 1) {
140       CHECK_EQ(reduction_dimensions.dimensions[0], 1);
141       return {tile_z, 1, 16};
142     }
143     if (reduction_dimensions.dimensions[2] % (kWarpSize * kWarpSize * 64) ==
144         0) {
145       return {tile_z, 1, 64};
146     }
147     int cc_major = 0;
148     if (cuda_compute_capability) {
149       cc_major = cuda_compute_capability->cc_major;
150     }
151     int unroll_x = 8;
152     if (cc_major >= 6 && smallest_input_dtype_bits == 16) {
153       unroll_x = 16;
154     } else if (cc_major >= 6 && smallest_input_dtype_bits == 8) {
155       unroll_x = 64;
156     }
157     return {tile_z, 1, unroll_x};
158   }
159 
160   // Column reduction.
161   return {1, 128, 1};
162 }
163 
164 const char* const kCudnnBatchNormForwardInferenceCallTarget =
165     "__cudnn$batchNormalizationForwardInference";
166 const char* const kCudnnBatchNormForwardTrainingCallTarget =
167     "__cudnn$batchNormalizationForwardTraining";
168 const char* const kCudnnBatchNormBackwardCallTarget =
169     "__cudnn$batchNormalizationBackward";
170 
IsCustomCallToDnnBatchNorm(const HloInstruction & hlo)171 bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo) {
172   if (hlo.opcode() != HloOpcode::kCustomCall) {
173     return false;
174   }
175   const auto& target = hlo.custom_call_target();
176   return target == kCudnnBatchNormForwardInferenceCallTarget ||
177          target == kCudnnBatchNormForwardTrainingCallTarget ||
178          target == kCudnnBatchNormBackwardCallTarget;
179 }
180 
181 const char* const kGemmCallTarget = "__cublas$gemm";
182 const char* const kCudnnConvForwardCallTarget = "__cudnn$convForward";
183 const char* const kCudnnConvBackwardInputCallTarget =
184     "__cudnn$convBackwardInput";
185 const char* const kCudnnConvBackwardFilterCallTarget =
186     "__cudnn$convBackwardFilter";
187 const char* const kCudnnConvBiasActivationForwardCallTarget =
188     "__cudnn$convBiasActivationForward";
189 
IsCustomCallToDnnConvolution(const HloInstruction & hlo)190 bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) {
191   if (hlo.opcode() != HloOpcode::kCustomCall) {
192     return false;
193   }
194   const auto& target = hlo.custom_call_target();
195   return target == kCudnnConvForwardCallTarget ||
196          target == kCudnnConvBackwardInputCallTarget ||
197          target == kCudnnConvBackwardFilterCallTarget ||
198          target == kCudnnConvBiasActivationForwardCallTarget;
199 }
200 
201 const char* const kCusolverCholeskyCallTarget = "__cusolver$cholesky";
202 
IsCustomCallToCusolver(const HloInstruction & hlo)203 bool IsCustomCallToCusolver(const HloInstruction& hlo) {
204   if (hlo.opcode() != HloOpcode::kCustomCall) {
205     return false;
206   }
207   const auto& target = hlo.custom_call_target();
208   return target == kCusolverCholeskyCallTarget;
209 }
210 
ImplementedAsLibraryCall(const HloInstruction & hlo)211 bool ImplementedAsLibraryCall(const HloInstruction& hlo) {
212   return IsCublasGemm(hlo) || IsCustomCallToDnnBatchNorm(hlo) ||
213          IsCustomCallToDnnConvolution(hlo);
214 }
215 
GetReductionKindAndContiguousComponentsImpl(const Shape & input_shape,absl::Span<const int64> dims_to_reduce)216 static ReductionDimensions GetReductionKindAndContiguousComponentsImpl(
217     const Shape& input_shape, absl::Span<const int64> dims_to_reduce) {
218   DimensionVector dims_to_keep;
219   for (int64 dim = 0; dim < input_shape.rank(); ++dim) {
220     if (!absl::c_linear_search(dims_to_reduce, dim)) {
221       dims_to_keep.push_back(dim);
222     }
223   }
224 
225   if (dims_to_keep.empty()) {
226     return {/*is_row_reduction=*/true,
227             {1, 1, ShapeUtil::ElementsIn(input_shape)}};
228   }
229 
230   if (LayoutUtil::AreDimensionsConsecutive(input_shape.layout(),
231                                            dims_to_keep)) {
232     std::array<int64, 3> shape_partition =
233         PartitionShapeByMiddleDimensions(input_shape, dims_to_keep);
234     if (shape_partition[1] == 1) {
235       return {/*is_row_reduction=*/true,
236               {1, 1, shape_partition[0] * shape_partition[2]}};
237     }
238     if (shape_partition[2] == 1) {
239       return {/*is_row_reduction=*/false,
240               {1, shape_partition[0], shape_partition[1]}};
241     }
242     return {/*is_row_reduction=*/true, shape_partition};
243   }
244 
245   std::array<int64, 3> shape_partition =
246       PartitionShapeByMiddleDimensions(input_shape, dims_to_reduce);
247 
248   if (shape_partition[2] == 1) {
249     return {/*is_row_reduction=*/true,
250             {1, shape_partition[0], shape_partition[1]}};
251   }
252   return {/*is_row_reduction=*/false, shape_partition};
253 }
254 
IsReductionFromOrToContiguousDimensions(const HloInstruction & reduce)255 bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) {
256   if (HloOpcode::kReduce != reduce.opcode()) {
257     return false;
258   }
259 
260   // TODO(b/129698548): Remove this check after fixing the bug.
261   if (reduce.shape().element_type() == C128) {
262     return false;
263   }
264 
265   const HloInstruction* input = reduce.operand(0);
266   std::vector<int64> dims_to_keep;
267   for (int64 dim = 0; dim < input->shape().dimensions().size(); ++dim) {
268     if (!absl::c_linear_search(reduce.dimensions(), dim)) {
269       dims_to_keep.push_back(dim);
270     }
271   }
272 
273   // We support fast codegen for three cases:
274   // 1) Row reduction: (K, R)
275   // 2) Column reduction: (K, R, K)
276   // 3) "Batched" row reduction: (R, K, R)
277   if (!LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
278                                             dims_to_keep) &&
279       !LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
280                                             reduce.dimensions())) {
281     return false;
282   }
283 
284   ReductionDimensions reduction_dimensions =
285       GetReductionKindAndContiguousComponents(reduce);
286 
287   if (reduction_dimensions.is_row_reduction) {
288     // For row reduction, the tile block is 1 x tile_size_x, and we are reducing
289     // along tile_size_x which needs to be large enough to make the tiling
290     // implementation efficient.
291     return reduction_dimensions.dimensions[2] >= kWarpSize;
292   }
293 
294   // For column reduction, the tile block is tile_size_y x tile_size_x, and we
295   // are reducing along tile_size_y. Only tile_size_y needs to be
296   // large enough to make the tiling implementation efficient.
297   return reduction_dimensions.dimensions[1] >= kWarpSize;
298 }
299 
IsReductionFromOrToContiguousDimensions(mlir::Operation * reduce)300 bool IsReductionFromOrToContiguousDimensions(mlir::Operation* reduce) {
301   if (!mlir::isa<mlir::lmhlo::ReduceOp>(reduce) &&
302       !mlir::isa<mlir::mhlo::ReduceOp>(reduce)) {
303     return false;
304   }
305   std::vector<mlir::Value> results = GetHloOutputs(reduce);
306   CHECK_EQ(1, results.size());
307 
308   auto c128_type =
309       mlir::ComplexType::get(mlir::FloatType::getF64(reduce->getContext()));
310 
311   // TODO(b/129698548): Remove this check after fixing the bug.
312   if (results[0].getType().cast<mlir::ShapedType>().getElementType() ==
313       c128_type) {
314     return false;
315   }
316 
317   mlir::Value input = reduce->getOperand(0);
318   Shape operand_shape = TypeToShape(input.getType());
319   if (auto tensor_type = input.getType().dyn_cast<mlir::TensorType>()) {
320     if (auto attr = mlir::GetLayoutFromMlirHlo(input.getDefiningOp())) {
321       std::vector<int64> minor_to_major;
322       absl::c_transform(
323           attr, std::back_inserter(minor_to_major),
324           std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
325       *operand_shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
326     }
327   }
328 
329   std::vector<int64> dimensions;
330   {
331     auto attr = reduce->getAttrOfType<mlir::DenseIntElementsAttr>("dimensions");
332     CHECK(attr);
333     absl::c_transform(
334         attr, std::back_inserter(dimensions),
335         std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
336   }
337 
338   std::vector<int64> dims_to_keep;
339   for (int64 dim = 0; dim < operand_shape.dimensions().size(); ++dim) {
340     if (!absl::c_linear_search(dimensions, dim)) {
341       dims_to_keep.push_back(dim);
342     }
343   }
344 
345   // We support fast codegen for three cases:
346   // 1) Row reduction: (K, R)
347   // 2) Column reduction: (K, R, K)
348   // 3) "Batched" row reduction: (R, K, R)
349   if (!LayoutUtil::AreDimensionsConsecutive(operand_shape.layout(),
350                                             dims_to_keep) &&
351       !LayoutUtil::AreDimensionsConsecutive(operand_shape.layout(),
352                                             dimensions)) {
353     return false;
354   }
355 
356   ReductionDimensions reduction_dimensions =
357       GetReductionKindAndContiguousComponentsImpl(operand_shape, dimensions);
358 
359   if (reduction_dimensions.is_row_reduction) {
360     // For row reduction, the tile block is 1 x tile_size_x, and we are reducing
361     // along tile_size_x which needs to be large enough to make the tiling
362     // implementation efficient.
363     return reduction_dimensions.dimensions[2] >= kWarpSize;
364   }
365 
366   // For column reduction, the tile block is tile_size_y x tile_size_x, and we
367   // are reducing along tile_size_y. Only tile_size_y needs to be
368   // large enough to make the tiling implementation efficient.
369   return reduction_dimensions.dimensions[1] >= kWarpSize;
370 }
371 
IsInputFusibleSlices(mlir::Operation * unnested_hlo,bool verify_no_strides)372 bool IsInputFusibleSlices(mlir::Operation* unnested_hlo,
373                           bool verify_no_strides) {
374   auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(unnested_hlo);
375   if (!fusion) {
376     return false;
377   }
378 
379   auto is_non_strided = [](mlir::DenseIntElementsAttr strides) -> bool {
380     return absl::c_all_of(
381         strides, [](const llvm::APInt& stride) { return stride == 1; });
382   };
383 
384   for (mlir::Value value : fusion.getFusionResults()) {
385     auto slice =
386         mlir::dyn_cast_or_null<mlir::mhlo::SliceOp>(value.getDefiningOp());
387     if (!slice) {
388       return false;
389     }
390     if (verify_no_strides && !is_non_strided(slice.strides())) {
391       return false;
392     }
393   }
394   return true;
395 }
396 
GetReductionKindAndContiguousComponents(const HloInstruction & reduce)397 ReductionDimensions GetReductionKindAndContiguousComponents(
398     const HloInstruction& reduce) {
399   return GetReductionKindAndContiguousComponentsImpl(reduce.operand(0)->shape(),
400                                                      reduce.dimensions());
401 }
402 
GetReductionKindAndContiguousComponents(mlir::Operation * reduce)403 ReductionDimensions GetReductionKindAndContiguousComponents(
404     mlir::Operation* reduce) {
405   mlir::Value input = reduce->getOperand(0);
406   Shape operand_shape = TypeToShape(input.getType());
407   std::vector<int64> dimensions;
408   {
409     auto attr = reduce->getAttrOfType<mlir::DenseIntElementsAttr>("dimensions");
410     CHECK(attr);
411     absl::c_transform(
412         attr, std::back_inserter(dimensions),
413         std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
414   }
415   return GetReductionKindAndContiguousComponentsImpl(operand_shape, dimensions);
416 }
417 
418 // This emits a device-side call to
419 // "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see
420 // http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls
EmitPrintf(absl::string_view fmt,absl::Span<llvm::Value * const> arguments,llvm::IRBuilder<> * builder)421 llvm::Value* EmitPrintf(absl::string_view fmt,
422                         absl::Span<llvm::Value* const> arguments,
423                         llvm::IRBuilder<>* builder) {
424   std::vector<llvm::Type*> argument_types;
425 
426   // Variadic arguments implicit promotion [1] converts float to double,
427   // and bool/char/short are converted to int.
428   // [1] https://en.cppreference.com/w/cpp/language/variadic_arguments
429   auto requires_int32_promotion = [](llvm::Type* type) {
430     return type->isIntegerTy(/*BitWidth=*/1) ||
431            type->isIntegerTy(/*BitWidth=*/8) ||
432            type->isIntegerTy(/*BitWidth=*/16);
433   };
434   auto requires_double_promotion = [](llvm::Type* type) {
435     return type->isFloatingPointTy();
436   };
437 
438   for (auto argument : arguments) {
439     llvm::Type* type = argument->getType();
440     if (requires_double_promotion(type)) {
441       argument_types.push_back(builder->getDoubleTy());
442     } else if (requires_int32_promotion(type)) {
443       argument_types.push_back(builder->getInt32Ty());
444     } else {
445       argument_types.push_back(type);
446     }
447   }
448   auto* arguments_type = llvm::StructType::create(argument_types);
449   llvm::Value* arguments_ptr = builder->CreateAlloca(arguments_type);
450   for (size_t i = 0; i < arguments.size(); ++i) {
451     llvm::Value* value = arguments[i];
452     llvm::Type* type = value->getType();
453     if (requires_double_promotion(type)) {
454       value = builder->CreateFPCast(value, builder->getDoubleTy());
455     } else if (requires_int32_promotion(type)) {
456       value = builder->CreateIntCast(value, builder->getInt32Ty(),
457                                      /*isSigned=*/true);
458     }
459     builder->CreateStore(
460         value, builder->CreateGEP(arguments_ptr, {builder->getInt64(0),
461                                                   builder->getInt32(i)}));
462   }
463   llvm::Type* ptr_ty = builder->getInt8Ty()->getPointerTo();
464   return builder->CreateCall(
465       builder->GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
466           "vprintf",
467           llvm::FunctionType::get(builder->getInt32Ty(), {ptr_ty, ptr_ty},
468                                   /*isVarArg=*/false)),
469       {builder->CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)),
470        builder->CreatePointerCast(arguments_ptr, ptr_ty)});
471 }
472 
473 // Helper function to emit call to AMDGPU shfl_down function.
EmitAMDGPUShflDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * b)474 llvm::Value* EmitAMDGPUShflDown(llvm::Value* value, llvm::Value* offset,
475                                 llvm::IRBuilder<>* b) {
476   llvm::Module* module = b->GetInsertBlock()->getModule();
477   CHECK_EQ(value->getType()->getPrimitiveSizeInBits(), 32);
478   auto* i32_ty = b->getInt32Ty();
479   llvm::FunctionCallee shfl_fn = module->getOrInsertFunction(
480       llvm_ir::AsStringRef("__ockl_readuplane_i32"),
481       llvm::FunctionType::get(/*Result=*/i32_ty, {i32_ty, i32_ty},
482                               /*isVarArg=*/false));
483   // AMDGPU device function requires first argument as i32.
484   llvm::Value* result =
485       b->CreateCall(shfl_fn, {b->CreateBitCast(value, i32_ty), offset});
486   // AMDGPU device function always returns an i32 type.
487   return b->CreateBitCast(result, value->getType());
488 }
489 
490 // Helper function to emit call to NVPTX shfl_down intrinsic.
EmitNVPTXShflDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * b)491 llvm::Value* EmitNVPTXShflDown(llvm::Value* value, llvm::Value* offset,
492                                llvm::IRBuilder<>* b) {
493   llvm::Module* module = b->GetInsertBlock()->getModule();
494   llvm::Intrinsic::ID llvm_intrinsic_id;
495   CHECK_EQ(value->getType()->getPrimitiveSizeInBits(), 32);
496   if (value->getType()->isFloatTy()) {
497     llvm_intrinsic_id = llvm::Intrinsic::nvvm_shfl_sync_down_f32;
498   } else {
499     llvm_intrinsic_id = llvm::Intrinsic::nvvm_shfl_sync_down_i32;
500   }
501   llvm::Function* intrinsic =
502       llvm::Intrinsic::getDeclaration(module, llvm_intrinsic_id, {});
503   return b->CreateCall(
504       intrinsic, {b->getInt32(-1), value, offset, b->getInt32(kWarpSize - 1)});
505 }
506 
EmitFullWarpShuffleDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * builder)507 llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
508                                      llvm::IRBuilder<>* builder) {
509   int bit_width = value->getType()->getPrimitiveSizeInBits();
510   llvm::Module* module = builder->GetInsertBlock()->getModule();
511   llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
512 
513   // Special case for efficiency
514   if (value->getType()->isFloatTy() && bit_width == 32) {
515     if (target_triple.isNVPTX()) {
516       return EmitNVPTXShflDown(value, offset, builder);
517     } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
518       return EmitAMDGPUShflDown(value, offset, builder);
519     } else {
520       LOG(FATAL) << "Invalid triple " << target_triple.str();
521     }
522   }
523 
524   // We must split values wider than 32 bits as the "shfl" instruction operates
525   // on 32-bit values.
526   int num_segments = CeilOfRatio(bit_width, 32);
527   llvm::Value* x = builder->CreateBitCast(
528       builder->CreateZExt(
529           builder->CreateBitCast(value, builder->getIntNTy(bit_width)),
530           builder->getIntNTy(32 * num_segments)),
531       llvm::VectorType::get(builder->getInt32Ty(), num_segments, false));
532   for (int i = 0; i < num_segments; ++i) {
533     llvm::Value* insert_val;
534     if (target_triple.isNVPTX()) {
535       insert_val = EmitNVPTXShflDown(builder->CreateExtractElement(x, i),
536                                      offset, builder);
537     } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
538       insert_val = EmitAMDGPUShflDown(builder->CreateExtractElement(x, i),
539                                       offset, builder);
540     } else {
541       LOG(FATAL) << "Invalid triple " << target_triple.str();
542     }
543     x = builder->CreateInsertElement(x, insert_val, i);
544   }
545   return builder->CreateBitCast(
546       builder->CreateTrunc(
547           builder->CreateBitCast(x, builder->getIntNTy(32 * num_segments)),
548           builder->getIntNTy(bit_width)),
549       value->getType());
550 }
551 
GetCudnnConvKind(const HloCustomCallInstruction * instr)552 StatusOr<CudnnConvKind> GetCudnnConvKind(
553     const HloCustomCallInstruction* instr) {
554   absl::string_view target = instr->custom_call_target();
555   if (target == kCudnnConvForwardCallTarget) {
556     return CudnnConvKind::kForward;
557   }
558   if (target == kCudnnConvBackwardInputCallTarget) {
559     return CudnnConvKind::kBackwardInput;
560   }
561   if (target == kCudnnConvBackwardFilterCallTarget) {
562     return CudnnConvKind::kBackwardFilter;
563   }
564   if (target == kCudnnConvBiasActivationForwardCallTarget) {
565     return CudnnConvKind::kForwardActivation;
566   }
567   return InternalError("Unexpected call target: %s", target);
568 }
569 
CudnnConvKindToString(CudnnConvKind kind)570 string CudnnConvKindToString(CudnnConvKind kind) {
571   switch (kind) {
572     case CudnnConvKind::kForward:
573       return "forward";
574     case CudnnConvKind::kBackwardFilter:
575       return "backward_filter";
576     case CudnnConvKind::kBackwardInput:
577       return "backward_input";
578     case CudnnConvKind::kForwardActivation:
579       return "forward with activation";
580   }
581 }
582 
IsBlock0Thread0(llvm::IRBuilder<> * b)583 llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) {
584   llvm::Value* is_thread0 = b->CreateICmpEQ(
585       b->getInt32(0),
586       EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b));
587 
588   llvm::Value* is_block0 = b->CreateICmpEQ(
589       b->getInt32(0),
590       EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b));
591   return b->CreateAnd(is_thread0, is_block0);
592 }
593 
IsFusedReductionOutputConsistent(const HloInstruction * inst,const HloInstruction * first_reduce)594 bool IsFusedReductionOutputConsistent(const HloInstruction* inst,
595                                       const HloInstruction* first_reduce) {
596   if (IsReductionFromOrToContiguousDimensions(*inst)) {
597     // Shapes, layouts and dimensions must be the same for all reduces
598     // inside of this fusion.
599     // TODO(tjoerg): Relax the shape constraint. The datatype does not matter.
600     return ShapeUtil::Equal(first_reduce->shape(), inst->shape()) &&
601            ShapeUtil::Equal(first_reduce->operand(0)->shape(),
602                             inst->operand(0)->shape()) &&
603            ShapeUtil::Equal(first_reduce->operand(1)->shape(),
604                             inst->operand(1)->shape()) &&
605            first_reduce->dimensions() == inst->dimensions();
606   }
607   return ShapeUtil::CompatibleIgnoringElementType(
608              first_reduce->operand(0)->shape(), inst->shape()) &&
609          LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(),
610                            inst->shape().layout());
611 }
612 
IsFusedReductionOutputConsistent(mlir::mhlo::ReduceOp inst,mlir::mhlo::ReduceOp first_reduce)613 bool IsFusedReductionOutputConsistent(mlir::mhlo::ReduceOp inst,
614                                       mlir::mhlo::ReduceOp first_reduce) {
615   CHECK_EQ(1, first_reduce.getNumResults());
616   Shape first_reduce_operand_shape =
617       TypeToShape(first_reduce.operands()[0].getType());
618   CHECK_EQ(1, inst.getNumResults());
619   auto inst_shape = TypeToShape(inst.getResult(0).getType());
620 
621   if (IsReductionFromOrToContiguousDimensions(inst)) {
622     auto first_reduce_shape = TypeToShape(first_reduce.getResult(0).getType());
623     auto first_reduce_init_shape =
624         TypeToShape(first_reduce.init_values()[0].getType());
625 
626     auto inst_operand_shape = TypeToShape(inst.operands()[0].getType());
627     auto inst_init_shape = TypeToShape(inst.init_values()[0].getType());
628 
629     // Shapes, layouts and dimensions must be the same for all reduces
630     // inside of this fusion.
631     // TODO(tjoerg): Relax the shape constraint. The datatype does not matter.
632     if (!(ShapeUtil::Equal(first_reduce_shape, inst_shape) &&
633           ShapeUtil::Equal(first_reduce_operand_shape, inst_operand_shape) &&
634           ShapeUtil::Equal(first_reduce_init_shape, inst_init_shape) &&
635           absl::c_equal(first_reduce.dimensions(), inst.dimensions()))) {
636       return false;
637     }
638   } else {
639     if (!(ShapeUtil::CompatibleIgnoringElementType(first_reduce_operand_shape,
640                                                    inst_shape) &&
641           LayoutUtil::Equal(first_reduce_operand_shape.layout(),
642                             inst_shape.layout()))) {
643       return false;
644     }
645   }
646   return true;
647 }
648 
649 // Given an LMHLO op, returns the operand index of the first output operand.
650 //
651 // Notice that an operand alised to an output isn't an output, even though in
652 // that case WritesMlirBuffer() returns true on that operand.
653 //
654 // An operand is !WritesMlirBuffer() || equals (aliases) to a later operand. An
655 // output is the opposite, being both WritesMlirBuffer() and does not equal to
656 // any later operand.
PartitionLmhloOperandsAndOutputs(mlir::Operation * op)657 int PartitionLmhloOperandsAndOutputs(mlir::Operation* op) {
658   CHECK(op->getDialect() == op->getContext()->getLoadedDialect("lmhlo"));
659 
660   int i;
661   for (i = op->getOperands().size() - 1; i >= 0; i--) {
662     const bool aliased =
663         std::find(op->getOperands().begin() + i + 1, op->getOperands().end(),
664                   op->getOperand(i)) != op->getOperands().end();
665     if (!WritesMlirBuffer(op, op->getOperand(i)) || aliased) {
666       break;
667     }
668   }
669   return i + 1;
670 }
671 
GetHloOperands(mlir::Operation * op)672 std::vector<mlir::Value> GetHloOperands(mlir::Operation* op) {
673   if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
674     return ToStdVector(fusion.getInputBuffers());
675   }
676   if (op->getDialect() == op->getContext()->getLoadedDialect("lmhlo")) {
677     int output_start = PartitionLmhloOperandsAndOutputs(op);
678     std::vector<mlir::Value> operands;
679     operands.reserve(output_start);
680     for (int i = 0; i < output_start; i++) {
681       operands.push_back(op->getOperand(i));
682     }
683     return operands;
684   }
685   if (op->getDialect() == op->getContext()->getLoadedDialect("mhlo")) {
686     return std::vector<mlir::Value>(op->getOperands().begin(),
687                                     op->getOperands().end());
688   }
689   LOG(FATAL) << "Unexpected op: " << MlirToString(op);
690 }
691 
GetHloOutputs(mlir::Operation * op)692 std::vector<mlir::Value> GetHloOutputs(mlir::Operation* op) {
693   if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
694     return ToStdVector(fusion.getOutputBuffers());
695   }
696   if (op->getDialect() == op->getContext()->getLoadedDialect("lmhlo")) {
697     int output_start = PartitionLmhloOperandsAndOutputs(op);
698     std::vector<mlir::Value> outputs;
699     for (int i = output_start; i < op->getNumOperands(); i++) {
700       outputs.push_back(op->getOperand(i));
701     }
702     return outputs;
703   }
704   if (op->getDialect() == op->getContext()->getLoadedDialect("mhlo")) {
705     return std::vector<mlir::Value>(op->getResults().begin(),
706                                     op->getResults().end());
707   }
708   LOG(FATAL) << "Unexpected op: " << MlirToString(op);
709 }
710 
WritesMlirBuffer(mlir::Operation * op,mlir::Value operand)711 bool WritesMlirBuffer(mlir::Operation* op, mlir::Value operand) {
712   llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 2> effects;
713   mlir::cast<mlir::MemoryEffectOpInterface>(op).getEffectsOnValue(operand,
714                                                                   effects);
715   return absl::c_any_of(
716       effects, [](const mlir::MemoryEffects::EffectInstance& instance) {
717         return mlir::isa<mlir::MemoryEffects::Write>(instance.getEffect());
718       });
719 }
720 
GetMemRefSizeInBytes(mlir::MemRefType type)721 static int64_t GetMemRefSizeInBytes(mlir::MemRefType type) {
722   // For i1 memrefs, the underlying allocation is 8 bits.
723   if (type.getElementType().isInteger(/*width=*/1)) {
724     return type.getNumElements();
725   } else {
726     return type.getSizeInBits() / CHAR_BIT;
727   }
728 }
729 
GetAllocationIndex(mlir::BlockArgument func_arg)730 static int64_t GetAllocationIndex(mlir::BlockArgument func_arg) {
731   auto func_op =
732       mlir::cast<mlir::FuncOp>(func_arg.getParentRegion()->getParentOp());
733   return func_op
734       .getArgAttrOfType<mlir::IntegerAttr>(func_arg.getArgNumber(),
735                                            "lmhlo.alloc")
736       .getValue()
737       .getSExtValue();
738 }
739 
GetAllocationSliceForMlir(mlir::Value v,absl::Span<const BufferAllocation> allocations)740 StatusOr<BufferAllocation::Slice> GetAllocationSliceForMlir(
741     mlir::Value v, absl::Span<const BufferAllocation> allocations) {
742   int64 size = GetMemRefSizeInBytes(v.getType().cast<mlir::MemRefType>());
743 
744   if (auto arg = v.dyn_cast<mlir::BlockArgument>()) {
745     return BufferAllocation::Slice(&allocations[GetAllocationIndex(arg)], 0,
746                                    size);
747   }
748 
749   // We match the following patterns here:
750   //  base := ViewOp(arg) | get_global_memref (global_memref)
751   //  root := base | MemRefReinterpretCastOp(base)
752 
753   if (mlir::Operation* op = v.getDefiningOp()) {
754     if (auto cast = mlir::dyn_cast<mlir::MemRefReinterpretCastOp>(op)) {
755       mlir::Value source = cast.getViewSource();
756       op = source.getDefiningOp();
757       if (!op) {
758         return Unimplemented("MemRefReinterpretCastOp has to wrap an op");
759       }
760     }
761     if (auto view = mlir::dyn_cast<mlir::ViewOp>(op)) {
762       return BufferAllocation::Slice(
763           &allocations[GetAllocationIndex(
764               view.source().cast<mlir::BlockArgument>())],
765           mlir::cast<mlir::ConstantOp>(view.byte_shift().getDefiningOp())
766               .value()
767               .cast<mlir::IntegerAttr>()
768               .getValue()
769               .getSExtValue(),
770           size);
771     } else if (auto get_global = mlir::dyn_cast<mlir::GetGlobalMemrefOp>(op)) {
772       auto module = get_global->getParentOfType<mlir::ModuleOp>();
773       auto global = mlir::cast<mlir::GlobalMemrefOp>(
774           module.lookupSymbol(get_global.name()));
775       int64_t index =
776           global->getAttrOfType<mlir::IntegerAttr>("lmhlo.alloc").getInt();
777       return BufferAllocation::Slice(&allocations[index], 0,
778                                      allocations[index].size());
779     }
780     return Unimplemented("MemRefReinterpretCastOp has to wrap a ViewOp");
781   }
782 
783   return Unimplemented(
784       "Operand has to be in the form of ViewOp(arg) or "
785       "StaticMemRefCastOp(ViewOp(arg))");
786 }
787 
CanEmitFusedDynamicUpdateSliceInPlaceForGpu(mlir::lmhlo::FusionOp fusion,absl::Span<const BufferAllocation> allocations)788 bool CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
789     mlir::lmhlo::FusionOp fusion,
790     absl::Span<const BufferAllocation> allocations) {
791   auto results = fusion.getFusionResults();
792   if (results.size() != 1) {
793     return false;
794   }
795   auto dus = mlir::dyn_cast<mlir::mhlo::DynamicUpdateSliceOp>(
796       results[0].getDefiningOp());
797   if (!dus) {
798     return false;
799   }
800 
801   auto output_buffers = fusion.getOutputBuffers();
802   CHECK_EQ(1, output_buffers.size());
803   auto parameter =
804       mlir::dyn_cast<mlir::TensorLoadOp>(dus.operand().getDefiningOp());
805 
806   if (!parameter) {
807     return false;
808   }
809 
810   auto maybe_lhs = GetAllocationSliceForMlir(parameter.memref(), allocations);
811   auto maybe_rhs = GetAllocationSliceForMlir(output_buffers[0], allocations);
812   return maybe_lhs.ok() && maybe_rhs.ok() && *maybe_lhs == *maybe_rhs;
813 }
814 
815 }  // namespace gpu
816 }  // namespace xla
817