1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering LHLO dialect to Affine dialect.
17
18 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
19 #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
20 #include "mlir/Dialect/Affine/IR/AffineOps.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Location.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26
27 namespace mlir {
28 namespace lmhlo {
29 namespace {
30
31 // Builds an affine loop nest iterating from zeros to "upper_bounds" with unit
32 // steps, and populates the body of the innermost loop using "body_builder".
BuildBoundedAffineLoopNest(OpBuilder & builder,Location location,ArrayRef<int64_t> upper_bounds,function_ref<void (OpBuilder &,Location,ValueRange)> body_builder)33 static void BuildBoundedAffineLoopNest(
34 OpBuilder& builder, Location location, ArrayRef<int64_t> upper_bounds,
35 function_ref<void(OpBuilder&, Location, ValueRange)> body_builder) {
36 SmallVector<int64_t, 3> lower_bounds(upper_bounds.size(), /*Value=*/0);
37 SmallVector<int64_t, 3> steps(upper_bounds.size(), /*Value=*/1);
38 buildAffineLoopNest(builder, location, lower_bounds, upper_bounds, steps,
39 body_builder);
40 }
41
42 struct DotOpConverter : public OpRewritePattern<DotOp> {
43 using OpRewritePattern<DotOp>::OpRewritePattern;
44
45 // Supports only rank-2 tensors for LHS and RHS.
matchAndRewritemlir::lmhlo::__anonceafcfb50111::DotOpConverter46 LogicalResult matchAndRewrite(DotOp op,
47 PatternRewriter& rewriter) const override {
48 Value lhs = op.lhs();
49 Value rhs = op.rhs();
50 MemRefType lhs_type = lhs.getType().cast<MemRefType>();
51 MemRefType rhs_type = rhs.getType().cast<MemRefType>();
52 Type element_type = lhs_type.getElementType();
53 ArrayRef<int64_t> shape_lhs = lhs_type.getShape();
54 ArrayRef<int64_t> shape_rhs = rhs_type.getShape();
55
56 if ((lhs_type.getRank() != 2) || (rhs_type.getRank() != 2)) {
57 return failure();
58 }
59
60 // We don't currently support batching dimensions, or multiple contraction
61 // dimensions.
62 mhlo::DotDimensionNumbers dot_dimension_numbers =
63 op.dot_dimension_numbers();
64 if (dot_dimension_numbers.lhs_batching_dimensions().size() > 0 ||
65 dot_dimension_numbers.rhs_batching_dimensions().size() > 0)
66 return failure();
67 if (dot_dimension_numbers.lhs_contracting_dimensions().size() != 1 ||
68 *dot_dimension_numbers.lhs_contracting_dimensions().begin() != 1 ||
69 dot_dimension_numbers.rhs_contracting_dimensions().size() != 1 ||
70 *dot_dimension_numbers.rhs_contracting_dimensions().begin() != 0) {
71 return failure();
72 }
73
74 LogicalResult map_status = success();
75 auto body_builder = [&](OpBuilder& builder, Location loc, ValueRange ivs) {
76 SmallVector<Value, 2> lhs_indices{ivs[0], ivs[2]},
77 rhs_indices{ivs[2], ivs[1]}, result_indices{ivs[0], ivs[1]};
78
79 auto l = builder.create<AffineLoadOp>(loc, lhs, lhs_indices);
80 auto r = builder.create<AffineLoadOp>(loc, rhs, rhs_indices);
81 auto result =
82 rewriter.create<AffineLoadOp>(loc, op.output(), result_indices);
83 Value op_result = lmhlo::HloOpToStdScalarOp::map<DotOp>(
84 op, element_type, {l, r, result}, &builder);
85 map_status = success(op_result != nullptr);
86 if (failed(map_status)) return;
87 builder.create<AffineStoreOp>(loc, op_result, op.output(),
88 result_indices);
89 };
90
91 BuildBoundedAffineLoopNest(rewriter, op.getLoc(),
92 {shape_lhs[0], shape_rhs[1], shape_rhs[0]},
93 body_builder);
94 if (failed(map_status)) return failure();
95
96 rewriter.eraseOp(op);
97 return success();
98 }
99 };
100
101 template <typename LhloOpTy>
102 struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
103 using OpRewritePattern<LhloOpTy>::OpRewritePattern;
104
matchAndRewritemlir::lmhlo::__anonceafcfb50111::BinaryOpConverter105 LogicalResult matchAndRewrite(LhloOpTy op,
106 PatternRewriter& rewriter) const override {
107 const auto& lhs = op.lhs();
108 const auto& rhs = op.rhs();
109 const auto& lhs_type = lhs.getType().template cast<MemRefType>();
110 const auto& rhs_type = rhs.getType().template cast<MemRefType>();
111 const auto& element_type = lhs_type.getElementType();
112
113 if (lhs_type.getShape() != rhs_type.getShape()) {
114 return failure();
115 }
116
117 LogicalResult map_status = success();
118 auto body_builder = [&](OpBuilder& builder, Location loc,
119 ValueRange induction_vars) {
120 auto l = builder.create<AffineLoadOp>(loc, lhs, induction_vars);
121 auto r = builder.create<AffineLoadOp>(loc, rhs, induction_vars);
122 Value op_result = lmhlo::HloOpToStdScalarOp::map<LhloOpTy>(
123 op, element_type, {l, r}, &builder);
124 map_status = success(op_result != nullptr);
125 if (failed(map_status)) return;
126 rewriter.create<AffineStoreOp>(loc, op_result, op.out(), induction_vars);
127 };
128
129 BuildBoundedAffineLoopNest(rewriter, op.getLoc(), lhs_type.getShape(),
130 body_builder);
131 if (failed(map_status)) return failure();
132 rewriter.eraseOp(op);
133 return success();
134 }
135 };
136
populateLHLOToAffineConversionPattern(MLIRContext * context,OwningRewritePatternList * patterns)137 void populateLHLOToAffineConversionPattern(MLIRContext* context,
138 OwningRewritePatternList* patterns) {
139 // clang-format off
140 patterns->insert<
141 BinaryOpConverter<lmhlo::AddOp>,
142 BinaryOpConverter<lmhlo::AndOp>,
143 BinaryOpConverter<lmhlo::DivOp>,
144 BinaryOpConverter<lmhlo::MaxOp>,
145 BinaryOpConverter<lmhlo::MinOp>,
146 BinaryOpConverter<lmhlo::MulOp>,
147 BinaryOpConverter<lmhlo::SubOp>,
148 DotOpConverter>(context);
149 // clang-format on
150 }
151
152 struct LhloLegalizeToAffinePass
153 : public PassWrapper<LhloLegalizeToAffinePass, FunctionPass> {
getDependentDialectsmlir::lmhlo::__anonceafcfb50111::LhloLegalizeToAffinePass154 void getDependentDialects(DialectRegistry& registry) const override {
155 registry.insert<AffineDialect>();
156 }
runOnFunctionmlir::lmhlo::__anonceafcfb50111::LhloLegalizeToAffinePass157 void runOnFunction() override {
158 OwningRewritePatternList patterns;
159 auto func = getFunction();
160 populateLHLOToAffineConversionPattern(func.getContext(), &patterns);
161 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
162 }
163 };
164
165 } // namespace
166
createLhloLegalizeToAffinePass()167 std::unique_ptr<OperationPass<FuncOp>> createLhloLegalizeToAffinePass() {
168 return std::make_unique<LhloLegalizeToAffinePass>();
169 }
170
171 } // namespace lmhlo
172 } // namespace mlir
173