1 //===- PassRegistry.cpp - Pass Registration Utilities ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Pass/PassRegistry.h"
10 #include "mlir/Pass/Pass.h"
11 #include "mlir/Pass/PassManager.h"
12 #include "llvm/ADT/DenseMap.h"
13 #include "llvm/Support/ManagedStatic.h"
14 #include "llvm/Support/MemoryBuffer.h"
15 #include "llvm/Support/SourceMgr.h"
16 
17 using namespace mlir;
18 using namespace detail;
19 
20 /// Static mapping of all of the registered passes.
21 static llvm::ManagedStatic<DenseMap<TypeID, PassInfo>> passRegistry;
22 
23 /// Static mapping of all of the registered pass pipelines.
24 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
25     passPipelineRegistry;
26 
27 /// Utility to create a default registry function from a pass instance.
28 static PassRegistryFunction
buildDefaultRegistryFn(const PassAllocatorFunction & allocator)29 buildDefaultRegistryFn(const PassAllocatorFunction &allocator) {
30   return [=](OpPassManager &pm, StringRef options,
31              function_ref<LogicalResult(const Twine &)> errorHandler) {
32     std::unique_ptr<Pass> pass = allocator();
33     LogicalResult result = pass->initializeOptions(options);
34     if ((pm.getNesting() == OpPassManager::Nesting::Explicit) &&
35         pass->getOpName() && *pass->getOpName() != pm.getOpName())
36       return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() +
37                           "' restricted to '" + *pass->getOpName() +
38                           "' on a PassManager intended to run on '" +
39                           pm.getOpName() + "', did you intend to nest?");
40     pm.addPass(std::move(pass));
41     return result;
42   };
43 }
44 
45 /// Utility to print the help string for a specific option.
printOptionHelp(StringRef arg,StringRef desc,size_t indent,size_t descIndent,bool isTopLevel)46 static void printOptionHelp(StringRef arg, StringRef desc, size_t indent,
47                             size_t descIndent, bool isTopLevel) {
48   size_t numSpaces = descIndent - indent - 4;
49   llvm::outs().indent(indent)
50       << "--" << llvm::left_justify(arg, numSpaces) << "-   " << desc << '\n';
51 }
52 
53 //===----------------------------------------------------------------------===//
54 // PassRegistry
55 //===----------------------------------------------------------------------===//
56 
57 /// Print the help information for this pass. This includes the argument,
58 /// description, and any pass options. `descIndent` is the indent that the
59 /// descriptions should be aligned.
printHelpStr(size_t indent,size_t descIndent) const60 void PassRegistryEntry::printHelpStr(size_t indent, size_t descIndent) const {
61   printOptionHelp(getPassArgument(), getPassDescription(), indent, descIndent,
62                   /*isTopLevel=*/true);
63   // If this entry has options, print the help for those as well.
64   optHandler([=](const PassOptions &options) {
65     options.printHelp(indent, descIndent);
66   });
67 }
68 
69 /// Return the maximum width required when printing the options of this
70 /// entry.
getOptionWidth() const71 size_t PassRegistryEntry::getOptionWidth() const {
72   size_t maxLen = 0;
73   optHandler([&](const PassOptions &options) mutable {
74     maxLen = options.getOptionWidth() + 2;
75   });
76   return maxLen;
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // PassPipelineInfo
81 //===----------------------------------------------------------------------===//
82 
registerPassPipeline(StringRef arg,StringRef description,const PassRegistryFunction & function,std::function<void (function_ref<void (const PassOptions &)>)> optHandler)83 void mlir::registerPassPipeline(
84     StringRef arg, StringRef description, const PassRegistryFunction &function,
85     std::function<void(function_ref<void(const PassOptions &)>)> optHandler) {
86   PassPipelineInfo pipelineInfo(arg, description, function, optHandler);
87   bool inserted = passPipelineRegistry->try_emplace(arg, pipelineInfo).second;
88   assert(inserted && "Pass pipeline registered multiple times");
89   (void)inserted;
90 }
91 
92 //===----------------------------------------------------------------------===//
93 // PassInfo
94 //===----------------------------------------------------------------------===//
95 
PassInfo(StringRef arg,StringRef description,TypeID passID,const PassAllocatorFunction & allocator)96 PassInfo::PassInfo(StringRef arg, StringRef description, TypeID passID,
97                    const PassAllocatorFunction &allocator)
98     : PassRegistryEntry(
99           arg, description, buildDefaultRegistryFn(allocator),
100           // Use a temporary pass to provide an options instance.
101           [=](function_ref<void(const PassOptions &)> optHandler) {
102             optHandler(allocator()->passOptions);
103           }) {}
104 
registerPass(StringRef arg,StringRef description,const PassAllocatorFunction & function)105 void mlir::registerPass(StringRef arg, StringRef description,
106                         const PassAllocatorFunction &function) {
107   // TODO: We should use the 'arg' as the lookup key instead of the pass id.
108   TypeID passID = function()->getTypeID();
109   PassInfo passInfo(arg, description, passID, function);
110   passRegistry->try_emplace(passID, passInfo);
111 }
112 
113 /// Returns the pass info for the specified pass class or null if unknown.
lookupPassInfo(TypeID passID)114 const PassInfo *mlir::Pass::lookupPassInfo(TypeID passID) {
115   auto it = passRegistry->find(passID);
116   if (it == passRegistry->end())
117     return nullptr;
118   return &it->getSecond();
119 }
120 
121 //===----------------------------------------------------------------------===//
122 // PassOptions
123 //===----------------------------------------------------------------------===//
124 
125 /// Out of line virtual function to provide home for the class.
anchor()126 void detail::PassOptions::OptionBase::anchor() {}
127 
128 /// Copy the option values from 'other'.
copyOptionValuesFrom(const PassOptions & other)129 void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) {
130   assert(options.size() == other.options.size());
131   if (options.empty())
132     return;
133   for (auto optionsIt : llvm::zip(options, other.options))
134     std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
135 }
136 
parseFromString(StringRef options)137 LogicalResult detail::PassOptions::parseFromString(StringRef options) {
138   // TODO: Handle escaping strings.
139   // NOTE: `options` is modified in place to always refer to the unprocessed
140   // part of the string.
141   while (!options.empty()) {
142     size_t spacePos = options.find(' ');
143     StringRef arg = options;
144     if (spacePos != StringRef::npos) {
145       arg = options.substr(0, spacePos);
146       options = options.substr(spacePos + 1);
147     } else {
148       options = StringRef();
149     }
150     if (arg.empty())
151       continue;
152 
153     // At this point, arg refers to everything that is non-space in options
154     // upto the next space, and options refers to the rest of the string after
155     // that point.
156 
157     // Split the individual option on '=' to form key and value. If there is no
158     // '=', then value is `StringRef()`.
159     size_t equalPos = arg.find('=');
160     StringRef key = arg;
161     StringRef value;
162     if (equalPos != StringRef::npos) {
163       key = arg.substr(0, equalPos);
164       value = arg.substr(equalPos + 1);
165     }
166     auto it = OptionsMap.find(key);
167     if (it == OptionsMap.end()) {
168       llvm::errs() << "<Pass-Options-Parser>: no such option " << key << "\n";
169       return failure();
170     }
171     if (llvm::cl::ProvidePositionalOption(it->second, value, 0))
172       return failure();
173   }
174 
175   return success();
176 }
177 
178 /// Print the options held by this struct in a form that can be parsed via
179 /// 'parseFromString'.
print(raw_ostream & os)180 void detail::PassOptions::print(raw_ostream &os) {
181   // If there are no options, there is nothing left to do.
182   if (OptionsMap.empty())
183     return;
184 
185   // Sort the options to make the ordering deterministic.
186   SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
187   auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
188     return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
189   };
190   llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
191 
192   // Interleave the options with ' '.
193   os << '{';
194   llvm::interleave(
195       orderedOps, os, [&](OptionBase *option) { option->print(os); }, " ");
196   os << '}';
197 }
198 
199 /// Print the help string for the options held by this struct. `descIndent` is
200 /// the indent within the stream that the descriptions should be aligned.
printHelp(size_t indent,size_t descIndent) const201 void detail::PassOptions::printHelp(size_t indent, size_t descIndent) const {
202   // Sort the options to make the ordering deterministic.
203   SmallVector<OptionBase *, 4> orderedOps(options.begin(), options.end());
204   auto compareOptionArgs = [](OptionBase *const *lhs, OptionBase *const *rhs) {
205     return (*lhs)->getArgStr().compare((*rhs)->getArgStr());
206   };
207   llvm::array_pod_sort(orderedOps.begin(), orderedOps.end(), compareOptionArgs);
208   for (OptionBase *option : orderedOps) {
209     // TODO: printOptionInfo assumes a specific indent and will
210     // print options with values with incorrect indentation. We should add
211     // support to llvm::cl::Option for passing in a base indent to use when
212     // printing.
213     llvm::outs().indent(indent);
214     option->getOption()->printOptionInfo(descIndent - indent);
215   }
216 }
217 
218 /// Return the maximum width required when printing the help string.
getOptionWidth() const219 size_t detail::PassOptions::getOptionWidth() const {
220   size_t max = 0;
221   for (auto *option : options)
222     max = std::max(max, option->getOption()->getOptionWidth());
223   return max;
224 }
225 
226 //===----------------------------------------------------------------------===//
227 // TextualPassPipeline Parser
228 //===----------------------------------------------------------------------===//
229 
230 namespace {
231 /// This class represents a textual description of a pass pipeline.
232 class TextualPipeline {
233 public:
234   /// Try to initialize this pipeline with the given pipeline text.
235   /// `errorStream` is the output stream to emit errors to.
236   LogicalResult initialize(StringRef text, raw_ostream &errorStream);
237 
238   /// Add the internal pipeline elements to the provided pass manager.
239   LogicalResult
240   addToPipeline(OpPassManager &pm,
241                 function_ref<LogicalResult(const Twine &)> errorHandler) const;
242 
243 private:
244   /// A functor used to emit errors found during pipeline handling. The first
245   /// parameter corresponds to the raw location within the pipeline string. This
246   /// should always return failure.
247   using ErrorHandlerT = function_ref<LogicalResult(const char *, Twine)>;
248 
249   /// A struct to capture parsed pass pipeline names.
250   ///
251   /// A pipeline is defined as a series of names, each of which may in itself
252   /// recursively contain a nested pipeline. A name is either the name of a pass
253   /// (e.g. "cse") or the name of an operation type (e.g. "func"). If the name
254   /// is the name of a pass, the InnerPipeline is empty, since passes cannot
255   /// contain inner pipelines.
256   struct PipelineElement {
PipelineElement__anon2c89db4f0811::TextualPipeline::PipelineElement257     PipelineElement(StringRef name) : name(name), registryEntry(nullptr) {}
258 
259     StringRef name;
260     StringRef options;
261     const PassRegistryEntry *registryEntry;
262     std::vector<PipelineElement> innerPipeline;
263   };
264 
265   /// Parse the given pipeline text into the internal pipeline vector. This
266   /// function only parses the structure of the pipeline, and does not resolve
267   /// its elements.
268   LogicalResult parsePipelineText(StringRef text, ErrorHandlerT errorHandler);
269 
270   /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
271   /// the corresponding registry entry.
272   LogicalResult
273   resolvePipelineElements(MutableArrayRef<PipelineElement> elements,
274                           ErrorHandlerT errorHandler);
275 
276   /// Resolve a single element of the pipeline.
277   LogicalResult resolvePipelineElement(PipelineElement &element,
278                                        ErrorHandlerT errorHandler);
279 
280   /// Add the given pipeline elements to the provided pass manager.
281   LogicalResult
282   addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm,
283                 function_ref<LogicalResult(const Twine &)> errorHandler) const;
284 
285   std::vector<PipelineElement> pipeline;
286 };
287 
288 } // end anonymous namespace
289 
290 /// Try to initialize this pipeline with the given pipeline text. An option is
291 /// given to enable accurate error reporting.
initialize(StringRef text,raw_ostream & errorStream)292 LogicalResult TextualPipeline::initialize(StringRef text,
293                                           raw_ostream &errorStream) {
294   // Build a source manager to use for error reporting.
295   llvm::SourceMgr pipelineMgr;
296   pipelineMgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(
297                                      text, "MLIR Textual PassPipeline Parser"),
298                                  llvm::SMLoc());
299   auto errorHandler = [&](const char *rawLoc, Twine msg) {
300     pipelineMgr.PrintMessage(errorStream, llvm::SMLoc::getFromPointer(rawLoc),
301                              llvm::SourceMgr::DK_Error, msg);
302     return failure();
303   };
304 
305   // Parse the provided pipeline string.
306   if (failed(parsePipelineText(text, errorHandler)))
307     return failure();
308   return resolvePipelineElements(pipeline, errorHandler);
309 }
310 
311 /// Add the internal pipeline elements to the provided pass manager.
addToPipeline(OpPassManager & pm,function_ref<LogicalResult (const Twine &)> errorHandler) const312 LogicalResult TextualPipeline::addToPipeline(
313     OpPassManager &pm,
314     function_ref<LogicalResult(const Twine &)> errorHandler) const {
315   return addToPipeline(pipeline, pm, errorHandler);
316 }
317 
318 /// Parse the given pipeline text into the internal pipeline vector. This
319 /// function only parses the structure of the pipeline, and does not resolve
320 /// its elements.
parsePipelineText(StringRef text,ErrorHandlerT errorHandler)321 LogicalResult TextualPipeline::parsePipelineText(StringRef text,
322                                                  ErrorHandlerT errorHandler) {
323   SmallVector<std::vector<PipelineElement> *, 4> pipelineStack = {&pipeline};
324   for (;;) {
325     std::vector<PipelineElement> &pipeline = *pipelineStack.back();
326     size_t pos = text.find_first_of(",(){");
327     pipeline.emplace_back(/*name=*/text.substr(0, pos).trim());
328 
329     // If we have a single terminating name, we're done.
330     if (pos == text.npos)
331       break;
332 
333     text = text.substr(pos);
334     char sep = text[0];
335 
336     // Handle pulling ... from 'pass{...}' out as PipelineElement.options.
337     if (sep == '{') {
338       text = text.substr(1);
339 
340       // Skip over everything until the closing '}' and store as options.
341       size_t close = text.find('}');
342 
343       // TODO: Handle skipping over quoted sub-strings.
344       if (close == StringRef::npos) {
345         return errorHandler(
346             /*rawLoc=*/text.data() - 1,
347             "missing closing '}' while processing pass options");
348       }
349       pipeline.back().options = text.substr(0, close);
350       text = text.substr(close + 1);
351 
352       // Skip checking for '(' because nested pipelines cannot have options.
353     } else if (sep == '(') {
354       text = text.substr(1);
355 
356       // Push the inner pipeline onto the stack to continue processing.
357       pipelineStack.push_back(&pipeline.back().innerPipeline);
358       continue;
359     }
360 
361     // When handling the close parenthesis, we greedily consume them to avoid
362     // empty strings in the pipeline.
363     while (text.consume_front(")")) {
364       // If we try to pop the outer pipeline we have unbalanced parentheses.
365       if (pipelineStack.size() == 1)
366         return errorHandler(/*rawLoc=*/text.data() - 1,
367                             "encountered extra closing ')' creating unbalanced "
368                             "parentheses while parsing pipeline");
369 
370       pipelineStack.pop_back();
371     }
372 
373     // Check if we've finished parsing.
374     if (text.empty())
375       break;
376 
377     // Otherwise, the end of an inner pipeline always has to be followed by
378     // a comma, and then we can continue.
379     if (!text.consume_front(","))
380       return errorHandler(text.data(), "expected ',' after parsing pipeline");
381   }
382 
383   // Check for unbalanced parentheses.
384   if (pipelineStack.size() > 1)
385     return errorHandler(
386         text.data(),
387         "encountered unbalanced parentheses while parsing pipeline");
388 
389   assert(pipelineStack.back() == &pipeline &&
390          "wrong pipeline at the bottom of the stack");
391   return success();
392 }
393 
394 /// Resolve the elements of the pipeline, i.e. connect passes and pipelines to
395 /// the corresponding registry entry.
resolvePipelineElements(MutableArrayRef<PipelineElement> elements,ErrorHandlerT errorHandler)396 LogicalResult TextualPipeline::resolvePipelineElements(
397     MutableArrayRef<PipelineElement> elements, ErrorHandlerT errorHandler) {
398   for (auto &elt : elements)
399     if (failed(resolvePipelineElement(elt, errorHandler)))
400       return failure();
401   return success();
402 }
403 
404 /// Resolve a single element of the pipeline.
405 LogicalResult
resolvePipelineElement(PipelineElement & element,ErrorHandlerT errorHandler)406 TextualPipeline::resolvePipelineElement(PipelineElement &element,
407                                         ErrorHandlerT errorHandler) {
408   // If the inner pipeline of this element is not empty, this is an operation
409   // pipeline.
410   if (!element.innerPipeline.empty())
411     return resolvePipelineElements(element.innerPipeline, errorHandler);
412   // Otherwise, this must be a pass or pass pipeline.
413   // Check to see if a pipeline was registered with this name.
414   auto pipelineRegistryIt = passPipelineRegistry->find(element.name);
415   if (pipelineRegistryIt != passPipelineRegistry->end()) {
416     element.registryEntry = &pipelineRegistryIt->second;
417     return success();
418   }
419 
420   // If not, then this must be a specific pass name.
421   for (auto &passIt : *passRegistry) {
422     if (passIt.second.getPassArgument() == element.name) {
423       element.registryEntry = &passIt.second;
424       return success();
425     }
426   }
427 
428   // Emit an error for the unknown pass.
429   auto *rawLoc = element.name.data();
430   return errorHandler(rawLoc, "'" + element.name +
431                                   "' does not refer to a "
432                                   "registered pass or pass pipeline");
433 }
434 
435 /// Add the given pipeline elements to the provided pass manager.
addToPipeline(ArrayRef<PipelineElement> elements,OpPassManager & pm,function_ref<LogicalResult (const Twine &)> errorHandler) const436 LogicalResult TextualPipeline::addToPipeline(
437     ArrayRef<PipelineElement> elements, OpPassManager &pm,
438     function_ref<LogicalResult(const Twine &)> errorHandler) const {
439   for (auto &elt : elements) {
440     if (elt.registryEntry) {
441       if (failed(
442               elt.registryEntry->addToPipeline(pm, elt.options, errorHandler)))
443         return failure();
444     } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name),
445                                     errorHandler))) {
446       return failure();
447     }
448   }
449   return success();
450 }
451 
452 /// This function parses the textual representation of a pass pipeline, and adds
453 /// the result to 'pm' on success. This function returns failure if the given
454 /// pipeline was invalid. 'errorStream' is an optional parameter that, if
455 /// non-null, will be used to emit errors found during parsing.
parsePassPipeline(StringRef pipeline,OpPassManager & pm,raw_ostream & errorStream)456 LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm,
457                                       raw_ostream &errorStream) {
458   TextualPipeline pipelineParser;
459   if (failed(pipelineParser.initialize(pipeline, errorStream)))
460     return failure();
461   auto errorHandler = [&](Twine msg) {
462     errorStream << msg << "\n";
463     return failure();
464   };
465   if (failed(pipelineParser.addToPipeline(pm, errorHandler)))
466     return failure();
467   return success();
468 }
469 
470 //===----------------------------------------------------------------------===//
471 // PassNameParser
472 //===----------------------------------------------------------------------===//
473 
474 namespace {
475 /// This struct represents the possible data entries in a parsed pass pipeline
476 /// list.
477 struct PassArgData {
PassArgData__anon2c89db4f0b11::PassArgData478   PassArgData() : registryEntry(nullptr) {}
PassArgData__anon2c89db4f0b11::PassArgData479   PassArgData(const PassRegistryEntry *registryEntry)
480       : registryEntry(registryEntry) {}
481 
482   /// This field is used when the parsed option corresponds to a registered pass
483   /// or pass pipeline.
484   const PassRegistryEntry *registryEntry;
485 
486   /// This field is set when instance specific pass options have been provided
487   /// on the command line.
488   StringRef options;
489 
490   /// This field is used when the parsed option corresponds to an explicit
491   /// pipeline.
492   TextualPipeline pipeline;
493 };
494 } // end anonymous namespace
495 
496 namespace llvm {
497 namespace cl {
498 /// Define a valid OptionValue for the command line pass argument.
499 template <>
500 struct OptionValue<PassArgData> final
501     : OptionValueBase<PassArgData, /*isClass=*/true> {
OptionValuellvm::cl::OptionValue502   OptionValue(const PassArgData &value) { this->setValue(value); }
503   OptionValue() = default;
anchorllvm::cl::OptionValue504   void anchor() override {}
505 
hasValuellvm::cl::OptionValue506   bool hasValue() const { return true; }
getValuellvm::cl::OptionValue507   const PassArgData &getValue() const { return value; }
setValuellvm::cl::OptionValue508   void setValue(const PassArgData &value) { this->value = value; }
509 
510   PassArgData value;
511 };
512 } // end namespace cl
513 } // end namespace llvm
514 
515 namespace {
516 
517 /// The name for the command line option used for parsing the textual pass
518 /// pipeline.
519 static constexpr StringLiteral passPipelineArg = "pass-pipeline";
520 
521 /// Adds command line option for each registered pass or pass pipeline, as well
522 /// as textual pass pipelines.
523 struct PassNameParser : public llvm::cl::parser<PassArgData> {
PassNameParser__anon2c89db4f0c11::PassNameParser524   PassNameParser(llvm::cl::Option &opt) : llvm::cl::parser<PassArgData>(opt) {}
525 
526   void initialize();
527   void printOptionInfo(const llvm::cl::Option &opt,
528                        size_t globalWidth) const override;
529   size_t getOptionWidth(const llvm::cl::Option &opt) const override;
530   bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
531              PassArgData &value);
532 };
533 } // namespace
534 
initialize()535 void PassNameParser::initialize() {
536   llvm::cl::parser<PassArgData>::initialize();
537 
538   /// Add an entry for the textual pass pipeline option.
539   addLiteralOption(passPipelineArg, PassArgData(),
540                    "A textual description of a pass pipeline to run");
541 
542   /// Add the pass entries.
543   for (const auto &kv : *passRegistry) {
544     addLiteralOption(kv.second.getPassArgument(), &kv.second,
545                      kv.second.getPassDescription());
546   }
547   /// Add the pass pipeline entries.
548   for (const auto &kv : *passPipelineRegistry) {
549     addLiteralOption(kv.second.getPassArgument(), &kv.second,
550                      kv.second.getPassDescription());
551   }
552 }
553 
printOptionInfo(const llvm::cl::Option & opt,size_t globalWidth) const554 void PassNameParser::printOptionInfo(const llvm::cl::Option &opt,
555                                      size_t globalWidth) const {
556   // Print the information for the top-level option.
557   if (opt.hasArgStr()) {
558     llvm::outs() << "  --" << opt.ArgStr;
559     opt.printHelpStr(opt.HelpStr, globalWidth, opt.ArgStr.size() + 7);
560   } else {
561     llvm::outs() << "  " << opt.HelpStr << '\n';
562   }
563 
564   // Print the top-level pipeline argument.
565   printOptionHelp(passPipelineArg,
566                   "A textual description of a pass pipeline to run",
567                   /*indent=*/4, globalWidth, /*isTopLevel=*/!opt.hasArgStr());
568 
569   // Functor used to print the ordered entries of a registration map.
570   auto printOrderedEntries = [&](StringRef header, auto &map) {
571     llvm::SmallVector<PassRegistryEntry *, 32> orderedEntries;
572     for (auto &kv : map)
573       orderedEntries.push_back(&kv.second);
574     llvm::array_pod_sort(
575         orderedEntries.begin(), orderedEntries.end(),
576         [](PassRegistryEntry *const *lhs, PassRegistryEntry *const *rhs) {
577           return (*lhs)->getPassArgument().compare((*rhs)->getPassArgument());
578         });
579 
580     llvm::outs().indent(4) << header << ":\n";
581     for (PassRegistryEntry *entry : orderedEntries)
582       entry->printHelpStr(/*indent=*/6, globalWidth);
583   };
584 
585   // Print the available passes.
586   printOrderedEntries("Passes", *passRegistry);
587 
588   // Print the available pass pipelines.
589   if (!passPipelineRegistry->empty())
590     printOrderedEntries("Pass Pipelines", *passPipelineRegistry);
591 }
592 
getOptionWidth(const llvm::cl::Option & opt) const593 size_t PassNameParser::getOptionWidth(const llvm::cl::Option &opt) const {
594   size_t maxWidth = llvm::cl::parser<PassArgData>::getOptionWidth(opt) + 2;
595 
596   // Check for any wider pass or pipeline options.
597   for (auto &entry : *passRegistry)
598     maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
599   for (auto &entry : *passPipelineRegistry)
600     maxWidth = std::max(maxWidth, entry.second.getOptionWidth() + 4);
601   return maxWidth;
602 }
603 
parse(llvm::cl::Option & opt,StringRef argName,StringRef arg,PassArgData & value)604 bool PassNameParser::parse(llvm::cl::Option &opt, StringRef argName,
605                            StringRef arg, PassArgData &value) {
606   // Handle the pipeline option explicitly.
607   if (argName == passPipelineArg)
608     return failed(value.pipeline.initialize(arg, llvm::errs()));
609 
610   // Otherwise, default to the base for handling.
611   if (llvm::cl::parser<PassArgData>::parse(opt, argName, arg, value))
612     return true;
613   value.options = arg;
614   return false;
615 }
616 
617 //===----------------------------------------------------------------------===//
618 // PassPipelineCLParser
619 //===----------------------------------------------------------------------===//
620 
621 namespace mlir {
622 namespace detail {
623 struct PassPipelineCLParserImpl {
PassPipelineCLParserImplmlir::detail::PassPipelineCLParserImpl624   PassPipelineCLParserImpl(StringRef arg, StringRef description)
625       : passList(arg, llvm::cl::desc(description)) {
626     passList.setValueExpectedFlag(llvm::cl::ValueExpected::ValueOptional);
627   }
628 
629   /// The set of passes and pass pipelines to run.
630   llvm::cl::list<PassArgData, bool, PassNameParser> passList;
631 };
632 } // end namespace detail
633 } // end namespace mlir
634 
635 /// Construct a pass pipeline parser with the given command line description.
PassPipelineCLParser(StringRef arg,StringRef description)636 PassPipelineCLParser::PassPipelineCLParser(StringRef arg, StringRef description)
637     : impl(std::make_unique<detail::PassPipelineCLParserImpl>(arg,
638                                                               description)) {}
~PassPipelineCLParser()639 PassPipelineCLParser::~PassPipelineCLParser() {}
640 
641 /// Returns true if this parser contains any valid options to add.
hasAnyOccurrences() const642 bool PassPipelineCLParser::hasAnyOccurrences() const {
643   return impl->passList.getNumOccurrences() != 0;
644 }
645 
646 /// Returns true if the given pass registry entry was registered at the
647 /// top-level of the parser, i.e. not within an explicit textual pipeline.
contains(const PassRegistryEntry * entry) const648 bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const {
649   return llvm::any_of(impl->passList, [&](const PassArgData &data) {
650     return data.registryEntry == entry;
651   });
652 }
653 
654 /// Adds the passes defined by this parser entry to the given pass manager.
addToPipeline(OpPassManager & pm,function_ref<LogicalResult (const Twine &)> errorHandler) const655 LogicalResult PassPipelineCLParser::addToPipeline(
656     OpPassManager &pm,
657     function_ref<LogicalResult(const Twine &)> errorHandler) const {
658   for (auto &passIt : impl->passList) {
659     if (passIt.registryEntry) {
660       if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
661                                                      errorHandler)))
662         return failure();
663     } else {
664       OpPassManager::Nesting nesting = pm.getNesting();
665       pm.setNesting(OpPassManager::Nesting::Explicit);
666       LogicalResult status = passIt.pipeline.addToPipeline(pm, errorHandler);
667       pm.setNesting(nesting);
668       if (failed(status))
669         return failure();
670     }
671   }
672   return success();
673 }
674