1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Support/IndentedOstream.h"
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/GenInfo.h"
17 #include "mlir/TableGen/Operator.h"
18 #include "mlir/TableGen/Pattern.h"
19 #include "mlir/TableGen/Predicate.h"
20 #include "mlir/TableGen/Type.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringSet.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/FormatAdapters.h"
26 #include "llvm/Support/PrettyStackTrace.h"
27 #include "llvm/Support/Signals.h"
28 #include "llvm/TableGen/Error.h"
29 #include "llvm/TableGen/Main.h"
30 #include "llvm/TableGen/Record.h"
31 #include "llvm/TableGen/TableGenBackend.h"
32 
33 using namespace mlir;
34 using namespace mlir::tblgen;
35 
36 using llvm::formatv;
37 using llvm::Record;
38 using llvm::RecordKeeper;
39 
40 #define DEBUG_TYPE "mlir-tblgen-rewritergen"
41 
42 namespace llvm {
43 template <>
44 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
formatllvm::format_provider45   static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
46                      raw_ostream &os, StringRef style) {
47     os << v.first << ":" << v.second;
48   }
49 };
50 } // end namespace llvm
51 
52 //===----------------------------------------------------------------------===//
53 // PatternEmitter
54 //===----------------------------------------------------------------------===//
55 
56 namespace {
57 class PatternEmitter {
58 public:
59   PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
60 
61   // Emits the mlir::RewritePattern struct named `rewriteName`.
62   void emit(StringRef rewriteName);
63 
64 private:
65   // Emits the code for matching ops.
66   void emitMatchLogic(DagNode tree, StringRef opName);
67 
68   // Emits the code for rewriting ops.
69   void emitRewriteLogic();
70 
71   //===--------------------------------------------------------------------===//
72   // Match utilities
73   //===--------------------------------------------------------------------===//
74 
75   // Emits C++ statements for matching the DAG structure.
76   void emitMatch(DagNode tree, StringRef name, int depth);
77 
78   // Emits C++ statements for matching using a native code call.
79   void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
80 
81   // Emits C++ statements for matching the op constrained by the given DAG
82   // `tree` returning the op's variable name.
83   void emitOpMatch(DagNode tree, StringRef opName, int depth);
84 
85   // Emits C++ statements for matching the `argIndex`-th argument of the given
86   // DAG `tree` as an operand.
87   void emitOperandMatch(DagNode tree, StringRef opName, int argIndex,
88                         int depth);
89 
90   // Emits C++ statements for matching the `argIndex`-th argument of the given
91   // DAG `tree` as an attribute.
92   void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
93                           int depth);
94 
95   // Emits C++ for checking a match with a corresponding match failure
96   // diagnostic.
97   void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt,
98                       const llvm::formatv_object_base &failureFmt);
99 
100   // Emits C++ for checking a match with a corresponding match failure
101   // diagnostics.
102   void emitMatchCheck(StringRef opName, const std::string &matchStr,
103                       const std::string &failureStr);
104 
105   //===--------------------------------------------------------------------===//
106   // Rewrite utilities
107   //===--------------------------------------------------------------------===//
108 
109   // The entry point for handling a result pattern rooted at `resultTree`. This
110   // method dispatches to concrete handlers according to `resultTree`'s kind and
111   // returns a symbol representing the whole value pack. Callers are expected to
112   // further resolve the symbol according to the specific use case.
113   //
114   // `depth` is the nesting level of `resultTree`; 0 means top-level result
115   // pattern. For top-level result pattern, `resultIndex` indicates which result
116   // of the matched root op this pattern is intended to replace, which can be
117   // used to deduce the result type of the op generated from this result
118   // pattern.
119   std::string handleResultPattern(DagNode resultTree, int resultIndex,
120                                   int depth);
121 
122   // Emits the C++ statement to replace the matched DAG with a value built via
123   // calling native C++ code.
124   std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth);
125 
126   // Returns the symbol of the old value serving as the replacement.
127   StringRef handleReplaceWithValue(DagNode tree);
128 
129   // Returns the location value to use.
130   std::pair<bool, std::string> getLocation(DagNode tree);
131 
132   // Returns the location value to use.
133   std::string handleLocationDirective(DagNode tree);
134 
135   // Emits the C++ statement to build a new op out of the given DAG `tree` and
136   // returns the variable name that this op is assigned to. If the root op in
137   // DAG `tree` has a specified name, the created op will be assigned to a
138   // variable of the given name. Otherwise, a unique name will be used as the
139   // result value name.
140   std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
141 
142   using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
143 
144   // Emits a local variable for each value and attribute to be used for creating
145   // an op.
146   void createSeparateLocalVarsForOpArgs(DagNode node,
147                                         ChildNodeIndexNameMap &childNodeNames);
148 
149   // Emits the concrete arguments used to call an op's builder.
150   void supplyValuesForOpArgs(DagNode node,
151                              const ChildNodeIndexNameMap &childNodeNames,
152                              int depth);
153 
154   // Emits the local variables for holding all values as a whole and all named
155   // attributes as a whole to be used for creating an op.
156   void createAggregateLocalVarsForOpArgs(
157       DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth);
158 
159   // Returns the C++ expression to construct a constant attribute of the given
160   // `value` for the given attribute kind `attr`.
161   std::string handleConstantAttr(Attribute attr, StringRef value);
162 
163   // Returns the C++ expression to build an argument from the given DAG `leaf`.
164   // `patArgName` is used to bound the argument to the source pattern.
165   std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
166 
167   //===--------------------------------------------------------------------===//
168   // General utilities
169   //===--------------------------------------------------------------------===//
170 
171   // Collects all of the operations within the given dag tree.
172   void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
173 
174   // Returns a unique symbol for a local variable of the given `op`.
175   std::string getUniqueSymbol(const Operator *op);
176 
177   //===--------------------------------------------------------------------===//
178   // Symbol utilities
179   //===--------------------------------------------------------------------===//
180 
181   // Returns how many static values the given DAG `node` correspond to.
182   int getNodeValueCount(DagNode node);
183 
184 private:
185   // Pattern instantiation location followed by the location of multiclass
186   // prototypes used. This is intended to be used as a whole to
187   // PrintFatalError() on errors.
188   ArrayRef<llvm::SMLoc> loc;
189 
190   // Op's TableGen Record to wrapper object.
191   RecordOperatorMap *opMap;
192 
193   // Handy wrapper for pattern being emitted.
194   Pattern pattern;
195 
196   // Map for all bound symbols' info.
197   SymbolInfoMap symbolInfoMap;
198 
199   // The next unused ID for newly created values.
200   unsigned nextValueId;
201 
202   raw_indented_ostream os;
203 
204   // Format contexts containing placeholder substitutions.
205   FmtContext fmtCtx;
206 
207   // Number of op processed.
208   int opCounter = 0;
209 };
210 } // end anonymous namespace
211 
PatternEmitter(Record * pat,RecordOperatorMap * mapper,raw_ostream & os)212 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
213                                raw_ostream &os)
214     : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
215       symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) {
216   fmtCtx.withBuilder("rewriter");
217 }
218 
handleConstantAttr(Attribute attr,StringRef value)219 std::string PatternEmitter::handleConstantAttr(Attribute attr,
220                                                StringRef value) {
221   if (!attr.isConstBuildable())
222     PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
223                              " does not have the 'constBuilderCall' field");
224 
225   // TODO: Verify the constants here
226   return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value));
227 }
228 
229 // Helper function to match patterns.
emitMatch(DagNode tree,StringRef name,int depth)230 void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
231   if (tree.isNativeCodeCall()) {
232     emitNativeCodeMatch(tree, name, depth);
233     return;
234   }
235 
236   if (tree.isOperation()) {
237     emitOpMatch(tree, name, depth);
238     return;
239   }
240 
241   PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match.");
242 }
243 
244 // Helper function to match patterns.
emitNativeCodeMatch(DagNode tree,StringRef opName,int depth)245 void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
246                                          int depth) {
247   LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: ");
248   LLVM_DEBUG(tree.print(llvm::dbgs()));
249   LLVM_DEBUG(llvm::dbgs() << '\n');
250 
251   // TODO(suderman): iterate through arguments, determine their types, output
252   // names.
253   SmallVector<std::string, 8> capture(8);
254   if (tree.getNumArgs() > 8) {
255     PrintFatalError(loc,
256                     "unsupported NativeCodeCall matcher argument numbers: " +
257                         Twine(tree.getNumArgs()));
258   }
259 
260   raw_indented_ostream::DelimitedScope scope(os);
261 
262   os << "if(!" << opName << ") return failure();\n";
263   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
264     std::string argName = formatv("arg{0}_{1}", depth, i);
265     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
266       os << "Value " << argName << ";\n";
267     } else {
268       auto leaf = tree.getArgAsLeaf(i);
269       if (leaf.isAttrMatcher() || leaf.isConstantAttr()) {
270         os << "Attribute " << argName << ";\n";
271       } else if (leaf.isOperandMatcher()) {
272         os << "Operation " << argName << ";\n";
273       }
274     }
275 
276     capture[i] = std::move(argName);
277   }
278 
279   bool hasLocationDirective;
280   std::string locToUse;
281   std::tie(hasLocationDirective, locToUse) = getLocation(tree);
282 
283   auto fmt = tree.getNativeCodeTemplate();
284   auto nativeCodeCall = std::string(tgfmt(
285       fmt, &fmtCtx.addSubst("_loc", locToUse), opName, capture[0], capture[1],
286       capture[2], capture[3], capture[4], capture[5], capture[6], capture[7]));
287 
288   os << "if (failed(" << nativeCodeCall << ")) return failure();\n";
289 
290   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
291     auto name = tree.getArgName(i);
292     if (!name.empty() && name != "_") {
293       os << formatv("{0} = {1};\n", name, capture[i]);
294     }
295   }
296 
297   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
298     std::string argName = capture[i];
299 
300     // Handle nested DAG construct first
301     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
302       PrintFatalError(
303           loc, formatv("Matching nested tree in NativeCodecall not support for "
304                        "{0} as arg {1}",
305                        argName, i));
306     }
307 
308     DagLeaf leaf = tree.getArgAsLeaf(i);
309     auto constraint = leaf.getAsConstraint();
310 
311     auto self = formatv("{0}", argName);
312     emitMatchCheck(
313         opName,
314         tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
315         formatv("\"operand {0} of native code call '{1}' failed to satisfy "
316                 "constraint: "
317                 "'{2}'\"",
318                 i, tree.getNativeCodeTemplate(), constraint.getDescription()));
319   }
320 
321   LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n");
322 }
323 
324 // Helper function to match patterns.
emitOpMatch(DagNode tree,StringRef opName,int depth)325 void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
326   Operator &op = tree.getDialectOp(opMap);
327   LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
328                           << op.getOperationName() << "' at depth " << depth
329                           << '\n');
330 
331   std::string castedName = formatv("castedOp{0}", depth);
332   os << formatv("auto {0} = ::llvm::dyn_cast_or_null<{2}>({1}); "
333                 "(void){0};\n",
334                 castedName, opName, op.getQualCppClassName());
335   // Skip the operand matching at depth 0 as the pattern rewriter already does.
336   if (depth != 0) {
337     // Skip if there is no defining operation (e.g., arguments to function).
338     os << formatv("if (!{0}) return failure();\n", castedName);
339   }
340   if (tree.getNumArgs() != op.getNumArgs()) {
341     PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
342                                  "pattern vs. {2} in definition",
343                                  op.getOperationName(), tree.getNumArgs(),
344                                  op.getNumArgs()));
345   }
346 
347   // If the operand's name is set, set to that variable.
348   auto name = tree.getSymbol();
349   if (!name.empty())
350     os << formatv("{0} = {1};\n", name, castedName);
351 
352   for (int i = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; ++i) {
353     auto opArg = op.getArg(i);
354     std::string argName = formatv("op{0}", depth + 1);
355 
356     // Handle nested DAG construct first
357     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
358       if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
359         if (operand->isVariableLength()) {
360           auto error = formatv("use nested DAG construct to match op {0}'s "
361                                "variadic operand #{1} unsupported now",
362                                op.getOperationName(), i);
363           PrintFatalError(loc, error);
364         }
365       }
366       os << "{\n";
367 
368       // Attributes don't count for getODSOperands.
369       os.indent() << formatv(
370           "auto *{0} = "
371           "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
372           argName, castedName, nextOperand++);
373       emitMatch(argTree, argName, depth + 1);
374       os << formatv("tblgen_ops[{0}] = {1};\n", ++opCounter, argName);
375       os.unindent() << "}\n";
376       continue;
377     }
378 
379     // Next handle DAG leaf: operand or attribute
380     if (opArg.is<NamedTypeConstraint *>()) {
381       // emitOperandMatch's argument indexing counts attributes.
382       emitOperandMatch(tree, castedName, i, depth);
383       ++nextOperand;
384     } else if (opArg.is<NamedAttribute *>()) {
385       emitAttributeMatch(tree, opName, i, depth);
386     } else {
387       PrintFatalError(loc, "unhandled case when matching op");
388     }
389   }
390   LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
391                           << op.getOperationName() << "' at depth " << depth
392                           << '\n');
393 }
394 
emitOperandMatch(DagNode tree,StringRef opName,int argIndex,int depth)395 void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
396                                       int argIndex, int depth) {
397   Operator &op = tree.getDialectOp(opMap);
398   auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
399   auto matcher = tree.getArgAsLeaf(argIndex);
400 
401   // If a constraint is specified, we need to generate C++ statements to
402   // check the constraint.
403   if (!matcher.isUnspecified()) {
404     if (!matcher.isOperandMatcher()) {
405       PrintFatalError(
406           loc, formatv("the {1}-th argument of op '{0}' should be an operand",
407                        op.getOperationName(), argIndex + 1));
408     }
409 
410     // Only need to verify if the matcher's type is different from the one
411     // of op definition.
412     Constraint constraint = matcher.getAsConstraint();
413     if (operand->constraint != constraint) {
414       if (operand->isVariableLength()) {
415         auto error = formatv(
416             "further constrain op {0}'s variadic operand #{1} unsupported now",
417             op.getOperationName(), argIndex);
418         PrintFatalError(loc, error);
419       }
420       auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()",
421                           opName, argIndex);
422       emitMatchCheck(
423           opName,
424           tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
425           formatv("\"operand {0} of op '{1}' failed to satisfy constraint: "
426                   "'{2}'\"",
427                   operand - op.operand_begin(), op.getOperationName(),
428                   constraint.getDescription()));
429     }
430   }
431 
432   // Capture the value
433   auto name = tree.getArgName(argIndex);
434   // `$_` is a special symbol to ignore op argument matching.
435   if (!name.empty() && name != "_") {
436     // We need to subtract the number of attributes before this operand to get
437     // the index in the operand list.
438     auto numPrevAttrs = std::count_if(
439         op.arg_begin(), op.arg_begin() + argIndex,
440         [](const Argument &arg) { return arg.is<NamedAttribute *>(); });
441 
442     auto res = symbolInfoMap.findBoundSymbol(name, op, argIndex);
443     os << formatv("{0} = {1}.getODSOperands({2});\n",
444                   res->second.getVarName(name), opName,
445                   argIndex - numPrevAttrs);
446   }
447 }
448 
emitAttributeMatch(DagNode tree,StringRef opName,int argIndex,int depth)449 void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
450                                         int argIndex, int depth) {
451   Operator &op = tree.getDialectOp(opMap);
452   auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
453   const auto &attr = namedAttr->attr;
454 
455   os << "{\n";
456   os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
457                          "(void)tblgen_attr;\n",
458                          opName, attr.getStorageType(), namedAttr->name);
459 
460   // TODO: This should use getter method to avoid duplication.
461   if (attr.hasDefaultValue()) {
462     os << "if (!tblgen_attr) tblgen_attr = "
463        << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
464                             attr.getDefaultValue()))
465        << ";\n";
466   } else if (attr.isOptional()) {
467     // For a missing attribute that is optional according to definition, we
468     // should just capture a mlir::Attribute() to signal the missing state.
469     // That is precisely what getAttr() returns on missing attributes.
470   } else {
471     emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
472                    formatv("\"expected op '{0}' to have attribute '{1}' "
473                            "of type '{2}'\"",
474                            op.getOperationName(), namedAttr->name,
475                            attr.getStorageType()));
476   }
477 
478   auto matcher = tree.getArgAsLeaf(argIndex);
479   if (!matcher.isUnspecified()) {
480     if (!matcher.isAttrMatcher()) {
481       PrintFatalError(
482           loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
483                        op.getOperationName(), argIndex + 1));
484     }
485 
486     // If a constraint is specified, we need to generate C++ statements to
487     // check the constraint.
488     emitMatchCheck(
489         opName,
490         tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")),
491         formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
492                 "{2}\"",
493                 op.getOperationName(), namedAttr->name,
494                 matcher.getAsConstraint().getDescription()));
495   }
496 
497   // Capture the value
498   auto name = tree.getArgName(argIndex);
499   // `$_` is a special symbol to ignore op argument matching.
500   if (!name.empty() && name != "_") {
501     os << formatv("{0} = tblgen_attr;\n", name);
502   }
503 
504   os.unindent() << "}\n";
505 }
506 
emitMatchCheck(StringRef opName,const FmtObjectBase & matchFmt,const llvm::formatv_object_base & failureFmt)507 void PatternEmitter::emitMatchCheck(
508     StringRef opName, const FmtObjectBase &matchFmt,
509     const llvm::formatv_object_base &failureFmt) {
510   emitMatchCheck(opName, matchFmt.str(), failureFmt.str());
511 }
512 
emitMatchCheck(StringRef opName,const std::string & matchStr,const std::string & failureStr)513 void PatternEmitter::emitMatchCheck(StringRef opName,
514                                     const std::string &matchStr,
515                                     const std::string &failureStr) {
516 
517   os << "if (!(" << matchStr << "))";
518   os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName
519                               << ", [&](::mlir::Diagnostic &diag) {\n  diag << "
520                               << failureStr << ";\n});";
521 }
522 
emitMatchLogic(DagNode tree,StringRef opName)523 void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
524   LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
525   int depth = 0;
526   emitMatch(tree, opName, depth);
527 
528   for (auto &appliedConstraint : pattern.getConstraints()) {
529     auto &constraint = appliedConstraint.constraint;
530     auto &entities = appliedConstraint.entities;
531 
532     auto condition = constraint.getConditionTemplate();
533     if (isa<TypeConstraint>(constraint)) {
534       auto self = formatv("({0}.getType())",
535                           symbolInfoMap.getValueAndRangeUse(entities.front()));
536       emitMatchCheck(
537           opName, tgfmt(condition, &fmtCtx.withSelf(self.str())),
538           formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"",
539                   entities.front(), constraint.getDescription()));
540 
541     } else if (isa<AttrConstraint>(constraint)) {
542       PrintFatalError(
543           loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
544     } else {
545       // TODO: replace formatv arguments with the exact specified
546       // args.
547       if (entities.size() > 4) {
548         PrintFatalError(loc, "only support up to 4-entity constraints now");
549       }
550       SmallVector<std::string, 4> names;
551       int i = 0;
552       for (int e = entities.size(); i < e; ++i)
553         names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
554       std::string self = appliedConstraint.self;
555       if (!self.empty())
556         self = symbolInfoMap.getValueAndRangeUse(self);
557       for (; i < 4; ++i)
558         names.push_back("<unused>");
559       emitMatchCheck(opName,
560                      tgfmt(condition, &fmtCtx.withSelf(self), names[0],
561                            names[1], names[2], names[3]),
562                      formatv("\"entities '{0}' failed to satisfy constraint: "
563                              "{1}\"",
564                              llvm::join(entities, ", "),
565                              constraint.getDescription()));
566     }
567   }
568 
569   // Some of the operands could be bound to the same symbol name, we need
570   // to enforce equality constraint on those.
571   // TODO: we should be able to emit equality checks early
572   // and short circuit unnecessary work if vars are not equal.
573   for (auto symbolInfoIt = symbolInfoMap.begin();
574        symbolInfoIt != symbolInfoMap.end();) {
575     auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
576     auto startRange = range.first;
577     auto endRange = range.second;
578 
579     auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
580     for (++startRange; startRange != endRange; ++startRange) {
581       auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
582       emitMatchCheck(
583           opName,
584           formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
585           formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
586                   secondOperand));
587     }
588 
589     symbolInfoIt = endRange;
590   }
591 
592   LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
593 }
594 
collectOps(DagNode tree,llvm::SmallPtrSetImpl<const Operator * > & ops)595 void PatternEmitter::collectOps(DagNode tree,
596                                 llvm::SmallPtrSetImpl<const Operator *> &ops) {
597   // Check if this tree is an operation.
598   if (tree.isOperation()) {
599     const Operator &op = tree.getDialectOp(opMap);
600     LLVM_DEBUG(llvm::dbgs()
601                << "found operation " << op.getOperationName() << '\n');
602     ops.insert(&op);
603   }
604 
605   // Recurse the arguments of the tree.
606   for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
607     if (auto child = tree.getArgAsNestedDag(i))
608       collectOps(child, ops);
609 }
610 
emit(StringRef rewriteName)611 void PatternEmitter::emit(StringRef rewriteName) {
612   // Get the DAG tree for the source pattern.
613   DagNode sourceTree = pattern.getSourcePattern();
614 
615   const Operator &rootOp = pattern.getSourceRootOp();
616   auto rootName = rootOp.getOperationName();
617 
618   // Collect the set of result operations.
619   llvm::SmallPtrSet<const Operator *, 4> resultOps;
620   LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
621   for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
622     collectOps(pattern.getResultPattern(i), resultOps);
623   }
624   LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
625 
626   // Emit RewritePattern for Pattern.
627   auto locs = pattern.getLocation();
628   os << formatv("/* Generated from:\n    {0:$[ instantiating\n    ]}\n*/\n",
629                 make_range(locs.rbegin(), locs.rend()));
630   os << formatv(R"(struct {0} : public ::mlir::RewritePattern {
631   {0}(::mlir::MLIRContext *context)
632       : ::mlir::RewritePattern("{1}", {{)",
633                 rewriteName, rootName);
634   // Sort result operators by name.
635   llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
636                                                          resultOps.end());
637   llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
638     return lhs->getOperationName() < rhs->getOperationName();
639   });
640   llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) {
641     os << '"' << op->getOperationName() << '"';
642   });
643   os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
644 
645   // Emit matchAndRewrite() function.
646   {
647     auto classScope = os.scope();
648     os.reindent(R"(
649     ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0,
650         ::mlir::PatternRewriter &rewriter) const override {)")
651         << '\n';
652     {
653       auto functionScope = os.scope();
654 
655       // Register all symbols bound in the source pattern.
656       pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
657 
658       LLVM_DEBUG(llvm::dbgs()
659                  << "start creating local variables for capturing matches\n");
660       os << "// Variables for capturing values and attributes used while "
661             "creating ops\n";
662       // Create local variables for storing the arguments and results bound
663       // to symbols.
664       for (const auto &symbolInfoPair : symbolInfoMap) {
665         const auto &symbol = symbolInfoPair.first;
666         const auto &info = symbolInfoPair.second;
667 
668         os << info.getVarDecl(symbol);
669       }
670       // TODO: capture ops with consistent numbering so that it can be
671       // reused for fused loc.
672       os << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n",
673                     pattern.getSourcePattern().getNumOps());
674       LLVM_DEBUG(llvm::dbgs()
675                  << "done creating local variables for capturing matches\n");
676 
677       os << "// Match\n";
678       os << "tblgen_ops[0] = op0;\n";
679       emitMatchLogic(sourceTree, "op0");
680 
681       os << "\n// Rewrite\n";
682       emitRewriteLogic();
683 
684       os << "return ::mlir::success();\n";
685     }
686     os << "};\n";
687   }
688   os << "};\n\n";
689 }
690 
emitRewriteLogic()691 void PatternEmitter::emitRewriteLogic() {
692   LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
693   const Operator &rootOp = pattern.getSourceRootOp();
694   int numExpectedResults = rootOp.getNumResults();
695   int numResultPatterns = pattern.getNumResultPatterns();
696 
697   // First register all symbols bound to ops generated in result patterns.
698   pattern.collectResultPatternBoundSymbols(symbolInfoMap);
699 
700   // Only the last N static values generated are used to replace the matched
701   // root N-result op. We need to calculate the starting index (of the results
702   // of the matched op) each result pattern is to replace.
703   SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
704   // If we don't need to replace any value at all, set the replacement starting
705   // index as the number of result patterns so we skip all of them when trying
706   // to replace the matched op's results.
707   int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
708   for (int i = numResultPatterns - 1; i >= 0; --i) {
709     auto numValues = getNodeValueCount(pattern.getResultPattern(i));
710     offsets[i] = offsets[i + 1] - numValues;
711     if (offsets[i] == 0) {
712       if (replStartIndex == -1)
713         replStartIndex = i;
714     } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
715       auto error = formatv(
716           "cannot use the same multi-result op '{0}' to generate both "
717           "auxiliary values and values to be used for replacing the matched op",
718           pattern.getResultPattern(i).getSymbol());
719       PrintFatalError(loc, error);
720     }
721   }
722 
723   if (offsets.front() > 0) {
724     const char error[] = "no enough values generated to replace the matched op";
725     PrintFatalError(loc, error);
726   }
727 
728   os << "auto odsLoc = rewriter.getFusedLoc({";
729   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
730     os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
731   }
732   os << "}); (void)odsLoc;\n";
733 
734   // Process auxiliary result patterns.
735   for (int i = 0; i < replStartIndex; ++i) {
736     DagNode resultTree = pattern.getResultPattern(i);
737     auto val = handleResultPattern(resultTree, offsets[i], 0);
738     // Normal op creation will be streamed to `os` by the above call; but
739     // NativeCodeCall will only be materialized to `os` if it is used. Here
740     // we are handling auxiliary patterns so we want the side effect even if
741     // NativeCodeCall is not replacing matched root op's results.
742     if (resultTree.isNativeCodeCall())
743       os << val << ";\n";
744   }
745 
746   if (numExpectedResults == 0) {
747     assert(replStartIndex >= numResultPatterns &&
748            "invalid auxiliary vs. replacement pattern division!");
749     // No result to replace. Just erase the op.
750     os << "rewriter.eraseOp(op0);\n";
751   } else {
752     // Process replacement result patterns.
753     os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
754     for (int i = replStartIndex; i < numResultPatterns; ++i) {
755       DagNode resultTree = pattern.getResultPattern(i);
756       auto val = handleResultPattern(resultTree, offsets[i], 0);
757       os << "\n";
758       // Resolve each symbol for all range use so that we can loop over them.
759       // We need an explicit cast to `SmallVector` to capture the cases where
760       // `{0}` resolves to an `Operation::result_range` as well as cases that
761       // are not iterable (e.g. vector that gets wrapped in additional braces by
762       // RewriterGen).
763       // TODO: Revisit the need for materializing a vector.
764       os << symbolInfoMap.getAllRangeUse(
765           val,
766           "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
767           "  tblgen_repl_values.push_back(v);\n}\n",
768           "\n");
769     }
770     os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
771   }
772 
773   LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
774 }
775 
getUniqueSymbol(const Operator * op)776 std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
777   return std::string(
778       formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++));
779 }
780 
handleResultPattern(DagNode resultTree,int resultIndex,int depth)781 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
782                                                 int resultIndex, int depth) {
783   LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
784   LLVM_DEBUG(resultTree.print(llvm::dbgs()));
785   LLVM_DEBUG(llvm::dbgs() << '\n');
786 
787   if (resultTree.isLocationDirective()) {
788     PrintFatalError(loc,
789                     "location directive can only be used with op creation");
790   }
791 
792   if (resultTree.isNativeCodeCall()) {
793     auto symbol = handleReplaceWithNativeCodeCall(resultTree, depth);
794     symbolInfoMap.bindValue(symbol);
795     return symbol;
796   }
797 
798   if (resultTree.isReplaceWithValue())
799     return handleReplaceWithValue(resultTree).str();
800 
801   // Normal op creation.
802   auto symbol = handleOpCreation(resultTree, resultIndex, depth);
803   if (resultTree.getSymbol().empty()) {
804     // This is an op not explicitly bound to a symbol in the rewrite rule.
805     // Register the auto-generated symbol for it.
806     symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
807   }
808   return symbol;
809 }
810 
handleReplaceWithValue(DagNode tree)811 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
812   assert(tree.isReplaceWithValue());
813 
814   if (tree.getNumArgs() != 1) {
815     PrintFatalError(
816         loc, "replaceWithValue directive must take exactly one argument");
817   }
818 
819   if (!tree.getSymbol().empty()) {
820     PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
821   }
822 
823   return tree.getArgName(0);
824 }
825 
handleLocationDirective(DagNode tree)826 std::string PatternEmitter::handleLocationDirective(DagNode tree) {
827   assert(tree.isLocationDirective());
828   auto lookUpArgLoc = [this, &tree](int idx) {
829     const auto *const lookupFmt = "(*{0}.begin()).getLoc()";
830     return symbolInfoMap.getAllRangeUse(tree.getArgName(idx), lookupFmt);
831   };
832 
833   if (tree.getNumArgs() == 0)
834     llvm::PrintFatalError(
835         "At least one argument to location directive required");
836 
837   if (!tree.getSymbol().empty())
838     PrintFatalError(loc, "cannot bind symbol to location");
839 
840   if (tree.getNumArgs() == 1) {
841     DagLeaf leaf = tree.getArgAsLeaf(0);
842     if (leaf.isStringAttr())
843       return formatv("::mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"), "
844                      "rewriter.getContext())",
845                      leaf.getStringAttr())
846           .str();
847     return lookUpArgLoc(0);
848   }
849 
850   std::string ret;
851   llvm::raw_string_ostream os(ret);
852   std::string strAttr;
853   os << "rewriter.getFusedLoc({";
854   bool first = true;
855   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
856     DagLeaf leaf = tree.getArgAsLeaf(i);
857     // Handle the optional string value.
858     if (leaf.isStringAttr()) {
859       if (!strAttr.empty())
860         llvm::PrintFatalError("Only one string attribute may be specified");
861       strAttr = leaf.getStringAttr();
862       continue;
863     }
864     os << (first ? "" : ", ") << lookUpArgLoc(i);
865     first = false;
866   }
867   os << "}";
868   if (!strAttr.empty()) {
869     os << ", rewriter.getStringAttr(\"" << strAttr << "\")";
870   }
871   os << ")";
872   return os.str();
873 }
874 
handleOpArgument(DagLeaf leaf,StringRef patArgName)875 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
876                                              StringRef patArgName) {
877   if (leaf.isStringAttr())
878     PrintFatalError(loc, "raw string not supported as argument");
879   if (leaf.isConstantAttr()) {
880     auto constAttr = leaf.getAsConstantAttr();
881     return handleConstantAttr(constAttr.getAttribute(),
882                               constAttr.getConstantValue());
883   }
884   if (leaf.isEnumAttrCase()) {
885     auto enumCase = leaf.getAsEnumAttrCase();
886     if (enumCase.isStrCase())
887       return handleConstantAttr(enumCase, enumCase.getSymbol());
888     // This is an enum case backed by an IntegerAttr. We need to get its value
889     // to build the constant.
890     std::string val = std::to_string(enumCase.getValue());
891     return handleConstantAttr(enumCase, val);
892   }
893 
894   LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
895   auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
896   if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
897     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
898                             << "' (via symbol ref)\n");
899     return argName;
900   }
901   if (leaf.isNativeCodeCall()) {
902     auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
903     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
904                             << "' (via NativeCodeCall)\n");
905     return std::string(repl);
906   }
907   PrintFatalError(loc, "unhandled case when rewriting op");
908 }
909 
handleReplaceWithNativeCodeCall(DagNode tree,int depth)910 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
911                                                             int depth) {
912   LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
913   LLVM_DEBUG(tree.print(llvm::dbgs()));
914   LLVM_DEBUG(llvm::dbgs() << '\n');
915 
916   auto fmt = tree.getNativeCodeTemplate();
917   // TODO: replace formatv arguments with the exact specified args.
918   SmallVector<std::string, 8> attrs(8);
919   if (tree.getNumArgs() > 8) {
920     PrintFatalError(loc,
921                     "unsupported NativeCodeCall replace argument numbers: " +
922                         Twine(tree.getNumArgs()));
923   }
924   bool hasLocationDirective;
925   std::string locToUse;
926   std::tie(hasLocationDirective, locToUse) = getLocation(tree);
927 
928   for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
929     if (tree.isNestedDagArg(i)) {
930       attrs[i] = handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1);
931     } else {
932       attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
933     }
934     LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
935                             << " replacement: " << attrs[i] << "\n");
936   }
937   return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs[0],
938                            attrs[1], attrs[2], attrs[3], attrs[4], attrs[5],
939                            attrs[6], attrs[7]));
940 }
941 
getNodeValueCount(DagNode node)942 int PatternEmitter::getNodeValueCount(DagNode node) {
943   if (node.isOperation()) {
944     // If the op is bound to a symbol in the rewrite rule, query its result
945     // count from the symbol info map.
946     auto symbol = node.getSymbol();
947     if (!symbol.empty()) {
948       return symbolInfoMap.getStaticValueCount(symbol);
949     }
950     // Otherwise this is an unbound op; we will use all its results.
951     return pattern.getDialectOp(node).getNumResults();
952   }
953   // TODO: This considers all NativeCodeCall as returning one
954   // value. Enhance if multi-value ones are needed.
955   return 1;
956 }
957 
getLocation(DagNode tree)958 std::pair<bool, std::string> PatternEmitter::getLocation(DagNode tree) {
959   auto numPatArgs = tree.getNumArgs();
960 
961   if (numPatArgs != 0) {
962     if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1))
963       if (lastArg.isLocationDirective()) {
964         return std::make_pair(true, handleLocationDirective(lastArg));
965       }
966   }
967 
968   // If no explicit location is given, use the default, all fused, location.
969   return std::make_pair(false, "odsLoc");
970 }
971 
handleOpCreation(DagNode tree,int resultIndex,int depth)972 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
973                                              int depth) {
974   LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
975   LLVM_DEBUG(tree.print(llvm::dbgs()));
976   LLVM_DEBUG(llvm::dbgs() << '\n');
977 
978   Operator &resultOp = tree.getDialectOp(opMap);
979   auto numOpArgs = resultOp.getNumArgs();
980   auto numPatArgs = tree.getNumArgs();
981 
982   bool hasLocationDirective;
983   std::string locToUse;
984   std::tie(hasLocationDirective, locToUse) = getLocation(tree);
985 
986   auto inPattern = numPatArgs - hasLocationDirective;
987   if (numOpArgs != inPattern) {
988     PrintFatalError(loc,
989                     formatv("resultant op '{0}' argument number mismatch: "
990                             "{1} in pattern vs. {2} in definition",
991                             resultOp.getOperationName(), inPattern, numOpArgs));
992   }
993 
994   // A map to collect all nested DAG child nodes' names, with operand index as
995   // the key. This includes both bound and unbound child nodes.
996   ChildNodeIndexNameMap childNodeNames;
997 
998   // First go through all the child nodes who are nested DAG constructs to
999   // create ops for them and remember the symbol names for them, so that we can
1000   // use the results in the current node. This happens in a recursive manner.
1001   for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
1002     if (auto child = tree.getArgAsNestedDag(i))
1003       childNodeNames[i] = handleResultPattern(child, i, depth + 1);
1004   }
1005 
1006   // The name of the local variable holding this op.
1007   std::string valuePackName;
1008   // The symbol for holding the result of this pattern. Note that the result of
1009   // this pattern is not necessarily the same as the variable created by this
1010   // pattern because we can use `__N` suffix to refer only a specific result if
1011   // the generated op is a multi-result op.
1012   std::string resultValue;
1013   if (tree.getSymbol().empty()) {
1014     // No symbol is explicitly bound to this op in the pattern. Generate a
1015     // unique name.
1016     valuePackName = resultValue = getUniqueSymbol(&resultOp);
1017   } else {
1018     resultValue = std::string(tree.getSymbol());
1019     // Strip the index to get the name for the value pack and use it to name the
1020     // local variable for the op.
1021     valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue));
1022   }
1023 
1024   // Create the local variable for this op.
1025   os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(),
1026                 valuePackName);
1027 
1028   // Right now ODS don't have general type inference support. Except a few
1029   // special cases listed below, DRR needs to supply types for all results
1030   // when building an op.
1031   bool isSameOperandsAndResultType =
1032       resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType");
1033   bool useFirstAttr =
1034       resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType");
1035 
1036   if (isSameOperandsAndResultType || useFirstAttr) {
1037     // We know how to deduce the result type for ops with these traits and we've
1038     // generated builders taking aggregate parameters. Use those builders to
1039     // create the ops.
1040 
1041     // First prepare local variables for op arguments used in builder call.
1042     createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1043 
1044     // Then create the op.
1045     os.scope("", "\n}\n").os << formatv(
1046         "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
1047         valuePackName, resultOp.getQualCppClassName(), locToUse);
1048     return resultValue;
1049   }
1050 
1051   bool usePartialResults = valuePackName != resultValue;
1052 
1053   if (usePartialResults || depth > 0 || resultIndex < 0) {
1054     // For these cases (broadcastable ops, op results used both as auxiliary
1055     // values and replacement values, ops in nested patterns, auxiliary ops), we
1056     // still need to supply the result types when building the op. But because
1057     // we don't generate a builder automatically with ODS for them, it's the
1058     // developer's responsibility to make sure such a builder (with result type
1059     // deduction ability) exists. We go through the separate-parameter builder
1060     // here given that it's easier for developers to write compared to
1061     // aggregate-parameter builders.
1062     createSeparateLocalVarsForOpArgs(tree, childNodeNames);
1063 
1064     os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
1065                              resultOp.getQualCppClassName(), locToUse);
1066     supplyValuesForOpArgs(tree, childNodeNames, depth);
1067     os << "\n  );\n}\n";
1068     return resultValue;
1069   }
1070 
1071   // If depth == 0 and resultIndex >= 0, it means we are replacing the values
1072   // generated from the source pattern root op. Then we can use the source
1073   // pattern's value types to determine the value type of the generated op
1074   // here.
1075 
1076   // First prepare local variables for op arguments used in builder call.
1077   createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1078 
1079   // Then prepare the result types. We need to specify the types for all
1080   // results.
1081   os.indent() << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; "
1082                          "(void)tblgen_types;\n");
1083   int numResults = resultOp.getNumResults();
1084   if (numResults != 0) {
1085     for (int i = 0; i < numResults; ++i)
1086       os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n"
1087                     "  tblgen_types.push_back(v.getType());\n}\n",
1088                     resultIndex + i);
1089   }
1090   os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
1091                 "tblgen_values, tblgen_attrs);\n",
1092                 valuePackName, resultOp.getQualCppClassName(), locToUse);
1093   os.unindent() << "}\n";
1094   return resultValue;
1095 }
1096 
createSeparateLocalVarsForOpArgs(DagNode node,ChildNodeIndexNameMap & childNodeNames)1097 void PatternEmitter::createSeparateLocalVarsForOpArgs(
1098     DagNode node, ChildNodeIndexNameMap &childNodeNames) {
1099   Operator &resultOp = node.getDialectOp(opMap);
1100 
1101   // Now prepare operands used for building this op:
1102   // * If the operand is non-variadic, we create a `Value` local variable.
1103   // * If the operand is variadic, we create a `SmallVector<Value>` local
1104   //   variable.
1105 
1106   int valueIndex = 0; // An index for uniquing local variable names.
1107   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1108     const auto *operand =
1109         resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
1110     // We do not need special handling for attributes.
1111     if (!operand)
1112       continue;
1113 
1114     raw_indented_ostream::DelimitedScope scope(os);
1115     std::string varName;
1116     if (operand->isVariadic()) {
1117       varName = std::string(formatv("tblgen_values_{0}", valueIndex++));
1118       os << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n", varName);
1119       std::string range;
1120       if (node.isNestedDagArg(argIndex)) {
1121         range = childNodeNames[argIndex];
1122       } else {
1123         range = std::string(node.getArgName(argIndex));
1124       }
1125       // Resolve the symbol for all range use so that we have a uniform way of
1126       // capturing the values.
1127       range = symbolInfoMap.getValueAndRangeUse(range);
1128       os << formatv("for (auto v: {0}) {{\n  {1}.push_back(v);\n}\n", range,
1129                     varName);
1130     } else {
1131       varName = std::string(formatv("tblgen_value_{0}", valueIndex++));
1132       os << formatv("::mlir::Value {0} = ", varName);
1133       if (node.isNestedDagArg(argIndex)) {
1134         os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
1135       } else {
1136         DagLeaf leaf = node.getArgAsLeaf(argIndex);
1137         auto symbol =
1138             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1139         if (leaf.isNativeCodeCall()) {
1140           os << std::string(
1141               tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1142         } else {
1143           os << symbol;
1144         }
1145       }
1146       os << ";\n";
1147     }
1148 
1149     // Update to use the newly created local variable for building the op later.
1150     childNodeNames[argIndex] = varName;
1151   }
1152 }
1153 
supplyValuesForOpArgs(DagNode node,const ChildNodeIndexNameMap & childNodeNames,int depth)1154 void PatternEmitter::supplyValuesForOpArgs(
1155     DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1156   Operator &resultOp = node.getDialectOp(opMap);
1157   for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
1158        argIndex != numOpArgs; ++argIndex) {
1159     // Start each argument on its own line.
1160     os << ",\n    ";
1161 
1162     Argument opArg = resultOp.getArg(argIndex);
1163     // Handle the case of operand first.
1164     if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
1165       if (!operand->name.empty())
1166         os << "/*" << operand->name << "=*/";
1167       os << childNodeNames.lookup(argIndex);
1168       continue;
1169     }
1170 
1171     // The argument in the op definition.
1172     auto opArgName = resultOp.getArgName(argIndex);
1173     if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1174       if (!subTree.isNativeCodeCall())
1175         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1176                              "for creating attribute");
1177       os << formatv("/*{0}=*/{1}", opArgName,
1178                     handleReplaceWithNativeCodeCall(subTree, depth));
1179     } else {
1180       auto leaf = node.getArgAsLeaf(argIndex);
1181       // The argument in the result DAG pattern.
1182       auto patArgName = node.getArgName(argIndex);
1183       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
1184         // TODO: Refactor out into map to avoid recomputing these.
1185         if (!opArg.is<NamedAttribute *>())
1186           PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
1187         if (!patArgName.empty())
1188           os << "/*" << patArgName << "=*/";
1189       } else {
1190         os << "/*" << opArgName << "=*/";
1191       }
1192       os << handleOpArgument(leaf, patArgName);
1193     }
1194   }
1195 }
1196 
createAggregateLocalVarsForOpArgs(DagNode node,const ChildNodeIndexNameMap & childNodeNames,int depth)1197 void PatternEmitter::createAggregateLocalVarsForOpArgs(
1198     DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1199   Operator &resultOp = node.getDialectOp(opMap);
1200 
1201   auto scope = os.scope();
1202   os << formatv("::mlir::SmallVector<::mlir::Value, 4> "
1203                 "tblgen_values; (void)tblgen_values;\n");
1204   os << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> "
1205                 "tblgen_attrs; (void)tblgen_attrs;\n");
1206 
1207   const char *addAttrCmd =
1208       "if (auto tmpAttr = {1}) {\n"
1209       "  tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), "
1210       "tmpAttr);\n}\n";
1211   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1212     if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
1213       // The argument in the op definition.
1214       auto opArgName = resultOp.getArgName(argIndex);
1215       if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1216         if (!subTree.isNativeCodeCall())
1217           PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1218                                "for creating attribute");
1219         os << formatv(addAttrCmd, opArgName,
1220                       handleReplaceWithNativeCodeCall(subTree, depth + 1));
1221       } else {
1222         auto leaf = node.getArgAsLeaf(argIndex);
1223         // The argument in the result DAG pattern.
1224         auto patArgName = node.getArgName(argIndex);
1225         os << formatv(addAttrCmd, opArgName,
1226                       handleOpArgument(leaf, patArgName));
1227       }
1228       continue;
1229     }
1230 
1231     const auto *operand =
1232         resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
1233     std::string varName;
1234     if (operand->isVariadic()) {
1235       std::string range;
1236       if (node.isNestedDagArg(argIndex)) {
1237         range = childNodeNames.lookup(argIndex);
1238       } else {
1239         range = std::string(node.getArgName(argIndex));
1240       }
1241       // Resolve the symbol for all range use so that we have a uniform way of
1242       // capturing the values.
1243       range = symbolInfoMap.getValueAndRangeUse(range);
1244       os << formatv("for (auto v: {0}) {{\n  tblgen_values.push_back(v);\n}\n",
1245                     range);
1246     } else {
1247       os << formatv("tblgen_values.push_back(");
1248       if (node.isNestedDagArg(argIndex)) {
1249         os << symbolInfoMap.getValueAndRangeUse(
1250             childNodeNames.lookup(argIndex));
1251       } else {
1252         DagLeaf leaf = node.getArgAsLeaf(argIndex);
1253         auto symbol =
1254             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1255         if (leaf.isNativeCodeCall()) {
1256           os << std::string(
1257               tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1258         } else {
1259           os << symbol;
1260         }
1261       }
1262       os << ");\n";
1263     }
1264   }
1265 }
1266 
emitRewriters(const RecordKeeper & recordKeeper,raw_ostream & os)1267 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
1268   emitSourceFileHeader("Rewriters", os);
1269 
1270   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
1271   auto numPatterns = patterns.size();
1272 
1273   // We put the map here because it can be shared among multiple patterns.
1274   RecordOperatorMap recordOpMap;
1275 
1276   std::vector<std::string> rewriterNames;
1277   rewriterNames.reserve(numPatterns);
1278 
1279   std::string baseRewriterName = "GeneratedConvert";
1280   int rewriterIndex = 0;
1281 
1282   for (Record *p : patterns) {
1283     std::string name;
1284     if (p->isAnonymous()) {
1285       // If no name is provided, ensure unique rewriter names simply by
1286       // appending unique suffix.
1287       name = baseRewriterName + llvm::utostr(rewriterIndex++);
1288     } else {
1289       name = std::string(p->getName());
1290     }
1291     LLVM_DEBUG(llvm::dbgs()
1292                << "=== start generating pattern '" << name << "' ===\n");
1293     PatternEmitter(p, &recordOpMap, os).emit(name);
1294     LLVM_DEBUG(llvm::dbgs()
1295                << "=== done generating pattern '" << name << "' ===\n");
1296     rewriterNames.push_back(std::move(name));
1297   }
1298 
1299   // Emit function to add the generated matchers to the pattern list.
1300   os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(::mlir::MLIRContext "
1301         "*context, ::mlir::OwningRewritePatternList &patterns) {\n";
1302   for (const auto &name : rewriterNames) {
1303     os << "  patterns.insert<" << name << ">(context);\n";
1304   }
1305   os << "}\n";
1306 }
1307 
1308 static mlir::GenRegistration
1309     genRewriters("gen-rewriters", "Generate pattern rewriters",
__anon1c196a6d0602(const RecordKeeper &records, raw_ostream &os) 1310                  [](const RecordKeeper &records, raw_ostream &os) {
1311                    emitRewriters(records, os);
1312                    return false;
1313                  });
1314