1 //===- OptUtils.cpp - MLIR Execution Engine optimization pass utilities ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the utility functions to trigger LLVM optimizations from
10 // MLIR Execution Engine.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/ExecutionEngine/OptUtils.h"
15
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/Analysis/TargetTransformInfo.h"
18 #include "llvm/IR/LegacyPassManager.h"
19 #include "llvm/IR/LegacyPassNameParser.h"
20 #include "llvm/IR/Module.h"
21 #include "llvm/InitializePasses.h"
22 #include "llvm/Pass.h"
23 #include "llvm/Support/Allocator.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/Error.h"
26 #include "llvm/Support/StringSaver.h"
27 #include "llvm/Target/TargetMachine.h"
28 #include "llvm/Transforms/Coroutines.h"
29 #include "llvm/Transforms/IPO.h"
30 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
31 #include <climits>
32 #include <mutex>
33
34 // Run the module and function passes managed by the module manager.
runPasses(llvm::legacy::PassManager & modulePM,llvm::legacy::FunctionPassManager & funcPM,llvm::Module & m)35 static void runPasses(llvm::legacy::PassManager &modulePM,
36 llvm::legacy::FunctionPassManager &funcPM,
37 llvm::Module &m) {
38 funcPM.doInitialization();
39 for (auto &func : m) {
40 funcPM.run(func);
41 }
42 funcPM.doFinalization();
43 modulePM.run(m);
44 }
45
46 // Initialize basic LLVM transformation passes under lock.
initializeLLVMPasses()47 void mlir::initializeLLVMPasses() {
48 static std::mutex mutex;
49 std::lock_guard<std::mutex> lock(mutex);
50
51 auto ®istry = *llvm::PassRegistry::getPassRegistry();
52 llvm::initializeCore(registry);
53 llvm::initializeTransformUtils(registry);
54 llvm::initializeScalarOpts(registry);
55 llvm::initializeIPO(registry);
56 llvm::initializeInstCombine(registry);
57 llvm::initializeAggressiveInstCombine(registry);
58 llvm::initializeAnalysis(registry);
59 llvm::initializeVectorization(registry);
60 llvm::initializeCoroutines(registry);
61 }
62
63 // Populate pass managers according to the optimization and size levels.
64 // This behaves similarly to LLVM opt.
populatePassManagers(llvm::legacy::PassManager & modulePM,llvm::legacy::FunctionPassManager & funcPM,unsigned optLevel,unsigned sizeLevel,llvm::TargetMachine * targetMachine)65 static void populatePassManagers(llvm::legacy::PassManager &modulePM,
66 llvm::legacy::FunctionPassManager &funcPM,
67 unsigned optLevel, unsigned sizeLevel,
68 llvm::TargetMachine *targetMachine) {
69 llvm::PassManagerBuilder builder;
70 builder.OptLevel = optLevel;
71 builder.SizeLevel = sizeLevel;
72 builder.Inliner = llvm::createFunctionInliningPass(
73 optLevel, sizeLevel, /*DisableInlineHotCallSite=*/false);
74 builder.LoopVectorize = optLevel > 1 && sizeLevel < 2;
75 builder.SLPVectorize = optLevel > 1 && sizeLevel < 2;
76 builder.DisableUnrollLoops = (optLevel == 0);
77
78 // Add all coroutine passes to the builder.
79 addCoroutinePassesToExtensionPoints(builder);
80
81 if (targetMachine) {
82 // Add pass to initialize TTI for this specific target. Otherwise, TTI will
83 // be initialized to NoTTIImpl by default.
84 modulePM.add(createTargetTransformInfoWrapperPass(
85 targetMachine->getTargetIRAnalysis()));
86 funcPM.add(createTargetTransformInfoWrapperPass(
87 targetMachine->getTargetIRAnalysis()));
88 }
89
90 builder.populateModulePassManager(modulePM);
91 builder.populateFunctionPassManager(funcPM);
92 }
93
94 // Create and return a lambda that uses LLVM pass manager builder to set up
95 // optimizations based on the given level.
96 std::function<llvm::Error(llvm::Module *)>
makeOptimizingTransformer(unsigned optLevel,unsigned sizeLevel,llvm::TargetMachine * targetMachine)97 mlir::makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel,
98 llvm::TargetMachine *targetMachine) {
99 return [optLevel, sizeLevel, targetMachine](llvm::Module *m) -> llvm::Error {
100 llvm::legacy::PassManager modulePM;
101 llvm::legacy::FunctionPassManager funcPM(m);
102 populatePassManagers(modulePM, funcPM, optLevel, sizeLevel, targetMachine);
103 runPasses(modulePM, funcPM, *m);
104
105 return llvm::Error::success();
106 };
107 }
108
109 // Create and return a lambda that is given a set of passes to run, plus an
110 // optional optimization level to pre-populate the pass manager.
makeLLVMPassesTransformer(llvm::ArrayRef<const llvm::PassInfo * > llvmPasses,llvm::Optional<unsigned> mbOptLevel,llvm::TargetMachine * targetMachine,unsigned optPassesInsertPos)111 std::function<llvm::Error(llvm::Module *)> mlir::makeLLVMPassesTransformer(
112 llvm::ArrayRef<const llvm::PassInfo *> llvmPasses,
113 llvm::Optional<unsigned> mbOptLevel, llvm::TargetMachine *targetMachine,
114 unsigned optPassesInsertPos) {
115 return [llvmPasses, mbOptLevel, optPassesInsertPos,
116 targetMachine](llvm::Module *m) -> llvm::Error {
117 llvm::legacy::PassManager modulePM;
118 llvm::legacy::FunctionPassManager funcPM(m);
119
120 bool insertOptPasses = mbOptLevel.hasValue();
121 for (unsigned i = 0, e = llvmPasses.size(); i < e; ++i) {
122 const auto *passInfo = llvmPasses[i];
123 if (!passInfo->getNormalCtor())
124 continue;
125
126 if (insertOptPasses && optPassesInsertPos == i) {
127 populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0,
128 targetMachine);
129 insertOptPasses = false;
130 }
131
132 auto *pass = passInfo->createPass();
133 if (!pass)
134 return llvm::make_error<llvm::StringError>(
135 "could not create pass " + passInfo->getPassName(),
136 llvm::inconvertibleErrorCode());
137 modulePM.add(pass);
138 }
139
140 if (insertOptPasses)
141 populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0,
142 targetMachine);
143
144 runPasses(modulePM, funcPM, *m);
145 return llvm::Error::success();
146 };
147 }
148