1 //===- BuiltinTypes.h - MLIR Builtin Type Classes ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_BUILTINTYPES_H
10 #define MLIR_IR_BUILTINTYPES_H
11 
12 #include "mlir/IR/Types.h"
13 
14 namespace llvm {
15 struct fltSemantics;
16 } // namespace llvm
17 
18 namespace mlir {
19 class AffineExpr;
20 class AffineMap;
21 class FloatType;
22 class Identifier;
23 class IndexType;
24 class IntegerType;
25 class Location;
26 class MLIRContext;
27 class TypeRange;
28 
29 namespace detail {
30 
31 struct BaseMemRefTypeStorage;
32 struct ComplexTypeStorage;
33 struct FunctionTypeStorage;
34 struct IntegerTypeStorage;
35 struct MemRefTypeStorage;
36 struct OpaqueTypeStorage;
37 struct RankedTensorTypeStorage;
38 struct ShapedTypeStorage;
39 struct TupleTypeStorage;
40 struct UnrankedMemRefTypeStorage;
41 struct UnrankedTensorTypeStorage;
42 struct VectorTypeStorage;
43 
44 } // namespace detail
45 
46 //===----------------------------------------------------------------------===//
47 // ComplexType
48 //===----------------------------------------------------------------------===//
49 
50 /// The 'complex' type represents a complex number with a parameterized element
51 /// type, which is composed of a real and imaginary value of that element type.
52 ///
53 /// The element must be a floating point or integer scalar type.
54 ///
55 class ComplexType
56     : public Type::TypeBase<ComplexType, Type, detail::ComplexTypeStorage> {
57 public:
58   using Base::Base;
59 
60   /// Get or create a ComplexType with the provided element type.
61   static ComplexType get(Type elementType);
62 
63   /// Get or create a ComplexType with the provided element type.  This emits
64   /// and error at the specified location and returns null if the element type
65   /// isn't supported.
66   static ComplexType getChecked(Type elementType, Location location);
67 
68   /// Verify the construction of an integer type.
69   static LogicalResult verifyConstructionInvariants(Location loc,
70                                                     Type elementType);
71 
72   Type getElementType();
73 };
74 
75 //===----------------------------------------------------------------------===//
76 // IndexType
77 //===----------------------------------------------------------------------===//
78 
79 /// Index is a special integer-like type with unknown platform-dependent bit
80 /// width.
81 class IndexType : public Type::TypeBase<IndexType, Type, TypeStorage> {
82 public:
83   using Base::Base;
84 
85   /// Get an instance of the IndexType.
86   static IndexType get(MLIRContext *context);
87 
88   /// Storage bit width used for IndexType by internal compiler data structures.
89   static constexpr unsigned kInternalStorageBitWidth = 64;
90 };
91 
92 //===----------------------------------------------------------------------===//
93 // IntegerType
94 //===----------------------------------------------------------------------===//
95 
96 /// Integer types can have arbitrary bitwidth up to a large fixed limit.
97 class IntegerType
98     : public Type::TypeBase<IntegerType, Type, detail::IntegerTypeStorage> {
99 public:
100   using Base::Base;
101 
102   /// Signedness semantics.
103   enum SignednessSemantics : uint32_t {
104     Signless, /// No signedness semantics
105     Signed,   /// Signed integer
106     Unsigned, /// Unsigned integer
107   };
108 
109   /// Get or create a new IntegerType of the given width within the context.
110   /// The created IntegerType is signless (i.e., no signedness semantics).
111   /// Assume the width is within the allowed range and assert on failures. Use
112   /// getChecked to handle failures gracefully.
113   static IntegerType get(unsigned width, MLIRContext *context);
114 
115   /// Get or create a new IntegerType of the given width within the context.
116   /// The created IntegerType has signedness semantics as indicated via
117   /// `signedness`. Assume the width is within the allowed range and assert on
118   /// failures. Use getChecked to handle failures gracefully.
119   static IntegerType get(unsigned width, SignednessSemantics signedness,
120                          MLIRContext *context);
121 
122   /// Get or create a new IntegerType of the given width within the context,
123   /// defined at the given, potentially unknown, location.  The created
124   /// IntegerType is signless (i.e., no signedness semantics). If the width is
125   /// outside the allowed range, emit errors and return a null type.
126   static IntegerType getChecked(unsigned width, Location location);
127 
128   /// Get or create a new IntegerType of the given width within the context,
129   /// defined at the given, potentially unknown, location. The created
130   /// IntegerType has signedness semantics as indicated via `signedness`. If the
131   /// width is outside the allowed range, emit errors and return a null type.
132   static IntegerType getChecked(unsigned width, SignednessSemantics signedness,
133                                 Location location);
134 
135   /// Verify the construction of an integer type.
136   static LogicalResult
137   verifyConstructionInvariants(Location loc, unsigned width,
138                                SignednessSemantics signedness);
139 
140   /// Return the bitwidth of this integer type.
141   unsigned getWidth() const;
142 
143   /// Return the signedness semantics of this integer type.
144   SignednessSemantics getSignedness() const;
145 
146   /// Return true if this is a signless integer type.
isSignless()147   bool isSignless() const { return getSignedness() == Signless; }
148   /// Return true if this is a signed integer type.
isSigned()149   bool isSigned() const { return getSignedness() == Signed; }
150   /// Return true if this is an unsigned integer type.
isUnsigned()151   bool isUnsigned() const { return getSignedness() == Unsigned; }
152 
153   /// Integer representation maximal bitwidth.
154   static constexpr unsigned kMaxWidth = 4096;
155 };
156 
157 //===----------------------------------------------------------------------===//
158 // FloatType
159 //===----------------------------------------------------------------------===//
160 
161 class FloatType : public Type {
162 public:
163   using Type::Type;
164 
165   // Convenience factories.
166   static FloatType getBF16(MLIRContext *ctx);
167   static FloatType getF16(MLIRContext *ctx);
168   static FloatType getF32(MLIRContext *ctx);
169   static FloatType getF64(MLIRContext *ctx);
170 
171   /// Methods for support type inquiry through isa, cast, and dyn_cast.
172   static bool classof(Type type);
173 
174   /// Return the bitwidth of this float type.
175   unsigned getWidth();
176 
177   /// Return the floating semantics of this float type.
178   const llvm::fltSemantics &getFloatSemantics();
179 };
180 
181 //===----------------------------------------------------------------------===//
182 // BFloat16Type
183 
184 class BFloat16Type
185     : public Type::TypeBase<BFloat16Type, FloatType, TypeStorage> {
186 public:
187   using Base::Base;
188 
189   /// Return an instance of the bfloat16 type.
190   static BFloat16Type get(MLIRContext *context);
191 };
192 
getBF16(MLIRContext * ctx)193 inline FloatType FloatType::getBF16(MLIRContext *ctx) {
194   return BFloat16Type::get(ctx);
195 }
196 
197 //===----------------------------------------------------------------------===//
198 // Float16Type
199 
200 class Float16Type : public Type::TypeBase<Float16Type, FloatType, TypeStorage> {
201 public:
202   using Base::Base;
203 
204   /// Return an instance of the float16 type.
205   static Float16Type get(MLIRContext *context);
206 };
207 
getF16(MLIRContext * ctx)208 inline FloatType FloatType::getF16(MLIRContext *ctx) {
209   return Float16Type::get(ctx);
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // Float32Type
214 
215 class Float32Type : public Type::TypeBase<Float32Type, FloatType, TypeStorage> {
216 public:
217   using Base::Base;
218 
219   /// Return an instance of the float32 type.
220   static Float32Type get(MLIRContext *context);
221 };
222 
getF32(MLIRContext * ctx)223 inline FloatType FloatType::getF32(MLIRContext *ctx) {
224   return Float32Type::get(ctx);
225 }
226 
227 //===----------------------------------------------------------------------===//
228 // Float64Type
229 
230 class Float64Type : public Type::TypeBase<Float64Type, FloatType, TypeStorage> {
231 public:
232   using Base::Base;
233 
234   /// Return an instance of the float64 type.
235   static Float64Type get(MLIRContext *context);
236 };
237 
getF64(MLIRContext * ctx)238 inline FloatType FloatType::getF64(MLIRContext *ctx) {
239   return Float64Type::get(ctx);
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // FunctionType
244 //===----------------------------------------------------------------------===//
245 
246 /// Function types map from a list of inputs to a list of results.
247 class FunctionType
248     : public Type::TypeBase<FunctionType, Type, detail::FunctionTypeStorage> {
249 public:
250   using Base::Base;
251 
252   static FunctionType get(TypeRange inputs, TypeRange results,
253                           MLIRContext *context);
254 
255   /// Input types.
256   unsigned getNumInputs() const;
getInput(unsigned i)257   Type getInput(unsigned i) const { return getInputs()[i]; }
258   ArrayRef<Type> getInputs() const;
259 
260   /// Result types.
261   unsigned getNumResults() const;
getResult(unsigned i)262   Type getResult(unsigned i) const { return getResults()[i]; }
263   ArrayRef<Type> getResults() const;
264 
265   /// Returns a new function type without the specified arguments and results.
266   FunctionType getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
267                                         ArrayRef<unsigned> resultIndices);
268 };
269 
270 //===----------------------------------------------------------------------===//
271 // NoneType
272 //===----------------------------------------------------------------------===//
273 
274 /// NoneType is a unit type, i.e. a type with exactly one possible value, where
275 /// its value does not have a defined dynamic representation.
276 class NoneType : public Type::TypeBase<NoneType, Type, TypeStorage> {
277 public:
278   using Base::Base;
279 
280   /// Get an instance of the NoneType.
281   static NoneType get(MLIRContext *context);
282 };
283 
284 //===----------------------------------------------------------------------===//
285 // OpaqueType
286 //===----------------------------------------------------------------------===//
287 
288 /// Opaque types represent types of non-registered dialects. These are types
289 /// represented in their raw string form, and can only usefully be tested for
290 /// type equality.
291 class OpaqueType
292     : public Type::TypeBase<OpaqueType, Type, detail::OpaqueTypeStorage> {
293 public:
294   using Base::Base;
295 
296   /// Get or create a new OpaqueType with the provided dialect and string data.
297   static OpaqueType get(Identifier dialect, StringRef typeData,
298                         MLIRContext *context);
299 
300   /// Get or create a new OpaqueType with the provided dialect and string data.
301   /// If the given identifier is not a valid namespace for a dialect, then a
302   /// null type is returned.
303   static OpaqueType getChecked(Identifier dialect, StringRef typeData,
304                                MLIRContext *context, Location location);
305 
306   /// Returns the dialect namespace of the opaque type.
307   Identifier getDialectNamespace() const;
308 
309   /// Returns the raw type data of the opaque type.
310   StringRef getTypeData() const;
311 
312   /// Verify the construction of an opaque type.
313   static LogicalResult verifyConstructionInvariants(Location loc,
314                                                     Identifier dialect,
315                                                     StringRef typeData);
316 };
317 
318 //===----------------------------------------------------------------------===//
319 // ShapedType
320 //===----------------------------------------------------------------------===//
321 
322 /// This is a common base class between Vector, UnrankedTensor, RankedTensor,
323 /// and MemRef types because they share behavior and semantics around shape,
324 /// rank, and fixed element type. Any type with these semantics should inherit
325 /// from ShapedType.
326 class ShapedType : public Type {
327 public:
328   using ImplType = detail::ShapedTypeStorage;
329   using Type::Type;
330 
331   // TODO: merge these two special values in a single one used everywhere.
332   // Unfortunately, uses of `-1` have crept deep into the codebase now and are
333   // hard to track.
334   static constexpr int64_t kDynamicSize = -1;
335   static constexpr int64_t kDynamicStrideOrOffset =
336       std::numeric_limits<int64_t>::min();
337 
338   /// Return the element type.
339   Type getElementType() const;
340 
341   /// If an element type is an integer or a float, return its width. Otherwise,
342   /// abort.
343   unsigned getElementTypeBitWidth() const;
344 
345   /// If it has static shape, return the number of elements. Otherwise, abort.
346   int64_t getNumElements() const;
347 
348   /// If this is a ranked type, return the rank. Otherwise, abort.
349   int64_t getRank() const;
350 
351   /// Whether or not this is a ranked type. Memrefs, vectors and ranked tensors
352   /// have a rank, while unranked tensors do not.
353   bool hasRank() const;
354 
355   /// If this is a ranked type, return the shape. Otherwise, abort.
356   ArrayRef<int64_t> getShape() const;
357 
358   /// If this is unranked type or any dimension has unknown size (<0), it
359   /// doesn't have static shape. If all dimensions have known size (>= 0), it
360   /// has static shape.
361   bool hasStaticShape() const;
362 
363   /// If this has a static shape and the shape is equal to `shape` return true.
364   bool hasStaticShape(ArrayRef<int64_t> shape) const;
365 
366   /// If this is a ranked type, return the number of dimensions with dynamic
367   /// size. Otherwise, abort.
368   int64_t getNumDynamicDims() const;
369 
370   /// If this is ranked type, return the size of the specified dimension.
371   /// Otherwise, abort.
372   int64_t getDimSize(unsigned idx) const;
373 
374   /// Returns true if this dimension has a dynamic size (for ranked types);
375   /// aborts for unranked types.
376   bool isDynamicDim(unsigned idx) const;
377 
378   /// Returns the position of the dynamic dimension relative to just the dynamic
379   /// dimensions, given its `index` within the shape.
380   unsigned getDynamicDimIndex(unsigned index) const;
381 
382   /// Get the total amount of bits occupied by a value of this type.  This does
383   /// not take into account any memory layout or widening constraints, e.g. a
384   /// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice
385   /// it will likely be stored as in a 4xi64 vector register.  Fail an assertion
386   /// if the size cannot be computed statically, i.e. if the type has a dynamic
387   /// shape or if its elemental type does not have a known bit width.
388   int64_t getSizeInBits() const;
389 
390   /// Methods for support type inquiry through isa, cast, and dyn_cast.
391   static bool classof(Type type);
392 
393   /// Whether the given dimension size indicates a dynamic dimension.
isDynamic(int64_t dSize)394   static constexpr bool isDynamic(int64_t dSize) {
395     return dSize == kDynamicSize;
396   }
isDynamicStrideOrOffset(int64_t dStrideOrOffset)397   static constexpr bool isDynamicStrideOrOffset(int64_t dStrideOrOffset) {
398     return dStrideOrOffset == kDynamicStrideOrOffset;
399   }
400 };
401 
402 //===----------------------------------------------------------------------===//
403 // VectorType
404 //===----------------------------------------------------------------------===//
405 
406 /// Vector types represent multi-dimensional SIMD vectors, and have a fixed
407 /// known constant shape with one or more dimension.
408 class VectorType
409     : public Type::TypeBase<VectorType, ShapedType, detail::VectorTypeStorage> {
410 public:
411   using Base::Base;
412 
413   /// Get or create a new VectorType of the provided shape and element type.
414   /// Assumes the arguments define a well-formed VectorType.
415   static VectorType get(ArrayRef<int64_t> shape, Type elementType);
416 
417   /// Get or create a new VectorType of the provided shape and element type
418   /// declared at the given, potentially unknown, location.  If the VectorType
419   /// defined by the arguments would be ill-formed, emit errors and return
420   /// nullptr-wrapping type.
421   static VectorType getChecked(ArrayRef<int64_t> shape, Type elementType,
422                                Location location);
423 
424   /// Verify the construction of a vector type.
425   static LogicalResult verifyConstructionInvariants(Location loc,
426                                                     ArrayRef<int64_t> shape,
427                                                     Type elementType);
428 
429   /// Returns true of the given type can be used as an element of a vector type.
430   /// In particular, vectors can consist of integer or float primitives.
isValidElementType(Type t)431   static bool isValidElementType(Type t) {
432     return t.isa<IntegerType, FloatType>();
433   }
434 
435   ArrayRef<int64_t> getShape() const;
436 };
437 
438 //===----------------------------------------------------------------------===//
439 // TensorType
440 //===----------------------------------------------------------------------===//
441 
442 /// Tensor types represent multi-dimensional arrays, and have two variants:
443 /// RankedTensorType and UnrankedTensorType.
444 class TensorType : public ShapedType {
445 public:
446   using ShapedType::ShapedType;
447 
448   /// Return true if the specified element type is ok in a tensor.
449   static bool isValidElementType(Type type);
450 
451   /// Methods for support type inquiry through isa, cast, and dyn_cast.
452   static bool classof(Type type);
453 };
454 
455 //===----------------------------------------------------------------------===//
456 // RankedTensorType
457 
458 /// Ranked tensor types represent multi-dimensional arrays that have a shape
459 /// with a fixed number of dimensions. Each shape element can be a non-negative
460 /// integer or unknown (represented by -1).
461 class RankedTensorType
462     : public Type::TypeBase<RankedTensorType, TensorType,
463                             detail::RankedTensorTypeStorage> {
464 public:
465   using Base::Base;
466 
467   /// Get or create a new RankedTensorType of the provided shape and element
468   /// type. Assumes the arguments define a well-formed type.
469   static RankedTensorType get(ArrayRef<int64_t> shape, Type elementType);
470 
471   /// Get or create a new RankedTensorType of the provided shape and element
472   /// type declared at the given, potentially unknown, location.  If the
473   /// RankedTensorType defined by the arguments would be ill-formed, emit errors
474   /// and return a nullptr-wrapping type.
475   static RankedTensorType getChecked(ArrayRef<int64_t> shape, Type elementType,
476                                      Location location);
477 
478   /// Verify the construction of a ranked tensor type.
479   static LogicalResult verifyConstructionInvariants(Location loc,
480                                                     ArrayRef<int64_t> shape,
481                                                     Type elementType);
482 
483   ArrayRef<int64_t> getShape() const;
484 };
485 
486 //===----------------------------------------------------------------------===//
487 // UnrankedTensorType
488 
489 /// Unranked tensor types represent multi-dimensional arrays that have an
490 /// unknown shape.
491 class UnrankedTensorType
492     : public Type::TypeBase<UnrankedTensorType, TensorType,
493                             detail::UnrankedTensorTypeStorage> {
494 public:
495   using Base::Base;
496 
497   /// Get or create a new UnrankedTensorType of the provided shape and element
498   /// type. Assumes the arguments define a well-formed type.
499   static UnrankedTensorType get(Type elementType);
500 
501   /// Get or create a new UnrankedTensorType of the provided shape and element
502   /// type declared at the given, potentially unknown, location.  If the
503   /// UnrankedTensorType defined by the arguments would be ill-formed, emit
504   /// errors and return a nullptr-wrapping type.
505   static UnrankedTensorType getChecked(Type elementType, Location location);
506 
507   /// Verify the construction of a unranked tensor type.
508   static LogicalResult verifyConstructionInvariants(Location loc,
509                                                     Type elementType);
510 
getShape()511   ArrayRef<int64_t> getShape() const { return llvm::None; }
512 };
513 
514 //===----------------------------------------------------------------------===//
515 // BaseMemRefType
516 //===----------------------------------------------------------------------===//
517 
518 /// Base MemRef for Ranked and Unranked variants
519 class BaseMemRefType : public ShapedType {
520 public:
521   using ImplType = detail::BaseMemRefTypeStorage;
522   using ShapedType::ShapedType;
523 
524   /// Return true if the specified element type is ok in a memref.
isValidElementType(Type type)525   static bool isValidElementType(Type type) {
526     return type.isIntOrIndexOrFloat() || type.isa<VectorType, ComplexType>();
527   }
528 
529   /// Methods for support type inquiry through isa, cast, and dyn_cast.
530   static bool classof(Type type);
531 
532   /// Returns the memory space in which data referred to by this memref resides.
533   unsigned getMemorySpace() const;
534 };
535 
536 //===----------------------------------------------------------------------===//
537 // MemRefType
538 
539 /// MemRef types represent a region of memory that have a shape with a fixed
540 /// number of dimensions. Each shape element can be a non-negative integer or
541 /// unknown (represented by -1). MemRef types also have an affine map
542 /// composition, represented as an array AffineMap pointers.
543 class MemRefType : public Type::TypeBase<MemRefType, BaseMemRefType,
544                                          detail::MemRefTypeStorage> {
545 public:
546   /// This is a builder type that keeps local references to arguments. Arguments
547   /// that are passed into the builder must out-live the builder.
548   class Builder {
549   public:
550     // Build from another MemRefType.
Builder(MemRefType other)551     explicit Builder(MemRefType other)
552         : shape(other.getShape()), elementType(other.getElementType()),
553           affineMaps(other.getAffineMaps()),
554           memorySpace(other.getMemorySpace()) {}
555 
556     // Build from scratch.
Builder(ArrayRef<int64_t> shape,Type elementType)557     Builder(ArrayRef<int64_t> shape, Type elementType)
558         : shape(shape), elementType(elementType), affineMaps(), memorySpace(0) {
559     }
560 
setShape(ArrayRef<int64_t> newShape)561     Builder &setShape(ArrayRef<int64_t> newShape) {
562       shape = newShape;
563       return *this;
564     }
565 
setElementType(Type newElementType)566     Builder &setElementType(Type newElementType) {
567       elementType = newElementType;
568       return *this;
569     }
570 
setAffineMaps(ArrayRef<AffineMap> newAffineMaps)571     Builder &setAffineMaps(ArrayRef<AffineMap> newAffineMaps) {
572       affineMaps = newAffineMaps;
573       return *this;
574     }
575 
setMemorySpace(unsigned newMemorySpace)576     Builder &setMemorySpace(unsigned newMemorySpace) {
577       memorySpace = newMemorySpace;
578       return *this;
579     }
580 
MemRefType()581     operator MemRefType() {
582       return MemRefType::get(shape, elementType, affineMaps, memorySpace);
583     }
584 
585   private:
586     ArrayRef<int64_t> shape;
587     Type elementType;
588     ArrayRef<AffineMap> affineMaps;
589     unsigned memorySpace;
590   };
591 
592   using Base::Base;
593 
594   /// Get or create a new MemRefType based on shape, element type, affine
595   /// map composition, and memory space.  Assumes the arguments define a
596   /// well-formed MemRef type.  Use getChecked to gracefully handle MemRefType
597   /// construction failures.
598   static MemRefType get(ArrayRef<int64_t> shape, Type elementType,
599                         ArrayRef<AffineMap> affineMapComposition = {},
600                         unsigned memorySpace = 0);
601 
602   /// Get or create a new MemRefType based on shape, element type, affine
603   /// map composition, and memory space declared at the given location.
604   /// If the location is unknown, the last argument should be an instance of
605   /// UnknownLoc.  If the MemRefType defined by the arguments would be
606   /// ill-formed, emits errors (to the handler registered with the context or to
607   /// the error stream) and returns nullptr.
608   static MemRefType getChecked(ArrayRef<int64_t> shape, Type elementType,
609                                ArrayRef<AffineMap> affineMapComposition,
610                                unsigned memorySpace, Location location);
611 
612   ArrayRef<int64_t> getShape() const;
613 
614   /// Returns an array of affine map pointers representing the memref affine
615   /// map composition.
616   ArrayRef<AffineMap> getAffineMaps() const;
617 
618   // TODO: merge these two special values in a single one used everywhere.
619   // Unfortunately, uses of `-1` have crept deep into the codebase now and are
620   // hard to track.
getDynamicStrideOrOffset()621   static int64_t getDynamicStrideOrOffset() {
622     return ShapedType::kDynamicStrideOrOffset;
623   }
624 
625 private:
626   /// Get or create a new MemRefType defined by the arguments.  If the resulting
627   /// type would be ill-formed, return nullptr.  If the location is provided,
628   /// emit detailed error messages.
629   static MemRefType getImpl(ArrayRef<int64_t> shape, Type elementType,
630                             ArrayRef<AffineMap> affineMapComposition,
631                             unsigned memorySpace, Optional<Location> location);
632   using Base::getImpl;
633 };
634 
635 //===----------------------------------------------------------------------===//
636 // UnrankedMemRefType
637 
638 /// Unranked MemRef type represent multi-dimensional MemRefs that
639 /// have an unknown rank.
640 class UnrankedMemRefType
641     : public Type::TypeBase<UnrankedMemRefType, BaseMemRefType,
642                             detail::UnrankedMemRefTypeStorage> {
643 public:
644   using Base::Base;
645 
646   /// Get or create a new UnrankedMemRefType of the provided element
647   /// type and memory space
648   static UnrankedMemRefType get(Type elementType, unsigned memorySpace);
649 
650   /// Get or create a new UnrankedMemRefType of the provided element
651   /// type and memory space declared at the given, potentially unknown,
652   /// location. If the UnrankedMemRefType defined by the arguments would be
653   /// ill-formed, emit errors and return a nullptr-wrapping type.
654   static UnrankedMemRefType getChecked(Type elementType, unsigned memorySpace,
655                                        Location location);
656 
657   /// Verify the construction of a unranked memref type.
658   static LogicalResult verifyConstructionInvariants(Location loc,
659                                                     Type elementType,
660                                                     unsigned memorySpace);
661 
getShape()662   ArrayRef<int64_t> getShape() const { return llvm::None; }
663 };
664 
665 //===----------------------------------------------------------------------===//
666 // TupleType
667 //===----------------------------------------------------------------------===//
668 
669 /// Tuple types represent a collection of other types. Note: This type merely
670 /// provides a common mechanism for representing tuples in MLIR. It is up to
671 /// dialect authors to provides operations for manipulating them, e.g.
672 /// extract_tuple_element. When possible, users should prefer multi-result
673 /// operations in the place of tuples.
674 class TupleType
675     : public Type::TypeBase<TupleType, Type, detail::TupleTypeStorage> {
676 public:
677   using Base::Base;
678 
679   /// Get or create a new TupleType with the provided element types. Assumes the
680   /// arguments define a well-formed type.
681   static TupleType get(TypeRange elementTypes, MLIRContext *context);
682 
683   /// Get or create an empty tuple type.
684   static TupleType get(MLIRContext *context);
685 
686   /// Return the elements types for this tuple.
687   ArrayRef<Type> getTypes() const;
688 
689   /// Accumulate the types contained in this tuple and tuples nested within it.
690   /// Note that this only flattens nested tuples, not any other container type,
691   /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
692   /// (i32, tensor<i32>, f32, i64)
693   void getFlattenedTypes(SmallVectorImpl<Type> &types);
694 
695   /// Return the number of held types.
696   size_t size() const;
697 
698   /// Iterate over the held elements.
699   using iterator = ArrayRef<Type>::iterator;
begin()700   iterator begin() const { return getTypes().begin(); }
end()701   iterator end() const { return getTypes().end(); }
702 
703   /// Return the element type at index 'index'.
getType(size_t index)704   Type getType(size_t index) const {
705     assert(index < size() && "invalid index for tuple type");
706     return getTypes()[index];
707   }
708 };
709 
710 //===----------------------------------------------------------------------===//
711 // Deferred Method Definitions
712 //===----------------------------------------------------------------------===//
713 
classof(Type type)714 inline bool BaseMemRefType::classof(Type type) {
715   return type.isa<MemRefType, UnrankedMemRefType>();
716 }
717 
classof(Type type)718 inline bool FloatType::classof(Type type) {
719   return type.isa<BFloat16Type, Float16Type, Float32Type, Float64Type>();
720 }
721 
classof(Type type)722 inline bool ShapedType::classof(Type type) {
723   return type.isa<RankedTensorType, VectorType, UnrankedTensorType,
724                   UnrankedMemRefType, MemRefType>();
725 }
726 
classof(Type type)727 inline bool TensorType::classof(Type type) {
728   return type.isa<RankedTensorType, UnrankedTensorType>();
729 }
730 
731 //===----------------------------------------------------------------------===//
732 // Type Utilities
733 //===----------------------------------------------------------------------===//
734 
735 /// Returns the strides of the MemRef if the layout map is in strided form.
736 /// MemRefs with layout maps in strided form include:
737 ///   1. empty or identity layout map, in which case the stride information is
738 ///      the canonical form computed from sizes;
739 ///   2. single affine map layout of the form `K + k0 * d0 + ... kn * dn`,
740 ///      where K and ki's are constants or symbols.
741 ///
742 /// A stride specification is a list of integer values that are either static
743 /// or dynamic (encoded with getDynamicStrideOrOffset()). Strides encode the
744 /// distance in the number of elements between successive entries along a
745 /// particular dimension. For example, `memref<42x16xf32, (64 * d0 + d1)>`
746 /// specifies a view into a non-contiguous memory region of `42` by `16` `f32`
747 /// elements in which the distance between two consecutive elements along the
748 /// outer dimension is `1` and the distance between two consecutive elements
749 /// along the inner dimension is `64`.
750 ///
751 /// Returns whether a simple strided form can be extracted from the composition
752 /// of the layout map.
753 ///
754 /// The convention is that the strides for dimensions d0, .. dn appear in
755 /// order to make indexing intuitive into the result.
756 LogicalResult getStridesAndOffset(MemRefType t,
757                                   SmallVectorImpl<int64_t> &strides,
758                                   int64_t &offset);
759 LogicalResult getStridesAndOffset(MemRefType t,
760                                   SmallVectorImpl<AffineExpr> &strides,
761                                   AffineExpr &offset);
762 
763 /// Given a list of strides (in which MemRefType::getDynamicStrideOrOffset()
764 /// represents a dynamic value), return the single result AffineMap which
765 /// represents the linearized strided layout map. Dimensions correspond to the
766 /// offset followed by the strides in order. Symbols are inserted for each
767 /// dynamic dimension in order. A stride cannot take value `0`.
768 ///
769 /// Examples:
770 /// =========
771 ///
772 ///   1. For offset: 0 strides: ?, ?, 1 return
773 ///         (i, j, k)[M, N]->(M * i + N * j + k)
774 ///
775 ///   2. For offset: 3 strides: 32, ?, 16 return
776 ///         (i, j, k)[M]->(3 + 32 * i + M * j + 16 * k)
777 ///
778 ///   3. For offset: ? strides: ?, ?, ? return
779 ///         (i, j, k)[off, M, N, P]->(off + M * i + N * j + P * k)
780 AffineMap makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, int64_t offset,
781                                      MLIRContext *context);
782 
783 /// Return a version of `t` with identity layout if it can be determined
784 /// statically that the layout is the canonical contiguous strided layout.
785 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
786 /// `t` with simplified layout.
787 MemRefType canonicalizeStridedLayout(MemRefType t);
788 
789 /// Return a version of `t` with a layout that has all dynamic offset and
790 /// strides. This is used to erase the static layout.
791 MemRefType eraseStridedLayout(MemRefType t);
792 
793 /// Given MemRef `sizes` that are either static or dynamic, returns the
794 /// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
795 /// once a dynamic dimension is encountered, all canonical strides become
796 /// dynamic and need to be encoded with a different symbol.
797 /// For canonical strides expressions, the offset is always 0 and and fastest
798 /// varying stride is always `1`.
799 ///
800 /// Examples:
801 ///   - memref<3x4x5xf32> has canonical stride expression
802 ///         `20*exprs[0] + 5*exprs[1] + exprs[2]`.
803 ///   - memref<3x?x5xf32> has canonical stride expression
804 ///         `s0*exprs[0] + 5*exprs[1] + exprs[2]`.
805 ///   - memref<3x4x?xf32> has canonical stride expression
806 ///         `s1*exprs[0] + s0*exprs[1] + exprs[2]`.
807 AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
808                                           ArrayRef<AffineExpr> exprs,
809                                           MLIRContext *context);
810 
811 /// Return the result of makeCanonicalStrudedLayoutExpr for the common case
812 /// where `exprs` is {d0, d1, .., d_(sizes.size()-1)}
813 AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
814                                           MLIRContext *context);
815 
816 /// Return true if the layout for `t` is compatible with strided semantics.
817 bool isStrided(MemRefType t);
818 
819 } // end namespace mlir
820 
821 #endif // MLIR_IR_BUILTINTYPES_H
822