1 //===- PassRegistry.cpp - Pass Registration Utilities ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Pass/PassRegistry.h"
10 #include "mlir/Pass/Pass.h"
11 #include "mlir/Pass/PassManager.h"
12 #include "llvm/ADT/DenseMap.h"
13 #include "llvm/Support/ManagedStatic.h"
14 #include "llvm/Support/MemoryBuffer.h"
15 #include "llvm/Support/SourceMgr.h"
16
17 using namespace mlir;
18 using namespace detail;
19
20 /// Static mapping of all of the registered passes.
21 static llvm::ManagedStatic<DenseMap<TypeID, PassInfo>> passRegistry;
22
23 /// Static mapping of all of the registered pass pipelines.
24 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
25 passPipelineRegistry;
26
27 /// Utility to create a default registry function from a pass instance.
28 static PassRegistryFunction
buildDefaultRegistryFn(const PassAllocatorFunction & allocator)29 buildDefaultRegistryFn(const PassAllocatorFunction &allocator) {
30 return [=](OpPassManager &pm, StringRef options,
31 function_ref<LogicalResult(const Twine &)> errorHandler) {
32 std::unique_ptr<Pass> pass = allocator();
33 LogicalResult result = pass->initializeOptions(options);
34 if ((pm.getNesting() == OpPassManager::Nesting::Explicit) &&
35 pass->getOpName() && *pass->getOpName() != pm.getOpName())
36 return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() +
37 "' restricted to '" + *pass->getOpName() +
38 "' on a PassManager intended to run on '" +
39 pm.getOpName() + "', did you intend to nest?");
40 pm.addPass(std::move(pass));
41 return result;
42 };
43 }
44
45 /// Utility to print the help string for a specific option.
printOptionHelp(StringRef arg,StringRef desc,size_t indent,size_t descIndent,bool isTopLevel)46 static void printOptionHelp(StringRef arg, StringRef desc, size_t indent,
47 size_t descIndent, bool isTopLevel) {
48 size_t numSpaces = descIndent - indent - 4;
49 llvm::outs().indent(indent)
50 << "--" << llvm::left_justify(arg, numSpaces) << "- " << desc << '\n';
51 }
52
53 //===----------------------------------------------------------------------===//
54 // PassRegistry
55 //===----------------------------------------------------------------------===//
56
57 /// Print the help information for this pass. This includes the argument,
58 /// description, and any pass options. `descIndent` is the indent that the
59 /// descriptions should be aligned.
printHelpStr(size_t indent,size_t descIndent) const60 void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const {
61 printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
62 /*isTopLevel=*/true);
63 // If this entry has options, print the help for those as well.
64 optHandler([=](const PassOptions &options) {
65 options.printHelp(indent, descIndent);
66 });
67 }
68
69 /// Return the maximum width required when printing the options of this
70 /// entry.
getOptionWidth() const71 size_t PassRegistryEntry::getOptionWidth() const {
72 size_t maxLen = 0;
73 optHandler([&](const PassOptions &options) mutable {
74 maxLen = options.getOptionWidth() + 2;
75 });
76 return maxLen;
77 }
78
79 //===----------------------------------------------------------------------===//
80 // PassPipelineInfo
81 //===----------------------------------------------------------------------===//
82
registerPassPipeline(StringRef arg,StringRef description,const PassRegistryFunction & function,std::function<void (function_ref<void (const PassOptions &)>)> optHandler)83 void mlir::registerPassPipeline(
84 StringRef arg, StringRef description, const PassRegistryFunction &function,
85 std::function<void(function_ref<void(const PassOptions &)>)> optHandler) {
86 PassPipelineInfo pipelineInfo(arg, description, function, optHandler);
87 bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second;
88 assert(inserted && "Pass pipeline registered multiple times");
89 (void)inserted;
90 }
91
92 //===----------------------------------------------------------------------===//
93 // PassInfo
94 //===----------------------------------------------------------------------===//
95
PassInfo(StringRef arg,StringRef description,TypeID passID,const PassAllocatorFunction & allocator)96 PassInfo::PassInfo(StringRef arg, StringRef description, TypeID passID,
97 const PassAllocatorFunction &allocator)
98 : PassRegistryEntry(
99 arg, description, buildDefaultRegistryFn(allocator),
100 // Use a temporary pass to provide an options instance.
101 [=](function_ref<void(const PassOptions &)> optHandler) {
102 optHandler(allocator()->passOptions);
103 }) {}
104
registerPass(StringRef arg,StringRef description,const PassAllocatorFunction & function)105 void mlir::registerPass(StringRef arg, StringRef description,
106 const PassAllocatorFunction &function) {
107 // TODO: We should use the 'arg' as the lookup key instead of the pass id.
108 TypeID passID = function()->getTypeID();
109 PassInfo passInfo(arg, description, passID, function);
110 passRegistry->try_emplace(passID, passInfo);
111 }
112
113 /// Returns the pass info for the specified pass class or null if unknown.
lookupPassInfo(TypeID passID)114 const PassInfo *mlir::Pass::lookupPassInfo(TypeID passID) {
115 auto it = passRegistry->find(passID);
116 if (it == passRegistry->end())
117 return nullptr;
118 return &it->getSecond();
119 }
120
121 //===----------------------------------------------------------------------===//
122 // PassOptions
123 //===----------------------------------------------------------------------===//
124
125 /// Out of line virtual function to provide home for the class.
anchor()126 void detail::PassOptions::OptionBase::anchor() {}
127
128 /// Copy the option values from 'other'.
copyOptionValuesFrom(const PassOptions & other)129 void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) {
130 assert(options.size() == other.options.size());
131 if (options.empty())
132 return;
133 for (auto optionsIt : llvm::zip(options, other.options))
134 std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
135 }
136
parseFromString(StringRef options)137 LogicalResult detail::PassOptions::parseFromString(StringRef options) {
138 // TODO: Handle escaping strings.
139 // NOTE: `options` is modified in place to always refer to the unprocessed
140 // part of the string.
141 while (!options.empty()) {
142 size_t spacePos = options.find(' ');
143 StringRef arg = options;
144 if (spacePos != StringRef::npos) {
145 arg = options.substr(0, spacePos);
146 options = options.substr(spacePos + 1);
147 } else {
148 options = StringRef();
149 }
150 if (arg.empty())
151 continue;
152
153 // At this point, arg refers to everything that is non-space in options
154 // upto the next space, and options refers to the rest of the string after
155 // that point.
156
157 // Split the individual option on '=' to form key and value. If there is no
158 // '=', then value is `StringRef()`.
159 size_t equalPos = arg.find('=');
160 StringRef key = arg;
161 StringRef value;
162 if (equalPos != StringRef::npos) {
163 key = arg.substr(0, equalPos);
164 value = arg.substr(equalPos + 1);
165 }
166 auto it = OptionsMap.find(key);
167 if (it == OptionsMap.end()) {
168 llvm::errs() << "<Pass-Options-Parser>: no such option " << key << "\n";
169 return failure();
170 }
171 if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
172 return failure();
173 }
174
175 return success();
176 }
177
178 /// Print the options held by this struct in a form that can be parsed via
179 /// 'parseFromString'.
print(raw_ostream & os)180 void detail::PassOptions::print(raw_ostream &os) {
181 // If there are no options, there is nothing left to do.
182 if (OptionsMap.empty())
183 return;
184
185 // Sort the options to make the ordering deterministic.
186 SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
187 auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
188 return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
189 };
190 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
191
192 // Interleave the options with ' '.
193 os << '{';
194 llvm::interleave(
195 orderedOps, os, [&](OptionBase *option) { option->print(os); }, " ");
196 os << '}';
197 }
198
199 /// Print the help string for the options held by this struct. `descIndent` is
200 /// the indent within the stream that the descriptions should be aligned.
printHelp(size_t indent,size_t descIndent) const201 void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const {
202 // Sort the options to make the ordering deterministic.
203 SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
204 auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
205 return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
206 };
207 llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
208 for (OptionBase *option : orderedOps) {
209 // TODO: printOptionInfo assumes a specific indent and will
210 // print options with values with incorrect indentation. We should add
211 // support to llvm::cl::Option for passing in a base indent to use when
212 // printing.
213 llvm::outs().indent(indent);
214 option->getOption()->printOptionInfo(descIndent - indent);
215 }
216 }
217
218 /// Return the maximum width required when printing the help string.
getOptionWidth() const219 size_t detail::PassOptions::getOptionWidth() const {
220 size_t max = 0;
221 for (auto *option : options)
222 max = std::max(max, option->getOption()->getOptionWidth());
223 return max;
224 }
225
226 //===----------------------------------------------------------------------===//
227 // TextualPassPipeline Parser
228 //===----------------------------------------------------------------------===//
229
230 namespace {
231 /// This class represents a textual description of a pass pipeline.
232 class TextualPipeline {
233 public:
234 /// Try to initialize this pipeline with the given pipeline text.
235 /// `errorStream` is the output stream to emit errors to.
236 LogicalResult initialize(StringRef text, raw_ostream &errorStream);
237
238 /// Add the internal pipeline elements to the provided pass manager.
239 LogicalResult
240 addToPipeline(OpPassManager &pm,
241 function_ref<LogicalResult(const Twine &)> errorHandler) const;
242
243 private:
244 /// A functor used to emit errors found during pipeline handling. The first
245 /// parameter corresponds to the raw location within the pipeline string. This
246 /// should always return failure.
247 using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>;
248
249 /// A struct to capture parsed pass pipeline names.
250 ///
251 /// A pipeline is defined as a series of names, each of which may in itself
252 /// recursively contain a nested pipeline. A name is either the name of a pass
253 /// (e.g. "cse") or the name of an operation type (e.g. "func"). If the name
254 /// is the name of a pass, the InnerPipeline is empty, since passes cannot
255 /// contain inner pipelines.
256 struct PipelineElement {
PipelineElement__anon2c89db4f0811::TextualPipeline::PipelineElement257 PipelineElement(StringRef name) : name(name), registryEntry(nullptr) {}
258
259 StringRef name;
260 StringRef options;
261 const PassRegistryEntry *registryEntry;
262 std::vector<PipelineElement> innerPipeline;
263 };
264
265 /// Parse the given pipeline text into the internal pipeline vector. This
266 /// function only parses the structure of the pipeline, and does not resolve
267 /// its elements.
268 LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
269
270 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
271 /// the corresponding registry entry.
272 LogicalResult
273 resolvePipelineElements(MutableArrayRef<PipelineElement> elements,
274 ErrorHandlerT errorHandler);
275
276 /// Resolve a single element of the pipeline.
277 LogicalResult resolvePipelineElement(PipelineElement &element,
278 ErrorHandlerT errorHandler);
279
280 /// Add the given pipeline elements to the provided pass manager.
281 LogicalResult
282 addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm,
283 function_ref<LogicalResult(const Twine &)> errorHandler) const;
284
285 std::vector<PipelineElement> pipeline;
286 };
287
288 } // end anonymous namespace
289
290 /// Try to initialize this pipeline with the given pipeline text. An option is
291 /// given to enable accurate error reporting.
initialize(StringRef text,raw_ostream & errorStream)292 LogicalResult TextualPipeline::initialize(StringRef text,
293 raw_ostream &errorStream) {
294 // Build a source manager to use for error reporting.
295 llvm::SourceMgr pipelineMgr;
296 pipelineMgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(
297 text, "MLIR Textual PassPipeline Parser"),
298 llvm::SMLoc());
299 auto errorHandler = [&](const char *rawLoc, Twine msg) {
300 pipelineMgr.PrintMessage(errorStream, llvm::SMLoc::getFromPointer(rawLoc),
301 llvm::SourceMgr::DK_Error, msg);
302 return failure();
303 };
304
305 // Parse the provided pipeline string.
306 if (failed(parsePipelineText(text, errorHandler)))
307 return failure();
308 return resolvePipelineElements(pipeline, errorHandler);
309 }
310
311 /// Add the internal pipeline elements to the provided pass manager.
addToPipeline(OpPassManager & pm,function_ref<LogicalResult (const Twine &)> errorHandler) const312 LogicalResult TextualPipeline::addToPipeline(
313 OpPassManager &pm,
314 function_ref<LogicalResult(const Twine &)> errorHandler) const {
315 return addToPipeline(pipeline, pm, errorHandler);
316 }
317
318 /// Parse the given pipeline text into the internal pipeline vector. This
319 /// function only parses the structure of the pipeline, and does not resolve
320 /// its elements.
parsePipelineText(StringRef text,ErrorHandlerT errorHandler)321 LogicalResult TextualPipeline::parsePipelineText(StringRef text,
322 ErrorHandlerT errorHandler) {
323 SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline};
324 for (;;) {
325 std::vector<PipelineElement> &pipeline = *pipelineStack.back();
326 size_t pos = text.find_first_of(",(){");
327 pipeline.emplace_back(/*name=*/text.substr(0, pos).trim());
328
329 // If we have a single terminating name, we're done.
330 if (pos == text.npos)
331 break;
332
333 text = text.substr(pos);
334 char sep = text[0];
335
336 // Handle pulling ... from 'pass{...}' out as PipelineElement.options.
337 if (sep == '{') {
338 text = text.substr(1);
339
340 // Skip over everything until the closing '}' and store as options.
341 size_t close = text.find('}');
342
343 // TODO: Handle skipping over quoted sub-strings.
344 if (close == StringRef::npos) {
345 return errorHandler(
346 /*rawLoc=*/text.data() - 1,
347 "missing closing '}' while processing pass options");
348 }
349 pipeline.back().options = text.substr(0, close);
350 text = text.substr(close + 1);
351
352 // Skip checking for '(' because nested pipelines cannot have options.
353 } else if (sep == '(') {
354 text = text.substr(1);
355
356 // Push the inner pipeline onto the stack to continue processing.
357 pipelineStack.push_back(&pipeline.back().innerPipeline);
358 continue;
359 }
360
361 // When handling the close parenthesis, we greedily consume them to avoid
362 // empty strings in the pipeline.
363 while (text.consume_front(")")) {
364 // If we try to pop the outer pipeline we have unbalanced parentheses.
365 if (pipelineStack.size() == 1)
366 return errorHandler(/*rawLoc=*/text.data() - 1,
367 "encountered extra closing ')' creating unbalanced "
368 "parentheses while parsing pipeline");
369
370 pipelineStack.pop_back();
371 }
372
373 // Check if we've finished parsing.
374 if (text.empty())
375 break;
376
377 // Otherwise, the end of an inner pipeline always has to be followed by
378 // a comma, and then we can continue.
379 if (!text.consume_front(","))
380 return errorHandler(text.data(), "expected ',' after parsing pipeline");
381 }
382
383 // Check for unbalanced parentheses.
384 if (pipelineStack.size() > 1)
385 return errorHandler(
386 text.data(),
387 "encountered unbalanced parentheses while parsing pipeline");
388
389 assert(pipelineStack.back() == &pipeline &&
390 "wrong pipeline at the bottom of the stack");
391 return success();
392 }
393
394 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
395 /// the corresponding registry entry.
resolvePipelineElements(MutableArrayRef<PipelineElement> elements,ErrorHandlerT errorHandler)396 LogicalResult TextualPipeline::resolvePipelineElements(
397 MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) {
398 for (auto &elt : elements)
399 if (failed(resolvePipelineElement(elt, errorHandler)))
400 return failure();
401 return success();
402 }
403
404 /// Resolve a single element of the pipeline.
405 LogicalResult
resolvePipelineElement(PipelineElement & element,ErrorHandlerT errorHandler)406 TextualPipeline::resolvePipelineElement(PipelineElement &element,
407 ErrorHandlerT errorHandler) {
408 // If the inner pipeline of this element is not empty, this is an operation
409 // pipeline.
410 if (!element.innerPipeline.empty())
411 return resolvePipelineElements(element.innerPipeline, errorHandler);
412 // Otherwise, this must be a pass or pass pipeline.
413 // Check to see if a pipeline was registered with this name.
414 auto pipelineRegistryIt = passPipelineRegistry->find(element.name);
415 if (pipelineRegistryIt != passPipelineRegistry->end()) {
416 element.registryEntry = &pipelineRegistryIt->second;
417 return success();
418 }
419
420 // If not, then this must be a specific pass name.
421 for (auto &passIt : *passRegistry) {
422 if (passIt.second.getPassArgument() == element.name) {
423 element.registryEntry = &passIt.second;
424 return success();
425 }
426 }
427
428 // Emit an error for the unknown pass.
429 auto *rawLoc = element.name.data();
430 return errorHandler(rawLoc, "'" + element.name +
431 "' does not refer to a "
432 "registered pass or pass pipeline");
433 }
434
435 /// Add the given pipeline elements to the provided pass manager.
addToPipeline(ArrayRef<PipelineElement> elements,OpPassManager & pm,function_ref<LogicalResult (const Twine &)> errorHandler) const436 LogicalResult TextualPipeline::addToPipeline(
437 ArrayRef<PipelineElement> elements, OpPassManager &pm,
438 function_ref<LogicalResult(const Twine &)> errorHandler) const {
439 for (auto &elt : elements) {
440 if (elt.registryEntry) {
441 if (failed(
442 elt.registryEntry->addToPipeline(pm, elt.options, errorHandler)))
443 return failure();
444 } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name),
445 errorHandler))) {
446 return failure();
447 }
448 }
449 return success();
450 }
451
452 /// This function parses the textual representation of a pass pipeline, and adds
453 /// the result to 'pm' on success. This function returns failure if the given
454 /// pipeline was invalid. 'errorStream' is an optional parameter that, if
455 /// non-null, will be used to emit errors found during parsing.
parsePassPipeline(StringRef pipeline,OpPassManager & pm,raw_ostream & errorStream)456 LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm,
457 raw_ostream &errorStream) {
458 TextualPipeline pipelineParser;
459 if (failed(pipelineParser.initialize(pipeline, errorStream)))
460 return failure();
461 auto errorHandler = [&](Twine msg) {
462 errorStream << msg << "\n";
463 return failure();
464 };
465 if (failed(pipelineParser.addToPipeline(pm, errorHandler)))
466 return failure();
467 return success();
468 }
469
470 //===----------------------------------------------------------------------===//
471 // PassNameParser
472 //===----------------------------------------------------------------------===//
473
474 namespace {
475 /// This struct represents the possible data entries in a parsed pass pipeline
476 /// list.
477 struct PassArgData {
PassArgData__anon2c89db4f0b11::PassArgData478 PassArgData() : registryEntry(nullptr) {}
PassArgData__anon2c89db4f0b11::PassArgData479 PassArgData(const PassRegistryEntry *registryEntry)
480 : registryEntry(registryEntry) {}
481
482 /// This field is used when the parsed option corresponds to a registered pass
483 /// or pass pipeline.
484 const PassRegistryEntry *registryEntry;
485
486 /// This field is set when instance specific pass options have been provided
487 /// on the command line.
488 StringRef options;
489
490 /// This field is used when the parsed option corresponds to an explicit
491 /// pipeline.
492 TextualPipeline pipeline;
493 };
494 } // end anonymous namespace
495
496 namespace llvm {
497 namespace cl {
498 /// Define a valid OptionValue for the command line pass argument.
499 template <>
500 struct OptionValue<PassArgData> final
501 : OptionValueBase<PassArgData, /*isClass=*/true> {
OptionValuellvm::cl::OptionValue502 OptionValue(const PassArgData &value) { this->setValue(value); }
503 OptionValue() = default;
anchorllvm::cl::OptionValue504 void anchor() override {}
505
hasValuellvm::cl::OptionValue506 bool hasValue() const { return true; }
getValuellvm::cl::OptionValue507 const PassArgData &getValue() const { return value; }
setValuellvm::cl::OptionValue508 void setValue(const PassArgData &value) { this->value = value; }
509
510 PassArgData value;
511 };
512 } // end namespace cl
513 } // end namespace llvm
514
515 namespace {
516
517 /// The name for the command line option used for parsing the textual pass
518 /// pipeline.
519 static constexpr StringLiteral passPipelineArg = "pass-pipeline";
520
521 /// Adds command line option for each registered pass or pass pipeline, as well
522 /// as textual pass pipelines.
523 struct PassNameParser : public llvm::cl::parser<PassArgData> {
PassNameParser__anon2c89db4f0c11::PassNameParser524 PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {}
525
526 void initialize();
527 void printOptionInfo(const llvm::cl::Option &opt,
528 size_t globalWidth) const override;
529 size_t getOptionWidth(const llvm::cl::Option &opt) const override;
530 bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
531 PassArgData &value);
532 };
533 } // namespace
534
initialize()535 void PassNameParser::initialize() {
536 llvm::cl::parser<PassArgData>::initialize();
537
538 /// Add an entry for the textual pass pipeline option.
539 addLiteralOption(passPipelineArg, PassArgData(),
540 "A textual description of a pass pipeline to run");
541
542 /// Add the pass entries.
543 for (const auto &kv : *passRegistry) {
544 addLiteralOption(kv.second.getPassArgument(), &kv.second,
545 kv.second.getPassDescription());
546 }
547 /// Add the pass pipeline entries.
548 for (const auto &kv : *passPipelineRegistry) {
549 addLiteralOption(kv.second.getPassArgument(), &kv.second,
550 kv.second.getPassDescription());
551 }
552 }
553
printOptionInfo(const llvm::cl::Option & opt,size_t globalWidth) const554 void PassNameParser::printOptionInfo(const llvm::cl::Option &opt,
555 size_t globalWidth) const {
556 // Print the information for the top-level option.
557 if (opt.hasArgStr()) {
558 llvm::outs() << " --" << opt.ArgStr;
559 opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
560 } else {
561 llvm::outs() << " " << opt.HelpStr << '\n';
562 }
563
564 // Print the top-level pipeline argument.
565 printOptionHelp(passPipelineArg,
566 "A textual description of a pass pipeline to run",
567 /*indent=*/4, globalWidth, /*isTopLevel=*/!opt.hasArgStr());
568
569 // Functor used to print the ordered entries of a registration map.
570 auto printOrderedEntries = [&](StringRef header, auto &map) {
571 llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries;
572 for (auto &kv : map)
573 orderedEntries.push_back(&kv.second);
574 llvm::array_pod_sort(
575 orderedEntries.begin(), orderedEntries.end(),
576 [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
577 return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
578 });
579
580 llvm::outs().indent(4) << header << ":\n";
581 for (PassRegistryEntry *entry : orderedEntries)
582 entry->printHelpStr(/*indent=*/6, globalWidth);
583 };
584
585 // Print the available passes.
586 printOrderedEntries("Passes", *passRegistry);
587
588 // Print the available pass pipelines.
589 if (!passPipelineRegistry->empty())
590 printOrderedEntries("Pass Pipelines", *passPipelineRegistry);
591 }
592
getOptionWidth(const llvm::cl::Option & opt) const593 size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const {
594 size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(opt) + 2;
595
596 // Check for any wider pass or pipeline options.
597 for (auto &entry : *passRegistry)
598 maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
599 for (auto &entry : *passPipelineRegistry)
600 maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
601 return maxWidth;
602 }
603
parse(llvm::cl::Option & opt,StringRef argName,StringRef arg,PassArgData & value)604 bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,
605 StringRef arg, PassArgData &value) {
606 // Handle the pipeline option explicitly.
607 if (argName == passPipelineArg)
608 return failed(value.pipeline.initialize(arg, llvm::errs()));
609
610 // Otherwise, default to the base for handling.
611 if (llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value))
612 return true;
613 value.options = arg;
614 return false;
615 }
616
617 //===----------------------------------------------------------------------===//
618 // PassPipelineCLParser
619 //===----------------------------------------------------------------------===//
620
621 namespace mlir {
622 namespace detail {
623 struct PassPipelineCLParserImpl {
PassPipelineCLParserImplmlir::detail::PassPipelineCLParserImpl624 PassPipelineCLParserImpl(StringRef arg, StringRef description)
625 : passList(arg, llvm::cl::desc(description)) {
626 passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
627 }
628
629 /// The set of passes and pass pipelines to run.
630 llvm::cl::list<PassArgData, bool, PassNameParser> passList;
631 };
632 } // end namespace detail
633 } // end namespace mlir
634
635 /// Construct a pass pipeline parser with the given command line description.
PassPipelineCLParser(StringRef arg,StringRef description)636 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description)
637 : impl(std::make_unique<detail::PassPipelineCLParserImpl>(arg,
638 description)) {}
~PassPipelineCLParser()639 PassPipelineCLParser::~PassPipelineCLParser() {}
640
641 /// Returns true if this parser contains any valid options to add.
hasAnyOccurrences() const642 bool PassPipelineCLParser::hasAnyOccurrences() const {
643 return impl->passList.getNumOccurrences() != 0;
644 }
645
646 /// Returns true if the given pass registry entry was registered at the
647 /// top-level of the parser, i.e. not within an explicit textual pipeline.
contains(const PassRegistryEntry * entry) const648 bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const {
649 return llvm::any_of(impl->passList, [&](const PassArgData &data) {
650 return data.registryEntry == entry;
651 });
652 }
653
654 /// Adds the passes defined by this parser entry to the given pass manager.
addToPipeline(OpPassManager & pm,function_ref<LogicalResult (const Twine &)> errorHandler) const655 LogicalResult PassPipelineCLParser::addToPipeline(
656 OpPassManager &pm,
657 function_ref<LogicalResult(const Twine &)> errorHandler) const {
658 for (auto &passIt : impl->passList) {
659 if (passIt.registryEntry) {
660 if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
661 errorHandler)))
662 return failure();
663 } else {
664 OpPassManager::Nesting nesting = pm.getNesting();
665 pm.setNesting(OpPassManager::Nesting::Explicit);
666 LogicalResult status = passIt.pipeline.addToPipeline(pm, errorHandler);
667 pm.setNesting(nesting);
668 if (failed(status))
669 return failure();
670 }
671 }
672 return success();
673 }
674