1 //===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert scf.parallel operations into OpenMP
10 // parallel loops.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
15 #include "../PassDetail.h"
16 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
17 #include "mlir/Dialect/SCF/SCF.h"
18 #include "mlir/Transforms/DialectConversion.h"
19
20 using namespace mlir;
21
22 namespace {
23
24 /// Converts SCF parallel operation into an OpenMP workshare loop construct.
25 struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
26 using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
27
matchAndRewrite__anon47b919990111::ParallelOpLowering28 LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
29 PatternRewriter &rewriter) const override {
30 // TODO: add support for reductions when OpenMP loops have them.
31 if (parallelOp.getNumResults() != 0)
32 return rewriter.notifyMatchFailure(
33 parallelOp,
34 "OpenMP dialect does not yet support loops with reductions");
35
36 // Replace SCF yield with OpenMP yield.
37 {
38 OpBuilder::InsertionGuard guard(rewriter);
39 rewriter.setInsertionPointToEnd(parallelOp.getBody());
40 assert(llvm::hasSingleElement(parallelOp.region()) &&
41 "expected scf.parallel to have one block");
42 rewriter.replaceOpWithNewOp<omp::YieldOp>(
43 parallelOp.getBody()->getTerminator(), ValueRange());
44 }
45
46 // Replace the loop.
47 auto loop = rewriter.create<omp::WsLoopOp>(
48 parallelOp.getLoc(), parallelOp.lowerBound(), parallelOp.upperBound(),
49 parallelOp.step());
50 rewriter.inlineRegionBefore(parallelOp.region(), loop.region(),
51 loop.region().begin());
52 rewriter.eraseOp(parallelOp);
53 return success();
54 }
55 };
56
57 /// Inserts OpenMP "parallel" operations around top-level SCF "parallel"
58 /// operations in the given function. This is implemented as a direct IR
59 /// modification rather than as a conversion pattern because it does not
60 /// modify the top-level operation it matches, which is a requirement for
61 /// rewrite patterns.
62 //
63 // TODO: consider creating nested parallel operations when necessary.
insertOpenMPParallel(FuncOp func)64 static void insertOpenMPParallel(FuncOp func) {
65 // Collect top-level SCF "parallel" ops.
66 SmallVector<scf::ParallelOp, 4> topLevelParallelOps;
67 func.walk([&topLevelParallelOps](scf::ParallelOp parallelOp) {
68 // Ignore ops that are already within OpenMP parallel construct.
69 if (!parallelOp->getParentOfType<scf::ParallelOp>())
70 topLevelParallelOps.push_back(parallelOp);
71 });
72
73 // Wrap SCF ops into OpenMP "parallel" ops.
74 for (scf::ParallelOp parallelOp : topLevelParallelOps) {
75 OpBuilder builder(parallelOp);
76 auto omp = builder.create<omp::ParallelOp>(parallelOp.getLoc());
77 Block *block = builder.createBlock(&omp.getRegion());
78 builder.create<omp::TerminatorOp>(parallelOp.getLoc());
79 block->getOperations().splice(block->begin(),
80 parallelOp->getBlock()->getOperations(),
81 parallelOp.getOperation());
82 }
83 }
84
85 /// Applies the conversion patterns in the given function.
applyPatterns(FuncOp func)86 static LogicalResult applyPatterns(FuncOp func) {
87 ConversionTarget target(*func.getContext());
88 target.addIllegalOp<scf::ParallelOp>();
89 target.addDynamicallyLegalOp<scf::YieldOp>(
90 [](scf::YieldOp op) { return !isa<scf::ParallelOp>(op->getParentOp()); });
91 target.addLegalDialect<omp::OpenMPDialect>();
92
93 OwningRewritePatternList patterns;
94 patterns.insert<ParallelOpLowering>(func.getContext());
95 FrozenRewritePatternList frozen(std::move(patterns));
96 return applyPartialConversion(func, target, frozen);
97 }
98
99 /// A pass converting SCF operations to OpenMP operations.
100 struct SCFToOpenMPPass : public ConvertSCFToOpenMPBase<SCFToOpenMPPass> {
101 /// Pass entry point.
runOnFunction__anon47b919990111::SCFToOpenMPPass102 void runOnFunction() override {
103 insertOpenMPParallel(getFunction());
104 if (failed(applyPatterns(getFunction())))
105 signalPassFailure();
106 }
107 };
108
109 } // end namespace
110
createConvertSCFToOpenMPPass()111 std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertSCFToOpenMPPass() {
112 return std::make_unique<SCFToOpenMPPass>();
113 }
114