1 //===- Dialect.cpp - Dialect implementation -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/Diagnostics.h"
11 #include "mlir/IR/DialectImplementation.h"
12 #include "mlir/IR/DialectInterface.h"
13 #include "mlir/IR/MLIRContext.h"
14 #include "mlir/IR/Operation.h"
15 #include "llvm/ADT/MapVector.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/ManagedStatic.h"
18 #include "llvm/Support/Regex.h"
19 
20 using namespace mlir;
21 using namespace detail;
22 
~DialectAsmParser()23 DialectAsmParser::~DialectAsmParser() {}
24 
loadByName(StringRef name,MLIRContext * context)25 Dialect *DialectRegistry::loadByName(StringRef name, MLIRContext *context) {
26   auto it = registry.find(name.str());
27   if (it == registry.end())
28     return nullptr;
29   return it->second.second(context);
30 }
31 
insert(TypeID typeID,StringRef name,DialectAllocatorFunction ctor)32 void DialectRegistry::insert(TypeID typeID, StringRef name,
33                              DialectAllocatorFunction ctor) {
34   auto inserted = registry.insert(
35       std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
36   if (!inserted.second && inserted.first->second.first != typeID) {
37     llvm::report_fatal_error(
38         "Trying to register different dialects for the same namespace: " +
39         name);
40   }
41 }
42 
43 //===----------------------------------------------------------------------===//
44 // Dialect
45 //===----------------------------------------------------------------------===//
46 
Dialect(StringRef name,MLIRContext * context,TypeID id)47 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
48     : name(name), dialectID(id), context(context) {
49   assert(isValidNamespace(name) && "invalid dialect namespace");
50 }
51 
~Dialect()52 Dialect::~Dialect() {}
53 
54 /// Verify an attribute from this dialect on the argument at 'argIndex' for
55 /// the region at 'regionIndex' on the given operation. Returns failure if
56 /// the verification failed, success otherwise. This hook may optionally be
57 /// invoked from any operation containing a region.
verifyRegionArgAttribute(Operation *,unsigned,unsigned,NamedAttribute)58 LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
59                                                 NamedAttribute) {
60   return success();
61 }
62 
63 /// Verify an attribute from this dialect on the result at 'resultIndex' for
64 /// the region at 'regionIndex' on the given operation. Returns failure if
65 /// the verification failed, success otherwise. This hook may optionally be
66 /// invoked from any operation containing a region.
verifyRegionResultAttribute(Operation *,unsigned,unsigned,NamedAttribute)67 LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
68                                                    unsigned, NamedAttribute) {
69   return success();
70 }
71 
72 /// Parse an attribute registered to this dialect.
parseAttribute(DialectAsmParser & parser,Type type) const73 Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
74   parser.emitError(parser.getNameLoc())
75       << "dialect '" << getNamespace()
76       << "' provides no attribute parsing hook";
77   return Attribute();
78 }
79 
80 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const81 Type Dialect::parseType(DialectAsmParser &parser) const {
82   // If this dialect allows unknown types, then represent this with OpaqueType.
83   if (allowsUnknownTypes()) {
84     auto ns = Identifier::get(getNamespace(), getContext());
85     return OpaqueType::get(ns, parser.getFullSymbolSpec(), getContext());
86   }
87 
88   parser.emitError(parser.getNameLoc())
89       << "dialect '" << getNamespace() << "' provides no type parsing hook";
90   return Type();
91 }
92 
93 /// Utility function that returns if the given string is a valid dialect
94 /// namespace.
isValidNamespace(StringRef str)95 bool Dialect::isValidNamespace(StringRef str) {
96   if (str.empty())
97     return true;
98   llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
99   return dialectNameRegex.match(str);
100 }
101 
102 /// Register a set of dialect interfaces with this dialect instance.
addInterface(std::unique_ptr<DialectInterface> interface)103 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
104   auto it = registeredInterfaces.try_emplace(interface->getID(),
105                                              std::move(interface));
106   (void)it;
107   assert(it.second && "interface kind has already been registered");
108 }
109 
110 //===----------------------------------------------------------------------===//
111 // Dialect Interface
112 //===----------------------------------------------------------------------===//
113 
~DialectInterface()114 DialectInterface::~DialectInterface() {}
115 
DialectInterfaceCollectionBase(MLIRContext * ctx,TypeID interfaceKind)116 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
117     MLIRContext *ctx, TypeID interfaceKind) {
118   for (auto *dialect : ctx->getLoadedDialects()) {
119     if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
120       interfaces.insert(interface);
121       orderedInterfaces.push_back(interface);
122     }
123   }
124 }
125 
~DialectInterfaceCollectionBase()126 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() {}
127 
128 /// Get the interface for the dialect of given operation, or null if one
129 /// is not registered.
130 const DialectInterface *
getInterfaceFor(Operation * op) const131 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
132   return getInterfaceFor(op->getDialect());
133 }
134