1 //===- Operator.cpp - Operator class --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Operator wrapper to simplify using TableGen Record defining a MLIR Op.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/TableGen/Operator.h"
14 #include "mlir/TableGen/OpTrait.h"
15 #include "mlir/TableGen/Predicate.h"
16 #include "mlir/TableGen/Type.h"
17 #include "llvm/ADT/EquivalenceClasses.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/Sequence.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include "llvm/TableGen/Error.h"
26 #include "llvm/TableGen/Record.h"
27 
28 #define DEBUG_TYPE "mlir-tblgen-operator"
29 
30 using namespace mlir;
31 using namespace mlir::tblgen;
32 
33 using llvm::DagInit;
34 using llvm::DefInit;
35 using llvm::Record;
36 
Operator(const llvm::Record & def)37 Operator::Operator(const llvm::Record &def)
38     : dialect(def.getValueAsDef("opDialect")), def(def) {
39   // The first `_` in the op's TableGen def name is treated as separating the
40   // dialect prefix and the op class name. The dialect prefix will be ignored if
41   // not empty. Otherwise, if def name starts with a `_`, the `_` is considered
42   // as part of the class name.
43   StringRef prefix;
44   std::tie(prefix, cppClassName) = def.getName().split('_');
45   if (prefix.empty()) {
46     // Class name with a leading underscore and without dialect prefix
47     cppClassName = def.getName();
48   } else if (cppClassName.empty()) {
49     // Class name without dialect prefix
50     cppClassName = prefix;
51   }
52 
53   populateOpStructure();
54 }
55 
getOperationName() const56 std::string Operator::getOperationName() const {
57   auto prefix = dialect.getName();
58   auto opName = def.getValueAsString("opName");
59   if (prefix.empty())
60     return std::string(opName);
61   return std::string(llvm::formatv("{0}.{1}", prefix, opName));
62 }
63 
getAdaptorName() const64 std::string Operator::getAdaptorName() const {
65   return std::string(llvm::formatv("{0}Adaptor", getCppClassName()));
66 }
67 
getDialectName() const68 StringRef Operator::getDialectName() const { return dialect.getName(); }
69 
getCppClassName() const70 StringRef Operator::getCppClassName() const { return cppClassName; }
71 
getQualCppClassName() const72 std::string Operator::getQualCppClassName() const {
73   auto prefix = dialect.getCppNamespace();
74   if (prefix.empty())
75     return std::string(cppClassName);
76   return std::string(llvm::formatv("{0}::{1}", prefix, cppClassName));
77 }
78 
getNumResults() const79 int Operator::getNumResults() const {
80   DagInit *results = def.getValueAsDag("results");
81   return results->getNumArgs();
82 }
83 
getExtraClassDeclaration() const84 StringRef Operator::getExtraClassDeclaration() const {
85   constexpr auto attr = "extraClassDeclaration";
86   if (def.isValueUnset(attr))
87     return {};
88   return def.getValueAsString(attr);
89 }
90 
getDef() const91 const llvm::Record &Operator::getDef() const { return def; }
92 
skipDefaultBuilders() const93 bool Operator::skipDefaultBuilders() const {
94   return def.getValueAsBit("skipDefaultBuilders");
95 }
96 
result_begin()97 auto Operator::result_begin() -> value_iterator { return results.begin(); }
98 
result_end()99 auto Operator::result_end() -> value_iterator { return results.end(); }
100 
getResults()101 auto Operator::getResults() -> value_range {
102   return {result_begin(), result_end()};
103 }
104 
getResultTypeConstraint(int index) const105 TypeConstraint Operator::getResultTypeConstraint(int index) const {
106   DagInit *results = def.getValueAsDag("results");
107   return TypeConstraint(cast<DefInit>(results->getArg(index)));
108 }
109 
getResultName(int index) const110 StringRef Operator::getResultName(int index) const {
111   DagInit *results = def.getValueAsDag("results");
112   return results->getArgNameStr(index);
113 }
114 
getResultDecorators(int index) const115 auto Operator::getResultDecorators(int index) const -> var_decorator_range {
116   Record *result =
117       cast<DefInit>(def.getValueAsDag("results")->getArg(index))->getDef();
118   if (!result->isSubClassOf("OpVariable"))
119     return var_decorator_range(nullptr, nullptr);
120   return *result->getValueAsListInit("decorators");
121 }
122 
getNumVariableLengthResults() const123 unsigned Operator::getNumVariableLengthResults() const {
124   return llvm::count_if(results, [](const NamedTypeConstraint &c) {
125     return c.constraint.isVariableLength();
126   });
127 }
128 
getNumVariableLengthOperands() const129 unsigned Operator::getNumVariableLengthOperands() const {
130   return llvm::count_if(operands, [](const NamedTypeConstraint &c) {
131     return c.constraint.isVariableLength();
132   });
133 }
134 
hasSingleVariadicArg() const135 bool Operator::hasSingleVariadicArg() const {
136   return getNumArgs() == 1 && getArg(0).is<NamedTypeConstraint *>() &&
137          getOperand(0).isVariadic();
138 }
139 
arg_begin() const140 Operator::arg_iterator Operator::arg_begin() const { return arguments.begin(); }
141 
arg_end() const142 Operator::arg_iterator Operator::arg_end() const { return arguments.end(); }
143 
getArgs() const144 Operator::arg_range Operator::getArgs() const {
145   return {arg_begin(), arg_end()};
146 }
147 
getArgName(int index) const148 StringRef Operator::getArgName(int index) const {
149   DagInit *argumentValues = def.getValueAsDag("arguments");
150   return argumentValues->getArgNameStr(index);
151 }
152 
getArgDecorators(int index) const153 auto Operator::getArgDecorators(int index) const -> var_decorator_range {
154   Record *arg =
155       cast<DefInit>(def.getValueAsDag("arguments")->getArg(index))->getDef();
156   if (!arg->isSubClassOf("OpVariable"))
157     return var_decorator_range(nullptr, nullptr);
158   return *arg->getValueAsListInit("decorators");
159 }
160 
getTrait(StringRef trait) const161 const OpTrait *Operator::getTrait(StringRef trait) const {
162   for (const auto &t : traits) {
163     if (const auto *opTrait = dyn_cast<NativeOpTrait>(&t)) {
164       if (opTrait->getTrait() == trait)
165         return opTrait;
166     } else if (const auto *opTrait = dyn_cast<InternalOpTrait>(&t)) {
167       if (opTrait->getTrait() == trait)
168         return opTrait;
169     } else if (const auto *opTrait = dyn_cast<InterfaceOpTrait>(&t)) {
170       if (opTrait->getTrait() == trait)
171         return opTrait;
172     }
173   }
174   return nullptr;
175 }
176 
region_begin() const177 auto Operator::region_begin() const -> const_region_iterator {
178   return regions.begin();
179 }
region_end() const180 auto Operator::region_end() const -> const_region_iterator {
181   return regions.end();
182 }
getRegions() const183 auto Operator::getRegions() const
184     -> llvm::iterator_range<const_region_iterator> {
185   return {region_begin(), region_end()};
186 }
187 
getNumRegions() const188 unsigned Operator::getNumRegions() const { return regions.size(); }
189 
getRegion(unsigned index) const190 const NamedRegion &Operator::getRegion(unsigned index) const {
191   return regions[index];
192 }
193 
getNumVariadicRegions() const194 unsigned Operator::getNumVariadicRegions() const {
195   return llvm::count_if(regions,
196                         [](const NamedRegion &c) { return c.isVariadic(); });
197 }
198 
successor_begin() const199 auto Operator::successor_begin() const -> const_successor_iterator {
200   return successors.begin();
201 }
successor_end() const202 auto Operator::successor_end() const -> const_successor_iterator {
203   return successors.end();
204 }
getSuccessors() const205 auto Operator::getSuccessors() const
206     -> llvm::iterator_range<const_successor_iterator> {
207   return {successor_begin(), successor_end()};
208 }
209 
getNumSuccessors() const210 unsigned Operator::getNumSuccessors() const { return successors.size(); }
211 
getSuccessor(unsigned index) const212 const NamedSuccessor &Operator::getSuccessor(unsigned index) const {
213   return successors[index];
214 }
215 
getNumVariadicSuccessors() const216 unsigned Operator::getNumVariadicSuccessors() const {
217   return llvm::count_if(successors,
218                         [](const NamedSuccessor &c) { return c.isVariadic(); });
219 }
220 
trait_begin() const221 auto Operator::trait_begin() const -> const_trait_iterator {
222   return traits.begin();
223 }
trait_end() const224 auto Operator::trait_end() const -> const_trait_iterator {
225   return traits.end();
226 }
getTraits() const227 auto Operator::getTraits() const -> llvm::iterator_range<const_trait_iterator> {
228   return {trait_begin(), trait_end()};
229 }
230 
attribute_begin() const231 auto Operator::attribute_begin() const -> attribute_iterator {
232   return attributes.begin();
233 }
attribute_end() const234 auto Operator::attribute_end() const -> attribute_iterator {
235   return attributes.end();
236 }
getAttributes() const237 auto Operator::getAttributes() const
238     -> llvm::iterator_range<attribute_iterator> {
239   return {attribute_begin(), attribute_end()};
240 }
241 
operand_begin()242 auto Operator::operand_begin() -> value_iterator { return operands.begin(); }
operand_end()243 auto Operator::operand_end() -> value_iterator { return operands.end(); }
getOperands()244 auto Operator::getOperands() -> value_range {
245   return {operand_begin(), operand_end()};
246 }
247 
getArg(int index) const248 auto Operator::getArg(int index) const -> Argument { return arguments[index]; }
249 
250 // Mapping from result index to combined argument and result index. Arguments
251 // are indexed to match getArg index, while the result indexes are mapped to
252 // avoid overlap.
resultIndex(int i)253 static int resultIndex(int i) { return -1 - i; }
254 
isVariadic() const255 bool Operator::isVariadic() const {
256   return any_of(llvm::concat<const NamedTypeConstraint>(operands, results),
257                 [](const NamedTypeConstraint &op) { return op.isVariadic(); });
258 }
259 
populateTypeInferenceInfo(const llvm::StringMap<int> & argumentsAndResultsIndex)260 void Operator::populateTypeInferenceInfo(
261     const llvm::StringMap<int> &argumentsAndResultsIndex) {
262   // If the type inference op interface is not registered, then do not attempt
263   // to determine if the result types an be inferred.
264   auto &recordKeeper = def.getRecords();
265   auto *inferTrait = recordKeeper.getDef(inferTypeOpInterface);
266   allResultsHaveKnownTypes = false;
267   if (!inferTrait)
268     return;
269 
270   // If there are no results, the skip this else the build method generated
271   // overlaps with another autogenerated builder.
272   if (getNumResults() == 0)
273     return;
274 
275   // Skip for ops with variadic operands/results.
276   // TODO: This can be relaxed.
277   if (isVariadic())
278     return;
279 
280   // Skip cases currently being custom generated.
281   // TODO: Remove special cases.
282   if (getTrait("::mlir::OpTrait::SameOperandsAndResultType"))
283     return;
284 
285   // We create equivalence classes of argument/result types where arguments
286   // and results are mapped into the same index space and indices corresponding
287   // to the same type are in the same equivalence class.
288   llvm::EquivalenceClasses<int> ecs;
289   resultTypeMapping.resize(getNumResults());
290   // Captures the argument whose type matches a given result type. Preference
291   // towards capturing operands first before attributes.
292   auto captureMapping = [&](int i) {
293     bool found = false;
294     ecs.insert(resultIndex(i));
295     auto mi = ecs.findLeader(resultIndex(i));
296     for (auto me = ecs.member_end(); mi != me; ++mi) {
297       if (*mi < 0) {
298         auto tc = getResultTypeConstraint(i);
299         if (tc.getBuilderCall().hasValue()) {
300           resultTypeMapping[i].emplace_back(tc);
301           found = true;
302         }
303         continue;
304       }
305 
306       if (getArg(*mi).is<NamedAttribute *>()) {
307         // TODO: Handle attributes.
308         continue;
309       } else {
310         resultTypeMapping[i].emplace_back(*mi);
311         found = true;
312       }
313     }
314     return found;
315   };
316 
317   for (const OpTrait &trait : traits) {
318     const llvm::Record &def = trait.getDef();
319     // If the infer type op interface was manually added, then treat it as
320     // intention that the op needs special handling.
321     // TODO: Reconsider whether to always generate, this is more conservative
322     // and keeps existing behavior so starting that way for now.
323     if (def.isSubClassOf(
324             llvm::formatv("{0}::Trait", inferTypeOpInterface).str()))
325       return;
326     if (const auto *opTrait = dyn_cast<InterfaceOpTrait>(&trait))
327       if (&opTrait->getDef() == inferTrait)
328         return;
329 
330     if (!def.isSubClassOf("AllTypesMatch"))
331       continue;
332 
333     auto values = def.getValueAsListOfStrings("values");
334     auto root = argumentsAndResultsIndex.lookup(values.front());
335     for (StringRef str : values)
336       ecs.unionSets(argumentsAndResultsIndex.lookup(str), root);
337   }
338 
339   // Verifies that all output types have a corresponding known input type
340   // and chooses matching operand or attribute (in that order) that
341   // matches it.
342   allResultsHaveKnownTypes =
343       all_of(llvm::seq<int>(0, getNumResults()), captureMapping);
344 
345   // If the types could be computed, then add type inference trait.
346   if (allResultsHaveKnownTypes)
347     traits.push_back(OpTrait::create(inferTrait->getDefInit()));
348 }
349 
populateOpStructure()350 void Operator::populateOpStructure() {
351   auto &recordKeeper = def.getRecords();
352   auto *typeConstraintClass = recordKeeper.getClass("TypeConstraint");
353   auto *attrClass = recordKeeper.getClass("Attr");
354   auto *derivedAttrClass = recordKeeper.getClass("DerivedAttr");
355   auto *opVarClass = recordKeeper.getClass("OpVariable");
356   numNativeAttributes = 0;
357 
358   DagInit *argumentValues = def.getValueAsDag("arguments");
359   unsigned numArgs = argumentValues->getNumArgs();
360 
361   // Mapping from name of to argument or result index. Arguments are indexed
362   // to match getArg index, while the results are negatively indexed.
363   llvm::StringMap<int> argumentsAndResultsIndex;
364 
365   // Handle operands and native attributes.
366   for (unsigned i = 0; i != numArgs; ++i) {
367     auto *arg = argumentValues->getArg(i);
368     auto givenName = argumentValues->getArgNameStr(i);
369     auto *argDefInit = dyn_cast<DefInit>(arg);
370     if (!argDefInit)
371       PrintFatalError(def.getLoc(),
372                       Twine("undefined type for argument #") + Twine(i));
373     Record *argDef = argDefInit->getDef();
374     if (argDef->isSubClassOf(opVarClass))
375       argDef = argDef->getValueAsDef("constraint");
376 
377     if (argDef->isSubClassOf(typeConstraintClass)) {
378       operands.push_back(
379           NamedTypeConstraint{givenName, TypeConstraint(argDef)});
380     } else if (argDef->isSubClassOf(attrClass)) {
381       if (givenName.empty())
382         PrintFatalError(argDef->getLoc(), "attributes must be named");
383       if (argDef->isSubClassOf(derivedAttrClass))
384         PrintFatalError(argDef->getLoc(),
385                         "derived attributes not allowed in argument list");
386       attributes.push_back({givenName, Attribute(argDef)});
387       ++numNativeAttributes;
388     } else {
389       PrintFatalError(def.getLoc(), "unexpected def type; only defs deriving "
390                                     "from TypeConstraint or Attr are allowed");
391     }
392     if (!givenName.empty())
393       argumentsAndResultsIndex[givenName] = i;
394   }
395 
396   // Handle derived attributes.
397   for (const auto &val : def.getValues()) {
398     if (auto *record = dyn_cast<llvm::RecordRecTy>(val.getType())) {
399       if (!record->isSubClassOf(attrClass))
400         continue;
401       if (!record->isSubClassOf(derivedAttrClass))
402         PrintFatalError(def.getLoc(),
403                         "unexpected Attr where only DerivedAttr is allowed");
404 
405       if (record->getClasses().size() != 1) {
406         PrintFatalError(
407             def.getLoc(),
408             "unsupported attribute modelling, only single class expected");
409       }
410       attributes.push_back(
411           {cast<llvm::StringInit>(val.getNameInit())->getValue(),
412            Attribute(cast<DefInit>(val.getValue()))});
413     }
414   }
415 
416   // Populate `arguments`. This must happen after we've finalized `operands` and
417   // `attributes` because we will put their elements' pointers in `arguments`.
418   // SmallVector may perform re-allocation under the hood when adding new
419   // elements.
420   int operandIndex = 0, attrIndex = 0;
421   for (unsigned i = 0; i != numArgs; ++i) {
422     Record *argDef = dyn_cast<DefInit>(argumentValues->getArg(i))->getDef();
423     if (argDef->isSubClassOf(opVarClass))
424       argDef = argDef->getValueAsDef("constraint");
425 
426     if (argDef->isSubClassOf(typeConstraintClass)) {
427       attrOrOperandMapping.push_back(
428           {OperandOrAttribute::Kind::Operand, operandIndex});
429       arguments.emplace_back(&operands[operandIndex++]);
430     } else {
431       assert(argDef->isSubClassOf(attrClass));
432       attrOrOperandMapping.push_back(
433           {OperandOrAttribute::Kind::Attribute, attrIndex});
434       arguments.emplace_back(&attributes[attrIndex++]);
435     }
436   }
437 
438   auto *resultsDag = def.getValueAsDag("results");
439   auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator());
440   if (!outsOp || outsOp->getDef()->getName() != "outs") {
441     PrintFatalError(def.getLoc(), "'results' must have 'outs' directive");
442   }
443 
444   // Handle results.
445   for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) {
446     auto name = resultsDag->getArgNameStr(i);
447     auto *resultInit = dyn_cast<DefInit>(resultsDag->getArg(i));
448     if (!resultInit) {
449       PrintFatalError(def.getLoc(),
450                       Twine("undefined type for result #") + Twine(i));
451     }
452     auto *resultDef = resultInit->getDef();
453     if (resultDef->isSubClassOf(opVarClass))
454       resultDef = resultDef->getValueAsDef("constraint");
455     results.push_back({name, TypeConstraint(resultDef)});
456     if (!name.empty())
457       argumentsAndResultsIndex[name] = resultIndex(i);
458   }
459 
460   // Handle successors
461   auto *successorsDag = def.getValueAsDag("successors");
462   auto *successorsOp = dyn_cast<DefInit>(successorsDag->getOperator());
463   if (!successorsOp || successorsOp->getDef()->getName() != "successor") {
464     PrintFatalError(def.getLoc(),
465                     "'successors' must have 'successor' directive");
466   }
467 
468   for (unsigned i = 0, e = successorsDag->getNumArgs(); i < e; ++i) {
469     auto name = successorsDag->getArgNameStr(i);
470     auto *successorInit = dyn_cast<DefInit>(successorsDag->getArg(i));
471     if (!successorInit) {
472       PrintFatalError(def.getLoc(),
473                       Twine("undefined kind for successor #") + Twine(i));
474     }
475     Successor successor(successorInit->getDef());
476 
477     // Only support variadic successors if it is the last one for now.
478     if (i != e - 1 && successor.isVariadic())
479       PrintFatalError(def.getLoc(), "only the last successor can be variadic");
480     successors.push_back({name, successor});
481   }
482 
483   // Create list of traits, skipping over duplicates: appending to lists in
484   // tablegen is easy, making them unique less so, so dedupe here.
485   if (auto *traitList = def.getValueAsListInit("traits")) {
486     // This is uniquing based on pointers of the trait.
487     SmallPtrSet<const llvm::Init *, 32> traitSet;
488     traits.reserve(traitSet.size());
489     for (auto *traitInit : *traitList) {
490       // Keep traits in the same order while skipping over duplicates.
491       if (traitSet.insert(traitInit).second)
492         traits.push_back(OpTrait::create(traitInit));
493     }
494   }
495 
496   populateTypeInferenceInfo(argumentsAndResultsIndex);
497 
498   // Handle regions
499   auto *regionsDag = def.getValueAsDag("regions");
500   auto *regionsOp = dyn_cast<DefInit>(regionsDag->getOperator());
501   if (!regionsOp || regionsOp->getDef()->getName() != "region") {
502     PrintFatalError(def.getLoc(), "'regions' must have 'region' directive");
503   }
504 
505   for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) {
506     auto name = regionsDag->getArgNameStr(i);
507     auto *regionInit = dyn_cast<DefInit>(regionsDag->getArg(i));
508     if (!regionInit) {
509       PrintFatalError(def.getLoc(),
510                       Twine("undefined kind for region #") + Twine(i));
511     }
512     Region region(regionInit->getDef());
513     if (region.isVariadic()) {
514       // Only support variadic regions if it is the last one for now.
515       if (i != e - 1)
516         PrintFatalError(def.getLoc(), "only the last region can be variadic");
517       if (name.empty())
518         PrintFatalError(def.getLoc(), "variadic regions must be named");
519     }
520 
521     regions.push_back({name, region});
522   }
523 
524   LLVM_DEBUG(print(llvm::dbgs()));
525 }
526 
getSameTypeAsResult(int index) const527 auto Operator::getSameTypeAsResult(int index) const -> ArrayRef<ArgOrType> {
528   assert(allResultTypesKnown());
529   return resultTypeMapping[index];
530 }
531 
getLoc() const532 ArrayRef<llvm::SMLoc> Operator::getLoc() const { return def.getLoc(); }
533 
hasDescription() const534 bool Operator::hasDescription() const {
535   return def.getValue("description") != nullptr;
536 }
537 
getDescription() const538 StringRef Operator::getDescription() const {
539   return def.getValueAsString("description");
540 }
541 
hasSummary() const542 bool Operator::hasSummary() const { return def.getValue("summary") != nullptr; }
543 
getSummary() const544 StringRef Operator::getSummary() const {
545   return def.getValueAsString("summary");
546 }
547 
hasAssemblyFormat() const548 bool Operator::hasAssemblyFormat() const {
549   auto *valueInit = def.getValueInit("assemblyFormat");
550   return isa<llvm::StringInit>(valueInit);
551 }
552 
getAssemblyFormat() const553 StringRef Operator::getAssemblyFormat() const {
554   return TypeSwitch<llvm::Init *, StringRef>(def.getValueInit("assemblyFormat"))
555       .Case<llvm::StringInit>(
556           [&](auto *init) { return init->getValue(); });
557 }
558 
print(llvm::raw_ostream & os) const559 void Operator::print(llvm::raw_ostream &os) const {
560   os << "op '" << getOperationName() << "'\n";
561   for (Argument arg : arguments) {
562     if (auto *attr = arg.dyn_cast<NamedAttribute *>())
563       os << "[attribute] " << attr->name << '\n';
564     else
565       os << "[operand] " << arg.get<NamedTypeConstraint *>()->name << '\n';
566   }
567 }
568 
unwrap(llvm::Init * init)569 auto Operator::VariableDecoratorIterator::unwrap(llvm::Init *init)
570     -> VariableDecorator {
571   return VariableDecorator(cast<llvm::DefInit>(init)->getDef());
572 }
573 
getArgToOperandOrAttribute(int index) const574 auto Operator::getArgToOperandOrAttribute(int index) const
575     -> OperandOrAttribute {
576   return attrOrOperandMapping[index];
577 }
578