1 //===-------- TestLoopUnrolling.cpp --- loop unrolling test pass ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to unroll loops by a specified unroll factor.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SCF/SCF.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Transforms/LoopUtils.h"
17 #include "mlir/Transforms/Passes.h"
18 
19 using namespace mlir;
20 
21 namespace {
22 
getNestingDepth(Operation * op)23 static unsigned getNestingDepth(Operation *op) {
24   Operation *currOp = op;
25   unsigned depth = 0;
26   while ((currOp = currOp->getParentOp())) {
27     if (isa<scf::ForOp>(currOp))
28       depth++;
29   }
30   return depth;
31 }
32 
33 class TestLoopUnrollingPass
34     : public PassWrapper<TestLoopUnrollingPass, FunctionPass> {
35 public:
36   TestLoopUnrollingPass() = default;
TestLoopUnrollingPass(const TestLoopUnrollingPass &)37   TestLoopUnrollingPass(const TestLoopUnrollingPass &) {}
TestLoopUnrollingPass(uint64_t unrollFactorParam,unsigned loopDepthParam)38   explicit TestLoopUnrollingPass(uint64_t unrollFactorParam,
39                                  unsigned loopDepthParam) {
40     unrollFactor = unrollFactorParam;
41     loopDepth = loopDepthParam;
42   }
43 
runOnFunction()44   void runOnFunction() override {
45     FuncOp func = getFunction();
46     SmallVector<scf::ForOp, 4> loops;
47     func.walk([&](scf::ForOp forOp) {
48       if (getNestingDepth(forOp) == loopDepth)
49         loops.push_back(forOp);
50     });
51     for (auto loop : loops) {
52       loopUnrollByFactor(loop, unrollFactor);
53     }
54   }
55   Option<uint64_t> unrollFactor{*this, "unroll-factor",
56                                 llvm::cl::desc("Loop unroll factor."),
57                                 llvm::cl::init(1)};
58   Option<bool> unrollUpToFactor{*this, "unroll-up-to-factor",
59                                 llvm::cl::desc("Loop unroll up to factor."),
60                                 llvm::cl::init(false)};
61   Option<unsigned> loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."),
62                              llvm::cl::init(0)};
63 };
64 } // namespace
65 
66 namespace mlir {
67 namespace test {
registerTestLoopUnrollingPass()68 void registerTestLoopUnrollingPass() {
69   PassRegistration<TestLoopUnrollingPass>(
70       "test-loop-unrolling", "Tests loop unrolling transformation");
71 }
72 } // namespace test
73 } // namespace mlir
74