1 //===- TestLoopMapping.cpp --- Parametric loop mapping pass ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to parametrically map scf.for loops to virtual
10 // processing element dimensions.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/SCF/SCF.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/LoopUtils.h"
18 #include "mlir/Transforms/Passes.h"
19 
20 #include "llvm/ADT/SetVector.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 class TestLoopMappingPass
26     : public PassWrapper<TestLoopMappingPass, FunctionPass> {
27 public:
TestLoopMappingPass()28   explicit TestLoopMappingPass() {}
29 
runOnFunction()30   void runOnFunction() override {
31     FuncOp func = getFunction();
32 
33     // SSA values for the transformation are created out of thin air by
34     // unregistered "new_processor_id_and_range" operations. This is enough to
35     // emulate mapping conditions.
36     SmallVector<Value, 8> processorIds, numProcessors;
37     func.walk([&processorIds, &numProcessors](Operation *op) {
38       if (op->getName().getStringRef() != "new_processor_id_and_range")
39         return;
40       processorIds.push_back(op->getResult(0));
41       numProcessors.push_back(op->getResult(1));
42     });
43 
44     func.walk([&processorIds, &numProcessors](scf::ForOp op) {
45       // Ignore nested loops.
46       if (op->getParentRegion()->getParentOfType<scf::ForOp>())
47         return;
48       mapLoopToProcessorIds(op, processorIds, numProcessors);
49     });
50   }
51 };
52 } // namespace
53 
54 namespace mlir {
55 namespace test {
registerTestLoopMappingPass()56 void registerTestLoopMappingPass() {
57   PassRegistration<TestLoopMappingPass>(
58       "test-mapping-to-processing-elements",
59       "test mapping a single loop on a virtual processor grid");
60 }
61 } // namespace test
62 } // namespace mlir
63