1 //===- Pattern.h - Pattern wrapper class ------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_TABLEGEN_PATTERN_H_
15 #define MLIR_TABLEGEN_PATTERN_H_
16 
17 #include "mlir/Support/LLVM.h"
18 #include "mlir/TableGen/Argument.h"
19 #include "mlir/TableGen/Operator.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/StringMap.h"
22 #include "llvm/ADT/StringSet.h"
23 
24 #include <unordered_map>
25 
26 namespace llvm {
27 class DagInit;
28 class Init;
29 class Record;
30 } // end namespace llvm
31 
32 namespace mlir {
33 namespace tblgen {
34 
35 // Mapping from TableGen Record to Operator wrapper object.
36 //
37 // We allocate each wrapper object in heap to make sure the pointer to it is
38 // valid throughout the lifetime of this map. This is important because this map
39 // is shared among multiple patterns to avoid creating the wrapper object for
40 // the same op again and again. But this map will continuously grow.
41 using RecordOperatorMap =
42     DenseMap<const llvm::Record *, std::unique_ptr<Operator>>;
43 
44 class Pattern;
45 
46 // Wrapper class providing helper methods for accessing TableGen DAG leaves
47 // used inside Patterns. This class is lightweight and designed to be used like
48 // values.
49 //
50 // A TableGen DAG construct is of the syntax
51 //   `(operator, arg0, arg1, ...)`.
52 //
53 // This class provides getters to retrieve `arg*` as tblgen:: wrapper objects
54 // for handy helper methods. It only works on `arg*`s that are not nested DAG
55 // constructs.
56 class DagLeaf {
57 public:
DagLeaf(const llvm::Init * def)58   explicit DagLeaf(const llvm::Init *def) : def(def) {}
59 
60   // Returns true if this DAG leaf is not specified in the pattern. That is, it
61   // places no further constraints/transforms and just carries over the original
62   // value.
63   bool isUnspecified() const;
64 
65   // Returns true if this DAG leaf is matching an operand. That is, it specifies
66   // a type constraint.
67   bool isOperandMatcher() const;
68 
69   // Returns true if this DAG leaf is matching an attribute. That is, it
70   // specifies an attribute constraint.
71   bool isAttrMatcher() const;
72 
73   // Returns true if this DAG leaf is wrapping native code call.
74   bool isNativeCodeCall() const;
75 
76   // Returns true if this DAG leaf is specifying a constant attribute.
77   bool isConstantAttr() const;
78 
79   // Returns true if this DAG leaf is specifying an enum attribute case.
80   bool isEnumAttrCase() const;
81 
82   // Returns true if this DAG leaf is specifying a string attribute.
83   bool isStringAttr() const;
84 
85   // Returns this DAG leaf as a constraint. Asserts if fails.
86   Constraint getAsConstraint() const;
87 
88   // Returns this DAG leaf as an constant attribute. Asserts if fails.
89   ConstantAttr getAsConstantAttr() const;
90 
91   // Returns this DAG leaf as an enum attribute case.
92   // Precondition: isEnumAttrCase()
93   EnumAttrCase getAsEnumAttrCase() const;
94 
95   // Returns the matching condition template inside this DAG leaf. Assumes the
96   // leaf is an operand/attribute matcher and asserts otherwise.
97   std::string getConditionTemplate() const;
98 
99   // Returns the native code call template inside this DAG leaf.
100   // Precondition: isNativeCodeCall()
101   StringRef getNativeCodeTemplate() const;
102 
103   // Returns the string associated with the leaf.
104   // Precondition: isStringAttr()
105   std::string getStringAttr() const;
106 
107   void print(raw_ostream &os) const;
108 
109 private:
110   // Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and
111   // also a subclass of the given `superclass`.
112   bool isSubClassOf(StringRef superclass) const;
113 
114   const llvm::Init *def;
115 };
116 
117 // Wrapper class providing helper methods for accessing TableGen DAG constructs
118 // used inside Patterns. This class is lightweight and designed to be used like
119 // values.
120 //
121 // A TableGen DAG construct is of the syntax
122 //   `(operator, arg0, arg1, ...)`.
123 //
124 // When used inside Patterns, `operator` corresponds to some dialect op, or
125 // a known list of verbs that defines special transformation actions. This
126 // `arg*` can be a nested DAG construct. This class provides getters to
127 // retrieve `operator` and `arg*` as tblgen:: wrapper objects for handy helper
128 // methods.
129 //
130 // A null DagNode contains a nullptr and converts to false implicitly.
131 class DagNode {
132 public:
DagNode(const llvm::DagInit * node)133   explicit DagNode(const llvm::DagInit *node) : node(node) {}
134 
135   // Implicit bool converter that returns true if this DagNode is not a null
136   // DagNode.
137   operator bool() const { return node != nullptr; }
138 
139   // Returns the symbol bound to this DAG node.
140   StringRef getSymbol() const;
141 
142   // Returns the operator wrapper object corresponding to the dialect op matched
143   // by this DAG. The operator wrapper will be queried from the given `mapper`
144   // and created in it if not existing.
145   Operator &getDialectOp(RecordOperatorMap *mapper) const;
146 
147   // Returns the number of operations recursively involved in the DAG tree
148   // rooted from this node.
149   int getNumOps() const;
150 
151   // Returns the number of immediate arguments to this DAG node.
152   int getNumArgs() const;
153 
154   // Returns true if the `index`-th argument is a nested DAG construct.
155   bool isNestedDagArg(unsigned index) const;
156 
157   // Gets the `index`-th argument as a nested DAG construct if possible. Returns
158   // null DagNode otherwise.
159   DagNode getArgAsNestedDag(unsigned index) const;
160 
161   // Gets the `index`-th argument as a DAG leaf.
162   DagLeaf getArgAsLeaf(unsigned index) const;
163 
164   // Returns the specified name of the `index`-th argument.
165   StringRef getArgName(unsigned index) const;
166 
167   // Returns true if this DAG construct means to replace with an existing SSA
168   // value.
169   bool isReplaceWithValue() const;
170 
171   // Returns whether this DAG represents the location of an op creation.
172   bool isLocationDirective() const;
173 
174   // Returns true if this DAG node is wrapping native code call.
175   bool isNativeCodeCall() const;
176 
177   // Returns true if this DAG node is an operation.
178   bool isOperation() const;
179 
180   // Returns the native code call template inside this DAG node.
181   // Precondition: isNativeCodeCall()
182   StringRef getNativeCodeTemplate() const;
183 
184   void print(raw_ostream &os) const;
185 
186 private:
187   const llvm::DagInit *node; // nullptr means null DagNode
188 };
189 
190 // A class for maintaining information for symbols bound in patterns and
191 // provides methods for resolving them according to specific use cases.
192 //
193 // Symbols can be bound to
194 //
195 // * Op arguments and op results in the source pattern and
196 // * Op results in result patterns.
197 //
198 // Symbols can be referenced in result patterns and additional constraints to
199 // the pattern.
200 //
201 // For example, in
202 //
203 // ```
204 // def : Pattern<
205 //     (SrcOp:$results1 $arg0, %arg1),
206 //     [(ResOp1:$results2), (ResOp2 $results2 (ResOp3 $arg0, $arg1))]>;
207 // ```
208 //
209 // `$argN` is bound to the `SrcOp`'s N-th argument. `$results1` is bound to
210 // `SrcOp`. `$results2` is bound to `ResOp1`. $result2 is referenced to build
211 // `ResOp2`. `$arg0` and `$arg1` are referenced to build `ResOp3`.
212 //
213 // If a symbol binds to a multi-result op and it does not have the `__N`
214 // suffix, the symbol is expanded to represent all results generated by the
215 // multi-result op. If the symbol has a `__N` suffix, then it will expand to
216 // only the N-th *static* result as declared in ODS, and that can still
217 // corresponds to multiple *dynamic* values if the N-th *static* result is
218 // variadic.
219 //
220 // This class keeps track of such symbols and resolves them into their bound
221 // values in a suitable way.
222 class SymbolInfoMap {
223 public:
SymbolInfoMap(ArrayRef<llvm::SMLoc> loc)224   explicit SymbolInfoMap(ArrayRef<llvm::SMLoc> loc) : loc(loc) {}
225 
226   // Class for information regarding a symbol.
227   class SymbolInfo {
228   public:
229     // Returns a string for defining a variable named as `name` to store the
230     // value bound by this symbol.
231     std::string getVarDecl(StringRef name) const;
232 
233     // Returns a variable name for the symbol named as `name`.
234     std::string getVarName(StringRef name) const;
235 
236   private:
237     // Allow SymbolInfoMap to access private methods.
238     friend class SymbolInfoMap;
239 
240     // What kind of entity this symbol represents:
241     // * Attr: op attribute
242     // * Operand: op operand
243     // * Result: op result
244     // * Value: a value not attached to an op (e.g., from NativeCodeCall)
245     enum class Kind : uint8_t { Attr, Operand, Result, Value };
246 
247     // Creates a SymbolInfo instance. `index` is only used for `Attr` and
248     // `Operand` so should be negative for `Result` and `Value` kind.
249     SymbolInfo(const Operator *op, Kind kind, Optional<int> index);
250 
251     // Static methods for creating SymbolInfo.
getAttr(const Operator * op,int index)252     static SymbolInfo getAttr(const Operator *op, int index) {
253       return SymbolInfo(op, Kind::Attr, index);
254     }
getAttr()255     static SymbolInfo getAttr() {
256       return SymbolInfo(nullptr, Kind::Attr, llvm::None);
257     }
getOperand(const Operator * op,int index)258     static SymbolInfo getOperand(const Operator *op, int index) {
259       return SymbolInfo(op, Kind::Operand, index);
260     }
getResult(const Operator * op)261     static SymbolInfo getResult(const Operator *op) {
262       return SymbolInfo(op, Kind::Result, llvm::None);
263     }
getValue()264     static SymbolInfo getValue() {
265       return SymbolInfo(nullptr, Kind::Value, llvm::None);
266     }
267 
268     // Returns the number of static values this symbol corresponds to.
269     // A static value is an operand/result declared in ODS. Normally a symbol
270     // only represents one static value, but symbols bound to op results can
271     // represent more than one if the op is a multi-result op.
272     int getStaticValueCount() const;
273 
274     // Returns a string containing the C++ expression for referencing this
275     // symbol as a value (if this symbol represents one static value) or a value
276     // range (if this symbol represents multiple static values). `name` is the
277     // name of the C++ variable that this symbol bounds to. `index` should only
278     // be used for indexing results.  `fmt` is used to format each value.
279     // `separator` is used to separate values if this is a value range.
280     std::string getValueAndRangeUse(StringRef name, int index, const char *fmt,
281                                     const char *separator) const;
282 
283     // Returns a string containing the C++ expression for referencing this
284     // symbol as a value range regardless of how many static values this symbol
285     // represents. `name` is the name of the C++ variable that this symbol
286     // bounds to. `index` should only be used for indexing results. `fmt` is
287     // used to format each value. `separator` is used to separate values in the
288     // range.
289     std::string getAllRangeUse(StringRef name, int index, const char *fmt,
290                                const char *separator) const;
291 
292     const Operator *op; // The op where the bound entity belongs
293     Kind kind;          // The kind of the bound entity
294     // The argument index (for `Attr` and `Operand` only)
295     Optional<int> argIndex;
296     // Alternative name for the symbol. It is used in case the name
297     // is not unique. Applicable for `Operand` only.
298     Optional<std::string> alternativeName;
299   };
300 
301   using BaseT = std::unordered_multimap<std::string, SymbolInfo>;
302 
303   // Iterators for accessing all symbols.
304   using iterator = BaseT::iterator;
begin()305   iterator begin() { return symbolInfoMap.begin(); }
end()306   iterator end() { return symbolInfoMap.end(); }
307 
308   // Const iterators for accessing all symbols.
309   using const_iterator = BaseT::const_iterator;
begin()310   const_iterator begin() const { return symbolInfoMap.begin(); }
end()311   const_iterator end() const { return symbolInfoMap.end(); }
312 
313   // Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
314   // Returns false if `symbol` is already bound and symbols are not operands.
315   bool bindOpArgument(StringRef symbol, const Operator &op, int argIndex);
316 
317   // Binds the given `symbol` to the results the given `op`. Returns false if
318   // `symbol` is already bound.
319   bool bindOpResult(StringRef symbol, const Operator &op);
320 
321   // Registers the given `symbol` as bound to a value. Returns false if `symbol`
322   // is already bound.
323   bool bindValue(StringRef symbol);
324 
325   // Registers the given `symbol` as bound to an attr. Returns false if `symbol`
326   // is already bound.
327   bool bindAttr(StringRef symbol);
328 
329   // Returns true if the given `symbol` is bound.
330   bool contains(StringRef symbol) const;
331 
332   // Returns an iterator to the information of the given symbol named as `key`.
333   const_iterator find(StringRef key) const;
334 
335   // Returns an iterator to the information of the given symbol named as `key`,
336   // with index `argIndex` for operator `op`.
337   const_iterator findBoundSymbol(StringRef key, const Operator &op,
338                                  int argIndex) const;
339 
340   // Returns the bounds of a range that includes all the elements which
341   // bind to the `key`.
342   std::pair<iterator, iterator> getRangeOfEqualElements(StringRef key);
343 
344   // Returns number of times symbol named as `key` was used.
345   int count(StringRef key) const;
346 
347   // Returns the number of static values of the given `symbol` corresponds to.
348   // A static value is an operand/result declared in ODS. Normally a symbol only
349   // represents one static value, but symbols bound to op results can represent
350   // more than one if the op is a multi-result op.
351   int getStaticValueCount(StringRef symbol) const;
352 
353   // Returns a string containing the C++ expression for referencing this
354   // symbol as a value (if this symbol represents one static value) or a value
355   // range (if this symbol represents multiple static values). `fmt` is used to
356   // format each value. `separator` is used to separate values if `symbol`
357   // represents a value range.
358   std::string getValueAndRangeUse(StringRef symbol, const char *fmt = "{0}",
359                                   const char *separator = ", ") const;
360 
361   // Returns a string containing the C++ expression for referencing this
362   // symbol as a value range regardless of how many static values this symbol
363   // represents. `fmt` is used to format each value. `separator` is used to
364   // separate values in the range.
365   std::string getAllRangeUse(StringRef symbol, const char *fmt = "{0}",
366                              const char *separator = ", ") const;
367 
368   // Assign alternative unique names to Operands that have equal names.
369   void assignUniqueAlternativeNames();
370 
371   // Splits the given `symbol` into a value pack name and an index. Returns the
372   // value pack name and writes the index to `index` on success. Returns
373   // `symbol` itself if it does not contain an index.
374   //
375   // We can use `name__N` to access the `N`-th value in the value pack bound to
376   // `name`. `name` is typically the results of an multi-result op.
377   static StringRef getValuePackName(StringRef symbol, int *index = nullptr);
378 
379 private:
380   BaseT symbolInfoMap;
381 
382   // Pattern instantiation location. This is intended to be used as parameter
383   // to PrintFatalError() to report errors.
384   ArrayRef<llvm::SMLoc> loc;
385 };
386 
387 // Wrapper class providing helper methods for accessing MLIR Pattern defined
388 // in TableGen. This class should closely reflect what is defined as class
389 // `Pattern` in TableGen. This class contains maps so it is not intended to be
390 // used as values.
391 class Pattern {
392 public:
393   explicit Pattern(const llvm::Record *def, RecordOperatorMap *mapper);
394 
395   // Returns the source pattern to match.
396   DagNode getSourcePattern() const;
397 
398   // Returns the number of result patterns generated by applying this rewrite
399   // rule.
400   int getNumResultPatterns() const;
401 
402   // Returns the DAG tree root node of the `index`-th result pattern.
403   DagNode getResultPattern(unsigned index) const;
404 
405   // Collects all symbols bound in the source pattern into `infoMap`.
406   void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap);
407 
408   // Collects all symbols bound in result patterns into `infoMap`.
409   void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap);
410 
411   // Returns the op that the root node of the source pattern matches.
412   const Operator &getSourceRootOp();
413 
414   // Returns the operator wrapper object corresponding to the given `node`'s DAG
415   // operator.
416   Operator &getDialectOp(DagNode node);
417 
418   // Returns the constraints.
419   std::vector<AppliedConstraint> getConstraints() const;
420 
421   // Returns the benefit score of the pattern.
422   int getBenefit() const;
423 
424   using IdentifierLine = std::pair<StringRef, unsigned>;
425 
426   // Returns the file location of the pattern (buffer identifier + line number
427   // pair).
428   std::vector<IdentifierLine> getLocation() const;
429 
430 private:
431   // Helper function to verify variabld binding.
432   void verifyBind(bool result, StringRef symbolName);
433 
434   // Recursively collects all bound symbols inside the DAG tree rooted
435   // at `tree` and updates the given `infoMap`.
436   void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
437                            bool isSrcPattern);
438 
439   // The TableGen definition of this pattern.
440   const llvm::Record &def;
441 
442   // All operators.
443   // TODO: we need a proper context manager, like MLIRContext, for managing the
444   // lifetime of shared entities.
445   RecordOperatorMap *recordOpMap;
446 };
447 
448 } // end namespace tblgen
449 } // end namespace mlir
450 
451 #endif // MLIR_TABLEGEN_PATTERN_H_
452