1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering LHLO dialect to Affine dialect.
17 
18 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
19 #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
20 #include "mlir/Dialect/Affine/IR/AffineOps.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Location.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 
27 namespace mlir {
28 namespace lmhlo {
29 namespace {
30 
31 // Builds an affine loop nest iterating from zeros to "upper_bounds" with unit
32 // steps, and populates the body of the innermost loop using "body_builder".
BuildBoundedAffineLoopNest(OpBuilder & builder,Location location,ArrayRef<int64_t> upper_bounds,function_ref<void (OpBuilder &,Location,ValueRange)> body_builder)33 static void BuildBoundedAffineLoopNest(
34     OpBuilder& builder, Location location, ArrayRef<int64_t> upper_bounds,
35     function_ref<void(OpBuilder&, Location, ValueRange)> body_builder) {
36   SmallVector<int64_t, 3> lower_bounds(upper_bounds.size(), /*Value=*/0);
37   SmallVector<int64_t, 3> steps(upper_bounds.size(), /*Value=*/1);
38   buildAffineLoopNest(builder, location, lower_bounds, upper_bounds, steps,
39                       body_builder);
40 }
41 
42 struct DotOpConverter : public OpRewritePattern<DotOp> {
43   using OpRewritePattern<DotOp>::OpRewritePattern;
44 
45   // Supports only rank-2 tensors for LHS and RHS.
matchAndRewritemlir::lmhlo::__anonceafcfb50111::DotOpConverter46   LogicalResult matchAndRewrite(DotOp op,
47                                 PatternRewriter& rewriter) const override {
48     Value lhs = op.lhs();
49     Value rhs = op.rhs();
50     MemRefType lhs_type = lhs.getType().cast<MemRefType>();
51     MemRefType rhs_type = rhs.getType().cast<MemRefType>();
52     Type element_type = lhs_type.getElementType();
53     ArrayRef<int64_t> shape_lhs = lhs_type.getShape();
54     ArrayRef<int64_t> shape_rhs = rhs_type.getShape();
55 
56     if ((lhs_type.getRank() != 2) || (rhs_type.getRank() != 2)) {
57       return failure();
58     }
59 
60     // We don't currently support batching dimensions, or multiple contraction
61     // dimensions.
62     mhlo::DotDimensionNumbers dot_dimension_numbers =
63         op.dot_dimension_numbers();
64     if (dot_dimension_numbers.lhs_batching_dimensions().size() > 0 ||
65         dot_dimension_numbers.rhs_batching_dimensions().size() > 0)
66       return failure();
67     if (dot_dimension_numbers.lhs_contracting_dimensions().size() != 1 ||
68         *dot_dimension_numbers.lhs_contracting_dimensions().begin() != 1 ||
69         dot_dimension_numbers.rhs_contracting_dimensions().size() != 1 ||
70         *dot_dimension_numbers.rhs_contracting_dimensions().begin() != 0) {
71       return failure();
72     }
73 
74     LogicalResult map_status = success();
75     auto body_builder = [&](OpBuilder& builder, Location loc, ValueRange ivs) {
76       SmallVector<Value, 2> lhs_indices{ivs[0], ivs[2]},
77           rhs_indices{ivs[2], ivs[1]}, result_indices{ivs[0], ivs[1]};
78 
79       auto l = builder.create<AffineLoadOp>(loc, lhs, lhs_indices);
80       auto r = builder.create<AffineLoadOp>(loc, rhs, rhs_indices);
81       auto result =
82           rewriter.create<AffineLoadOp>(loc, op.output(), result_indices);
83       Value op_result = lmhlo::HloOpToStdScalarOp::map<DotOp>(
84           op, element_type, {l, r, result}, &builder);
85       map_status = success(op_result != nullptr);
86       if (failed(map_status)) return;
87       builder.create<AffineStoreOp>(loc, op_result, op.output(),
88                                     result_indices);
89     };
90 
91     BuildBoundedAffineLoopNest(rewriter, op.getLoc(),
92                                {shape_lhs[0], shape_rhs[1], shape_rhs[0]},
93                                body_builder);
94     if (failed(map_status)) return failure();
95 
96     rewriter.eraseOp(op);
97     return success();
98   }
99 };
100 
101 template <typename LhloOpTy>
102 struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
103   using OpRewritePattern<LhloOpTy>::OpRewritePattern;
104 
matchAndRewritemlir::lmhlo::__anonceafcfb50111::BinaryOpConverter105   LogicalResult matchAndRewrite(LhloOpTy op,
106                                 PatternRewriter& rewriter) const override {
107     const auto& lhs = op.lhs();
108     const auto& rhs = op.rhs();
109     const auto& lhs_type = lhs.getType().template cast<MemRefType>();
110     const auto& rhs_type = rhs.getType().template cast<MemRefType>();
111     const auto& element_type = lhs_type.getElementType();
112 
113     if (lhs_type.getShape() != rhs_type.getShape()) {
114       return failure();
115     }
116 
117     LogicalResult map_status = success();
118     auto body_builder = [&](OpBuilder& builder, Location loc,
119                             ValueRange induction_vars) {
120       auto l = builder.create<AffineLoadOp>(loc, lhs, induction_vars);
121       auto r = builder.create<AffineLoadOp>(loc, rhs, induction_vars);
122       Value op_result = lmhlo::HloOpToStdScalarOp::map<LhloOpTy>(
123           op, element_type, {l, r}, &builder);
124       map_status = success(op_result != nullptr);
125       if (failed(map_status)) return;
126       rewriter.create<AffineStoreOp>(loc, op_result, op.out(), induction_vars);
127     };
128 
129     BuildBoundedAffineLoopNest(rewriter, op.getLoc(), lhs_type.getShape(),
130                                body_builder);
131     if (failed(map_status)) return failure();
132     rewriter.eraseOp(op);
133     return success();
134   }
135 };
136 
populateLHLOToAffineConversionPattern(MLIRContext * context,OwningRewritePatternList * patterns)137 void populateLHLOToAffineConversionPattern(MLIRContext* context,
138                                            OwningRewritePatternList* patterns) {
139   // clang-format off
140   patterns->insert<
141       BinaryOpConverter<lmhlo::AddOp>,
142       BinaryOpConverter<lmhlo::AndOp>,
143       BinaryOpConverter<lmhlo::DivOp>,
144       BinaryOpConverter<lmhlo::MaxOp>,
145       BinaryOpConverter<lmhlo::MinOp>,
146       BinaryOpConverter<lmhlo::MulOp>,
147       BinaryOpConverter<lmhlo::SubOp>,
148       DotOpConverter>(context);
149   // clang-format on
150 }
151 
152 struct LhloLegalizeToAffinePass
153     : public PassWrapper<LhloLegalizeToAffinePass, FunctionPass> {
getDependentDialectsmlir::lmhlo::__anonceafcfb50111::LhloLegalizeToAffinePass154   void getDependentDialects(DialectRegistry& registry) const override {
155     registry.insert<AffineDialect>();
156   }
runOnFunctionmlir::lmhlo::__anonceafcfb50111::LhloLegalizeToAffinePass157   void runOnFunction() override {
158     OwningRewritePatternList patterns;
159     auto func = getFunction();
160     populateLHLOToAffineConversionPattern(func.getContext(), &patterns);
161     (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
162   }
163 };
164 
165 }  // namespace
166 
createLhloLegalizeToAffinePass()167 std::unique_ptr<OperationPass<FuncOp>> createLhloLegalizeToAffinePass() {
168   return std::make_unique<LhloLegalizeToAffinePass>();
169 }
170 
171 }  // namespace lmhlo
172 }  // namespace mlir
173