1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
17 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
18 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/Pass/Pass.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22
23 namespace mlir {
24
25 namespace mhlo {
26 namespace {
27
28 struct GatherIsTorchIndexSelect : public OpRewritePattern<GatherOp> {
29 using OpRewritePattern<GatherOp>::OpRewritePattern;
30
matchAndRewritemlir::mhlo::__anon5939d9ee0111::GatherIsTorchIndexSelect31 LogicalResult matchAndRewrite(GatherOp gather,
32 PatternRewriter &rewriter) const override {
33 auto start_indices = gather.start_indices();
34 auto start_indices_ty = start_indices.getType().cast<ShapedType>();
35 if (!start_indices_ty.hasRank()) {
36 return failure();
37 }
38
39 auto operand = gather.operand();
40 auto operand_ty = operand.getType().cast<ShapedType>();
41 if (!operand_ty.hasRank()) {
42 return failure();
43 }
44
45 int64_t index_vector_dim =
46 std::max<int64_t>(0, start_indices_ty.getRank() - 1);
47
48 // We can use torch_index_select if the last dimension represents the
49 // gather indices.
50 auto dimension_numbers = gather.dimension_numbers();
51 if (dimension_numbers.index_vector_dim().getValue().getSExtValue() !=
52 index_vector_dim) {
53 return failure();
54 }
55
56 // Index select only works across a single dimension.
57 if (!start_indices_ty.getShape().empty() &&
58 start_indices_ty.getShape().back() != 1) {
59 return failure();
60 }
61
62 // Only support the default case for start_index_map.
63 if (dimension_numbers.start_index_map().getType().getRank() != 1 ||
64 dimension_numbers.start_index_map()
65 .getValue(0)
66 .cast<IntegerAttr>()
67 .getValue() != 0) {
68 return failure();
69 }
70
71 auto result_ty = gather.getResult().getType().dyn_cast<RankedTensorType>();
72 if (!result_ty) {
73 return failure();
74 }
75
76 // Offset dimensions should be the defaults.
77 if (dimension_numbers.offset_dims().getType().getNumElements() !=
78 result_ty.getRank() - index_vector_dim) {
79 return failure();
80 }
81
82 for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) {
83 if ((it.index() + index_vector_dim) != it.value()) {
84 return failure();
85 }
86 }
87
88 for (auto it : llvm::enumerate(gather.slice_sizes().getIntValues())) {
89 // First shape value must be 1.
90 if (it.index() == 0) {
91 if (it.value().getSExtValue() != 1) {
92 return failure();
93 }
94 continue;
95 }
96
97 // The gather needs to index the entire slice for each other dimension.
98 if (it.value().getSExtValue() != operand_ty.getDimSize(it.index())) {
99 return failure();
100 }
101 }
102
103 llvm::SmallVector<int64_t, 4> index_select_shape =
104 llvm::to_vector<4>(start_indices_ty.getShape());
105
106 for (auto dim : operand_ty.getShape().drop_front()) {
107 index_select_shape.push_back(dim);
108 }
109
110 if (!dimension_numbers.collapsed_slice_dims().getType().hasRank() ||
111 dimension_numbers.collapsed_slice_dims().getType().getNumElements() !=
112 1 ||
113 dimension_numbers.collapsed_slice_dims().getValue<int64_t>({0}) != 0) {
114 return failure();
115 }
116
117 auto torch_index_select = rewriter.create<TorchIndexSelectOp>(
118 gather.getLoc(),
119 RankedTensorType::get(index_select_shape, operand_ty.getElementType()),
120 operand, gather.start_indices(), rewriter.getI64IntegerAttr(0),
121 rewriter.getI64IntegerAttr(0));
122
123 rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(),
124 torch_index_select);
125
126 return success();
127 }
128 };
129
130 struct LegalizeGatherToTorchIndexSelectPass
131 : public PassWrapper<LegalizeGatherToTorchIndexSelectPass, FunctionPass> {
132 /// Perform the lowering of standard dialect operations to approximations.
runOnFunctionmlir::mhlo::__anon5939d9ee0111::LegalizeGatherToTorchIndexSelectPass133 void runOnFunction() override {
134 OwningRewritePatternList patterns;
135 PopulateGatherToTorchIndexSelectPatterns(&getContext(), &patterns);
136 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
137 }
138 };
139 } // namespace
140
PopulateGatherToTorchIndexSelectPatterns(mlir::MLIRContext * context,OwningRewritePatternList * patterns)141 void PopulateGatherToTorchIndexSelectPatterns(
142 mlir::MLIRContext *context, OwningRewritePatternList *patterns) {
143 patterns->insert<GatherIsTorchIndexSelect>(context);
144 }
145
createLegalizeGatherToTorchIndexSelectPass()146 std::unique_ptr<FunctionPass> createLegalizeGatherToTorchIndexSelectPass() {
147 return std::make_unique<LegalizeGatherToTorchIndexSelectPass>();
148 }
149
150 } // namespace mhlo
151 } // namespace mlir
152