1 //===- TestLoopFusion.cpp - Test loop fusion ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test various loop fusion utility functions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/Utils.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/LoopFusionUtils.h"
18 #include "mlir/Transforms/LoopUtils.h"
19 #include "mlir/Transforms/Passes.h"
20 
21 #define DEBUG_TYPE "test-loop-fusion"
22 
23 using namespace mlir;
24 
25 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
26 
27 static llvm::cl::opt<bool> clTestDependenceCheck(
28     "test-loop-fusion-dependence-check",
29     llvm::cl::desc("Enable testing of loop fusion dependence check"),
30     llvm::cl::cat(clOptionsCategory));
31 
32 static llvm::cl::opt<bool> clTestSliceComputation(
33     "test-loop-fusion-slice-computation",
34     llvm::cl::desc("Enable testing of loop fusion slice computation"),
35     llvm::cl::cat(clOptionsCategory));
36 
37 static llvm::cl::opt<bool> clTestLoopFusionTransformation(
38     "test-loop-fusion-transformation",
39     llvm::cl::desc("Enable testing of loop fusion transformation"),
40     llvm::cl::cat(clOptionsCategory));
41 
42 namespace {
43 
44 struct TestLoopFusion : public PassWrapper<TestLoopFusion, FunctionPass> {
45   void runOnFunction() override;
46 };
47 
48 } // end anonymous namespace
49 
50 // Run fusion dependence check on 'loops[i]' and 'loops[j]' at loop depths
51 // in range ['loopDepth' + 1, 'maxLoopDepth'].
52 // Emits a remark on 'loops[i]' if a fusion-preventing dependence exists.
53 // Returns false as IR is not transformed.
testDependenceCheck(AffineForOp srcForOp,AffineForOp dstForOp,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)54 static bool testDependenceCheck(AffineForOp srcForOp, AffineForOp dstForOp,
55                                 unsigned i, unsigned j, unsigned loopDepth,
56                                 unsigned maxLoopDepth) {
57   mlir::ComputationSliceState sliceUnion;
58   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
59     FusionResult result =
60         mlir::canFuseLoops(srcForOp, dstForOp, d, &sliceUnion);
61     if (result.value == FusionResult::FailBlockDependence) {
62       srcForOp->emitRemark("block-level dependence preventing"
63                            " fusion of loop nest ")
64           << i << " into loop nest " << j << " at depth " << loopDepth;
65     }
66   }
67   return false;
68 }
69 
70 // Returns the index of 'op' in its block.
getBlockIndex(Operation & op)71 static unsigned getBlockIndex(Operation &op) {
72   unsigned index = 0;
73   for (auto &opX : *op.getBlock()) {
74     if (&op == &opX)
75       break;
76     ++index;
77   }
78   return index;
79 }
80 
81 // Returns a string representation of 'sliceUnion'.
getSliceStr(const mlir::ComputationSliceState & sliceUnion)82 static std::string getSliceStr(const mlir::ComputationSliceState &sliceUnion) {
83   std::string result;
84   llvm::raw_string_ostream os(result);
85   // Slice insertion point format [loop-depth, operation-block-index]
86   unsigned ipd = getNestingDepth(&*sliceUnion.insertPoint);
87   unsigned ipb = getBlockIndex(*sliceUnion.insertPoint);
88   os << "insert point: (" << std::to_string(ipd) << ", " << std::to_string(ipb)
89      << ")";
90   assert(sliceUnion.lbs.size() == sliceUnion.ubs.size());
91   os << " loop bounds: ";
92   for (unsigned k = 0, e = sliceUnion.lbs.size(); k < e; ++k) {
93     os << '[';
94     sliceUnion.lbs[k].print(os);
95     os << ", ";
96     sliceUnion.ubs[k].print(os);
97     os << "] ";
98   }
99   return os.str();
100 }
101 
102 // Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths
103 // in range ['loopDepth' + 1, 'maxLoopDepth'].
104 // Emits a string representation of the slice union as a remark on 'loops[j]'.
105 // Returns false as IR is not transformed.
testSliceComputation(AffineForOp forOpA,AffineForOp forOpB,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)106 static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB,
107                                  unsigned i, unsigned j, unsigned loopDepth,
108                                  unsigned maxLoopDepth) {
109   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
110     mlir::ComputationSliceState sliceUnion;
111     FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
112     if (result.value == FusionResult::Success) {
113       forOpB->emitRemark("slice (")
114           << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
115           << " : " << getSliceStr(sliceUnion) << ")";
116     }
117   }
118   return false;
119 }
120 
121 // Attempts to fuse 'forOpA' into 'forOpB' at loop depths in range
122 // ['loopDepth' + 1, 'maxLoopDepth'].
123 // Returns true if loops were successfully fused, false otherwise.
testLoopFusionTransformation(AffineForOp forOpA,AffineForOp forOpB,unsigned i,unsigned j,unsigned loopDepth,unsigned maxLoopDepth)124 static bool testLoopFusionTransformation(AffineForOp forOpA, AffineForOp forOpB,
125                                          unsigned i, unsigned j,
126                                          unsigned loopDepth,
127                                          unsigned maxLoopDepth) {
128   for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
129     mlir::ComputationSliceState sliceUnion;
130     FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
131     if (result.value == FusionResult::Success) {
132       mlir::fuseLoops(forOpA, forOpB, sliceUnion);
133       // Note: 'forOpA' is removed to simplify test output. A proper loop
134       // fusion pass should check the data dependence graph and run memref
135       // region analysis to ensure removing 'forOpA' is safe.
136       forOpA.erase();
137       return true;
138     }
139   }
140   return false;
141 }
142 
143 using LoopFunc = function_ref<bool(AffineForOp, AffineForOp, unsigned, unsigned,
144                                    unsigned, unsigned)>;
145 
146 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
147 // If 'return_on_change' is true, returns on first invocation of 'fn' which
148 // returns true.
iterateLoops(ArrayRef<SmallVector<AffineForOp,2>> depthToLoops,LoopFunc fn,bool return_on_change=false)149 static bool iterateLoops(ArrayRef<SmallVector<AffineForOp, 2>> depthToLoops,
150                          LoopFunc fn, bool return_on_change = false) {
151   bool changed = false;
152   for (unsigned loopDepth = 0, end = depthToLoops.size(); loopDepth < end;
153        ++loopDepth) {
154     auto &loops = depthToLoops[loopDepth];
155     unsigned numLoops = loops.size();
156     for (unsigned j = 0; j < numLoops; ++j) {
157       for (unsigned k = 0; k < numLoops; ++k) {
158         if (j != k)
159           changed |=
160               fn(loops[j], loops[k], j, k, loopDepth, depthToLoops.size());
161         if (changed && return_on_change)
162           return true;
163       }
164     }
165   }
166   return changed;
167 }
168 
runOnFunction()169 void TestLoopFusion::runOnFunction() {
170   std::vector<SmallVector<AffineForOp, 2>> depthToLoops;
171   if (clTestLoopFusionTransformation) {
172     // Run loop fusion until a fixed point is reached.
173     do {
174       depthToLoops.clear();
175       // Gather all AffineForOps by loop depth.
176       gatherLoops(getFunction(), depthToLoops);
177 
178       // Try to fuse all combinations of src/dst loop nests in 'depthToLoops'.
179     } while (iterateLoops(depthToLoops, testLoopFusionTransformation,
180                           /*return_on_change=*/true));
181     return;
182   }
183 
184   // Gather all AffineForOps by loop depth.
185   gatherLoops(getFunction(), depthToLoops);
186 
187   // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
188   if (clTestDependenceCheck)
189     iterateLoops(depthToLoops, testDependenceCheck);
190   if (clTestSliceComputation)
191     iterateLoops(depthToLoops, testSliceComputation);
192 }
193 
194 namespace mlir {
195 namespace test {
registerTestLoopFusion()196 void registerTestLoopFusion() {
197   PassRegistration<TestLoopFusion>("test-loop-fusion",
198                                    "Tests loop fusion utility functions.");
199 }
200 } // namespace test
201 } // namespace mlir
202