1 //===- AffineExpr.h - MLIR Affine Expr Class --------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // An affine expression is an affine combination of dimension identifiers and
10 // symbols, including ceildiv/floordiv/mod by a constant integer.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #ifndef MLIR_IR_AFFINE_EXPR_H
15 #define MLIR_IR_AFFINE_EXPR_H
16
17 #include "mlir/Support/LLVM.h"
18 #include "llvm/ADT/DenseMapInfo.h"
19 #include "llvm/Support/Casting.h"
20 #include <type_traits>
21
22 namespace mlir {
23
24 class MLIRContext;
25 class AffineMap;
26 class IntegerSet;
27
28 namespace detail {
29
30 struct AffineExprStorage;
31 struct AffineBinaryOpExprStorage;
32 struct AffineDimExprStorage;
33 struct AffineSymbolExprStorage;
34 struct AffineConstantExprStorage;
35
36 } // namespace detail
37
38 enum class AffineExprKind {
39 Add,
40 /// RHS of mul is always a constant or a symbolic expression.
41 Mul,
42 /// RHS of mod is always a constant or a symbolic expression with a positive
43 /// value.
44 Mod,
45 /// RHS of floordiv is always a constant or a symbolic expression.
46 FloorDiv,
47 /// RHS of ceildiv is always a constant or a symbolic expression.
48 CeilDiv,
49
50 /// This is a marker for the last affine binary op. The range of binary
51 /// op's is expected to be this element and earlier.
52 LAST_AFFINE_BINARY_OP = CeilDiv,
53
54 /// Constant integer.
55 Constant,
56 /// Dimensional identifier.
57 DimId,
58 /// Symbolic identifier.
59 SymbolId,
60 };
61
62 /// Base type for affine expression.
63 /// AffineExpr's are immutable value types with intuitive operators to
64 /// operate on chainable, lightweight compositions.
65 /// An AffineExpr is an interface to the underlying storage type pointer.
66 class AffineExpr {
67 public:
68 using ImplType = detail::AffineExprStorage;
69
AffineExpr()70 constexpr AffineExpr() : expr(nullptr) {}
AffineExpr(const ImplType * expr)71 /* implicit */ AffineExpr(const ImplType *expr)
72 : expr(const_cast<ImplType *>(expr)) {}
73
74 bool operator==(AffineExpr other) const { return expr == other.expr; }
75 bool operator!=(AffineExpr other) const { return !(*this == other); }
76 bool operator==(int64_t v) const;
77 bool operator!=(int64_t v) const { return !(*this == v); }
78 explicit operator bool() const { return expr; }
79
80 bool operator!() const { return expr == nullptr; }
81
82 template <typename U> bool isa() const;
83 template <typename U> U dyn_cast() const;
84 template <typename U> U dyn_cast_or_null() const;
85 template <typename U> U cast() const;
86
87 MLIRContext *getContext() const;
88
89 /// Return the classification for this type.
90 AffineExprKind getKind() const;
91
92 void print(raw_ostream &os) const;
93 void dump() const;
94
95 /// Returns true if this expression is made out of only symbols and
96 /// constants, i.e., it does not involve dimensional identifiers.
97 bool isSymbolicOrConstant() const;
98
99 /// Returns true if this is a pure affine expression, i.e., multiplication,
100 /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
101 bool isPureAffine() const;
102
103 /// Returns the greatest known integral divisor of this affine expression. The
104 /// result is always positive.
105 int64_t getLargestKnownDivisor() const;
106
107 /// Return true if the affine expression is a multiple of 'factor'.
108 bool isMultipleOf(int64_t factor) const;
109
110 /// Return true if the affine expression involves AffineDimExpr `position`.
111 bool isFunctionOfDim(unsigned position) const;
112
113 /// Walk all of the AffineExpr's in this expression in postorder.
114 void walk(std::function<void(AffineExpr)> callback) const;
115
116 /// This method substitutes any uses of dimensions and symbols (e.g.
117 /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
118 /// This is a dense replacement method: a replacement must be specified for
119 /// every single dim and symbol.
120 AffineExpr replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
121 ArrayRef<AffineExpr> symReplacements) const;
122
123 /// Sparse replace method. Replace `expr` by `replacement` and return the
124 /// modified expression tree.
125 AffineExpr replace(AffineExpr expr, AffineExpr replacement) const;
126
127 /// Sparse replace method. If `*this` appears in `map` replaces it by
128 /// `map[*this]` and return the modified expression tree. Otherwise traverse
129 /// `*this` and apply replace with `map` on its subexpressions.
130 AffineExpr replace(const DenseMap<AffineExpr, AffineExpr> &map) const;
131
132 /// Replace symbols[0 .. numDims - 1] by
133 /// symbols[shift .. shift + numDims - 1].
134 AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift) const;
135
136 AffineExpr operator+(int64_t v) const;
137 AffineExpr operator+(AffineExpr other) const;
138 AffineExpr operator-() const;
139 AffineExpr operator-(int64_t v) const;
140 AffineExpr operator-(AffineExpr other) const;
141 AffineExpr operator*(int64_t v) const;
142 AffineExpr operator*(AffineExpr other) const;
143 AffineExpr floorDiv(uint64_t v) const;
144 AffineExpr floorDiv(AffineExpr other) const;
145 AffineExpr ceilDiv(uint64_t v) const;
146 AffineExpr ceilDiv(AffineExpr other) const;
147 AffineExpr operator%(uint64_t v) const;
148 AffineExpr operator%(AffineExpr other) const;
149
150 /// Compose with an AffineMap.
151 /// Returns the composition of this AffineExpr with `map`.
152 ///
153 /// Prerequisites:
154 /// `this` and `map` are composable, i.e. that the number of AffineDimExpr of
155 /// `this` is smaller than the number of results of `map`. If a result of a
156 /// map does not have a corresponding AffineDimExpr, that result simply does
157 /// not appear in the produced AffineExpr.
158 ///
159 /// Example:
160 /// expr: `d0 + d2`
161 /// map: `(d0, d1, d2)[s0, s1] -> (d0 + s1, d1 + s0, d0 + d1 + d2)`
162 /// returned expr: `d0 * 2 + d1 + d2 + s1`
163 AffineExpr compose(AffineMap map) const;
164
165 friend ::llvm::hash_code hash_value(AffineExpr arg);
166
167 /// Methods supporting C API.
getAsOpaquePointer()168 const void *getAsOpaquePointer() const {
169 return static_cast<const void *>(expr);
170 }
getFromOpaquePointer(const void * pointer)171 static AffineExpr getFromOpaquePointer(const void *pointer) {
172 return AffineExpr(
173 reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
174 }
175
176 protected:
177 ImplType *expr;
178 };
179
180 /// Affine binary operation expression. An affine binary operation could be an
181 /// add, mul, floordiv, ceildiv, or a modulo operation. (Subtraction is
182 /// represented through a multiply by -1 and add.) These expressions are always
183 /// constructed in a simplified form. For eg., the LHS and RHS operands can't
184 /// both be constants. There are additional canonicalizing rules depending on
185 /// the op type: see checks in the constructor.
186 class AffineBinaryOpExpr : public AffineExpr {
187 public:
188 using ImplType = detail::AffineBinaryOpExprStorage;
189 /* implicit */ AffineBinaryOpExpr(AffineExpr::ImplType *ptr);
190 AffineExpr getLHS() const;
191 AffineExpr getRHS() const;
192 };
193
194 /// A dimensional identifier appearing in an affine expression.
195 class AffineDimExpr : public AffineExpr {
196 public:
197 using ImplType = detail::AffineDimExprStorage;
198 /* implicit */ AffineDimExpr(AffineExpr::ImplType *ptr);
199 unsigned getPosition() const;
200 };
201
202 /// A symbolic identifier appearing in an affine expression.
203 class AffineSymbolExpr : public AffineExpr {
204 public:
205 using ImplType = detail::AffineDimExprStorage;
206 /* implicit */ AffineSymbolExpr(AffineExpr::ImplType *ptr);
207 unsigned getPosition() const;
208 };
209
210 /// An integer constant appearing in affine expression.
211 class AffineConstantExpr : public AffineExpr {
212 public:
213 using ImplType = detail::AffineConstantExprStorage;
214 /* implicit */ AffineConstantExpr(AffineExpr::ImplType *ptr = nullptr);
215 int64_t getValue() const;
216 };
217
218 /// Make AffineExpr hashable.
hash_value(AffineExpr arg)219 inline ::llvm::hash_code hash_value(AffineExpr arg) {
220 return ::llvm::hash_value(arg.expr);
221 }
222
223 inline AffineExpr operator+(int64_t val, AffineExpr expr) { return expr + val; }
224 inline AffineExpr operator*(int64_t val, AffineExpr expr) { return expr * val; }
225 inline AffineExpr operator-(int64_t val, AffineExpr expr) {
226 return expr * (-1) + val;
227 }
228
229 /// These free functions allow clients of the API to not use classes in detail.
230 AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context);
231 AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context);
232 AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context);
233 AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
234 AffineExpr rhs);
235
236 /// Constructs an affine expression from a flat ArrayRef. If there are local
237 /// identifiers (neither dimensional nor symbolic) that appear in the sum of
238 /// products expression, 'localExprs' is expected to have the AffineExpr
239 /// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the
240 /// format [dims, symbols, locals, constant term].
241 AffineExpr getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
242 unsigned numDims, unsigned numSymbols,
243 ArrayRef<AffineExpr> localExprs,
244 MLIRContext *context);
245
246 raw_ostream &operator<<(raw_ostream &os, AffineExpr expr);
247
isa()248 template <typename U> bool AffineExpr::isa() const {
249 if (std::is_same<U, AffineBinaryOpExpr>::value)
250 return getKind() <= AffineExprKind::LAST_AFFINE_BINARY_OP;
251 if (std::is_same<U, AffineDimExpr>::value)
252 return getKind() == AffineExprKind::DimId;
253 if (std::is_same<U, AffineSymbolExpr>::value)
254 return getKind() == AffineExprKind::SymbolId;
255 if (std::is_same<U, AffineConstantExpr>::value)
256 return getKind() == AffineExprKind::Constant;
257 }
dyn_cast()258 template <typename U> U AffineExpr::dyn_cast() const {
259 if (isa<U>())
260 return U(expr);
261 return U(nullptr);
262 }
dyn_cast_or_null()263 template <typename U> U AffineExpr::dyn_cast_or_null() const {
264 return (!*this || !isa<U>()) ? U(nullptr) : U(expr);
265 }
cast()266 template <typename U> U AffineExpr::cast() const {
267 assert(isa<U>());
268 return U(expr);
269 }
270
271 /// Simplify an affine expression by flattening and some amount of
272 /// simple analysis. This has complexity linear in the number of nodes in
273 /// 'expr'. Returns the simplified expression, which is the same as the input
274 /// expression if it can't be simplified.
275 AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims,
276 unsigned numSymbols);
277
278 namespace detail {
bindDims(MLIRContext * ctx)279 template <int N> void bindDims(MLIRContext *ctx) {}
280
281 template <int N, typename AffineExprTy, typename... AffineExprTy2>
bindDims(MLIRContext * ctx,AffineExprTy & e,AffineExprTy2 &...exprs)282 void bindDims(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &... exprs) {
283 e = getAffineDimExpr(N, ctx);
284 bindDims<N + 1, AffineExprTy2 &...>(ctx, exprs...);
285 }
286 } // namespace detail
287
288 /// Bind a list of AffineExpr references to DimExpr at positions:
289 /// [0 .. sizeof...(exprs)]
290 template <typename... AffineExprTy>
bindDims(MLIRContext * ctx,AffineExprTy &...exprs)291 void bindDims(MLIRContext *ctx, AffineExprTy &... exprs) {
292 detail::bindDims<0>(ctx, exprs...);
293 }
294
295 } // namespace mlir
296
297 namespace llvm {
298
299 // AffineExpr hash just like pointers
300 template <> struct DenseMapInfo<mlir::AffineExpr> {
301 static mlir::AffineExpr getEmptyKey() {
302 auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
303 return mlir::AffineExpr(static_cast<mlir::AffineExpr::ImplType *>(pointer));
304 }
305 static mlir::AffineExpr getTombstoneKey() {
306 auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
307 return mlir::AffineExpr(static_cast<mlir::AffineExpr::ImplType *>(pointer));
308 }
309 static unsigned getHashValue(mlir::AffineExpr val) {
310 return mlir::hash_value(val);
311 }
312 static bool isEqual(mlir::AffineExpr LHS, mlir::AffineExpr RHS) {
313 return LHS == RHS;
314 }
315 };
316
317 } // namespace llvm
318
319 #endif // MLIR_IR_AFFINE_EXPR_H
320