1 //===- ScalarEvolutionNormalization.cpp - See below -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for working with "normalized" expressions.
10 // See the comments at the top of ScalarEvolutionNormalization.h for details.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Analysis/ScalarEvolutionNormalization.h"
15 #include "llvm/Analysis/LoopInfo.h"
16 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
17 using namespace llvm;
18 
19 /// TransformKind - Different types of transformations that
20 /// TransformForPostIncUse can do.
21 enum TransformKind {
22   /// Normalize - Normalize according to the given loops.
23   Normalize,
24   /// Denormalize - Perform the inverse transform on the expression with the
25   /// given loop set.
26   Denormalize
27 };
28 
29 namespace {
30 struct NormalizeDenormalizeRewriter
31     : public SCEVRewriteVisitor<NormalizeDenormalizeRewriter> {
32   const TransformKind Kind;
33 
34   // NB! Pred is a function_ref.  Storing it here is okay only because
35   // we're careful about the lifetime of NormalizeDenormalizeRewriter.
36   const NormalizePredTy Pred;
37 
NormalizeDenormalizeRewriter__anonf3211b890111::NormalizeDenormalizeRewriter38   NormalizeDenormalizeRewriter(TransformKind Kind, NormalizePredTy Pred,
39                                ScalarEvolution &SE)
40       : SCEVRewriteVisitor<NormalizeDenormalizeRewriter>(SE), Kind(Kind),
41         Pred(Pred) {}
42   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr);
43 };
44 } // namespace
45 
46 const SCEV *
visitAddRecExpr(const SCEVAddRecExpr * AR)47 NormalizeDenormalizeRewriter::visitAddRecExpr(const SCEVAddRecExpr *AR) {
48   SmallVector<const SCEV *, 8> Operands;
49 
50   transform(AR->operands(), std::back_inserter(Operands),
51             [&](const SCEV *Op) { return visit(Op); });
52 
53   if (!Pred(AR))
54     return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
55 
56   // Normalization and denormalization are fancy names for decrementing and
57   // incrementing a SCEV expression with respect to a set of loops.  Since
58   // Pred(AR) has returned true, we know we need to normalize or denormalize AR
59   // with respect to its loop.
60 
61   if (Kind == Denormalize) {
62     // Denormalization / "partial increment" is essentially the same as \c
63     // SCEVAddRecExpr::getPostIncExpr.  Here we use an explicit loop to make the
64     // symmetry with Normalization clear.
65     for (int i = 0, e = Operands.size() - 1; i < e; i++)
66       Operands[i] = SE.getAddExpr(Operands[i], Operands[i + 1]);
67   } else {
68     assert(Kind == Normalize && "Only two possibilities!");
69 
70     // Normalization / "partial decrement" is a bit more subtle.  Since
71     // incrementing a SCEV expression (in general) changes the step of the SCEV
72     // expression as well, we cannot use the step of the current expression.
73     // Instead, we have to use the step of the very expression we're trying to
74     // compute!
75     //
76     // We solve the issue by recursively building up the result, starting from
77     // the "least significant" operand in the add recurrence:
78     //
79     // Base case:
80     //   Single operand add recurrence.  It's its own normalization.
81     //
82     // N-operand case:
83     //   {S_{N-1},+,S_{N-2},+,...,+,S_0} = S
84     //
85     //   Since the step recurrence of S is {S_{N-2},+,...,+,S_0}, we know its
86     //   normalization by induction.  We subtract the normalized step
87     //   recurrence from S_{N-1} to get the normalization of S.
88 
89     for (int i = Operands.size() - 2; i >= 0; i--)
90       Operands[i] = SE.getMinusSCEV(Operands[i], Operands[i + 1]);
91   }
92 
93   return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
94 }
95 
normalizeForPostIncUse(const SCEV * S,const PostIncLoopSet & Loops,ScalarEvolution & SE)96 const SCEV *llvm::normalizeForPostIncUse(const SCEV *S,
97                                          const PostIncLoopSet &Loops,
98                                          ScalarEvolution &SE) {
99   auto Pred = [&](const SCEVAddRecExpr *AR) {
100     return Loops.count(AR->getLoop());
101   };
102   return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S);
103 }
104 
normalizeForPostIncUseIf(const SCEV * S,NormalizePredTy Pred,ScalarEvolution & SE)105 const SCEV *llvm::normalizeForPostIncUseIf(const SCEV *S, NormalizePredTy Pred,
106                                            ScalarEvolution &SE) {
107   return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S);
108 }
109 
denormalizeForPostIncUse(const SCEV * S,const PostIncLoopSet & Loops,ScalarEvolution & SE)110 const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S,
111                                            const PostIncLoopSet &Loops,
112                                            ScalarEvolution &SE) {
113   auto Pred = [&](const SCEVAddRecExpr *AR) {
114     return Loops.count(AR->getLoop());
115   };
116   return NormalizeDenormalizeRewriter(Denormalize, Pred, SE).visit(S);
117 }
118