1 //===- ByteCode.h - Pattern byte-code interpreter ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares a byte-code and interpreter for pattern rewrites in MLIR.
10 // The byte-code is constructed from the PDL Interpreter dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_REWRITE_BYTECODE_H_
15 #define MLIR_REWRITE_BYTECODE_H_
16 
17 #include "mlir/IR/PatternMatch.h"
18 
19 namespace mlir {
20 namespace pdl_interp {
21 class RecordMatchOp;
22 } // end namespace pdl_interp
23 
24 namespace detail {
25 class PDLByteCode;
26 
27 /// Use generic bytecode types. ByteCodeField refers to the actual bytecode
28 /// entries (set to uint8_t for "byte" bytecode). ByteCodeAddr refers to size of
29 /// indices into the bytecode. Correctness is checked with static asserts.
30 using ByteCodeField = uint16_t;
31 using ByteCodeAddr = uint32_t;
32 
33 //===----------------------------------------------------------------------===//
34 // PDLByteCodePattern
35 //===----------------------------------------------------------------------===//
36 
37 /// All of the data pertaining to a specific pattern within the bytecode.
38 class PDLByteCodePattern : public Pattern {
39 public:
40   static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp,
41                                    ByteCodeAddr rewriterAddr);
42 
43   /// Return the bytecode address of the rewriter for this pattern.
getRewriterAddr()44   ByteCodeAddr getRewriterAddr() const { return rewriterAddr; }
45 
46 private:
47   template <typename... Args>
PDLByteCodePattern(ByteCodeAddr rewriterAddr,Args &&...patternArgs)48   PDLByteCodePattern(ByteCodeAddr rewriterAddr, Args &&...patternArgs)
49       : Pattern(std::forward<Args>(patternArgs)...),
50         rewriterAddr(rewriterAddr) {}
51 
52   /// The address of the rewriter for this pattern.
53   ByteCodeAddr rewriterAddr;
54 };
55 
56 //===----------------------------------------------------------------------===//
57 // PDLByteCodeMutableState
58 //===----------------------------------------------------------------------===//
59 
60 /// This class contains the mutable state of a bytecode instance. This allows
61 /// for a bytecode instance to be cached and reused across various different
62 /// threads/drivers.
63 class PDLByteCodeMutableState {
64 public:
65   /// Initialize the state from a bytecode instance.
66   void initialize(PDLByteCode &bytecode);
67 
68   /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
69   /// to the position of the pattern within the range returned by
70   /// `PDLByteCode::getPatterns`.
71   void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit);
72 
73 private:
74   /// Allow access to data fields.
75   friend class PDLByteCode;
76 
77   /// The mutable block of memory used during the matching and rewriting phases
78   /// of the bytecode.
79   std::vector<const void *> memory;
80 
81   /// The up-to-date benefits of the patterns held by the bytecode. The order
82   /// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`.
83   std::vector<PatternBenefit> currentPatternBenefits;
84 };
85 
86 //===----------------------------------------------------------------------===//
87 // PDLByteCode
88 //===----------------------------------------------------------------------===//
89 
90 /// The bytecode class is also the interpreter. Contains the bytecode itself,
91 /// the static info, addresses of the rewriter functions, the interpreter
92 /// memory buffer, and the execution context.
93 class PDLByteCode {
94 public:
95   /// Each successful match returns a MatchResult, which contains information
96   /// necessary to execute the rewriter and indicates the originating pattern.
97   struct MatchResult {
MatchResultMatchResult98     MatchResult(Location loc, const PDLByteCodePattern &pattern,
99                 PatternBenefit benefit)
100         : location(loc), pattern(&pattern), benefit(benefit) {}
101 
102     /// The location of operations to be replaced.
103     Location location;
104     /// Memory values defined in the matcher that are passed to the rewriter.
105     SmallVector<const void *, 4> values;
106     /// The originating pattern that was matched. This is always non-null, but
107     /// represented with a pointer to allow for assignment.
108     const PDLByteCodePattern *pattern;
109     /// The current benefit of the pattern that was matched.
110     PatternBenefit benefit;
111   };
112 
113   /// Create a ByteCode instance from the given module containing operations in
114   /// the PDL interpreter dialect.
115   PDLByteCode(ModuleOp module,
116               llvm::StringMap<PDLConstraintFunction> constraintFns,
117               llvm::StringMap<PDLCreateFunction> createFns,
118               llvm::StringMap<PDLRewriteFunction> rewriteFns);
119 
120   /// Return the patterns held by the bytecode.
getPatterns()121   ArrayRef<PDLByteCodePattern> getPatterns() const { return patterns; }
122 
123   /// Initialize the given state such that it can be used to execute the current
124   /// bytecode.
125   void initializeMutableState(PDLByteCodeMutableState &state) const;
126 
127   /// Run the pattern matcher on the given root operation, collecting the
128   /// matched patterns in `matches`.
129   void match(Operation *op, PatternRewriter &rewriter,
130              SmallVectorImpl<MatchResult> &matches,
131              PDLByteCodeMutableState &state) const;
132 
133   /// Run the rewriter of the given pattern that was previously matched in
134   /// `match`.
135   void rewrite(PatternRewriter &rewriter, const MatchResult &match,
136                PDLByteCodeMutableState &state) const;
137 
138 private:
139   /// Execute the given byte code starting at the provided instruction `inst`.
140   /// `matches` is an optional field provided when this function is executed in
141   /// a matching context.
142   void executeByteCode(const ByteCodeField *inst, PatternRewriter &rewriter,
143                        PDLByteCodeMutableState &state,
144                        SmallVectorImpl<MatchResult> *matches) const;
145 
146   /// A vector containing pointers to unqiued data. The storage is intentionally
147   /// opaque such that we can store a wide range of data types. The types of
148   /// data stored here include:
149   ///  * Attribute, Identifier, OperationName, Type
150   std::vector<const void *> uniquedData;
151 
152   /// A vector containing the generated bytecode for the matcher.
153   SmallVector<ByteCodeField, 64> matcherByteCode;
154 
155   /// A vector containing the generated bytecode for all of the rewriters.
156   SmallVector<ByteCodeField, 64> rewriterByteCode;
157 
158   /// The set of patterns contained within the bytecode.
159   SmallVector<PDLByteCodePattern, 32> patterns;
160 
161   /// A set of user defined functions invoked via PDL.
162   std::vector<PDLConstraintFunction> constraintFunctions;
163   std::vector<PDLCreateFunction> createFunctions;
164   std::vector<PDLRewriteFunction> rewriteFunctions;
165 
166   /// The maximum memory index used by a value.
167   ByteCodeField maxValueMemoryIndex = 0;
168 };
169 
170 } // end namespace detail
171 } // end namespace mlir
172 
173 #endif // MLIR_REWRITE_BYTECODE_H_
174