1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering LHLO dialect to GPU dialect.
17 
18 #include <cstdint>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
23 #include "mlir/Dialect/Affine/IR/AffineOps.h"
24 #include "mlir/Dialect/GPU/GPUDialect.h"
25 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
26 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
27 #include "mlir/Dialect/SCF/SCF.h"
28 #include "mlir/Dialect/StandardOps/IR/Ops.h"
29 #include "mlir/IR/Attributes.h"
30 #include "mlir/IR/BlockAndValueMapping.h"
31 #include "mlir/IR/Builders.h"
32 #include "mlir/IR/BuiltinOps.h"
33 #include "mlir/IR/BuiltinTypes.h"
34 #include "mlir/IR/Location.h"
35 #include "mlir/IR/MLIRContext.h"
36 #include "mlir/IR/Operation.h"
37 #include "mlir/IR/PatternMatch.h"
38 #include "mlir/Pass/Pass.h"
39 #include "mlir/Transforms/DialectConversion.h"
40 
41 namespace mlir {
42 namespace lmhlo {
43 namespace {
44 
45 // A simple translation of LHLO reduce operations to a corresponding gpu
46 // launch operation. The transformation does no tiling and also only supports
47 // 1d results.
48 class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
49  public:
50   using OpConversionPattern::OpConversionPattern;
51 
matchAndRewrite(ReduceOp reduce_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const52   LogicalResult matchAndRewrite(
53       ReduceOp reduce_op, ArrayRef<Value> args,
54       ConversionPatternRewriter& rewriter) const final {
55     auto loc = reduce_op.getLoc();
56     // Only support 1d reductions for now.
57     int64_t size = 0;
58     for (auto result : reduce_op.out()) {
59       auto shaped_type = result.getType().dyn_cast<ShapedType>();
60       if (!shaped_type || shaped_type.getRank() != 1) {
61         return failure();
62       }
63       auto dim_size = shaped_type.getDimSize(0);
64       if (size && size != dim_size) {
65         return failure();
66       }
67       size = dim_size;
68     }
69 
70     auto reducing_dimension = *reduce_op.dimensions().int_value_begin();
71 
72     // Require all inputs to have the same shape.
73     int64_t reduce_dim_size = 0;
74     for (auto input : reduce_op.operands()) {
75       auto shaped_type = input.getType().dyn_cast<ShapedType>();
76       if (!shaped_type || !shaped_type.hasStaticShape()) {
77         return failure();
78       }
79       reduce_dim_size =
80           shaped_type.getDimSize(reducing_dimension.getSExtValue());
81     }
82 
83     // Create a launch that is parallel in the result dimension.
84     auto block_size_x = rewriter.create<mlir::ConstantOp>(
85         loc, rewriter.getIndexType(),
86         rewriter.getIntegerAttr(rewriter.getIndexType(), size));
87     auto one = rewriter.create<mlir::ConstantOp>(
88         loc, rewriter.getIndexType(),
89         rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
90     auto launch_op = rewriter.create<mlir::gpu::LaunchOp>(
91         loc, one, one, one, block_size_x, one, one);
92     {
93       OpBuilder::InsertionGuard guard(rewriter);
94       rewriter.setInsertionPointToEnd(&launch_op.body().front());
95       auto index = launch_op.getThreadIds().x;
96 
97       // Load the initial value and store it to the output.
98       for (auto pair : llvm::zip(reduce_op.init_values(), reduce_op.out())) {
99         auto init_value = rewriter.create<mlir::LoadOp>(loc, std::get<0>(pair));
100         rewriter.create<mlir::StoreOp>(loc, init_value, std::get<1>(pair),
101                                        ArrayRef<Value>{index});
102       }
103 
104       // Insert a loop into the body to compute the reduction. The loop ranges
105       // from [0.dim).
106       auto zero = rewriter.create<mlir::ConstantOp>(
107           loc, rewriter.getIndexType(),
108           rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
109       // TODO(b/137624192) Use dimOp to make it shape independent.
110       auto upper = rewriter.create<mlir::ConstantOp>(
111           loc, rewriter.getIndexType(),
112           rewriter.getIntegerAttr(rewriter.getIndexType(), reduce_dim_size));
113       auto step = rewriter.create<mlir::ConstantOp>(
114           loc, rewriter.getIndexType(),
115           rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
116       auto loop = rewriter.create<mlir::scf::ForOp>(loc, zero, upper, step);
117 
118       rewriter.setInsertionPointToStart(loop.getBody());
119       // Compute memrefs for the value to reduce. This makes it easier to just
120       // inline the body.
121       auto output = *reduce_op.out().begin();
122       auto resType = MemRefType::get(
123           llvm::None, getElementTypeOrSelf(output.getType()),
124           makeStridedLinearLayoutMap(llvm::None,
125                                      MemRefType::getDynamicStrideOrOffset(),
126                                      rewriter.getContext()));
127       OpFoldResult offset = launch_op.getThreadIds().x;
128       auto oneAttr = rewriter.getI64IntegerAttr(1);
129       OpFoldResult size = oneAttr;
130       OpFoldResult stride = oneAttr;
131       auto accumulator = rewriter.create<SubViewOp>(loc, resType, output,
132                                                     offset, size, stride);
133       llvm::SmallVector<Value, 4> indexings;
134       auto input_buffer = *reduce_op.operands().begin();
135       auto input_type_rank =
136           input_buffer.getType().cast<MemRefType>().getRank();
137 
138       Value input = *reduce_op.operand_begin();
139       SmallVector<OpFoldResult> offsets = llvm::to_vector<4>(llvm::map_range(
140           llvm::seq<int>(0, input_type_rank), [&](int dim) -> OpFoldResult {
141             return dim == reducing_dimension ? loop.getInductionVar()
142                                              : launch_op.getThreadIds().x;
143           }));
144       SmallVector<OpFoldResult> sizes(input_type_rank, oneAttr);
145       SmallVector<OpFoldResult> strides(input_type_rank, oneAttr);
146       auto rhs = rewriter.create<SubViewOp>(loc, accumulator.getType(), input,
147                                             offsets, sizes, strides);
148 
149       // Now copy over the actual body of the reduction, leaving out the
150       // terminator.
151       BlockAndValueMapping mapping;
152       mapping.map(reduce_op.body().getArgument(0), accumulator);
153       mapping.map(reduce_op.body().getArgument(1), rhs);
154       mapping.map(reduce_op.body().getArgument(2), accumulator);
155       for (auto& nested : reduce_op.body().front().without_terminator()) {
156         auto clone = rewriter.clone(nested, mapping);
157         for (auto pair : llvm::zip(nested.getResults(), clone->getResults())) {
158           mapping.map(std::get<0>(pair), std::get<1>(pair));
159         }
160       }
161 
162       // Finally, insert the terminator for the launchOp.
163       rewriter.setInsertionPointToEnd(&launch_op.body().front());
164       rewriter.create<mlir::gpu::TerminatorOp>(loc);
165     }
166 
167     rewriter.eraseOp(reduce_op);
168     return success();
169   };
170 };
171 
172 struct LhloLegalizeToGpuPass
173     : public PassWrapper<LhloLegalizeToGpuPass, FunctionPass> {
getDependentDialectsmlir::lmhlo::__anond23aa8980111::LhloLegalizeToGpuPass174   void getDependentDialects(DialectRegistry& registry) const override {
175     registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
176                     scf::SCFDialect>();
177   }
178 
runOnFunctionmlir::lmhlo::__anond23aa8980111::LhloLegalizeToGpuPass179   void runOnFunction() override {
180     OwningRewritePatternList patterns;
181     ConversionTarget target(getContext());
182     target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
183                            gpu::GPUDialect, scf::SCFDialect, LmhloDialect>();
184     target.addIllegalOp<ReduceOp>();
185     auto func = getFunction();
186     patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());
187     if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
188       signalPassFailure();
189     }
190   }
191 };
192 
193 }  // namespace
194 
createLegalizeToGpuPass()195 std::unique_ptr<FunctionPass> createLegalizeToGpuPass() {
196   return std::make_unique<LhloLegalizeToGpuPass>();
197 }
198 
199 }  // namespace lmhlo
200 }  // namespace mlir
201