1 //===- NumberOfExecutions.cpp - Number of executions analysis -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implementation of the number of executions analysis.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/NumberOfExecutions.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/Matchers.h"
16 #include "mlir/IR/RegionKindInterface.h"
17 #include "mlir/Interfaces/ControlFlowInterfaces.h"
18 
19 #include "llvm/ADT/FunctionExtras.h"
20 #include "llvm/ADT/SmallSet.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 #define DEBUG_TYPE "number-of-executions-analysis"
24 
25 using namespace mlir;
26 
27 //===----------------------------------------------------------------------===//
28 // NumberOfExecutions
29 //===----------------------------------------------------------------------===//
30 
31 /// Computes blocks number of executions information for the given region.
computeRegionBlockNumberOfExecutions(Region & region,DenseMap<Block *,BlockNumberOfExecutionsInfo> & blockInfo)32 static void computeRegionBlockNumberOfExecutions(
33     Region &region, DenseMap<Block *, BlockNumberOfExecutionsInfo> &blockInfo) {
34   Operation *parentOp = region.getParentOp();
35   int regionId = region.getRegionNumber();
36 
37   auto regionKindInterface = dyn_cast<RegionKindInterface>(parentOp);
38   bool isGraphRegion =
39       regionKindInterface &&
40       regionKindInterface.getRegionKind(regionId) == RegionKind::Graph;
41 
42   // CFG analysis does not make sense for Graph regions, set the number of
43   // executions for all blocks as unknown.
44   if (isGraphRegion) {
45     for (Block &block : region)
46       blockInfo.insert({&block, {&block, None, None}});
47     return;
48   }
49 
50   // Number of region invocations for all attached regions.
51   SmallVector<int64_t, 4> numRegionsInvocations;
52 
53   // Query RegionBranchOpInterface interface if it is available.
54   if (auto regionInterface = dyn_cast<RegionBranchOpInterface>(parentOp)) {
55     SmallVector<Attribute, 4> operands(parentOp->getNumOperands());
56     for (auto operandIt : llvm::enumerate(parentOp->getOperands()))
57       matchPattern(operandIt.value(), m_Constant(&operands[operandIt.index()]));
58 
59     regionInterface.getNumRegionInvocations(operands, numRegionsInvocations);
60   }
61 
62   // Number of region invocations *each time* parent operation is invoked.
63   Optional<int64_t> numRegionInvocations;
64 
65   if (!numRegionsInvocations.empty() &&
66       numRegionsInvocations[regionId] != kUnknownNumRegionInvocations) {
67     numRegionInvocations = numRegionsInvocations[regionId];
68   }
69 
70   // DFS traversal looking for loops in the CFG.
71   llvm::SmallSet<Block *, 4> loopStart;
72 
73   llvm::unique_function<void(Block *, llvm::SmallSet<Block *, 4> &)> dfs =
74       [&](Block *block, llvm::SmallSet<Block *, 4> &visited) {
75         // Found a loop in the CFG that starts at the `block`.
76         if (visited.contains(block)) {
77           loopStart.insert(block);
78           return;
79         }
80 
81         // Continue DFS traversal.
82         visited.insert(block);
83         for (Block *successor : block->getSuccessors())
84           dfs(successor, visited);
85         visited.erase(block);
86       };
87 
88   llvm::SmallSet<Block *, 4> visited;
89   dfs(&region.front(), visited);
90 
91   // Start from the entry block and follow only blocks with single succesor.
92   Block *block = &region.front();
93   while (block && !loopStart.contains(block)) {
94     // Block will be executed exactly once.
95     blockInfo.insert(
96         {block, BlockNumberOfExecutionsInfo(block, numRegionInvocations,
97                                             /*numberOfBlockExecutions=*/1)});
98 
99     // We reached the exit block or block with multiple successors.
100     if (block->getNumSuccessors() != 1)
101       break;
102 
103     // Continue traversal.
104     block = block->getSuccessor(0);
105   }
106 
107   // For all blocks that we did not visit set the executions number to unknown.
108   for (Block &block : region)
109     if (blockInfo.count(&block) == 0)
110       blockInfo.insert({&block, BlockNumberOfExecutionsInfo(
111                                     &block, numRegionInvocations,
112                                     /*numberOfBlockExecutions=*/None)});
113 }
114 
115 /// Creates a new NumberOfExecutions analysis that computes how many times a
116 /// block within a region is executed for all associated regions.
NumberOfExecutions(Operation * op)117 NumberOfExecutions::NumberOfExecutions(Operation *op) : operation(op) {
118   operation->walk([&](Region *region) {
119     computeRegionBlockNumberOfExecutions(*region, blockNumbersOfExecution);
120   });
121 }
122 
123 Optional<int64_t>
getNumberOfExecutions(Operation * op,Region * perEntryOfThisRegion) const124 NumberOfExecutions::getNumberOfExecutions(Operation *op,
125                                           Region *perEntryOfThisRegion) const {
126   // Assuming that all operations complete in a finite amount of time (do not
127   // abort and do not go into the infinite loop), the number of operation
128   // executions is equal to the number of block executions that contains the
129   // operation.
130   return getNumberOfExecutions(op->getBlock(), perEntryOfThisRegion);
131 }
132 
133 Optional<int64_t>
getNumberOfExecutions(Block * block,Region * perEntryOfThisRegion) const134 NumberOfExecutions::getNumberOfExecutions(Block *block,
135                                           Region *perEntryOfThisRegion) const {
136   // Return None if the given `block` does not lie inside the
137   // `perEntryOfThisRegion` region.
138   if (!perEntryOfThisRegion->findAncestorBlockInRegion(*block))
139     return None;
140 
141   // Find the block information for the given `block.
142   auto blockIt = blockNumbersOfExecution.find(block);
143   if (blockIt == blockNumbersOfExecution.end())
144     return None;
145   const auto &blockInfo = blockIt->getSecond();
146 
147   // Override the number of region invocations with `1` if the
148   // `perEntryOfThisRegion` region owns the block.
149   auto getNumberOfExecutions = [&](const BlockNumberOfExecutionsInfo &info) {
150     if (info.getBlock()->getParent() == perEntryOfThisRegion)
151       return info.getNumberOfExecutions(/*numberOfRegionInvocations=*/1);
152     return info.getNumberOfExecutions();
153   };
154 
155   // Immediately return None if we do not know the block number of executions.
156   auto blockExecutions = getNumberOfExecutions(blockInfo);
157   if (!blockExecutions.hasValue())
158     return None;
159 
160   // Follow parent operations until we reach the operations that owns the
161   // `perEntryOfThisRegion`.
162   int64_t numberOfExecutions = *blockExecutions;
163   Operation *parentOp = block->getParentOp();
164 
165   while (parentOp != perEntryOfThisRegion->getParentOp()) {
166     // Find how many times will be executed the block that owns the parent
167     // operation.
168     Block *parentBlock = parentOp->getBlock();
169 
170     auto parentBlockIt = blockNumbersOfExecution.find(parentBlock);
171     if (parentBlockIt == blockNumbersOfExecution.end())
172       return None;
173     const auto &parentBlockInfo = parentBlockIt->getSecond();
174     auto parentBlockExecutions = getNumberOfExecutions(parentBlockInfo);
175 
176     // We stumbled upon an operation with unknown number of executions.
177     if (!parentBlockExecutions.hasValue())
178       return None;
179 
180     // Number of block executions is a product of all parent blocks executions.
181     numberOfExecutions *= *parentBlockExecutions;
182     parentOp = parentOp->getParentOp();
183 
184     assert(parentOp != nullptr);
185   }
186 
187   return numberOfExecutions;
188 }
189 
printBlockExecutions(raw_ostream & os,Region * perEntryOfThisRegion) const190 void NumberOfExecutions::printBlockExecutions(
191     raw_ostream &os, Region *perEntryOfThisRegion) const {
192   unsigned blockId = 0;
193 
194   operation->walk([&](Block *block) {
195     llvm::errs() << "Block: " << blockId++ << "\n";
196     llvm::errs() << "Number of executions: ";
197     if (auto n = getNumberOfExecutions(block, perEntryOfThisRegion))
198       llvm::errs() << *n << "\n";
199     else
200       llvm::errs() << "<unknown>\n";
201   });
202 }
203 
printOperationExecutions(raw_ostream & os,Region * perEntryOfThisRegion) const204 void NumberOfExecutions::printOperationExecutions(
205     raw_ostream &os, Region *perEntryOfThisRegion) const {
206   operation->walk([&](Block *block) {
207     block->walk([&](Operation *operation) {
208       // Skip the operation that was used to build the analysis.
209       if (operation == this->operation)
210         return;
211 
212       llvm::errs() << "Operation: " << operation->getName() << "\n";
213       llvm::errs() << "Number of executions: ";
214       if (auto n = getNumberOfExecutions(operation, perEntryOfThisRegion))
215         llvm::errs() << *n << "\n";
216       else
217         llvm::errs() << "<unknown>\n";
218     });
219   });
220 }
221 
222 //===----------------------------------------------------------------------===//
223 // BlockNumberOfExecutionsInfo
224 //===----------------------------------------------------------------------===//
225 
BlockNumberOfExecutionsInfo(Block * block,Optional<int64_t> numberOfRegionInvocations,Optional<int64_t> numberOfBlockExecutions)226 BlockNumberOfExecutionsInfo::BlockNumberOfExecutionsInfo(
227     Block *block, Optional<int64_t> numberOfRegionInvocations,
228     Optional<int64_t> numberOfBlockExecutions)
229     : block(block), numberOfRegionInvocations(numberOfRegionInvocations),
230       numberOfBlockExecutions(numberOfBlockExecutions) {}
231 
getNumberOfExecutions() const232 Optional<int64_t> BlockNumberOfExecutionsInfo::getNumberOfExecutions() const {
233   if (numberOfRegionInvocations && numberOfBlockExecutions)
234     return *numberOfRegionInvocations * *numberOfBlockExecutions;
235   return None;
236 }
237 
getNumberOfExecutions(int64_t numberOfRegionInvocations) const238 Optional<int64_t> BlockNumberOfExecutionsInfo::getNumberOfExecutions(
239     int64_t numberOfRegionInvocations) const {
240   if (numberOfBlockExecutions)
241     return numberOfRegionInvocations * *numberOfBlockExecutions;
242   return None;
243 }
244