1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file provides optional optimization patterns for mhlo, canonocalizing
17 // operations to equivalent but potentially more efficient operations.
18
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <numeric>
23
24 #include "llvm/ADT/STLExtras.h"
25 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
27 #include "mlir-hlo/utils/hlo_utils.h"
28 #include "mlir/IR/Attributes.h"
29 #include "mlir/IR/MLIRContext.h"
30 #include "mlir/IR/Operation.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
33 #include "mlir/IR/Types.h"
34 #include "mlir/Pass/Pass.h"
35 #include "mlir/Pass/PassRegistry.h"
36
37 using mlir::OwningRewritePatternList;
38
39 namespace mlir {
40 namespace mhlo {
41 namespace {
42
43 // Returns 1D 64-bit dense elements attribute with the given values.
GetI64ElementsAttr(ArrayRef<int64_t> values,Builder * builder)44 static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
45 Builder* builder) {
46 RankedTensorType ty = RankedTensorType::get(
47 {static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
48 return DenseIntElementsAttr::get(ty, values);
49 }
50
51 //===----------------------------------------------------------------------===//
52 // GatherOp
53 //===----------------------------------------------------------------------===//
54
55 class GatherIsSlice : public OpRewritePattern<GatherOp> {
56 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(GatherOp gather,PatternRewriter & rewriter) const57 LogicalResult matchAndRewrite(GatherOp gather,
58 PatternRewriter& rewriter) const override {
59 auto dimension_numbers = gather.dimension_numbers();
60
61 // Inputs need to be ranked to lower.
62 if (!gather.operand().getType().cast<ShapedType>().hasRank() ||
63 !gather.operand().getType().cast<ShapedType>().hasStaticShape() ||
64 !gather.start_indices().getType().cast<ShapedType>().hasRank() ||
65 !gather.start_indices().getType().cast<ShapedType>().hasStaticShape()) {
66 return failure();
67 }
68
69 if (dimension_numbers.index_vector_dim().getValue().getSExtValue() != 0) {
70 return failure();
71 }
72
73 // TODO(suderman): Handle start index map != {0}.
74 if (!dimension_numbers.start_index_map() ||
75 dimension_numbers.start_index_map().getType().getRank() != 1 ||
76 dimension_numbers.start_index_map().getType().getDimSize(0) != 1 ||
77 dimension_numbers.start_index_map()
78 .getValue({0})
79 .cast<IntegerAttr>()
80 .getValue() != 0) {
81 return failure();
82 }
83
84 auto result_ty = gather.getResult().getType().dyn_cast<RankedTensorType>();
85
86 // Requires a ranked output.
87 if (!result_ty) {
88 return failure();
89 }
90 if (dimension_numbers.offset_dims().getType().getNumElements() !=
91 result_ty.getRank()) {
92 return failure();
93 }
94 for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) {
95 if (it.index() != it.value()) {
96 return failure();
97 }
98 }
99
100 // Verify the gather slice sizes are correct.
101 if (gather.slice_sizes().getNumElements() !=
102 gather.operand().getType().cast<ShapedType>().getRank()) {
103 return failure();
104 }
105
106 // Validate the slice sizes are correct.
107 if (gather.slice_sizes().getType().cast<ShapedType>().getNumElements() <
108 result_ty.getShape().size() + 1) {
109 return failure();
110 }
111
112 for (auto it : llvm::enumerate(result_ty.getShape())) {
113 if (gather.slice_sizes()
114 .getValue(it.index() + 1)
115 .cast<IntegerAttr>()
116 .getValue() != it.value()) {
117 return failure();
118 }
119 }
120
121 auto gather_start_indices = gather.start_indices();
122 auto gather_start_indices_ty =
123 gather_start_indices.getType().cast<ShapedType>();
124
125 llvm::SmallVector<Value, 4> slice_start_indices;
126
127 if (gather_start_indices_ty.getRank() == 0) {
128 slice_start_indices.push_back(gather_start_indices);
129 } else if (gather_start_indices_ty.getRank() == 1) {
130 for (int i = 0; i < gather_start_indices_ty.getDimSize(0); i++) {
131 auto start = GetI64ElementsAttr({i}, &rewriter);
132 auto limit = GetI64ElementsAttr({i + 1}, &rewriter);
133 auto stride = GetI64ElementsAttr({1}, &rewriter);
134 auto indicesSlice = rewriter.create<SliceOp>(
135 gather.getLoc(), gather_start_indices, start, limit, stride);
136 auto reshaped = rewriter.create<ReshapeOp>(
137 gather.getLoc(),
138 RankedTensorType::get(
139 {}, indicesSlice.getType().cast<ShapedType>().getElementType()),
140 indicesSlice);
141 slice_start_indices.push_back(reshaped);
142 }
143 } else {
144 return failure();
145 }
146
147 auto sliceSizes = gather.slice_sizes();
148 auto sliceSizesTy = sliceSizes.getType();
149 if (sliceSizesTy.getRank() != 1) {
150 return failure();
151 }
152
153 // Start indices have implicit zeros when not specified. This is because
154 // Gather occurs similar to slicing where full slices are inferred. Add any
155 // missing zeros as necessary.
156 auto zero = rewriter.create<ConstOp>(
157 gather.getLoc(), rewriter.getZeroAttr(RankedTensorType::get(
158 {}, gather_start_indices_ty.getElementType())));
159 while (slice_start_indices.size() < sliceSizesTy.getDimSize(0)) {
160 slice_start_indices.push_back(zero);
161 }
162
163 SmallVector<int64_t, 5> sliceShape;
164 for (auto shapeValue : gather.slice_sizes().getIntValues()) {
165 sliceShape.push_back(shapeValue.getSExtValue());
166 }
167
168 auto sliceTy =
169 RankedTensorType::get(sliceShape, result_ty.getElementType());
170 auto slice = rewriter.create<DynamicSliceOp>(
171 gather.getLoc(), sliceTy, gather.operand(), slice_start_indices,
172 gather.slice_sizes());
173
174 rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(), slice);
175
176 return success();
177 }
178 };
179
180 } // end anonymous namespace
181
PopulateOptimizeMHLOPatterns(MLIRContext * context,OwningRewritePatternList * patterns)182 void PopulateOptimizeMHLOPatterns(MLIRContext* context,
183 OwningRewritePatternList* patterns) {
184 patterns->insert<GatherIsSlice>(context);
185 }
186 } // end namespace mhlo
187 } // end namespace mlir
188