//===- BuiltinAttributes.cpp - C Interface to MLIR Builtin Attributes -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir-c/BuiltinAttributes.h" #include "mlir/CAPI/AffineMap.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" using namespace mlir; //===----------------------------------------------------------------------===// // Affine map attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsAAffineMap(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) { return wrap(AffineMapAttr::get(unwrap(map))); } MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) { return wrap(unwrap(attr).cast().getValue()); } //===----------------------------------------------------------------------===// // Array attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsAArray(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements, MlirAttribute const *elements) { SmallVector attrs; return wrap(ArrayAttr::get( unwrapList(static_cast(numElements), elements, attrs), unwrap(ctx))); } intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) { return static_cast(unwrap(attr).cast().size()); } MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) { return wrap(unwrap(attr).cast().getValue()[pos]); } //===----------------------------------------------------------------------===// // Dictionary attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsADictionary(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements, MlirNamedAttribute const *elements) { SmallVector attributes; attributes.reserve(numElements); for (intptr_t i = 0; i < numElements; ++i) attributes.emplace_back( Identifier::get(unwrap(elements[i].name), unwrap(ctx)), unwrap(elements[i].attribute)); return wrap(DictionaryAttr::get(attributes, unwrap(ctx))); } intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) { return static_cast(unwrap(attr).cast().size()); } MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr, intptr_t pos) { NamedAttribute attribute = unwrap(attr).cast().getValue()[pos]; return {wrap(attribute.first.strref()), wrap(attribute.second)}; } MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, MlirStringRef name) { return wrap(unwrap(attr).cast().get(unwrap(name))); } //===----------------------------------------------------------------------===// // Floating point attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsAFloat(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type, double value) { return wrap(FloatAttr::get(unwrap(type), value)); } MlirAttribute mlirFloatAttrDoubleGetChecked(MlirType type, double value, MlirLocation loc) { return wrap(FloatAttr::getChecked(unwrap(type), value, unwrap(loc))); } double mlirFloatAttrGetValueDouble(MlirAttribute attr) { return unwrap(attr).cast().getValueAsDouble(); } //===----------------------------------------------------------------------===// // Integer attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsAInteger(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) { return wrap(IntegerAttr::get(unwrap(type), value)); } int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) { return unwrap(attr).cast().getInt(); } //===----------------------------------------------------------------------===// // Bool attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsABool(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) { return wrap(BoolAttr::get(value, unwrap(ctx))); } bool mlirBoolAttrGetValue(MlirAttribute attr) { return unwrap(attr).cast().getValue(); } //===----------------------------------------------------------------------===// // Integer set attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsAIntegerSet(MlirAttribute attr) { return unwrap(attr).isa(); } //===----------------------------------------------------------------------===// // Opaque attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsAOpaque(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace, intptr_t dataLength, const char *data, MlirType type) { return wrap( OpaqueAttr::get(Identifier::get(unwrap(dialectNamespace), unwrap(ctx)), StringRef(data, dataLength), unwrap(type), unwrap(ctx))); } MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) { return wrap(unwrap(attr).cast().getDialectNamespace().strref()); } MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) { return wrap(unwrap(attr).cast().getAttrData()); } //===----------------------------------------------------------------------===// // String attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsAString(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) { return wrap(StringAttr::get(unwrap(str), unwrap(ctx))); } MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) { return wrap(StringAttr::get(unwrap(str), unwrap(type))); } MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) { return wrap(unwrap(attr).cast().getValue()); } //===----------------------------------------------------------------------===// // SymbolRef attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsASymbolRef(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol, intptr_t numReferences, MlirAttribute const *references) { SmallVector refs; refs.reserve(numReferences); for (intptr_t i = 0; i < numReferences; ++i) refs.push_back(unwrap(references[i]).cast()); return wrap(SymbolRefAttr::get(unwrap(symbol), refs, unwrap(ctx))); } MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) { return wrap(unwrap(attr).cast().getRootReference()); } MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) { return wrap(unwrap(attr).cast().getLeafReference()); } intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) { return static_cast( unwrap(attr).cast().getNestedReferences().size()); } MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, intptr_t pos) { return wrap(unwrap(attr).cast().getNestedReferences()[pos]); } //===----------------------------------------------------------------------===// // Flat SymbolRef attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) { return wrap(FlatSymbolRefAttr::get(unwrap(symbol), unwrap(ctx))); } MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) { return wrap(unwrap(attr).cast().getValue()); } //===----------------------------------------------------------------------===// // Type attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsAType(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirTypeAttrGet(MlirType type) { return wrap(TypeAttr::get(unwrap(type))); } MlirType mlirTypeAttrGetValue(MlirAttribute attr) { return wrap(unwrap(attr).cast().getValue()); } //===----------------------------------------------------------------------===// // Unit attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsAUnit(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirUnitAttrGet(MlirContext ctx) { return wrap(UnitAttr::get(unwrap(ctx))); } //===----------------------------------------------------------------------===// // Elements attributes. //===----------------------------------------------------------------------===// bool mlirAttributeIsAElements(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank, uint64_t *idxs) { return wrap(unwrap(attr).cast().getValue( llvm::makeArrayRef(idxs, rank))); } bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank, uint64_t *idxs) { return unwrap(attr).cast().isValidIndex( llvm::makeArrayRef(idxs, rank)); } int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) { return unwrap(attr).cast().getNumElements(); } //===----------------------------------------------------------------------===// // Dense elements attribute. //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// // IsA support. bool mlirAttributeIsADenseElements(MlirAttribute attr) { return unwrap(attr).isa(); } bool mlirAttributeIsADenseIntElements(MlirAttribute attr) { return unwrap(attr).isa(); } bool mlirAttributeIsADenseFPElements(MlirAttribute attr) { return unwrap(attr).isa(); } //===----------------------------------------------------------------------===// // Constructors. MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType, intptr_t numElements, MlirAttribute const *elements) { SmallVector attributes; return wrap( DenseElementsAttr::get(unwrap(shapedType).cast(), unwrapList(numElements, elements, attributes))); } MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType, MlirAttribute element) { return wrap(DenseElementsAttr::get(unwrap(shapedType).cast(), unwrap(element))); } MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType, bool element) { return wrap( DenseElementsAttr::get(unwrap(shapedType).cast(), element)); } MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType, uint32_t element) { return wrap( DenseElementsAttr::get(unwrap(shapedType).cast(), element)); } MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType, int32_t element) { return wrap( DenseElementsAttr::get(unwrap(shapedType).cast(), element)); } MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType, uint64_t element) { return wrap( DenseElementsAttr::get(unwrap(shapedType).cast(), element)); } MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType, int64_t element) { return wrap( DenseElementsAttr::get(unwrap(shapedType).cast(), element)); } MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType, float element) { return wrap( DenseElementsAttr::get(unwrap(shapedType).cast(), element)); } MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType, double element) { return wrap( DenseElementsAttr::get(unwrap(shapedType).cast(), element)); } MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType, intptr_t numElements, const int *elements) { SmallVector values(elements, elements + numElements); return wrap( DenseElementsAttr::get(unwrap(shapedType).cast(), values)); } /// Creates a dense attribute with elements of the type deduced by templates. template static MlirAttribute getDenseAttribute(MlirType shapedType, intptr_t numElements, const T *elements) { return wrap( DenseElementsAttr::get(unwrap(shapedType).cast(), llvm::makeArrayRef(elements, numElements))); } MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType, intptr_t numElements, const uint32_t *elements) { return getDenseAttribute(shapedType, numElements, elements); } MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType, intptr_t numElements, const int32_t *elements) { return getDenseAttribute(shapedType, numElements, elements); } MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType, intptr_t numElements, const uint64_t *elements) { return getDenseAttribute(shapedType, numElements, elements); } MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType, intptr_t numElements, const int64_t *elements) { return getDenseAttribute(shapedType, numElements, elements); } MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType, intptr_t numElements, const float *elements) { return getDenseAttribute(shapedType, numElements, elements); } MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType, intptr_t numElements, const double *elements) { return getDenseAttribute(shapedType, numElements, elements); } MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType, intptr_t numElements, MlirStringRef *strs) { SmallVector values; values.reserve(numElements); for (intptr_t i = 0; i < numElements; ++i) values.push_back(unwrap(strs[i])); return wrap( DenseElementsAttr::get(unwrap(shapedType).cast(), values)); } MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr, MlirType shapedType) { return wrap(unwrap(attr).cast().reshape( unwrap(shapedType).cast())); } //===----------------------------------------------------------------------===// // Splat accessors. bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) { return unwrap(attr).cast().isSplat(); } MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) { return wrap(unwrap(attr).cast().getSplatValue()); } int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) { return unwrap(attr).cast().getSplatValue(); } int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) { return unwrap(attr).cast().getSplatValue(); } uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) { return unwrap(attr).cast().getSplatValue(); } int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) { return unwrap(attr).cast().getSplatValue(); } uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) { return unwrap(attr).cast().getSplatValue(); } float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) { return unwrap(attr).cast().getSplatValue(); } double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) { return unwrap(attr).cast().getSplatValue(); } MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) { return wrap( unwrap(attr).cast().getSplatValue()); } //===----------------------------------------------------------------------===// // Indexed accessors. bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) { return *(unwrap(attr).cast().getValues().begin() + pos); } int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) { return *(unwrap(attr).cast().getValues().begin() + pos); } uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) { return *( unwrap(attr).cast().getValues().begin() + pos); } int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) { return *(unwrap(attr).cast().getValues().begin() + pos); } uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) { return *( unwrap(attr).cast().getValues().begin() + pos); } float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) { return *(unwrap(attr).cast().getValues().begin() + pos); } double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) { return *(unwrap(attr).cast().getValues().begin() + pos); } MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos) { return wrap( *(unwrap(attr).cast().getValues().begin() + pos)); } //===----------------------------------------------------------------------===// // Raw data accessors. const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) { return static_cast( unwrap(attr).cast().getRawData().data()); } //===----------------------------------------------------------------------===// // Opaque elements attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsAOpaqueElements(MlirAttribute attr) { return unwrap(attr).isa(); } //===----------------------------------------------------------------------===// // Sparse elements attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsASparseElements(MlirAttribute attr) { return unwrap(attr).isa(); } MlirAttribute mlirSparseElementsAttribute(MlirType shapedType, MlirAttribute denseIndices, MlirAttribute denseValues) { return wrap( SparseElementsAttr::get(unwrap(shapedType).cast(), unwrap(denseIndices).cast(), unwrap(denseValues).cast())); } MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) { return wrap(unwrap(attr).cast().getIndices()); } MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) { return wrap(unwrap(attr).cast().getValues()); }