1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Transforms/Bufferize.h"
10 #include "PassDetail.h"
11 #include "mlir/IR/Operation.h"
12 #include "mlir/Transforms/Passes.h"
13
14 using namespace mlir;
15
16 //===----------------------------------------------------------------------===//
17 // BufferizeTypeConverter
18 //===----------------------------------------------------------------------===//
19
materializeTensorLoad(OpBuilder & builder,TensorType type,ValueRange inputs,Location loc)20 static Value materializeTensorLoad(OpBuilder &builder, TensorType type,
21 ValueRange inputs, Location loc) {
22 assert(inputs.size() == 1);
23 assert(inputs[0].getType().isa<BaseMemRefType>());
24 return builder.create<TensorLoadOp>(loc, type, inputs[0]);
25 }
26
27 /// Registers conversions into BufferizeTypeConverter
BufferizeTypeConverter()28 BufferizeTypeConverter::BufferizeTypeConverter() {
29 // Keep all types unchanged.
30 addConversion([](Type type) { return type; });
31 // Convert RankedTensorType to MemRefType.
32 addConversion([](RankedTensorType type) -> Type {
33 return MemRefType::get(type.getShape(), type.getElementType());
34 });
35 // Convert UnrankedTensorType to UnrankedMemRefType.
36 addConversion([](UnrankedTensorType type) -> Type {
37 return UnrankedMemRefType::get(type.getElementType(), 0);
38 });
39 addArgumentMaterialization(materializeTensorLoad);
40 addSourceMaterialization(materializeTensorLoad);
41 addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
42 ValueRange inputs, Location loc) -> Value {
43 assert(inputs.size() == 1);
44 assert(inputs[0].getType().isa<TensorType>());
45 return builder.create<TensorToMemrefOp>(loc, type, inputs[0]);
46 });
47 }
48
populateBufferizeMaterializationLegality(ConversionTarget & target)49 void mlir::populateBufferizeMaterializationLegality(ConversionTarget &target) {
50 target.addLegalOp<TensorLoadOp, TensorToMemrefOp>();
51 }
52
53 namespace {
54 // In a finalizing bufferize conversion, we know that all tensors have been
55 // converted to memrefs, thus, this op becomes an identity.
56 class BufferizeTensorLoadOp : public OpConversionPattern<TensorLoadOp> {
57 public:
58 using OpConversionPattern::OpConversionPattern;
59 LogicalResult
matchAndRewrite(TensorLoadOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const60 matchAndRewrite(TensorLoadOp op, ArrayRef<Value> operands,
61 ConversionPatternRewriter &rewriter) const override {
62 TensorLoadOp::Adaptor adaptor(operands);
63 rewriter.replaceOp(op, adaptor.memref());
64 return success();
65 }
66 };
67 } // namespace
68
69 namespace {
70 // In a finalizing bufferize conversion, we know that all tensors have been
71 // converted to memrefs, thus, this op becomes an identity.
72 class BufferizeTensorToMemrefOp : public OpConversionPattern<TensorToMemrefOp> {
73 public:
74 using OpConversionPattern::OpConversionPattern;
75 LogicalResult
matchAndRewrite(TensorToMemrefOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const76 matchAndRewrite(TensorToMemrefOp op, ArrayRef<Value> operands,
77 ConversionPatternRewriter &rewriter) const override {
78 TensorToMemrefOp::Adaptor adaptor(operands);
79 rewriter.replaceOp(op, adaptor.tensor());
80 return success();
81 }
82 };
83 } // namespace
84
populateEliminateBufferizeMaterializationsPatterns(MLIRContext * context,BufferizeTypeConverter & typeConverter,OwningRewritePatternList & patterns)85 void mlir::populateEliminateBufferizeMaterializationsPatterns(
86 MLIRContext *context, BufferizeTypeConverter &typeConverter,
87 OwningRewritePatternList &patterns) {
88 patterns.insert<BufferizeTensorLoadOp, BufferizeTensorToMemrefOp>(
89 typeConverter, context);
90 }
91
92 namespace {
93 struct FinalizingBufferizePass
94 : public FinalizingBufferizeBase<FinalizingBufferizePass> {
95 using FinalizingBufferizeBase<
96 FinalizingBufferizePass>::FinalizingBufferizeBase;
97
runOnFunction__anonafa32b390711::FinalizingBufferizePass98 void runOnFunction() override {
99 auto func = getFunction();
100 auto *context = &getContext();
101
102 BufferizeTypeConverter typeConverter;
103 OwningRewritePatternList patterns;
104 ConversionTarget target(*context);
105
106 populateEliminateBufferizeMaterializationsPatterns(context, typeConverter,
107 patterns);
108
109 // If all result types are legal, and all block arguments are legal (ensured
110 // by func conversion above), then all types in the program are legal.
111 //
112 // We also check that the operand types are legal to avoid creating invalid
113 // IR. For example, this prevents
114 // populateEliminateBufferizeMaterializationsPatterns from updating the
115 // types of the operands to a return op without updating the enclosing
116 // function.
117 target.markUnknownOpDynamicallyLegal(
118 [&](Operation *op) { return typeConverter.isLegal(op); });
119
120 if (failed(applyFullConversion(func, target, std::move(patterns))))
121 signalPassFailure();
122 }
123 };
124 } // namespace
125
createFinalizingBufferizePass()126 std::unique_ptr<FunctionPass> mlir::createFinalizingBufferizePass() {
127 return std::make_unique<FinalizingBufferizePass>();
128 }
129