1 //===- Bufferize.cpp - Bufferization of linalg ops ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Transforms/Bufferize.h"
10 #include "PassDetail.h"
11 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
12 #include "mlir/Dialect/Linalg/Passes.h"
13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14 #include "mlir/Dialect/Linalg/Utils/Utils.h"
15 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/IR/BuiltinDialect.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/Pass/Pass.h"
20
21 using namespace ::mlir;
22 using namespace ::mlir::linalg;
23
maybeConvertToIndex(Location loc,Value val,OpBuilder & b)24 static Value maybeConvertToIndex(Location loc, Value val, OpBuilder &b) {
25 if (val.getType().isIndex())
26 return val;
27 return b.create<IndexCastOp>(loc, val, b.getIndexType());
28 }
29
cloneMemref(Location loc,Value memref,OpBuilder & b)30 static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
31 auto memrefType = memref.getType().cast<MemRefType>();
32 SmallVector<Value, 4> dynOperands;
33 for (auto dim : llvm::enumerate(memrefType.getShape())) {
34 if (dim.value() == TensorType::kDynamicSize) {
35 dynOperands.push_back(b.create<DimOp>(loc, memref, dim.index()));
36 }
37 }
38 auto alloc = b.create<AllocOp>(loc, memrefType, dynOperands);
39 b.create<linalg::CopyOp>(loc, memref, alloc);
40 return alloc;
41 }
42
43 static LogicalResult
allocateBuffersForResults(Location loc,LinalgOp linalgOp,linalg::GenericOpAdaptor & adaptor,SmallVectorImpl<Value> & resultBuffers,OpBuilder & b)44 allocateBuffersForResults(Location loc, LinalgOp linalgOp,
45 linalg::GenericOpAdaptor &adaptor,
46 SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) {
47 // Lazily compute loopRanges.
48 SmallVector<Range, 4> loopRanges;
49
50 // Allocate a buffer for every tensor result.
51 for (auto en : llvm::enumerate(linalgOp->getResultTypes())) {
52 size_t resultIndex = en.index();
53 Type resultType = en.value();
54
55 auto tensorType = resultType.dyn_cast<RankedTensorType>();
56 if (tensorType == nullptr) {
57 linalgOp.emitOpError()
58 << "tensor to buffer conversion expects ranked tensor results";
59 return failure();
60 }
61 auto tensorShape = tensorType.getShape();
62 auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType());
63
64 // Allocate buffers for init tensors that are assumed to fold onto the first
65 // results.
66 // TODO: update this assumption because the reality is more complex
67 // under linalg on tensor based transformations.
68 bool hasInitTensor = resultIndex < linalgOp.getNumInitTensors();
69 if (hasInitTensor) {
70 resultBuffers.push_back(
71 cloneMemref(loc, adaptor.init_tensors()[resultIndex], b));
72 continue;
73 }
74
75 // Allocate buffers for statically-shaped results.
76 if (memrefType.hasStaticShape()) {
77 resultBuffers.push_back(b.create<AllocOp>(loc, memrefType));
78 continue;
79 }
80
81 // Perform a naive shape inference for the dynamically-shaped results.
82 // Extract the required element out of the vector.
83 SmallVector<Value, 4> dynOperands;
84 auto resultIndexingMap = linalgOp.getOutputIndexingMap(resultIndex);
85 for (auto shapeElement : llvm::enumerate(tensorType.getShape())) {
86 if (loopRanges.empty())
87 loopRanges = linalgOp.createLoopRanges(b, loc);
88 if (shapeElement.value() != ShapedType::kDynamicSize)
89 continue;
90 AffineExpr expr = resultIndexingMap.getResult(shapeElement.index());
91 switch (expr.getKind()) {
92 case AffineExprKind::DimId: {
93 int64_t loopIndex = expr.cast<AffineDimExpr>().getPosition();
94 Value size = maybeConvertToIndex(loc, loopRanges[loopIndex].size, b);
95 dynOperands.push_back(size);
96 break;
97 }
98 default:
99 return failure();
100 }
101 }
102 resultBuffers.push_back(b.create<AllocOp>(loc, memrefType, dynOperands));
103 }
104 return success();
105 }
106
107 /// Specialization for `linalg::GenericOp` and `linalg::IndexedGenericOp`.
108 /// A pattern to convert Generic Linalg operations which work on tensors to
109 /// use buffers. BufferPlacement pass should be later used to move
110 /// Alloc operations to the correct positions and insert the missing Dealloc
111 /// operations in the correct places.
112 template <typename GenericOpTy>
113 static void
finalizeBufferAllocationForGenericOp(ConversionPatternRewriter & rewriter,GenericOpTy genericOp,ValueRange inputs,ValueRange outputs)114 finalizeBufferAllocationForGenericOp(ConversionPatternRewriter &rewriter,
115 GenericOpTy genericOp, ValueRange inputs,
116 ValueRange outputs) {
117 // Generate a new linalg operation that works on buffers.
118 auto newGenericOp = rewriter.create<GenericOpTy>(
119 genericOp.getLoc(),
120 /*resultTensorTypes=*/llvm::None,
121 /*inputs=*/inputs,
122 /*outputBuffers=*/outputs,
123 /*initTensors=*/llvm::None, genericOp.indexing_maps(),
124 genericOp.iterator_types(), genericOp.docAttr(),
125 genericOp.library_callAttr(), genericOp.sparseAttr());
126
127 // Create a new block in the region of the new Generic Op.
128 Block *oldBlock = genericOp.getBody();
129 Region &newRegion = newGenericOp.region();
130 Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
131 oldBlock->getArgumentTypes());
132
133 // Add the result arguments to the new block.
134 for (Value v : ValueRange(outputs).drop_front(genericOp.getNumInitTensors()))
135 newBlock->addArgument(v.getType().cast<MemRefType>().getElementType());
136
137 // Clone the body of the old block to the new block.
138 BlockAndValueMapping mapping;
139 mapping.map(oldBlock->getArguments(), newBlock->getArguments());
140
141 OpBuilder::InsertionGuard guard(rewriter);
142 rewriter.setInsertionPointToEnd(newBlock);
143 for (auto &op : oldBlock->getOperations()) {
144 Operation *clonedOp = rewriter.clone(op, mapping);
145 mapping.map(op.getResults(), clonedOp->getResults());
146 }
147
148 // Replace the results of the old op with the new output buffers.
149 rewriter.replaceOp(genericOp, outputs);
150 }
151
152 /// Specialization for all other `linalg::LinalgOp`.
finalizeBufferAllocation(ConversionPatternRewriter & rewriter,linalg::LinalgOp linalgOp,ValueRange inputs,ValueRange outputs)153 static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
154 linalg::LinalgOp linalgOp,
155 ValueRange inputs, ValueRange outputs) {
156 assert(!isa<linalg::GenericOp>(linalgOp.getOperation()));
157 assert(!isa<linalg::IndexedGenericOp>(linalgOp.getOperation()));
158 SmallVector<Value, 8> newOperands = inputs;
159 newOperands.append(outputs.begin(), outputs.end());
160 auto otherOperands = linalgOp.getAssumedNonShapedOperands();
161 newOperands.append(otherOperands.begin(), otherOperands.end());
162 LinalgOp res = cast<LinalgOp>(linalgOp.clone(rewriter, linalgOp.getLoc(),
163 /*resultTypes=*/ArrayRef<Type>{},
164 newOperands));
165 // Need to mutate the operands_segment_sizes in the resulting op.
166 res.setNumOutputBuffers(outputs.size());
167 res.setNumInitTensors(0);
168 // Replace the results of the old op with the new output buffers.
169 rewriter.replaceOp(linalgOp, outputs);
170 }
171
172 //===----------------------------------------------------------------------===//
173 // Bufferization patterns.
174 //===----------------------------------------------------------------------===//
175
176 namespace {
177 /// Generic conversion pattern that matches any LinalgOp. This avoids template
178 /// instantiating one pattern for each LinalgOp.
179 class BufferizeAnyLinalgOp : public ConversionPattern {
180 public:
BufferizeAnyLinalgOp(TypeConverter & typeConverter)181 BufferizeAnyLinalgOp(TypeConverter &typeConverter)
182 : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {}
183
184 LogicalResult
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const185 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
186 ConversionPatternRewriter &rewriter) const final {
187
188 LinalgOp linalgOp = dyn_cast<linalg::LinalgOp>(op);
189 if (!linalgOp)
190 return failure();
191
192 // We abuse the GenericOpAdaptor here.
193 // TODO: Manually create an Adaptor that captures inputs, output_buffers and
194 // init_tensors for all linalg::LinalgOp interface ops.
195 linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
196
197 Location loc = linalgOp.getLoc();
198 SmallVector<Value, 2> newOutputBuffers(adaptor.output_buffers().begin(),
199 adaptor.output_buffers().end());
200
201 if (failed(allocateBuffersForResults(loc, linalgOp, adaptor,
202 newOutputBuffers, rewriter))) {
203 linalgOp.emitOpError()
204 << "Failed to allocate buffers for tensor results.";
205 return failure();
206 }
207
208 // Delegate to the linalg generic pattern.
209 if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
210 finalizeBufferAllocationForGenericOp<GenericOp>(
211 rewriter, genericOp, adaptor.inputs(), newOutputBuffers);
212 return success();
213 }
214
215 // Delegate to the linalg indexed generic pattern.
216 if (auto genericOp = dyn_cast<linalg::IndexedGenericOp>(op)) {
217 finalizeBufferAllocationForGenericOp<IndexedGenericOp>(
218 rewriter, genericOp, adaptor.inputs(), newOutputBuffers);
219 return success();
220 }
221
222 finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(),
223 newOutputBuffers);
224 return success();
225 }
226 };
227
228 // Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
extractFromI64ArrayAttr(Attribute attr)229 static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
230 return llvm::to_vector<4>(
231 llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
232 return a.cast<IntegerAttr>().getInt();
233 }));
234 }
235
236 /// Convert `subtensor %t [offsets][sizes][strides] -> %st` to an alloc + copy
237 /// pattern.
238 /// ```
239 /// %a = alloc(sizes)
240 /// %sv = subview %source [offsets][sizes][strides]
241 /// linalg_copy(%sv, %a)
242 /// ```
243 ///
244 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
245 /// std::CopyOp.
246 class SubTensorOpConverter : public OpConversionPattern<SubTensorOp> {
247 public:
248 using OpConversionPattern<SubTensorOp>::OpConversionPattern;
249
250 LogicalResult
matchAndRewrite(SubTensorOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const251 matchAndRewrite(SubTensorOp op, ArrayRef<Value> operands,
252 ConversionPatternRewriter &rewriter) const final {
253 SubTensorOpAdaptor adaptor(operands, op->getAttrDictionary());
254 Value sourceMemref = adaptor.source();
255 assert(sourceMemref.getType().isa<MemRefType>());
256
257 MemRefType subviewMemRefType =
258 getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
259 // op.sizes() capture exactly the dynamic alloc operands matching the
260 // subviewMemRefType thanks to subview/subtensor canonicalization and
261 // verification.
262 Value alloc =
263 rewriter.create<AllocOp>(op.getLoc(), subviewMemRefType, op.sizes());
264 Value subView = rewriter.create<SubViewOp>(
265 op.getLoc(), sourceMemref, extractFromI64ArrayAttr(op.static_offsets()),
266 extractFromI64ArrayAttr(op.static_sizes()),
267 extractFromI64ArrayAttr(op.static_strides()), op.offsets(), op.sizes(),
268 op.strides());
269 rewriter.create<linalg::CopyOp>(op.getLoc(), subView, alloc);
270 rewriter.replaceOp(op, alloc);
271 return success();
272 }
273 };
274
275 /// Convert `subtensor_insert %source into %dest [offsets][sizes][strides] ->
276 /// %t` to an tensor_to_memref + subview + copy + tensor_load pattern.
277 /// tensor_to_memref and tensor_load are inserted automatically by the
278 /// conversion infra:
279 /// ```
280 /// %sv = subview %dest [offsets][sizes][strides]
281 /// linalg_copy(%source, %sv)
282 /// // replace with %dest
283 /// ```
284 ///
285 /// This pattern is arguable a std pattern once linalg::CopyOp becomes
286 /// std::CopyOp.
287 class SubTensorInsertOpConverter
288 : public OpConversionPattern<SubTensorInsertOp> {
289 public:
290 using OpConversionPattern<SubTensorInsertOp>::OpConversionPattern;
291
292 LogicalResult
matchAndRewrite(SubTensorInsertOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const293 matchAndRewrite(SubTensorInsertOp op, ArrayRef<Value> operands,
294 ConversionPatternRewriter &rewriter) const final {
295 SubTensorInsertOpAdaptor adaptor(operands, op->getAttrDictionary());
296 Value sourceMemRef = adaptor.source();
297 assert(sourceMemRef.getType().isa<MemRefType>());
298
299 // For now, be conservative and copy the converted input memref.
300 // In general, the converted input memref here could be aliased or could
301 // point into constant memory, so mutating it would lead to miscompilations.
302 Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter);
303 assert(destMemRef.getType().isa<MemRefType>());
304
305 // Take a subview to copy the small memref.
306 Value subview = rewriter.create<SubViewOp>(
307 op.getLoc(), destMemRef, extractFromI64ArrayAttr(op.static_offsets()),
308 extractFromI64ArrayAttr(op.static_sizes()),
309 extractFromI64ArrayAttr(op.static_strides()), adaptor.offsets(),
310 adaptor.sizes(), adaptor.strides());
311 // Copy the small memref.
312 rewriter.create<linalg::CopyOp>(op.getLoc(), sourceMemRef, subview);
313 rewriter.replaceOp(op, destMemRef);
314 return success();
315 }
316 };
317 } // namespace
318
319 namespace {
320 /// Converts Linalg operations that work on tensor-type operands or results to
321 /// work on buffers.
322 struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
runOnOperation__anonb5143ce40311::LinalgBufferizePass323 void runOnOperation() override {
324 MLIRContext &context = getContext();
325 ConversionTarget target(context);
326 BufferizeTypeConverter typeConverter;
327
328 // Mark all Standard operations legal.
329 target.addLegalDialect<AffineDialect, StandardOpsDialect>();
330 target.addIllegalOp<SubTensorOp, SubTensorInsertOp>();
331
332 // Mark all Linalg operations illegal as long as they work on tensors.
333 auto isLegalOperation = [&](Operation *op) {
334 return typeConverter.isLegal(op);
335 };
336 target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
337 target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation);
338
339 OwningRewritePatternList patterns;
340 populateLinalgBufferizePatterns(&context, typeConverter, patterns);
341 if (failed(applyPartialConversion(getOperation(), target,
342 std::move(patterns))))
343 signalPassFailure();
344 }
345 };
346 } // end anonymous namespace
347
createLinalgBufferizePass()348 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() {
349 return std::make_unique<LinalgBufferizePass>();
350 }
351
populateLinalgBufferizePatterns(MLIRContext * context,BufferizeTypeConverter & typeConverter,OwningRewritePatternList & patterns)352 void mlir::linalg::populateLinalgBufferizePatterns(
353 MLIRContext *context, BufferizeTypeConverter &typeConverter,
354 OwningRewritePatternList &patterns) {
355 patterns.insert<BufferizeAnyLinalgOp>(typeConverter);
356 // TODO: Drop this once tensor constants work in standard.
357 patterns.insert<
358 // clang-format off
359 SubTensorOpConverter,
360 SubTensorInsertOpConverter
361 // clang-format on
362 >(typeConverter, context);
363 }
364