1 //===- MlirOptMain.cpp - MLIR Optimizer Driver ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a utility that runs an optimization pass and prints the result back
10 // out. It is designed to support unit testing.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Support/MlirOptMain.h"
15 #include "mlir/IR/AsmState.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/Diagnostics.h"
19 #include "mlir/IR/Dialect.h"
20 #include "mlir/IR/Location.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/Parser.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Support/FileUtilities.h"
26 #include "mlir/Support/ToolUtilities.h"
27 #include "mlir/Transforms/Passes.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Support/FileUtilities.h"
30 #include "llvm/Support/InitLLVM.h"
31 #include "llvm/Support/Regex.h"
32 #include "llvm/Support/SourceMgr.h"
33 #include "llvm/Support/ToolOutputFile.h"
34 
35 using namespace mlir;
36 using namespace llvm;
37 using llvm::SMLoc;
38 
39 /// Perform the actions on the input file indicated by the command line flags
40 /// within the specified context.
41 ///
42 /// This typically parses the main source file, runs zero or more optimization
43 /// passes, then prints the output.
44 ///
performActions(raw_ostream & os,bool verifyDiagnostics,bool verifyPasses,SourceMgr & sourceMgr,MLIRContext * context,const PassPipelineCLParser & passPipeline)45 static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
46                                     bool verifyPasses, SourceMgr &sourceMgr,
47                                     MLIRContext *context,
48                                     const PassPipelineCLParser &passPipeline) {
49   // Disable multi-threading when parsing the input file. This removes the
50   // unnecessary/costly context synchronization when parsing.
51   bool wasThreadingEnabled = context->isMultithreadingEnabled();
52   context->disableMultithreading();
53 
54   // Parse the input file and reset the context threading state.
55   OwningModuleRef module(parseSourceFile(sourceMgr, context));
56   context->enableMultithreading(wasThreadingEnabled);
57   if (!module)
58     return failure();
59 
60   // Apply any pass manager command line options.
61   PassManager pm(context, OpPassManager::Nesting::Implicit);
62   pm.enableVerifier(verifyPasses);
63   applyPassManagerCLOptions(pm);
64 
65   auto errorHandler = [&](const Twine &msg) {
66     emitError(UnknownLoc::get(context)) << msg;
67     return failure();
68   };
69 
70   // Build the provided pipeline.
71   if (failed(passPipeline.addToPipeline(pm, errorHandler)))
72     return failure();
73 
74   // Run the pipeline.
75   if (failed(pm.run(*module)))
76     return failure();
77 
78   // Print the output.
79   module->print(os);
80   os << '\n';
81   return success();
82 }
83 
84 /// Parses the memory buffer.  If successfully, run a series of passes against
85 /// it and print the result.
processBuffer(raw_ostream & os,std::unique_ptr<MemoryBuffer> ownedBuffer,bool verifyDiagnostics,bool verifyPasses,bool allowUnregisteredDialects,bool preloadDialectsInContext,const PassPipelineCLParser & passPipeline,DialectRegistry & registry)86 static LogicalResult processBuffer(raw_ostream &os,
87                                    std::unique_ptr<MemoryBuffer> ownedBuffer,
88                                    bool verifyDiagnostics, bool verifyPasses,
89                                    bool allowUnregisteredDialects,
90                                    bool preloadDialectsInContext,
91                                    const PassPipelineCLParser &passPipeline,
92                                    DialectRegistry &registry) {
93   // Tell sourceMgr about this buffer, which is what the parser will pick up.
94   SourceMgr sourceMgr;
95   sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
96 
97   // Parse the input file.
98   MLIRContext context;
99   registry.appendTo(context.getDialectRegistry());
100   if (preloadDialectsInContext)
101     registry.loadAll(&context);
102   context.allowUnregisteredDialects(allowUnregisteredDialects);
103   context.printOpOnDiagnostic(!verifyDiagnostics);
104 
105   // If we are in verify diagnostics mode then we have a lot of work to do,
106   // otherwise just perform the actions without worrying about it.
107   if (!verifyDiagnostics) {
108     SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
109     return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr,
110                           &context, passPipeline);
111   }
112 
113   SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
114 
115   // Do any processing requested by command line flags.  We don't care whether
116   // these actions succeed or fail, we only care what diagnostics they produce
117   // and whether they match our expectations.
118   performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context,
119                  passPipeline);
120 
121   // Verify the diagnostic handler to make sure that each of the diagnostics
122   // matched.
123   return sourceMgrHandler.verify();
124 }
125 
MlirOptMain(raw_ostream & outputStream,std::unique_ptr<MemoryBuffer> buffer,const PassPipelineCLParser & passPipeline,DialectRegistry & registry,bool splitInputFile,bool verifyDiagnostics,bool verifyPasses,bool allowUnregisteredDialects,bool preloadDialectsInContext)126 LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
127                                 std::unique_ptr<MemoryBuffer> buffer,
128                                 const PassPipelineCLParser &passPipeline,
129                                 DialectRegistry &registry, bool splitInputFile,
130                                 bool verifyDiagnostics, bool verifyPasses,
131                                 bool allowUnregisteredDialects,
132                                 bool preloadDialectsInContext) {
133   // The split-input-file mode is a very specific mode that slices the file
134   // up into small pieces and checks each independently.
135   if (splitInputFile)
136     return splitAndProcessBuffer(
137         std::move(buffer),
138         [&](std::unique_ptr<MemoryBuffer> chunkBuffer, raw_ostream &os) {
139           return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics,
140                                verifyPasses, allowUnregisteredDialects,
141                                preloadDialectsInContext, passPipeline,
142                                registry);
143         },
144         outputStream);
145 
146   return processBuffer(outputStream, std::move(buffer), verifyDiagnostics,
147                        verifyPasses, allowUnregisteredDialects,
148                        preloadDialectsInContext, passPipeline, registry);
149 }
150 
MlirOptMain(int argc,char ** argv,llvm::StringRef toolName,DialectRegistry & registry,bool preloadDialectsInContext)151 LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
152                                 DialectRegistry &registry,
153                                 bool preloadDialectsInContext) {
154   static cl::opt<std::string> inputFilename(
155       cl::Positional, cl::desc("<input file>"), cl::init("-"));
156 
157   static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"),
158                                              cl::value_desc("filename"),
159                                              cl::init("-"));
160 
161   static cl::opt<bool> splitInputFile(
162       "split-input-file",
163       cl::desc("Split the input file into pieces and process each "
164                "chunk independently"),
165       cl::init(false));
166 
167   static cl::opt<bool> verifyDiagnostics(
168       "verify-diagnostics",
169       cl::desc("Check that emitted diagnostics match "
170                "expected-* lines on the corresponding line"),
171       cl::init(false));
172 
173   static cl::opt<bool> verifyPasses(
174       "verify-each",
175       cl::desc("Run the verifier after each transformation pass"),
176       cl::init(true));
177 
178   static cl::opt<bool> allowUnregisteredDialects(
179       "allow-unregistered-dialect",
180       cl::desc("Allow operation with no registered dialects"), cl::init(false));
181 
182   static cl::opt<bool> showDialects(
183       "show-dialects", cl::desc("Print the list of registered dialects"),
184       cl::init(false));
185 
186   InitLLVM y(argc, argv);
187 
188   // Register any command line options.
189   registerAsmPrinterCLOptions();
190   registerMLIRContextCLOptions();
191   registerPassManagerCLOptions();
192   PassPipelineCLParser passPipeline("", "Compiler passes to run");
193 
194   // Build the list of dialects as a header for the --help message.
195   std::string helpHeader = (toolName + "\nAvailable Dialects: ").str();
196   {
197     llvm::raw_string_ostream os(helpHeader);
198     MLIRContext context;
199     interleaveComma(registry, os, [&](auto &registryEntry) {
200       StringRef name = registryEntry.first;
201       os << name;
202     });
203   }
204   // Parse pass names in main to ensure static initialization completed.
205   cl::ParseCommandLineOptions(argc, argv, helpHeader);
206 
207   if (showDialects) {
208     llvm::outs() << "Available Dialects:\n";
209     interleave(
210         registry, llvm::outs(),
211         [](auto &registryEntry) { llvm::outs() << registryEntry.first; }, "\n");
212     return success();
213   }
214 
215   // Set up the input file.
216   std::string errorMessage;
217   auto file = openInputFile(inputFilename, &errorMessage);
218   if (!file) {
219     llvm::errs() << errorMessage << "\n";
220     return failure();
221   }
222 
223   auto output = openOutputFile(outputFilename, &errorMessage);
224   if (!output) {
225     llvm::errs() << errorMessage << "\n";
226     return failure();
227   }
228 
229   if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry,
230                          splitInputFile, verifyDiagnostics, verifyPasses,
231                          allowUnregisteredDialects, preloadDialectsInContext)))
232     return failure();
233 
234   // Keep the output file if the invocation of MlirOptMain was successful.
235   output->keep();
236   return success();
237 }
238