1 //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file defines the classes used to represent and build scalar expressions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 15 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 16 17 #include "llvm/ADT/DenseMap.h" 18 #include "llvm/ADT/FoldingSet.h" 19 #include "llvm/ADT/SmallPtrSet.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include "llvm/ADT/iterator_range.h" 22 #include "llvm/Analysis/ScalarEvolution.h" 23 #include "llvm/IR/Constants.h" 24 #include "llvm/IR/Value.h" 25 #include "llvm/IR/ValueHandle.h" 26 #include "llvm/Support/Casting.h" 27 #include "llvm/Support/ErrorHandling.h" 28 #include <cassert> 29 #include <cstddef> 30 31 namespace llvm { 32 33 class APInt; 34 class Constant; 35 class ConstantRange; 36 class Loop; 37 class Type; 38 39 enum SCEVTypes { 40 // These should be ordered in terms of increasing complexity to make the 41 // folders simpler. 42 scConstant, scTruncate, scZeroExtend, scSignExtend, scAddExpr, scMulExpr, 43 scUDivExpr, scAddRecExpr, scUMaxExpr, scSMaxExpr, 44 scUnknown, scCouldNotCompute 45 }; 46 47 /// This class represents a constant integer value. 48 class SCEVConstant : public SCEV { 49 friend class ScalarEvolution; 50 51 ConstantInt *V; 52 SCEVConstant(const FoldingSetNodeIDRef ID,ConstantInt * v)53 SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v) : 54 SCEV(ID, scConstant), V(v) {} 55 56 public: getValue()57 ConstantInt *getValue() const { return V; } getAPInt()58 const APInt &getAPInt() const { return getValue()->getValue(); } 59 getType()60 Type *getType() const { return V->getType(); } 61 62 /// Methods for support type inquiry through isa, cast, and dyn_cast: classof(const SCEV * S)63 static bool classof(const SCEV *S) { 64 return S->getSCEVType() == scConstant; 65 } 66 }; 67 68 /// This is the base class for unary cast operator classes. 69 class SCEVCastExpr : public SCEV { 70 protected: 71 const SCEV *Op; 72 Type *Ty; 73 74 SCEVCastExpr(const FoldingSetNodeIDRef ID, 75 unsigned SCEVTy, const SCEV *op, Type *ty); 76 77 public: getOperand()78 const SCEV *getOperand() const { return Op; } getType()79 Type *getType() const { return Ty; } 80 81 /// Methods for support type inquiry through isa, cast, and dyn_cast: classof(const SCEV * S)82 static bool classof(const SCEV *S) { 83 return S->getSCEVType() == scTruncate || 84 S->getSCEVType() == scZeroExtend || 85 S->getSCEVType() == scSignExtend; 86 } 87 }; 88 89 /// This class represents a truncation of an integer value to a 90 /// smaller integer value. 91 class SCEVTruncateExpr : public SCEVCastExpr { 92 friend class ScalarEvolution; 93 94 SCEVTruncateExpr(const FoldingSetNodeIDRef ID, 95 const SCEV *op, Type *ty); 96 97 public: 98 /// Methods for support type inquiry through isa, cast, and dyn_cast: classof(const SCEV * S)99 static bool classof(const SCEV *S) { 100 return S->getSCEVType() == scTruncate; 101 } 102 }; 103 104 /// This class represents a zero extension of a small integer value 105 /// to a larger integer value. 106 class SCEVZeroExtendExpr : public SCEVCastExpr { 107 friend class ScalarEvolution; 108 109 SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, 110 const SCEV *op, Type *ty); 111 112 public: 113 /// Methods for support type inquiry through isa, cast, and dyn_cast: classof(const SCEV * S)114 static bool classof(const SCEV *S) { 115 return S->getSCEVType() == scZeroExtend; 116 } 117 }; 118 119 /// This class represents a sign extension of a small integer value 120 /// to a larger integer value. 121 class SCEVSignExtendExpr : public SCEVCastExpr { 122 friend class ScalarEvolution; 123 124 SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, 125 const SCEV *op, Type *ty); 126 127 public: 128 /// Methods for support type inquiry through isa, cast, and dyn_cast: classof(const SCEV * S)129 static bool classof(const SCEV *S) { 130 return S->getSCEVType() == scSignExtend; 131 } 132 }; 133 134 /// This node is a base class providing common functionality for 135 /// n'ary operators. 136 class SCEVNAryExpr : public SCEV { 137 protected: 138 // Since SCEVs are immutable, ScalarEvolution allocates operand 139 // arrays with its SCEVAllocator, so this class just needs a simple 140 // pointer rather than a more elaborate vector-like data structure. 141 // This also avoids the need for a non-trivial destructor. 142 const SCEV *const *Operands; 143 size_t NumOperands; 144 SCEVNAryExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)145 SCEVNAryExpr(const FoldingSetNodeIDRef ID, 146 enum SCEVTypes T, const SCEV *const *O, size_t N) 147 : SCEV(ID, T), Operands(O), NumOperands(N) {} 148 149 public: getNumOperands()150 size_t getNumOperands() const { return NumOperands; } 151 getOperand(unsigned i)152 const SCEV *getOperand(unsigned i) const { 153 assert(i < NumOperands && "Operand index out of range!"); 154 return Operands[i]; 155 } 156 157 using op_iterator = const SCEV *const *; 158 using op_range = iterator_range<op_iterator>; 159 op_begin()160 op_iterator op_begin() const { return Operands; } op_end()161 op_iterator op_end() const { return Operands + NumOperands; } operands()162 op_range operands() const { 163 return make_range(op_begin(), op_end()); 164 } 165 getType()166 Type *getType() const { return getOperand(0)->getType(); } 167 168 NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const { 169 return (NoWrapFlags)(SubclassData & Mask); 170 } 171 hasNoUnsignedWrap()172 bool hasNoUnsignedWrap() const { 173 return getNoWrapFlags(FlagNUW) != FlagAnyWrap; 174 } 175 hasNoSignedWrap()176 bool hasNoSignedWrap() const { 177 return getNoWrapFlags(FlagNSW) != FlagAnyWrap; 178 } 179 hasNoSelfWrap()180 bool hasNoSelfWrap() const { 181 return getNoWrapFlags(FlagNW) != FlagAnyWrap; 182 } 183 184 /// Methods for support type inquiry through isa, cast, and dyn_cast: classof(const SCEV * S)185 static bool classof(const SCEV *S) { 186 return S->getSCEVType() == scAddExpr || 187 S->getSCEVType() == scMulExpr || 188 S->getSCEVType() == scSMaxExpr || 189 S->getSCEVType() == scUMaxExpr || 190 S->getSCEVType() == scAddRecExpr; 191 } 192 }; 193 194 /// This node is the base class for n'ary commutative operators. 195 class SCEVCommutativeExpr : public SCEVNAryExpr { 196 protected: SCEVCommutativeExpr(const FoldingSetNodeIDRef ID,enum SCEVTypes T,const SCEV * const * O,size_t N)197 SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, 198 enum SCEVTypes T, const SCEV *const *O, size_t N) 199 : SCEVNAryExpr(ID, T, O, N) {} 200 201 public: 202 /// Methods for support type inquiry through isa, cast, and dyn_cast: classof(const SCEV * S)203 static bool classof(const SCEV *S) { 204 return S->getSCEVType() == scAddExpr || 205 S->getSCEVType() == scMulExpr || 206 S->getSCEVType() == scSMaxExpr || 207 S->getSCEVType() == scUMaxExpr; 208 } 209 210 /// Set flags for a non-recurrence without clearing previously set flags. setNoWrapFlags(NoWrapFlags Flags)211 void setNoWrapFlags(NoWrapFlags Flags) { 212 SubclassData |= Flags; 213 } 214 }; 215 216 /// This node represents an addition of some number of SCEVs. 217 class SCEVAddExpr : public SCEVCommutativeExpr { 218 friend class ScalarEvolution; 219 SCEVAddExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)220 SCEVAddExpr(const FoldingSetNodeIDRef ID, 221 const SCEV *const *O, size_t N) 222 : SCEVCommutativeExpr(ID, scAddExpr, O, N) {} 223 224 public: getType()225 Type *getType() const { 226 // Use the type of the last operand, which is likely to be a pointer 227 // type, if there is one. This doesn't usually matter, but it can help 228 // reduce casts when the expressions are expanded. 229 return getOperand(getNumOperands() - 1)->getType(); 230 } 231 232 /// Methods for support type inquiry through isa, cast, and dyn_cast: classof(const SCEV * S)233 static bool classof(const SCEV *S) { 234 return S->getSCEVType() == scAddExpr; 235 } 236 }; 237 238 /// This node represents multiplication of some number of SCEVs. 239 class SCEVMulExpr : public SCEVCommutativeExpr { 240 friend class ScalarEvolution; 241 SCEVMulExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)242 SCEVMulExpr(const FoldingSetNodeIDRef ID, 243 const SCEV *const *O, size_t N) 244 : SCEVCommutativeExpr(ID, scMulExpr, O, N) {} 245 246 public: 247 /// Methods for support type inquiry through isa, cast, and dyn_cast: classof(const SCEV * S)248 static bool classof(const SCEV *S) { 249 return S->getSCEVType() == scMulExpr; 250 } 251 }; 252 253 /// This class represents a binary unsigned division operation. 254 class SCEVUDivExpr : public SCEV { 255 friend class ScalarEvolution; 256 257 const SCEV *LHS; 258 const SCEV *RHS; 259 SCEVUDivExpr(const FoldingSetNodeIDRef ID,const SCEV * lhs,const SCEV * rhs)260 SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs) 261 : SCEV(ID, scUDivExpr), LHS(lhs), RHS(rhs) {} 262 263 public: getLHS()264 const SCEV *getLHS() const { return LHS; } getRHS()265 const SCEV *getRHS() const { return RHS; } 266 getType()267 Type *getType() const { 268 // In most cases the types of LHS and RHS will be the same, but in some 269 // crazy cases one or the other may be a pointer. ScalarEvolution doesn't 270 // depend on the type for correctness, but handling types carefully can 271 // avoid extra casts in the SCEVExpander. The LHS is more likely to be 272 // a pointer type than the RHS, so use the RHS' type here. 273 return getRHS()->getType(); 274 } 275 276 /// Methods for support type inquiry through isa, cast, and dyn_cast: classof(const SCEV * S)277 static bool classof(const SCEV *S) { 278 return S->getSCEVType() == scUDivExpr; 279 } 280 }; 281 282 /// This node represents a polynomial recurrence on the trip count 283 /// of the specified loop. This is the primary focus of the 284 /// ScalarEvolution framework; all the other SCEV subclasses are 285 /// mostly just supporting infrastructure to allow SCEVAddRecExpr 286 /// expressions to be created and analyzed. 287 /// 288 /// All operands of an AddRec are required to be loop invariant. 289 /// 290 class SCEVAddRecExpr : public SCEVNAryExpr { 291 friend class ScalarEvolution; 292 293 const Loop *L; 294 SCEVAddRecExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N,const Loop * l)295 SCEVAddRecExpr(const FoldingSetNodeIDRef ID, 296 const SCEV *const *O, size_t N, const Loop *l) 297 : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {} 298 299 public: getStart()300 const SCEV *getStart() const { return Operands[0]; } getLoop()301 const Loop *getLoop() const { return L; } 302 303 /// Constructs and returns the recurrence indicating how much this 304 /// expression steps by. If this is a polynomial of degree N, it 305 /// returns a chrec of degree N-1. We cannot determine whether 306 /// the step recurrence has self-wraparound. getStepRecurrence(ScalarEvolution & SE)307 const SCEV *getStepRecurrence(ScalarEvolution &SE) const { 308 if (isAffine()) return getOperand(1); 309 return SE.getAddRecExpr(SmallVector<const SCEV *, 3>(op_begin()+1, 310 op_end()), 311 getLoop(), FlagAnyWrap); 312 } 313 314 /// Return true if this represents an expression A + B*x where A 315 /// and B are loop invariant values. isAffine()316 bool isAffine() const { 317 // We know that the start value is invariant. This expression is thus 318 // affine iff the step is also invariant. 319 return getNumOperands() == 2; 320 } 321 322 /// Return true if this represents an expression A + B*x + C*x^2 323 /// where A, B and C are loop invariant values. This corresponds 324 /// to an addrec of the form {L,+,M,+,N} isQuadratic()325 bool isQuadratic() const { 326 return getNumOperands() == 3; 327 } 328 329 /// Set flags for a recurrence without clearing any previously set flags. 330 /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here 331 /// to make it easier to propagate flags. setNoWrapFlags(NoWrapFlags Flags)332 void setNoWrapFlags(NoWrapFlags Flags) { 333 if (Flags & (FlagNUW | FlagNSW)) 334 Flags = ScalarEvolution::setFlags(Flags, FlagNW); 335 SubclassData |= Flags; 336 } 337 338 /// Return the value of this chain of recurrences at the specified 339 /// iteration number. 340 const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const; 341 342 /// Return the number of iterations of this loop that produce 343 /// values in the specified constant range. Another way of 344 /// looking at this is that it returns the first iteration number 345 /// where the value is not in the condition, thus computing the 346 /// exit count. If the iteration count can't be computed, an 347 /// instance of SCEVCouldNotCompute is returned. 348 const SCEV *getNumIterationsInRange(const ConstantRange &Range, 349 ScalarEvolution &SE) const; 350 351 /// Return an expression representing the value of this expression 352 /// one iteration of the loop ahead. 353 const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const; 354 355 /// Methods for support type inquiry through isa, cast, and dyn_cast: classof(const SCEV * S)356 static bool classof(const SCEV *S) { 357 return S->getSCEVType() == scAddRecExpr; 358 } 359 }; 360 361 /// This class represents a signed maximum selection. 362 class SCEVSMaxExpr : public SCEVCommutativeExpr { 363 friend class ScalarEvolution; 364 SCEVSMaxExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)365 SCEVSMaxExpr(const FoldingSetNodeIDRef ID, 366 const SCEV *const *O, size_t N) 367 : SCEVCommutativeExpr(ID, scSMaxExpr, O, N) { 368 // Max never overflows. 369 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)); 370 } 371 372 public: 373 /// Methods for support type inquiry through isa, cast, and dyn_cast: classof(const SCEV * S)374 static bool classof(const SCEV *S) { 375 return S->getSCEVType() == scSMaxExpr; 376 } 377 }; 378 379 /// This class represents an unsigned maximum selection. 380 class SCEVUMaxExpr : public SCEVCommutativeExpr { 381 friend class ScalarEvolution; 382 SCEVUMaxExpr(const FoldingSetNodeIDRef ID,const SCEV * const * O,size_t N)383 SCEVUMaxExpr(const FoldingSetNodeIDRef ID, 384 const SCEV *const *O, size_t N) 385 : SCEVCommutativeExpr(ID, scUMaxExpr, O, N) { 386 // Max never overflows. 387 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)); 388 } 389 390 public: 391 /// Methods for support type inquiry through isa, cast, and dyn_cast: classof(const SCEV * S)392 static bool classof(const SCEV *S) { 393 return S->getSCEVType() == scUMaxExpr; 394 } 395 }; 396 397 /// This means that we are dealing with an entirely unknown SCEV 398 /// value, and only represent it as its LLVM Value. This is the 399 /// "bottom" value for the analysis. 400 class SCEVUnknown final : public SCEV, private CallbackVH { 401 friend class ScalarEvolution; 402 403 /// The parent ScalarEvolution value. This is used to update the 404 /// parent's maps when the value associated with a SCEVUnknown is 405 /// deleted or RAUW'd. 406 ScalarEvolution *SE; 407 408 /// The next pointer in the linked list of all SCEVUnknown 409 /// instances owned by a ScalarEvolution. 410 SCEVUnknown *Next; 411 SCEVUnknown(const FoldingSetNodeIDRef ID,Value * V,ScalarEvolution * se,SCEVUnknown * next)412 SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V, 413 ScalarEvolution *se, SCEVUnknown *next) : 414 SCEV(ID, scUnknown), CallbackVH(V), SE(se), Next(next) {} 415 416 // Implement CallbackVH. 417 void deleted() override; 418 void allUsesReplacedWith(Value *New) override; 419 420 public: getValue()421 Value *getValue() const { return getValPtr(); } 422 423 /// @{ 424 /// Test whether this is a special constant representing a type 425 /// size, alignment, or field offset in a target-independent 426 /// manner, and hasn't happened to have been folded with other 427 /// operations into something unrecognizable. This is mainly only 428 /// useful for pretty-printing and other situations where it isn't 429 /// absolutely required for these to succeed. 430 bool isSizeOf(Type *&AllocTy) const; 431 bool isAlignOf(Type *&AllocTy) const; 432 bool isOffsetOf(Type *&STy, Constant *&FieldNo) const; 433 /// @} 434 getType()435 Type *getType() const { return getValPtr()->getType(); } 436 437 /// Methods for support type inquiry through isa, cast, and dyn_cast: classof(const SCEV * S)438 static bool classof(const SCEV *S) { 439 return S->getSCEVType() == scUnknown; 440 } 441 }; 442 443 /// This class defines a simple visitor class that may be used for 444 /// various SCEV analysis purposes. 445 template<typename SC, typename RetVal=void> 446 struct SCEVVisitor { visitSCEVVisitor447 RetVal visit(const SCEV *S) { 448 switch (S->getSCEVType()) { 449 case scConstant: 450 return ((SC*)this)->visitConstant((const SCEVConstant*)S); 451 case scTruncate: 452 return ((SC*)this)->visitTruncateExpr((const SCEVTruncateExpr*)S); 453 case scZeroExtend: 454 return ((SC*)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr*)S); 455 case scSignExtend: 456 return ((SC*)this)->visitSignExtendExpr((const SCEVSignExtendExpr*)S); 457 case scAddExpr: 458 return ((SC*)this)->visitAddExpr((const SCEVAddExpr*)S); 459 case scMulExpr: 460 return ((SC*)this)->visitMulExpr((const SCEVMulExpr*)S); 461 case scUDivExpr: 462 return ((SC*)this)->visitUDivExpr((const SCEVUDivExpr*)S); 463 case scAddRecExpr: 464 return ((SC*)this)->visitAddRecExpr((const SCEVAddRecExpr*)S); 465 case scSMaxExpr: 466 return ((SC*)this)->visitSMaxExpr((const SCEVSMaxExpr*)S); 467 case scUMaxExpr: 468 return ((SC*)this)->visitUMaxExpr((const SCEVUMaxExpr*)S); 469 case scUnknown: 470 return ((SC*)this)->visitUnknown((const SCEVUnknown*)S); 471 case scCouldNotCompute: 472 return ((SC*)this)->visitCouldNotCompute((const SCEVCouldNotCompute*)S); 473 default: 474 llvm_unreachable("Unknown SCEV type!"); 475 } 476 } 477 visitCouldNotComputeSCEVVisitor478 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) { 479 llvm_unreachable("Invalid use of SCEVCouldNotCompute!"); 480 } 481 }; 482 483 /// Visit all nodes in the expression tree using worklist traversal. 484 /// 485 /// Visitor implements: 486 /// // return true to follow this node. 487 /// bool follow(const SCEV *S); 488 /// // return true to terminate the search. 489 /// bool isDone(); 490 template<typename SV> 491 class SCEVTraversal { 492 SV &Visitor; 493 SmallVector<const SCEV *, 8> Worklist; 494 SmallPtrSet<const SCEV *, 8> Visited; 495 push(const SCEV * S)496 void push(const SCEV *S) { 497 if (Visited.insert(S).second && Visitor.follow(S)) 498 Worklist.push_back(S); 499 } 500 501 public: SCEVTraversal(SV & V)502 SCEVTraversal(SV& V): Visitor(V) {} 503 visitAll(const SCEV * Root)504 void visitAll(const SCEV *Root) { 505 push(Root); 506 while (!Worklist.empty() && !Visitor.isDone()) { 507 const SCEV *S = Worklist.pop_back_val(); 508 509 switch (S->getSCEVType()) { 510 case scConstant: 511 case scUnknown: 512 break; 513 case scTruncate: 514 case scZeroExtend: 515 case scSignExtend: 516 push(cast<SCEVCastExpr>(S)->getOperand()); 517 break; 518 case scAddExpr: 519 case scMulExpr: 520 case scSMaxExpr: 521 case scUMaxExpr: 522 case scAddRecExpr: 523 for (const auto *Op : cast<SCEVNAryExpr>(S)->operands()) 524 push(Op); 525 break; 526 case scUDivExpr: { 527 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S); 528 push(UDiv->getLHS()); 529 push(UDiv->getRHS()); 530 break; 531 } 532 case scCouldNotCompute: 533 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); 534 default: 535 llvm_unreachable("Unknown SCEV kind!"); 536 } 537 } 538 } 539 }; 540 541 /// Use SCEVTraversal to visit all nodes in the given expression tree. 542 template<typename SV> visitAll(const SCEV * Root,SV & Visitor)543 void visitAll(const SCEV *Root, SV& Visitor) { 544 SCEVTraversal<SV> T(Visitor); 545 T.visitAll(Root); 546 } 547 548 /// Return true if any node in \p Root satisfies the predicate \p Pred. 549 template <typename PredTy> SCEVExprContains(const SCEV * Root,PredTy Pred)550 bool SCEVExprContains(const SCEV *Root, PredTy Pred) { 551 struct FindClosure { 552 bool Found = false; 553 PredTy Pred; 554 555 FindClosure(PredTy Pred) : Pred(Pred) {} 556 557 bool follow(const SCEV *S) { 558 if (!Pred(S)) 559 return true; 560 561 Found = true; 562 return false; 563 } 564 565 bool isDone() const { return Found; } 566 }; 567 568 FindClosure FC(Pred); 569 visitAll(Root, FC); 570 return FC.Found; 571 } 572 573 /// This visitor recursively visits a SCEV expression and re-writes it. 574 /// The result from each visit is cached, so it will return the same 575 /// SCEV for the same input. 576 template<typename SC> 577 class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> { 578 protected: 579 ScalarEvolution &SE; 580 // Memoize the result of each visit so that we only compute once for 581 // the same input SCEV. This is to avoid redundant computations when 582 // a SCEV is referenced by multiple SCEVs. Without memoization, this 583 // visit algorithm would have exponential time complexity in the worst 584 // case, causing the compiler to hang on certain tests. 585 DenseMap<const SCEV *, const SCEV *> RewriteResults; 586 587 public: SCEVRewriteVisitor(ScalarEvolution & SE)588 SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {} 589 visit(const SCEV * S)590 const SCEV *visit(const SCEV *S) { 591 auto It = RewriteResults.find(S); 592 if (It != RewriteResults.end()) 593 return It->second; 594 auto* Visited = SCEVVisitor<SC, const SCEV *>::visit(S); 595 auto Result = RewriteResults.try_emplace(S, Visited); 596 assert(Result.second && "Should insert a new entry"); 597 return Result.first->second; 598 } 599 visitConstant(const SCEVConstant * Constant)600 const SCEV *visitConstant(const SCEVConstant *Constant) { 601 return Constant; 602 } 603 visitTruncateExpr(const SCEVTruncateExpr * Expr)604 const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { 605 const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand()); 606 return Operand == Expr->getOperand() 607 ? Expr 608 : SE.getTruncateExpr(Operand, Expr->getType()); 609 } 610 visitZeroExtendExpr(const SCEVZeroExtendExpr * Expr)611 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { 612 const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand()); 613 return Operand == Expr->getOperand() 614 ? Expr 615 : SE.getZeroExtendExpr(Operand, Expr->getType()); 616 } 617 visitSignExtendExpr(const SCEVSignExtendExpr * Expr)618 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { 619 const SCEV *Operand = ((SC*)this)->visit(Expr->getOperand()); 620 return Operand == Expr->getOperand() 621 ? Expr 622 : SE.getSignExtendExpr(Operand, Expr->getType()); 623 } 624 visitAddExpr(const SCEVAddExpr * Expr)625 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { 626 SmallVector<const SCEV *, 2> Operands; 627 bool Changed = false; 628 for (auto *Op : Expr->operands()) { 629 Operands.push_back(((SC*)this)->visit(Op)); 630 Changed |= Op != Operands.back(); 631 } 632 return !Changed ? Expr : SE.getAddExpr(Operands); 633 } 634 visitMulExpr(const SCEVMulExpr * Expr)635 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { 636 SmallVector<const SCEV *, 2> Operands; 637 bool Changed = false; 638 for (auto *Op : Expr->operands()) { 639 Operands.push_back(((SC*)this)->visit(Op)); 640 Changed |= Op != Operands.back(); 641 } 642 return !Changed ? Expr : SE.getMulExpr(Operands); 643 } 644 visitUDivExpr(const SCEVUDivExpr * Expr)645 const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { 646 auto *LHS = ((SC *)this)->visit(Expr->getLHS()); 647 auto *RHS = ((SC *)this)->visit(Expr->getRHS()); 648 bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS(); 649 return !Changed ? Expr : SE.getUDivExpr(LHS, RHS); 650 } 651 visitAddRecExpr(const SCEVAddRecExpr * Expr)652 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { 653 SmallVector<const SCEV *, 2> Operands; 654 bool Changed = false; 655 for (auto *Op : Expr->operands()) { 656 Operands.push_back(((SC*)this)->visit(Op)); 657 Changed |= Op != Operands.back(); 658 } 659 return !Changed ? Expr 660 : SE.getAddRecExpr(Operands, Expr->getLoop(), 661 Expr->getNoWrapFlags()); 662 } 663 visitSMaxExpr(const SCEVSMaxExpr * Expr)664 const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) { 665 SmallVector<const SCEV *, 2> Operands; 666 bool Changed = false; 667 for (auto *Op : Expr->operands()) { 668 Operands.push_back(((SC *)this)->visit(Op)); 669 Changed |= Op != Operands.back(); 670 } 671 return !Changed ? Expr : SE.getSMaxExpr(Operands); 672 } 673 visitUMaxExpr(const SCEVUMaxExpr * Expr)674 const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { 675 SmallVector<const SCEV *, 2> Operands; 676 bool Changed = false; 677 for (auto *Op : Expr->operands()) { 678 Operands.push_back(((SC*)this)->visit(Op)); 679 Changed |= Op != Operands.back(); 680 } 681 return !Changed ? Expr : SE.getUMaxExpr(Operands); 682 } 683 visitUnknown(const SCEVUnknown * Expr)684 const SCEV *visitUnknown(const SCEVUnknown *Expr) { 685 return Expr; 686 } 687 visitCouldNotCompute(const SCEVCouldNotCompute * Expr)688 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { 689 return Expr; 690 } 691 }; 692 693 using ValueToValueMap = DenseMap<const Value *, Value *>; 694 695 /// The SCEVParameterRewriter takes a scalar evolution expression and updates 696 /// the SCEVUnknown components following the Map (Value -> Value). 697 class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> { 698 public: 699 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, 700 ValueToValueMap &Map, 701 bool InterpretConsts = false) { 702 SCEVParameterRewriter Rewriter(SE, Map, InterpretConsts); 703 return Rewriter.visit(Scev); 704 } 705 SCEVParameterRewriter(ScalarEvolution & SE,ValueToValueMap & M,bool C)706 SCEVParameterRewriter(ScalarEvolution &SE, ValueToValueMap &M, bool C) 707 : SCEVRewriteVisitor(SE), Map(M), InterpretConsts(C) {} 708 visitUnknown(const SCEVUnknown * Expr)709 const SCEV *visitUnknown(const SCEVUnknown *Expr) { 710 Value *V = Expr->getValue(); 711 if (Map.count(V)) { 712 Value *NV = Map[V]; 713 if (InterpretConsts && isa<ConstantInt>(NV)) 714 return SE.getConstant(cast<ConstantInt>(NV)); 715 return SE.getUnknown(NV); 716 } 717 return Expr; 718 } 719 720 private: 721 ValueToValueMap ⤅ 722 bool InterpretConsts; 723 }; 724 725 using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>; 726 727 /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies 728 /// the Map (Loop -> SCEV) to all AddRecExprs. 729 class SCEVLoopAddRecRewriter 730 : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> { 731 public: SCEVLoopAddRecRewriter(ScalarEvolution & SE,LoopToScevMapT & M)732 SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M) 733 : SCEVRewriteVisitor(SE), Map(M) {} 734 rewrite(const SCEV * Scev,LoopToScevMapT & Map,ScalarEvolution & SE)735 static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map, 736 ScalarEvolution &SE) { 737 SCEVLoopAddRecRewriter Rewriter(SE, Map); 738 return Rewriter.visit(Scev); 739 } 740 visitAddRecExpr(const SCEVAddRecExpr * Expr)741 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { 742 SmallVector<const SCEV *, 2> Operands; 743 for (const SCEV *Op : Expr->operands()) 744 Operands.push_back(visit(Op)); 745 746 const Loop *L = Expr->getLoop(); 747 const SCEV *Res = SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags()); 748 749 if (0 == Map.count(L)) 750 return Res; 751 752 const SCEVAddRecExpr *Rec = cast<SCEVAddRecExpr>(Res); 753 return Rec->evaluateAtIteration(Map[L], SE); 754 } 755 756 private: 757 LoopToScevMapT ⤅ 758 }; 759 760 } // end namespace llvm 761 762 #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H 763