1 //===- BuiltinTypes.cpp - C Interface to MLIR Builtin Types ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/BuiltinTypes.h"
10 #include "mlir-c/AffineMap.h"
11 #include "mlir-c/IR.h"
12 #include "mlir/CAPI/AffineMap.h"
13 #include "mlir/CAPI/IR.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Types.h"
17 
18 using namespace mlir;
19 
20 //===----------------------------------------------------------------------===//
21 // Integer types.
22 //===----------------------------------------------------------------------===//
23 
mlirTypeIsAInteger(MlirType type)24 bool mlirTypeIsAInteger(MlirType type) {
25   return unwrap(type).isa<IntegerType>();
26 }
27 
mlirIntegerTypeGet(MlirContext ctx,unsigned bitwidth)28 MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) {
29   return wrap(IntegerType::get(bitwidth, unwrap(ctx)));
30 }
31 
mlirIntegerTypeSignedGet(MlirContext ctx,unsigned bitwidth)32 MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth) {
33   return wrap(IntegerType::get(bitwidth, IntegerType::Signed, unwrap(ctx)));
34 }
35 
mlirIntegerTypeUnsignedGet(MlirContext ctx,unsigned bitwidth)36 MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) {
37   return wrap(IntegerType::get(bitwidth, IntegerType::Unsigned, unwrap(ctx)));
38 }
39 
mlirIntegerTypeGetWidth(MlirType type)40 unsigned mlirIntegerTypeGetWidth(MlirType type) {
41   return unwrap(type).cast<IntegerType>().getWidth();
42 }
43 
mlirIntegerTypeIsSignless(MlirType type)44 bool mlirIntegerTypeIsSignless(MlirType type) {
45   return unwrap(type).cast<IntegerType>().isSignless();
46 }
47 
mlirIntegerTypeIsSigned(MlirType type)48 bool mlirIntegerTypeIsSigned(MlirType type) {
49   return unwrap(type).cast<IntegerType>().isSigned();
50 }
51 
mlirIntegerTypeIsUnsigned(MlirType type)52 bool mlirIntegerTypeIsUnsigned(MlirType type) {
53   return unwrap(type).cast<IntegerType>().isUnsigned();
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // Index type.
58 //===----------------------------------------------------------------------===//
59 
mlirTypeIsAIndex(MlirType type)60 bool mlirTypeIsAIndex(MlirType type) { return unwrap(type).isa<IndexType>(); }
61 
mlirIndexTypeGet(MlirContext ctx)62 MlirType mlirIndexTypeGet(MlirContext ctx) {
63   return wrap(IndexType::get(unwrap(ctx)));
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // Floating-point types.
68 //===----------------------------------------------------------------------===//
69 
mlirTypeIsABF16(MlirType type)70 bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
71 
mlirBF16TypeGet(MlirContext ctx)72 MlirType mlirBF16TypeGet(MlirContext ctx) {
73   return wrap(FloatType::getBF16(unwrap(ctx)));
74 }
75 
mlirTypeIsAF16(MlirType type)76 bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); }
77 
mlirF16TypeGet(MlirContext ctx)78 MlirType mlirF16TypeGet(MlirContext ctx) {
79   return wrap(FloatType::getF16(unwrap(ctx)));
80 }
81 
mlirTypeIsAF32(MlirType type)82 bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); }
83 
mlirF32TypeGet(MlirContext ctx)84 MlirType mlirF32TypeGet(MlirContext ctx) {
85   return wrap(FloatType::getF32(unwrap(ctx)));
86 }
87 
mlirTypeIsAF64(MlirType type)88 bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); }
89 
mlirF64TypeGet(MlirContext ctx)90 MlirType mlirF64TypeGet(MlirContext ctx) {
91   return wrap(FloatType::getF64(unwrap(ctx)));
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // None type.
96 //===----------------------------------------------------------------------===//
97 
mlirTypeIsANone(MlirType type)98 bool mlirTypeIsANone(MlirType type) { return unwrap(type).isa<NoneType>(); }
99 
mlirNoneTypeGet(MlirContext ctx)100 MlirType mlirNoneTypeGet(MlirContext ctx) {
101   return wrap(NoneType::get(unwrap(ctx)));
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // Complex type.
106 //===----------------------------------------------------------------------===//
107 
mlirTypeIsAComplex(MlirType type)108 bool mlirTypeIsAComplex(MlirType type) {
109   return unwrap(type).isa<ComplexType>();
110 }
111 
mlirComplexTypeGet(MlirType elementType)112 MlirType mlirComplexTypeGet(MlirType elementType) {
113   return wrap(ComplexType::get(unwrap(elementType)));
114 }
115 
mlirComplexTypeGetElementType(MlirType type)116 MlirType mlirComplexTypeGetElementType(MlirType type) {
117   return wrap(unwrap(type).cast<ComplexType>().getElementType());
118 }
119 
120 //===----------------------------------------------------------------------===//
121 // Shaped type.
122 //===----------------------------------------------------------------------===//
123 
mlirTypeIsAShaped(MlirType type)124 bool mlirTypeIsAShaped(MlirType type) { return unwrap(type).isa<ShapedType>(); }
125 
mlirShapedTypeGetElementType(MlirType type)126 MlirType mlirShapedTypeGetElementType(MlirType type) {
127   return wrap(unwrap(type).cast<ShapedType>().getElementType());
128 }
129 
mlirShapedTypeHasRank(MlirType type)130 bool mlirShapedTypeHasRank(MlirType type) {
131   return unwrap(type).cast<ShapedType>().hasRank();
132 }
133 
mlirShapedTypeGetRank(MlirType type)134 int64_t mlirShapedTypeGetRank(MlirType type) {
135   return unwrap(type).cast<ShapedType>().getRank();
136 }
137 
mlirShapedTypeHasStaticShape(MlirType type)138 bool mlirShapedTypeHasStaticShape(MlirType type) {
139   return unwrap(type).cast<ShapedType>().hasStaticShape();
140 }
141 
mlirShapedTypeIsDynamicDim(MlirType type,intptr_t dim)142 bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) {
143   return unwrap(type).cast<ShapedType>().isDynamicDim(
144       static_cast<unsigned>(dim));
145 }
146 
mlirShapedTypeGetDimSize(MlirType type,intptr_t dim)147 int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) {
148   return unwrap(type).cast<ShapedType>().getDimSize(static_cast<unsigned>(dim));
149 }
150 
mlirShapedTypeIsDynamicSize(int64_t size)151 bool mlirShapedTypeIsDynamicSize(int64_t size) {
152   return ShapedType::isDynamic(size);
153 }
154 
mlirShapedTypeIsDynamicStrideOrOffset(int64_t val)155 bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) {
156   return ShapedType::isDynamicStrideOrOffset(val);
157 }
158 
159 //===----------------------------------------------------------------------===//
160 // Vector type.
161 //===----------------------------------------------------------------------===//
162 
mlirTypeIsAVector(MlirType type)163 bool mlirTypeIsAVector(MlirType type) { return unwrap(type).isa<VectorType>(); }
164 
mlirVectorTypeGet(intptr_t rank,const int64_t * shape,MlirType elementType)165 MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
166                            MlirType elementType) {
167   return wrap(
168       VectorType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
169                       unwrap(elementType)));
170 }
171 
mlirVectorTypeGetChecked(intptr_t rank,const int64_t * shape,MlirType elementType,MlirLocation loc)172 MlirType mlirVectorTypeGetChecked(intptr_t rank, const int64_t *shape,
173                                   MlirType elementType, MlirLocation loc) {
174   return wrap(VectorType::getChecked(
175       llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
176       unwrap(loc)));
177 }
178 
179 //===----------------------------------------------------------------------===//
180 // Ranked / Unranked tensor type.
181 //===----------------------------------------------------------------------===//
182 
mlirTypeIsATensor(MlirType type)183 bool mlirTypeIsATensor(MlirType type) { return unwrap(type).isa<TensorType>(); }
184 
mlirTypeIsARankedTensor(MlirType type)185 bool mlirTypeIsARankedTensor(MlirType type) {
186   return unwrap(type).isa<RankedTensorType>();
187 }
188 
mlirTypeIsAUnrankedTensor(MlirType type)189 bool mlirTypeIsAUnrankedTensor(MlirType type) {
190   return unwrap(type).isa<UnrankedTensorType>();
191 }
192 
mlirRankedTensorTypeGet(intptr_t rank,const int64_t * shape,MlirType elementType)193 MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape,
194                                  MlirType elementType) {
195   return wrap(RankedTensorType::get(
196       llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
197       unwrap(elementType)));
198 }
199 
mlirRankedTensorTypeGetChecked(intptr_t rank,const int64_t * shape,MlirType elementType,MlirLocation loc)200 MlirType mlirRankedTensorTypeGetChecked(intptr_t rank, const int64_t *shape,
201                                         MlirType elementType,
202                                         MlirLocation loc) {
203   return wrap(RankedTensorType::getChecked(
204       llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
205       unwrap(loc)));
206 }
207 
mlirUnrankedTensorTypeGet(MlirType elementType)208 MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
209   return wrap(UnrankedTensorType::get(unwrap(elementType)));
210 }
211 
mlirUnrankedTensorTypeGetChecked(MlirType elementType,MlirLocation loc)212 MlirType mlirUnrankedTensorTypeGetChecked(MlirType elementType,
213                                           MlirLocation loc) {
214   return wrap(UnrankedTensorType::getChecked(unwrap(elementType), unwrap(loc)));
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // Ranked / Unranked MemRef type.
219 //===----------------------------------------------------------------------===//
220 
mlirTypeIsAMemRef(MlirType type)221 bool mlirTypeIsAMemRef(MlirType type) { return unwrap(type).isa<MemRefType>(); }
222 
mlirMemRefTypeGet(MlirType elementType,intptr_t rank,const int64_t * shape,intptr_t numMaps,MlirAffineMap const * affineMaps,unsigned memorySpace)223 MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank,
224                            const int64_t *shape, intptr_t numMaps,
225                            MlirAffineMap const *affineMaps,
226                            unsigned memorySpace) {
227   SmallVector<AffineMap, 1> maps;
228   (void)unwrapList(numMaps, affineMaps, maps);
229   return wrap(
230       MemRefType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
231                       unwrap(elementType), maps, memorySpace));
232 }
233 
mlirMemRefTypeContiguousGet(MlirType elementType,intptr_t rank,const int64_t * shape,unsigned memorySpace)234 MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
235                                      const int64_t *shape,
236                                      unsigned memorySpace) {
237   return wrap(
238       MemRefType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
239                       unwrap(elementType), llvm::None, memorySpace));
240 }
241 
mlirMemRefTypeContiguousGetChecked(MlirType elementType,intptr_t rank,const int64_t * shape,unsigned memorySpace,MlirLocation loc)242 MlirType mlirMemRefTypeContiguousGetChecked(MlirType elementType, intptr_t rank,
243                                             const int64_t *shape,
244                                             unsigned memorySpace,
245                                             MlirLocation loc) {
246   return wrap(MemRefType::getChecked(
247       llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
248       llvm::None, memorySpace, unwrap(loc)));
249 }
250 
mlirMemRefTypeGetNumAffineMaps(MlirType type)251 intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) {
252   return static_cast<intptr_t>(
253       unwrap(type).cast<MemRefType>().getAffineMaps().size());
254 }
255 
mlirMemRefTypeGetAffineMap(MlirType type,intptr_t pos)256 MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, intptr_t pos) {
257   return wrap(unwrap(type).cast<MemRefType>().getAffineMaps()[pos]);
258 }
259 
mlirMemRefTypeGetMemorySpace(MlirType type)260 unsigned mlirMemRefTypeGetMemorySpace(MlirType type) {
261   return unwrap(type).cast<MemRefType>().getMemorySpace();
262 }
263 
mlirTypeIsAUnrankedMemRef(MlirType type)264 bool mlirTypeIsAUnrankedMemRef(MlirType type) {
265   return unwrap(type).isa<UnrankedMemRefType>();
266 }
267 
mlirUnrankedMemRefTypeGet(MlirType elementType,unsigned memorySpace)268 MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace) {
269   return wrap(UnrankedMemRefType::get(unwrap(elementType), memorySpace));
270 }
271 
mlirUnrankedMemRefTypeGetChecked(MlirType elementType,unsigned memorySpace,MlirLocation loc)272 MlirType mlirUnrankedMemRefTypeGetChecked(MlirType elementType,
273                                           unsigned memorySpace,
274                                           MlirLocation loc) {
275   return wrap(UnrankedMemRefType::getChecked(unwrap(elementType), memorySpace,
276                                              unwrap(loc)));
277 }
278 
mlirUnrankedMemrefGetMemorySpace(MlirType type)279 unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) {
280   return unwrap(type).cast<UnrankedMemRefType>().getMemorySpace();
281 }
282 
283 //===----------------------------------------------------------------------===//
284 // Tuple type.
285 //===----------------------------------------------------------------------===//
286 
mlirTypeIsATuple(MlirType type)287 bool mlirTypeIsATuple(MlirType type) { return unwrap(type).isa<TupleType>(); }
288 
mlirTupleTypeGet(MlirContext ctx,intptr_t numElements,MlirType const * elements)289 MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements,
290                           MlirType const *elements) {
291   SmallVector<Type, 4> types;
292   ArrayRef<Type> typeRef = unwrapList(numElements, elements, types);
293   return wrap(TupleType::get(typeRef, unwrap(ctx)));
294 }
295 
mlirTupleTypeGetNumTypes(MlirType type)296 intptr_t mlirTupleTypeGetNumTypes(MlirType type) {
297   return unwrap(type).cast<TupleType>().size();
298 }
299 
mlirTupleTypeGetType(MlirType type,intptr_t pos)300 MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) {
301   return wrap(unwrap(type).cast<TupleType>().getType(static_cast<size_t>(pos)));
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // Function type.
306 //===----------------------------------------------------------------------===//
307 
mlirTypeIsAFunction(MlirType type)308 bool mlirTypeIsAFunction(MlirType type) {
309   return unwrap(type).isa<FunctionType>();
310 }
311 
mlirFunctionTypeGet(MlirContext ctx,intptr_t numInputs,MlirType const * inputs,intptr_t numResults,MlirType const * results)312 MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs,
313                              MlirType const *inputs, intptr_t numResults,
314                              MlirType const *results) {
315   SmallVector<Type, 4> inputsList;
316   SmallVector<Type, 4> resultsList;
317   (void)unwrapList(numInputs, inputs, inputsList);
318   (void)unwrapList(numResults, results, resultsList);
319   return wrap(FunctionType::get(inputsList, resultsList, unwrap(ctx)));
320 }
321 
mlirFunctionTypeGetNumInputs(MlirType type)322 intptr_t mlirFunctionTypeGetNumInputs(MlirType type) {
323   return unwrap(type).cast<FunctionType>().getNumInputs();
324 }
325 
mlirFunctionTypeGetNumResults(MlirType type)326 intptr_t mlirFunctionTypeGetNumResults(MlirType type) {
327   return unwrap(type).cast<FunctionType>().getNumResults();
328 }
329 
mlirFunctionTypeGetInput(MlirType type,intptr_t pos)330 MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos) {
331   assert(pos >= 0 && "pos in array must be positive");
332   return wrap(
333       unwrap(type).cast<FunctionType>().getInput(static_cast<unsigned>(pos)));
334 }
335 
mlirFunctionTypeGetResult(MlirType type,intptr_t pos)336 MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) {
337   assert(pos >= 0 && "pos in array must be positive");
338   return wrap(
339       unwrap(type).cast<FunctionType>().getResult(static_cast<unsigned>(pos)));
340 }
341