1 //===- StorageUniquer.h - Common Storage Class Uniquer ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_SUPPORT_STORAGEUNIQUER_H
10 #define MLIR_SUPPORT_STORAGEUNIQUER_H
11 
12 #include "mlir/Support/LLVM.h"
13 #include "mlir/Support/LogicalResult.h"
14 #include "mlir/Support/TypeID.h"
15 #include "llvm/ADT/DenseSet.h"
16 #include "llvm/Support/Allocator.h"
17 
18 namespace mlir {
19 namespace detail {
20 struct StorageUniquerImpl;
21 
22 /// Trait to check if ImplTy provides a 'getKey' method with types 'Args'.
23 template <typename ImplTy, typename... Args>
24 using has_impltype_getkey_t = decltype(ImplTy::getKey(std::declval<Args>()...));
25 
26 /// Trait to check if ImplTy provides a 'hashKey' method for 'T'.
27 template <typename ImplTy, typename T>
28 using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval<T>()));
29 } // namespace detail
30 
31 /// A utility class to get or create instances of "storage classes". These
32 /// storage classes must derive from 'StorageUniquer::BaseStorage'.
33 ///
34 /// For non-parametric storage classes, i.e. singleton classes, nothing else is
35 /// needed. Instances of these classes can be created by calling `get` without
36 /// trailing arguments.
37 ///
38 /// Otherwise, the parametric storage classes may be created with `get`,
39 /// and must respect the following:
40 ///    - Define a type alias, KeyTy, to a type that uniquely identifies the
41 ///      instance of the storage class.
42 ///      * The key type must be constructible from the values passed into the
43 ///        getComplex call.
44 ///      * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
45 ///        storage class must define a hashing method:
46 ///         'static unsigned hashKey(const KeyTy &)'
47 ///
48 ///    - Provide a method, 'bool operator==(const KeyTy &) const', to
49 ///      compare the storage instance against an instance of the key type.
50 ///
51 ///    - Provide a static construction method:
52 ///        'DerivedStorage *construct(StorageAllocator &, const KeyTy &key)'
53 ///      that builds a unique instance of the derived storage. The arguments to
54 ///      this function are an allocator to store any uniqued data and the key
55 ///      type for this storage.
56 ///
57 ///    - Provide a cleanup method:
58 ///        'void cleanup()'
59 ///      that is called when erasing a storage instance. This should cleanup any
60 ///      fields of the storage as necessary and not attempt to free the memory
61 ///      of the storage itself.
62 ///
63 /// Storage classes may have an optional mutable component, which must not take
64 /// part in the unique immutable key. In this case, storage classes may be
65 /// mutated with `mutate` and must additionally respect the following:
66 ///    - Provide a mutation method:
67 ///        'LogicalResult mutate(StorageAllocator &, <...>)'
68 ///      that is called when mutating a storage instance. The first argument is
69 ///      an allocator to store any mutable data, and the remaining arguments are
70 ///      forwarded from the call site. The storage can be mutated at any time
71 ///      after creation. Care must be taken to avoid excessive mutation since
72 ///      the allocated storage can keep containing previous states. The return
73 ///      value of the function is used to indicate whether the mutation was
74 ///      successful, e.g., to limit the number of mutations or enable deferred
75 ///      one-time assignment of the mutable component.
76 ///
77 /// All storage classes must be registered with the uniquer via
78 /// `registerStorageType` using an appropriate unique `TypeID` for the storage
79 /// class.
80 class StorageUniquer {
81 public:
82   /// This class acts as the base storage that all storage classes must derived
83   /// from.
84   class BaseStorage {
85   protected:
86     BaseStorage() = default;
87   };
88 
89   /// This is a utility allocator used to allocate memory for instances of
90   /// derived types.
91   class StorageAllocator {
92   public:
93     /// Copy the specified array of elements into memory managed by our bump
94     /// pointer allocator.  This assumes the elements are all PODs.
copyInto(ArrayRef<T> elements)95     template <typename T> ArrayRef<T> copyInto(ArrayRef<T> elements) {
96       if (elements.empty())
97         return llvm::None;
98       auto result = allocator.Allocate<T>(elements.size());
99       std::uninitialized_copy(elements.begin(), elements.end(), result);
100       return ArrayRef<T>(result, elements.size());
101     }
102 
103     /// Copy the provided string into memory managed by our bump pointer
104     /// allocator.
copyInto(StringRef str)105     StringRef copyInto(StringRef str) {
106       auto result = copyInto(ArrayRef<char>(str.data(), str.size()));
107       return StringRef(result.data(), str.size());
108     }
109 
110     /// Allocate an instance of the provided type.
allocate()111     template <typename T> T *allocate() { return allocator.Allocate<T>(); }
112 
113     /// Allocate 'size' bytes of 'alignment' aligned memory.
allocate(size_t size,size_t alignment)114     void *allocate(size_t size, size_t alignment) {
115       return allocator.Allocate(size, alignment);
116     }
117 
118     /// Returns true if this allocator allocated the provided object pointer.
allocated(const void * ptr)119     bool allocated(const void *ptr) {
120       return allocator.identifyObject(ptr).hasValue();
121     }
122 
123   private:
124     /// The raw allocator for type storage objects.
125     llvm::BumpPtrAllocator allocator;
126   };
127 
128   StorageUniquer();
129   ~StorageUniquer();
130 
131   /// Set the flag specifying if multi-threading is disabled within the uniquer.
132   void disableMultithreading(bool disable = true);
133 
134   /// Register a new parametric storage class, this is necessary to create
135   /// instances of this class type. `id` is the type identifier that will be
136   /// used to identify this type when creating instances of it via 'get'.
registerParametricStorageType(TypeID id)137   template <typename Storage> void registerParametricStorageType(TypeID id) {
138     registerParametricStorageTypeImpl(id);
139   }
140   /// Utility override when the storage type represents the type id.
registerParametricStorageType()141   template <typename Storage> void registerParametricStorageType() {
142     registerParametricStorageType<Storage>(TypeID::get<Storage>());
143   }
144   /// Register a new singleton storage class, this is necessary to get the
145   /// singletone instance. `id` is the type identifier that will be used to
146   /// access the singleton instance via 'get'. An optional initialization
147   /// function may also be provided to initialize the newly created storage
148   /// instance, and used when the singleton instance is created.
149   template <typename Storage>
registerSingletonStorageType(TypeID id,function_ref<void (Storage *)> initFn)150   void registerSingletonStorageType(TypeID id,
151                                     function_ref<void(Storage *)> initFn) {
152     auto ctorFn = [&](StorageAllocator &allocator) {
153       auto *storage = new (allocator.allocate<Storage>()) Storage();
154       if (initFn)
155         initFn(storage);
156       return storage;
157     };
158     registerSingletonImpl(id, ctorFn);
159   }
registerSingletonStorageType(TypeID id)160   template <typename Storage> void registerSingletonStorageType(TypeID id) {
161     registerSingletonStorageType<Storage>(id, llvm::None);
162   }
163   /// Utility override when the storage type represents the type id.
164   template <typename Storage>
165   void registerSingletonStorageType(function_ref<void(Storage *)> initFn = {}) {
166     registerSingletonStorageType<Storage>(TypeID::get<Storage>(), initFn);
167   }
168 
169   /// Gets a uniqued instance of 'Storage'. 'id' is the type id used when
170   /// registering the storage instance. 'initFn' is an optional parameter that
171   /// can be used to initialize a newly inserted storage instance. This function
172   /// is used for derived types that have complex storage or uniquing
173   /// constraints.
174   template <typename Storage, typename... Args>
get(function_ref<void (Storage *)> initFn,TypeID id,Args &&...args)175   Storage *get(function_ref<void(Storage *)> initFn, TypeID id,
176                Args &&...args) {
177     // Construct a value of the derived key type.
178     auto derivedKey = getKey<Storage>(std::forward<Args>(args)...);
179 
180     // Create a hash of the derived key.
181     unsigned hashValue = getHash<Storage>(derivedKey);
182 
183     // Generate an equality function for the derived storage.
184     auto isEqual = [&derivedKey](const BaseStorage *existing) {
185       return static_cast<const Storage &>(*existing) == derivedKey;
186     };
187 
188     // Generate a constructor function for the derived storage.
189     auto ctorFn = [&](StorageAllocator &allocator) {
190       auto *storage = Storage::construct(allocator, derivedKey);
191       if (initFn)
192         initFn(storage);
193       return storage;
194     };
195 
196     // Get an instance for the derived storage.
197     return static_cast<Storage *>(
198         getParametricStorageTypeImpl(id, hashValue, isEqual, ctorFn));
199   }
200   /// Utility override when the storage type represents the type id.
201   template <typename Storage, typename... Args>
get(function_ref<void (Storage *)> initFn,Args &&...args)202   Storage *get(function_ref<void(Storage *)> initFn, Args &&...args) {
203     return get<Storage>(initFn, TypeID::get<Storage>(),
204                         std::forward<Args>(args)...);
205   }
206 
207   /// Gets a uniqued instance of 'Storage' which is a singleton storage type.
208   /// 'id' is the type id used when registering the storage instance.
get(TypeID id)209   template <typename Storage> Storage *get(TypeID id) {
210     return static_cast<Storage *>(getSingletonImpl(id));
211   }
212   /// Utility override when the storage type represents the type id.
get()213   template <typename Storage> Storage *get() {
214     return get<Storage>(TypeID::get<Storage>());
215   }
216 
217   /// Test if there is a singleton storage uniquer initialized for the provided
218   /// TypeID. This is only useful for debugging/diagnostic purpose: the uniquer
219   /// is initialized when a dialect is loaded.
220   bool isSingletonStorageInitialized(TypeID id);
221 
222   /// Test if there is a parametric storage uniquer initialized for the provided
223   /// TypeID. This is only useful for debugging/diagnostic purpose: the uniquer
224   /// is initialized when a dialect is loaded.
225   bool isParametricStorageInitialized(TypeID id);
226 
227   /// Changes the mutable component of 'storage' by forwarding the trailing
228   /// arguments to the 'mutate' function of the derived class.
229   template <typename Storage, typename... Args>
mutate(TypeID id,Storage * storage,Args &&...args)230   LogicalResult mutate(TypeID id, Storage *storage, Args &&...args) {
231     auto mutationFn = [&](StorageAllocator &allocator) -> LogicalResult {
232       return static_cast<Storage &>(*storage).mutate(
233           allocator, std::forward<Args>(args)...);
234     };
235     return mutateImpl(id, storage, mutationFn);
236   }
237 
238 private:
239   /// Implementation for getting/creating an instance of a derived type with
240   /// parametric storage.
241   BaseStorage *getParametricStorageTypeImpl(
242       TypeID id, unsigned hashValue,
243       function_ref<bool(const BaseStorage *)> isEqual,
244       function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
245 
246   /// Implementation for registering an instance of a derived type with
247   /// parametric storage.
248   void registerParametricStorageTypeImpl(TypeID id);
249 
250   /// Implementation for getting an instance of a derived type with default
251   /// storage.
252   BaseStorage *getSingletonImpl(TypeID id);
253 
254   /// Implementation for registering an instance of a derived type with default
255   /// storage.
256   void
257   registerSingletonImpl(TypeID id,
258                         function_ref<BaseStorage *(StorageAllocator &)> ctorFn);
259 
260   /// Implementation for mutating an instance of a derived storage.
261   LogicalResult
262   mutateImpl(TypeID id, BaseStorage *storage,
263              function_ref<LogicalResult(StorageAllocator &)> mutationFn);
264 
265   /// The internal implementation class.
266   std::unique_ptr<detail::StorageUniquerImpl> impl;
267 
268   //===--------------------------------------------------------------------===//
269   // Key Construction
270   //===--------------------------------------------------------------------===//
271 
272   /// Used to construct an instance of 'ImplTy::KeyTy' if there is an
273   /// 'ImplTy::getKey' function for the provided arguments.
274   template <typename ImplTy, typename... Args>
275   static typename std::enable_if<
276       llvm::is_detected<detail::has_impltype_getkey_t, ImplTy, Args...>::value,
277       typename ImplTy::KeyTy>::type
getKey(Args &&...args)278   getKey(Args &&...args) {
279     return ImplTy::getKey(args...);
280   }
281   /// If there is no 'ImplTy::getKey' method, then we try to directly construct
282   /// the 'ImplTy::KeyTy' with the provided arguments.
283   template <typename ImplTy, typename... Args>
284   static typename std::enable_if<
285       !llvm::is_detected<detail::has_impltype_getkey_t, ImplTy, Args...>::value,
286       typename ImplTy::KeyTy>::type
getKey(Args &&...args)287   getKey(Args &&...args) {
288     return typename ImplTy::KeyTy(args...);
289   }
290 
291   //===--------------------------------------------------------------------===//
292   // Key Hashing
293   //===--------------------------------------------------------------------===//
294 
295   /// Used to generate a hash for the 'ImplTy::KeyTy' of a storage instance if
296   /// there is an 'ImplTy::hashKey' overload for 'DerivedKey'.
297   template <typename ImplTy, typename DerivedKey>
298   static typename std::enable_if<
299       llvm::is_detected<detail::has_impltype_hash_t, ImplTy, DerivedKey>::value,
300       ::llvm::hash_code>::type
getHash(const DerivedKey & derivedKey)301   getHash(const DerivedKey &derivedKey) {
302     return ImplTy::hashKey(derivedKey);
303   }
304   /// If there is no 'ImplTy::hashKey' default to using the 'llvm::DenseMapInfo'
305   /// definition for 'DerivedKey' for generating a hash.
306   template <typename ImplTy, typename DerivedKey>
307   static typename std::enable_if<!llvm::is_detected<detail::has_impltype_hash_t,
308                                                     ImplTy, DerivedKey>::value,
309                                  ::llvm::hash_code>::type
getHash(const DerivedKey & derivedKey)310   getHash(const DerivedKey &derivedKey) {
311     return DenseMapInfo<DerivedKey>::getHashValue(derivedKey);
312   }
313 };
314 } // end namespace mlir
315 
316 #endif
317