1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
17 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
18 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 
23 namespace mlir {
24 
25 namespace mhlo {
26 namespace {
27 
28 struct GatherIsTorchIndexSelect : public OpRewritePattern<GatherOp> {
29   using OpRewritePattern<GatherOp>::OpRewritePattern;
30 
matchAndRewritemlir::mhlo::__anon5939d9ee0111::GatherIsTorchIndexSelect31   LogicalResult matchAndRewrite(GatherOp gather,
32                                 PatternRewriter &rewriter) const override {
33     auto start_indices = gather.start_indices();
34     auto start_indices_ty = start_indices.getType().cast<ShapedType>();
35     if (!start_indices_ty.hasRank()) {
36       return failure();
37     }
38 
39     auto operand = gather.operand();
40     auto operand_ty = operand.getType().cast<ShapedType>();
41     if (!operand_ty.hasRank()) {
42       return failure();
43     }
44 
45     int64_t index_vector_dim =
46         std::max<int64_t>(0, start_indices_ty.getRank() - 1);
47 
48     // We can use torch_index_select if the last dimension represents the
49     // gather indices.
50     auto dimension_numbers = gather.dimension_numbers();
51     if (dimension_numbers.index_vector_dim().getValue().getSExtValue() !=
52         index_vector_dim) {
53       return failure();
54     }
55 
56     // Index select only works across a single dimension.
57     if (!start_indices_ty.getShape().empty() &&
58         start_indices_ty.getShape().back() != 1) {
59       return failure();
60     }
61 
62     // Only support the default case for start_index_map.
63     if (dimension_numbers.start_index_map().getType().getRank() != 1 ||
64         dimension_numbers.start_index_map()
65                 .getValue(0)
66                 .cast<IntegerAttr>()
67                 .getValue() != 0) {
68       return failure();
69     }
70 
71     auto result_ty = gather.getResult().getType().dyn_cast<RankedTensorType>();
72     if (!result_ty) {
73       return failure();
74     }
75 
76     // Offset dimensions should be the defaults.
77     if (dimension_numbers.offset_dims().getType().getNumElements() !=
78         result_ty.getRank() - index_vector_dim) {
79       return failure();
80     }
81 
82     for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) {
83       if ((it.index() + index_vector_dim) != it.value()) {
84         return failure();
85       }
86     }
87 
88     for (auto it : llvm::enumerate(gather.slice_sizes().getIntValues())) {
89       // First shape value must be 1.
90       if (it.index() == 0) {
91         if (it.value().getSExtValue() != 1) {
92           return failure();
93         }
94         continue;
95       }
96 
97       // The gather needs to index the entire slice for each other dimension.
98       if (it.value().getSExtValue() != operand_ty.getDimSize(it.index())) {
99         return failure();
100       }
101     }
102 
103     llvm::SmallVector<int64_t, 4> index_select_shape =
104         llvm::to_vector<4>(start_indices_ty.getShape());
105 
106     for (auto dim : operand_ty.getShape().drop_front()) {
107       index_select_shape.push_back(dim);
108     }
109 
110     if (!dimension_numbers.collapsed_slice_dims().getType().hasRank() ||
111         dimension_numbers.collapsed_slice_dims().getType().getNumElements() !=
112             1 ||
113         dimension_numbers.collapsed_slice_dims().getValue<int64_t>({0}) != 0) {
114       return failure();
115     }
116 
117     auto torch_index_select = rewriter.create<TorchIndexSelectOp>(
118         gather.getLoc(),
119         RankedTensorType::get(index_select_shape, operand_ty.getElementType()),
120         operand, gather.start_indices(), rewriter.getI64IntegerAttr(0),
121         rewriter.getI64IntegerAttr(0));
122 
123     rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(),
124                                            torch_index_select);
125 
126     return success();
127   }
128 };
129 
130 struct LegalizeGatherToTorchIndexSelectPass
131     : public PassWrapper<LegalizeGatherToTorchIndexSelectPass, FunctionPass> {
132   /// Perform the lowering of standard dialect operations to approximations.
runOnFunctionmlir::mhlo::__anon5939d9ee0111::LegalizeGatherToTorchIndexSelectPass133   void runOnFunction() override {
134     OwningRewritePatternList patterns;
135     PopulateGatherToTorchIndexSelectPatterns(&getContext(), &patterns);
136     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
137   }
138 };
139 }  // namespace
140 
PopulateGatherToTorchIndexSelectPatterns(mlir::MLIRContext * context,OwningRewritePatternList * patterns)141 void PopulateGatherToTorchIndexSelectPatterns(
142     mlir::MLIRContext *context, OwningRewritePatternList *patterns) {
143   patterns->insert<GatherIsTorchIndexSelect>(context);
144 }
145 
createLegalizeGatherToTorchIndexSelectPass()146 std::unique_ptr<FunctionPass> createLegalizeGatherToTorchIndexSelectPass() {
147   return std::make_unique<LegalizeGatherToTorchIndexSelectPass>();
148 }
149 
150 }  // namespace mhlo
151 }  // namespace mlir
152