1 //===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert scf.parallel operations into OpenMP
10 // parallel loops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
15 #include "../PassDetail.h"
16 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
17 #include "mlir/Dialect/SCF/SCF.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 
20 using namespace mlir;
21 
22 namespace {
23 
24 /// Converts SCF parallel operation into an OpenMP workshare loop construct.
25 struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
26   using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
27 
matchAndRewrite__anon47b919990111::ParallelOpLowering28   LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
29                                 PatternRewriter &rewriter) const override {
30     // TODO: add support for reductions when OpenMP loops have them.
31     if (parallelOp.getNumResults() != 0)
32       return rewriter.notifyMatchFailure(
33           parallelOp,
34           "OpenMP dialect does not yet support loops with reductions");
35 
36     // Replace SCF yield with OpenMP yield.
37     {
38       OpBuilder::InsertionGuard guard(rewriter);
39       rewriter.setInsertionPointToEnd(parallelOp.getBody());
40       assert(llvm::hasSingleElement(parallelOp.region()) &&
41              "expected scf.parallel to have one block");
42       rewriter.replaceOpWithNewOp<omp::YieldOp>(
43           parallelOp.getBody()->getTerminator(), ValueRange());
44     }
45 
46     // Replace the loop.
47     auto loop = rewriter.create<omp::WsLoopOp>(
48         parallelOp.getLoc(), parallelOp.lowerBound(), parallelOp.upperBound(),
49         parallelOp.step());
50     rewriter.inlineRegionBefore(parallelOp.region(), loop.region(),
51                                 loop.region().begin());
52     rewriter.eraseOp(parallelOp);
53     return success();
54   }
55 };
56 
57 /// Inserts OpenMP "parallel" operations around top-level SCF "parallel"
58 /// operations in the given function. This is implemented as a direct IR
59 /// modification rather than as a conversion pattern because it does not
60 /// modify the top-level operation it matches, which is a requirement for
61 /// rewrite patterns.
62 //
63 // TODO: consider creating nested parallel operations when necessary.
insertOpenMPParallel(FuncOp func)64 static void insertOpenMPParallel(FuncOp func) {
65   // Collect top-level SCF "parallel" ops.
66   SmallVector<scf::ParallelOp, 4> topLevelParallelOps;
67   func.walk([&topLevelParallelOps](scf::ParallelOp parallelOp) {
68     // Ignore ops that are already within OpenMP parallel construct.
69     if (!parallelOp->getParentOfType<scf::ParallelOp>())
70       topLevelParallelOps.push_back(parallelOp);
71   });
72 
73   // Wrap SCF ops into OpenMP "parallel" ops.
74   for (scf::ParallelOp parallelOp : topLevelParallelOps) {
75     OpBuilder builder(parallelOp);
76     auto omp = builder.create<omp::ParallelOp>(parallelOp.getLoc());
77     Block *block = builder.createBlock(&omp.getRegion());
78     builder.create<omp::TerminatorOp>(parallelOp.getLoc());
79     block->getOperations().splice(block->begin(),
80                                   parallelOp->getBlock()->getOperations(),
81                                   parallelOp.getOperation());
82   }
83 }
84 
85 /// Applies the conversion patterns in the given function.
applyPatterns(FuncOp func)86 static LogicalResult applyPatterns(FuncOp func) {
87   ConversionTarget target(*func.getContext());
88   target.addIllegalOp<scf::ParallelOp>();
89   target.addDynamicallyLegalOp<scf::YieldOp>(
90       [](scf::YieldOp op) { return !isa<scf::ParallelOp>(op->getParentOp()); });
91   target.addLegalDialect<omp::OpenMPDialect>();
92 
93   OwningRewritePatternList patterns;
94   patterns.insert<ParallelOpLowering>(func.getContext());
95   FrozenRewritePatternList frozen(std::move(patterns));
96   return applyPartialConversion(func, target, frozen);
97 }
98 
99 /// A pass converting SCF operations to OpenMP operations.
100 struct SCFToOpenMPPass : public ConvertSCFToOpenMPBase<SCFToOpenMPPass> {
101   /// Pass entry point.
runOnFunction__anon47b919990111::SCFToOpenMPPass102   void runOnFunction() override {
103     insertOpenMPParallel(getFunction());
104     if (failed(applyPatterns(getFunction())))
105       signalPassFailure();
106   }
107 };
108 
109 } // end namespace
110 
createConvertSCFToOpenMPPass()111 std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertSCFToOpenMPPass() {
112   return std::make_unique<SCFToOpenMPPass>();
113 }
114