1 //===- Attribute.cpp - Attribute wrapper class ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Attribute wrapper to simplify using TableGen Record defining a MLIR
10 // Attribute.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/Format.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "llvm/TableGen/Record.h"
17 
18 using namespace mlir;
19 using namespace mlir::tblgen;
20 
21 using llvm::DefInit;
22 using llvm::Init;
23 using llvm::Record;
24 using llvm::StringInit;
25 
26 // Returns the initializer's value as string if the given TableGen initializer
27 // is a code or string initializer. Returns the empty StringRef otherwise.
getValueAsString(const Init * init)28 static StringRef getValueAsString(const Init *init) {
29   if (const auto *str = dyn_cast<StringInit>(init))
30     return str->getValue().trim();
31   return {};
32 }
33 
AttrConstraint(const Record * record)34 AttrConstraint::AttrConstraint(const Record *record)
35     : Constraint(Constraint::CK_Attr, record) {
36   assert(isSubClassOf("AttrConstraint") &&
37          "must be subclass of TableGen 'AttrConstraint' class");
38 }
39 
isSubClassOf(StringRef className) const40 bool AttrConstraint::isSubClassOf(StringRef className) const {
41   return def->isSubClassOf(className);
42 }
43 
Attribute(const Record * record)44 Attribute::Attribute(const Record *record) : AttrConstraint(record) {
45   assert(record->isSubClassOf("Attr") &&
46          "must be subclass of TableGen 'Attr' class");
47 }
48 
Attribute(const DefInit * init)49 Attribute::Attribute(const DefInit *init) : Attribute(init->getDef()) {}
50 
isDerivedAttr() const51 bool Attribute::isDerivedAttr() const { return isSubClassOf("DerivedAttr"); }
52 
isTypeAttr() const53 bool Attribute::isTypeAttr() const { return isSubClassOf("TypeAttrBase"); }
54 
isSymbolRefAttr() const55 bool Attribute::isSymbolRefAttr() const {
56   StringRef defName = def->getName();
57   if (defName == "SymbolRefAttr" || defName == "FlatSymbolRefAttr")
58     return true;
59   return isSubClassOf("SymbolRefAttr") || isSubClassOf("FlatSymbolRefAttr");
60 }
61 
isEnumAttr() const62 bool Attribute::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); }
63 
getStorageType() const64 StringRef Attribute::getStorageType() const {
65   const auto *init = def->getValueInit("storageType");
66   auto type = getValueAsString(init);
67   if (type.empty())
68     return "Attribute";
69   return type;
70 }
71 
getReturnType() const72 StringRef Attribute::getReturnType() const {
73   const auto *init = def->getValueInit("returnType");
74   return getValueAsString(init);
75 }
76 
77 // Return the type constraint corresponding to the type of this attribute, or
78 // None if this is not a TypedAttr.
getValueType() const79 llvm::Optional<Type> Attribute::getValueType() const {
80   if (auto *defInit = dyn_cast<llvm::DefInit>(def->getValueInit("valueType")))
81     return Type(defInit->getDef());
82   return llvm::None;
83 }
84 
getConvertFromStorageCall() const85 StringRef Attribute::getConvertFromStorageCall() const {
86   const auto *init = def->getValueInit("convertFromStorage");
87   return getValueAsString(init);
88 }
89 
isConstBuildable() const90 bool Attribute::isConstBuildable() const {
91   const auto *init = def->getValueInit("constBuilderCall");
92   return !getValueAsString(init).empty();
93 }
94 
getConstBuilderTemplate() const95 StringRef Attribute::getConstBuilderTemplate() const {
96   const auto *init = def->getValueInit("constBuilderCall");
97   return getValueAsString(init);
98 }
99 
getBaseAttr() const100 Attribute Attribute::getBaseAttr() const {
101   if (const auto *defInit =
102           llvm::dyn_cast<llvm::DefInit>(def->getValueInit("baseAttr"))) {
103     return Attribute(defInit).getBaseAttr();
104   }
105   return *this;
106 }
107 
hasDefaultValue() const108 bool Attribute::hasDefaultValue() const {
109   const auto *init = def->getValueInit("defaultValue");
110   return !getValueAsString(init).empty();
111 }
112 
getDefaultValue() const113 StringRef Attribute::getDefaultValue() const {
114   const auto *init = def->getValueInit("defaultValue");
115   return getValueAsString(init);
116 }
117 
isOptional() const118 bool Attribute::isOptional() const { return def->getValueAsBit("isOptional"); }
119 
getAttrDefName() const120 StringRef Attribute::getAttrDefName() const {
121   if (def->isAnonymous()) {
122     return getBaseAttr().def->getName();
123   }
124   return def->getName();
125 }
126 
getDerivedCodeBody() const127 StringRef Attribute::getDerivedCodeBody() const {
128   assert(isDerivedAttr() && "only derived attribute has 'body' field");
129   return def->getValueAsString("body");
130 }
131 
getDialect() const132 Dialect Attribute::getDialect() const {
133   const llvm::RecordVal *record = def->getValue("dialect");
134   if (record && record->getValue()) {
135     if (DefInit *init = dyn_cast<DefInit>(record->getValue()))
136       return Dialect(init->getDef());
137   }
138   return Dialect(nullptr);
139 }
140 
ConstantAttr(const DefInit * init)141 ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) {
142   assert(def->isSubClassOf("ConstantAttr") &&
143          "must be subclass of TableGen 'ConstantAttr' class");
144 }
145 
getAttribute() const146 Attribute ConstantAttr::getAttribute() const {
147   return Attribute(def->getValueAsDef("attr"));
148 }
149 
getConstantValue() const150 StringRef ConstantAttr::getConstantValue() const {
151   return def->getValueAsString("value");
152 }
153 
EnumAttrCase(const llvm::Record * record)154 EnumAttrCase::EnumAttrCase(const llvm::Record *record) : Attribute(record) {
155   assert(isSubClassOf("EnumAttrCaseInfo") &&
156          "must be subclass of TableGen 'EnumAttrInfo' class");
157 }
158 
EnumAttrCase(const llvm::DefInit * init)159 EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
160     : EnumAttrCase(init->getDef()) {}
161 
isStrCase() const162 bool EnumAttrCase::isStrCase() const { return isSubClassOf("StrEnumAttrCase"); }
163 
getSymbol() const164 StringRef EnumAttrCase::getSymbol() const {
165   return def->getValueAsString("symbol");
166 }
167 
getStr() const168 StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); }
169 
getValue() const170 int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); }
171 
getDef() const172 const llvm::Record &EnumAttrCase::getDef() const { return *def; }
173 
EnumAttr(const llvm::Record * record)174 EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) {
175   assert(isSubClassOf("EnumAttrInfo") &&
176          "must be subclass of TableGen 'EnumAttr' class");
177 }
178 
EnumAttr(const llvm::Record & record)179 EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {}
180 
EnumAttr(const llvm::DefInit * init)181 EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {}
182 
classof(const Attribute * attr)183 bool EnumAttr::classof(const Attribute *attr) {
184   return attr->isSubClassOf("EnumAttrInfo");
185 }
186 
isBitEnum() const187 bool EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); }
188 
getEnumClassName() const189 StringRef EnumAttr::getEnumClassName() const {
190   return def->getValueAsString("className");
191 }
192 
getCppNamespace() const193 StringRef EnumAttr::getCppNamespace() const {
194   return def->getValueAsString("cppNamespace");
195 }
196 
getUnderlyingType() const197 StringRef EnumAttr::getUnderlyingType() const {
198   return def->getValueAsString("underlyingType");
199 }
200 
getUnderlyingToSymbolFnName() const201 StringRef EnumAttr::getUnderlyingToSymbolFnName() const {
202   return def->getValueAsString("underlyingToSymbolFnName");
203 }
204 
getStringToSymbolFnName() const205 StringRef EnumAttr::getStringToSymbolFnName() const {
206   return def->getValueAsString("stringToSymbolFnName");
207 }
208 
getSymbolToStringFnName() const209 StringRef EnumAttr::getSymbolToStringFnName() const {
210   return def->getValueAsString("symbolToStringFnName");
211 }
212 
getSymbolToStringFnRetType() const213 StringRef EnumAttr::getSymbolToStringFnRetType() const {
214   return def->getValueAsString("symbolToStringFnRetType");
215 }
216 
getMaxEnumValFnName() const217 StringRef EnumAttr::getMaxEnumValFnName() const {
218   return def->getValueAsString("maxEnumValFnName");
219 }
220 
getAllCases() const221 std::vector<EnumAttrCase> EnumAttr::getAllCases() const {
222   const auto *inits = def->getValueAsListInit("enumerants");
223 
224   std::vector<EnumAttrCase> cases;
225   cases.reserve(inits->size());
226 
227   for (const llvm::Init *init : *inits) {
228     cases.push_back(EnumAttrCase(cast<llvm::DefInit>(init)));
229   }
230 
231   return cases;
232 }
233 
StructFieldAttr(const llvm::Record * record)234 StructFieldAttr::StructFieldAttr(const llvm::Record *record) : def(record) {
235   assert(def->isSubClassOf("StructFieldAttr") &&
236          "must be subclass of TableGen 'StructFieldAttr' class");
237 }
238 
StructFieldAttr(const llvm::Record & record)239 StructFieldAttr::StructFieldAttr(const llvm::Record &record)
240     : StructFieldAttr(&record) {}
241 
StructFieldAttr(const llvm::DefInit * init)242 StructFieldAttr::StructFieldAttr(const llvm::DefInit *init)
243     : StructFieldAttr(init->getDef()) {}
244 
getName() const245 StringRef StructFieldAttr::getName() const {
246   return def->getValueAsString("name");
247 }
248 
getType() const249 Attribute StructFieldAttr::getType() const {
250   auto init = def->getValueInit("type");
251   return Attribute(cast<llvm::DefInit>(init));
252 }
253 
StructAttr(const llvm::Record * record)254 StructAttr::StructAttr(const llvm::Record *record) : Attribute(record) {
255   assert(isSubClassOf("StructAttr") &&
256          "must be subclass of TableGen 'StructAttr' class");
257 }
258 
StructAttr(const llvm::DefInit * init)259 StructAttr::StructAttr(const llvm::DefInit *init)
260     : StructAttr(init->getDef()) {}
261 
getStructClassName() const262 StringRef StructAttr::getStructClassName() const {
263   return def->getValueAsString("className");
264 }
265 
getCppNamespace() const266 StringRef StructAttr::getCppNamespace() const {
267   Dialect dialect(def->getValueAsDef("dialect"));
268   return dialect.getCppNamespace();
269 }
270 
getAllFields() const271 std::vector<StructFieldAttr> StructAttr::getAllFields() const {
272   std::vector<StructFieldAttr> attributes;
273 
274   const auto *inits = def->getValueAsListInit("fields");
275   attributes.reserve(inits->size());
276 
277   for (const llvm::Init *init : *inits) {
278     attributes.emplace_back(cast<llvm::DefInit>(init));
279   }
280 
281   return attributes;
282 }
283 
284 const char * ::mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface";
285