1 //===- AffineExpr.h - MLIR Affine Expr Class --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // An affine expression is an affine combination of dimension identifiers and
10 // symbols, including ceildiv/floordiv/mod by a constant integer.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_IR_AFFINE_EXPR_H
15 #define MLIR_IR_AFFINE_EXPR_H
16 
17 #include "mlir/Support/LLVM.h"
18 #include "llvm/ADT/DenseMapInfo.h"
19 #include "llvm/Support/Casting.h"
20 #include <type_traits>
21 
22 namespace mlir {
23 
24 class MLIRContext;
25 class AffineMap;
26 class IntegerSet;
27 
28 namespace detail {
29 
30 struct AffineExprStorage;
31 struct AffineBinaryOpExprStorage;
32 struct AffineDimExprStorage;
33 struct AffineSymbolExprStorage;
34 struct AffineConstantExprStorage;
35 
36 } // namespace detail
37 
38 enum class AffineExprKind {
39   Add,
40   /// RHS of mul is always a constant or a symbolic expression.
41   Mul,
42   /// RHS of mod is always a constant or a symbolic expression with a positive
43   /// value.
44   Mod,
45   /// RHS of floordiv is always a constant or a symbolic expression.
46   FloorDiv,
47   /// RHS of ceildiv is always a constant or a symbolic expression.
48   CeilDiv,
49 
50   /// This is a marker for the last affine binary op. The range of binary
51   /// op's is expected to be this element and earlier.
52   LAST_AFFINE_BINARY_OP = CeilDiv,
53 
54   /// Constant integer.
55   Constant,
56   /// Dimensional identifier.
57   DimId,
58   /// Symbolic identifier.
59   SymbolId,
60 };
61 
62 /// Base type for affine expression.
63 /// AffineExpr's are immutable value types with intuitive operators to
64 /// operate on chainable, lightweight compositions.
65 /// An AffineExpr is an interface to the underlying storage type pointer.
66 class AffineExpr {
67 public:
68   using ImplType = detail::AffineExprStorage;
69 
AffineExpr()70   constexpr AffineExpr() : expr(nullptr) {}
AffineExpr(const ImplType * expr)71   /* implicit */ AffineExpr(const ImplType *expr)
72       : expr(const_cast<ImplType *>(expr)) {}
73 
74   bool operator==(AffineExpr other) const { return expr == other.expr; }
75   bool operator!=(AffineExpr other) const { return !(*this == other); }
76   bool operator==(int64_t v) const;
77   bool operator!=(int64_t v) const { return !(*this == v); }
78   explicit operator bool() const { return expr; }
79 
80   bool operator!() const { return expr == nullptr; }
81 
82   template <typename U> bool isa() const;
83   template <typename U> U dyn_cast() const;
84   template <typename U> U dyn_cast_or_null() const;
85   template <typename U> U cast() const;
86 
87   MLIRContext *getContext() const;
88 
89   /// Return the classification for this type.
90   AffineExprKind getKind() const;
91 
92   void print(raw_ostream &os) const;
93   void dump() const;
94 
95   /// Returns true if this expression is made out of only symbols and
96   /// constants, i.e., it does not involve dimensional identifiers.
97   bool isSymbolicOrConstant() const;
98 
99   /// Returns true if this is a pure affine expression, i.e., multiplication,
100   /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
101   bool isPureAffine() const;
102 
103   /// Returns the greatest known integral divisor of this affine expression. The
104   /// result is always positive.
105   int64_t getLargestKnownDivisor() const;
106 
107   /// Return true if the affine expression is a multiple of 'factor'.
108   bool isMultipleOf(int64_t factor) const;
109 
110   /// Return true if the affine expression involves AffineDimExpr `position`.
111   bool isFunctionOfDim(unsigned position) const;
112 
113   /// Walk all of the AffineExpr's in this expression in postorder.
114   void walk(std::function<void(AffineExpr)> callback) const;
115 
116   /// This method substitutes any uses of dimensions and symbols (e.g.
117   /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
118   /// This is a dense replacement method: a replacement must be specified for
119   /// every single dim and symbol.
120   AffineExpr replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
121                                    ArrayRef<AffineExpr> symReplacements) const;
122 
123   /// Sparse replace method. Replace `expr` by `replacement` and return the
124   /// modified expression tree.
125   AffineExpr replace(AffineExpr expr, AffineExpr replacement) const;
126 
127   /// Sparse replace method. If `*this` appears in `map` replaces it by
128   /// `map[*this]` and return the modified expression tree. Otherwise traverse
129   /// `*this` and apply replace with `map` on its subexpressions.
130   AffineExpr replace(const DenseMap<AffineExpr, AffineExpr> &map) const;
131 
132   /// Replace symbols[0 .. numDims - 1] by
133   /// symbols[shift .. shift + numDims - 1].
134   AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift) const;
135 
136   AffineExpr operator+(int64_t v) const;
137   AffineExpr operator+(AffineExpr other) const;
138   AffineExpr operator-() const;
139   AffineExpr operator-(int64_t v) const;
140   AffineExpr operator-(AffineExpr other) const;
141   AffineExpr operator*(int64_t v) const;
142   AffineExpr operator*(AffineExpr other) const;
143   AffineExpr floorDiv(uint64_t v) const;
144   AffineExpr floorDiv(AffineExpr other) const;
145   AffineExpr ceilDiv(uint64_t v) const;
146   AffineExpr ceilDiv(AffineExpr other) const;
147   AffineExpr operator%(uint64_t v) const;
148   AffineExpr operator%(AffineExpr other) const;
149 
150   /// Compose with an AffineMap.
151   /// Returns the composition of this AffineExpr with `map`.
152   ///
153   /// Prerequisites:
154   /// `this` and `map` are composable, i.e. that the number of AffineDimExpr of
155   /// `this` is smaller than the number of results of `map`. If a result of a
156   /// map does not have a corresponding AffineDimExpr, that result simply does
157   /// not appear in the produced AffineExpr.
158   ///
159   /// Example:
160   ///   expr: `d0 + d2`
161   ///   map:  `(d0, d1, d2)[s0, s1] -> (d0 + s1, d1 + s0, d0 + d1 + d2)`
162   ///   returned expr: `d0 * 2 + d1 + d2 + s1`
163   AffineExpr compose(AffineMap map) const;
164 
165   friend ::llvm::hash_code hash_value(AffineExpr arg);
166 
167   /// Methods supporting C API.
getAsOpaquePointer()168   const void *getAsOpaquePointer() const {
169     return static_cast<const void *>(expr);
170   }
getFromOpaquePointer(const void * pointer)171   static AffineExpr getFromOpaquePointer(const void *pointer) {
172     return AffineExpr(
173         reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
174   }
175 
176 protected:
177   ImplType *expr;
178 };
179 
180 /// Affine binary operation expression. An affine binary operation could be an
181 /// add, mul, floordiv, ceildiv, or a modulo operation. (Subtraction is
182 /// represented through a multiply by -1 and add.) These expressions are always
183 /// constructed in a simplified form. For eg., the LHS and RHS operands can't
184 /// both be constants. There are additional canonicalizing rules depending on
185 /// the op type: see checks in the constructor.
186 class AffineBinaryOpExpr : public AffineExpr {
187 public:
188   using ImplType = detail::AffineBinaryOpExprStorage;
189   /* implicit */ AffineBinaryOpExpr(AffineExpr::ImplType *ptr);
190   AffineExpr getLHS() const;
191   AffineExpr getRHS() const;
192 };
193 
194 /// A dimensional identifier appearing in an affine expression.
195 class AffineDimExpr : public AffineExpr {
196 public:
197   using ImplType = detail::AffineDimExprStorage;
198   /* implicit */ AffineDimExpr(AffineExpr::ImplType *ptr);
199   unsigned getPosition() const;
200 };
201 
202 /// A symbolic identifier appearing in an affine expression.
203 class AffineSymbolExpr : public AffineExpr {
204 public:
205   using ImplType = detail::AffineDimExprStorage;
206   /* implicit */ AffineSymbolExpr(AffineExpr::ImplType *ptr);
207   unsigned getPosition() const;
208 };
209 
210 /// An integer constant appearing in affine expression.
211 class AffineConstantExpr : public AffineExpr {
212 public:
213   using ImplType = detail::AffineConstantExprStorage;
214   /* implicit */ AffineConstantExpr(AffineExpr::ImplType *ptr = nullptr);
215   int64_t getValue() const;
216 };
217 
218 /// Make AffineExpr hashable.
hash_value(AffineExpr arg)219 inline ::llvm::hash_code hash_value(AffineExpr arg) {
220   return ::llvm::hash_value(arg.expr);
221 }
222 
223 inline AffineExpr operator+(int64_t val, AffineExpr expr) { return expr + val; }
224 inline AffineExpr operator*(int64_t val, AffineExpr expr) { return expr * val; }
225 inline AffineExpr operator-(int64_t val, AffineExpr expr) {
226   return expr * (-1) + val;
227 }
228 
229 /// These free functions allow clients of the API to not use classes in detail.
230 AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context);
231 AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context);
232 AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context);
233 AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
234                                  AffineExpr rhs);
235 
236 /// Constructs an affine expression from a flat ArrayRef. If there are local
237 /// identifiers (neither dimensional nor symbolic) that appear in the sum of
238 /// products expression, 'localExprs' is expected to have the AffineExpr
239 /// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the
240 /// format [dims, symbols, locals, constant term].
241 AffineExpr getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
242                                      unsigned numDims, unsigned numSymbols,
243                                      ArrayRef<AffineExpr> localExprs,
244                                      MLIRContext *context);
245 
246 raw_ostream &operator<<(raw_ostream &os, AffineExpr expr);
247 
isa()248 template <typename U> bool AffineExpr::isa() const {
249   if (std::is_same<U, AffineBinaryOpExpr>::value)
250     return getKind() <= AffineExprKind::LAST_AFFINE_BINARY_OP;
251   if (std::is_same<U, AffineDimExpr>::value)
252     return getKind() == AffineExprKind::DimId;
253   if (std::is_same<U, AffineSymbolExpr>::value)
254     return getKind() == AffineExprKind::SymbolId;
255   if (std::is_same<U, AffineConstantExpr>::value)
256     return getKind() == AffineExprKind::Constant;
257 }
dyn_cast()258 template <typename U> U AffineExpr::dyn_cast() const {
259   if (isa<U>())
260     return U(expr);
261   return U(nullptr);
262 }
dyn_cast_or_null()263 template <typename U> U AffineExpr::dyn_cast_or_null() const {
264   return (!*this || !isa<U>()) ? U(nullptr) : U(expr);
265 }
cast()266 template <typename U> U AffineExpr::cast() const {
267   assert(isa<U>());
268   return U(expr);
269 }
270 
271 /// Simplify an affine expression by flattening and some amount of
272 /// simple analysis. This has complexity linear in the number of nodes in
273 /// 'expr'. Returns the simplified expression, which is the same as the input
274 ///  expression if it can't be simplified.
275 AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims,
276                               unsigned numSymbols);
277 
278 namespace detail {
bindDims(MLIRContext * ctx)279 template <int N> void bindDims(MLIRContext *ctx) {}
280 
281 template <int N, typename AffineExprTy, typename... AffineExprTy2>
bindDims(MLIRContext * ctx,AffineExprTy & e,AffineExprTy2 &...exprs)282 void bindDims(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &... exprs) {
283   e = getAffineDimExpr(N, ctx);
284   bindDims<N + 1, AffineExprTy2 &...>(ctx, exprs...);
285 }
286 } // namespace detail
287 
288 /// Bind a list of AffineExpr references to DimExpr at positions:
289 ///   [0 .. sizeof...(exprs)]
290 template <typename... AffineExprTy>
bindDims(MLIRContext * ctx,AffineExprTy &...exprs)291 void bindDims(MLIRContext *ctx, AffineExprTy &... exprs) {
292   detail::bindDims<0>(ctx, exprs...);
293 }
294 
295 } // namespace mlir
296 
297 namespace llvm {
298 
299 // AffineExpr hash just like pointers
300 template <> struct DenseMapInfo<mlir::AffineExpr> {
301   static mlir::AffineExpr getEmptyKey() {
302     auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
303     return mlir::AffineExpr(static_cast<mlir::AffineExpr::ImplType *>(pointer));
304   }
305   static mlir::AffineExpr getTombstoneKey() {
306     auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
307     return mlir::AffineExpr(static_cast<mlir::AffineExpr::ImplType *>(pointer));
308   }
309   static unsigned getHashValue(mlir::AffineExpr val) {
310     return mlir::hash_value(val);
311   }
312   static bool isEqual(mlir::AffineExpr LHS, mlir::AffineExpr RHS) {
313     return LHS == RHS;
314   }
315 };
316 
317 } // namespace llvm
318 
319 #endif // MLIR_IR_AFFINE_EXPR_H
320