1 //===- Predicate.h - Pattern predicates -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains definitions for "predicates" used when converting PDL into
10 // a matcher tree. Predicates are composed of three different parts:
11 //
12 //  * Positions
13 //    - A position refers to a specific location on the input DAG, i.e. an
14 //      existing MLIR entity being matched. These can be attributes, operands,
15 //      operations, results, and types. Each position also defines a relation to
16 //      its parent. For example, the operand `[0] -> 1` has a parent operation
17 //      position `[0]`. The attribute `[0, 1] -> "myAttr"` has parent operation
18 //      position of `[0, 1]`. The operation `[0, 1]` has a parent operand edge
19 //      `[0] -> 1` (i.e. it is the defining op of operand 1). The only position
20 //      without a parent is `[0]`, which refers to the root operation.
21 //  * Questions
22 //    - A question refers to a query on a specific positional value. For
23 //    example, an operation name question checks the name of an operation
24 //    position.
25 //  * Answers
26 //    - An answer is the expected result of a question. For example, when
27 //    matching an operation with the name "foo.op". The question would be an
28 //    operation name question, with an expected answer of "foo.op".
29 //
30 //===----------------------------------------------------------------------===//
31 
32 #ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_
33 #define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_
34 
35 #include "mlir/IR/MLIRContext.h"
36 #include "mlir/IR/OperationSupport.h"
37 #include "mlir/IR/PatternMatch.h"
38 #include "mlir/IR/Types.h"
39 
40 namespace mlir {
41 namespace pdl_to_pdl_interp {
42 namespace Predicates {
43 /// An enumeration of the kinds of predicates.
44 enum Kind : unsigned {
45   /// Positions, ordered by decreasing priority.
46   OperationPos,
47   OperandPos,
48   AttributePos,
49   ResultPos,
50   TypePos,
51 
52   // Questions, ordered by dependency and decreasing priority.
53   IsNotNullQuestion,
54   OperationNameQuestion,
55   TypeQuestion,
56   AttributeQuestion,
57   OperandCountQuestion,
58   ResultCountQuestion,
59   EqualToQuestion,
60   ConstraintQuestion,
61 
62   // Answers.
63   AttributeAnswer,
64   TrueAnswer,
65   OperationNameAnswer,
66   TypeAnswer,
67   UnsignedAnswer,
68 };
69 } // end namespace Predicates
70 
71 /// Base class for all predicates, used to allow efficient pointer comparison.
72 template <typename ConcreteT, typename BaseT, typename Key,
73           Predicates::Kind Kind>
74 class PredicateBase : public BaseT {
75 public:
76   using KeyTy = Key;
77   using Base = PredicateBase<ConcreteT, BaseT, Key, Kind>;
78 
79   template <typename KeyT>
PredicateBase(KeyT && key)80   explicit PredicateBase(KeyT &&key)
81       : BaseT(Kind), key(std::forward<KeyT>(key)) {}
82 
83   /// Get an instance of this position.
84   template <typename... Args>
get(StorageUniquer & uniquer,Args &&...args)85   static ConcreteT *get(StorageUniquer &uniquer, Args &&...args) {
86     return uniquer.get<ConcreteT>(/*initFn=*/{}, std::forward<Args>(args)...);
87   }
88 
89   /// Construct an instance with the given storage allocator.
90   template <typename KeyT>
construct(StorageUniquer::StorageAllocator & alloc,KeyT && key)91   static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc,
92                               KeyT &&key) {
93     return new (alloc.allocate<ConcreteT>()) ConcreteT(std::forward<KeyT>(key));
94   }
95 
96   /// Utility methods required by the storage allocator.
97   bool operator==(const KeyTy &key) const { return this->key == key; }
classof(const BaseT * pred)98   static bool classof(const BaseT *pred) { return pred->getKind() == Kind; }
99 
100   /// Return the key value of this predicate.
getValue()101   const KeyTy &getValue() const { return key; }
102 
103 protected:
104   KeyTy key;
105 };
106 
107 /// Base storage for simple predicates that only unique with the kind.
108 template <typename ConcreteT, typename BaseT, Predicates::Kind Kind>
109 class PredicateBase<ConcreteT, BaseT, void, Kind> : public BaseT {
110 public:
111   using Base = PredicateBase<ConcreteT, BaseT, void, Kind>;
112 
PredicateBase()113   explicit PredicateBase() : BaseT(Kind) {}
114 
get(StorageUniquer & uniquer)115   static ConcreteT *get(StorageUniquer &uniquer) {
116     return uniquer.get<ConcreteT>();
117   }
classof(const BaseT * pred)118   static bool classof(const BaseT *pred) { return pred->getKind() == Kind; }
119 };
120 
121 //===----------------------------------------------------------------------===//
122 // Positions
123 //===----------------------------------------------------------------------===//
124 
125 struct OperationPosition;
126 
127 /// A position describes a value on the input IR on which a predicate may be
128 /// applied, such as an operation or attribute. This enables re-use between
129 /// predicates, and assists generating bytecode and memory management.
130 ///
131 /// Operation positions form the base of other positions, which are formed
132 /// relative to a parent operation, e.g. OperandPosition<[0] -> 1>. Operations
133 /// are indexed by child index: [0, 1, 2] refers to the 3rd child of the 2nd
134 /// child of the root operation.
135 ///
136 /// Positions are linked to their parent position, which describes how to obtain
137 /// a positional value. As a concrete example, getting OperationPosition<[0, 1]>
138 /// would be `root->getOperand(1)->getDefiningOp()`, so its parent is
139 /// OperandPosition<[0] -> 1>, whose parent is OperationPosition<[0]>.
140 class Position : public StorageUniquer::BaseStorage {
141 public:
Position(Predicates::Kind kind)142   explicit Position(Predicates::Kind kind) : kind(kind) {}
143   virtual ~Position();
144 
145   /// Returns the base node position. This is an array of indices.
146   virtual ArrayRef<unsigned> getIndex() const = 0;
147 
148   /// Returns the parent position. The root operation position has no parent.
getParent()149   Position *getParent() const { return parent; }
150 
151   /// Returns the kind of this position.
getKind()152   Predicates::Kind getKind() const { return kind; }
153 
154 protected:
155   /// Link to the parent position.
156   Position *parent = nullptr;
157 
158 private:
159   /// The kind of this position.
160   Predicates::Kind kind;
161 };
162 
163 //===----------------------------------------------------------------------===//
164 // AttributePosition
165 
166 /// A position describing an attribute of an operation.
167 struct AttributePosition
168     : public PredicateBase<AttributePosition, Position,
169                            std::pair<OperationPosition *, Identifier>,
170                            Predicates::AttributePos> {
171   explicit AttributePosition(const KeyTy &key);
172 
173   /// Returns the index of this position.
getIndexAttributePosition174   ArrayRef<unsigned> getIndex() const final { return parent->getIndex(); }
175 
176   /// Returns the attribute name of this position.
getNameAttributePosition177   Identifier getName() const { return key.second; }
178 };
179 
180 //===----------------------------------------------------------------------===//
181 // OperandPosition
182 
183 /// A position describing an operand of an operation.
184 struct OperandPosition
185     : public PredicateBase<OperandPosition, Position,
186                            std::pair<OperationPosition *, unsigned>,
187                            Predicates::OperandPos> {
188   explicit OperandPosition(const KeyTy &key);
189 
190   /// Returns the index of this position.
getIndexOperandPosition191   ArrayRef<unsigned> getIndex() const final { return parent->getIndex(); }
192 
193   /// Returns the operand number of this position.
getOperandNumberOperandPosition194   unsigned getOperandNumber() const { return key.second; }
195 };
196 
197 //===----------------------------------------------------------------------===//
198 // OperationPosition
199 
200 /// An operation position describes an operation node in the IR. Other position
201 /// kinds are formed with respect to an operation position.
202 struct OperationPosition
203     : public PredicateBase<OperationPosition, Position, ArrayRef<unsigned>,
204                            Predicates::OperationPos> {
205   using Base::Base;
206 
207   /// Gets the root position, which is always [0].
getRootOperationPosition208   static OperationPosition *getRoot(StorageUniquer &uniquer) {
209     return get(uniquer, ArrayRef<unsigned>(0));
210   }
211   /// Gets a node position for the given index.
212   static OperationPosition *get(StorageUniquer &uniquer,
213                                 ArrayRef<unsigned> index);
214 
215   /// Constructs an instance with the given storage allocator.
constructOperationPosition216   static OperationPosition *construct(StorageUniquer::StorageAllocator &alloc,
217                                       ArrayRef<unsigned> key) {
218     return Base::construct(alloc, alloc.copyInto(key));
219   }
220 
221   /// Returns the index of this position.
getIndexOperationPosition222   ArrayRef<unsigned> getIndex() const final { return key; }
223 
224   /// Returns if this operation position corresponds to the root.
isRootOperationPosition225   bool isRoot() const { return key.size() == 1 && key[0] == 0; }
226 };
227 
228 //===----------------------------------------------------------------------===//
229 // ResultPosition
230 
231 /// A position describing a result of an operation.
232 struct ResultPosition
233     : public PredicateBase<ResultPosition, Position,
234                            std::pair<OperationPosition *, unsigned>,
235                            Predicates::ResultPos> {
ResultPositionResultPosition236   explicit ResultPosition(const KeyTy &key) : Base(key) { parent = key.first; }
237 
238   /// Returns the index of this position.
getIndexResultPosition239   ArrayRef<unsigned> getIndex() const final { return key.first->getIndex(); }
240 
241   /// Returns the result number of this position.
getResultNumberResultPosition242   unsigned getResultNumber() const { return key.second; }
243 };
244 
245 //===----------------------------------------------------------------------===//
246 // TypePosition
247 
248 /// A position describing the result type of an entity, i.e. an Attribute,
249 /// Operand, Result, etc.
250 struct TypePosition : public PredicateBase<TypePosition, Position, Position *,
251                                            Predicates::TypePos> {
TypePositionTypePosition252   explicit TypePosition(const KeyTy &key) : Base(key) {
253     assert((isa<AttributePosition>(key) || isa<OperandPosition>(key) ||
254             isa<ResultPosition>(key)) &&
255            "expected parent to be an attribute, operand, or result");
256     parent = key;
257   }
258 
259   /// Returns the index of this position.
getIndexTypePosition260   ArrayRef<unsigned> getIndex() const final { return key->getIndex(); }
261 };
262 
263 //===----------------------------------------------------------------------===//
264 // Qualifiers
265 //===----------------------------------------------------------------------===//
266 
267 /// An ordinal predicate consists of a "Question" and a set of acceptable
268 /// "Answers" (later converted to ordinal values). A predicate will query some
269 /// property of a positional value and decide what to do based on the result.
270 ///
271 /// This makes top-level predicate representations ordinal (SwitchOp). Later,
272 /// predicates that end up with only one acceptable answer (including all
273 /// boolean kinds) will be converted to boolean predicates (PredicateOp) in the
274 /// matcher.
275 ///
276 /// For simplicity, both are represented as "qualifiers", with a base kind and
277 /// perhaps additional properties. For example, all OperationName predicates ask
278 /// the same question, but GenericConstraint predicates may ask different ones.
279 class Qualifier : public StorageUniquer::BaseStorage {
280 public:
Qualifier(Predicates::Kind kind)281   explicit Qualifier(Predicates::Kind kind) : kind(kind) {}
282 
283   /// Returns the kind of this qualifier.
getKind()284   Predicates::Kind getKind() const { return kind; }
285 
286 private:
287   /// The kind of this position.
288   Predicates::Kind kind;
289 };
290 
291 //===----------------------------------------------------------------------===//
292 // Answers
293 
294 /// An Answer representing an `Attribute` value.
295 struct AttributeAnswer
296     : public PredicateBase<AttributeAnswer, Qualifier, Attribute,
297                            Predicates::AttributeAnswer> {
298   using Base::Base;
299 };
300 
301 /// An Answer representing an `OperationName` value.
302 struct OperationNameAnswer
303     : public PredicateBase<OperationNameAnswer, Qualifier, OperationName,
304                            Predicates::OperationNameAnswer> {
305   using Base::Base;
306 };
307 
308 /// An Answer representing a boolean `true` value.
309 struct TrueAnswer
310     : PredicateBase<TrueAnswer, Qualifier, void, Predicates::TrueAnswer> {
311   using Base::Base;
312 };
313 
314 /// An Answer representing a `Type` value.
315 struct TypeAnswer : public PredicateBase<TypeAnswer, Qualifier, Type,
316                                          Predicates::TypeAnswer> {
317   using Base::Base;
318 };
319 
320 /// An Answer representing an unsigned value.
321 struct UnsignedAnswer
322     : public PredicateBase<UnsignedAnswer, Qualifier, unsigned,
323                            Predicates::UnsignedAnswer> {
324   using Base::Base;
325 };
326 
327 //===----------------------------------------------------------------------===//
328 // Questions
329 
330 /// Compare an `Attribute` to a constant value.
331 struct AttributeQuestion
332     : public PredicateBase<AttributeQuestion, Qualifier, void,
333                            Predicates::AttributeQuestion> {};
334 
335 /// Apply a parameterized constraint to multiple position values.
336 struct ConstraintQuestion
337     : public PredicateBase<
338           ConstraintQuestion, Qualifier,
339           std::tuple<StringRef, ArrayRef<Position *>, Attribute>,
340           Predicates::ConstraintQuestion> {
341   using Base::Base;
342 
343   /// Construct an instance with the given storage allocator.
constructConstraintQuestion344   static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
345                                        KeyTy key) {
346     return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
347                                         alloc.copyInto(std::get<1>(key)),
348                                         std::get<2>(key)});
349   }
350 };
351 
352 /// Compare the equality of two values.
353 struct EqualToQuestion
354     : public PredicateBase<EqualToQuestion, Qualifier, Position *,
355                            Predicates::EqualToQuestion> {
356   using Base::Base;
357 };
358 
359 /// Compare a positional value with null, i.e. check if it exists.
360 struct IsNotNullQuestion
361     : public PredicateBase<IsNotNullQuestion, Qualifier, void,
362                            Predicates::IsNotNullQuestion> {};
363 
364 /// Compare the number of operands of an operation with a known value.
365 struct OperandCountQuestion
366     : public PredicateBase<OperandCountQuestion, Qualifier, void,
367                            Predicates::OperandCountQuestion> {};
368 
369 /// Compare the name of an operation with a known value.
370 struct OperationNameQuestion
371     : public PredicateBase<OperationNameQuestion, Qualifier, void,
372                            Predicates::OperationNameQuestion> {};
373 
374 /// Compare the number of results of an operation with a known value.
375 struct ResultCountQuestion
376     : public PredicateBase<ResultCountQuestion, Qualifier, void,
377                            Predicates::ResultCountQuestion> {};
378 
379 /// Compare the type of an attribute or value with a known type.
380 struct TypeQuestion : public PredicateBase<TypeQuestion, Qualifier, void,
381                                            Predicates::TypeQuestion> {};
382 
383 //===----------------------------------------------------------------------===//
384 // PredicateUniquer
385 //===----------------------------------------------------------------------===//
386 
387 /// This class provides a storage uniquer that is used to allocate predicate
388 /// instances.
389 class PredicateUniquer : public StorageUniquer {
390 public:
PredicateUniquer()391   PredicateUniquer() {
392     // Register the types of Positions with the uniquer.
393     registerParametricStorageType<AttributePosition>();
394     registerParametricStorageType<OperandPosition>();
395     registerParametricStorageType<OperationPosition>();
396     registerParametricStorageType<ResultPosition>();
397     registerParametricStorageType<TypePosition>();
398 
399     // Register the types of Questions with the uniquer.
400     registerParametricStorageType<AttributeAnswer>();
401     registerParametricStorageType<OperationNameAnswer>();
402     registerParametricStorageType<TypeAnswer>();
403     registerParametricStorageType<UnsignedAnswer>();
404     registerSingletonStorageType<TrueAnswer>();
405 
406     // Register the types of Answers with the uniquer.
407     registerParametricStorageType<ConstraintQuestion>();
408     registerParametricStorageType<EqualToQuestion>();
409     registerSingletonStorageType<AttributeQuestion>();
410     registerSingletonStorageType<IsNotNullQuestion>();
411     registerSingletonStorageType<OperandCountQuestion>();
412     registerSingletonStorageType<OperationNameQuestion>();
413     registerSingletonStorageType<ResultCountQuestion>();
414     registerSingletonStorageType<TypeQuestion>();
415   }
416 };
417 
418 //===----------------------------------------------------------------------===//
419 // PredicateBuilder
420 //===----------------------------------------------------------------------===//
421 
422 /// This class provides utilties for constructing predicates.
423 class PredicateBuilder {
424 public:
PredicateBuilder(PredicateUniquer & uniquer,MLIRContext * ctx)425   PredicateBuilder(PredicateUniquer &uniquer, MLIRContext *ctx)
426       : uniquer(uniquer), ctx(ctx) {}
427 
428   //===--------------------------------------------------------------------===//
429   // Positions
430   //===--------------------------------------------------------------------===//
431 
432   /// Returns the root operation position.
getRoot()433   Position *getRoot() { return OperationPosition::getRoot(uniquer); }
434 
435   /// Returns the parent position defining the value held by the given operand.
getParent(OperandPosition * p)436   Position *getParent(OperandPosition *p) {
437     std::vector<unsigned> index = p->getIndex();
438     index.push_back(p->getOperandNumber());
439     return OperationPosition::get(uniquer, index);
440   }
441 
442   /// Returns an attribute position for an attribute of the given operation.
getAttribute(OperationPosition * p,StringRef name)443   Position *getAttribute(OperationPosition *p, StringRef name) {
444     return AttributePosition::get(uniquer, p, Identifier::get(name, ctx));
445   }
446 
447   /// Returns an operand position for an operand of the given operation.
getOperand(OperationPosition * p,unsigned operand)448   Position *getOperand(OperationPosition *p, unsigned operand) {
449     return OperandPosition::get(uniquer, p, operand);
450   }
451 
452   /// Returns a result position for a result of the given operation.
getResult(OperationPosition * p,unsigned result)453   Position *getResult(OperationPosition *p, unsigned result) {
454     return ResultPosition::get(uniquer, p, result);
455   }
456 
457   /// Returns a type position for the given entity.
getType(Position * p)458   Position *getType(Position *p) { return TypePosition::get(uniquer, p); }
459 
460   //===--------------------------------------------------------------------===//
461   // Qualifiers
462   //===--------------------------------------------------------------------===//
463 
464   /// An ordinal predicate consists of a "Question" and a set of acceptable
465   /// "Answers" (later converted to ordinal values). A predicate will query some
466   /// property of a positional value and decide what to do based on the result.
467   using Predicate = std::pair<Qualifier *, Qualifier *>;
468 
469   /// Create a predicate comparing an attribute to a known value.
getAttributeConstraint(Attribute attr)470   Predicate getAttributeConstraint(Attribute attr) {
471     return {AttributeQuestion::get(uniquer),
472             AttributeAnswer::get(uniquer, attr)};
473   }
474 
475   /// Create a predicate comparing two values.
getEqualTo(Position * pos)476   Predicate getEqualTo(Position *pos) {
477     return {EqualToQuestion::get(uniquer, pos), TrueAnswer::get(uniquer)};
478   }
479 
480   /// Create a predicate that applies a generic constraint.
getConstraint(StringRef name,ArrayRef<Position * > pos,Attribute params)481   Predicate getConstraint(StringRef name, ArrayRef<Position *> pos,
482                           Attribute params) {
483     return {
484         ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, params)),
485         TrueAnswer::get(uniquer)};
486   }
487 
488   /// Create a predicate comparing a value with null.
getIsNotNull()489   Predicate getIsNotNull() {
490     return {IsNotNullQuestion::get(uniquer), TrueAnswer::get(uniquer)};
491   }
492 
493   /// Create a predicate comparing the number of operands of an operation to a
494   /// known value.
getOperandCount(unsigned count)495   Predicate getOperandCount(unsigned count) {
496     return {OperandCountQuestion::get(uniquer),
497             UnsignedAnswer::get(uniquer, count)};
498   }
499 
500   /// Create a predicate comparing the name of an operation to a known value.
getOperationName(StringRef name)501   Predicate getOperationName(StringRef name) {
502     return {OperationNameQuestion::get(uniquer),
503             OperationNameAnswer::get(uniquer, OperationName(name, ctx))};
504   }
505 
506   /// Create a predicate comparing the number of results of an operation to a
507   /// known value.
getResultCount(unsigned count)508   Predicate getResultCount(unsigned count) {
509     return {ResultCountQuestion::get(uniquer),
510             UnsignedAnswer::get(uniquer, count)};
511   }
512 
513   /// Create a predicate comparing the type of an attribute or value to a known
514   /// type.
getTypeConstraint(Type type)515   Predicate getTypeConstraint(Type type) {
516     return {TypeQuestion::get(uniquer), TypeAnswer::get(uniquer, type)};
517   }
518 
519 private:
520   /// The uniquer used when allocating predicate nodes.
521   PredicateUniquer &uniquer;
522 
523   /// The current MLIR context.
524   MLIRContext *ctx;
525 };
526 
527 } // end namespace pdl_to_pdl_interp
528 } // end namespace mlir
529 
530 #endif // MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_
531