1 //===- IRPrinting.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Pass/PassManager.h"
11 #include "llvm/Support/Format.h"
12 #include "llvm/Support/FormatVariadic.h"
13 #include "llvm/Support/SHA1.h"
14 
15 using namespace mlir;
16 using namespace mlir::detail;
17 
18 namespace {
19 //===----------------------------------------------------------------------===//
20 // OperationFingerPrint
21 //===----------------------------------------------------------------------===//
22 
23 /// A unique fingerprint for a specific operation, and all of it's internal
24 /// operations.
25 class OperationFingerPrint {
26 public:
OperationFingerPrint(Operation * topOp)27   OperationFingerPrint(Operation *topOp) {
28     llvm::SHA1 hasher;
29 
30     // Hash each of the operations based upon their mutable bits:
31     topOp->walk([&](Operation *op) {
32       //   - Operation pointer
33       addDataToHash(hasher, op);
34       //   - Attributes
35       addDataToHash(hasher, op->getMutableAttrDict());
36       //   - Blocks in Regions
37       for (Region &region : op->getRegions()) {
38         for (Block &block : region) {
39           addDataToHash(hasher, &block);
40           for (BlockArgument arg : block.getArguments())
41             addDataToHash(hasher, arg);
42         }
43       }
44       //   - Location
45       addDataToHash(hasher, op->getLoc().getAsOpaquePointer());
46       //   - Operands
47       for (Value operand : op->getOperands())
48         addDataToHash(hasher, operand);
49       //   - Successors
50       for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i)
51         addDataToHash(hasher, op->getSuccessor(i));
52     });
53     hash = hasher.result();
54   }
55 
operator ==(const OperationFingerPrint & other) const56   bool operator==(const OperationFingerPrint &other) const {
57     return hash == other.hash;
58   }
operator !=(const OperationFingerPrint & other) const59   bool operator!=(const OperationFingerPrint &other) const {
60     return !(*this == other);
61   }
62 
63 private:
addDataToHash(llvm::SHA1 & hasher,const T & data)64   template <typename T> void addDataToHash(llvm::SHA1 &hasher, const T &data) {
65     hasher.update(
66         ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
67   }
68 
69   SmallString<20> hash;
70 };
71 
72 //===----------------------------------------------------------------------===//
73 // IRPrinter
74 //===----------------------------------------------------------------------===//
75 
76 class IRPrinterInstrumentation : public PassInstrumentation {
77 public:
IRPrinterInstrumentation(std::unique_ptr<PassManager::IRPrinterConfig> config)78   IRPrinterInstrumentation(std::unique_ptr<PassManager::IRPrinterConfig> config)
79       : config(std::move(config)) {}
80 
81 private:
82   /// Instrumentation hooks.
83   void runBeforePass(Pass *pass, Operation *op) override;
84   void runAfterPass(Pass *pass, Operation *op) override;
85   void runAfterPassFailed(Pass *pass, Operation *op) override;
86 
87   /// Configuration to use.
88   std::unique_ptr<PassManager::IRPrinterConfig> config;
89 
90   /// The following is a set of fingerprints for operations that are currently
91   /// being operated on in a pass. This field is only used when the
92   /// configuration asked for change detection.
93   DenseMap<Pass *, OperationFingerPrint> beforePassFingerPrints;
94 };
95 } // end anonymous namespace
96 
printIR(Operation * op,bool printModuleScope,raw_ostream & out,OpPrintingFlags flags)97 static void printIR(Operation *op, bool printModuleScope, raw_ostream &out,
98                     OpPrintingFlags flags) {
99   // Otherwise, check to see if we are not printing at module scope.
100   if (!printModuleScope)
101     return op->print(out << "\n",
102                      op->getBlock() ? flags.useLocalScope() : flags);
103 
104   // Otherwise, we are printing at module scope.
105   out << " ('" << op->getName() << "' operation";
106   if (auto symbolName =
107           op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()))
108     out << ": @" << symbolName.getValue();
109   out << ")\n";
110 
111   // Find the top-level operation.
112   auto *topLevelOp = op;
113   while (auto *parentOp = topLevelOp->getParentOp())
114     topLevelOp = parentOp;
115   topLevelOp->print(out, flags);
116 }
117 
118 /// Instrumentation hooks.
runBeforePass(Pass * pass,Operation * op)119 void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) {
120   if (isa<OpToOpPassAdaptor>(pass))
121     return;
122   // If the config asked to detect changes, record the current fingerprint.
123   if (config->shouldPrintAfterOnlyOnChange())
124     beforePassFingerPrints.try_emplace(pass, op);
125 
126   config->printBeforeIfEnabled(pass, op, [&](raw_ostream &out) {
127     out << formatv("// *** IR Dump Before {0} ***", pass->getName());
128     printIR(op, config->shouldPrintAtModuleScope(), out,
129             config->getOpPrintingFlags());
130     out << "\n\n";
131   });
132 }
133 
runAfterPass(Pass * pass,Operation * op)134 void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) {
135   if (isa<OpToOpPassAdaptor>(pass))
136     return;
137   // If the config asked to detect changes, compare the current fingerprint with
138   // the previous.
139   if (config->shouldPrintAfterOnlyOnChange()) {
140     auto fingerPrintIt = beforePassFingerPrints.find(pass);
141     assert(fingerPrintIt != beforePassFingerPrints.end() &&
142            "expected valid fingerprint");
143     // If the fingerprints are the same, we don't print the IR.
144     if (fingerPrintIt->second == OperationFingerPrint(op)) {
145       beforePassFingerPrints.erase(fingerPrintIt);
146       return;
147     }
148     beforePassFingerPrints.erase(fingerPrintIt);
149   }
150 
151   config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) {
152     out << formatv("// *** IR Dump After {0} ***", pass->getName());
153     printIR(op, config->shouldPrintAtModuleScope(), out,
154             config->getOpPrintingFlags());
155     out << "\n\n";
156   });
157 }
158 
runAfterPassFailed(Pass * pass,Operation * op)159 void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) {
160   if (isa<OpToOpPassAdaptor>(pass))
161     return;
162   if (config->shouldPrintAfterOnlyOnChange())
163     beforePassFingerPrints.erase(pass);
164 
165   config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) {
166     out << formatv("// *** IR Dump After {0} Failed ***", pass->getName());
167     printIR(op, config->shouldPrintAtModuleScope(), out,
168             OpPrintingFlags().printGenericOpForm());
169     out << "\n\n";
170   });
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // IRPrinterConfig
175 //===----------------------------------------------------------------------===//
176 
177 /// Initialize the configuration.
IRPrinterConfig(bool printModuleScope,bool printAfterOnlyOnChange,OpPrintingFlags opPrintingFlags)178 PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope,
179                                               bool printAfterOnlyOnChange,
180                                               OpPrintingFlags opPrintingFlags)
181     : printModuleScope(printModuleScope),
182       printAfterOnlyOnChange(printAfterOnlyOnChange),
183       opPrintingFlags(opPrintingFlags) {}
~IRPrinterConfig()184 PassManager::IRPrinterConfig::~IRPrinterConfig() {}
185 
186 /// A hook that may be overridden by a derived config that checks if the IR
187 /// of 'operation' should be dumped *before* the pass 'pass' has been
188 /// executed. If the IR should be dumped, 'printCallback' should be invoked
189 /// with the stream to dump into.
printBeforeIfEnabled(Pass * pass,Operation * operation,PrintCallbackFn printCallback)190 void PassManager::IRPrinterConfig::printBeforeIfEnabled(
191     Pass *pass, Operation *operation, PrintCallbackFn printCallback) {
192   // By default, never print.
193 }
194 
195 /// A hook that may be overridden by a derived config that checks if the IR
196 /// of 'operation' should be dumped *after* the pass 'pass' has been
197 /// executed. If the IR should be dumped, 'printCallback' should be invoked
198 /// with the stream to dump into.
printAfterIfEnabled(Pass * pass,Operation * operation,PrintCallbackFn printCallback)199 void PassManager::IRPrinterConfig::printAfterIfEnabled(
200     Pass *pass, Operation *operation, PrintCallbackFn printCallback) {
201   // By default, never print.
202 }
203 
204 //===----------------------------------------------------------------------===//
205 // PassManager
206 //===----------------------------------------------------------------------===//
207 
208 namespace {
209 /// Simple wrapper config that allows for the simpler interface defined above.
210 struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig {
BasicIRPrinterConfig__anonc17873250611::BasicIRPrinterConfig211   BasicIRPrinterConfig(
212       std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
213       std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
214       bool printModuleScope, bool printAfterOnlyOnChange,
215       OpPrintingFlags opPrintingFlags, raw_ostream &out)
216       : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange,
217                         opPrintingFlags),
218         shouldPrintBeforePass(shouldPrintBeforePass),
219         shouldPrintAfterPass(shouldPrintAfterPass), out(out) {
220     assert((shouldPrintBeforePass || shouldPrintAfterPass) &&
221            "expected at least one valid filter function");
222   }
223 
printBeforeIfEnabled__anonc17873250611::BasicIRPrinterConfig224   void printBeforeIfEnabled(Pass *pass, Operation *operation,
225                             PrintCallbackFn printCallback) final {
226     if (shouldPrintBeforePass && shouldPrintBeforePass(pass, operation))
227       printCallback(out);
228   }
229 
printAfterIfEnabled__anonc17873250611::BasicIRPrinterConfig230   void printAfterIfEnabled(Pass *pass, Operation *operation,
231                            PrintCallbackFn printCallback) final {
232     if (shouldPrintAfterPass && shouldPrintAfterPass(pass, operation))
233       printCallback(out);
234   }
235 
236   /// Filter functions for before and after pass execution.
237   std::function<bool(Pass *, Operation *)> shouldPrintBeforePass;
238   std::function<bool(Pass *, Operation *)> shouldPrintAfterPass;
239 
240   /// The stream to output to.
241   raw_ostream &out;
242 };
243 } // end anonymous namespace
244 
245 /// Add an instrumentation to print the IR before and after pass execution,
246 /// using the provided configuration.
enableIRPrinting(std::unique_ptr<IRPrinterConfig> config)247 void PassManager::enableIRPrinting(std::unique_ptr<IRPrinterConfig> config) {
248   if (config->shouldPrintAtModuleScope() &&
249       getContext()->isMultithreadingEnabled())
250     llvm::report_fatal_error("IR printing can't be setup on a pass-manager "
251                              "without disabling multi-threading first.");
252   addInstrumentation(
253       std::make_unique<IRPrinterInstrumentation>(std::move(config)));
254 }
255 
256 /// Add an instrumentation to print the IR before and after pass execution.
enableIRPrinting(std::function<bool (Pass *,Operation *)> shouldPrintBeforePass,std::function<bool (Pass *,Operation *)> shouldPrintAfterPass,bool printModuleScope,bool printAfterOnlyOnChange,raw_ostream & out,OpPrintingFlags opPrintingFlags)257 void PassManager::enableIRPrinting(
258     std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
259     std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
260     bool printModuleScope, bool printAfterOnlyOnChange, raw_ostream &out,
261     OpPrintingFlags opPrintingFlags) {
262   enableIRPrinting(std::make_unique<BasicIRPrinterConfig>(
263       std::move(shouldPrintBeforePass), std::move(shouldPrintAfterPass),
264       printModuleScope, printAfterOnlyOnChange, opPrintingFlags, out));
265 }
266