1 //===- StorageUniquer.h - Common Storage Class Uniquer ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_SUPPORT_STORAGEUNIQUER_H 10 #define MLIR_SUPPORT_STORAGEUNIQUER_H 11 12 #include "mlir/Support/LLVM.h" 13 #include "mlir/Support/LogicalResult.h" 14 #include "mlir/Support/TypeID.h" 15 #include "llvm/ADT/DenseSet.h" 16 #include "llvm/Support/Allocator.h" 17 18 namespace mlir { 19 namespace detail { 20 struct StorageUniquerImpl; 21 22 /// Trait to check if ImplTy provides a 'getKey' method with types 'Args'. 23 template <typename ImplTy, typename... Args> 24 using has_impltype_getkey_t = decltype(ImplTy::getKey(std::declval<Args>()...)); 25 26 /// Trait to check if ImplTy provides a 'hashKey' method for 'T'. 27 template <typename ImplTy, typename T> 28 using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval<T>())); 29 } // namespace detail 30 31 /// A utility class to get or create instances of "storage classes". These 32 /// storage classes must derive from 'StorageUniquer::BaseStorage'. 33 /// 34 /// For non-parametric storage classes, i.e. singleton classes, nothing else is 35 /// needed. Instances of these classes can be created by calling `get` without 36 /// trailing arguments. 37 /// 38 /// Otherwise, the parametric storage classes may be created with `get`, 39 /// and must respect the following: 40 /// - Define a type alias, KeyTy, to a type that uniquely identifies the 41 /// instance of the storage class. 42 /// * The key type must be constructible from the values passed into the 43 /// getComplex call. 44 /// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the 45 /// storage class must define a hashing method: 46 /// 'static unsigned hashKey(const KeyTy &)' 47 /// 48 /// - Provide a method, 'bool operator==(const KeyTy &) const', to 49 /// compare the storage instance against an instance of the key type. 50 /// 51 /// - Provide a static construction method: 52 /// 'DerivedStorage *construct(StorageAllocator &, const KeyTy &key)' 53 /// that builds a unique instance of the derived storage. The arguments to 54 /// this function are an allocator to store any uniqued data and the key 55 /// type for this storage. 56 /// 57 /// - Provide a cleanup method: 58 /// 'void cleanup()' 59 /// that is called when erasing a storage instance. This should cleanup any 60 /// fields of the storage as necessary and not attempt to free the memory 61 /// of the storage itself. 62 /// 63 /// Storage classes may have an optional mutable component, which must not take 64 /// part in the unique immutable key. In this case, storage classes may be 65 /// mutated with `mutate` and must additionally respect the following: 66 /// - Provide a mutation method: 67 /// 'LogicalResult mutate(StorageAllocator &, <...>)' 68 /// that is called when mutating a storage instance. The first argument is 69 /// an allocator to store any mutable data, and the remaining arguments are 70 /// forwarded from the call site. The storage can be mutated at any time 71 /// after creation. Care must be taken to avoid excessive mutation since 72 /// the allocated storage can keep containing previous states. The return 73 /// value of the function is used to indicate whether the mutation was 74 /// successful, e.g., to limit the number of mutations or enable deferred 75 /// one-time assignment of the mutable component. 76 /// 77 /// All storage classes must be registered with the uniquer via 78 /// `registerStorageType` using an appropriate unique `TypeID` for the storage 79 /// class. 80 class StorageUniquer { 81 public: 82 /// This class acts as the base storage that all storage classes must derived 83 /// from. 84 class BaseStorage { 85 protected: 86 BaseStorage() = default; 87 }; 88 89 /// This is a utility allocator used to allocate memory for instances of 90 /// derived types. 91 class StorageAllocator { 92 public: 93 /// Copy the specified array of elements into memory managed by our bump 94 /// pointer allocator. This assumes the elements are all PODs. copyInto(ArrayRef<T> elements)95 template <typename T> ArrayRef<T> copyInto(ArrayRef<T> elements) { 96 if (elements.empty()) 97 return llvm::None; 98 auto result = allocator.Allocate<T>(elements.size()); 99 std::uninitialized_copy(elements.begin(), elements.end(), result); 100 return ArrayRef<T>(result, elements.size()); 101 } 102 103 /// Copy the provided string into memory managed by our bump pointer 104 /// allocator. copyInto(StringRef str)105 StringRef copyInto(StringRef str) { 106 auto result = copyInto(ArrayRef<char>(str.data(), str.size())); 107 return StringRef(result.data(), str.size()); 108 } 109 110 /// Allocate an instance of the provided type. allocate()111 template <typename T> T *allocate() { return allocator.Allocate<T>(); } 112 113 /// Allocate 'size' bytes of 'alignment' aligned memory. allocate(size_t size,size_t alignment)114 void *allocate(size_t size, size_t alignment) { 115 return allocator.Allocate(size, alignment); 116 } 117 118 /// Returns true if this allocator allocated the provided object pointer. allocated(const void * ptr)119 bool allocated(const void *ptr) { 120 return allocator.identifyObject(ptr).hasValue(); 121 } 122 123 private: 124 /// The raw allocator for type storage objects. 125 llvm::BumpPtrAllocator allocator; 126 }; 127 128 StorageUniquer(); 129 ~StorageUniquer(); 130 131 /// Set the flag specifying if multi-threading is disabled within the uniquer. 132 void disableMultithreading(bool disable = true); 133 134 /// Register a new parametric storage class, this is necessary to create 135 /// instances of this class type. `id` is the type identifier that will be 136 /// used to identify this type when creating instances of it via 'get'. registerParametricStorageType(TypeID id)137 template <typename Storage> void registerParametricStorageType(TypeID id) { 138 registerParametricStorageTypeImpl(id); 139 } 140 /// Utility override when the storage type represents the type id. registerParametricStorageType()141 template <typename Storage> void registerParametricStorageType() { 142 registerParametricStorageType<Storage>(TypeID::get<Storage>()); 143 } 144 /// Register a new singleton storage class, this is necessary to get the 145 /// singletone instance. `id` is the type identifier that will be used to 146 /// access the singleton instance via 'get'. An optional initialization 147 /// function may also be provided to initialize the newly created storage 148 /// instance, and used when the singleton instance is created. 149 template <typename Storage> registerSingletonStorageType(TypeID id,function_ref<void (Storage *)> initFn)150 void registerSingletonStorageType(TypeID id, 151 function_ref<void(Storage *)> initFn) { 152 auto ctorFn = [&](StorageAllocator &allocator) { 153 auto *storage = new (allocator.allocate<Storage>()) Storage(); 154 if (initFn) 155 initFn(storage); 156 return storage; 157 }; 158 registerSingletonImpl(id, ctorFn); 159 } registerSingletonStorageType(TypeID id)160 template <typename Storage> void registerSingletonStorageType(TypeID id) { 161 registerSingletonStorageType<Storage>(id, llvm::None); 162 } 163 /// Utility override when the storage type represents the type id. 164 template <typename Storage> 165 void registerSingletonStorageType(function_ref<void(Storage *)> initFn = {}) { 166 registerSingletonStorageType<Storage>(TypeID::get<Storage>(), initFn); 167 } 168 169 /// Gets a uniqued instance of 'Storage'. 'id' is the type id used when 170 /// registering the storage instance. 'initFn' is an optional parameter that 171 /// can be used to initialize a newly inserted storage instance. This function 172 /// is used for derived types that have complex storage or uniquing 173 /// constraints. 174 template <typename Storage, typename... Args> get(function_ref<void (Storage *)> initFn,TypeID id,Args &&...args)175 Storage *get(function_ref<void(Storage *)> initFn, TypeID id, 176 Args &&...args) { 177 // Construct a value of the derived key type. 178 auto derivedKey = getKey<Storage>(std::forward<Args>(args)...); 179 180 // Create a hash of the derived key. 181 unsigned hashValue = getHash<Storage>(derivedKey); 182 183 // Generate an equality function for the derived storage. 184 auto isEqual = [&derivedKey](const BaseStorage *existing) { 185 return static_cast<const Storage &>(*existing) == derivedKey; 186 }; 187 188 // Generate a constructor function for the derived storage. 189 auto ctorFn = [&](StorageAllocator &allocator) { 190 auto *storage = Storage::construct(allocator, derivedKey); 191 if (initFn) 192 initFn(storage); 193 return storage; 194 }; 195 196 // Get an instance for the derived storage. 197 return static_cast<Storage *>( 198 getParametricStorageTypeImpl(id, hashValue, isEqual, ctorFn)); 199 } 200 /// Utility override when the storage type represents the type id. 201 template <typename Storage, typename... Args> get(function_ref<void (Storage *)> initFn,Args &&...args)202 Storage *get(function_ref<void(Storage *)> initFn, Args &&...args) { 203 return get<Storage>(initFn, TypeID::get<Storage>(), 204 std::forward<Args>(args)...); 205 } 206 207 /// Gets a uniqued instance of 'Storage' which is a singleton storage type. 208 /// 'id' is the type id used when registering the storage instance. get(TypeID id)209 template <typename Storage> Storage *get(TypeID id) { 210 return static_cast<Storage *>(getSingletonImpl(id)); 211 } 212 /// Utility override when the storage type represents the type id. get()213 template <typename Storage> Storage *get() { 214 return get<Storage>(TypeID::get<Storage>()); 215 } 216 217 /// Test if there is a singleton storage uniquer initialized for the provided 218 /// TypeID. This is only useful for debugging/diagnostic purpose: the uniquer 219 /// is initialized when a dialect is loaded. 220 bool isSingletonStorageInitialized(TypeID id); 221 222 /// Test if there is a parametric storage uniquer initialized for the provided 223 /// TypeID. This is only useful for debugging/diagnostic purpose: the uniquer 224 /// is initialized when a dialect is loaded. 225 bool isParametricStorageInitialized(TypeID id); 226 227 /// Changes the mutable component of 'storage' by forwarding the trailing 228 /// arguments to the 'mutate' function of the derived class. 229 template <typename Storage, typename... Args> mutate(TypeID id,Storage * storage,Args &&...args)230 LogicalResult mutate(TypeID id, Storage *storage, Args &&...args) { 231 auto mutationFn = [&](StorageAllocator &allocator) -> LogicalResult { 232 return static_cast<Storage &>(*storage).mutate( 233 allocator, std::forward<Args>(args)...); 234 }; 235 return mutateImpl(id, storage, mutationFn); 236 } 237 238 private: 239 /// Implementation for getting/creating an instance of a derived type with 240 /// parametric storage. 241 BaseStorage *getParametricStorageTypeImpl( 242 TypeID id, unsigned hashValue, 243 function_ref<bool(const BaseStorage *)> isEqual, 244 function_ref<BaseStorage *(StorageAllocator &)> ctorFn); 245 246 /// Implementation for registering an instance of a derived type with 247 /// parametric storage. 248 void registerParametricStorageTypeImpl(TypeID id); 249 250 /// Implementation for getting an instance of a derived type with default 251 /// storage. 252 BaseStorage *getSingletonImpl(TypeID id); 253 254 /// Implementation for registering an instance of a derived type with default 255 /// storage. 256 void 257 registerSingletonImpl(TypeID id, 258 function_ref<BaseStorage *(StorageAllocator &)> ctorFn); 259 260 /// Implementation for mutating an instance of a derived storage. 261 LogicalResult 262 mutateImpl(TypeID id, BaseStorage *storage, 263 function_ref<LogicalResult(StorageAllocator &)> mutationFn); 264 265 /// The internal implementation class. 266 std::unique_ptr<detail::StorageUniquerImpl> impl; 267 268 //===--------------------------------------------------------------------===// 269 // Key Construction 270 //===--------------------------------------------------------------------===// 271 272 /// Used to construct an instance of 'ImplTy::KeyTy' if there is an 273 /// 'ImplTy::getKey' function for the provided arguments. 274 template <typename ImplTy, typename... Args> 275 static typename std::enable_if< 276 llvm::is_detected<detail::has_impltype_getkey_t, ImplTy, Args...>::value, 277 typename ImplTy::KeyTy>::type getKey(Args &&...args)278 getKey(Args &&...args) { 279 return ImplTy::getKey(args...); 280 } 281 /// If there is no 'ImplTy::getKey' method, then we try to directly construct 282 /// the 'ImplTy::KeyTy' with the provided arguments. 283 template <typename ImplTy, typename... Args> 284 static typename std::enable_if< 285 !llvm::is_detected<detail::has_impltype_getkey_t, ImplTy, Args...>::value, 286 typename ImplTy::KeyTy>::type getKey(Args &&...args)287 getKey(Args &&...args) { 288 return typename ImplTy::KeyTy(args...); 289 } 290 291 //===--------------------------------------------------------------------===// 292 // Key Hashing 293 //===--------------------------------------------------------------------===// 294 295 /// Used to generate a hash for the 'ImplTy::KeyTy' of a storage instance if 296 /// there is an 'ImplTy::hashKey' overload for 'DerivedKey'. 297 template <typename ImplTy, typename DerivedKey> 298 static typename std::enable_if< 299 llvm::is_detected<detail::has_impltype_hash_t, ImplTy, DerivedKey>::value, 300 ::llvm::hash_code>::type getHash(const DerivedKey & derivedKey)301 getHash(const DerivedKey &derivedKey) { 302 return ImplTy::hashKey(derivedKey); 303 } 304 /// If there is no 'ImplTy::hashKey' default to using the 'llvm::DenseMapInfo' 305 /// definition for 'DerivedKey' for generating a hash. 306 template <typename ImplTy, typename DerivedKey> 307 static typename std::enable_if<!llvm::is_detected<detail::has_impltype_hash_t, 308 ImplTy, DerivedKey>::value, 309 ::llvm::hash_code>::type getHash(const DerivedKey & derivedKey)310 getHash(const DerivedKey &derivedKey) { 311 return DenseMapInfo<DerivedKey>::getHashValue(derivedKey); 312 } 313 }; 314 } // end namespace mlir 315 316 #endif 317