1 //===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/Bufferize.h"
10 #include "PassDetail.h"
11 #include "mlir/IR/Operation.h"
12 #include "mlir/Transforms/Passes.h"
13 
14 using namespace mlir;
15 
16 //===----------------------------------------------------------------------===//
17 // BufferizeTypeConverter
18 //===----------------------------------------------------------------------===//
19 
materializeTensorLoad(OpBuilder & builder,TensorType type,ValueRange inputs,Location loc)20 static Value materializeTensorLoad(OpBuilder &builder, TensorType type,
21                                    ValueRange inputs, Location loc) {
22   assert(inputs.size() == 1);
23   assert(inputs[0].getType().isa<BaseMemRefType>());
24   return builder.create<TensorLoadOp>(loc, type, inputs[0]);
25 }
26 
27 /// Registers conversions into BufferizeTypeConverter
BufferizeTypeConverter()28 BufferizeTypeConverter::BufferizeTypeConverter() {
29   // Keep all types unchanged.
30   addConversion([](Type type) { return type; });
31   // Convert RankedTensorType to MemRefType.
32   addConversion([](RankedTensorType type) -> Type {
33     return MemRefType::get(type.getShape(), type.getElementType());
34   });
35   // Convert UnrankedTensorType to UnrankedMemRefType.
36   addConversion([](UnrankedTensorType type) -> Type {
37     return UnrankedMemRefType::get(type.getElementType(), 0);
38   });
39   addArgumentMaterialization(materializeTensorLoad);
40   addSourceMaterialization(materializeTensorLoad);
41   addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
42                               ValueRange inputs, Location loc) -> Value {
43     assert(inputs.size() == 1);
44     assert(inputs[0].getType().isa<TensorType>());
45     return builder.create<TensorToMemrefOp>(loc, type, inputs[0]);
46   });
47 }
48 
populateBufferizeMaterializationLegality(ConversionTarget & target)49 void mlir::populateBufferizeMaterializationLegality(ConversionTarget &target) {
50   target.addLegalOp<TensorLoadOp, TensorToMemrefOp>();
51 }
52 
53 namespace {
54 // In a finalizing bufferize conversion, we know that all tensors have been
55 // converted to memrefs, thus, this op becomes an identity.
56 class BufferizeTensorLoadOp : public OpConversionPattern<TensorLoadOp> {
57 public:
58   using OpConversionPattern::OpConversionPattern;
59   LogicalResult
matchAndRewrite(TensorLoadOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const60   matchAndRewrite(TensorLoadOp op, ArrayRef<Value> operands,
61                   ConversionPatternRewriter &rewriter) const override {
62     TensorLoadOp::Adaptor adaptor(operands);
63     rewriter.replaceOp(op, adaptor.memref());
64     return success();
65   }
66 };
67 } // namespace
68 
69 namespace {
70 // In a finalizing bufferize conversion, we know that all tensors have been
71 // converted to memrefs, thus, this op becomes an identity.
72 class BufferizeTensorToMemrefOp : public OpConversionPattern<TensorToMemrefOp> {
73 public:
74   using OpConversionPattern::OpConversionPattern;
75   LogicalResult
matchAndRewrite(TensorToMemrefOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const76   matchAndRewrite(TensorToMemrefOp op, ArrayRef<Value> operands,
77                   ConversionPatternRewriter &rewriter) const override {
78     TensorToMemrefOp::Adaptor adaptor(operands);
79     rewriter.replaceOp(op, adaptor.tensor());
80     return success();
81   }
82 };
83 } // namespace
84 
populateEliminateBufferizeMaterializationsPatterns(MLIRContext * context,BufferizeTypeConverter & typeConverter,OwningRewritePatternList & patterns)85 void mlir::populateEliminateBufferizeMaterializationsPatterns(
86     MLIRContext *context, BufferizeTypeConverter &typeConverter,
87     OwningRewritePatternList &patterns) {
88   patterns.insert<BufferizeTensorLoadOp, BufferizeTensorToMemrefOp>(
89       typeConverter, context);
90 }
91 
92 namespace {
93 struct FinalizingBufferizePass
94     : public FinalizingBufferizeBase<FinalizingBufferizePass> {
95   using FinalizingBufferizeBase<
96       FinalizingBufferizePass>::FinalizingBufferizeBase;
97 
runOnFunction__anonafa32b390711::FinalizingBufferizePass98   void runOnFunction() override {
99     auto func = getFunction();
100     auto *context = &getContext();
101 
102     BufferizeTypeConverter typeConverter;
103     OwningRewritePatternList patterns;
104     ConversionTarget target(*context);
105 
106     populateEliminateBufferizeMaterializationsPatterns(context, typeConverter,
107                                                        patterns);
108 
109     // If all result types are legal, and all block arguments are legal (ensured
110     // by func conversion above), then all types in the program are legal.
111     //
112     // We also check that the operand types are legal to avoid creating invalid
113     // IR. For example, this prevents
114     // populateEliminateBufferizeMaterializationsPatterns from updating the
115     // types of the operands to a return op without updating the enclosing
116     // function.
117     target.markUnknownOpDynamicallyLegal(
118         [&](Operation *op) { return typeConverter.isLegal(op); });
119 
120     if (failed(applyFullConversion(func, target, std::move(patterns))))
121       signalPassFailure();
122   }
123 };
124 } // namespace
125 
createFinalizingBufferizePass()126 std::unique_ptr<FunctionPass> mlir::createFinalizingBufferizePass() {
127   return std::make_unique<FinalizingBufferizePass>();
128 }
129