1 //===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
14 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Identifier.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/ADT/StringSwitch.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 
24 using namespace mlir;
25 using namespace mlir::spirv;
26 
27 // Pull in all enum utility function definitions
28 #include "mlir/Dialect/SPIRV/SPIRVEnums.cpp.inc"
29 // Pull in all enum type availability query function definitions
30 #include "mlir/Dialect/SPIRV/SPIRVEnumAvailability.cpp.inc"
31 
32 //===----------------------------------------------------------------------===//
33 // Availability relationship
34 //===----------------------------------------------------------------------===//
35 
getImpliedExtensions(Version version)36 ArrayRef<Extension> spirv::getImpliedExtensions(Version version) {
37   // Note: the following lists are from "Appendix A: Changes" of the spec.
38 
39 #define V_1_3_IMPLIED_EXTS                                                     \
40   Extension::SPV_KHR_shader_draw_parameters, Extension::SPV_KHR_16bit_storage, \
41       Extension::SPV_KHR_device_group, Extension::SPV_KHR_multiview,           \
42       Extension::SPV_KHR_storage_buffer_storage_class,                         \
43       Extension::SPV_KHR_variable_pointers
44 
45 #define V_1_4_IMPLIED_EXTS                                                     \
46   Extension::SPV_KHR_no_integer_wrap_decoration,                               \
47       Extension::SPV_GOOGLE_decorate_string,                                   \
48       Extension::SPV_GOOGLE_hlsl_functionality1,                               \
49       Extension::SPV_KHR_float_controls
50 
51 #define V_1_5_IMPLIED_EXTS                                                     \
52   Extension::SPV_KHR_8bit_storage, Extension::SPV_EXT_descriptor_indexing,     \
53       Extension::SPV_EXT_shader_viewport_index_layer,                          \
54       Extension::SPV_EXT_physical_storage_buffer,                              \
55       Extension::SPV_KHR_physical_storage_buffer,                              \
56       Extension::SPV_KHR_vulkan_memory_model
57 
58   switch (version) {
59   default:
60     return {};
61   case Version::V_1_3: {
62     // The following manual ArrayRef constructor call is to satisfy GCC 5.
63     static const Extension exts[] = {V_1_3_IMPLIED_EXTS};
64     return ArrayRef<Extension>(exts, llvm::array_lengthof(exts));
65   }
66   case Version::V_1_4: {
67     static const Extension exts[] = {V_1_3_IMPLIED_EXTS, V_1_4_IMPLIED_EXTS};
68     return ArrayRef<Extension>(exts, llvm::array_lengthof(exts));
69   }
70   case Version::V_1_5: {
71     static const Extension exts[] = {V_1_3_IMPLIED_EXTS, V_1_4_IMPLIED_EXTS,
72                                      V_1_5_IMPLIED_EXTS};
73     return ArrayRef<Extension>(exts, llvm::array_lengthof(exts));
74   }
75   }
76 
77 #undef V_1_5_IMPLIED_EXTS
78 #undef V_1_4_IMPLIED_EXTS
79 #undef V_1_3_IMPLIED_EXTS
80 }
81 
82 // Pull in utility function definition for implied capabilities
83 #include "mlir/Dialect/SPIRV/SPIRVCapabilityImplication.inc"
84 
85 SmallVector<Capability, 0>
getRecursiveImpliedCapabilities(Capability cap)86 spirv::getRecursiveImpliedCapabilities(Capability cap) {
87   ArrayRef<Capability> directCaps = getDirectImpliedCapabilities(cap);
88   llvm::SetVector<Capability, SmallVector<Capability, 0>> allCaps(
89       directCaps.begin(), directCaps.end());
90 
91   // TODO: This is insufficient; find a better way to handle this
92   // (e.g., using static lists) if this turns out to be a bottleneck.
93   for (unsigned i = 0; i < allCaps.size(); ++i)
94     for (Capability c : getDirectImpliedCapabilities(allCaps[i]))
95       allCaps.insert(c);
96 
97   return allCaps.takeVector();
98 }
99 
100 //===----------------------------------------------------------------------===//
101 // ArrayType
102 //===----------------------------------------------------------------------===//
103 
104 struct spirv::detail::ArrayTypeStorage : public TypeStorage {
105   using KeyTy = std::tuple<Type, unsigned, unsigned>;
106 
constructspirv::detail::ArrayTypeStorage107   static ArrayTypeStorage *construct(TypeStorageAllocator &allocator,
108                                      const KeyTy &key) {
109     return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
110   }
111 
operator ==spirv::detail::ArrayTypeStorage112   bool operator==(const KeyTy &key) const {
113     return key == KeyTy(elementType, elementCount, stride);
114   }
115 
ArrayTypeStoragespirv::detail::ArrayTypeStorage116   ArrayTypeStorage(const KeyTy &key)
117       : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
118         stride(std::get<2>(key)) {}
119 
120   Type elementType;
121   unsigned elementCount;
122   unsigned stride;
123 };
124 
get(Type elementType,unsigned elementCount)125 ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
126   assert(elementCount && "ArrayType needs at least one element");
127   return Base::get(elementType.getContext(), elementType, elementCount,
128                    /*stride=*/0);
129 }
130 
get(Type elementType,unsigned elementCount,unsigned stride)131 ArrayType ArrayType::get(Type elementType, unsigned elementCount,
132                          unsigned stride) {
133   assert(elementCount && "ArrayType needs at least one element");
134   return Base::get(elementType.getContext(), elementType, elementCount, stride);
135 }
136 
getNumElements() const137 unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
138 
getElementType() const139 Type ArrayType::getElementType() const { return getImpl()->elementType; }
140 
getArrayStride() const141 unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
142 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)143 void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
144                               Optional<StorageClass> storage) {
145   getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
146 }
147 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)148 void ArrayType::getCapabilities(
149     SPIRVType::CapabilityArrayRefVector &capabilities,
150     Optional<StorageClass> storage) {
151   getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
152 }
153 
getSizeInBytes()154 Optional<int64_t> ArrayType::getSizeInBytes() {
155   auto elementType = getElementType().cast<SPIRVType>();
156   Optional<int64_t> size = elementType.getSizeInBytes();
157   if (!size)
158     return llvm::None;
159   return (*size + getArrayStride()) * getNumElements();
160 }
161 
162 //===----------------------------------------------------------------------===//
163 // CompositeType
164 //===----------------------------------------------------------------------===//
165 
classof(Type type)166 bool CompositeType::classof(Type type) {
167   if (auto vectorType = type.dyn_cast<VectorType>())
168     return isValid(vectorType);
169   return type
170       .isa<spirv::ArrayType, spirv::CooperativeMatrixNVType, spirv::MatrixType,
171            spirv::RuntimeArrayType, spirv::StructType>();
172 }
173 
isValid(VectorType type)174 bool CompositeType::isValid(VectorType type) {
175   switch (type.getNumElements()) {
176   case 2:
177   case 3:
178   case 4:
179   case 8:
180   case 16:
181     break;
182   default:
183     return false;
184   }
185   return type.getRank() == 1 && type.getElementType().isa<ScalarType>();
186 }
187 
getElementType(unsigned index) const188 Type CompositeType::getElementType(unsigned index) const {
189   return TypeSwitch<Type, Type>(*this)
190       .Case<ArrayType, CooperativeMatrixNVType, RuntimeArrayType, VectorType>(
191           [](auto type) { return type.getElementType(); })
192       .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
193       .Case<StructType>(
194           [index](StructType type) { return type.getElementType(index); })
195       .Default(
196           [](Type) -> Type { llvm_unreachable("invalid composite type"); });
197 }
198 
getNumElements() const199 unsigned CompositeType::getNumElements() const {
200   if (auto arrayType = dyn_cast<ArrayType>())
201     return arrayType.getNumElements();
202   if (auto matrixType = dyn_cast<MatrixType>())
203     return matrixType.getNumColumns();
204   if (auto structType = dyn_cast<StructType>())
205     return structType.getNumElements();
206   if (auto vectorType = dyn_cast<VectorType>())
207     return vectorType.getNumElements();
208   if (isa<CooperativeMatrixNVType>()) {
209     llvm_unreachable(
210         "invalid to query number of elements of spirv::CooperativeMatrix type");
211   }
212   if (isa<RuntimeArrayType>()) {
213     llvm_unreachable(
214         "invalid to query number of elements of spirv::RuntimeArray type");
215   }
216   llvm_unreachable("invalid composite type");
217 }
218 
hasCompileTimeKnownNumElements() const219 bool CompositeType::hasCompileTimeKnownNumElements() const {
220   return !isa<CooperativeMatrixNVType, RuntimeArrayType>();
221 }
222 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)223 void CompositeType::getExtensions(
224     SPIRVType::ExtensionArrayRefVector &extensions,
225     Optional<StorageClass> storage) {
226   TypeSwitch<Type>(*this)
227       .Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
228             StructType>(
229           [&](auto type) { type.getExtensions(extensions, storage); })
230       .Case<VectorType>([&](VectorType type) {
231         return type.getElementType().cast<ScalarType>().getExtensions(
232             extensions, storage);
233       })
234       .Default([](Type) { llvm_unreachable("invalid composite type"); });
235 }
236 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)237 void CompositeType::getCapabilities(
238     SPIRVType::CapabilityArrayRefVector &capabilities,
239     Optional<StorageClass> storage) {
240   TypeSwitch<Type>(*this)
241       .Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
242             StructType>(
243           [&](auto type) { type.getCapabilities(capabilities, storage); })
244       .Case<VectorType>([&](VectorType type) {
245         auto vecSize = getNumElements();
246         if (vecSize == 8 || vecSize == 16) {
247           static const Capability caps[] = {Capability::Vector16};
248           ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
249           capabilities.push_back(ref);
250         }
251         return type.getElementType().cast<ScalarType>().getCapabilities(
252             capabilities, storage);
253       })
254       .Default([](Type) { llvm_unreachable("invalid composite type"); });
255 }
256 
getSizeInBytes()257 Optional<int64_t> CompositeType::getSizeInBytes() {
258   if (auto arrayType = dyn_cast<ArrayType>())
259     return arrayType.getSizeInBytes();
260   if (auto structType = dyn_cast<StructType>())
261     return structType.getSizeInBytes();
262   if (auto vectorType = dyn_cast<VectorType>()) {
263     Optional<int64_t> elementSize =
264         vectorType.getElementType().cast<ScalarType>().getSizeInBytes();
265     if (!elementSize)
266       return llvm::None;
267     return *elementSize * vectorType.getNumElements();
268   }
269   return llvm::None;
270 }
271 
272 //===----------------------------------------------------------------------===//
273 // CooperativeMatrixType
274 //===----------------------------------------------------------------------===//
275 
276 struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage {
277   using KeyTy = std::tuple<Type, Scope, unsigned, unsigned>;
278 
279   static CooperativeMatrixTypeStorage *
constructspirv::detail::CooperativeMatrixTypeStorage280   construct(TypeStorageAllocator &allocator, const KeyTy &key) {
281     return new (allocator.allocate<CooperativeMatrixTypeStorage>())
282         CooperativeMatrixTypeStorage(key);
283   }
284 
operator ==spirv::detail::CooperativeMatrixTypeStorage285   bool operator==(const KeyTy &key) const {
286     return key == KeyTy(elementType, scope, rows, columns);
287   }
288 
CooperativeMatrixTypeStoragespirv::detail::CooperativeMatrixTypeStorage289   CooperativeMatrixTypeStorage(const KeyTy &key)
290       : elementType(std::get<0>(key)), rows(std::get<2>(key)),
291         columns(std::get<3>(key)), scope(std::get<1>(key)) {}
292 
293   Type elementType;
294   unsigned rows;
295   unsigned columns;
296   Scope scope;
297 };
298 
get(Type elementType,Scope scope,unsigned rows,unsigned columns)299 CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType,
300                                                      Scope scope, unsigned rows,
301                                                      unsigned columns) {
302   return Base::get(elementType.getContext(), elementType, scope, rows, columns);
303 }
304 
getElementType() const305 Type CooperativeMatrixNVType::getElementType() const {
306   return getImpl()->elementType;
307 }
308 
getScope() const309 Scope CooperativeMatrixNVType::getScope() const { return getImpl()->scope; }
310 
getRows() const311 unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; }
312 
getColumns() const313 unsigned CooperativeMatrixNVType::getColumns() const {
314   return getImpl()->columns;
315 }
316 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)317 void CooperativeMatrixNVType::getExtensions(
318     SPIRVType::ExtensionArrayRefVector &extensions,
319     Optional<StorageClass> storage) {
320   getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
321   static const Extension exts[] = {Extension::SPV_NV_cooperative_matrix};
322   ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
323   extensions.push_back(ref);
324 }
325 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)326 void CooperativeMatrixNVType::getCapabilities(
327     SPIRVType::CapabilityArrayRefVector &capabilities,
328     Optional<StorageClass> storage) {
329   getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
330   static const Capability caps[] = {Capability::CooperativeMatrixNV};
331   ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
332   capabilities.push_back(ref);
333 }
334 
335 //===----------------------------------------------------------------------===//
336 // ImageType
337 //===----------------------------------------------------------------------===//
338 
getNumBits()339 template <typename T> static constexpr unsigned getNumBits() { return 0; }
getNumBits()340 template <> constexpr unsigned getNumBits<Dim>() {
341   static_assert((1 << 3) > getMaxEnumValForDim(),
342                 "Not enough bits to encode Dim value");
343   return 3;
344 }
getNumBits()345 template <> constexpr unsigned getNumBits<ImageDepthInfo>() {
346   static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
347                 "Not enough bits to encode ImageDepthInfo value");
348   return 2;
349 }
getNumBits()350 template <> constexpr unsigned getNumBits<ImageArrayedInfo>() {
351   static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
352                 "Not enough bits to encode ImageArrayedInfo value");
353   return 1;
354 }
getNumBits()355 template <> constexpr unsigned getNumBits<ImageSamplingInfo>() {
356   static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
357                 "Not enough bits to encode ImageSamplingInfo value");
358   return 1;
359 }
getNumBits()360 template <> constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
361   static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
362                 "Not enough bits to encode ImageSamplerUseInfo value");
363   return 2;
364 }
getNumBits()365 template <> constexpr unsigned getNumBits<ImageFormat>() {
366   static_assert((1 << 6) > getMaxEnumValForImageFormat(),
367                 "Not enough bits to encode ImageFormat value");
368   return 6;
369 }
370 
371 struct spirv::detail::ImageTypeStorage : public TypeStorage {
372 public:
373   using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
374                            ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
375 
constructspirv::detail::ImageTypeStorage376   static ImageTypeStorage *construct(TypeStorageAllocator &allocator,
377                                      const KeyTy &key) {
378     return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
379   }
380 
operator ==spirv::detail::ImageTypeStorage381   bool operator==(const KeyTy &key) const {
382     return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
383                         samplerUseInfo, format);
384   }
385 
ImageTypeStoragespirv::detail::ImageTypeStorage386   ImageTypeStorage(const KeyTy &key)
387       : elementType(std::get<0>(key)), dim(std::get<1>(key)),
388         depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
389         samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
390         format(std::get<6>(key)) {}
391 
392   Type elementType;
393   Dim dim : getNumBits<Dim>();
394   ImageDepthInfo depthInfo : getNumBits<ImageDepthInfo>();
395   ImageArrayedInfo arrayedInfo : getNumBits<ImageArrayedInfo>();
396   ImageSamplingInfo samplingInfo : getNumBits<ImageSamplingInfo>();
397   ImageSamplerUseInfo samplerUseInfo : getNumBits<ImageSamplerUseInfo>();
398   ImageFormat format : getNumBits<ImageFormat>();
399 };
400 
401 ImageType
get(std::tuple<Type,Dim,ImageDepthInfo,ImageArrayedInfo,ImageSamplingInfo,ImageSamplerUseInfo,ImageFormat> value)402 ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
403                           ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
404                    value) {
405   return Base::get(std::get<0>(value).getContext(), value);
406 }
407 
getElementType() const408 Type ImageType::getElementType() const { return getImpl()->elementType; }
409 
getDim() const410 Dim ImageType::getDim() const { return getImpl()->dim; }
411 
getDepthInfo() const412 ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
413 
getArrayedInfo() const414 ImageArrayedInfo ImageType::getArrayedInfo() const {
415   return getImpl()->arrayedInfo;
416 }
417 
getSamplingInfo() const418 ImageSamplingInfo ImageType::getSamplingInfo() const {
419   return getImpl()->samplingInfo;
420 }
421 
getSamplerUseInfo() const422 ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
423   return getImpl()->samplerUseInfo;
424 }
425 
getImageFormat() const426 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
427 
getExtensions(SPIRVType::ExtensionArrayRefVector &,Optional<StorageClass>)428 void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &,
429                               Optional<StorageClass>) {
430   // Image types do not require extra extensions thus far.
431 }
432 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass>)433 void ImageType::getCapabilities(
434     SPIRVType::CapabilityArrayRefVector &capabilities, Optional<StorageClass>) {
435   if (auto dimCaps = spirv::getCapabilities(getDim()))
436     capabilities.push_back(*dimCaps);
437 
438   if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
439     capabilities.push_back(*fmtCaps);
440 }
441 
442 //===----------------------------------------------------------------------===//
443 // PointerType
444 //===----------------------------------------------------------------------===//
445 
446 struct spirv::detail::PointerTypeStorage : public TypeStorage {
447   // (Type, StorageClass) as the key: Type stored in this struct, and
448   // StorageClass stored as TypeStorage's subclass data.
449   using KeyTy = std::pair<Type, StorageClass>;
450 
constructspirv::detail::PointerTypeStorage451   static PointerTypeStorage *construct(TypeStorageAllocator &allocator,
452                                        const KeyTy &key) {
453     return new (allocator.allocate<PointerTypeStorage>())
454         PointerTypeStorage(key);
455   }
456 
operator ==spirv::detail::PointerTypeStorage457   bool operator==(const KeyTy &key) const {
458     return key == KeyTy(pointeeType, storageClass);
459   }
460 
PointerTypeStoragespirv::detail::PointerTypeStorage461   PointerTypeStorage(const KeyTy &key)
462       : pointeeType(key.first), storageClass(key.second) {}
463 
464   Type pointeeType;
465   StorageClass storageClass;
466 };
467 
get(Type pointeeType,StorageClass storageClass)468 PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
469   return Base::get(pointeeType.getContext(), pointeeType, storageClass);
470 }
471 
getPointeeType() const472 Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
473 
getStorageClass() const474 StorageClass PointerType::getStorageClass() const {
475   return getImpl()->storageClass;
476 }
477 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)478 void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
479                                 Optional<StorageClass> storage) {
480   // Use this pointer type's storage class because this pointer indicates we are
481   // using the pointee type in that specific storage class.
482   getPointeeType().cast<SPIRVType>().getExtensions(extensions,
483                                                    getStorageClass());
484 
485   if (auto scExts = spirv::getExtensions(getStorageClass()))
486     extensions.push_back(*scExts);
487 }
488 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)489 void PointerType::getCapabilities(
490     SPIRVType::CapabilityArrayRefVector &capabilities,
491     Optional<StorageClass> storage) {
492   // Use this pointer type's storage class because this pointer indicates we are
493   // using the pointee type in that specific storage class.
494   getPointeeType().cast<SPIRVType>().getCapabilities(capabilities,
495                                                      getStorageClass());
496 
497   if (auto scCaps = spirv::getCapabilities(getStorageClass()))
498     capabilities.push_back(*scCaps);
499 }
500 
501 //===----------------------------------------------------------------------===//
502 // RuntimeArrayType
503 //===----------------------------------------------------------------------===//
504 
505 struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage {
506   using KeyTy = std::pair<Type, unsigned>;
507 
constructspirv::detail::RuntimeArrayTypeStorage508   static RuntimeArrayTypeStorage *construct(TypeStorageAllocator &allocator,
509                                             const KeyTy &key) {
510     return new (allocator.allocate<RuntimeArrayTypeStorage>())
511         RuntimeArrayTypeStorage(key);
512   }
513 
operator ==spirv::detail::RuntimeArrayTypeStorage514   bool operator==(const KeyTy &key) const {
515     return key == KeyTy(elementType, stride);
516   }
517 
RuntimeArrayTypeStoragespirv::detail::RuntimeArrayTypeStorage518   RuntimeArrayTypeStorage(const KeyTy &key)
519       : elementType(key.first), stride(key.second) {}
520 
521   Type elementType;
522   unsigned stride;
523 };
524 
get(Type elementType)525 RuntimeArrayType RuntimeArrayType::get(Type elementType) {
526   return Base::get(elementType.getContext(), elementType, /*stride=*/0);
527 }
528 
get(Type elementType,unsigned stride)529 RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
530   return Base::get(elementType.getContext(), elementType, stride);
531 }
532 
getElementType() const533 Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
534 
getArrayStride() const535 unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
536 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)537 void RuntimeArrayType::getExtensions(
538     SPIRVType::ExtensionArrayRefVector &extensions,
539     Optional<StorageClass> storage) {
540   getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
541 }
542 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)543 void RuntimeArrayType::getCapabilities(
544     SPIRVType::CapabilityArrayRefVector &capabilities,
545     Optional<StorageClass> storage) {
546   {
547     static const Capability caps[] = {Capability::Shader};
548     ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
549     capabilities.push_back(ref);
550   }
551   getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
552 }
553 
554 //===----------------------------------------------------------------------===//
555 // ScalarType
556 //===----------------------------------------------------------------------===//
557 
classof(Type type)558 bool ScalarType::classof(Type type) {
559   if (auto floatType = type.dyn_cast<FloatType>()) {
560     return isValid(floatType);
561   }
562   if (auto intType = type.dyn_cast<IntegerType>()) {
563     return isValid(intType);
564   }
565   return false;
566 }
567 
isValid(FloatType type)568 bool ScalarType::isValid(FloatType type) { return !type.isBF16(); }
569 
isValid(IntegerType type)570 bool ScalarType::isValid(IntegerType type) {
571   switch (type.getWidth()) {
572   case 1:
573   case 8:
574   case 16:
575   case 32:
576   case 64:
577     return true;
578   default:
579     return false;
580   }
581 }
582 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)583 void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
584                                Optional<StorageClass> storage) {
585   // 8- or 16-bit integer/floating-point numbers will require extra extensions
586   // to appear in interface storage classes. See SPV_KHR_16bit_storage and
587   // SPV_KHR_8bit_storage for more details.
588   if (!storage)
589     return;
590 
591   switch (*storage) {
592   case StorageClass::PushConstant:
593   case StorageClass::StorageBuffer:
594   case StorageClass::Uniform:
595     if (getIntOrFloatBitWidth() == 8) {
596       static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
597       ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
598       extensions.push_back(ref);
599     }
600     LLVM_FALLTHROUGH;
601   case StorageClass::Input:
602   case StorageClass::Output:
603     if (getIntOrFloatBitWidth() == 16) {
604       static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
605       ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
606       extensions.push_back(ref);
607     }
608     break;
609   default:
610     break;
611   }
612 }
613 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)614 void ScalarType::getCapabilities(
615     SPIRVType::CapabilityArrayRefVector &capabilities,
616     Optional<StorageClass> storage) {
617   unsigned bitwidth = getIntOrFloatBitWidth();
618 
619   // 8- or 16-bit integer/floating-point numbers will require extra capabilities
620   // to appear in interface storage classes. See SPV_KHR_16bit_storage and
621   // SPV_KHR_8bit_storage for more details.
622 
623 #define STORAGE_CASE(storage, cap8, cap16)                                     \
624   case StorageClass::storage: {                                                \
625     if (bitwidth == 8) {                                                       \
626       static const Capability caps[] = {Capability::cap8};                     \
627       ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));              \
628       capabilities.push_back(ref);                                             \
629     } else if (bitwidth == 16) {                                               \
630       static const Capability caps[] = {Capability::cap16};                    \
631       ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));              \
632       capabilities.push_back(ref);                                             \
633     }                                                                          \
634     /* No requirements for other bitwidths */                                  \
635     return;                                                                    \
636   }
637 
638   // This part only handles the cases where special bitwidths appearing in
639   // interface storage classes.
640   if (storage) {
641     switch (*storage) {
642       STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
643       STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
644                    StorageBuffer16BitAccess);
645       STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
646                    StorageUniform16);
647     case StorageClass::Input:
648     case StorageClass::Output: {
649       if (bitwidth == 16) {
650         static const Capability caps[] = {Capability::StorageInputOutput16};
651         ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
652         capabilities.push_back(ref);
653       }
654       return;
655     }
656     default:
657       break;
658     }
659   }
660 #undef STORAGE_CASE
661 
662   // For other non-interface storage classes, require a different set of
663   // capabilities for special bitwidths.
664 
665 #define WIDTH_CASE(type, width)                                                \
666   case width: {                                                                \
667     static const Capability caps[] = {Capability::type##width};                \
668     ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));                \
669     capabilities.push_back(ref);                                               \
670   } break
671 
672   if (auto intType = dyn_cast<IntegerType>()) {
673     switch (bitwidth) {
674     case 32:
675     case 1:
676       break;
677       WIDTH_CASE(Int, 8);
678       WIDTH_CASE(Int, 16);
679       WIDTH_CASE(Int, 64);
680     default:
681       llvm_unreachable("invalid bitwidth to getCapabilities");
682     }
683   } else {
684     assert(isa<FloatType>());
685     switch (bitwidth) {
686     case 32:
687       break;
688       WIDTH_CASE(Float, 16);
689       WIDTH_CASE(Float, 64);
690     default:
691       llvm_unreachable("invalid bitwidth to getCapabilities");
692     }
693   }
694 
695 #undef WIDTH_CASE
696 }
697 
getSizeInBytes()698 Optional<int64_t> ScalarType::getSizeInBytes() {
699   auto bitWidth = getIntOrFloatBitWidth();
700   // According to the SPIR-V spec:
701   // "There is no physical size or bit pattern defined for values with boolean
702   // type. If they are stored (in conjunction with OpVariable), they can only
703   // be used with logical addressing operations, not physical, and only with
704   // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
705   // Private, Function, Input, and Output."
706   if (bitWidth == 1)
707     return llvm::None;
708   return bitWidth / 8;
709 }
710 
711 //===----------------------------------------------------------------------===//
712 // SPIRVType
713 //===----------------------------------------------------------------------===//
714 
classof(Type type)715 bool SPIRVType::classof(Type type) {
716   // Allow SPIR-V dialect types
717   if (llvm::isa<SPIRVDialect>(type.getDialect()))
718     return true;
719   if (type.isa<ScalarType>())
720     return true;
721   if (auto vectorType = type.dyn_cast<VectorType>())
722     return CompositeType::isValid(vectorType);
723   return false;
724 }
725 
isScalarOrVector()726 bool SPIRVType::isScalarOrVector() {
727   return isIntOrFloat() || isa<VectorType>();
728 }
729 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)730 void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
731                               Optional<StorageClass> storage) {
732   if (auto scalarType = dyn_cast<ScalarType>()) {
733     scalarType.getExtensions(extensions, storage);
734   } else if (auto compositeType = dyn_cast<CompositeType>()) {
735     compositeType.getExtensions(extensions, storage);
736   } else if (auto imageType = dyn_cast<ImageType>()) {
737     imageType.getExtensions(extensions, storage);
738   } else if (auto matrixType = dyn_cast<MatrixType>()) {
739     matrixType.getExtensions(extensions, storage);
740   } else if (auto ptrType = dyn_cast<PointerType>()) {
741     ptrType.getExtensions(extensions, storage);
742   } else {
743     llvm_unreachable("invalid SPIR-V Type to getExtensions");
744   }
745 }
746 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)747 void SPIRVType::getCapabilities(
748     SPIRVType::CapabilityArrayRefVector &capabilities,
749     Optional<StorageClass> storage) {
750   if (auto scalarType = dyn_cast<ScalarType>()) {
751     scalarType.getCapabilities(capabilities, storage);
752   } else if (auto compositeType = dyn_cast<CompositeType>()) {
753     compositeType.getCapabilities(capabilities, storage);
754   } else if (auto imageType = dyn_cast<ImageType>()) {
755     imageType.getCapabilities(capabilities, storage);
756   } else if (auto matrixType = dyn_cast<MatrixType>()) {
757     matrixType.getCapabilities(capabilities, storage);
758   } else if (auto ptrType = dyn_cast<PointerType>()) {
759     ptrType.getCapabilities(capabilities, storage);
760   } else {
761     llvm_unreachable("invalid SPIR-V Type to getCapabilities");
762   }
763 }
764 
getSizeInBytes()765 Optional<int64_t> SPIRVType::getSizeInBytes() {
766   if (auto scalarType = dyn_cast<ScalarType>())
767     return scalarType.getSizeInBytes();
768   if (auto compositeType = dyn_cast<CompositeType>())
769     return compositeType.getSizeInBytes();
770   return llvm::None;
771 }
772 
773 //===----------------------------------------------------------------------===//
774 // StructType
775 //===----------------------------------------------------------------------===//
776 
777 /// Type storage for SPIR-V structure types:
778 ///
779 /// Structures are uniqued using:
780 /// - for identified structs:
781 ///   - a string identifier;
782 /// - for literal structs:
783 ///   - a list of member types;
784 ///   - a list of member offset info;
785 ///   - a list of member decoration info.
786 ///
787 /// Identified structures only have a mutable component consisting of:
788 /// - a list of member types;
789 /// - a list of member offset info;
790 /// - a list of member decoration info.
791 struct spirv::detail::StructTypeStorage : public TypeStorage {
792   /// Construct a storage object for an identified struct type. A struct type
793   /// associated with such storage must call StructType::trySetBody(...) later
794   /// in order to mutate the storage object providing the actual content.
StructTypeStoragespirv::detail::StructTypeStorage795   StructTypeStorage(StringRef identifier)
796       : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
797         numMemberDecorations(0), memberDecorationsInfo(nullptr),
798         identifier(identifier) {}
799 
800   /// Construct a storage object for a literal struct type. A struct type
801   /// associated with such storage is immutable.
StructTypeStoragespirv::detail::StructTypeStorage802   StructTypeStorage(
803       unsigned numMembers, Type const *memberTypes,
804       StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
805       StructType::MemberDecorationInfo const *memberDecorationsInfo)
806       : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
807         numMembers(numMembers), numMemberDecorations(numMemberDecorations),
808         memberDecorationsInfo(memberDecorationsInfo), identifier(StringRef()) {}
809 
810   /// A storage key is divided into 2 parts:
811   /// - for identified structs:
812   ///   - a StringRef representing the struct identifier;
813   /// - for literal structs:
814   ///   - an ArrayRef<Type> for member types;
815   ///   - an ArrayRef<StructType::OffsetInfo> for member offset info;
816   ///   - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
817   ///     info.
818   ///
819   /// An identified struct type is uniqued only by the first part (field 0)
820   /// of the key.
821   ///
822   /// A literal struct type is unqiued only by the second part (fields 1, 2, and
823   /// 3) of the key. The identifier field (field 0) must be empty.
824   using KeyTy =
825       std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
826                  ArrayRef<StructType::MemberDecorationInfo>>;
827 
828   /// For identified structs, return true if the given key contains the same
829   /// identifier.
830   ///
831   /// For literal structs, return true if the given key contains a matching list
832   /// of member types + offset info + decoration info.
operator ==spirv::detail::StructTypeStorage833   bool operator==(const KeyTy &key) const {
834     if (isIdentified()) {
835       // Identified types are uniqued by their identifier.
836       return getIdentifier() == std::get<0>(key);
837     }
838 
839     return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
840                         getMemberDecorationsInfo());
841   }
842 
843   /// If the given key contains a non-empty identifier, this method constructs
844   /// an identified struct and leaves the rest of the struct type data to be set
845   /// through a later call to StructType::trySetBody(...).
846   ///
847   /// If, on the other hand, the key contains an empty identifier, a literal
848   /// struct is constructed using the other fields of the key.
constructspirv::detail::StructTypeStorage849   static StructTypeStorage *construct(TypeStorageAllocator &allocator,
850                                       const KeyTy &key) {
851     StringRef keyIdentifier = std::get<0>(key);
852 
853     if (!keyIdentifier.empty()) {
854       StringRef identifier = allocator.copyInto(keyIdentifier);
855 
856       // Identified StructType body/members will be set through trySetBody(...)
857       // later.
858       return new (allocator.allocate<StructTypeStorage>())
859           StructTypeStorage(identifier);
860     }
861 
862     ArrayRef<Type> keyTypes = std::get<1>(key);
863 
864     // Copy the member type and layout information into the bump pointer
865     const Type *typesList = nullptr;
866     if (!keyTypes.empty()) {
867       typesList = allocator.copyInto(keyTypes).data();
868     }
869 
870     const StructType::OffsetInfo *offsetInfoList = nullptr;
871     if (!std::get<2>(key).empty()) {
872       ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
873       assert(keyOffsetInfo.size() == keyTypes.size() &&
874              "size of offset information must be same as the size of number of "
875              "elements");
876       offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
877     }
878 
879     const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
880     unsigned numMemberDecorations = 0;
881     if (!std::get<3>(key).empty()) {
882       auto keyMemberDecorations = std::get<3>(key);
883       numMemberDecorations = keyMemberDecorations.size();
884       memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
885     }
886 
887     return new (allocator.allocate<StructTypeStorage>())
888         StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
889                           numMemberDecorations, memberDecorationList);
890   }
891 
getMemberTypesspirv::detail::StructTypeStorage892   ArrayRef<Type> getMemberTypes() const {
893     return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
894   }
895 
getOffsetInfospirv::detail::StructTypeStorage896   ArrayRef<StructType::OffsetInfo> getOffsetInfo() const {
897     if (offsetInfo) {
898       return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers);
899     }
900     return {};
901   }
902 
getMemberDecorationsInfospirv::detail::StructTypeStorage903   ArrayRef<StructType::MemberDecorationInfo> getMemberDecorationsInfo() const {
904     if (memberDecorationsInfo) {
905       return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo,
906                                                         numMemberDecorations);
907     }
908     return {};
909   }
910 
getIdentifierspirv::detail::StructTypeStorage911   StringRef getIdentifier() const { return identifier; }
912 
isIdentifiedspirv::detail::StructTypeStorage913   bool isIdentified() const { return !identifier.empty(); }
914 
915   /// Sets the struct type content for identified structs. Calling this method
916   /// is only valid for identified structs.
917   ///
918   /// Fails under the following conditions:
919   /// - If called for a literal struct;
920   /// - If called for an identified struct whose body was set before (through a
921   /// call to this method) but with different contents from the passed
922   /// arguments.
mutatespirv::detail::StructTypeStorage923   LogicalResult mutate(
924       TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
925       ArrayRef<StructType::OffsetInfo> structOffsetInfo,
926       ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
927     if (!isIdentified())
928       return failure();
929 
930     if (memberTypesAndIsBodySet.getInt() &&
931         (getMemberTypes() != structMemberTypes ||
932          getOffsetInfo() != structOffsetInfo ||
933          getMemberDecorationsInfo() != structMemberDecorationInfo))
934       return failure();
935 
936     memberTypesAndIsBodySet.setInt(true);
937     numMembers = structMemberTypes.size();
938 
939     // Copy the member type and layout information into the bump pointer.
940     if (!structMemberTypes.empty())
941       memberTypesAndIsBodySet.setPointer(
942           allocator.copyInto(structMemberTypes).data());
943 
944     if (!structOffsetInfo.empty()) {
945       assert(structOffsetInfo.size() == structMemberTypes.size() &&
946              "size of offset information must be same as the size of number of "
947              "elements");
948       offsetInfo = allocator.copyInto(structOffsetInfo).data();
949     }
950 
951     if (!structMemberDecorationInfo.empty()) {
952       numMemberDecorations = structMemberDecorationInfo.size();
953       memberDecorationsInfo =
954           allocator.copyInto(structMemberDecorationInfo).data();
955     }
956 
957     return success();
958   }
959 
960   llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
961   StructType::OffsetInfo const *offsetInfo;
962   unsigned numMembers;
963   unsigned numMemberDecorations;
964   StructType::MemberDecorationInfo const *memberDecorationsInfo;
965   StringRef identifier;
966 };
967 
968 StructType
get(ArrayRef<Type> memberTypes,ArrayRef<StructType::OffsetInfo> offsetInfo,ArrayRef<StructType::MemberDecorationInfo> memberDecorations)969 StructType::get(ArrayRef<Type> memberTypes,
970                 ArrayRef<StructType::OffsetInfo> offsetInfo,
971                 ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
972   assert(!memberTypes.empty() && "Struct needs at least one member type");
973   // Sort the decorations.
974   SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
975       memberDecorations.begin(), memberDecorations.end());
976   llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
977   return Base::get(memberTypes.vec().front().getContext(),
978                    /*identifier=*/StringRef(), memberTypes, offsetInfo,
979                    sortedDecorations);
980 }
981 
getIdentified(MLIRContext * context,StringRef identifier)982 StructType StructType::getIdentified(MLIRContext *context,
983                                      StringRef identifier) {
984   assert(!identifier.empty() &&
985          "StructType identifier must be non-empty string");
986 
987   return Base::get(context, identifier, ArrayRef<Type>(),
988                    ArrayRef<StructType::OffsetInfo>(),
989                    ArrayRef<StructType::MemberDecorationInfo>());
990 }
991 
getEmpty(MLIRContext * context,StringRef identifier)992 StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
993   StructType newStructType = Base::get(
994       context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
995       ArrayRef<StructType::MemberDecorationInfo>());
996   // Set an empty body in case this is a identified struct.
997   if (newStructType.isIdentified() &&
998       failed(newStructType.trySetBody(
999           ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
1000           ArrayRef<StructType::MemberDecorationInfo>())))
1001     return StructType();
1002 
1003   return newStructType;
1004 }
1005 
getIdentifier() const1006 StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
1007 
isIdentified() const1008 bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
1009 
getNumElements() const1010 unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
1011 
getElementType(unsigned index) const1012 Type StructType::getElementType(unsigned index) const {
1013   assert(getNumElements() > index && "member index out of range");
1014   return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1015 }
1016 
getElementTypes() const1017 StructType::ElementTypeRange StructType::getElementTypes() const {
1018   return ElementTypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
1019                           getNumElements());
1020 }
1021 
hasOffset() const1022 bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1023 
getMemberOffset(unsigned index) const1024 uint64_t StructType::getMemberOffset(unsigned index) const {
1025   assert(getNumElements() > index && "member index out of range");
1026   return getImpl()->offsetInfo[index];
1027 }
1028 
getMemberDecorations(SmallVectorImpl<StructType::MemberDecorationInfo> & memberDecorations) const1029 void StructType::getMemberDecorations(
1030     SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorations)
1031     const {
1032   memberDecorations.clear();
1033   auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1034   memberDecorations.append(implMemberDecorations.begin(),
1035                            implMemberDecorations.end());
1036 }
1037 
getMemberDecorations(unsigned index,SmallVectorImpl<StructType::MemberDecorationInfo> & decorationsInfo) const1038 void StructType::getMemberDecorations(
1039     unsigned index,
1040     SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
1041   assert(getNumElements() > index && "member index out of range");
1042   auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1043   decorationsInfo.clear();
1044   for (const auto &memberDecoration : memberDecorations) {
1045     if (memberDecoration.memberIndex == index) {
1046       decorationsInfo.push_back(memberDecoration);
1047     }
1048     if (memberDecoration.memberIndex > index) {
1049       // Early exit since the decorations are stored sorted.
1050       return;
1051     }
1052   }
1053 }
1054 
1055 LogicalResult
trySetBody(ArrayRef<Type> memberTypes,ArrayRef<OffsetInfo> offsetInfo,ArrayRef<MemberDecorationInfo> memberDecorations)1056 StructType::trySetBody(ArrayRef<Type> memberTypes,
1057                        ArrayRef<OffsetInfo> offsetInfo,
1058                        ArrayRef<MemberDecorationInfo> memberDecorations) {
1059   return Base::mutate(memberTypes, offsetInfo, memberDecorations);
1060 }
1061 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)1062 void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
1063                                Optional<StorageClass> storage) {
1064   for (Type elementType : getElementTypes())
1065     elementType.cast<SPIRVType>().getExtensions(extensions, storage);
1066 }
1067 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)1068 void StructType::getCapabilities(
1069     SPIRVType::CapabilityArrayRefVector &capabilities,
1070     Optional<StorageClass> storage) {
1071   for (Type elementType : getElementTypes())
1072     elementType.cast<SPIRVType>().getCapabilities(capabilities, storage);
1073 }
1074 
hash_value(const StructType::MemberDecorationInfo & memberDecorationInfo)1075 llvm::hash_code spirv::hash_value(
1076     const StructType::MemberDecorationInfo &memberDecorationInfo) {
1077   return llvm::hash_combine(memberDecorationInfo.memberIndex,
1078                             memberDecorationInfo.decoration);
1079 }
1080 
1081 //===----------------------------------------------------------------------===//
1082 // MatrixType
1083 //===----------------------------------------------------------------------===//
1084 
1085 struct spirv::detail::MatrixTypeStorage : public TypeStorage {
MatrixTypeStoragespirv::detail::MatrixTypeStorage1086   MatrixTypeStorage(Type columnType, uint32_t columnCount)
1087       : TypeStorage(), columnType(columnType), columnCount(columnCount) {}
1088 
1089   using KeyTy = std::tuple<Type, uint32_t>;
1090 
constructspirv::detail::MatrixTypeStorage1091   static MatrixTypeStorage *construct(TypeStorageAllocator &allocator,
1092                                       const KeyTy &key) {
1093 
1094     // Initialize the memory using placement new.
1095     return new (allocator.allocate<MatrixTypeStorage>())
1096         MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
1097   }
1098 
operator ==spirv::detail::MatrixTypeStorage1099   bool operator==(const KeyTy &key) const {
1100     return key == KeyTy(columnType, columnCount);
1101   }
1102 
1103   Type columnType;
1104   const uint32_t columnCount;
1105 };
1106 
get(Type columnType,uint32_t columnCount)1107 MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1108   return Base::get(columnType.getContext(), columnType, columnCount);
1109 }
1110 
getChecked(Type columnType,uint32_t columnCount,Location location)1111 MatrixType MatrixType::getChecked(Type columnType, uint32_t columnCount,
1112                                   Location location) {
1113   return Base::getChecked(location, columnType, columnCount);
1114 }
1115 
verifyConstructionInvariants(Location loc,Type columnType,uint32_t columnCount)1116 LogicalResult MatrixType::verifyConstructionInvariants(Location loc,
1117                                                        Type columnType,
1118                                                        uint32_t columnCount) {
1119   if (columnCount < 2 || columnCount > 4)
1120     return emitError(loc, "matrix can have 2, 3, or 4 columns only");
1121 
1122   if (!isValidColumnType(columnType))
1123     return emitError(loc, "matrix columns must be vectors of floats");
1124 
1125   /// The underlying vectors (columns) must be of size 2, 3, or 4
1126   ArrayRef<int64_t> columnShape = columnType.cast<VectorType>().getShape();
1127   if (columnShape.size() != 1)
1128     return emitError(loc, "matrix columns must be 1D vectors");
1129 
1130   if (columnShape[0] < 2 || columnShape[0] > 4)
1131     return emitError(loc, "matrix columns must be of size 2, 3, or 4");
1132 
1133   return success();
1134 }
1135 
1136 /// Returns true if the matrix elements are vectors of float elements
isValidColumnType(Type columnType)1137 bool MatrixType::isValidColumnType(Type columnType) {
1138   if (auto vectorType = columnType.dyn_cast<VectorType>()) {
1139     if (vectorType.getElementType().isa<FloatType>())
1140       return true;
1141   }
1142   return false;
1143 }
1144 
getColumnType() const1145 Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1146 
getElementType() const1147 Type MatrixType::getElementType() const {
1148   return getImpl()->columnType.cast<VectorType>().getElementType();
1149 }
1150 
getNumColumns() const1151 unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
1152 
getNumRows() const1153 unsigned MatrixType::getNumRows() const {
1154   return getImpl()->columnType.cast<VectorType>().getShape()[0];
1155 }
1156 
getNumElements() const1157 unsigned MatrixType::getNumElements() const {
1158   return (getImpl()->columnCount) * getNumRows();
1159 }
1160 
getExtensions(SPIRVType::ExtensionArrayRefVector & extensions,Optional<StorageClass> storage)1161 void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
1162                                Optional<StorageClass> storage) {
1163   getColumnType().cast<SPIRVType>().getExtensions(extensions, storage);
1164 }
1165 
getCapabilities(SPIRVType::CapabilityArrayRefVector & capabilities,Optional<StorageClass> storage)1166 void MatrixType::getCapabilities(
1167     SPIRVType::CapabilityArrayRefVector &capabilities,
1168     Optional<StorageClass> storage) {
1169   {
1170     static const Capability caps[] = {Capability::Matrix};
1171     ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
1172     capabilities.push_back(ref);
1173   }
1174   // Add any capabilities associated with the underlying vectors (i.e., columns)
1175   getColumnType().cast<SPIRVType>().getCapabilities(capabilities, storage);
1176 }
1177