1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering LHLO dialect to GPU dialect.
17
18 #include <cstdint>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
23 #include "mlir/Dialect/Affine/IR/AffineOps.h"
24 #include "mlir/Dialect/GPU/GPUDialect.h"
25 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
26 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
27 #include "mlir/Dialect/SCF/SCF.h"
28 #include "mlir/Dialect/StandardOps/IR/Ops.h"
29 #include "mlir/IR/Attributes.h"
30 #include "mlir/IR/BlockAndValueMapping.h"
31 #include "mlir/IR/Builders.h"
32 #include "mlir/IR/BuiltinOps.h"
33 #include "mlir/IR/BuiltinTypes.h"
34 #include "mlir/IR/Location.h"
35 #include "mlir/IR/MLIRContext.h"
36 #include "mlir/IR/Operation.h"
37 #include "mlir/IR/PatternMatch.h"
38 #include "mlir/Pass/Pass.h"
39 #include "mlir/Transforms/DialectConversion.h"
40
41 namespace mlir {
42 namespace lmhlo {
43 namespace {
44
45 // A simple translation of LHLO reduce operations to a corresponding gpu
46 // launch operation. The transformation does no tiling and also only supports
47 // 1d results.
48 class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
49 public:
50 using OpConversionPattern::OpConversionPattern;
51
matchAndRewrite(ReduceOp reduce_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const52 LogicalResult matchAndRewrite(
53 ReduceOp reduce_op, ArrayRef<Value> args,
54 ConversionPatternRewriter& rewriter) const final {
55 auto loc = reduce_op.getLoc();
56 // Only support 1d reductions for now.
57 int64_t size = 0;
58 for (auto result : reduce_op.out()) {
59 auto shaped_type = result.getType().dyn_cast<ShapedType>();
60 if (!shaped_type || shaped_type.getRank() != 1) {
61 return failure();
62 }
63 auto dim_size = shaped_type.getDimSize(0);
64 if (size && size != dim_size) {
65 return failure();
66 }
67 size = dim_size;
68 }
69
70 auto reducing_dimension = *reduce_op.dimensions().int_value_begin();
71
72 // Require all inputs to have the same shape.
73 int64_t reduce_dim_size = 0;
74 for (auto input : reduce_op.operands()) {
75 auto shaped_type = input.getType().dyn_cast<ShapedType>();
76 if (!shaped_type || !shaped_type.hasStaticShape()) {
77 return failure();
78 }
79 reduce_dim_size =
80 shaped_type.getDimSize(reducing_dimension.getSExtValue());
81 }
82
83 // Create a launch that is parallel in the result dimension.
84 auto block_size_x = rewriter.create<mlir::ConstantOp>(
85 loc, rewriter.getIndexType(),
86 rewriter.getIntegerAttr(rewriter.getIndexType(), size));
87 auto one = rewriter.create<mlir::ConstantOp>(
88 loc, rewriter.getIndexType(),
89 rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
90 auto launch_op = rewriter.create<mlir::gpu::LaunchOp>(
91 loc, one, one, one, block_size_x, one, one);
92 {
93 OpBuilder::InsertionGuard guard(rewriter);
94 rewriter.setInsertionPointToEnd(&launch_op.body().front());
95 auto index = launch_op.getThreadIds().x;
96
97 // Load the initial value and store it to the output.
98 for (auto pair : llvm::zip(reduce_op.init_values(), reduce_op.out())) {
99 auto init_value = rewriter.create<mlir::LoadOp>(loc, std::get<0>(pair));
100 rewriter.create<mlir::StoreOp>(loc, init_value, std::get<1>(pair),
101 ArrayRef<Value>{index});
102 }
103
104 // Insert a loop into the body to compute the reduction. The loop ranges
105 // from [0.dim).
106 auto zero = rewriter.create<mlir::ConstantOp>(
107 loc, rewriter.getIndexType(),
108 rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
109 // TODO(b/137624192) Use dimOp to make it shape independent.
110 auto upper = rewriter.create<mlir::ConstantOp>(
111 loc, rewriter.getIndexType(),
112 rewriter.getIntegerAttr(rewriter.getIndexType(), reduce_dim_size));
113 auto step = rewriter.create<mlir::ConstantOp>(
114 loc, rewriter.getIndexType(),
115 rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
116 auto loop = rewriter.create<mlir::scf::ForOp>(loc, zero, upper, step);
117
118 rewriter.setInsertionPointToStart(loop.getBody());
119 // Compute memrefs for the value to reduce. This makes it easier to just
120 // inline the body.
121 auto output = *reduce_op.out().begin();
122 auto resType = MemRefType::get(
123 llvm::None, getElementTypeOrSelf(output.getType()),
124 makeStridedLinearLayoutMap(llvm::None,
125 MemRefType::getDynamicStrideOrOffset(),
126 rewriter.getContext()));
127 OpFoldResult offset = launch_op.getThreadIds().x;
128 auto oneAttr = rewriter.getI64IntegerAttr(1);
129 OpFoldResult size = oneAttr;
130 OpFoldResult stride = oneAttr;
131 auto accumulator = rewriter.create<SubViewOp>(loc, resType, output,
132 offset, size, stride);
133 llvm::SmallVector<Value, 4> indexings;
134 auto input_buffer = *reduce_op.operands().begin();
135 auto input_type_rank =
136 input_buffer.getType().cast<MemRefType>().getRank();
137
138 Value input = *reduce_op.operand_begin();
139 SmallVector<OpFoldResult> offsets = llvm::to_vector<4>(llvm::map_range(
140 llvm::seq<int>(0, input_type_rank), [&](int dim) -> OpFoldResult {
141 return dim == reducing_dimension ? loop.getInductionVar()
142 : launch_op.getThreadIds().x;
143 }));
144 SmallVector<OpFoldResult> sizes(input_type_rank, oneAttr);
145 SmallVector<OpFoldResult> strides(input_type_rank, oneAttr);
146 auto rhs = rewriter.create<SubViewOp>(loc, accumulator.getType(), input,
147 offsets, sizes, strides);
148
149 // Now copy over the actual body of the reduction, leaving out the
150 // terminator.
151 BlockAndValueMapping mapping;
152 mapping.map(reduce_op.body().getArgument(0), accumulator);
153 mapping.map(reduce_op.body().getArgument(1), rhs);
154 mapping.map(reduce_op.body().getArgument(2), accumulator);
155 for (auto& nested : reduce_op.body().front().without_terminator()) {
156 auto clone = rewriter.clone(nested, mapping);
157 for (auto pair : llvm::zip(nested.getResults(), clone->getResults())) {
158 mapping.map(std::get<0>(pair), std::get<1>(pair));
159 }
160 }
161
162 // Finally, insert the terminator for the launchOp.
163 rewriter.setInsertionPointToEnd(&launch_op.body().front());
164 rewriter.create<mlir::gpu::TerminatorOp>(loc);
165 }
166
167 rewriter.eraseOp(reduce_op);
168 return success();
169 };
170 };
171
172 struct LhloLegalizeToGpuPass
173 : public PassWrapper<LhloLegalizeToGpuPass, FunctionPass> {
getDependentDialectsmlir::lmhlo::__anond23aa8980111::LhloLegalizeToGpuPass174 void getDependentDialects(DialectRegistry& registry) const override {
175 registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
176 scf::SCFDialect>();
177 }
178
runOnFunctionmlir::lmhlo::__anond23aa8980111::LhloLegalizeToGpuPass179 void runOnFunction() override {
180 OwningRewritePatternList patterns;
181 ConversionTarget target(getContext());
182 target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
183 gpu::GPUDialect, scf::SCFDialect, LmhloDialect>();
184 target.addIllegalOp<ReduceOp>();
185 auto func = getFunction();
186 patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());
187 if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
188 signalPassFailure();
189 }
190 }
191 };
192
193 } // namespace
194
createLegalizeToGpuPass()195 std::unique_ptr<FunctionPass> createLegalizeToGpuPass() {
196 return std::make_unique<LhloLegalizeToGpuPass>();
197 }
198
199 } // namespace lmhlo
200 } // namespace mlir
201