1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/Diagnostics.h"
14 #include "mlir/IR/Dialect.h"
15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/ADT/BitVector.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/Twine.h"
19 
20 using namespace mlir;
21 using namespace mlir::detail;
22 
23 //===----------------------------------------------------------------------===//
24 /// ComplexType
25 //===----------------------------------------------------------------------===//
26 
get(Type elementType)27 ComplexType ComplexType::get(Type elementType) {
28   return Base::get(elementType.getContext(), elementType);
29 }
30 
getChecked(Type elementType,Location location)31 ComplexType ComplexType::getChecked(Type elementType, Location location) {
32   return Base::getChecked(location, elementType);
33 }
34 
35 /// Verify the construction of an integer type.
verifyConstructionInvariants(Location loc,Type elementType)36 LogicalResult ComplexType::verifyConstructionInvariants(Location loc,
37                                                         Type elementType) {
38   if (!elementType.isIntOrFloat())
39     return emitError(loc, "invalid element type for complex");
40   return success();
41 }
42 
getElementType()43 Type ComplexType::getElementType() { return getImpl()->elementType; }
44 
45 //===----------------------------------------------------------------------===//
46 // Integer Type
47 //===----------------------------------------------------------------------===//
48 
49 // static constexpr must have a definition (until in C++17 and inline variable).
50 constexpr unsigned IntegerType::kMaxWidth;
51 
52 /// Verify the construction of an integer type.
53 LogicalResult
verifyConstructionInvariants(Location loc,unsigned width,SignednessSemantics signedness)54 IntegerType::verifyConstructionInvariants(Location loc, unsigned width,
55                                           SignednessSemantics signedness) {
56   if (width > IntegerType::kMaxWidth) {
57     return emitError(loc) << "integer bitwidth is limited to "
58                           << IntegerType::kMaxWidth << " bits";
59   }
60   return success();
61 }
62 
getWidth() const63 unsigned IntegerType::getWidth() const { return getImpl()->width; }
64 
getSignedness() const65 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
66   return getImpl()->signedness;
67 }
68 
69 //===----------------------------------------------------------------------===//
70 // Float Type
71 //===----------------------------------------------------------------------===//
72 
getWidth()73 unsigned FloatType::getWidth() {
74   if (isa<Float16Type, BFloat16Type>())
75     return 16;
76   if (isa<Float32Type>())
77     return 32;
78   if (isa<Float64Type>())
79     return 64;
80   llvm_unreachable("unexpected float type");
81 }
82 
83 /// Returns the floating semantics for the given type.
getFloatSemantics()84 const llvm::fltSemantics &FloatType::getFloatSemantics() {
85   if (isa<BFloat16Type>())
86     return APFloat::BFloat();
87   if (isa<Float16Type>())
88     return APFloat::IEEEhalf();
89   if (isa<Float32Type>())
90     return APFloat::IEEEsingle();
91   if (isa<Float64Type>())
92     return APFloat::IEEEdouble();
93   llvm_unreachable("non-floating point type used");
94 }
95 
96 //===----------------------------------------------------------------------===//
97 // FunctionType
98 //===----------------------------------------------------------------------===//
99 
get(TypeRange inputs,TypeRange results,MLIRContext * context)100 FunctionType FunctionType::get(TypeRange inputs, TypeRange results,
101                                MLIRContext *context) {
102   return Base::get(context, inputs, results);
103 }
104 
getNumInputs() const105 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
106 
getInputs() const107 ArrayRef<Type> FunctionType::getInputs() const {
108   return getImpl()->getInputs();
109 }
110 
getNumResults() const111 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
112 
getResults() const113 ArrayRef<Type> FunctionType::getResults() const {
114   return getImpl()->getResults();
115 }
116 
117 /// Helper to call a callback once on each index in the range
118 /// [0, `totalIndices`), *except* for the indices given in `indices`.
119 /// `indices` is allowed to have duplicates and can be in any order.
iterateIndicesExcept(unsigned totalIndices,ArrayRef<unsigned> indices,function_ref<void (unsigned)> callback)120 inline void iterateIndicesExcept(unsigned totalIndices,
121                                  ArrayRef<unsigned> indices,
122                                  function_ref<void(unsigned)> callback) {
123   llvm::BitVector skipIndices(totalIndices);
124   for (unsigned i : indices)
125     skipIndices.set(i);
126 
127   for (unsigned i = 0; i < totalIndices; ++i)
128     if (!skipIndices.test(i))
129       callback(i);
130 }
131 
132 /// Returns a new function type without the specified arguments and results.
133 FunctionType
getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,ArrayRef<unsigned> resultIndices)134 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
135                                        ArrayRef<unsigned> resultIndices) {
136   ArrayRef<Type> newInputTypes = getInputs();
137   SmallVector<Type, 4> newInputTypesBuffer;
138   if (!argIndices.empty()) {
139     unsigned originalNumArgs = getNumInputs();
140     iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
141       newInputTypesBuffer.emplace_back(getInput(i));
142     });
143     newInputTypes = newInputTypesBuffer;
144   }
145 
146   ArrayRef<Type> newResultTypes = getResults();
147   SmallVector<Type, 4> newResultTypesBuffer;
148   if (!resultIndices.empty()) {
149     unsigned originalNumResults = getNumResults();
150     iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
151       newResultTypesBuffer.emplace_back(getResult(i));
152     });
153     newResultTypes = newResultTypesBuffer;
154   }
155 
156   return get(newInputTypes, newResultTypes, getContext());
157 }
158 
159 //===----------------------------------------------------------------------===//
160 // OpaqueType
161 //===----------------------------------------------------------------------===//
162 
get(Identifier dialect,StringRef typeData,MLIRContext * context)163 OpaqueType OpaqueType::get(Identifier dialect, StringRef typeData,
164                            MLIRContext *context) {
165   return Base::get(context, dialect, typeData);
166 }
167 
getChecked(Identifier dialect,StringRef typeData,MLIRContext * context,Location location)168 OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData,
169                                   MLIRContext *context, Location location) {
170   return Base::getChecked(location, dialect, typeData);
171 }
172 
173 /// Returns the dialect namespace of the opaque type.
getDialectNamespace() const174 Identifier OpaqueType::getDialectNamespace() const {
175   return getImpl()->dialectNamespace;
176 }
177 
178 /// Returns the raw type data of the opaque type.
getTypeData() const179 StringRef OpaqueType::getTypeData() const { return getImpl()->typeData; }
180 
181 /// Verify the construction of an opaque type.
verifyConstructionInvariants(Location loc,Identifier dialect,StringRef typeData)182 LogicalResult OpaqueType::verifyConstructionInvariants(Location loc,
183                                                        Identifier dialect,
184                                                        StringRef typeData) {
185   if (!Dialect::isValidNamespace(dialect.strref()))
186     return emitError(loc, "invalid dialect namespace '") << dialect << "'";
187   return success();
188 }
189 
190 //===----------------------------------------------------------------------===//
191 // ShapedType
192 //===----------------------------------------------------------------------===//
193 constexpr int64_t ShapedType::kDynamicSize;
194 constexpr int64_t ShapedType::kDynamicStrideOrOffset;
195 
getElementType() const196 Type ShapedType::getElementType() const {
197   return static_cast<ImplType *>(impl)->elementType;
198 }
199 
getElementTypeBitWidth() const200 unsigned ShapedType::getElementTypeBitWidth() const {
201   return getElementType().getIntOrFloatBitWidth();
202 }
203 
getNumElements() const204 int64_t ShapedType::getNumElements() const {
205   assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
206   auto shape = getShape();
207   int64_t num = 1;
208   for (auto dim : shape)
209     num *= dim;
210   return num;
211 }
212 
getRank() const213 int64_t ShapedType::getRank() const { return getShape().size(); }
214 
hasRank() const215 bool ShapedType::hasRank() const {
216   return !isa<UnrankedMemRefType, UnrankedTensorType>();
217 }
218 
getDimSize(unsigned idx) const219 int64_t ShapedType::getDimSize(unsigned idx) const {
220   assert(idx < getRank() && "invalid index for shaped type");
221   return getShape()[idx];
222 }
223 
isDynamicDim(unsigned idx) const224 bool ShapedType::isDynamicDim(unsigned idx) const {
225   assert(idx < getRank() && "invalid index for shaped type");
226   return isDynamic(getShape()[idx]);
227 }
228 
getDynamicDimIndex(unsigned index) const229 unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
230   assert(index < getRank() && "invalid index");
231   assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
232   return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
233 }
234 
235 /// Get the number of bits require to store a value of the given shaped type.
236 /// Compute the value recursively since tensors are allowed to have vectors as
237 /// elements.
getSizeInBits() const238 int64_t ShapedType::getSizeInBits() const {
239   assert(hasStaticShape() &&
240          "cannot get the bit size of an aggregate with a dynamic shape");
241 
242   auto elementType = getElementType();
243   if (elementType.isIntOrFloat())
244     return elementType.getIntOrFloatBitWidth() * getNumElements();
245 
246   if (auto complexType = elementType.dyn_cast<ComplexType>()) {
247     elementType = complexType.getElementType();
248     return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
249   }
250 
251   // Tensors can have vectors and other tensors as elements, other shaped types
252   // cannot.
253   assert(isa<TensorType>() && "unsupported element type");
254   assert((elementType.isa<VectorType, TensorType>()) &&
255          "unsupported tensor element type");
256   return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
257 }
258 
getShape() const259 ArrayRef<int64_t> ShapedType::getShape() const {
260   if (auto vectorType = dyn_cast<VectorType>())
261     return vectorType.getShape();
262   if (auto tensorType = dyn_cast<RankedTensorType>())
263     return tensorType.getShape();
264   return cast<MemRefType>().getShape();
265 }
266 
getNumDynamicDims() const267 int64_t ShapedType::getNumDynamicDims() const {
268   return llvm::count_if(getShape(), isDynamic);
269 }
270 
hasStaticShape() const271 bool ShapedType::hasStaticShape() const {
272   return hasRank() && llvm::none_of(getShape(), isDynamic);
273 }
274 
hasStaticShape(ArrayRef<int64_t> shape) const275 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
276   return hasStaticShape() && getShape() == shape;
277 }
278 
279 //===----------------------------------------------------------------------===//
280 // VectorType
281 //===----------------------------------------------------------------------===//
282 
get(ArrayRef<int64_t> shape,Type elementType)283 VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
284   return Base::get(elementType.getContext(), shape, elementType);
285 }
286 
getChecked(ArrayRef<int64_t> shape,Type elementType,Location location)287 VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
288                                   Location location) {
289   return Base::getChecked(location, shape, elementType);
290 }
291 
verifyConstructionInvariants(Location loc,ArrayRef<int64_t> shape,Type elementType)292 LogicalResult VectorType::verifyConstructionInvariants(Location loc,
293                                                        ArrayRef<int64_t> shape,
294                                                        Type elementType) {
295   if (shape.empty())
296     return emitError(loc, "vector types must have at least one dimension");
297 
298   if (!isValidElementType(elementType))
299     return emitError(loc, "vector elements must be int or float type");
300 
301   if (any_of(shape, [](int64_t i) { return i <= 0; }))
302     return emitError(loc, "vector types must have positive constant sizes");
303 
304   return success();
305 }
306 
getShape() const307 ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
308 
309 //===----------------------------------------------------------------------===//
310 // TensorType
311 //===----------------------------------------------------------------------===//
312 
313 // Check if "elementType" can be an element type of a tensor. Emit errors if
314 // location is not nullptr.  Returns failure if check failed.
checkTensorElementType(Location location,Type elementType)315 static LogicalResult checkTensorElementType(Location location,
316                                             Type elementType) {
317   if (!TensorType::isValidElementType(elementType))
318     return emitError(location, "invalid tensor element type: ") << elementType;
319   return success();
320 }
321 
322 /// Return true if the specified element type is ok in a tensor.
isValidElementType(Type type)323 bool TensorType::isValidElementType(Type type) {
324   // Note: Non standard/builtin types are allowed to exist within tensor
325   // types. Dialects are expected to verify that tensor types have a valid
326   // element type within that dialect.
327   return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
328                   IndexType>() ||
329          !type.getDialect().getNamespace().empty();
330 }
331 
332 //===----------------------------------------------------------------------===//
333 // RankedTensorType
334 //===----------------------------------------------------------------------===//
335 
get(ArrayRef<int64_t> shape,Type elementType)336 RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
337                                        Type elementType) {
338   return Base::get(elementType.getContext(), shape, elementType);
339 }
340 
getChecked(ArrayRef<int64_t> shape,Type elementType,Location location)341 RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
342                                               Type elementType,
343                                               Location location) {
344   return Base::getChecked(location, shape, elementType);
345 }
346 
verifyConstructionInvariants(Location loc,ArrayRef<int64_t> shape,Type elementType)347 LogicalResult RankedTensorType::verifyConstructionInvariants(
348     Location loc, ArrayRef<int64_t> shape, Type elementType) {
349   for (int64_t s : shape) {
350     if (s < -1)
351       return emitError(loc, "invalid tensor dimension size");
352   }
353   return checkTensorElementType(loc, elementType);
354 }
355 
getShape() const356 ArrayRef<int64_t> RankedTensorType::getShape() const {
357   return getImpl()->getShape();
358 }
359 
360 //===----------------------------------------------------------------------===//
361 // UnrankedTensorType
362 //===----------------------------------------------------------------------===//
363 
get(Type elementType)364 UnrankedTensorType UnrankedTensorType::get(Type elementType) {
365   return Base::get(elementType.getContext(), elementType);
366 }
367 
getChecked(Type elementType,Location location)368 UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
369                                                   Location location) {
370   return Base::getChecked(location, elementType);
371 }
372 
373 LogicalResult
verifyConstructionInvariants(Location loc,Type elementType)374 UnrankedTensorType::verifyConstructionInvariants(Location loc,
375                                                  Type elementType) {
376   return checkTensorElementType(loc, elementType);
377 }
378 
379 //===----------------------------------------------------------------------===//
380 // BaseMemRefType
381 //===----------------------------------------------------------------------===//
382 
getMemorySpace() const383 unsigned BaseMemRefType::getMemorySpace() const {
384   return static_cast<ImplType *>(impl)->memorySpace;
385 }
386 
387 //===----------------------------------------------------------------------===//
388 // MemRefType
389 //===----------------------------------------------------------------------===//
390 
391 /// Get or create a new MemRefType based on shape, element type, affine
392 /// map composition, and memory space.  Assumes the arguments define a
393 /// well-formed MemRef type.  Use getChecked to gracefully handle MemRefType
394 /// construction failures.
get(ArrayRef<int64_t> shape,Type elementType,ArrayRef<AffineMap> affineMapComposition,unsigned memorySpace)395 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
396                            ArrayRef<AffineMap> affineMapComposition,
397                            unsigned memorySpace) {
398   auto result = getImpl(shape, elementType, affineMapComposition, memorySpace,
399                         /*location=*/llvm::None);
400   assert(result && "Failed to construct instance of MemRefType.");
401   return result;
402 }
403 
404 /// Get or create a new MemRefType based on shape, element type, affine
405 /// map composition, and memory space declared at the given location.
406 /// If the location is unknown, the last argument should be an instance of
407 /// UnknownLoc.  If the MemRefType defined by the arguments would be
408 /// ill-formed, emits errors (to the handler registered with the context or to
409 /// the error stream) and returns nullptr.
getChecked(ArrayRef<int64_t> shape,Type elementType,ArrayRef<AffineMap> affineMapComposition,unsigned memorySpace,Location location)410 MemRefType MemRefType::getChecked(ArrayRef<int64_t> shape, Type elementType,
411                                   ArrayRef<AffineMap> affineMapComposition,
412                                   unsigned memorySpace, Location location) {
413   return getImpl(shape, elementType, affineMapComposition, memorySpace,
414                  location);
415 }
416 
417 /// Get or create a new MemRefType defined by the arguments.  If the resulting
418 /// type would be ill-formed, return nullptr.  If the location is provided,
419 /// emit detailed error messages.  To emit errors when the location is unknown,
420 /// pass in an instance of UnknownLoc.
getImpl(ArrayRef<int64_t> shape,Type elementType,ArrayRef<AffineMap> affineMapComposition,unsigned memorySpace,Optional<Location> location)421 MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
422                                ArrayRef<AffineMap> affineMapComposition,
423                                unsigned memorySpace,
424                                Optional<Location> location) {
425   auto *context = elementType.getContext();
426 
427   if (!BaseMemRefType::isValidElementType(elementType))
428     return emitOptionalError(location, "invalid memref element type"),
429            MemRefType();
430 
431   for (int64_t s : shape) {
432     // Negative sizes are not allowed except for `-1` that means dynamic size.
433     if (s < -1)
434       return emitOptionalError(location, "invalid memref size"), MemRefType();
435   }
436 
437   // Check that the structure of the composition is valid, i.e. that each
438   // subsequent affine map has as many inputs as the previous map has results.
439   // Take the dimensionality of the MemRef for the first map.
440   auto dim = shape.size();
441   unsigned i = 0;
442   for (const auto &affineMap : affineMapComposition) {
443     if (affineMap.getNumDims() != dim) {
444       if (location)
445         emitError(*location)
446             << "memref affine map dimension mismatch between "
447             << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
448             << " and affine map" << i + 1 << ": " << dim
449             << " != " << affineMap.getNumDims();
450       return nullptr;
451     }
452 
453     dim = affineMap.getNumResults();
454     ++i;
455   }
456 
457   // Drop identity maps from the composition.
458   // This may lead to the composition becoming empty, which is interpreted as an
459   // implicit identity.
460   SmallVector<AffineMap, 2> cleanedAffineMapComposition;
461   for (const auto &map : affineMapComposition) {
462     if (map.isIdentity())
463       continue;
464     cleanedAffineMapComposition.push_back(map);
465   }
466 
467   return Base::get(context, shape, elementType, cleanedAffineMapComposition,
468                    memorySpace);
469 }
470 
getShape() const471 ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); }
472 
getAffineMaps() const473 ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
474   return getImpl()->getAffineMaps();
475 }
476 
477 //===----------------------------------------------------------------------===//
478 // UnrankedMemRefType
479 //===----------------------------------------------------------------------===//
480 
get(Type elementType,unsigned memorySpace)481 UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
482                                            unsigned memorySpace) {
483   return Base::get(elementType.getContext(), elementType, memorySpace);
484 }
485 
getChecked(Type elementType,unsigned memorySpace,Location location)486 UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType,
487                                                   unsigned memorySpace,
488                                                   Location location) {
489   return Base::getChecked(location, elementType, memorySpace);
490 }
491 
492 LogicalResult
verifyConstructionInvariants(Location loc,Type elementType,unsigned memorySpace)493 UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
494                                                  unsigned memorySpace) {
495   if (!BaseMemRefType::isValidElementType(elementType))
496     return emitError(loc, "invalid memref element type");
497   return success();
498 }
499 
500 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
501 // i.e. single term). Accumulate the AffineExpr into the existing one.
extractStridesFromTerm(AffineExpr e,AffineExpr multiplicativeFactor,MutableArrayRef<AffineExpr> strides,AffineExpr & offset)502 static void extractStridesFromTerm(AffineExpr e,
503                                    AffineExpr multiplicativeFactor,
504                                    MutableArrayRef<AffineExpr> strides,
505                                    AffineExpr &offset) {
506   if (auto dim = e.dyn_cast<AffineDimExpr>())
507     strides[dim.getPosition()] =
508         strides[dim.getPosition()] + multiplicativeFactor;
509   else
510     offset = offset + e * multiplicativeFactor;
511 }
512 
513 /// Takes a single AffineExpr `e` and populates the `strides` array with the
514 /// strides expressions for each dim position.
515 /// The convention is that the strides for dimensions d0, .. dn appear in
516 /// order to make indexing intuitive into the result.
extractStrides(AffineExpr e,AffineExpr multiplicativeFactor,MutableArrayRef<AffineExpr> strides,AffineExpr & offset)517 static LogicalResult extractStrides(AffineExpr e,
518                                     AffineExpr multiplicativeFactor,
519                                     MutableArrayRef<AffineExpr> strides,
520                                     AffineExpr &offset) {
521   auto bin = e.dyn_cast<AffineBinaryOpExpr>();
522   if (!bin) {
523     extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
524     return success();
525   }
526 
527   if (bin.getKind() == AffineExprKind::CeilDiv ||
528       bin.getKind() == AffineExprKind::FloorDiv ||
529       bin.getKind() == AffineExprKind::Mod)
530     return failure();
531 
532   if (bin.getKind() == AffineExprKind::Mul) {
533     auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
534     if (dim) {
535       strides[dim.getPosition()] =
536           strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
537       return success();
538     }
539     // LHS and RHS may both contain complex expressions of dims. Try one path
540     // and if it fails try the other. This is guaranteed to succeed because
541     // only one path may have a `dim`, otherwise this is not an AffineExpr in
542     // the first place.
543     if (bin.getLHS().isSymbolicOrConstant())
544       return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
545                             strides, offset);
546     return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
547                           strides, offset);
548   }
549 
550   if (bin.getKind() == AffineExprKind::Add) {
551     auto res1 =
552         extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
553     auto res2 =
554         extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
555     return success(succeeded(res1) && succeeded(res2));
556   }
557 
558   llvm_unreachable("unexpected binary operation");
559 }
560 
getStridesAndOffset(MemRefType t,SmallVectorImpl<AffineExpr> & strides,AffineExpr & offset)561 LogicalResult mlir::getStridesAndOffset(MemRefType t,
562                                         SmallVectorImpl<AffineExpr> &strides,
563                                         AffineExpr &offset) {
564   auto affineMaps = t.getAffineMaps();
565   // For now strides are only computed on a single affine map with a single
566   // result (i.e. the closed subset of linearization maps that are compatible
567   // with striding semantics).
568   // TODO: support more forms on a per-need basis.
569   if (affineMaps.size() > 1)
570     return failure();
571   if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1)
572     return failure();
573 
574   auto zero = getAffineConstantExpr(0, t.getContext());
575   auto one = getAffineConstantExpr(1, t.getContext());
576   offset = zero;
577   strides.assign(t.getRank(), zero);
578 
579   AffineMap m;
580   if (!affineMaps.empty()) {
581     m = affineMaps.front();
582     assert(!m.isIdentity() && "unexpected identity map");
583   }
584 
585   // Canonical case for empty map.
586   if (!m) {
587     // 0-D corner case, offset is already 0.
588     if (t.getRank() == 0)
589       return success();
590     auto stridedExpr =
591         makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
592     if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
593       return success();
594     assert(false && "unexpected failure: extract strides in canonical layout");
595   }
596 
597   // Non-canonical case requires more work.
598   auto stridedExpr =
599       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
600   if (failed(extractStrides(stridedExpr, one, strides, offset))) {
601     offset = AffineExpr();
602     strides.clear();
603     return failure();
604   }
605 
606   // Simplify results to allow folding to constants and simple checks.
607   unsigned numDims = m.getNumDims();
608   unsigned numSymbols = m.getNumSymbols();
609   offset = simplifyAffineExpr(offset, numDims, numSymbols);
610   for (auto &stride : strides)
611     stride = simplifyAffineExpr(stride, numDims, numSymbols);
612 
613   /// In practice, a strided memref must be internally non-aliasing. Test
614   /// against 0 as a proxy.
615   /// TODO: static cases can have more advanced checks.
616   /// TODO: dynamic cases would require a way to compare symbolic
617   /// expressions and would probably need an affine set context propagated
618   /// everywhere.
619   if (llvm::any_of(strides, [](AffineExpr e) {
620         return e == getAffineConstantExpr(0, e.getContext());
621       })) {
622     offset = AffineExpr();
623     strides.clear();
624     return failure();
625   }
626 
627   return success();
628 }
629 
getStridesAndOffset(MemRefType t,SmallVectorImpl<int64_t> & strides,int64_t & offset)630 LogicalResult mlir::getStridesAndOffset(MemRefType t,
631                                         SmallVectorImpl<int64_t> &strides,
632                                         int64_t &offset) {
633   AffineExpr offsetExpr;
634   SmallVector<AffineExpr, 4> strideExprs;
635   if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
636     return failure();
637   if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
638     offset = cst.getValue();
639   else
640     offset = ShapedType::kDynamicStrideOrOffset;
641   for (auto e : strideExprs) {
642     if (auto c = e.dyn_cast<AffineConstantExpr>())
643       strides.push_back(c.getValue());
644     else
645       strides.push_back(ShapedType::kDynamicStrideOrOffset);
646   }
647   return success();
648 }
649 
650 //===----------------------------------------------------------------------===//
651 /// TupleType
652 //===----------------------------------------------------------------------===//
653 
654 /// Get or create a new TupleType with the provided element types. Assumes the
655 /// arguments define a well-formed type.
get(TypeRange elementTypes,MLIRContext * context)656 TupleType TupleType::get(TypeRange elementTypes, MLIRContext *context) {
657   return Base::get(context, elementTypes);
658 }
659 
660 /// Get or create an empty tuple type.
get(MLIRContext * context)661 TupleType TupleType::get(MLIRContext *context) { return get({}, context); }
662 
663 /// Return the elements types for this tuple.
getTypes() const664 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
665 
666 /// Accumulate the types contained in this tuple and tuples nested within it.
667 /// Note that this only flattens nested tuples, not any other container type,
668 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
669 /// (i32, tensor<i32>, f32, i64)
getFlattenedTypes(SmallVectorImpl<Type> & types)670 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
671   for (Type type : getTypes()) {
672     if (auto nestedTuple = type.dyn_cast<TupleType>())
673       nestedTuple.getFlattenedTypes(types);
674     else
675       types.push_back(type);
676   }
677 }
678 
679 /// Return the number of element types.
size() const680 size_t TupleType::size() const { return getImpl()->size(); }
681 
makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,int64_t offset,MLIRContext * context)682 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
683                                            int64_t offset,
684                                            MLIRContext *context) {
685   AffineExpr expr;
686   unsigned nSymbols = 0;
687 
688   // AffineExpr for offset.
689   // Static case.
690   if (offset != MemRefType::getDynamicStrideOrOffset()) {
691     auto cst = getAffineConstantExpr(offset, context);
692     expr = cst;
693   } else {
694     // Dynamic case, new symbol for the offset.
695     auto sym = getAffineSymbolExpr(nSymbols++, context);
696     expr = sym;
697   }
698 
699   // AffineExpr for strides.
700   for (auto en : llvm::enumerate(strides)) {
701     auto dim = en.index();
702     auto stride = en.value();
703     assert(stride != 0 && "Invalid stride specification");
704     auto d = getAffineDimExpr(dim, context);
705     AffineExpr mult;
706     // Static case.
707     if (stride != MemRefType::getDynamicStrideOrOffset())
708       mult = getAffineConstantExpr(stride, context);
709     else
710       // Dynamic case, new symbol for each new stride.
711       mult = getAffineSymbolExpr(nSymbols++, context);
712     expr = expr + d * mult;
713   }
714 
715   return AffineMap::get(strides.size(), nSymbols, expr);
716 }
717 
718 /// Return a version of `t` with identity layout if it can be determined
719 /// statically that the layout is the canonical contiguous strided layout.
720 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
721 /// `t` with simplified layout.
722 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
canonicalizeStridedLayout(MemRefType t)723 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
724   auto affineMaps = t.getAffineMaps();
725   // Already in canonical form.
726   if (affineMaps.empty())
727     return t;
728 
729   // Can't reduce to canonical identity form, return in canonical form.
730   if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1)
731     return t;
732 
733   // If the canonical strided layout for the sizes of `t` is equal to the
734   // simplified layout of `t` we can just return an empty layout. Otherwise,
735   // just simplify the existing layout.
736   AffineExpr expr =
737       makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
738   auto m = affineMaps[0];
739   auto simplifiedLayoutExpr =
740       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
741   if (expr != simplifiedLayoutExpr)
742     return MemRefType::Builder(t).setAffineMaps({AffineMap::get(
743         m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)});
744   return MemRefType::Builder(t).setAffineMaps({});
745 }
746 
makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,ArrayRef<AffineExpr> exprs,MLIRContext * context)747 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
748                                                 ArrayRef<AffineExpr> exprs,
749                                                 MLIRContext *context) {
750   // Size 0 corner case is useful for canonicalizations.
751   if (llvm::is_contained(sizes, 0))
752     return getAffineConstantExpr(0, context);
753 
754   auto maps = AffineMap::inferFromExprList(exprs);
755   assert(!maps.empty() && "Expected one non-empty map");
756   unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
757 
758   AffineExpr expr;
759   bool dynamicPoisonBit = false;
760   int64_t runningSize = 1;
761   for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
762     int64_t size = std::get<1>(en);
763     // Degenerate case, no size =-> no stride
764     if (size == 0)
765       continue;
766     AffineExpr dimExpr = std::get<0>(en);
767     AffineExpr stride = dynamicPoisonBit
768                             ? getAffineSymbolExpr(nSymbols++, context)
769                             : getAffineConstantExpr(runningSize, context);
770     expr = expr ? expr + dimExpr * stride : dimExpr * stride;
771     if (size > 0)
772       runningSize *= size;
773     else
774       dynamicPoisonBit = true;
775   }
776   return simplifyAffineExpr(expr, numDims, nSymbols);
777 }
778 
779 /// Return a version of `t` with a layout that has all dynamic offset and
780 /// strides. This is used to erase the static layout.
eraseStridedLayout(MemRefType t)781 MemRefType mlir::eraseStridedLayout(MemRefType t) {
782   auto val = ShapedType::kDynamicStrideOrOffset;
783   return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap(
784       SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()));
785 }
786 
makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,MLIRContext * context)787 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
788                                                 MLIRContext *context) {
789   SmallVector<AffineExpr, 4> exprs;
790   exprs.reserve(sizes.size());
791   for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
792     exprs.push_back(getAffineDimExpr(dim, context));
793   return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
794 }
795 
796 /// Return true if the layout for `t` is compatible with strided semantics.
isStrided(MemRefType t)797 bool mlir::isStrided(MemRefType t) {
798   int64_t offset;
799   SmallVector<int64_t, 4> stridesAndOffset;
800   auto res = getStridesAndOffset(t, stridesAndOffset, offset);
801   return succeeded(res);
802 }
803