1 //===- Pass.cpp - Pass infrastructure implementation ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements common pass infrastructure.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Pass/Pass.h"
14 #include "PassDetail.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Verifier.h"
18 #include "mlir/Support/FileUtilities.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/ScopeExit.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/Support/CommandLine.h"
23 #include "llvm/Support/CrashRecoveryContext.h"
24 #include "llvm/Support/Mutex.h"
25 #include "llvm/Support/Parallel.h"
26 #include "llvm/Support/Signals.h"
27 #include "llvm/Support/Threading.h"
28 #include "llvm/Support/ToolOutputFile.h"
29 
30 using namespace mlir;
31 using namespace mlir::detail;
32 
33 //===----------------------------------------------------------------------===//
34 // Pass
35 //===----------------------------------------------------------------------===//
36 
37 /// Out of line virtual method to ensure vtables and metadata are emitted to a
38 /// single .o file.
anchor()39 void Pass::anchor() {}
40 
41 /// Attempt to initialize the options of this pass from the given string.
initializeOptions(StringRef options)42 LogicalResult Pass::initializeOptions(StringRef options) {
43   return passOptions.parseFromString(options);
44 }
45 
46 /// Copy the option values from 'other', which is another instance of this
47 /// pass.
copyOptionValuesFrom(const Pass * other)48 void Pass::copyOptionValuesFrom(const Pass *other) {
49   passOptions.copyOptionValuesFrom(other->passOptions);
50 }
51 
52 /// Prints out the pass in the textual representation of pipelines. If this is
53 /// an adaptor pass, print with the op_name(sub_pass,...) format.
printAsTextualPipeline(raw_ostream & os)54 void Pass::printAsTextualPipeline(raw_ostream &os) {
55   // Special case for adaptors to use the 'op_name(sub_passes)' format.
56   if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(this)) {
57     llvm::interleaveComma(adaptor->getPassManagers(), os,
58                           [&](OpPassManager &pm) {
59                             os << pm.getOpName() << "(";
60                             pm.printAsTextualPipeline(os);
61                             os << ")";
62                           });
63     return;
64   }
65   // Otherwise, print the pass argument followed by its options. If the pass
66   // doesn't have an argument, print the name of the pass to give some indicator
67   // of what pass was run.
68   StringRef argument = getArgument();
69   if (!argument.empty())
70     os << argument;
71   else
72     os << "unknown<" << getName() << ">";
73   passOptions.print(os);
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // OpPassManagerImpl
78 //===----------------------------------------------------------------------===//
79 
80 namespace mlir {
81 namespace detail {
82 struct OpPassManagerImpl {
OpPassManagerImplmlir::detail::OpPassManagerImpl83   OpPassManagerImpl(Identifier identifier, OpPassManager::Nesting nesting)
84       : name(identifier.str()), identifier(identifier), nesting(nesting) {}
OpPassManagerImplmlir::detail::OpPassManagerImpl85   OpPassManagerImpl(StringRef name, OpPassManager::Nesting nesting)
86       : name(name), nesting(nesting) {}
87 
88   /// Merge the passes of this pass manager into the one provided.
89   void mergeInto(OpPassManagerImpl &rhs);
90 
91   /// Nest a new operation pass manager for the given operation kind under this
92   /// pass manager.
93   OpPassManager &nest(Identifier nestedName);
94   OpPassManager &nest(StringRef nestedName);
95 
96   /// Add the given pass to this pass manager. If this pass has a concrete
97   /// operation type, it must be the same type as this pass manager.
98   void addPass(std::unique_ptr<Pass> pass);
99 
100   /// Coalesce adjacent AdaptorPasses into one large adaptor. This runs
101   /// recursively through the pipeline graph.
102   void coalesceAdjacentAdaptorPasses();
103 
104   /// Split all of AdaptorPasses such that each adaptor only contains one leaf
105   /// pass.
106   void splitAdaptorPasses();
107 
getOpNamemlir::detail::OpPassManagerImpl108   Identifier getOpName(MLIRContext &context) {
109     if (!identifier)
110       identifier = Identifier::get(name, &context);
111     return *identifier;
112   }
113 
114   /// The name of the operation that passes of this pass manager operate on.
115   std::string name;
116 
117   /// The cached identifier (internalized in the context) for the name of the
118   /// operation that passes of this pass manager operate on.
119   Optional<Identifier> identifier;
120 
121   /// The set of passes to run as part of this pass manager.
122   std::vector<std::unique_ptr<Pass>> passes;
123 
124   /// Control the implicit nesting of passes that mismatch the name set for this
125   /// OpPassManager.
126   OpPassManager::Nesting nesting;
127 };
128 } // end namespace detail
129 } // end namespace mlir
130 
mergeInto(OpPassManagerImpl & rhs)131 void OpPassManagerImpl::mergeInto(OpPassManagerImpl &rhs) {
132   assert(name == rhs.name && "merging unrelated pass managers");
133   for (auto &pass : passes)
134     rhs.passes.push_back(std::move(pass));
135   passes.clear();
136 }
137 
nest(Identifier nestedName)138 OpPassManager &OpPassManagerImpl::nest(Identifier nestedName) {
139   OpPassManager nested(nestedName, nesting);
140   auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
141   addPass(std::unique_ptr<Pass>(adaptor));
142   return adaptor->getPassManagers().front();
143 }
144 
nest(StringRef nestedName)145 OpPassManager &OpPassManagerImpl::nest(StringRef nestedName) {
146   OpPassManager nested(nestedName, nesting);
147   auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
148   addPass(std::unique_ptr<Pass>(adaptor));
149   return adaptor->getPassManagers().front();
150 }
151 
addPass(std::unique_ptr<Pass> pass)152 void OpPassManagerImpl::addPass(std::unique_ptr<Pass> pass) {
153   // If this pass runs on a different operation than this pass manager, then
154   // implicitly nest a pass manager for this operation if enabled.
155   auto passOpName = pass->getOpName();
156   if (passOpName && passOpName->str() != name) {
157     if (nesting == OpPassManager::Nesting::Implicit)
158       return nest(*passOpName).addPass(std::move(pass));
159     llvm::report_fatal_error(llvm::Twine("Can't add pass '") + pass->getName() +
160                              "' restricted to '" + *passOpName +
161                              "' on a PassManager intended to run on '" + name +
162                              "', did you intend to nest?");
163   }
164 
165   passes.emplace_back(std::move(pass));
166 }
167 
coalesceAdjacentAdaptorPasses()168 void OpPassManagerImpl::coalesceAdjacentAdaptorPasses() {
169   // Bail out early if there are no adaptor passes.
170   if (llvm::none_of(passes, [](std::unique_ptr<Pass> &pass) {
171         return isa<OpToOpPassAdaptor>(pass.get());
172       }))
173     return;
174 
175   // Walk the pass list and merge adjacent adaptors.
176   OpToOpPassAdaptor *lastAdaptor = nullptr;
177   for (auto it = passes.begin(), e = passes.end(); it != e; ++it) {
178     // Check to see if this pass is an adaptor.
179     if (auto *currentAdaptor = dyn_cast<OpToOpPassAdaptor>(it->get())) {
180       // If it is the first adaptor in a possible chain, remember it and
181       // continue.
182       if (!lastAdaptor) {
183         lastAdaptor = currentAdaptor;
184         continue;
185       }
186 
187       // Otherwise, merge into the existing adaptor and delete the current one.
188       currentAdaptor->mergeInto(*lastAdaptor);
189       it->reset();
190     } else if (lastAdaptor) {
191       // If this pass is not an adaptor, then coalesce and forget any existing
192       // adaptor.
193       for (auto &pm : lastAdaptor->getPassManagers())
194         pm.getImpl().coalesceAdjacentAdaptorPasses();
195       lastAdaptor = nullptr;
196     }
197   }
198 
199   // If there was an adaptor at the end of the manager, coalesce it as well.
200   if (lastAdaptor) {
201     for (auto &pm : lastAdaptor->getPassManagers())
202       pm.getImpl().coalesceAdjacentAdaptorPasses();
203   }
204 
205   // Now that the adaptors have been merged, erase the empty slot corresponding
206   // to the merged adaptors that were nulled-out in the loop above.
207   llvm::erase_if(passes, std::logical_not<std::unique_ptr<Pass>>());
208 }
209 
splitAdaptorPasses()210 void OpPassManagerImpl::splitAdaptorPasses() {
211   std::vector<std::unique_ptr<Pass>> oldPasses;
212   std::swap(passes, oldPasses);
213 
214   for (std::unique_ptr<Pass> &pass : oldPasses) {
215     // If this pass isn't an adaptor, move it directly to the new pass list.
216     auto *currentAdaptor = dyn_cast<OpToOpPassAdaptor>(pass.get());
217     if (!currentAdaptor) {
218       addPass(std::move(pass));
219       continue;
220     }
221 
222     // Otherwise, split the adaptors of each manager within the adaptor.
223     for (OpPassManager &adaptorPM : currentAdaptor->getPassManagers()) {
224       adaptorPM.getImpl().splitAdaptorPasses();
225       for (std::unique_ptr<Pass> &nestedPass : adaptorPM.getImpl().passes)
226         nest(adaptorPM.getOpName()).addPass(std::move(nestedPass));
227     }
228   }
229 }
230 
231 //===----------------------------------------------------------------------===//
232 // OpPassManager
233 //===----------------------------------------------------------------------===//
234 
OpPassManager(Identifier name,Nesting nesting)235 OpPassManager::OpPassManager(Identifier name, Nesting nesting)
236     : impl(new OpPassManagerImpl(name, nesting)) {}
OpPassManager(StringRef name,Nesting nesting)237 OpPassManager::OpPassManager(StringRef name, Nesting nesting)
238     : impl(new OpPassManagerImpl(name, nesting)) {}
OpPassManager(OpPassManager && rhs)239 OpPassManager::OpPassManager(OpPassManager &&rhs) : impl(std::move(rhs.impl)) {}
OpPassManager(const OpPassManager & rhs)240 OpPassManager::OpPassManager(const OpPassManager &rhs) { *this = rhs; }
operator =(const OpPassManager & rhs)241 OpPassManager &OpPassManager::operator=(const OpPassManager &rhs) {
242   impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->nesting));
243   for (auto &pass : rhs.impl->passes)
244     impl->passes.emplace_back(pass->clone());
245   return *this;
246 }
247 
~OpPassManager()248 OpPassManager::~OpPassManager() {}
249 
begin()250 OpPassManager::pass_iterator OpPassManager::begin() {
251   return MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin();
252 }
end()253 OpPassManager::pass_iterator OpPassManager::end() {
254   return MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.end();
255 }
256 
begin() const257 OpPassManager::const_pass_iterator OpPassManager::begin() const {
258   return ArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin();
259 }
end() const260 OpPassManager::const_pass_iterator OpPassManager::end() const {
261   return ArrayRef<std::unique_ptr<Pass>>{impl->passes}.end();
262 }
263 
264 /// Nest a new operation pass manager for the given operation kind under this
265 /// pass manager.
nest(Identifier nestedName)266 OpPassManager &OpPassManager::nest(Identifier nestedName) {
267   return impl->nest(nestedName);
268 }
nest(StringRef nestedName)269 OpPassManager &OpPassManager::nest(StringRef nestedName) {
270   return impl->nest(nestedName);
271 }
272 
273 /// Add the given pass to this pass manager. If this pass has a concrete
274 /// operation type, it must be the same type as this pass manager.
addPass(std::unique_ptr<Pass> pass)275 void OpPassManager::addPass(std::unique_ptr<Pass> pass) {
276   impl->addPass(std::move(pass));
277 }
278 
279 /// Returns the number of passes held by this manager.
size() const280 size_t OpPassManager::size() const { return impl->passes.size(); }
281 
282 /// Returns the internal implementation instance.
getImpl()283 OpPassManagerImpl &OpPassManager::getImpl() { return *impl; }
284 
285 /// Return the operation name that this pass manager operates on.
getOpName() const286 StringRef OpPassManager::getOpName() const { return impl->name; }
287 
288 /// Return the operation name that this pass manager operates on.
getOpName(MLIRContext & context) const289 Identifier OpPassManager::getOpName(MLIRContext &context) const {
290   return impl->getOpName(context);
291 }
292 
293 /// Prints out the given passes as the textual representation of a pipeline.
printAsTextualPipeline(ArrayRef<std::unique_ptr<Pass>> passes,raw_ostream & os)294 static void printAsTextualPipeline(ArrayRef<std::unique_ptr<Pass>> passes,
295                                    raw_ostream &os) {
296   llvm::interleaveComma(passes, os, [&](const std::unique_ptr<Pass> &pass) {
297     pass->printAsTextualPipeline(os);
298   });
299 }
300 
301 /// Prints out the passes of the pass manager as the textual representation
302 /// of pipelines.
printAsTextualPipeline(raw_ostream & os)303 void OpPassManager::printAsTextualPipeline(raw_ostream &os) {
304   ::printAsTextualPipeline(impl->passes, os);
305 }
306 
dump()307 void OpPassManager::dump() {
308   llvm::errs() << "Pass Manager with " << impl->passes.size() << " passes: ";
309   ::printAsTextualPipeline(impl->passes, llvm::errs());
310   llvm::errs() << "\n";
311 }
312 
registerDialectsForPipeline(const OpPassManager & pm,DialectRegistry & dialects)313 static void registerDialectsForPipeline(const OpPassManager &pm,
314                                         DialectRegistry &dialects) {
315   for (const Pass &pass : pm.getPasses())
316     pass.getDependentDialects(dialects);
317 }
318 
getDependentDialects(DialectRegistry & dialects) const319 void OpPassManager::getDependentDialects(DialectRegistry &dialects) const {
320   registerDialectsForPipeline(*this, dialects);
321 }
322 
getNesting()323 OpPassManager::Nesting OpPassManager::getNesting() { return impl->nesting; }
324 
setNesting(Nesting nesting)325 void OpPassManager::setNesting(Nesting nesting) { impl->nesting = nesting; }
326 
327 //===----------------------------------------------------------------------===//
328 // OpToOpPassAdaptor
329 //===----------------------------------------------------------------------===//
330 
run(Pass * pass,Operation * op,AnalysisManager am,bool verifyPasses)331 LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
332                                      AnalysisManager am, bool verifyPasses) {
333   if (!op->getName().getAbstractOperation())
334     return op->emitOpError()
335            << "trying to schedule a pass on an unregistered operation";
336   if (!op->getName().getAbstractOperation()->hasProperty(
337           OperationProperty::IsolatedFromAbove))
338     return op->emitOpError() << "trying to schedule a pass on an operation not "
339                                 "marked as 'IsolatedFromAbove'";
340 
341   // Initialize the pass state with a callback for the pass to dynamically
342   // execute a pipeline on the currently visited operation.
343   auto dynamic_pipeline_callback =
344       [op, &am, verifyPasses](OpPassManager &pipeline,
345                               Operation *root) -> LogicalResult {
346     if (!op->isAncestor(root))
347       return root->emitOpError()
348              << "Trying to schedule a dynamic pipeline on an "
349                 "operation that isn't "
350                 "nested under the current operation the pass is processing";
351 
352     AnalysisManager nestedAm = am.nest(root);
353     return OpToOpPassAdaptor::runPipeline(pipeline.getPasses(), root, nestedAm,
354                                           verifyPasses);
355   };
356   pass->passState.emplace(op, am, dynamic_pipeline_callback);
357   // Instrument before the pass has run.
358   PassInstrumentor *pi = am.getPassInstrumentor();
359   if (pi)
360     pi->runBeforePass(pass, op);
361 
362   // Invoke the virtual runOnOperation method.
363   if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
364     adaptor->runOnOperation(verifyPasses);
365   else
366     pass->runOnOperation();
367   bool passFailed = pass->passState->irAndPassFailed.getInt();
368 
369   // Invalidate any non preserved analyses.
370   am.invalidate(pass->passState->preservedAnalyses);
371 
372   // Run the verifier if this pass didn't fail already.
373   if (!passFailed && verifyPasses)
374     passFailed = failed(verify(op));
375 
376   // Instrument after the pass has run.
377   if (pi) {
378     if (passFailed)
379       pi->runAfterPassFailed(pass, op);
380     else
381       pi->runAfterPass(pass, op);
382   }
383 
384   // Return if the pass signaled a failure.
385   return failure(passFailed);
386 }
387 
388 /// Run the given operation and analysis manager on a provided op pass manager.
runPipeline(iterator_range<OpPassManager::pass_iterator> passes,Operation * op,AnalysisManager am,bool verifyPasses)389 LogicalResult OpToOpPassAdaptor::runPipeline(
390     iterator_range<OpPassManager::pass_iterator> passes, Operation *op,
391     AnalysisManager am, bool verifyPasses) {
392   auto scope_exit = llvm::make_scope_exit([&] {
393     // Clear out any computed operation analyses. These analyses won't be used
394     // any more in this pipeline, and this helps reduce the current working set
395     // of memory. If preserving these analyses becomes important in the future
396     // we can re-evaluate this.
397     am.clear();
398   });
399 
400   // Run the pipeline over the provided operation.
401   for (Pass &pass : passes)
402     if (failed(run(&pass, op, am, verifyPasses)))
403       return failure();
404 
405   return success();
406 }
407 
408 /// Find an operation pass manager that can operate on an operation of the given
409 /// type, or nullptr if one does not exist.
findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,StringRef name)410 static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
411                                          StringRef name) {
412   auto it = llvm::find_if(
413       mgrs, [&](OpPassManager &mgr) { return mgr.getOpName() == name; });
414   return it == mgrs.end() ? nullptr : &*it;
415 }
416 
417 /// Find an operation pass manager that can operate on an operation of the given
418 /// type, or nullptr if one does not exist.
findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,Identifier name,MLIRContext & context)419 static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
420                                          Identifier name,
421                                          MLIRContext &context) {
422   auto it = llvm::find_if(
423       mgrs, [&](OpPassManager &mgr) { return mgr.getOpName(context) == name; });
424   return it == mgrs.end() ? nullptr : &*it;
425 }
426 
OpToOpPassAdaptor(OpPassManager && mgr)427 OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) {
428   mgrs.emplace_back(std::move(mgr));
429 }
430 
getDependentDialects(DialectRegistry & dialects) const431 void OpToOpPassAdaptor::getDependentDialects(DialectRegistry &dialects) const {
432   for (auto &pm : mgrs)
433     pm.getDependentDialects(dialects);
434 }
435 
436 /// Merge the current pass adaptor into given 'rhs'.
mergeInto(OpToOpPassAdaptor & rhs)437 void OpToOpPassAdaptor::mergeInto(OpToOpPassAdaptor &rhs) {
438   for (auto &pm : mgrs) {
439     // If an existing pass manager exists, then merge the given pass manager
440     // into it.
441     if (auto *existingPM = findPassManagerFor(rhs.mgrs, pm.getOpName())) {
442       pm.getImpl().mergeInto(existingPM->getImpl());
443     } else {
444       // Otherwise, add the given pass manager to the list.
445       rhs.mgrs.emplace_back(std::move(pm));
446     }
447   }
448   mgrs.clear();
449 
450   // After coalescing, sort the pass managers within rhs by name.
451   llvm::array_pod_sort(rhs.mgrs.begin(), rhs.mgrs.end(),
452                        [](const OpPassManager *lhs, const OpPassManager *rhs) {
453                          return lhs->getOpName().compare(rhs->getOpName());
454                        });
455 }
456 
457 /// Returns the adaptor pass name.
getAdaptorName()458 std::string OpToOpPassAdaptor::getAdaptorName() {
459   std::string name = "Pipeline Collection : [";
460   llvm::raw_string_ostream os(name);
461   llvm::interleaveComma(getPassManagers(), os, [&](OpPassManager &pm) {
462     os << '\'' << pm.getOpName() << '\'';
463   });
464   os << ']';
465   return os.str();
466 }
467 
runOnOperation()468 void OpToOpPassAdaptor::runOnOperation() {
469   llvm_unreachable(
470       "Unexpected call to Pass::runOnOperation() on OpToOpPassAdaptor");
471 }
472 
473 /// Run the held pipeline over all nested operations.
runOnOperation(bool verifyPasses)474 void OpToOpPassAdaptor::runOnOperation(bool verifyPasses) {
475   if (getContext().isMultithreadingEnabled())
476     runOnOperationAsyncImpl(verifyPasses);
477   else
478     runOnOperationImpl(verifyPasses);
479 }
480 
481 /// Run this pass adaptor synchronously.
runOnOperationImpl(bool verifyPasses)482 void OpToOpPassAdaptor::runOnOperationImpl(bool verifyPasses) {
483   auto am = getAnalysisManager();
484   PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
485                                                         this};
486   auto *instrumentor = am.getPassInstrumentor();
487   for (auto &region : getOperation()->getRegions()) {
488     for (auto &block : region) {
489       for (auto &op : block) {
490         auto *mgr = findPassManagerFor(mgrs, op.getName().getIdentifier(),
491                                        *op.getContext());
492         if (!mgr)
493           continue;
494         Identifier opName = mgr->getOpName(*getOperation()->getContext());
495 
496         // Run the held pipeline over the current operation.
497         if (instrumentor)
498           instrumentor->runBeforePipeline(opName, parentInfo);
499         LogicalResult result =
500             runPipeline(mgr->getPasses(), &op, am.nest(&op), verifyPasses);
501         if (instrumentor)
502           instrumentor->runAfterPipeline(opName, parentInfo);
503 
504         if (failed(result))
505           return signalPassFailure();
506       }
507     }
508   }
509 }
510 
511 /// Utility functor that checks if the two ranges of pass managers have a size
512 /// mismatch.
hasSizeMismatch(ArrayRef<OpPassManager> lhs,ArrayRef<OpPassManager> rhs)513 static bool hasSizeMismatch(ArrayRef<OpPassManager> lhs,
514                             ArrayRef<OpPassManager> rhs) {
515   return lhs.size() != rhs.size() ||
516          llvm::any_of(llvm::seq<size_t>(0, lhs.size()),
517                       [&](size_t i) { return lhs[i].size() != rhs[i].size(); });
518 }
519 
520 /// Run this pass adaptor synchronously.
runOnOperationAsyncImpl(bool verifyPasses)521 void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
522   AnalysisManager am = getAnalysisManager();
523 
524   // Create the async executors if they haven't been created, or if the main
525   // pipeline has changed.
526   if (asyncExecutors.empty() || hasSizeMismatch(asyncExecutors.front(), mgrs))
527     asyncExecutors.assign(llvm::hardware_concurrency().compute_thread_count(),
528                           mgrs);
529 
530   // Run a prepass over the operation to collect the nested operations to
531   // execute over. This ensures that an analysis manager exists for each
532   // operation, as well as providing a queue of operations to execute over.
533   std::vector<std::pair<Operation *, AnalysisManager>> opAMPairs;
534   for (auto &region : getOperation()->getRegions()) {
535     for (auto &block : region) {
536       for (auto &op : block) {
537         // Add this operation iff the name matches any of the pass managers.
538         if (findPassManagerFor(mgrs, op.getName().getIdentifier(),
539                                getContext()))
540           opAMPairs.emplace_back(&op, am.nest(&op));
541       }
542     }
543   }
544 
545   // A parallel diagnostic handler that provides deterministic diagnostic
546   // ordering.
547   ParallelDiagnosticHandler diagHandler(&getContext());
548 
549   // An index for the current operation/analysis manager pair.
550   std::atomic<unsigned> opIt(0);
551 
552   // Get the current thread for this adaptor.
553   PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
554                                                         this};
555   auto *instrumentor = am.getPassInstrumentor();
556 
557   // An atomic failure variable for the async executors.
558   std::atomic<bool> passFailed(false);
559   llvm::parallelForEach(
560       asyncExecutors.begin(),
561       std::next(asyncExecutors.begin(),
562                 std::min(asyncExecutors.size(), opAMPairs.size())),
563       [&](MutableArrayRef<OpPassManager> pms) {
564         for (auto e = opAMPairs.size(); !passFailed && opIt < e;) {
565           // Get the next available operation index.
566           unsigned nextID = opIt++;
567           if (nextID >= e)
568             break;
569 
570           // Set the order id for this thread in the diagnostic handler.
571           diagHandler.setOrderIDForThread(nextID);
572 
573           // Get the pass manager for this operation and execute it.
574           auto &it = opAMPairs[nextID];
575           auto *pm = findPassManagerFor(
576               pms, it.first->getName().getIdentifier(), getContext());
577           assert(pm && "expected valid pass manager for operation");
578 
579           Identifier opName = pm->getOpName(*getOperation()->getContext());
580           if (instrumentor)
581             instrumentor->runBeforePipeline(opName, parentInfo);
582           auto pipelineResult =
583               runPipeline(pm->getPasses(), it.first, it.second, verifyPasses);
584           if (instrumentor)
585             instrumentor->runAfterPipeline(opName, parentInfo);
586 
587           // Drop this thread from being tracked by the diagnostic handler.
588           // After this task has finished, the thread may be used outside of
589           // this pass manager context meaning that we don't want to track
590           // diagnostics from it anymore.
591           diagHandler.eraseOrderIDForThread();
592 
593           // Handle a failed pipeline result.
594           if (failed(pipelineResult)) {
595             passFailed = true;
596             break;
597           }
598         }
599       });
600 
601   // Signal a failure if any of the executors failed.
602   if (passFailed)
603     signalPassFailure();
604 }
605 
606 //===----------------------------------------------------------------------===//
607 // PassCrashReproducer
608 //===----------------------------------------------------------------------===//
609 
610 namespace {
611 /// This class contains all of the context for generating a recovery reproducer.
612 /// Each recovery context is registered globally to allow for generating
613 /// reproducers when a signal is raised, such as a segfault.
614 struct RecoveryReproducerContext {
615   RecoveryReproducerContext(MutableArrayRef<std::unique_ptr<Pass>> passes,
616                             Operation *op, StringRef filename,
617                             bool disableThreads, bool verifyPasses);
618   ~RecoveryReproducerContext();
619 
620   /// Generate a reproducer with the current context.
621   LogicalResult generate(std::string &error);
622 
623 private:
624   /// This function is invoked in the event of a crash.
625   static void crashHandler(void *);
626 
627   /// Register a signal handler to run in the event of a crash.
628   static void registerSignalHandler();
629 
630   /// The textual description of the currently executing pipeline.
631   std::string pipeline;
632 
633   /// The MLIR operation representing the IR before the crash.
634   Operation *preCrashOperation;
635 
636   /// The filename to use when generating the reproducer.
637   StringRef filename;
638 
639   /// Various pass manager and context flags.
640   bool disableThreads;
641   bool verifyPasses;
642 
643   /// The current set of active reproducer contexts. This is used in the event
644   /// of a crash. This is not thread_local as the pass manager may produce any
645   /// number of child threads. This uses a set to allow for multiple MLIR pass
646   /// managers to be running at the same time.
647   static llvm::ManagedStatic<llvm::sys::SmartMutex<true>> reproducerMutex;
648   static llvm::ManagedStatic<
649       llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
650       reproducerSet;
651 };
652 } // end anonymous namespace
653 
654 llvm::ManagedStatic<llvm::sys::SmartMutex<true>>
655     RecoveryReproducerContext::reproducerMutex;
656 llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
657     RecoveryReproducerContext::reproducerSet;
658 
RecoveryReproducerContext(MutableArrayRef<std::unique_ptr<Pass>> passes,Operation * op,StringRef filename,bool disableThreads,bool verifyPasses)659 RecoveryReproducerContext::RecoveryReproducerContext(
660     MutableArrayRef<std::unique_ptr<Pass>> passes, Operation *op,
661     StringRef filename, bool disableThreads, bool verifyPasses)
662     : preCrashOperation(op->clone()), filename(filename),
663       disableThreads(disableThreads), verifyPasses(verifyPasses) {
664   // Grab the textual pipeline being executed..
665   {
666     llvm::raw_string_ostream pipelineOS(pipeline);
667     ::printAsTextualPipeline(passes, pipelineOS);
668   }
669 
670   // Make sure that the handler is registered, and update the current context.
671   llvm::sys::SmartScopedLock<true> producerLock(*reproducerMutex);
672   if (reproducerSet->empty())
673     llvm::CrashRecoveryContext::Enable();
674   registerSignalHandler();
675   reproducerSet->insert(this);
676 }
677 
~RecoveryReproducerContext()678 RecoveryReproducerContext::~RecoveryReproducerContext() {
679   // Erase the cloned preCrash IR that we cached.
680   preCrashOperation->erase();
681 
682   llvm::sys::SmartScopedLock<true> producerLock(*reproducerMutex);
683   reproducerSet->remove(this);
684   if (reproducerSet->empty())
685     llvm::CrashRecoveryContext::Disable();
686 }
687 
generate(std::string & error)688 LogicalResult RecoveryReproducerContext::generate(std::string &error) {
689   std::unique_ptr<llvm::ToolOutputFile> outputFile =
690       mlir::openOutputFile(filename, &error);
691   if (!outputFile)
692     return failure();
693   auto &outputOS = outputFile->os();
694 
695   // Output the current pass manager configuration.
696   outputOS << "// configuration: -pass-pipeline='" << pipeline << "'";
697   if (disableThreads)
698     outputOS << " -mlir-disable-threading";
699 
700   // TODO: Should this also be configured with a pass manager flag?
701   outputOS << "\n// note: verifyPasses=" << (verifyPasses ? "true" : "false")
702            << "\n";
703 
704   // Output the .mlir module.
705   preCrashOperation->print(outputOS);
706   outputFile->keep();
707   return success();
708 }
709 
crashHandler(void *)710 void RecoveryReproducerContext::crashHandler(void *) {
711   // Walk the current stack of contexts and generate a reproducer for each one.
712   // We can't know for certain which one was the cause, so we need to generate
713   // a reproducer for all of them.
714   std::string ignored;
715   for (RecoveryReproducerContext *context : *reproducerSet)
716     context->generate(ignored);
717 }
718 
registerSignalHandler()719 void RecoveryReproducerContext::registerSignalHandler() {
720   // Ensure that the handler is only registered once.
721   static bool registered =
722       (llvm::sys::AddSignalHandler(crashHandler, nullptr), false);
723   (void)registered;
724 }
725 
726 /// Run the pass manager with crash recover enabled.
runWithCrashRecovery(Operation * op,AnalysisManager am)727 LogicalResult PassManager::runWithCrashRecovery(Operation *op,
728                                                 AnalysisManager am) {
729   // If this isn't a local producer, run all of the passes in recovery mode.
730   if (!localReproducer)
731     return runWithCrashRecovery(impl->passes, op, am);
732 
733   // Split the passes within adaptors to ensure that each pass can be run in
734   // isolation.
735   impl->splitAdaptorPasses();
736 
737   // If this is a local producer, run each of the passes individually.
738   MutableArrayRef<std::unique_ptr<Pass>> passes = impl->passes;
739   for (std::unique_ptr<Pass> &pass : passes)
740     if (failed(runWithCrashRecovery(pass, op, am)))
741       return failure();
742   return success();
743 }
744 
745 /// Run the given passes with crash recover enabled.
746 LogicalResult
runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,Operation * op,AnalysisManager am)747 PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
748                                   Operation *op, AnalysisManager am) {
749   RecoveryReproducerContext context(passes, op, *crashReproducerFileName,
750                                     !getContext()->isMultithreadingEnabled(),
751                                     verifyPasses);
752 
753   // Safely invoke the passes within a recovery context.
754   LogicalResult passManagerResult = failure();
755   llvm::CrashRecoveryContext recoveryContext;
756   recoveryContext.RunSafelyOnThread([&] {
757     for (std::unique_ptr<Pass> &pass : passes)
758       if (failed(OpToOpPassAdaptor::run(pass.get(), op, am, verifyPasses)))
759         return;
760     passManagerResult = success();
761   });
762   if (succeeded(passManagerResult))
763     return success();
764 
765   std::string error;
766   if (failed(context.generate(error)))
767     return op->emitError("<MLIR-PassManager-Crash-Reproducer>: ") << error;
768   return op->emitError()
769          << "A failure has been detected while processing the MLIR module, a "
770             "reproducer has been generated in '"
771          << *crashReproducerFileName << "'";
772 }
773 
774 //===----------------------------------------------------------------------===//
775 // PassManager
776 //===----------------------------------------------------------------------===//
777 
PassManager(MLIRContext * ctx,Nesting nesting,StringRef operationName)778 PassManager::PassManager(MLIRContext *ctx, Nesting nesting,
779                          StringRef operationName)
780     : OpPassManager(Identifier::get(operationName, ctx), nesting), context(ctx),
781       passTiming(false), localReproducer(false), verifyPasses(true) {}
782 
~PassManager()783 PassManager::~PassManager() {}
784 
enableVerifier(bool enabled)785 void PassManager::enableVerifier(bool enabled) { verifyPasses = enabled; }
786 
787 /// Run the passes within this manager on the provided operation.
run(Operation * op)788 LogicalResult PassManager::run(Operation *op) {
789   MLIRContext *context = getContext();
790   assert(op->getName().getIdentifier() == getOpName(*context) &&
791          "operation has a different name than the PassManager");
792 
793   // Before running, make sure to coalesce any adjacent pass adaptors in the
794   // pipeline.
795   getImpl().coalesceAdjacentAdaptorPasses();
796 
797   // Register all dialects for the current pipeline.
798   DialectRegistry dependentDialects;
799   getDependentDialects(dependentDialects);
800   dependentDialects.loadAll(context);
801 
802   // Construct a top level analysis manager for the pipeline.
803   ModuleAnalysisManager am(op, instrumentor.get());
804 
805   // Notify the context that we start running a pipeline for book keeping.
806   context->enterMultiThreadedExecution();
807 
808   // If reproducer generation is enabled, run the pass manager with crash
809   // handling enabled.
810   LogicalResult result =
811       crashReproducerFileName
812           ? runWithCrashRecovery(op, am)
813           : OpToOpPassAdaptor::runPipeline(getPasses(), op, am, verifyPasses);
814 
815   // Notify the context that the run is done.
816   context->exitMultiThreadedExecution();
817 
818   // Dump all of the pass statistics if necessary.
819   if (passStatisticsMode)
820     dumpStatistics();
821   return result;
822 }
823 
824 /// Enable support for the pass manager to generate a reproducer on the event
825 /// of a crash or a pass failure. `outputFile` is a .mlir filename used to write
826 /// the generated reproducer. If `genLocalReproducer` is true, the pass manager
827 /// will attempt to generate a local reproducer that contains the smallest
828 /// pipeline.
enableCrashReproducerGeneration(StringRef outputFile,bool genLocalReproducer)829 void PassManager::enableCrashReproducerGeneration(StringRef outputFile,
830                                                   bool genLocalReproducer) {
831   crashReproducerFileName = std::string(outputFile);
832   localReproducer = genLocalReproducer;
833 }
834 
835 /// Add the provided instrumentation to the pass manager.
addInstrumentation(std::unique_ptr<PassInstrumentation> pi)836 void PassManager::addInstrumentation(std::unique_ptr<PassInstrumentation> pi) {
837   if (!instrumentor)
838     instrumentor = std::make_unique<PassInstrumentor>();
839 
840   instrumentor->addInstrumentation(std::move(pi));
841 }
842 
843 //===----------------------------------------------------------------------===//
844 // AnalysisManager
845 //===----------------------------------------------------------------------===//
846 
847 /// Returns a pass instrumentation object for the current operation.
getPassInstrumentor() const848 PassInstrumentor *AnalysisManager::getPassInstrumentor() const {
849   ParentPointerT curParent = parent;
850   while (auto *parentAM = curParent.dyn_cast<const AnalysisManager *>())
851     curParent = parentAM->parent;
852   return curParent.get<const ModuleAnalysisManager *>()->getPassInstrumentor();
853 }
854 
855 /// Get an analysis manager for the given child operation.
nest(Operation * op)856 AnalysisManager AnalysisManager::nest(Operation *op) {
857   auto it = impl->childAnalyses.find(op);
858   if (it == impl->childAnalyses.end())
859     it = impl->childAnalyses
860              .try_emplace(op, std::make_unique<NestedAnalysisMap>(op))
861              .first;
862   return {this, it->second.get()};
863 }
864 
865 /// Invalidate any non preserved analyses.
invalidate(const detail::PreservedAnalyses & pa)866 void detail::NestedAnalysisMap::invalidate(
867     const detail::PreservedAnalyses &pa) {
868   // If all analyses were preserved, then there is nothing to do here.
869   if (pa.isAll())
870     return;
871 
872   // Invalidate the analyses for the current operation directly.
873   analyses.invalidate(pa);
874 
875   // If no analyses were preserved, then just simply clear out the child
876   // analysis results.
877   if (pa.isNone()) {
878     childAnalyses.clear();
879     return;
880   }
881 
882   // Otherwise, invalidate each child analysis map.
883   SmallVector<NestedAnalysisMap *, 8> mapsToInvalidate(1, this);
884   while (!mapsToInvalidate.empty()) {
885     auto *map = mapsToInvalidate.pop_back_val();
886     for (auto &analysisPair : map->childAnalyses) {
887       analysisPair.second->invalidate(pa);
888       if (!analysisPair.second->childAnalyses.empty())
889         mapsToInvalidate.push_back(analysisPair.second.get());
890     }
891   }
892 }
893 
894 //===----------------------------------------------------------------------===//
895 // PassInstrumentation
896 //===----------------------------------------------------------------------===//
897 
~PassInstrumentation()898 PassInstrumentation::~PassInstrumentation() {}
899 
900 //===----------------------------------------------------------------------===//
901 // PassInstrumentor
902 //===----------------------------------------------------------------------===//
903 
904 namespace mlir {
905 namespace detail {
906 struct PassInstrumentorImpl {
907   /// Mutex to keep instrumentation access thread-safe.
908   llvm::sys::SmartMutex<true> mutex;
909 
910   /// Set of registered instrumentations.
911   std::vector<std::unique_ptr<PassInstrumentation>> instrumentations;
912 };
913 } // end namespace detail
914 } // end namespace mlir
915 
PassInstrumentor()916 PassInstrumentor::PassInstrumentor() : impl(new PassInstrumentorImpl()) {}
~PassInstrumentor()917 PassInstrumentor::~PassInstrumentor() {}
918 
919 /// See PassInstrumentation::runBeforePipeline for details.
runBeforePipeline(Identifier name,const PassInstrumentation::PipelineParentInfo & parentInfo)920 void PassInstrumentor::runBeforePipeline(
921     Identifier name,
922     const PassInstrumentation::PipelineParentInfo &parentInfo) {
923   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
924   for (auto &instr : impl->instrumentations)
925     instr->runBeforePipeline(name, parentInfo);
926 }
927 
928 /// See PassInstrumentation::runAfterPipeline for details.
runAfterPipeline(Identifier name,const PassInstrumentation::PipelineParentInfo & parentInfo)929 void PassInstrumentor::runAfterPipeline(
930     Identifier name,
931     const PassInstrumentation::PipelineParentInfo &parentInfo) {
932   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
933   for (auto &instr : llvm::reverse(impl->instrumentations))
934     instr->runAfterPipeline(name, parentInfo);
935 }
936 
937 /// See PassInstrumentation::runBeforePass for details.
runBeforePass(Pass * pass,Operation * op)938 void PassInstrumentor::runBeforePass(Pass *pass, Operation *op) {
939   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
940   for (auto &instr : impl->instrumentations)
941     instr->runBeforePass(pass, op);
942 }
943 
944 /// See PassInstrumentation::runAfterPass for details.
runAfterPass(Pass * pass,Operation * op)945 void PassInstrumentor::runAfterPass(Pass *pass, Operation *op) {
946   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
947   for (auto &instr : llvm::reverse(impl->instrumentations))
948     instr->runAfterPass(pass, op);
949 }
950 
951 /// See PassInstrumentation::runAfterPassFailed for details.
runAfterPassFailed(Pass * pass,Operation * op)952 void PassInstrumentor::runAfterPassFailed(Pass *pass, Operation *op) {
953   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
954   for (auto &instr : llvm::reverse(impl->instrumentations))
955     instr->runAfterPassFailed(pass, op);
956 }
957 
958 /// See PassInstrumentation::runBeforeAnalysis for details.
runBeforeAnalysis(StringRef name,TypeID id,Operation * op)959 void PassInstrumentor::runBeforeAnalysis(StringRef name, TypeID id,
960                                          Operation *op) {
961   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
962   for (auto &instr : impl->instrumentations)
963     instr->runBeforeAnalysis(name, id, op);
964 }
965 
966 /// See PassInstrumentation::runAfterAnalysis for details.
runAfterAnalysis(StringRef name,TypeID id,Operation * op)967 void PassInstrumentor::runAfterAnalysis(StringRef name, TypeID id,
968                                         Operation *op) {
969   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
970   for (auto &instr : llvm::reverse(impl->instrumentations))
971     instr->runAfterAnalysis(name, id, op);
972 }
973 
974 /// Add the given instrumentation to the collection.
addInstrumentation(std::unique_ptr<PassInstrumentation> pi)975 void PassInstrumentor::addInstrumentation(
976     std::unique_ptr<PassInstrumentation> pi) {
977   llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
978   impl->instrumentations.emplace_back(std::move(pi));
979 }
980