1 //===- AffineExprVisitor.h - MLIR AffineExpr Visitor Class ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the AffineExpr visitor class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_AFFINE_EXPR_VISITOR_H
14 #define MLIR_IR_AFFINE_EXPR_VISITOR_H
15 
16 #include "mlir/IR/AffineExpr.h"
17 
18 namespace mlir {
19 
20 /// Base class for AffineExpr visitors/walkers.
21 ///
22 /// AffineExpr visitors are used when you want to perform different actions
23 /// for different kinds of AffineExprs without having to use lots of casts
24 /// and a big switch instruction.
25 ///
26 /// To define your own visitor, inherit from this class, specifying your
27 /// new type for the 'SubClass' template parameter, and "override" visitXXX
28 /// functions in your class. This class is defined in terms of statically
29 /// resolved overloading, not virtual functions.
30 ///
31 /// For example, here is a visitor that counts the number of for AffineDimExprs
32 /// in an AffineExpr.
33 ///
34 ///  /// Declare the class.  Note that we derive from AffineExprVisitor
35 ///  /// instantiated with our new subclasses_ type.
36 ///
37 ///  struct DimExprCounter : public AffineExprVisitor<DimExprCounter> {
38 ///    unsigned numDimExprs;
39 ///    DimExprCounter() : numDimExprs(0) {}
40 ///    void visitDimExpr(AffineDimExpr expr) { ++numDimExprs; }
41 ///  };
42 ///
43 ///  And this class would be used like this:
44 ///    DimExprCounter dec;
45 ///    dec.visit(affineExpr);
46 ///    numDimExprs = dec.numDimExprs;
47 ///
48 /// AffineExprVisitor provides visit methods for the following binary affine
49 /// op expressions:
50 /// AffineBinaryAddOpExpr, AffineBinaryMulOpExpr,
51 /// AffineBinaryModOpExpr, AffineBinaryFloorDivOpExpr,
52 /// AffineBinaryCeilDivOpExpr. Note that default implementations of these
53 /// methods will call the general AffineBinaryOpExpr method.
54 ///
55 /// In addition, visit methods are provided for the following affine
56 //  expressions: AffineConstantExpr, AffineDimExpr, and
57 //  AffineSymbolExpr.
58 ///
59 /// Note that if you don't implement visitXXX for some affine expression type,
60 /// the visitXXX method for Instruction superclass will be invoked.
61 ///
62 /// Note that this class is specifically designed as a template to avoid
63 /// virtual function call overhead. Defining and using a AffineExprVisitor is
64 /// just as efficient as having your own switch instruction over the instruction
65 /// opcode.
66 
67 template <typename SubClass, typename RetTy = void> class AffineExprVisitor {
68   //===--------------------------------------------------------------------===//
69   // Interface code - This is the public interface of the AffineExprVisitor
70   // that you use to visit affine expressions...
71 public:
72   // Function to walk an AffineExpr (in post order).
walkPostOrder(AffineExpr expr)73   RetTy walkPostOrder(AffineExpr expr) {
74     static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
75                   "Must instantiate with a derived type of AffineExprVisitor");
76     switch (expr.getKind()) {
77     case AffineExprKind::Add: {
78       auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
79       walkOperandsPostOrder(binOpExpr);
80       return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
81     }
82     case AffineExprKind::Mul: {
83       auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
84       walkOperandsPostOrder(binOpExpr);
85       return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
86     }
87     case AffineExprKind::Mod: {
88       auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
89       walkOperandsPostOrder(binOpExpr);
90       return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
91     }
92     case AffineExprKind::FloorDiv: {
93       auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
94       walkOperandsPostOrder(binOpExpr);
95       return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
96     }
97     case AffineExprKind::CeilDiv: {
98       auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
99       walkOperandsPostOrder(binOpExpr);
100       return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
101     }
102     case AffineExprKind::Constant:
103       return static_cast<SubClass *>(this)->visitConstantExpr(
104           expr.cast<AffineConstantExpr>());
105     case AffineExprKind::DimId:
106       return static_cast<SubClass *>(this)->visitDimExpr(
107           expr.cast<AffineDimExpr>());
108     case AffineExprKind::SymbolId:
109       return static_cast<SubClass *>(this)->visitSymbolExpr(
110           expr.cast<AffineSymbolExpr>());
111     }
112   }
113 
114   // Function to visit an AffineExpr.
visit(AffineExpr expr)115   RetTy visit(AffineExpr expr) {
116     static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
117                   "Must instantiate with a derived type of AffineExprVisitor");
118     switch (expr.getKind()) {
119     case AffineExprKind::Add: {
120       auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
121       return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
122     }
123     case AffineExprKind::Mul: {
124       auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
125       return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
126     }
127     case AffineExprKind::Mod: {
128       auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
129       return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
130     }
131     case AffineExprKind::FloorDiv: {
132       auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
133       return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
134     }
135     case AffineExprKind::CeilDiv: {
136       auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
137       return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
138     }
139     case AffineExprKind::Constant:
140       return static_cast<SubClass *>(this)->visitConstantExpr(
141           expr.cast<AffineConstantExpr>());
142     case AffineExprKind::DimId:
143       return static_cast<SubClass *>(this)->visitDimExpr(
144           expr.cast<AffineDimExpr>());
145     case AffineExprKind::SymbolId:
146       return static_cast<SubClass *>(this)->visitSymbolExpr(
147           expr.cast<AffineSymbolExpr>());
148     }
149     llvm_unreachable("Unknown AffineExpr");
150   }
151 
152   //===--------------------------------------------------------------------===//
153   // Visitation functions... these functions provide default fallbacks in case
154   // the user does not specify what to do for a particular instruction type.
155   // The default behavior is to generalize the instruction type to its subtype
156   // and try visiting the subtype.  All of this should be inlined perfectly,
157   // because there are no virtual functions to get in the way.
158   //
159 
160   // Default visit methods. Note that the default op-specific binary op visit
161   // methods call the general visitAffineBinaryOpExpr visit method.
visitAffineBinaryOpExpr(AffineBinaryOpExpr expr)162   void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {}
visitAddExpr(AffineBinaryOpExpr expr)163   void visitAddExpr(AffineBinaryOpExpr expr) {
164     static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
165   }
visitMulExpr(AffineBinaryOpExpr expr)166   void visitMulExpr(AffineBinaryOpExpr expr) {
167     static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
168   }
visitModExpr(AffineBinaryOpExpr expr)169   void visitModExpr(AffineBinaryOpExpr expr) {
170     static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
171   }
visitFloorDivExpr(AffineBinaryOpExpr expr)172   void visitFloorDivExpr(AffineBinaryOpExpr expr) {
173     static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
174   }
visitCeilDivExpr(AffineBinaryOpExpr expr)175   void visitCeilDivExpr(AffineBinaryOpExpr expr) {
176     static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
177   }
visitConstantExpr(AffineConstantExpr expr)178   void visitConstantExpr(AffineConstantExpr expr) {}
visitDimExpr(AffineDimExpr expr)179   void visitDimExpr(AffineDimExpr expr) {}
visitSymbolExpr(AffineSymbolExpr expr)180   void visitSymbolExpr(AffineSymbolExpr expr) {}
181 
182 private:
183   // Walk the operands - each operand is itself walked in post order.
walkOperandsPostOrder(AffineBinaryOpExpr expr)184   void walkOperandsPostOrder(AffineBinaryOpExpr expr) {
185     walkPostOrder(expr.getLHS());
186     walkPostOrder(expr.getRHS());
187   }
188 };
189 
190 // This class is used to flatten a pure affine expression (AffineExpr,
191 // which is in a tree form) into a sum of products (w.r.t constants) when
192 // possible, and in that process simplifying the expression. For a modulo,
193 // floordiv, or a ceildiv expression, an additional identifier, called a local
194 // identifier, is introduced to rewrite the expression as a sum of product
195 // affine expression. Each local identifier is always and by construction a
196 // floordiv of a pure add/mul affine function of dimensional, symbolic, and
197 // other local identifiers, in a non-mutually recursive way. Hence, every local
198 // identifier can ultimately always be recovered as an affine function of
199 // dimensional and symbolic identifiers (involving floordiv's); note however
200 // that by AffineExpr construction, some floordiv combinations are converted to
201 // mod's. The result of the flattening is a flattened expression and a set of
202 // constraints involving just the local variables.
203 //
204 // d2 + (d0 + d1) floordiv 4  is flattened to d2 + q where 'q' is the local
205 // variable introduced, with localVarCst containing 4*q <= d0 + d1 <= 4*q + 3.
206 //
207 // The simplification performed includes the accumulation of contributions for
208 // each dimensional and symbolic identifier together, the simplification of
209 // floordiv/ceildiv/mod expressions and other simplifications that in turn
210 // happen as a result. A simplification that this flattening naturally performs
211 // is of simplifying the numerator and denominator of floordiv/ceildiv, and
212 // folding a modulo expression to a zero, if possible. Three examples are below:
213 //
214 // (d0 + 3 * d1) + d0) - 2 * d1) - d0    simplified to     d0 + d1
215 // (d0 - d0 mod 4 + 4) mod 4             simplified to     0
216 // (3*d0 + 2*d1 + d0) floordiv 2 + d1    simplified to     2*d0 + 2*d1
217 //
218 // The way the flattening works for the second example is as follows: d0 % 4 is
219 // replaced by d0 - 4*q with q being introduced: the expression then simplifies
220 // to: (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to
221 // zero. Note that an affine expression may not always be expressible purely as
222 // a sum of products involving just the original dimensional and symbolic
223 // identifiers due to the presence of modulo/floordiv/ceildiv expressions that
224 // may not be eliminated after simplification; in such cases, the final
225 // expression can be reconstructed by replacing the local identifiers with their
226 // corresponding explicit form stored in 'localExprs' (note that each of the
227 // explicit forms itself would have been simplified).
228 //
229 // The expression walk method here performs a linear time post order walk that
230 // performs the above simplifications through visit methods, with partial
231 // results being stored in 'operandExprStack'. When a parent expr is visited,
232 // the flattened expressions corresponding to its two operands would already be
233 // on the stack - the parent expression looks at the two flattened expressions
234 // and combines the two. It pops off the operand expressions and pushes the
235 // combined result (although this is done in-place on its LHS operand expr).
236 // When the walk is completed, the flattened form of the top-level expression
237 // would be left on the stack.
238 //
239 // A flattener can be repeatedly used for multiple affine expressions that bind
240 // to the same operands, for example, for all result expressions of an
241 // AffineMap or AffineValueMap. In such cases, using it for multiple expressions
242 // is more efficient than creating a new flattener for each expression since
243 // common identical div and mod expressions appearing across different
244 // expressions are mapped to the same local identifier (same column position in
245 // 'localVarCst').
246 class SimpleAffineExprFlattener
247     : public AffineExprVisitor<SimpleAffineExprFlattener> {
248 public:
249   // Flattend expression layout: [dims, symbols, locals, constant]
250   // Stack that holds the LHS and RHS operands while visiting a binary op expr.
251   // In future, consider adding a prepass to determine how big the SmallVector's
252   // will be, and linearize this to std::vector<int64_t> to prevent
253   // SmallVector moves on re-allocation.
254   std::vector<SmallVector<int64_t, 8>> operandExprStack;
255 
256   unsigned numDims;
257   unsigned numSymbols;
258 
259   // Number of newly introduced identifiers to flatten mod/floordiv/ceildiv's.
260   unsigned numLocals;
261 
262   // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for
263   // which new identifiers were introduced; if the latter do not get canceled
264   // out, these expressions can be readily used to reconstruct the AffineExpr
265   // (tree) form. Note that these expressions themselves would have been
266   // simplified (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4
267   // will be simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1)
268   // ceildiv 2 would be the local expression stored for q.
269   SmallVector<AffineExpr, 4> localExprs;
270 
271   SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols);
272 
273   virtual ~SimpleAffineExprFlattener() = default;
274 
275   // Visitor method overrides.
276   void visitMulExpr(AffineBinaryOpExpr expr);
277   void visitAddExpr(AffineBinaryOpExpr expr);
278   void visitDimExpr(AffineDimExpr expr);
279   void visitSymbolExpr(AffineSymbolExpr expr);
280   void visitConstantExpr(AffineConstantExpr expr);
281   void visitCeilDivExpr(AffineBinaryOpExpr expr);
282   void visitFloorDivExpr(AffineBinaryOpExpr expr);
283 
284   //
285   // t = expr mod c   <=>  t = expr - c*q and c*q <= expr <= c*q + c - 1
286   //
287   // A mod expression "expr mod c" is thus flattened by introducing a new local
288   // variable q (= expr floordiv c), such that expr mod c is replaced with
289   // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
290   void visitModExpr(AffineBinaryOpExpr expr);
291 
292 protected:
293   // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
294   // The local identifier added is always a floordiv of a pure add/mul affine
295   // function of other identifiers, coefficients of which are specified in
296   // dividend and with respect to a positive constant divisor. localExpr is the
297   // simplified tree expression (AffineExpr) corresponding to the quantifier.
298   virtual void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
299                                   AffineExpr localExpr);
300 
301 private:
302   // t = expr floordiv c   <=> t = q, c * q <= expr <= c * q + c - 1
303   // A floordiv is thus flattened by introducing a new local variable q, and
304   // replacing that expression with 'q' while adding the constraints
305   // c * q <= expr <= c * q + c - 1 to localVarCst (done by
306   // FlatAffineConstraints::addLocalFloorDiv).
307   //
308   // A ceildiv is similarly flattened:
309   // t = expr ceildiv c   <=> t =  (expr + c - 1) floordiv c
310   void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil);
311 
312   int findLocalId(AffineExpr localExpr);
313 
getNumCols()314   inline unsigned getNumCols() const {
315     return numDims + numSymbols + numLocals + 1;
316   }
getConstantIndex()317   inline unsigned getConstantIndex() const { return getNumCols() - 1; }
getLocalVarStartIndex()318   inline unsigned getLocalVarStartIndex() const { return numDims + numSymbols; }
getSymbolStartIndex()319   inline unsigned getSymbolStartIndex() const { return numDims; }
getDimStartIndex()320   inline unsigned getDimStartIndex() const { return 0; }
321 };
322 
323 } // end namespace mlir
324 
325 #endif // MLIR_IR_AFFINE_EXPR_VISITOR_H
326