1 //===- OpClass.cpp - Helper classes for Op C++ code emission --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/TableGen/OpClass.h"
10 
11 #include "mlir/TableGen/Format.h"
12 #include "llvm/ADT/Sequence.h"
13 #include "llvm/ADT/Twine.h"
14 #include "llvm/Support/Debug.h"
15 #include "llvm/Support/raw_ostream.h"
16 #include <unordered_set>
17 
18 #define DEBUG_TYPE "mlir-tblgen-opclass"
19 
20 using namespace mlir;
21 using namespace mlir::tblgen;
22 
23 namespace {
24 
25 // Returns space to be emitted after the given C++ `type`. return "" if the
26 // ends with '&' or '*', or is empty, else returns " ".
getSpaceAfterType(StringRef type)27 StringRef getSpaceAfterType(StringRef type) {
28   return (type.empty() || type.endswith("&") || type.endswith("*")) ? "" : " ";
29 }
30 
31 } // namespace
32 
33 //===----------------------------------------------------------------------===//
34 // OpMethodParameter definitions
35 //===----------------------------------------------------------------------===//
36 
writeTo(raw_ostream & os,bool emitDefault) const37 void OpMethodParameter::writeTo(raw_ostream &os, bool emitDefault) const {
38   if (properties & PP_Optional)
39     os << "/*optional*/";
40   os << type << getSpaceAfterType(type) << name;
41   if (emitDefault && !defaultValue.empty())
42     os << " = " << defaultValue;
43 }
44 
45 //===----------------------------------------------------------------------===//
46 // OpMethodParameters definitions
47 //===----------------------------------------------------------------------===//
48 
49 // Factory methods to construct the correct type of `OpMethodParameters`
50 // object based on the arguments.
create()51 std::unique_ptr<OpMethodParameters> OpMethodParameters::create() {
52   return std::make_unique<OpMethodResolvedParameters>();
53 }
54 
55 std::unique_ptr<OpMethodParameters>
create(StringRef params)56 OpMethodParameters::create(StringRef params) {
57   return std::make_unique<OpMethodUnresolvedParameters>(params);
58 }
59 
60 std::unique_ptr<OpMethodParameters>
create(llvm::SmallVectorImpl<OpMethodParameter> && params)61 OpMethodParameters::create(llvm::SmallVectorImpl<OpMethodParameter> &&params) {
62   return std::make_unique<OpMethodResolvedParameters>(std::move(params));
63 }
64 
65 std::unique_ptr<OpMethodParameters>
create(StringRef type,StringRef name,StringRef defaultValue)66 OpMethodParameters::create(StringRef type, StringRef name,
67                            StringRef defaultValue) {
68   return std::make_unique<OpMethodResolvedParameters>(type, name, defaultValue);
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // OpMethodUnresolvedParameters definitions
73 //===----------------------------------------------------------------------===//
writeDeclTo(raw_ostream & os) const74 void OpMethodUnresolvedParameters::writeDeclTo(raw_ostream &os) const {
75   os << parameters;
76 }
77 
writeDefTo(raw_ostream & os) const78 void OpMethodUnresolvedParameters::writeDefTo(raw_ostream &os) const {
79   // We need to remove the default values for parameters in method definition.
80   // TODO: We are using '=' and ',' as delimiters for parameter
81   // initializers. This is incorrect for initializer list with more than one
82   // element. Change to a more robust approach.
83   llvm::SmallVector<StringRef, 4> tokens;
84   StringRef params = parameters;
85   while (!params.empty()) {
86     std::pair<StringRef, StringRef> parts = params.split("=");
87     tokens.push_back(parts.first);
88     params = parts.second.split(',').second;
89   }
90   llvm::interleaveComma(tokens, os, [&](StringRef token) { os << token; });
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // OpMethodResolvedParameters definitions
95 //===----------------------------------------------------------------------===//
96 
97 // Returns true if a method with these parameters makes a method with parameters
98 // `other` redundant. This should return true only if all possible calls to the
99 // other method can be replaced by calls to this method.
makesRedundant(const OpMethodResolvedParameters & other) const100 bool OpMethodResolvedParameters::makesRedundant(
101     const OpMethodResolvedParameters &other) const {
102   const size_t otherNumParams = other.getNumParameters();
103   const size_t thisNumParams = getNumParameters();
104 
105   // All calls to the other method can be replaced this method only if this
106   // method has the same or more arguments number of arguments as the other, and
107   // the common arguments have the same type.
108   if (thisNumParams < otherNumParams)
109     return false;
110   for (int idx : llvm::seq<int>(0, otherNumParams))
111     if (parameters[idx].getType() != other.parameters[idx].getType())
112       return false;
113 
114   // If all the common arguments have the same type, we can elide the other
115   // method if this method has the same number of arguments as other or the
116   // first argument after the common ones has a default value (and by C++
117   // requirement, all the later ones will also have a default value).
118   return thisNumParams == otherNumParams ||
119          parameters[otherNumParams].hasDefaultValue();
120 }
121 
writeDeclTo(raw_ostream & os) const122 void OpMethodResolvedParameters::writeDeclTo(raw_ostream &os) const {
123   llvm::interleaveComma(parameters, os, [&](const OpMethodParameter &param) {
124     param.writeDeclTo(os);
125   });
126 }
127 
writeDefTo(raw_ostream & os) const128 void OpMethodResolvedParameters::writeDefTo(raw_ostream &os) const {
129   llvm::interleaveComma(parameters, os, [&](const OpMethodParameter &param) {
130     param.writeDefTo(os);
131   });
132 }
133 
134 //===----------------------------------------------------------------------===//
135 // OpMethodSignature definitions
136 //===----------------------------------------------------------------------===//
137 
138 // Returns if a method with this signature makes a method with `other` signature
139 // redundant. Only supports resolved parameters.
makesRedundant(const OpMethodSignature & other) const140 bool OpMethodSignature::makesRedundant(const OpMethodSignature &other) const {
141   if (methodName != other.methodName)
142     return false;
143   auto *resolvedThis = dyn_cast<OpMethodResolvedParameters>(parameters.get());
144   auto *resolvedOther =
145       dyn_cast<OpMethodResolvedParameters>(other.parameters.get());
146   if (resolvedThis && resolvedOther)
147     return resolvedThis->makesRedundant(*resolvedOther);
148   return false;
149 }
150 
writeDeclTo(raw_ostream & os) const151 void OpMethodSignature::writeDeclTo(raw_ostream &os) const {
152   os << returnType << getSpaceAfterType(returnType) << methodName << "(";
153   parameters->writeDeclTo(os);
154   os << ")";
155 }
156 
writeDefTo(raw_ostream & os,StringRef namePrefix) const157 void OpMethodSignature::writeDefTo(raw_ostream &os,
158                                    StringRef namePrefix) const {
159   os << returnType << getSpaceAfterType(returnType) << namePrefix
160      << (namePrefix.empty() ? "" : "::") << methodName << "(";
161   parameters->writeDefTo(os);
162   os << ")";
163 }
164 
165 //===----------------------------------------------------------------------===//
166 // OpMethodBody definitions
167 //===----------------------------------------------------------------------===//
168 
OpMethodBody(bool declOnly)169 OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {}
170 
operator <<(Twine content)171 OpMethodBody &OpMethodBody::operator<<(Twine content) {
172   if (isEffective)
173     body.append(content.str());
174   return *this;
175 }
176 
operator <<(int content)177 OpMethodBody &OpMethodBody::operator<<(int content) {
178   if (isEffective)
179     body.append(std::to_string(content));
180   return *this;
181 }
182 
operator <<(const FmtObjectBase & content)183 OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) {
184   if (isEffective)
185     body.append(content.str());
186   return *this;
187 }
188 
writeTo(raw_ostream & os) const189 void OpMethodBody::writeTo(raw_ostream &os) const {
190   auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; });
191   os << bodyRef;
192   if (bodyRef.empty() || bodyRef.back() != '\n')
193     os << "\n";
194 }
195 
196 //===----------------------------------------------------------------------===//
197 // OpMethod definitions
198 //===----------------------------------------------------------------------===//
199 
writeDeclTo(raw_ostream & os) const200 void OpMethod::writeDeclTo(raw_ostream &os) const {
201   os.indent(2);
202   if (isStatic())
203     os << "static ";
204   methodSignature.writeDeclTo(os);
205   os << ";";
206 }
207 
writeDefTo(raw_ostream & os,StringRef namePrefix) const208 void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
209   // Do not write definition if the method is decl only.
210   if (properties & MP_Declaration)
211     return;
212   methodSignature.writeDefTo(os, namePrefix);
213   os << " {\n";
214   methodBody.writeTo(os);
215   os << "}";
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // OpConstructor definitions
220 //===----------------------------------------------------------------------===//
221 
addMemberInitializer(StringRef name,StringRef value)222 void OpConstructor::addMemberInitializer(StringRef name, StringRef value) {
223   memberInitializers.append(std::string(llvm::formatv(
224       "{0}{1}({2})", memberInitializers.empty() ? " : " : ", ", name, value)));
225 }
226 
writeDefTo(raw_ostream & os,StringRef namePrefix) const227 void OpConstructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
228   // Do not write definition if the method is decl only.
229   if (properties & MP_Declaration)
230     return;
231 
232   methodSignature.writeDefTo(os, namePrefix);
233   os << " " << memberInitializers << " {\n";
234   methodBody.writeTo(os);
235   os << "}";
236 }
237 
238 //===----------------------------------------------------------------------===//
239 // Class definitions
240 //===----------------------------------------------------------------------===//
241 
Class(StringRef name)242 Class::Class(StringRef name) : className(name) {}
243 
newField(StringRef type,StringRef name,StringRef defaultValue)244 void Class::newField(StringRef type, StringRef name, StringRef defaultValue) {
245   std::string varName = formatv("{0} {1}", type, name).str();
246   std::string field = defaultValue.empty()
247                           ? varName
248                           : formatv("{0} = {1}", varName, defaultValue).str();
249   fields.push_back(std::move(field));
250 }
writeDeclTo(raw_ostream & os) const251 void Class::writeDeclTo(raw_ostream &os) const {
252   bool hasPrivateMethod = false;
253   os << "class " << className << " {\n";
254   os << "public:\n";
255 
256   forAllMethods([&](const OpMethod &method) {
257     if (!method.isPrivate()) {
258       method.writeDeclTo(os);
259       os << '\n';
260     } else {
261       hasPrivateMethod = true;
262     }
263   });
264 
265   os << '\n';
266   os << "private:\n";
267   if (hasPrivateMethod) {
268     forAllMethods([&](const OpMethod &method) {
269       if (method.isPrivate()) {
270         method.writeDeclTo(os);
271         os << '\n';
272       }
273     });
274     os << '\n';
275   }
276 
277   for (const auto &field : fields)
278     os.indent(2) << field << ";\n";
279   os << "};\n";
280 }
281 
writeDefTo(raw_ostream & os) const282 void Class::writeDefTo(raw_ostream &os) const {
283   forAllMethods([&](const OpMethod &method) {
284     method.writeDefTo(os, className);
285     os << "\n\n";
286   });
287 }
288 
289 //===----------------------------------------------------------------------===//
290 // OpClass definitions
291 //===----------------------------------------------------------------------===//
292 
OpClass(StringRef name,StringRef extraClassDeclaration)293 OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
294     : Class(name), extraClassDeclaration(extraClassDeclaration) {}
295 
addTrait(Twine trait)296 void OpClass::addTrait(Twine trait) {
297   auto traitStr = trait.str();
298   if (traitsSet.insert(traitStr).second)
299     traitsVec.push_back(std::move(traitStr));
300 }
301 
writeDeclTo(raw_ostream & os) const302 void OpClass::writeDeclTo(raw_ostream &os) const {
303   os << "class " << className << " : public ::mlir::Op<" << className;
304   for (const auto &trait : traitsVec)
305     os << ", " << trait;
306   os << "> {\npublic:\n"
307      << "  using Op::Op;\n"
308      << "  using Op::print;\n"
309      << "  using Adaptor = " << className << "Adaptor;\n";
310 
311   bool hasPrivateMethod = false;
312   forAllMethods([&](const OpMethod &method) {
313     if (!method.isPrivate()) {
314       method.writeDeclTo(os);
315       os << "\n";
316     } else {
317       hasPrivateMethod = true;
318     }
319   });
320 
321   // TODO: Add line control markers to make errors easier to debug.
322   if (!extraClassDeclaration.empty())
323     os << extraClassDeclaration << "\n";
324 
325   if (hasPrivateMethod) {
326     os << "\nprivate:\n";
327     forAllMethods([&](const OpMethod &method) {
328       if (method.isPrivate()) {
329         method.writeDeclTo(os);
330         os << "\n";
331       }
332     });
333   }
334 
335   os << "};\n";
336 }
337