1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file provides optional optimization patterns for mhlo, canonocalizing
17 // operations to equivalent but potentially more efficient operations.
18 
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <numeric>
23 
24 #include "llvm/ADT/STLExtras.h"
25 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
27 #include "mlir-hlo/utils/hlo_utils.h"
28 #include "mlir/IR/Attributes.h"
29 #include "mlir/IR/MLIRContext.h"
30 #include "mlir/IR/Operation.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
33 #include "mlir/IR/Types.h"
34 #include "mlir/Pass/Pass.h"
35 #include "mlir/Pass/PassRegistry.h"
36 
37 using mlir::OwningRewritePatternList;
38 
39 namespace mlir {
40 namespace mhlo {
41 namespace {
42 
43 // Returns 1D 64-bit dense elements attribute with the given values.
GetI64ElementsAttr(ArrayRef<int64_t> values,Builder * builder)44 static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
45                                                Builder* builder) {
46   RankedTensorType ty = RankedTensorType::get(
47       {static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
48   return DenseIntElementsAttr::get(ty, values);
49 }
50 
51 //===----------------------------------------------------------------------===//
52 // GatherOp
53 //===----------------------------------------------------------------------===//
54 
55 class GatherIsSlice : public OpRewritePattern<GatherOp> {
56   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(GatherOp gather,PatternRewriter & rewriter) const57   LogicalResult matchAndRewrite(GatherOp gather,
58                                 PatternRewriter& rewriter) const override {
59     auto dimension_numbers = gather.dimension_numbers();
60 
61     // Inputs need to be ranked to lower.
62     if (!gather.operand().getType().cast<ShapedType>().hasRank() ||
63         !gather.operand().getType().cast<ShapedType>().hasStaticShape() ||
64         !gather.start_indices().getType().cast<ShapedType>().hasRank() ||
65         !gather.start_indices().getType().cast<ShapedType>().hasStaticShape()) {
66       return failure();
67     }
68 
69     if (dimension_numbers.index_vector_dim().getValue().getSExtValue() != 0) {
70       return failure();
71     }
72 
73     // TODO(suderman): Handle start index map != {0}.
74     if (!dimension_numbers.start_index_map() ||
75         dimension_numbers.start_index_map().getType().getRank() != 1 ||
76         dimension_numbers.start_index_map().getType().getDimSize(0) != 1 ||
77         dimension_numbers.start_index_map()
78                 .getValue({0})
79                 .cast<IntegerAttr>()
80                 .getValue() != 0) {
81       return failure();
82     }
83 
84     auto result_ty = gather.getResult().getType().dyn_cast<RankedTensorType>();
85 
86     // Requires a ranked output.
87     if (!result_ty) {
88       return failure();
89     }
90     if (dimension_numbers.offset_dims().getType().getNumElements() !=
91         result_ty.getRank()) {
92       return failure();
93     }
94     for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) {
95       if (it.index() != it.value()) {
96         return failure();
97       }
98     }
99 
100     // Verify the gather slice sizes are correct.
101     if (gather.slice_sizes().getNumElements() !=
102         gather.operand().getType().cast<ShapedType>().getRank()) {
103       return failure();
104     }
105 
106     // Validate the slice sizes are correct.
107     if (gather.slice_sizes().getType().cast<ShapedType>().getNumElements() <
108         result_ty.getShape().size() + 1) {
109       return failure();
110     }
111 
112     for (auto it : llvm::enumerate(result_ty.getShape())) {
113       if (gather.slice_sizes()
114               .getValue(it.index() + 1)
115               .cast<IntegerAttr>()
116               .getValue() != it.value()) {
117         return failure();
118       }
119     }
120 
121     auto gather_start_indices = gather.start_indices();
122     auto gather_start_indices_ty =
123         gather_start_indices.getType().cast<ShapedType>();
124 
125     llvm::SmallVector<Value, 4> slice_start_indices;
126 
127     if (gather_start_indices_ty.getRank() == 0) {
128       slice_start_indices.push_back(gather_start_indices);
129     } else if (gather_start_indices_ty.getRank() == 1) {
130       for (int i = 0; i < gather_start_indices_ty.getDimSize(0); i++) {
131         auto start = GetI64ElementsAttr({i}, &rewriter);
132         auto limit = GetI64ElementsAttr({i + 1}, &rewriter);
133         auto stride = GetI64ElementsAttr({1}, &rewriter);
134         auto indicesSlice = rewriter.create<SliceOp>(
135             gather.getLoc(), gather_start_indices, start, limit, stride);
136         auto reshaped = rewriter.create<ReshapeOp>(
137             gather.getLoc(),
138             RankedTensorType::get(
139                 {}, indicesSlice.getType().cast<ShapedType>().getElementType()),
140             indicesSlice);
141         slice_start_indices.push_back(reshaped);
142       }
143     } else {
144       return failure();
145     }
146 
147     auto sliceSizes = gather.slice_sizes();
148     auto sliceSizesTy = sliceSizes.getType();
149     if (sliceSizesTy.getRank() != 1) {
150       return failure();
151     }
152 
153     // Start indices have implicit zeros when not specified. This is because
154     // Gather occurs similar to slicing where full slices are inferred. Add any
155     // missing zeros as necessary.
156     auto zero = rewriter.create<ConstOp>(
157         gather.getLoc(), rewriter.getZeroAttr(RankedTensorType::get(
158                              {}, gather_start_indices_ty.getElementType())));
159     while (slice_start_indices.size() < sliceSizesTy.getDimSize(0)) {
160       slice_start_indices.push_back(zero);
161     }
162 
163     SmallVector<int64_t, 5> sliceShape;
164     for (auto shapeValue : gather.slice_sizes().getIntValues()) {
165       sliceShape.push_back(shapeValue.getSExtValue());
166     }
167 
168     auto sliceTy =
169         RankedTensorType::get(sliceShape, result_ty.getElementType());
170     auto slice = rewriter.create<DynamicSliceOp>(
171         gather.getLoc(), sliceTy, gather.operand(), slice_start_indices,
172         gather.slice_sizes());
173 
174     rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(), slice);
175 
176     return success();
177   }
178 };
179 
180 }  // end anonymous namespace
181 
PopulateOptimizeMHLOPatterns(MLIRContext * context,OwningRewritePatternList * patterns)182 void PopulateOptimizeMHLOPatterns(MLIRContext* context,
183                                   OwningRewritePatternList* patterns) {
184   patterns->insert<GatherIsSlice>(context);
185 }
186 }  // end namespace mhlo
187 }  // end namespace mlir
188