1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering MHLO general dot to a regular dot.
17 
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/StringSwitch.h"
20 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"
24 #include "mlir/IR/Attributes.h"
25 #include "mlir/IR/BuiltinOps.h"
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/Location.h"
28 #include "mlir/IR/Operation.h"
29 #include "mlir/IR/TypeUtilities.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32 
33 using mlir::DenseIntElementsAttr;
34 using mlir::ElementsAttr;
35 using mlir::failure;
36 using mlir::FunctionPass;
37 using mlir::LogicalResult;
38 using mlir::MLIRContext;
39 using mlir::OpRewritePattern;
40 using mlir::OwningRewritePatternList;
41 using mlir::PassWrapper;
42 using mlir::PatternRewriter;
43 using mlir::RankedTensorType;
44 using mlir::success;
45 using mlir::Value;
46 
47 namespace {
48 
TransposeReshape(Value arg,mlir::Location loc,llvm::ArrayRef<int64_t> left_dims,llvm::ArrayRef<int64_t> right_dims,llvm::ArrayRef<int64_t> arg_shape,PatternRewriter * rewriter)49 Value TransposeReshape(Value arg, mlir::Location loc,
50                        llvm::ArrayRef<int64_t> left_dims,
51                        llvm::ArrayRef<int64_t> right_dims,
52                        llvm::ArrayRef<int64_t> arg_shape,
53                        PatternRewriter *rewriter) {
54   auto element_type = mlir::getElementTypeOrSelf(arg.getType());
55 
56   int64_t left_size = 1;
57   for (auto dim : left_dims) {
58     left_size *= arg_shape[dim];
59   }
60 
61   int64_t right_size = 1;
62   for (auto dim : right_dims) {
63     right_size *= arg_shape[dim];
64   }
65 
66   // Generate the transpose permutation attribute.
67   llvm::SmallVector<int64_t, 5> transpose_permutation(left_dims.begin(),
68                                                       left_dims.end());
69   transpose_permutation.append(right_dims.begin(), right_dims.end());
70 
71   mlir::TensorType transpose_permutation_type = RankedTensorType::get(
72       {static_cast<int64_t>(transpose_permutation.size())},
73       rewriter->getIntegerType(64));
74 
75   auto transpose_permutation_attr =
76       DenseIntElementsAttr::get(transpose_permutation_type,
77                                 llvm::makeArrayRef(transpose_permutation))
78           .cast<DenseIntElementsAttr>();
79 
80   // Compute the resulting shape.
81   llvm::SmallVector<int64_t, 5> transposed_shape;
82   for (auto val : transpose_permutation) {
83     transposed_shape.push_back(arg_shape[val]);
84   }
85   auto transpose_type = RankedTensorType::get(transposed_shape, element_type);
86   auto transpose_result = rewriter->create<mlir::mhlo::TransposeOp>(
87       loc, transpose_type, arg, transpose_permutation_attr);
88 
89   // Return the final result.
90   auto reshaped_type =
91       RankedTensorType::get({left_size, right_size}, element_type);
92   return rewriter->create<mlir::mhlo::ReshapeOp>(loc, reshaped_type,
93                                                  transpose_result);
94 }
95 
ProcessDotArg(Value arg,mlir::Location loc,ElementsAttr contract_dims_attr,bool outer_dims_first,PatternRewriter * rewriter)96 Value ProcessDotArg(Value arg, mlir::Location loc,
97                     ElementsAttr contract_dims_attr, bool outer_dims_first,
98                     PatternRewriter *rewriter) {
99   auto shape = arg.getType().cast<mlir::ShapedType>().getShape();
100 
101   llvm::SmallVector<bool, 5> is_outer_dim;
102   is_outer_dim.resize(shape.size(), true);
103 
104   // Compute the contract dimension ordering.
105   llvm::SmallVector<int64_t, 5> contract_dims;
106   for (auto dim : contract_dims_attr.getValues<int64_t>()) {
107     contract_dims.push_back(dim);
108     is_outer_dim[dim] = false;
109   }
110 
111   // Compute the outer dimension orderings.
112   llvm::SmallVector<int64_t, 5> outer_dims;
113   for (auto it : llvm::enumerate(is_outer_dim)) {
114     if (it.value()) {
115       outer_dims.push_back(it.index());
116     }
117   }
118 
119   if (outer_dims_first) {
120     return TransposeReshape(arg, loc, outer_dims, contract_dims, shape,
121                             rewriter);
122   }
123 
124   return TransposeReshape(arg, loc, contract_dims, outer_dims, shape, rewriter);
125 }
126 
127 struct GeneralDotConvert : public OpRewritePattern<mlir::mhlo::DotGeneralOp> {
128   // Attempts to lower a General Dot operator to a standard Dot operator.
129   // General dots include batching dimensions and can have collapsing
130   // dimensions along any axis. Inserting correctly arrange transpose and
131   // reshape operators organizes the tensors and allows the General Dot to be
132   // replaced with the standard Dot operator.
133   //
134   // Note: This requires an empty list of batch dimensions.
135 
GeneralDotConvert__anon06bc0c9c0111::GeneralDotConvert136   explicit GeneralDotConvert(MLIRContext *context)
137       : OpRewritePattern(context) {}
138 
matchAndRewrite__anon06bc0c9c0111::GeneralDotConvert139   LogicalResult matchAndRewrite(mlir::mhlo::DotGeneralOp op,
140                                 PatternRewriter &rewriter) const override {
141     auto dot_element_type = mlir::getElementTypeOrSelf(op);
142 
143     auto dot_numbers = op.dot_dimension_numbers();
144     if (dot_numbers.lhs_batching_dimensions().getNumElements() != 0 ||
145         dot_numbers.rhs_batching_dimensions().getNumElements() != 0) {
146       return failure();
147     }
148 
149     auto lhs = ProcessDotArg(op.lhs(), op.getLoc(),
150                              dot_numbers.lhs_contracting_dimensions(),
151                              /*outer_dims_first=*/true, &rewriter);
152 
153     auto rhs = ProcessDotArg(op.rhs(), op.getLoc(),
154                              dot_numbers.rhs_contracting_dimensions(),
155                              /*outer_dims_first=*/false, &rewriter);
156 
157     // Accept only static shaped types.
158     auto lhs_shape_type = lhs.getType().dyn_cast_or_null<mlir::ShapedType>();
159     auto rhs_shape_type = rhs.getType().dyn_cast_or_null<mlir::ShapedType>();
160     if (!lhs_shape_type || !rhs_shape_type) return failure();
161     if (!lhs_shape_type.hasStaticShape() || !rhs_shape_type.hasStaticShape())
162       return failure();
163 
164     // Dot resulting shape.
165     auto lhs_shape = lhs_shape_type.getShape();
166     auto rhs_shape = rhs_shape_type.getShape();
167     auto new_dot_type =
168         RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type);
169 
170     auto new_dot_op = rewriter.create<mlir::mhlo::DotOp>(
171         op.getLoc(), new_dot_type, lhs, rhs, *(op.precision_config()));
172 
173     rewriter.replaceOpWithNewOp<mlir::mhlo::ReshapeOp>(op, op.getType(),
174                                                        new_dot_op);
175     return success();
176   }
177 };
178 
179 struct LegalizeGeneralDotPass
180     : public PassWrapper<LegalizeGeneralDotPass, FunctionPass> {
181   /// Lower all general dots that can be represented as a non-batched matmul.
runOnFunction__anon06bc0c9c0111::LegalizeGeneralDotPass182   void runOnFunction() override {
183     OwningRewritePatternList patterns;
184     mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(&patterns, &getContext());
185     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
186   }
187 };
188 
189 }  // namespace
190 
PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList * patterns,MLIRContext * ctx)191 void mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(
192     OwningRewritePatternList *patterns, MLIRContext *ctx) {
193   patterns->insert<GeneralDotConvert>(ctx);
194 }
195 
createLegalizeGeneralDotPass()196 std::unique_ptr<::mlir::Pass> mlir::mhlo::createLegalizeGeneralDotPass() {
197   return std::make_unique<LegalizeGeneralDotPass>();
198 }
199