1 //===- SPIRVTypes.h - MLIR SPIR-V Types -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares the types in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPIRV_SPIRVTYPES_H_
14 #define MLIR_DIALECT_SPIRV_SPIRVTYPES_H_
15 
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Diagnostics.h"
18 #include "mlir/IR/Location.h"
19 #include "mlir/IR/TypeSupport.h"
20 #include "mlir/IR/Types.h"
21 
22 #include <tuple>
23 
24 // Forward declare enum classes related to op availability. Their definitions
25 // are in the TableGen'erated SPIRVEnums.h.inc and can be referenced by other
26 // declarations in SPIRVEnums.h.inc.
27 namespace mlir {
28 namespace spirv {
29 enum class Version : uint32_t;
30 enum class Extension;
31 enum class Capability : uint32_t;
32 } // namespace spirv
33 } // namespace mlir
34 
35 // Pull in all enum type definitions and utility function declarations
36 #include "mlir/Dialect/SPIRV/SPIRVEnums.h.inc"
37 // Pull in all enum type availability query function declarations
38 #include "mlir/Dialect/SPIRV/SPIRVEnumAvailability.h.inc"
39 
40 namespace mlir {
41 namespace spirv {
42 /// Returns the implied extensions for the given version. These extensions are
43 /// incorporated into the current version so they are implicitly declared when
44 /// targeting the given version.
45 ArrayRef<Extension> getImpliedExtensions(Version version);
46 
47 /// Returns the directly implied capabilities for the given capability. These
48 /// capabilities are implicitly declared by the given capability.
49 ArrayRef<Capability> getDirectImpliedCapabilities(Capability cap);
50 /// Returns the recursively implied capabilities for the given capability. These
51 /// capabilities are implicitly declared by the given capability. Compared to
52 /// the above function, this function collects implied capabilities recursively:
53 /// if an implicitly declared capability implicitly declares a third one, the
54 /// third one will also be returned.
55 SmallVector<Capability, 0> getRecursiveImpliedCapabilities(Capability cap);
56 
57 namespace detail {
58 struct ArrayTypeStorage;
59 struct CooperativeMatrixTypeStorage;
60 struct ImageTypeStorage;
61 struct MatrixTypeStorage;
62 struct PointerTypeStorage;
63 struct RuntimeArrayTypeStorage;
64 struct StructTypeStorage;
65 
66 } // namespace detail
67 
68 // Base SPIR-V type for providing availability queries.
69 class SPIRVType : public Type {
70 public:
71   using Type::Type;
72 
73   static bool classof(Type type);
74 
75   bool isScalarOrVector();
76 
77   /// The extension requirements for each type are following the
78   /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
79   /// convention.
80   using ExtensionArrayRefVector = SmallVectorImpl<ArrayRef<Extension>>;
81 
82   /// Appends to `extensions` the extensions needed for this type to appear in
83   /// the given `storage` class. This method does not guarantee the uniqueness
84   /// of extensions; the same extension may be appended multiple times.
85   void getExtensions(ExtensionArrayRefVector &extensions,
86                      Optional<StorageClass> storage = llvm::None);
87 
88   /// The capability requirements for each type are following the
89   /// ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D))
90   /// convention.
91   using CapabilityArrayRefVector = SmallVectorImpl<ArrayRef<Capability>>;
92 
93   /// Appends to `capabilities` the capabilities needed for this type to appear
94   /// in the given `storage` class. This method does not guarantee the
95   /// uniqueness of capabilities; the same capability may be appended multiple
96   /// times.
97   void getCapabilities(CapabilityArrayRefVector &capabilities,
98                        Optional<StorageClass> storage = llvm::None);
99 
100   /// Returns the size in bytes for each type. If no size can be calculated,
101   /// returns `llvm::None`. Note that if the type has explicit layout, it is
102   /// also taken into account in calculation.
103   Optional<int64_t> getSizeInBytes();
104 };
105 
106 // SPIR-V scalar type: bool type, integer type, floating point type.
107 class ScalarType : public SPIRVType {
108 public:
109   using SPIRVType::SPIRVType;
110 
111   static bool classof(Type type);
112 
113   /// Returns true if the given integer type is valid for the SPIR-V dialect.
114   static bool isValid(FloatType);
115   /// Returns true if the given float type is valid for the SPIR-V dialect.
116   static bool isValid(IntegerType);
117 
118   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
119                      Optional<StorageClass> storage = llvm::None);
120   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
121                        Optional<StorageClass> storage = llvm::None);
122 
123   Optional<int64_t> getSizeInBytes();
124 };
125 
126 // SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
127 class CompositeType : public SPIRVType {
128 public:
129   using SPIRVType::SPIRVType;
130 
131   static bool classof(Type type);
132 
133   /// Returns true if the given vector type is valid for the SPIR-V dialect.
134   static bool isValid(VectorType);
135 
136   /// Return the number of elements of the type. This should only be called if
137   /// hasCompileTimeKnownNumElements is true.
138   unsigned getNumElements() const;
139 
140   Type getElementType(unsigned) const;
141 
142   /// Return true if the number of elements is known at compile time and is not
143   /// implementation dependent.
144   bool hasCompileTimeKnownNumElements() const;
145 
146   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
147                      Optional<StorageClass> storage = llvm::None);
148   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
149                        Optional<StorageClass> storage = llvm::None);
150 
151   Optional<int64_t> getSizeInBytes();
152 };
153 
154 // SPIR-V array type
155 class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
156                                         detail::ArrayTypeStorage> {
157 public:
158   using Base::Base;
159 
160   static ArrayType get(Type elementType, unsigned elementCount);
161 
162   /// Returns an array type with the given stride in bytes.
163   static ArrayType get(Type elementType, unsigned elementCount,
164                        unsigned stride);
165 
166   unsigned getNumElements() const;
167 
168   Type getElementType() const;
169 
170   /// Returns the array stride in bytes. 0 means no stride decorated on this
171   /// type.
172   unsigned getArrayStride() const;
173 
174   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
175                      Optional<StorageClass> storage = llvm::None);
176   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
177                        Optional<StorageClass> storage = llvm::None);
178 
179   /// Returns the array size in bytes. Since array type may have an explicit
180   /// stride declaration (in bytes), we also include it in the calculation.
181   Optional<int64_t> getSizeInBytes();
182 };
183 
184 // SPIR-V image type
185 class ImageType
186     : public Type::TypeBase<ImageType, SPIRVType, detail::ImageTypeStorage> {
187 public:
188   using Base::Base;
189 
190   static ImageType
191   get(Type elementType, Dim dim,
192       ImageDepthInfo depth = ImageDepthInfo::DepthUnknown,
193       ImageArrayedInfo arrayed = ImageArrayedInfo::NonArrayed,
194       ImageSamplingInfo samplingInfo = ImageSamplingInfo::SingleSampled,
195       ImageSamplerUseInfo samplerUse = ImageSamplerUseInfo::SamplerUnknown,
196       ImageFormat format = ImageFormat::Unknown) {
197     return ImageType::get(
198         std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
199                    ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>(
200             elementType, dim, depth, arrayed, samplingInfo, samplerUse,
201             format));
202   }
203 
204   static ImageType
205       get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
206                      ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>);
207 
208   Type getElementType() const;
209   Dim getDim() const;
210   ImageDepthInfo getDepthInfo() const;
211   ImageArrayedInfo getArrayedInfo() const;
212   ImageSamplingInfo getSamplingInfo() const;
213   ImageSamplerUseInfo getSamplerUseInfo() const;
214   ImageFormat getImageFormat() const;
215   // TODO: Add support for Access qualifier
216 
217   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
218                      Optional<StorageClass> storage = llvm::None);
219   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
220                        Optional<StorageClass> storage = llvm::None);
221 };
222 
223 // SPIR-V pointer type
224 class PointerType : public Type::TypeBase<PointerType, SPIRVType,
225                                           detail::PointerTypeStorage> {
226 public:
227   using Base::Base;
228 
229   static PointerType get(Type pointeeType, StorageClass storageClass);
230 
231   Type getPointeeType() const;
232 
233   StorageClass getStorageClass() const;
234 
235   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
236                      Optional<StorageClass> storage = llvm::None);
237   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
238                        Optional<StorageClass> storage = llvm::None);
239 };
240 
241 // SPIR-V run-time array type
242 class RuntimeArrayType
243     : public Type::TypeBase<RuntimeArrayType, SPIRVType,
244                             detail::RuntimeArrayTypeStorage> {
245 public:
246   using Base::Base;
247 
248   static RuntimeArrayType get(Type elementType);
249 
250   /// Returns a runtime array type with the given stride in bytes.
251   static RuntimeArrayType get(Type elementType, unsigned stride);
252 
253   Type getElementType() const;
254 
255   /// Returns the array stride in bytes. 0 means no stride decorated on this
256   /// type.
257   unsigned getArrayStride() const;
258 
259   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
260                      Optional<StorageClass> storage = llvm::None);
261   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
262                        Optional<StorageClass> storage = llvm::None);
263 };
264 
265 /// SPIR-V struct type. Two kinds of struct types are supported:
266 /// - Literal: a literal struct type is uniqued by its fields (types + offset
267 /// info + decoration info).
268 /// - Identified: an indentified struct type is uniqued by its string identifier
269 /// (name). This is useful in representing recursive structs. For example, the
270 /// following C struct:
271 ///
272 /// struct A {
273 ///   A* next;
274 /// };
275 ///
276 /// would be represented in MLIR as:
277 ///
278 /// !spv.struct<A, (!spv.ptr<!spv.struct<A>, Generic>)>
279 ///
280 /// In the above, expressing recursive struct types is accomplished by giving a
281 /// recursive struct a unique identified and using that identifier in the struct
282 /// definition for recursive references.
283 class StructType : public Type::TypeBase<StructType, CompositeType,
284                                          detail::StructTypeStorage> {
285 public:
286   using Base::Base;
287 
288   // Type for specifying the offset of the struct members
289   using OffsetInfo = uint32_t;
290 
291   // Type for specifying the decoration(s) on struct members
292   struct MemberDecorationInfo {
293     uint32_t memberIndex : 31;
294     uint32_t hasValue : 1;
295     Decoration decoration;
296     uint32_t decorationValue;
297 
MemberDecorationInfoMemberDecorationInfo298     MemberDecorationInfo(uint32_t index, uint32_t hasValue,
299                          Decoration decoration, uint32_t decorationValue)
300         : memberIndex(index), hasValue(hasValue), decoration(decoration),
301           decorationValue(decorationValue) {}
302 
303     bool operator==(const MemberDecorationInfo &other) const {
304       return (this->memberIndex == other.memberIndex) &&
305              (this->decoration == other.decoration) &&
306              (this->decorationValue == other.decorationValue);
307     }
308 
309     bool operator<(const MemberDecorationInfo &other) const {
310       return this->memberIndex < other.memberIndex ||
311              (this->memberIndex == other.memberIndex &&
312               static_cast<uint32_t>(this->decoration) <
313                   static_cast<uint32_t>(other.decoration));
314     }
315   };
316 
317   /// Construct a literal StructType with at least one member.
318   static StructType get(ArrayRef<Type> memberTypes,
319                         ArrayRef<OffsetInfo> offsetInfo = {},
320                         ArrayRef<MemberDecorationInfo> memberDecorations = {});
321 
322   /// Construct an identified StructType. This creates a StructType whose body
323   /// (member types, offset info, and decorations) is not set yet. A call to
324   /// StructType::trySetBody(...) must follow when the StructType contents are
325   /// available (e.g. parsed or deserialized).
326   ///
327   /// Note: If another thread creates (or had already created) a struct with the
328   /// same identifier, that struct will be returned as a result.
329   static StructType getIdentified(MLIRContext *context, StringRef identifier);
330 
331   /// Construct a (possibly identified) StructType with no members.
332   ///
333   /// Note: this method might fail in a multi-threaded setup if another thread
334   /// created an identified struct with the same identifier but with different
335   /// contents before returning. In which case, an empty (default-constructed)
336   /// StructType is returned.
337   static StructType getEmpty(MLIRContext *context, StringRef identifier = "");
338 
339   /// For literal structs, return an empty string.
340   /// For identified structs, return the struct's identifier.
341   StringRef getIdentifier() const;
342 
343   /// Returns true if the StructType is identified.
344   bool isIdentified() const;
345 
346   unsigned getNumElements() const;
347 
348   Type getElementType(unsigned) const;
349 
350   /// Range class for element types.
351   class ElementTypeRange
352       : public ::llvm::detail::indexed_accessor_range_base<
353             ElementTypeRange, const Type *, Type, Type, Type> {
354   private:
355     using RangeBaseT::RangeBaseT;
356 
357     /// See `llvm::detail::indexed_accessor_range_base` for details.
offset_base(const Type * object,ptrdiff_t index)358     static const Type *offset_base(const Type *object, ptrdiff_t index) {
359       return object + index;
360     }
361     /// See `llvm::detail::indexed_accessor_range_base` for details.
dereference_iterator(const Type * object,ptrdiff_t index)362     static Type dereference_iterator(const Type *object, ptrdiff_t index) {
363       return object[index];
364     }
365 
366     /// Allow base class access to `offset_base` and `dereference_iterator`.
367     friend RangeBaseT;
368   };
369 
370   ElementTypeRange getElementTypes() const;
371 
372   bool hasOffset() const;
373 
374   uint64_t getMemberOffset(unsigned) const;
375 
376   // Returns in `memberDecorations` the Decorations (apart from Offset)
377   // associated with all members of the StructType.
378   void getMemberDecorations(SmallVectorImpl<StructType::MemberDecorationInfo>
379                                 &memberDecorations) const;
380 
381   // Returns in `decorationsInfo` all the Decorations (apart from Offset)
382   // associated with the `i`-th member of the StructType.
383   void getMemberDecorations(unsigned i,
384                             SmallVectorImpl<StructType::MemberDecorationInfo>
385                                 &decorationsInfo) const;
386 
387   /// Sets the contents of an incomplete identified StructType. This method must
388   /// be called only for identified StructTypes and it must be called only once
389   /// per instance. Otherwise, failure() is returned.
390   LogicalResult
391   trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
392              ArrayRef<MemberDecorationInfo> memberDecorations = {});
393 
394   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
395                      Optional<StorageClass> storage = llvm::None);
396   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
397                        Optional<StorageClass> storage = llvm::None);
398 };
399 
400 llvm::hash_code
401 hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
402 
403 // SPIR-V cooperative matrix type
404 class CooperativeMatrixNVType
405     : public Type::TypeBase<CooperativeMatrixNVType, CompositeType,
406                             detail::CooperativeMatrixTypeStorage> {
407 public:
408   using Base::Base;
409 
410   static CooperativeMatrixNVType get(Type elementType, Scope scope,
411                                      unsigned rows, unsigned columns);
412   Type getElementType() const;
413 
414   /// Return the scope of the cooperative matrix.
415   Scope getScope() const;
416   /// return the number of rows of the matrix.
417   unsigned getRows() const;
418   /// return the number of columns of the matrix.
419   unsigned getColumns() const;
420 
421   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
422                      Optional<StorageClass> storage = llvm::None);
423   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
424                        Optional<StorageClass> storage = llvm::None);
425 };
426 
427 // SPIR-V matrix type
428 class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
429                                          detail::MatrixTypeStorage> {
430 public:
431   using Base::Base;
432 
433   static MatrixType get(Type columnType, uint32_t columnCount);
434 
435   static MatrixType getChecked(Type columnType, uint32_t columnCount,
436                                Location location);
437 
438   static LogicalResult verifyConstructionInvariants(Location loc,
439                                                     Type columnType,
440                                                     uint32_t columnCount);
441 
442   /// Returns true if the matrix elements are vectors of float elements.
443   static bool isValidColumnType(Type columnType);
444 
445   Type getColumnType() const;
446 
447   /// Returns the number of rows.
448   unsigned getNumRows() const;
449 
450   /// Returns the number of columns.
451   unsigned getNumColumns() const;
452 
453   /// Returns total number of elements (rows*columns).
454   unsigned getNumElements() const;
455 
456   /// Returns the elements' type (i.e, single element type).
457   Type getElementType() const;
458 
459   void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
460                      Optional<StorageClass> storage = llvm::None);
461   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
462                        Optional<StorageClass> storage = llvm::None);
463 };
464 
465 } // end namespace spirv
466 } // end namespace mlir
467 
468 #endif // MLIR_DIALECT_SPIRV_SPIRVTYPES_H_
469