1 //===- MlirOptMain.cpp - MLIR Optimizer Driver ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a utility that runs an optimization pass and prints the result back
10 // out. It is designed to support unit testing.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Support/MlirOptMain.h"
15 #include "mlir/IR/AsmState.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/Diagnostics.h"
19 #include "mlir/IR/Dialect.h"
20 #include "mlir/IR/Location.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/Parser.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Support/FileUtilities.h"
26 #include "mlir/Support/ToolUtilities.h"
27 #include "mlir/Transforms/Passes.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Support/FileUtilities.h"
30 #include "llvm/Support/InitLLVM.h"
31 #include "llvm/Support/Regex.h"
32 #include "llvm/Support/SourceMgr.h"
33 #include "llvm/Support/ToolOutputFile.h"
34
35 using namespace mlir;
36 using namespace llvm;
37 using llvm::SMLoc;
38
39 /// Perform the actions on the input file indicated by the command line flags
40 /// within the specified context.
41 ///
42 /// This typically parses the main source file, runs zero or more optimization
43 /// passes, then prints the output.
44 ///
performActions(raw_ostream & os,bool verifyDiagnostics,bool verifyPasses,SourceMgr & sourceMgr,MLIRContext * context,const PassPipelineCLParser & passPipeline)45 static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
46 bool verifyPasses, SourceMgr &sourceMgr,
47 MLIRContext *context,
48 const PassPipelineCLParser &passPipeline) {
49 // Disable multi-threading when parsing the input file. This removes the
50 // unnecessary/costly context synchronization when parsing.
51 bool wasThreadingEnabled = context->isMultithreadingEnabled();
52 context->disableMultithreading();
53
54 // Parse the input file and reset the context threading state.
55 OwningModuleRef module(parseSourceFile(sourceMgr, context));
56 context->enableMultithreading(wasThreadingEnabled);
57 if (!module)
58 return failure();
59
60 // Apply any pass manager command line options.
61 PassManager pm(context, OpPassManager::Nesting::Implicit);
62 pm.enableVerifier(verifyPasses);
63 applyPassManagerCLOptions(pm);
64
65 auto errorHandler = [&](const Twine &msg) {
66 emitError(UnknownLoc::get(context)) << msg;
67 return failure();
68 };
69
70 // Build the provided pipeline.
71 if (failed(passPipeline.addToPipeline(pm, errorHandler)))
72 return failure();
73
74 // Run the pipeline.
75 if (failed(pm.run(*module)))
76 return failure();
77
78 // Print the output.
79 module->print(os);
80 os << '\n';
81 return success();
82 }
83
84 /// Parses the memory buffer. If successfully, run a series of passes against
85 /// it and print the result.
processBuffer(raw_ostream & os,std::unique_ptr<MemoryBuffer> ownedBuffer,bool verifyDiagnostics,bool verifyPasses,bool allowUnregisteredDialects,bool preloadDialectsInContext,const PassPipelineCLParser & passPipeline,DialectRegistry & registry)86 static LogicalResult processBuffer(raw_ostream &os,
87 std::unique_ptr<MemoryBuffer> ownedBuffer,
88 bool verifyDiagnostics, bool verifyPasses,
89 bool allowUnregisteredDialects,
90 bool preloadDialectsInContext,
91 const PassPipelineCLParser &passPipeline,
92 DialectRegistry ®istry) {
93 // Tell sourceMgr about this buffer, which is what the parser will pick up.
94 SourceMgr sourceMgr;
95 sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
96
97 // Parse the input file.
98 MLIRContext context;
99 registry.appendTo(context.getDialectRegistry());
100 if (preloadDialectsInContext)
101 registry.loadAll(&context);
102 context.allowUnregisteredDialects(allowUnregisteredDialects);
103 context.printOpOnDiagnostic(!verifyDiagnostics);
104
105 // If we are in verify diagnostics mode then we have a lot of work to do,
106 // otherwise just perform the actions without worrying about it.
107 if (!verifyDiagnostics) {
108 SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
109 return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr,
110 &context, passPipeline);
111 }
112
113 SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
114
115 // Do any processing requested by command line flags. We don't care whether
116 // these actions succeed or fail, we only care what diagnostics they produce
117 // and whether they match our expectations.
118 performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context,
119 passPipeline);
120
121 // Verify the diagnostic handler to make sure that each of the diagnostics
122 // matched.
123 return sourceMgrHandler.verify();
124 }
125
MlirOptMain(raw_ostream & outputStream,std::unique_ptr<MemoryBuffer> buffer,const PassPipelineCLParser & passPipeline,DialectRegistry & registry,bool splitInputFile,bool verifyDiagnostics,bool verifyPasses,bool allowUnregisteredDialects,bool preloadDialectsInContext)126 LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
127 std::unique_ptr<MemoryBuffer> buffer,
128 const PassPipelineCLParser &passPipeline,
129 DialectRegistry ®istry, bool splitInputFile,
130 bool verifyDiagnostics, bool verifyPasses,
131 bool allowUnregisteredDialects,
132 bool preloadDialectsInContext) {
133 // The split-input-file mode is a very specific mode that slices the file
134 // up into small pieces and checks each independently.
135 if (splitInputFile)
136 return splitAndProcessBuffer(
137 std::move(buffer),
138 [&](std::unique_ptr<MemoryBuffer> chunkBuffer, raw_ostream &os) {
139 return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics,
140 verifyPasses, allowUnregisteredDialects,
141 preloadDialectsInContext, passPipeline,
142 registry);
143 },
144 outputStream);
145
146 return processBuffer(outputStream, std::move(buffer), verifyDiagnostics,
147 verifyPasses, allowUnregisteredDialects,
148 preloadDialectsInContext, passPipeline, registry);
149 }
150
MlirOptMain(int argc,char ** argv,llvm::StringRef toolName,DialectRegistry & registry,bool preloadDialectsInContext)151 LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
152 DialectRegistry ®istry,
153 bool preloadDialectsInContext) {
154 static cl::opt<std::string> inputFilename(
155 cl::Positional, cl::desc("<input file>"), cl::init("-"));
156
157 static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"),
158 cl::value_desc("filename"),
159 cl::init("-"));
160
161 static cl::opt<bool> splitInputFile(
162 "split-input-file",
163 cl::desc("Split the input file into pieces and process each "
164 "chunk independently"),
165 cl::init(false));
166
167 static cl::opt<bool> verifyDiagnostics(
168 "verify-diagnostics",
169 cl::desc("Check that emitted diagnostics match "
170 "expected-* lines on the corresponding line"),
171 cl::init(false));
172
173 static cl::opt<bool> verifyPasses(
174 "verify-each",
175 cl::desc("Run the verifier after each transformation pass"),
176 cl::init(true));
177
178 static cl::opt<bool> allowUnregisteredDialects(
179 "allow-unregistered-dialect",
180 cl::desc("Allow operation with no registered dialects"), cl::init(false));
181
182 static cl::opt<bool> showDialects(
183 "show-dialects", cl::desc("Print the list of registered dialects"),
184 cl::init(false));
185
186 InitLLVM y(argc, argv);
187
188 // Register any command line options.
189 registerAsmPrinterCLOptions();
190 registerMLIRContextCLOptions();
191 registerPassManagerCLOptions();
192 PassPipelineCLParser passPipeline("", "Compiler passes to run");
193
194 // Build the list of dialects as a header for the --help message.
195 std::string helpHeader = (toolName + "\nAvailable Dialects: ").str();
196 {
197 llvm::raw_string_ostream os(helpHeader);
198 MLIRContext context;
199 interleaveComma(registry, os, [&](auto ®istryEntry) {
200 StringRef name = registryEntry.first;
201 os << name;
202 });
203 }
204 // Parse pass names in main to ensure static initialization completed.
205 cl::ParseCommandLineOptions(argc, argv, helpHeader);
206
207 if (showDialects) {
208 llvm::outs() << "Available Dialects:\n";
209 interleave(
210 registry, llvm::outs(),
211 [](auto ®istryEntry) { llvm::outs() << registryEntry.first; }, "\n");
212 return success();
213 }
214
215 // Set up the input file.
216 std::string errorMessage;
217 auto file = openInputFile(inputFilename, &errorMessage);
218 if (!file) {
219 llvm::errs() << errorMessage << "\n";
220 return failure();
221 }
222
223 auto output = openOutputFile(outputFilename, &errorMessage);
224 if (!output) {
225 llvm::errs() << errorMessage << "\n";
226 return failure();
227 }
228
229 if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry,
230 splitInputFile, verifyDiagnostics, verifyPasses,
231 allowUnregisteredDialects, preloadDialectsInContext)))
232 return failure();
233
234 // Keep the output file if the invocation of MlirOptMain was successful.
235 output->keep();
236 return success();
237 }
238