//===- StandardToLLVM.cpp - Standard to LLVM dialect conversion -----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a pass to convert MLIR standard and builtin dialects // into the LLVM IR dialect. // //===----------------------------------------------------------------------===// #include "../PassDetail.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Type.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include using namespace mlir; #define PASS_NAME "convert-std-to-llvm" // Extract an LLVM IR type from the LLVM IR dialect type. static LLVM::LLVMType unwrap(Type type) { if (!type) return nullptr; auto *mlirContext = type.getContext(); auto wrappedLLVMType = type.dyn_cast(); if (!wrappedLLVMType) emitError(UnknownLoc::get(mlirContext), "conversion resulted in a non-LLVM type"); return wrappedLLVMType; } /// Callback to convert function argument types. It converts a MemRef function /// argument to a list of non-aggregate types containing descriptor /// information, and an UnrankedmemRef function argument to a list containing /// the rank and a pointer to a descriptor struct. LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, SmallVectorImpl &result) { if (auto memref = type.dyn_cast()) { // In signatures, Memref descriptors are expanded into lists of // non-aggregate values. auto converted = converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true); if (converted.empty()) return failure(); result.append(converted.begin(), converted.end()); return success(); } if (type.isa()) { auto converted = converter.getUnrankedMemRefDescriptorFields(); if (converted.empty()) return failure(); result.append(converted.begin(), converted.end()); return success(); } auto converted = converter.convertType(type); if (!converted) return failure(); result.push_back(converted); return success(); } /// Callback to convert function argument types. It converts MemRef function /// arguments to bare pointers to the MemRef element type. LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, SmallVectorImpl &result) { auto llvmTy = converter.convertCallingConventionType(type); if (!llvmTy) return failure(); result.push_back(llvmTy); return success(); } /// Create an LLVMTypeConverter using default LowerToLLVMOptions. LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx) : LLVMTypeConverter(ctx, LowerToLLVMOptions::getDefaultOptions()) {} /// Create an LLVMTypeConverter using custom LowerToLLVMOptions. LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options) : llvmDialect(ctx->getOrLoadDialect()), options(options) { assert(llvmDialect && "LLVM IR dialect is not registered"); if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout) this->options.indexBitwidth = options.dataLayout.getPointerSizeInBits(); // Register conversions for the builtin types. addConversion([&](ComplexType type) { return convertComplexType(type); }); addConversion([&](FloatType type) { return convertFloatType(type); }); addConversion([&](FunctionType type) { return convertFunctionType(type); }); addConversion([&](IndexType type) { return convertIndexType(type); }); addConversion([&](IntegerType type) { return convertIntegerType(type); }); addConversion([&](MemRefType type) { return convertMemRefType(type); }); addConversion( [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); }); addConversion([&](VectorType type) { return convertVectorType(type); }); // LLVMType is legal, so add a pass-through conversion. addConversion([](LLVM::LLVMType type) { return type; }); // Materialization for memrefs creates descriptor structs from individual // values constituting them, when descriptors are used, i.e. more than one // value represents a memref. addArgumentMaterialization( [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc) -> Optional { if (inputs.size() == 1) return llvm::None; return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs); }); addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc) -> Optional { if (inputs.size() == 1) return llvm::None; return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs); }); // Add generic source and target materializations to handle cases where // non-LLVM types persist after an LLVM conversion. addSourceMaterialization([&](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> Optional { if (inputs.size() != 1) return llvm::None; // FIXME: These should check LLVM::DialectCastOp can actually be constructed // from the input and result. return builder.create(loc, resultType, inputs[0]) .getResult(); }); addTargetMaterialization([&](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> Optional { if (inputs.size() != 1) return llvm::None; // FIXME: These should check LLVM::DialectCastOp can actually be constructed // from the input and result. return builder.create(loc, resultType, inputs[0]) .getResult(); }); } /// Returns the MLIR context. MLIRContext &LLVMTypeConverter::getContext() { return *getDialect()->getContext(); } LLVM::LLVMType LLVMTypeConverter::getIndexType() { return LLVM::LLVMType::getIntNTy(&getContext(), getIndexTypeBitwidth()); } unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) { return options.dataLayout.getPointerSizeInBits(addressSpace); } Type LLVMTypeConverter::convertIndexType(IndexType type) { return getIndexType(); } Type LLVMTypeConverter::convertIntegerType(IntegerType type) { return LLVM::LLVMType::getIntNTy(&getContext(), type.getWidth()); } Type LLVMTypeConverter::convertFloatType(FloatType type) { if (type.isa()) return LLVM::LLVMType::getFloatTy(&getContext()); if (type.isa()) return LLVM::LLVMType::getDoubleTy(&getContext()); if (type.isa()) return LLVM::LLVMType::getHalfTy(&getContext()); if (type.isa()) return LLVM::LLVMType::getBFloatTy(&getContext()); llvm_unreachable("non-float type in convertFloatType"); } // Convert a `ComplexType` to an LLVM type. The result is a complex number // struct with entries for the // 1. real part and for the // 2. imaginary part. static constexpr unsigned kRealPosInComplexNumberStruct = 0; static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1; Type LLVMTypeConverter::convertComplexType(ComplexType type) { auto elementType = convertType(type.getElementType()).cast(); return LLVM::LLVMType::getStructTy(&getContext(), {elementType, elementType}); } // Except for signatures, MLIR function types are converted into LLVM // pointer-to-function types. Type LLVMTypeConverter::convertFunctionType(FunctionType type) { SignatureConversion conversion(type.getNumInputs()); LLVM::LLVMType converted = convertFunctionSignature(type, /*isVariadic=*/false, conversion); return converted.getPointerTo(); } // Function types are converted to LLVM Function types by recursively converting // argument and result types. If MLIR Function has zero results, the LLVM // Function has one VoidType result. If MLIR Function has more than one result, // they are into an LLVM StructType in their order of appearance. LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature( FunctionType funcTy, bool isVariadic, LLVMTypeConverter::SignatureConversion &result) { // Select the argument converter depending on the calling convention. auto funcArgConverter = options.useBarePtrCallConv ? barePtrFuncArgTypeConverter : structFuncArgTypeConverter; // Convert argument types one by one and check for errors. for (auto &en : llvm::enumerate(funcTy.getInputs())) { Type type = en.value(); SmallVector converted; if (failed(funcArgConverter(*this, type, converted))) return {}; result.addInputs(en.index(), converted); } SmallVector argTypes; argTypes.reserve(llvm::size(result.getConvertedTypes())); for (Type type : result.getConvertedTypes()) argTypes.push_back(unwrap(type)); // If function does not return anything, create the void result type, // if it returns on element, convert it, otherwise pack the result types into // a struct. LLVM::LLVMType resultType = funcTy.getNumResults() == 0 ? LLVM::LLVMType::getVoidTy(&getContext()) : unwrap(packFunctionResults(funcTy.getResults())); if (!resultType) return {}; return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic); } /// Converts the function type to a C-compatible format, in particular using /// pointers to memref descriptors for arguments. LLVM::LLVMType LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) { SmallVector inputs; for (Type t : type.getInputs()) { auto converted = convertType(t).dyn_cast_or_null(); if (!converted) return {}; if (t.isa()) converted = converted.getPointerTo(); inputs.push_back(converted); } LLVM::LLVMType resultType = type.getNumResults() == 0 ? LLVM::LLVMType::getVoidTy(&getContext()) : unwrap(packFunctionResults(type.getResults())); if (!resultType) return {}; return LLVM::LLVMType::getFunctionTy(resultType, inputs, false); } static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0; static constexpr unsigned kAlignedPtrPosInMemRefDescriptor = 1; static constexpr unsigned kOffsetPosInMemRefDescriptor = 2; static constexpr unsigned kSizePosInMemRefDescriptor = 3; static constexpr unsigned kStridePosInMemRefDescriptor = 4; /// Convert a memref type into a list of LLVM IR types that will form the /// memref descriptor. The result contains the following types: /// 1. The pointer to the allocated data buffer, followed by /// 2. The pointer to the aligned data buffer, followed by /// 3. A lowered `index`-type integer containing the distance between the /// beginning of the buffer and the first element to be accessed through the /// view, followed by /// 4. An array containing as many `index`-type integers as the rank of the /// MemRef: the array represents the size, in number of elements, of the memref /// along the given dimension. For constant MemRef dimensions, the /// corresponding size entry is a constant whose runtime value must match the /// static value, followed by /// 5. A second array containing as many `index`-type integers as the rank of /// the MemRef: the second array represents the "stride" (in tensor abstraction /// sense), i.e. the number of consecutive elements of the underlying buffer. /// TODO: add assertions for the static cases. /// /// If `unpackAggregates` is set to true, the arrays described in (4) and (5) /// are expanded into individual index-type elements. /// /// template /// struct { /// Elem *allocatedPtr; /// Elem *alignedPtr; /// Index offset; /// Index sizes[Rank]; // omitted when rank == 0 /// Index strides[Rank]; // omitted when rank == 0 /// }; SmallVector LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type, bool unpackAggregates) { assert(isStrided(type) && "Non-strided layout maps must have been normalized away"); LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; auto ptrTy = elementType.getPointerTo(type.getMemorySpace()); auto indexTy = getIndexType(); SmallVector results = {ptrTy, ptrTy, indexTy}; auto rank = type.getRank(); if (rank == 0) return results; if (unpackAggregates) results.insert(results.end(), 2 * rank, indexTy); else results.insert(results.end(), 2, LLVM::LLVMType::getArrayTy(indexTy, rank)); return results; } /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that /// packs the descriptor fields as defined by `getMemRefDescriptorFields`. Type LLVMTypeConverter::convertMemRefType(MemRefType type) { // When converting a MemRefType to a struct with descriptor fields, do not // unpack the `sizes` and `strides` arrays. SmallVector types = getMemRefDescriptorFields(type, /*unpackAggregates=*/false); return LLVM::LLVMType::getStructTy(&getContext(), types); } static constexpr unsigned kRankInUnrankedMemRefDescriptor = 0; static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1; /// Convert an unranked memref type into a list of non-aggregate LLVM IR types /// that will form the unranked memref descriptor. In particular, the fields /// for an unranked memref descriptor are: /// 1. index-typed rank, the dynamic rank of this MemRef /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be /// stack allocated (alloca) copy of a MemRef descriptor that got casted to /// be unranked. SmallVector LLVMTypeConverter::getUnrankedMemRefDescriptorFields() { return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(&getContext())}; } Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) { return LLVM::LLVMType::getStructTy(&getContext(), getUnrankedMemRefDescriptorFields()); } /// Convert a memref type to a bare pointer to the memref element type. Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) { if (type.isa()) // Unranked memref is not supported in the bare pointer calling convention. return {}; // Check that the memref has static shape, strides and offset. Otherwise, it // cannot be lowered to a bare pointer. auto memrefTy = type.cast(); if (!memrefTy.hasStaticShape()) return {}; int64_t offset = 0; SmallVector strides; if (failed(getStridesAndOffset(memrefTy, strides, offset))) return {}; for (int64_t stride : strides) if (ShapedType::isDynamicStrideOrOffset(stride)) return {}; if (ShapedType::isDynamicStrideOrOffset(offset)) return {}; LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; return elementType.getPointerTo(type.getMemorySpace()); } // Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when // n > 1. // For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and // `vector<4 x 8 x 16 f32>` converts to `!llvm<"[4 x [8 x <16 x float>]]">`. Type LLVMTypeConverter::convertVectorType(VectorType type) { auto elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; auto vectorType = LLVM::LLVMType::getVectorTy(elementType, type.getShape().back()); auto shape = type.getShape(); for (int i = shape.size() - 2; i >= 0; --i) vectorType = LLVM::LLVMType::getArrayTy(vectorType, shape[i]); return vectorType; } /// Convert a type in the context of the default or bare pointer calling /// convention. Calling convention sensitive types, such as MemRefType and /// UnrankedMemRefType, are converted following the specific rules for the /// calling convention. Calling convention independent types are converted /// following the default LLVM type conversions. Type LLVMTypeConverter::convertCallingConventionType(Type type) { if (options.useBarePtrCallConv) if (auto memrefTy = type.dyn_cast()) return convertMemRefToBarePtr(memrefTy); return convertType(type); } /// Promote the bare pointers in 'values' that resulted from memrefs to /// descriptors. 'stdTypes' holds they types of 'values' before the conversion /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type). void LLVMTypeConverter::promoteBarePtrsToDescriptors( ConversionPatternRewriter &rewriter, Location loc, ArrayRef stdTypes, SmallVectorImpl &values) { assert(stdTypes.size() == values.size() && "The number of types and values doesn't match"); for (unsigned i = 0, end = values.size(); i < end; ++i) if (auto memrefTy = stdTypes[i].dyn_cast()) values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this, memrefTy, values[i]); } ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &typeConverter, PatternBenefit benefit) : ConversionPattern(rootOpName, benefit, typeConverter, context) {} //===----------------------------------------------------------------------===// // StructBuilder implementation //===----------------------------------------------------------------------===// StructBuilder::StructBuilder(Value v) : value(v) { assert(value != nullptr && "value cannot be null"); structType = value.getType().dyn_cast(); assert(structType && "expected llvm type"); } Value StructBuilder::extractPtr(OpBuilder &builder, Location loc, unsigned pos) { Type type = structType.cast().getStructElementType(pos); return builder.create(loc, type, value, builder.getI64ArrayAttr(pos)); } void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr) { value = builder.create(loc, structType, value, ptr, builder.getI64ArrayAttr(pos)); } //===----------------------------------------------------------------------===// // ComplexStructBuilder implementation //===----------------------------------------------------------------------===// ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder, Location loc, Type type) { Value val = builder.create(loc, type.cast()); return ComplexStructBuilder(val); } void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc, Value real) { setPtr(builder, loc, kRealPosInComplexNumberStruct, real); } Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kRealPosInComplexNumberStruct); } void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc, Value imaginary) { setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary); } Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct); } //===----------------------------------------------------------------------===// // MemRefDescriptor implementation //===----------------------------------------------------------------------===// /// Construct a helper for the given descriptor value. MemRefDescriptor::MemRefDescriptor(Value descriptor) : StructBuilder(descriptor) { assert(value != nullptr && "value cannot be null"); indexType = value.getType().cast().getStructElementType( kOffsetPosInMemRefDescriptor); } /// Builds IR creating an `undef` value of the descriptor type. MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc, Type descriptorType) { Value descriptor = builder.create(loc, descriptorType.cast()); return MemRefDescriptor(descriptor); } /// Builds IR creating a MemRef descriptor that represents `type` and /// populates it with static shape and stride information extracted from the /// type. MemRefDescriptor MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, MemRefType type, Value memory) { assert(type.hasStaticShape() && "unexpected dynamic shape"); // Extract all strides and offsets and verify they are static. int64_t offset; SmallVector strides; auto result = getStridesAndOffset(type, strides, offset); (void)result; assert(succeeded(result) && "unexpected failure in stride computation"); assert(offset != MemRefType::getDynamicStrideOrOffset() && "expected static offset"); assert(!llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) && "expected static strides"); auto convertedType = typeConverter.convertType(type); assert(convertedType && "unexpected failure in memref type conversion"); auto descr = MemRefDescriptor::undef(builder, loc, convertedType); descr.setAllocatedPtr(builder, loc, memory); descr.setAlignedPtr(builder, loc, memory); descr.setConstantOffset(builder, loc, offset); // Fill in sizes and strides for (unsigned i = 0, e = type.getRank(); i != e; ++i) { descr.setConstantSize(builder, loc, i, type.getDimSize(i)); descr.setConstantStride(builder, loc, i, strides[i]); } return descr; } /// Builds IR extracting the allocated pointer from the descriptor. Value MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor); } /// Builds IR inserting the allocated pointer into the descriptor. void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr) { setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr); } /// Builds IR extracting the aligned pointer from the descriptor. Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor); } /// Builds IR inserting the aligned pointer into the descriptor. void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, Value ptr) { setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr); } // Creates a constant Op producing a value of `resultType` from an index-typed // integer attribute. static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value) { return builder.create( loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); } /// Builds IR extracting the offset from the descriptor. Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) { return builder.create( loc, indexType, value, builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor)); } /// Builds IR inserting the offset into the descriptor. void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc, Value offset) { value = builder.create( loc, structType, value, offset, builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor)); } /// Builds IR inserting the offset into the descriptor. void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset) { setOffset(builder, loc, createIndexAttrConstant(builder, loc, indexType, offset)); } /// Builds IR extracting the pos-th size from the descriptor. Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) { return builder.create( loc, indexType, value, builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); } Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos, int64_t rank) { auto indexTy = indexType.cast(); auto indexPtrTy = indexTy.getPointerTo(); auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, rank); auto arrayPtrTy = arrayTy.getPointerTo(); // Copy size values to stack-allocated memory. auto zero = createIndexAttrConstant(builder, loc, indexType, 0); auto one = createIndexAttrConstant(builder, loc, indexType, 1); auto sizes = builder.create( loc, arrayTy, value, builder.getI64ArrayAttr({kSizePosInMemRefDescriptor})); auto sizesPtr = builder.create(loc, arrayPtrTy, one, /*alignment=*/0); builder.create(loc, sizes, sizesPtr); // Load an return size value of interest. auto resultPtr = builder.create(loc, indexPtrTy, sizesPtr, ValueRange({zero, pos})); return builder.create(loc, resultPtr); } /// Builds IR inserting the pos-th size into the descriptor void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, Value size) { value = builder.create( loc, structType, value, size, builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); } void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc, unsigned pos, uint64_t size) { setSize(builder, loc, pos, createIndexAttrConstant(builder, loc, indexType, size)); } /// Builds IR extracting the pos-th stride from the descriptor. Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) { return builder.create( loc, indexType, value, builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); } /// Builds IR inserting the pos-th stride into the descriptor void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride) { value = builder.create( loc, structType, value, stride, builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); } void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc, unsigned pos, uint64_t stride) { setStride(builder, loc, pos, createIndexAttrConstant(builder, loc, indexType, stride)); } LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() { return value.getType() .cast() .getStructElementType(kAlignedPtrPosInMemRefDescriptor) .cast(); } /// Creates a MemRef descriptor structure from a list of individual values /// composing that descriptor, in the following order: /// - allocated pointer; /// - aligned pointer; /// - offset; /// - sizes; /// - shapes; /// where is the MemRef rank as provided in `type`. Value MemRefDescriptor::pack(OpBuilder &builder, Location loc, LLVMTypeConverter &converter, MemRefType type, ValueRange values) { Type llvmType = converter.convertType(type); auto d = MemRefDescriptor::undef(builder, loc, llvmType); d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]); d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]); d.setOffset(builder, loc, values[kOffsetPosInMemRefDescriptor]); int64_t rank = type.getRank(); for (unsigned i = 0; i < rank; ++i) { d.setSize(builder, loc, i, values[kSizePosInMemRefDescriptor + i]); d.setStride(builder, loc, i, values[kSizePosInMemRefDescriptor + rank + i]); } return d; } /// Builds IR extracting individual elements of a MemRef descriptor structure /// and returning them as `results` list. void MemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl &results) { int64_t rank = type.getRank(); results.reserve(results.size() + getNumUnpackedValues(type)); MemRefDescriptor d(packed); results.push_back(d.allocatedPtr(builder, loc)); results.push_back(d.alignedPtr(builder, loc)); results.push_back(d.offset(builder, loc)); for (int64_t i = 0; i < rank; ++i) results.push_back(d.size(builder, loc, i)); for (int64_t i = 0; i < rank; ++i) results.push_back(d.stride(builder, loc, i)); } /// Returns the number of non-aggregate values that would be produced by /// `unpack`. unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) { // Two pointers, offset, sizes, shapes. return 3 + 2 * type.getRank(); } //===----------------------------------------------------------------------===// // MemRefDescriptorView implementation. //===----------------------------------------------------------------------===// MemRefDescriptorView::MemRefDescriptorView(ValueRange range) : rank((range.size() - kSizePosInMemRefDescriptor) / 2), elements(range) {} Value MemRefDescriptorView::allocatedPtr() { return elements[kAllocatedPtrPosInMemRefDescriptor]; } Value MemRefDescriptorView::alignedPtr() { return elements[kAlignedPtrPosInMemRefDescriptor]; } Value MemRefDescriptorView::offset() { return elements[kOffsetPosInMemRefDescriptor]; } Value MemRefDescriptorView::size(unsigned pos) { return elements[kSizePosInMemRefDescriptor + pos]; } Value MemRefDescriptorView::stride(unsigned pos) { return elements[kSizePosInMemRefDescriptor + rank + pos]; } //===----------------------------------------------------------------------===// // UnrankedMemRefDescriptor implementation //===----------------------------------------------------------------------===// /// Construct a helper for the given descriptor value. UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor) : StructBuilder(descriptor) {} /// Builds IR creating an `undef` value of the descriptor type. UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder, Location loc, Type descriptorType) { Value descriptor = builder.create(loc, descriptorType.cast()); return UnrankedMemRefDescriptor(descriptor); } Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor); } void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc, Value v) { setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v); } Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor); } void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder, Location loc, Value v) { setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v); } /// Builds IR populating an unranked MemRef descriptor structure from a list /// of individual constituent values in the following order: /// - rank of the memref; /// - pointer to the memref descriptor. Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc, LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values) { Type llvmType = converter.convertType(type); auto d = UnrankedMemRefDescriptor::undef(builder, loc, llvmType); d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]); d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]); return d; } /// Builds IR extracting individual elements that compose an unranked memref /// descriptor and returns them as `results` list. void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl &results) { UnrankedMemRefDescriptor d(packed); results.reserve(results.size() + 2); results.push_back(d.rank(builder, loc)); results.push_back(d.memRefDescPtr(builder, loc)); } void UnrankedMemRefDescriptor::computeSizes( OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, ArrayRef values, SmallVectorImpl &sizes) { if (values.empty()) return; // Cache the index type. LLVM::LLVMType indexType = typeConverter.getIndexType(); // Initialize shared constants. Value one = createIndexAttrConstant(builder, loc, indexType, 1); Value two = createIndexAttrConstant(builder, loc, indexType, 2); Value pointerSize = createIndexAttrConstant( builder, loc, indexType, ceilDiv(typeConverter.getPointerBitwidth(), 8)); Value indexSize = createIndexAttrConstant(builder, loc, indexType, ceilDiv(typeConverter.getIndexTypeBitwidth(), 8)); sizes.reserve(sizes.size() + values.size()); for (UnrankedMemRefDescriptor desc : values) { // Emit IR computing the memory necessary to store the descriptor. This // assumes the descriptor to be // { type*, type*, index, index[rank], index[rank] } // and densely packed, so the total size is // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index). // TODO: consider including the actual size (including eventual padding due // to data layout) into the unranked descriptor. Value doublePointerSize = builder.create(loc, indexType, two, pointerSize); // (1 + 2 * rank) * sizeof(index) Value rank = desc.rank(builder, loc); Value doubleRank = builder.create(loc, indexType, two, rank); Value doubleRankIncremented = builder.create(loc, indexType, doubleRank, one); Value rankIndexSize = builder.create( loc, indexType, doubleRankIncremented, indexSize); // Total allocation size. Value allocationSize = builder.create( loc, indexType, doublePointerSize, rankIndexSize); sizes.push_back(allocationSize); } } Value UnrankedMemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType) { Value elementPtrPtr = builder.create(loc, elemPtrPtrType, memRefDescPtr); return builder.create(loc, elementPtrPtr); } void UnrankedMemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType, Value allocatedPtr) { Value elementPtrPtr = builder.create(loc, elemPtrPtrType, memRefDescPtr); builder.create(loc, allocatedPtr, elementPtrPtr); } Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType) { Value elementPtrPtr = builder.create(loc, elemPtrPtrType, memRefDescPtr); Value one = createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1); Value alignedGep = builder.create( loc, elemPtrPtrType, elementPtrPtr, ValueRange({one})); return builder.create(loc, alignedGep); } void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType, Value alignedPtr) { Value elementPtrPtr = builder.create(loc, elemPtrPtrType, memRefDescPtr); Value one = createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1); Value alignedGep = builder.create( loc, elemPtrPtrType, elementPtrPtr, ValueRange({one})); builder.create(loc, alignedPtr, alignedGep); } Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType) { Value elementPtrPtr = builder.create(loc, elemPtrPtrType, memRefDescPtr); Value two = createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2); Value offsetGep = builder.create( loc, elemPtrPtrType, elementPtrPtr, ValueRange({two})); offsetGep = builder.create( loc, typeConverter.getIndexType().getPointerTo(), offsetGep); return builder.create(loc, offsetGep); } void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType, Value offset) { Value elementPtrPtr = builder.create(loc, elemPtrPtrType, memRefDescPtr); Value two = createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2); Value offsetGep = builder.create( loc, elemPtrPtrType, elementPtrPtr, ValueRange({two})); offsetGep = builder.create( loc, typeConverter.getIndexType().getPointerTo(), offsetGep); builder.create(loc, offset, offsetGep); } Value UnrankedMemRefDescriptor::sizeBasePtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType) { LLVM::LLVMType elemPtrTy = elemPtrPtrType.getPointerElementTy(); LLVM::LLVMType indexTy = typeConverter.getIndexType(); LLVM::LLVMType structPtrTy = LLVM::LLVMType::getStructTy(elemPtrTy, elemPtrTy, indexTy, indexTy) .getPointerTo(); Value structPtr = builder.create(loc, structPtrTy, memRefDescPtr); LLVM::LLVMType int32_type = unwrap(typeConverter.convertType(builder.getI32Type())); Value zero = createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0); Value three = builder.create(loc, int32_type, builder.getI32IntegerAttr(3)); return builder.create(loc, indexTy.getPointerTo(), structPtr, ValueRange({zero, three})); } Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc, LLVMTypeConverter typeConverter, Value sizeBasePtr, Value index) { LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); Value sizeStoreGep = builder.create(loc, indexPtrTy, sizeBasePtr, ValueRange({index})); return builder.create(loc, sizeStoreGep); } void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc, LLVMTypeConverter typeConverter, Value sizeBasePtr, Value index, Value size) { LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); Value sizeStoreGep = builder.create(loc, indexPtrTy, sizeBasePtr, ValueRange({index})); builder.create(loc, size, sizeStoreGep); } Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank) { LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); return builder.create(loc, indexPtrTy, sizeBasePtr, ValueRange({rank})); } Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc, LLVMTypeConverter typeConverter, Value strideBasePtr, Value index, Value stride) { LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); Value strideStoreGep = builder.create( loc, indexPtrTy, strideBasePtr, ValueRange({index})); return builder.create(loc, strideStoreGep); } void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc, LLVMTypeConverter typeConverter, Value strideBasePtr, Value index, Value stride) { LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); Value strideStoreGep = builder.create( loc, indexPtrTy, strideBasePtr, ValueRange({index})); builder.create(loc, stride, strideStoreGep); } LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const { return static_cast( ConversionPattern::getTypeConverter()); } LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const { return *getTypeConverter()->getDialect(); } LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const { return getTypeConverter()->getIndexType(); } LLVM::LLVMType ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const { return LLVM::LLVMType::getIntNTy( &getTypeConverter()->getContext(), getTypeConverter()->getPointerBitwidth(addressSpace)); } LLVM::LLVMType ConvertToLLVMPattern::getVoidType() const { return LLVM::LLVMType::getVoidTy(&getTypeConverter()->getContext()); } LLVM::LLVMType ConvertToLLVMPattern::getVoidPtrType() const { return LLVM::LLVMType::getInt8PtrTy(&getTypeConverter()->getContext()); } Value ConvertToLLVMPattern::createIndexConstant( ConversionPatternRewriter &builder, Location loc, uint64_t value) const { return createIndexAttrConstant(builder, loc, getIndexType(), value); } Value ConvertToLLVMPattern::getStridedElementPtr( Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const { int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(type, strides, offset); assert(succeeded(successStrides) && "unexpected non-strided memref"); (void)successStrides; MemRefDescriptor memRefDescriptor(memRefDesc); Value base = memRefDescriptor.alignedPtr(rewriter, loc); Value index; if (offset != 0) // Skip if offset is zero. index = offset == MemRefType::getDynamicStrideOrOffset() ? memRefDescriptor.offset(rewriter, loc) : createIndexConstant(rewriter, loc, offset); for (int i = 0, e = indices.size(); i < e; ++i) { Value increment = indices[i]; if (strides[i] != 1) { // Skip if stride is 1. Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset() ? memRefDescriptor.stride(rewriter, loc, i) : createIndexConstant(rewriter, loc, strides[i]); increment = rewriter.create(loc, increment, stride); } index = index ? rewriter.create(loc, index, increment) : increment; } LLVM::LLVMType elementPtrType = memRefDescriptor.getElementPtrType(); return index ? rewriter.create(loc, elementPtrType, base, index) : base; } Value ConvertToLLVMPattern::getDataPtr( Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const { return getStridedElementPtr(loc, type, memRefDesc, indices, rewriter); } // Check if the MemRefType `type` is supported by the lowering. We currently // only support memrefs with identity maps. bool ConvertToLLVMPattern::isSupportedMemRefType(MemRefType type) const { if (!typeConverter->convertType(type.getElementType())) return false; return type.getAffineMaps().empty() || llvm::all_of(type.getAffineMaps(), [](AffineMap map) { return map.isIdentity(); }); } Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { auto elementType = type.getElementType(); auto structElementType = unwrap(typeConverter->convertType(elementType)); return structElementType.getPointerTo(type.getMemorySpace()); } void ConvertToLLVMPattern::getMemRefDescriptorSizes( Location loc, MemRefType memRefType, ArrayRef dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl &sizes, SmallVectorImpl &strides, Value &sizeBytes) const { assert(isSupportedMemRefType(memRefType) && "layout maps must have been normalized away"); sizes.reserve(memRefType.getRank()); unsigned dynamicIndex = 0; for (int64_t size : memRefType.getShape()) { sizes.push_back(size == ShapedType::kDynamicSize ? dynamicSizes[dynamicIndex++] : createIndexConstant(rewriter, loc, size)); } // Strides: iterate sizes in reverse order and multiply. int64_t stride = 1; Value runningStride = createIndexConstant(rewriter, loc, 1); strides.resize(memRefType.getRank()); for (auto i = memRefType.getRank(); i-- > 0;) { strides[i] = runningStride; int64_t size = memRefType.getShape()[i]; if (size == 0) continue; bool useSizeAsStride = stride == 1; if (size == ShapedType::kDynamicSize) stride = ShapedType::kDynamicSize; if (stride != ShapedType::kDynamicSize) stride *= size; if (useSizeAsStride) runningStride = sizes[i]; else if (stride == ShapedType::kDynamicSize) runningStride = rewriter.create(loc, runningStride, sizes[i]); else runningStride = createIndexConstant(rewriter, loc, stride); } // Buffer size in bytes. Type elementPtrType = getElementPtrType(memRefType); Value nullPtr = rewriter.create(loc, elementPtrType); Value gepPtr = rewriter.create( loc, elementPtrType, ArrayRef{nullPtr, runningStride}); sizeBytes = rewriter.create(loc, getIndexType(), gepPtr); } Value ConvertToLLVMPattern::getSizeInBytes( Location loc, Type type, ConversionPatternRewriter &rewriter) const { // Compute the size of an individual element. This emits the MLIR equivalent // of the following sizeof(...) implementation in LLVM IR: // %0 = getelementptr %elementType* null, %indexType 1 // %1 = ptrtoint %elementType* %0 to %indexType // which is a common pattern of getting the size of a type in bytes. auto convertedPtrType = typeConverter->convertType(type).cast().getPointerTo(); auto nullPtr = rewriter.create(loc, convertedPtrType); auto gep = rewriter.create( loc, convertedPtrType, ArrayRef{nullPtr, createIndexConstant(rewriter, loc, 1)}); return rewriter.create(loc, getIndexType(), gep); } Value ConvertToLLVMPattern::getNumElements( Location loc, ArrayRef shape, ConversionPatternRewriter &rewriter) const { // Compute the total number of memref elements. Value numElements = shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front(); for (unsigned i = 1, e = shape.size(); i < e; ++i) numElements = rewriter.create(loc, numElements, shape[i]); return numElements; } /// Creates and populates the memref descriptor struct given all its fields. MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor( Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef sizes, ArrayRef strides, ConversionPatternRewriter &rewriter) const { auto structType = typeConverter->convertType(memRefType); auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType); // Field 1: Allocated pointer, used for malloc/free. memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr); // Field 2: Actual aligned pointer to payload. memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr); // Field 3: Offset in aligned pointer. memRefDescriptor.setOffset(rewriter, loc, createIndexConstant(rewriter, loc, 0)); // Fields 4: Sizes. for (auto en : llvm::enumerate(sizes)) memRefDescriptor.setSize(rewriter, loc, en.index(), en.value()); // Field 5: Strides. for (auto en : llvm::enumerate(strides)) memRefDescriptor.setStride(rewriter, loc, en.index(), en.value()); return memRefDescriptor; } /// Only retain those attributes that are not constructed by /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument /// attributes. static void filterFuncAttributes(ArrayRef attrs, bool filterArgAttrs, SmallVectorImpl &result) { for (const auto &attr : attrs) { if (attr.first == SymbolTable::getSymbolAttrName() || attr.first == impl::getTypeAttrName() || attr.first == "std.varargs" || (filterArgAttrs && impl::isArgAttrName(attr.first.strref()))) continue; result.push_back(attr); } } /// Creates an auxiliary function with pointer-to-memref-descriptor-struct /// arguments instead of unpacked arguments. This function can be called from C /// by passing a pointer to a C struct corresponding to a memref descriptor. /// Internally, the auxiliary function unpacks the descriptor into individual /// components and forwards them to `newFuncOp`. static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, LLVMTypeConverter &typeConverter, FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) { auto type = funcOp.getType(); SmallVector attributes; filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/false, attributes); auto wrapperFuncOp = rewriter.create( loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), typeConverter.convertFunctionTypeCWrapper(type), LLVM::Linkage::External, attributes); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock()); SmallVector args; for (auto &en : llvm::enumerate(type.getInputs())) { Value arg = wrapperFuncOp.getArgument(en.index()); if (auto memrefType = en.value().dyn_cast()) { Value loaded = rewriter.create(loc, arg); MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args); continue; } if (en.value().isa()) { Value loaded = rewriter.create(loc, arg); UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args); continue; } args.push_back(wrapperFuncOp.getArgument(en.index())); } auto call = rewriter.create(loc, newFuncOp, args); rewriter.create(loc, call.getResults()); } /// Creates an auxiliary function with pointer-to-memref-descriptor-struct /// arguments instead of unpacked arguments. Creates a body for the (external) /// `newFuncOp` that allocates a memref descriptor on stack, packs the /// individual arguments into this descriptor and passes a pointer to it into /// the auxiliary function. This auxiliary external function is now compatible /// with functions defined in C using pointers to C structs corresponding to a /// memref descriptor. static void wrapExternalFunction(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) { OpBuilder::InsertionGuard guard(builder); LLVM::LLVMType wrapperType = typeConverter.convertFunctionTypeCWrapper(funcOp.getType()); // This conversion can only fail if it could not convert one of the argument // types. But since it has been applies to a non-wrapper function before, it // should have failed earlier and not reach this point at all. assert(wrapperType && "unexpected type conversion failure"); SmallVector attributes; filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/false, attributes); // Create the auxiliary function. auto wrapperFunc = builder.create( loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), wrapperType, LLVM::Linkage::External, attributes); builder.setInsertionPointToStart(newFuncOp.addEntryBlock()); // Get a ValueRange containing arguments. FunctionType type = funcOp.getType(); SmallVector args; args.reserve(type.getNumInputs()); ValueRange wrapperArgsRange(newFuncOp.getArguments()); // Iterate over the inputs of the original function and pack values into // memref descriptors if the original type is a memref. for (auto &en : llvm::enumerate(type.getInputs())) { Value arg; int numToDrop = 1; auto memRefType = en.value().dyn_cast(); auto unrankedMemRefType = en.value().dyn_cast(); if (memRefType || unrankedMemRefType) { numToDrop = memRefType ? MemRefDescriptor::getNumUnpackedValues(memRefType) : UnrankedMemRefDescriptor::getNumUnpackedValues(); Value packed = memRefType ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType, wrapperArgsRange.take_front(numToDrop)) : UnrankedMemRefDescriptor::pack( builder, loc, typeConverter, unrankedMemRefType, wrapperArgsRange.take_front(numToDrop)); auto ptrTy = packed.getType().cast().getPointerTo(); Value one = builder.create( loc, typeConverter.convertType(builder.getIndexType()), builder.getIntegerAttr(builder.getIndexType(), 1)); Value allocated = builder.create(loc, ptrTy, one, /*alignment=*/0); builder.create(loc, packed, allocated); arg = allocated; } else { arg = wrapperArgsRange[0]; } args.push_back(arg); wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop); } assert(wrapperArgsRange.empty() && "did not map some of the arguments"); auto call = builder.create(loc, wrapperFunc, args); builder.create(loc, call.getResults()); } namespace { struct FuncOpConversionBase : public ConvertOpToLLVMPattern { protected: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided // to this legalization pattern. LLVM::LLVMFuncOp convertFuncOpToLLVMFuncOp(FuncOp funcOp, ConversionPatternRewriter &rewriter) const { // Convert the original function arguments. They are converted using the // LLVMTypeConverter provided to this legalization pattern. auto varargsAttr = funcOp->getAttrOfType("std.varargs"); TypeConverter::SignatureConversion result(funcOp.getNumArguments()); auto llvmType = getTypeConverter()->convertFunctionSignature( funcOp.getType(), varargsAttr && varargsAttr.getValue(), result); if (!llvmType) return nullptr; // Propagate argument attributes to all converted arguments obtained after // converting a given original argument. SmallVector attributes; filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/true, attributes); for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) { auto attr = impl::getArgAttrDict(funcOp, i); if (!attr) continue; auto mapping = result.getInputMapping(i); assert(mapping.hasValue() && "unexpected deletion of function argument"); SmallString<8> name; for (size_t j = 0; j < mapping->size; ++j) { impl::getArgAttrName(mapping->inputNo + j, name); attributes.push_back(rewriter.getNamedAttr(name, attr)); } } // Create an LLVM function, use external linkage by default until MLIR // functions have linkage. auto newFuncOp = rewriter.create( funcOp.getLoc(), funcOp.getName(), llvmType, LLVM::Linkage::External, attributes); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter, &result))) return nullptr; return newFuncOp; } }; /// FuncOp legalization pattern that converts MemRef arguments to pointers to /// MemRef descriptors (LLVM struct data types) containing all the MemRef type /// information. static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface"; struct FuncOpConversion : public FuncOpConversionBase { FuncOpConversion(LLVMTypeConverter &converter) : FuncOpConversionBase(converter) {} LogicalResult matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); if (!newFuncOp) return failure(); if (getTypeConverter()->getOptions().emitCWrappers || funcOp->getAttrOfType(kEmitIfaceAttrName)) { if (newFuncOp.isExternal()) wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(), funcOp, newFuncOp); else wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(), funcOp, newFuncOp); } rewriter.eraseOp(funcOp); return success(); } }; /// FuncOp legalization pattern that converts MemRef arguments to bare pointers /// to the MemRef element type. This will impact the calling convention and ABI. struct BarePtrFuncOpConversion : public FuncOpConversionBase { using FuncOpConversionBase::FuncOpConversionBase; LogicalResult matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // Store the type of memref-typed arguments before the conversion so that we // can promote them to MemRef descriptor at the beginning of the function. SmallVector oldArgTypes = llvm::to_vector<8>(funcOp.getType().getInputs()); auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); if (!newFuncOp) return failure(); if (newFuncOp.getBody().empty()) { rewriter.eraseOp(funcOp); return success(); } // Promote bare pointers from memref arguments to memref descriptors at the // beginning of the function so that all the memrefs in the function have a // uniform representation. Block *entryBlock = &newFuncOp.getBody().front(); auto blockArgs = entryBlock->getArguments(); assert(blockArgs.size() == oldArgTypes.size() && "The number of arguments and types doesn't match"); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(entryBlock); for (auto it : llvm::zip(blockArgs, oldArgTypes)) { BlockArgument arg = std::get<0>(it); Type argTy = std::get<1>(it); // Unranked memrefs are not supported in the bare pointer calling // convention. We should have bailed out before in the presence of // unranked memrefs. assert(!argTy.isa() && "Unranked memref is not supported"); auto memrefTy = argTy.dyn_cast(); if (!memrefTy) continue; // Replace barePtr with a placeholder (undef), promote barePtr to a ranked // or unranked memref descriptor and replace placeholder with the last // instruction of the memref descriptor. // TODO: The placeholder is needed to avoid replacing barePtr uses in the // MemRef descriptor instructions. We may want to have a utility in the // rewriter to properly handle this use case. Location loc = funcOp.getLoc(); auto placeholder = rewriter.create(loc, memrefTy); rewriter.replaceUsesOfBlockArgument(arg, placeholder); Value desc = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), memrefTy, arg); rewriter.replaceOp(placeholder, {desc}); } rewriter.eraseOp(funcOp); return success(); } }; //////////////// Support for Lowering operations on n-D vectors //////////////// // Helper struct to "unroll" operations on n-D vectors in terms of operations on // 1-D LLVM vectors. struct NDVectorTypeInfo { // LLVM array struct which encodes n-D vectors. LLVM::LLVMType llvmArrayTy; // LLVM vector type which encodes the inner 1-D vector type. LLVM::LLVMType llvmVectorTy; // Multiplicity of llvmArrayTy to llvmVectorTy. SmallVector arraySizes; }; } // namespace // For >1-D vector types, extracts the necessary information to iterate over all // 1-D subvectors in the underlying llrepresentation of the n-D vector // Iterates on the llvm array type until we hit a non-array type (which is // asserted to be an llvm vector type). static NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, LLVMTypeConverter &converter) { assert(vectorType.getRank() > 1 && "expected >1D vector type"); NDVectorTypeInfo info; info.llvmArrayTy = converter.convertType(vectorType).dyn_cast(); if (!info.llvmArrayTy) return info; info.arraySizes.reserve(vectorType.getRank() - 1); auto llvmTy = info.llvmArrayTy; while (llvmTy.isArrayTy()) { info.arraySizes.push_back(llvmTy.getArrayNumElements()); llvmTy = llvmTy.getArrayElementType(); } if (!llvmTy.isVectorTy()) return info; info.llvmVectorTy = llvmTy; return info; } // Express `linearIndex` in terms of coordinates of `basis`. // Returns the empty vector when linearIndex is out of the range [0, P] where // P is the product of all the basis coordinates. // // Prerequisites: // Basis is an array of nonnegative integers (signed type inherited from // vector shape type). static SmallVector getCoordinates(ArrayRef basis, unsigned linearIndex) { SmallVector res; res.reserve(basis.size()); for (unsigned basisElement : llvm::reverse(basis)) { res.push_back(linearIndex % basisElement); linearIndex = linearIndex / basisElement; } if (linearIndex > 0) return {}; std::reverse(res.begin(), res.end()); return res; } // Iterate of linear index, convert to coords space and insert splatted 1-D // vector in each position. template void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, Lambda fun) { unsigned ub = 1; for (auto s : info.arraySizes) ub *= s; for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) { auto coords = getCoordinates(info.arraySizes, linearIndex); // Linear index is out of bounds, we are done. if (coords.empty()) break; assert(coords.size() == info.arraySizes.size()); auto position = builder.getI64ArrayAttr(coords); fun(position); } } ////////////// End Support for Lowering operations on n-D vectors ////////////// /// Replaces the given operation "op" with a new operation of type "targetOp" /// and given operands. LogicalResult LLVM::detail::oneToOneRewrite( Operation *op, StringRef targetOp, ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { unsigned numResults = op->getNumResults(); Type packedType; if (numResults != 0) { packedType = typeConverter.packFunctionResults(op->getResultTypes()); if (!packedType) return failure(); } // Create the operation through state since we don't know its C++ type. OperationState state(op->getLoc(), targetOp); state.addTypes(packedType); state.addOperands(operands); state.addAttributes(op->getAttrs()); Operation *newOp = rewriter.createOperation(state); // If the operation produced 0 or 1 result, return them immediately. if (numResults == 0) return rewriter.eraseOp(op), success(); if (numResults == 1) return rewriter.replaceOp(op, newOp->getResult(0)), success(); // Otherwise, it had been converted to an operation producing a structure. // Extract individual results from the structure and return them as list. SmallVector results; results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { auto type = typeConverter.convertType(op->getResult(i).getType()); results.push_back(rewriter.create( op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i))); } rewriter.replaceOp(op, results); return success(); } static LogicalResult handleMultidimensionalVectors( Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, std::function createOperand, ConversionPatternRewriter &rewriter) { auto vectorType = op->getResult(0).getType().dyn_cast(); if (!vectorType) return failure(); auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, typeConverter); auto llvmVectorTy = vectorTypeInfo.llvmVectorTy; auto llvmArrayTy = operands[0].getType().cast(); if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy) return failure(); auto loc = op->getLoc(); Value desc = rewriter.create(loc, llvmArrayTy); nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { // For this unrolled `position` corresponding to the `linearIndex`^th // element, extract operand vectors SmallVector extractedOperands; for (auto operand : operands) extractedOperands.push_back(rewriter.create( loc, llvmVectorTy, operand, position)); Value newVal = createOperand(llvmVectorTy, extractedOperands); desc = rewriter.create(loc, llvmArrayTy, desc, newVal, position); }); rewriter.replaceOp(op, desc); return success(); } LogicalResult LLVM::detail::vectorOneToOneRewrite( Operation *op, StringRef targetOp, ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { assert(!operands.empty()); // Cannot convert ops if their operands are not of LLVM type. if (!llvm::all_of(operands.getTypes(), [](Type t) { return t.isa(); })) return failure(); auto llvmArrayTy = operands[0].getType().cast(); if (!llvmArrayTy.isArrayTy()) return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter); auto callback = [op, targetOp, &rewriter](LLVM::LLVMType llvmVectorTy, ValueRange operands) { OperationState state(op->getLoc(), targetOp); state.addTypes(llvmVectorTy); state.addOperands(operands); state.addAttributes(op->getAttrs()); return rewriter.createOperation(state)->getResult(0); }; return handleMultidimensionalVectors(op, operands, typeConverter, callback, rewriter); } namespace { // Straightforward lowerings. using AbsFOpLowering = VectorConvertToLLVMPattern; using AddFOpLowering = VectorConvertToLLVMPattern; using AddIOpLowering = VectorConvertToLLVMPattern; using AndOpLowering = VectorConvertToLLVMPattern; using CeilFOpLowering = VectorConvertToLLVMPattern; using CopySignOpLowering = VectorConvertToLLVMPattern; using CosOpLowering = VectorConvertToLLVMPattern; using DivFOpLowering = VectorConvertToLLVMPattern; using ExpOpLowering = VectorConvertToLLVMPattern; using Exp2OpLowering = VectorConvertToLLVMPattern; using FloorFOpLowering = VectorConvertToLLVMPattern; using Log10OpLowering = VectorConvertToLLVMPattern; using Log2OpLowering = VectorConvertToLLVMPattern; using LogOpLowering = VectorConvertToLLVMPattern; using MulFOpLowering = VectorConvertToLLVMPattern; using MulIOpLowering = VectorConvertToLLVMPattern; using NegFOpLowering = VectorConvertToLLVMPattern; using OrOpLowering = VectorConvertToLLVMPattern; using RemFOpLowering = VectorConvertToLLVMPattern; using SelectOpLowering = OneToOneConvertToLLVMPattern; using ShiftLeftOpLowering = OneToOneConvertToLLVMPattern; using SignedDivIOpLowering = VectorConvertToLLVMPattern; using SignedRemIOpLowering = VectorConvertToLLVMPattern; using SignedShiftRightOpLowering = OneToOneConvertToLLVMPattern; using SinOpLowering = VectorConvertToLLVMPattern; using SqrtOpLowering = VectorConvertToLLVMPattern; using SubFOpLowering = VectorConvertToLLVMPattern; using SubIOpLowering = VectorConvertToLLVMPattern; using UnsignedDivIOpLowering = VectorConvertToLLVMPattern; using UnsignedRemIOpLowering = VectorConvertToLLVMPattern; using UnsignedShiftRightOpLowering = OneToOneConvertToLLVMPattern; using XOrOpLowering = VectorConvertToLLVMPattern; /// Lower `std.assert`. The default lowering calls the `abort` function if the /// assertion is violated and has no effect otherwise. The failure message is /// ignored by the default lowering but should be propagated by any custom /// lowering. struct AssertOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(AssertOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); AssertOp::Adaptor transformed(operands); // Insert the `abort` declaration if necessary. auto module = op->getParentOfType(); auto abortFunc = module.lookupSymbol("abort"); if (!abortFunc) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); auto abortFuncTy = LLVM::LLVMType::getFunctionTy(getVoidType(), {}, /*isVarArg=*/false); abortFunc = rewriter.create(rewriter.getUnknownLoc(), "abort", abortFuncTy); } // Split block at `assert` operation. Block *opBlock = rewriter.getInsertionBlock(); auto opPosition = rewriter.getInsertionPoint(); Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition); // Generate IR to call `abort`. Block *failureBlock = rewriter.createBlock(opBlock->getParent()); rewriter.create(loc, abortFunc, llvm::None); rewriter.create(loc); // Generate assertion test. rewriter.setInsertionPointToEnd(opBlock); rewriter.replaceOpWithNewOp( op, transformed.arg(), continuationBlock, failureBlock); return success(); } }; // Lowerings for operations on complex numbers. struct CreateComplexOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(CreateComplexOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto complexOp = cast(op); CreateComplexOp::Adaptor transformed(operands); // Pack real and imaginary part in a complex number struct. auto loc = op.getLoc(); auto structType = typeConverter->convertType(complexOp.getType()); auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType); complexStruct.setReal(rewriter, loc, transformed.real()); complexStruct.setImaginary(rewriter, loc, transformed.imaginary()); rewriter.replaceOp(op, {complexStruct}); return success(); } }; struct ReOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(ReOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { ReOp::Adaptor transformed(operands); // Extract real part from the complex number struct. ComplexStructBuilder complexStruct(transformed.complex()); Value real = complexStruct.real(rewriter, op.getLoc()); rewriter.replaceOp(op, real); return success(); } }; struct ImOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(ImOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { ImOp::Adaptor transformed(operands); // Extract imaginary part from the complex number struct. ComplexStructBuilder complexStruct(transformed.complex()); Value imaginary = complexStruct.imaginary(rewriter, op.getLoc()); rewriter.replaceOp(op, imaginary); return success(); } }; struct BinaryComplexOperands { std::complex lhs, rhs; }; template BinaryComplexOperands unpackBinaryComplexOperands(OpTy op, ArrayRef operands, ConversionPatternRewriter &rewriter) { auto bop = cast(op); auto loc = bop.getLoc(); typename OpTy::Adaptor transformed(operands); // Extract real and imaginary values from operands. BinaryComplexOperands unpacked; ComplexStructBuilder lhs(transformed.lhs()); unpacked.lhs.real(lhs.real(rewriter, loc)); unpacked.lhs.imag(lhs.imaginary(rewriter, loc)); ComplexStructBuilder rhs(transformed.rhs()); unpacked.rhs.real(rhs.real(rewriter, loc)); unpacked.rhs.imag(rhs.imaginary(rewriter, loc)); return unpacked; } struct AddCFOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(AddCFOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); BinaryComplexOperands arg = unpackBinaryComplexOperands(op, operands, rewriter); // Initialize complex number struct for result. auto structType = typeConverter->convertType(op.getType()); auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to add complex numbers. Value real = rewriter.create(loc, arg.lhs.real(), arg.rhs.real()); Value imag = rewriter.create(loc, arg.lhs.imag(), arg.rhs.imag()); result.setReal(rewriter, loc, real); result.setImaginary(rewriter, loc, imag); rewriter.replaceOp(op, {result}); return success(); } }; struct SubCFOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(SubCFOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); BinaryComplexOperands arg = unpackBinaryComplexOperands(op, operands, rewriter); // Initialize complex number struct for result. auto structType = typeConverter->convertType(op.getType()); auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to substract complex numbers. Value real = rewriter.create(loc, arg.lhs.real(), arg.rhs.real()); Value imag = rewriter.create(loc, arg.lhs.imag(), arg.rhs.imag()); result.setReal(rewriter, loc, real); result.setImaginary(rewriter, loc, imag); rewriter.replaceOp(op, {result}); return success(); } }; struct ConstantOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(ConstantOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // If constant refers to a function, convert it to "addressof". if (auto symbolRef = op.getValue().dyn_cast()) { auto type = typeConverter->convertType(op.getResult().getType()) .dyn_cast_or_null(); if (!type) return rewriter.notifyMatchFailure(op, "failed to convert result type"); MutableDictionaryAttr attrs(op.getAttrs()); attrs.remove(rewriter.getIdentifier("value")); rewriter.replaceOpWithNewOp( op, type.cast(), symbolRef.getValue(), attrs.getAttrs()); return success(); } // Calling into other scopes (non-flat reference) is not supported in LLVM. if (op.getValue().isa()) return rewriter.notifyMatchFailure( op, "referring to a symbol outside of the current module"); return LLVM::detail::oneToOneRewrite( op, LLVM::ConstantOp::getOperationName(), operands, *getTypeConverter(), rewriter); } }; /// Lowering for AllocOp and AllocaOp. struct AllocLikeOpLowering : public ConvertToLLVMPattern { using ConvertToLLVMPattern::createIndexConstant; using ConvertToLLVMPattern::getIndexType; using ConvertToLLVMPattern::getVoidPtrType; explicit AllocLikeOpLowering(StringRef opName, LLVMTypeConverter &converter) : ConvertToLLVMPattern(opName, &converter.getContext(), converter) {} protected: // Returns 'input' aligned up to 'alignment'. Computes // bumped = input + alignement - 1 // aligned = bumped - bumped % alignment static Value createAligned(ConversionPatternRewriter &rewriter, Location loc, Value input, Value alignment) { Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1); Value bump = rewriter.create(loc, alignment, one); Value bumped = rewriter.create(loc, input, bump); Value mod = rewriter.create(loc, bumped, alignment); return rewriter.create(loc, bumped, mod); } // Creates a call to an allocation function with params and casts the // resulting void pointer to ptrType. Value createAllocCall(Location loc, StringRef name, Type ptrType, ArrayRef params, ModuleOp module, ConversionPatternRewriter &rewriter) const { SmallVector paramTypes; auto allocFuncOp = module.lookupSymbol(name); if (!allocFuncOp) { for (Value param : params) paramTypes.push_back(param.getType().cast()); auto allocFuncType = LLVM::LLVMType::getFunctionTy(getVoidPtrType(), paramTypes, /*isVarArg=*/false); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); allocFuncOp = rewriter.create(rewriter.getUnknownLoc(), name, allocFuncType); } auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFuncOp); auto allocatedPtr = rewriter .create(loc, getVoidPtrType(), allocFuncSymbol, params) .getResult(0); return rewriter.create(loc, ptrType, allocatedPtr); } /// Allocates the underlying buffer. Returns the allocated pointer and the /// aligned pointer. virtual std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op) const = 0; private: static MemRefType getMemRefResultType(Operation *op) { return op->getResult(0).getType().cast(); } LogicalResult match(Operation *op) const override { MemRefType memRefType = getMemRefResultType(op); return success(isSupportedMemRefType(memRefType)); } // An `alloc` is converted into a definition of a memref descriptor value and // a call to `malloc` to allocate the underlying data buffer. The memref // descriptor is of the LLVM structure type where: // 1. the first element is a pointer to the allocated (typed) data buffer, // 2. the second element is a pointer to the (typed) payload, aligned to the // specified alignment, // 3. the remaining elements serve to store all the sizes and strides of the // memref using LLVM-converted `index` type. // // Alignment is performed by allocating `alignment` more bytes than // requested and shifting the aligned pointer relative to the allocated // memory. Note: `alignment - ` would actually be // sufficient. If alignment is unspecified, the two pointers are equal. // An `alloca` is converted into a definition of a memref descriptor value and // an llvm.alloca to allocate the underlying data buffer. void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { MemRefType memRefType = getMemRefResultType(op); auto loc = op->getLoc(); // Get actual sizes of the memref as values: static sizes are constant // values and dynamic sizes are passed to 'alloc' as operands. In case of // zero-dimensional memref, assume a scalar (size 1). SmallVector sizes; SmallVector strides; Value sizeBytes; this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes, strides, sizeBytes); // Allocate the underlying buffer. Value allocatedPtr; Value alignedPtr; std::tie(allocatedPtr, alignedPtr) = this->allocateBuffer(rewriter, loc, sizeBytes, op); // Create the MemRef descriptor. auto memRefDescriptor = this->createMemRefDescriptor( loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter); // Return the final value of the descriptor. rewriter.replaceOp(op, {memRefDescriptor}); } }; struct AllocOpLowering : public AllocLikeOpLowering { AllocOpLowering(LLVMTypeConverter &converter) : AllocLikeOpLowering(AllocOp::getOperationName(), converter) {} std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op) const override { // Heap allocations. AllocOp allocOp = cast(op); MemRefType memRefType = allocOp.getType(); Value alignment; if (auto alignmentAttr = allocOp.alignment()) { alignment = createIndexConstant(rewriter, loc, *alignmentAttr); } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { // In the case where no alignment is specified, we may want to override // `malloc's` behavior. `malloc` typically aligns at the size of the // biggest scalar on a target HW. For non-scalars, use the natural // alignment of the LLVM type given by the LLVM DataLayout. alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); } if (alignment) { // Adjust the allocation size to consider alignment. sizeBytes = rewriter.create(loc, sizeBytes, alignment); } // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. Type elementPtrType = this->getElementPtrType(memRefType); Value allocatedPtr = createAllocCall(loc, "malloc", elementPtrType, {sizeBytes}, allocOp->getParentOfType(), rewriter); Value alignedPtr = allocatedPtr; if (alignment) { auto intPtrType = getIntPtrType(memRefType.getMemorySpace()); // Compute the aligned type pointer. Value allocatedInt = rewriter.create(loc, intPtrType, allocatedPtr); Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment); alignedPtr = rewriter.create(loc, elementPtrType, alignmentInt); } return std::make_tuple(allocatedPtr, alignedPtr); } }; struct AlignedAllocOpLowering : public AllocLikeOpLowering { AlignedAllocOpLowering(LLVMTypeConverter &converter) : AllocLikeOpLowering(AllocOp::getOperationName(), converter) {} /// Returns the memref's element size in bytes. // TODO: there are other places where this is used. Expose publicly? static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { auto elementType = memRefType.getElementType(); unsigned sizeInBits; if (elementType.isIntOrFloat()) { sizeInBits = elementType.getIntOrFloatBitWidth(); } else { auto vectorType = elementType.cast(); sizeInBits = vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); } return llvm::divideCeil(sizeInBits, 8); } /// Returns true if the memref size in bytes is known to be a multiple of /// factor. static bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor) { uint64_t sizeDivisor = getMemRefEltSizeInBytes(type); for (unsigned i = 0, e = type.getRank(); i < e; i++) { if (type.isDynamic(type.getDimSize(i))) continue; sizeDivisor = sizeDivisor * type.getDimSize(i); } return sizeDivisor % factor == 0; } /// Returns the alignment to be used for the allocation call itself. /// aligned_alloc requires the allocation size to be a power of two, and the /// allocation size to be a multiple of alignment, int64_t getAllocationAlignment(AllocOp allocOp) const { if (Optional alignment = allocOp.alignment()) return *alignment; // Whenever we don't have alignment set, we will use an alignment // consistent with the element type; since the allocation size has to be a // power of two, we will bump to the next power of two if it already isn't. auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType()); return std::max(kMinAlignedAllocAlignment, llvm::PowerOf2Ceil(eltSizeBytes)); } std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op) const override { // Heap allocations. AllocOp allocOp = cast(op); MemRefType memRefType = allocOp.getType(); int64_t alignment = getAllocationAlignment(allocOp); Value allocAlignment = createIndexConstant(rewriter, loc, alignment); // aligned_alloc requires size to be a multiple of alignment; we will pad // the size to the next multiple if necessary. if (!isMemRefSizeMultipleOf(memRefType, alignment)) sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); Type elementPtrType = this->getElementPtrType(memRefType); Value allocatedPtr = createAllocCall( loc, "aligned_alloc", elementPtrType, {allocAlignment, sizeBytes}, allocOp->getParentOfType(), rewriter); return std::make_tuple(allocatedPtr, allocatedPtr); } /// The minimum alignment to use with aligned_alloc (has to be a power of 2). static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; }; // Out of line definition, required till C++17. constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment; struct AllocaOpLowering : public AllocLikeOpLowering { AllocaOpLowering(LLVMTypeConverter &converter) : AllocLikeOpLowering(AllocaOp::getOperationName(), converter) {} /// Allocates the underlying buffer using the right call. `allocatedBytePtr` /// is set to null for stack allocations. `accessAlignment` is set if /// alignment is needed post allocation (for eg. in conjunction with malloc). std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op) const override { // With alloca, one gets a pointer to the element type right away. // For stack allocations. auto allocaOp = cast(op); auto elementPtrType = this->getElementPtrType(allocaOp.getType()); auto allocatedElementPtr = rewriter.create( loc, elementPtrType, sizeBytes, allocaOp.alignment() ? *allocaOp.alignment() : 0); return std::make_tuple(allocatedElementPtr, allocatedElementPtr); } }; /// Copies the shaped descriptor part to (if `toDynamic` is set) or from /// (otherwise) the dynamically allocated memory for any operands that were /// unranked descriptors originally. static LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, TypeRange origTypes, SmallVectorImpl &operands, bool toDynamic) { assert(origTypes.size() == operands.size() && "expected as may original types as operands"); // Find operands of unranked memref type and store them. SmallVector unrankedMemrefs; for (unsigned i = 0, e = operands.size(); i < e; ++i) if (origTypes[i].isa()) unrankedMemrefs.emplace_back(operands[i]); if (unrankedMemrefs.empty()) return success(); // Compute allocation sizes. SmallVector sizes; UnrankedMemRefDescriptor::computeSizes(builder, loc, typeConverter, unrankedMemrefs, sizes); // Get frequently used types. MLIRContext *context = builder.getContext(); auto voidType = LLVM::LLVMType::getVoidTy(context); auto voidPtrType = LLVM::LLVMType::getInt8PtrTy(context); auto i1Type = LLVM::LLVMType::getInt1Ty(context); LLVM::LLVMType indexType = typeConverter.getIndexType(); // Find the malloc and free, or declare them if necessary. auto module = builder.getInsertionPoint()->getParentOfType(); auto mallocFunc = module.lookupSymbol("malloc"); if (!mallocFunc && toDynamic) { OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(module.getBody()); mallocFunc = builder.create( builder.getUnknownLoc(), "malloc", LLVM::LLVMType::getFunctionTy( voidPtrType, llvm::makeArrayRef(indexType), /*isVarArg=*/false)); } auto freeFunc = module.lookupSymbol("free"); if (!freeFunc && !toDynamic) { OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(module.getBody()); freeFunc = builder.create( builder.getUnknownLoc(), "free", LLVM::LLVMType::getFunctionTy(voidType, llvm::makeArrayRef(voidPtrType), /*isVarArg=*/false)); } // Initialize shared constants. Value zero = builder.create(loc, i1Type, builder.getBoolAttr(false)); unsigned unrankedMemrefPos = 0; for (unsigned i = 0, e = operands.size(); i < e; ++i) { Type type = origTypes[i]; if (!type.isa()) continue; Value allocationSize = sizes[unrankedMemrefPos++]; UnrankedMemRefDescriptor desc(operands[i]); // Allocate memory, copy, and free the source if necessary. Value memory = toDynamic ? builder.create(loc, mallocFunc, allocationSize) .getResult(0) : builder.create(loc, voidPtrType, allocationSize, /*alignment=*/0); Value source = desc.memRefDescPtr(builder, loc); builder.create(loc, memory, source, allocationSize, zero); if (!toDynamic) builder.create(loc, freeFunc, source); // Create a new descriptor. The same descriptor can be returned multiple // times, attempting to modify its pointer can lead to memory leaks // (allocated twice and overwritten) or double frees (the caller does not // know if the descriptor points to the same memory). Type descriptorType = typeConverter.convertType(type); if (!descriptorType) return failure(); auto updatedDesc = UnrankedMemRefDescriptor::undef(builder, loc, descriptorType); Value rank = desc.rank(builder, loc); updatedDesc.setRank(builder, loc, rank); updatedDesc.setMemRefDescPtr(builder, loc, memory); operands[i] = updatedDesc; } return success(); } // A CallOp automatically promotes MemRefType to a sequence of alloca/store and // passes the pointer to the MemRef across function boundaries. template struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Super = CallOpInterfaceLowering; using Base = ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(CallOpType callOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { typename CallOpType::Adaptor transformed(operands); // Pack the result types into a struct. Type packedResult = nullptr; unsigned numResults = callOp.getNumResults(); auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); if (numResults != 0) { if (!(packedResult = this->getTypeConverter()->packFunctionResults(resultTypes))) return failure(); } auto promoted = this->getTypeConverter()->promoteOperands( callOp.getLoc(), /*opOperands=*/callOp->getOperands(), operands, rewriter); auto newOp = rewriter.create( callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), promoted, callOp.getAttrs()); SmallVector results; if (numResults < 2) { // If < 2 results, packing did not do anything and we can just return. results.append(newOp.result_begin(), newOp.result_end()); } else { // Otherwise, it had been converted to an operation producing a structure. // Extract individual results from the structure and return them as list. results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { auto type = this->typeConverter->convertType(callOp.getResult(i).getType()); results.push_back(rewriter.create( callOp.getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i))); } } if (this->getTypeConverter()->getOptions().useBarePtrCallConv) { // For the bare-ptr calling convention, promote memref results to // descriptors. assert(results.size() == resultTypes.size() && "The number of arguments and types doesn't match"); this->getTypeConverter()->promoteBarePtrsToDescriptors( rewriter, callOp.getLoc(), resultTypes, results); } else if (failed(copyUnrankedDescriptors(rewriter, callOp.getLoc(), *this->getTypeConverter(), resultTypes, results, /*toDynamic=*/false))) { return failure(); } rewriter.replaceOp(callOp, results); return success(); } }; struct CallOpLowering : public CallOpInterfaceLowering { using Super::Super; }; struct CallIndirectOpLowering : public CallOpInterfaceLowering { using Super::Super; }; // A `dealloc` is converted into a call to `free` on the underlying data buffer. // The memref descriptor being an SSA value, there is no need to clean it up // in any way. struct DeallocOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; explicit DeallocOpLowering(LLVMTypeConverter &converter) : ConvertOpToLLVMPattern(converter) {} LogicalResult matchAndRewrite(DeallocOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { assert(operands.size() == 1 && "dealloc takes one operand"); DeallocOp::Adaptor transformed(operands); // Insert the `free` declaration if it is not already present. auto freeFunc = op->getParentOfType().lookupSymbol("free"); if (!freeFunc) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart( op->getParentOfType().getBody()); freeFunc = rewriter.create( rewriter.getUnknownLoc(), "free", LLVM::LLVMType::getFunctionTy(getVoidType(), getVoidPtrType(), /*isVarArg=*/false)); } MemRefDescriptor memref(transformed.memref()); Value casted = rewriter.create( op.getLoc(), getVoidPtrType(), memref.allocatedPtr(rewriter, op.getLoc())); rewriter.replaceOpWithNewOp( op, TypeRange(), rewriter.getSymbolRefAttr(freeFunc), casted); return success(); } }; /// Returns the LLVM type of the global variable given the memref type `type`. static LLVM::LLVMType convertGlobalMemrefTypeToLLVM(MemRefType type, LLVMTypeConverter &typeConverter) { // LLVM type for a global memref will be a multi-dimension array. For // declarations or uninitialized global memrefs, we can potentially flatten // this to a 1D array. However, for global_memref's with an initial value, // we do not intend to flatten the ElementsAttribute when going from std -> // LLVM dialect, so the LLVM type needs to me a multi-dimension array. LLVM::LLVMType elementType = unwrap(typeConverter.convertType(type.getElementType())); LLVM::LLVMType arrayTy = elementType; // Shape has the outermost dim at index 0, so need to walk it backwards for (int64_t dim : llvm::reverse(type.getShape())) arrayTy = LLVM::LLVMType::getArrayTy(arrayTy, dim); return arrayTy; } /// GlobalMemrefOp is lowered to a LLVM Global Variable. struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(GlobalMemrefOp global, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { MemRefType type = global.type().cast(); if (!isSupportedMemRefType(type)) return failure(); LLVM::LLVMType arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); LLVM::Linkage linkage = global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; Attribute initialValue = nullptr; if (!global.isExternal() && !global.isUninitialized()) { auto elementsAttr = global.initial_value()->cast(); initialValue = elementsAttr; // For scalar memrefs, the global variable created is of the element type, // so unpack the elements attribute to extract the value. if (type.getRank() == 0) initialValue = elementsAttr.getValue({}); } rewriter.replaceOpWithNewOp( global, arrayTy, global.constant(), linkage, global.sym_name(), initialValue, type.getMemorySpace()); return success(); } }; /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to /// the first element stashed into the descriptor. This reuses /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering { GetGlobalMemrefOpLowering(LLVMTypeConverter &converter) : AllocLikeOpLowering(GetGlobalMemrefOp::getOperationName(), converter) {} /// Buffer "allocation" for get_global_memref op is getting the address of /// the global variable referenced. std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op) const override { auto getGlobalOp = cast(op); MemRefType type = getGlobalOp.result().getType().cast(); unsigned memSpace = type.getMemorySpace(); LLVM::LLVMType arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); auto addressOf = rewriter.create( loc, arrayTy.getPointerTo(memSpace), getGlobalOp.name()); // Get the address of the first element in the array by creating a GEP with // the address of the GV as the base, and (rank + 1) number of 0 indices. LLVM::LLVMType elementType = unwrap(typeConverter->convertType(type.getElementType())); LLVM::LLVMType elementPtrType = elementType.getPointerTo(memSpace); SmallVector operands = {addressOf}; operands.insert(operands.end(), type.getRank() + 1, createIndexConstant(rewriter, loc, 0)); auto gep = rewriter.create(loc, elementPtrType, operands); // We do not expect the memref obtained using `get_global_memref` to be // ever deallocated. Set the allocated pointer to be known bad value to // help debug if that ever happens. auto intPtrType = getIntPtrType(memSpace); Value deadBeefConst = createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); auto deadBeefPtr = rewriter.create(loc, elementPtrType, deadBeefConst); // Both allocated and aligned pointers are same. We could potentially stash // a nullptr for the allocated pointer since we do not expect any dealloc. return std::make_tuple(deadBeefPtr, gep); } }; // A `rsqrt` is converted into `1 / sqrt`. struct RsqrtOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(RsqrtOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { RsqrtOp::Adaptor transformed(operands); auto operandType = transformed.operand().getType().dyn_cast(); if (!operandType) return failure(); auto loc = op.getLoc(); auto resultType = op.getResult().getType(); auto floatType = getElementTypeOrSelf(resultType).cast(); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); if (!operandType.isArrayTy()) { LLVM::ConstantOp one; if (operandType.isVectorTy()) { one = rewriter.create( loc, operandType, SplatElementsAttr::get(resultType.cast(), floatOne)); } else { one = rewriter.create(loc, operandType, floatOne); } auto sqrt = rewriter.create(loc, transformed.operand()); rewriter.replaceOpWithNewOp(op, operandType, one, sqrt); return success(); } auto vectorType = resultType.dyn_cast(); if (!vectorType) return failure(); return handleMultidimensionalVectors( op.getOperation(), operands, *getTypeConverter(), [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) { auto splatAttr = SplatElementsAttr::get( mlir::VectorType::get({llvmVectorTy.getVectorNumElements()}, floatType), floatOne); auto one = rewriter.create(loc, llvmVectorTy, splatAttr); auto sqrt = rewriter.create(loc, llvmVectorTy, operands[0]); return rewriter.create(loc, llvmVectorTy, one, sqrt); }, rewriter); } }; struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult match(MemRefCastOp memRefCastOp) const override { Type srcType = memRefCastOp.getOperand().getType(); Type dstType = memRefCastOp.getType(); // MemRefCastOp reduce to bitcast in the ranked MemRef case and can be used // for type erasure. For now they must preserve underlying element type and // require source and result type to have the same rank. Therefore, perform // a sanity check that the underlying structs are the same. Once op // semantics are relaxed we can revisit. if (srcType.isa() && dstType.isa()) return success(typeConverter->convertType(srcType) == typeConverter->convertType(dstType)); // At least one of the operands is unranked type assert(srcType.isa() || dstType.isa()); // Unranked to unranked cast is disallowed return !(srcType.isa() && dstType.isa()) ? success() : failure(); } void rewrite(MemRefCastOp memRefCastOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { MemRefCastOp::Adaptor transformed(operands); auto srcType = memRefCastOp.getOperand().getType(); auto dstType = memRefCastOp.getType(); auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); auto loc = memRefCastOp.getLoc(); // For ranked/ranked case, just keep the original descriptor. if (srcType.isa() && dstType.isa()) return rewriter.replaceOp(memRefCastOp, {transformed.source()}); if (srcType.isa() && dstType.isa()) { // Casting ranked to unranked memref type // Set the rank in the destination from the memref type // Allocate space on the stack and copy the src memref descriptor // Set the ptr in the destination to the stack space auto srcMemRefType = srcType.cast(); int64_t rank = srcMemRefType.getRank(); // ptr = AllocaOp sizeof(MemRefDescriptor) auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( loc, transformed.source(), rewriter); // voidptr = BitCastOp srcType* to void* auto voidPtr = rewriter.create(loc, getVoidPtrType(), ptr) .getResult(); // rank = ConstantOp srcRank auto rankVal = rewriter.create( loc, typeConverter->convertType(rewriter.getIntegerType(64)), rewriter.getI64IntegerAttr(rank)); // undef = UndefOp UnrankedMemRefDescriptor memRefDesc = UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); // d1 = InsertValueOp undef, rank, 0 memRefDesc.setRank(rewriter, loc, rankVal); // d2 = InsertValueOp d1, voidptr, 1 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); } else if (srcType.isa() && dstType.isa()) { // Casting from unranked type to ranked. // The operation is assumed to be doing a correct cast. If the destination // type mismatches the unranked the type, it is undefined behavior. UnrankedMemRefDescriptor memRefDesc(transformed.source()); // ptr = ExtractValueOp src, 1 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); // castPtr = BitCastOp i8* to structTy* auto castPtr = rewriter .create( loc, targetStructType.cast().getPointerTo(), ptr) .getResult(); // struct = LoadOp castPtr auto loadOp = rewriter.create(loc, castPtr); rewriter.replaceOp(memRefCastOp, loadOp.getResult()); } else { llvm_unreachable("Unsupported unranked memref to unranked memref cast"); } } }; /// Extracts allocated, aligned pointers and offset from a ranked or unranked /// memref type. In unranked case, the fields are extracted from the underlying /// ranked descriptor. static void extractPointersAndOffset(Location loc, ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, Value originalOperand, Value convertedOperand, Value *allocatedPtr, Value *alignedPtr, Value *offset = nullptr) { Type operandType = originalOperand.getType(); if (operandType.isa()) { MemRefDescriptor desc(convertedOperand); *allocatedPtr = desc.allocatedPtr(rewriter, loc); *alignedPtr = desc.alignedPtr(rewriter, loc); if (offset != nullptr) *offset = desc.offset(rewriter, loc); return; } unsigned memorySpace = operandType.cast().getMemorySpace(); Type elementType = operandType.cast().getElementType(); LLVM::LLVMType llvmElementType = unwrap(typeConverter.convertType(elementType)); LLVM::LLVMType elementPtrPtrType = llvmElementType.getPointerTo(memorySpace).getPointerTo(); // Extract pointer to the underlying ranked memref descriptor and cast it to // ElemType**. UnrankedMemRefDescriptor unrankedDesc(convertedOperand); Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( rewriter, loc, underlyingDescPtr, elementPtrPtrType); *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); if (offset != nullptr) { *offset = UnrankedMemRefDescriptor::offset( rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); } } struct MemRefReinterpretCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(MemRefReinterpretCastOp castOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { MemRefReinterpretCastOp::Adaptor adaptor(operands, castOp->getAttrDictionary()); Type srcType = castOp.source().getType(); Value descriptor; if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, adaptor, &descriptor))) return failure(); rewriter.replaceOp(castOp, {descriptor}); return success(); } private: LogicalResult convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, Type srcType, MemRefReinterpretCastOp castOp, MemRefReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { MemRefType targetMemRefType = castOp.getResult().getType().cast(); auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) .dyn_cast_or_null(); if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) return failure(); // Create descriptor. Location loc = castOp.getLoc(); auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); // Set allocated and aligned pointers. Value allocatedPtr, alignedPtr; extractPointersAndOffset(loc, rewriter, *getTypeConverter(), castOp.source(), adaptor.source(), &allocatedPtr, &alignedPtr); desc.setAllocatedPtr(rewriter, loc, allocatedPtr); desc.setAlignedPtr(rewriter, loc, alignedPtr); // Set offset. if (castOp.isDynamicOffset(0)) desc.setOffset(rewriter, loc, adaptor.offsets()[0]); else desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); // Set sizes and strides. unsigned dynSizeId = 0; unsigned dynStrideId = 0; for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { if (castOp.isDynamicSize(i)) desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]); else desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); if (castOp.isDynamicStride(i)) desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]); else desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); } *descriptor = desc; return success(); } }; struct MemRefReshapeOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(MemRefReshapeOp reshapeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto *op = reshapeOp.getOperation(); MemRefReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary()); Type srcType = reshapeOp.source().getType(); Value descriptor; if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, adaptor, &descriptor))) return failure(); rewriter.replaceOp(op, {descriptor}); return success(); } private: LogicalResult convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, Type srcType, MemRefReshapeOp reshapeOp, MemRefReshapeOp::Adaptor adaptor, Value *descriptor) const { // Conversion for statically-known shape args is performed via // `memref_reinterpret_cast`. auto shapeMemRefType = reshapeOp.shape().getType().cast(); if (shapeMemRefType.hasStaticShape()) return failure(); // The shape is a rank-1 tensor with unknown length. Location loc = reshapeOp.getLoc(); MemRefDescriptor shapeDesc(adaptor.shape()); Value resultRank = shapeDesc.size(rewriter, loc, 0); // Extract address space and element type. auto targetType = reshapeOp.getResult().getType().cast(); unsigned addressSpace = targetType.getMemorySpace(); Type elementType = targetType.getElementType(); // Create the unranked memref descriptor that holds the ranked one. The // inner descriptor is allocated on stack. auto targetDesc = UnrankedMemRefDescriptor::undef( rewriter, loc, unwrap(typeConverter->convertType(targetType))); targetDesc.setRank(rewriter, loc, resultRank); SmallVector sizes; UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), targetDesc, sizes); Value underlyingDescPtr = rewriter.create( loc, getVoidPtrType(), sizes.front(), llvm::None); targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); // Extract pointers and offset from the source memref. Value allocatedPtr, alignedPtr, offset; extractPointersAndOffset(loc, rewriter, *getTypeConverter(), reshapeOp.source(), adaptor.source(), &allocatedPtr, &alignedPtr, &offset); // Set pointers and offset. LLVM::LLVMType llvmElementType = unwrap(typeConverter->convertType(elementType)); LLVM::LLVMType elementPtrPtrType = llvmElementType.getPointerTo(addressSpace).getPointerTo(); UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, elementPtrPtrType, allocatedPtr); UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrPtrType, alignedPtr); UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrPtrType, offset); // Use the offset pointer as base for further addressing. Copy over the new // shape and compute strides. For this, we create a loop from rank-1 to 0. Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrPtrType); Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); Value oneIndex = createIndexConstant(rewriter, loc, 1); Value resultRankMinusOne = rewriter.create(loc, resultRank, oneIndex); Block *initBlock = rewriter.getInsertionBlock(); LLVM::LLVMType indexType = getTypeConverter()->getIndexType(); Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, {indexType, indexType}); // Iterate over the remaining ops in initBlock and move them to condBlock. BlockAndValueMapping map; for (auto it = remainingOpsIt, e = initBlock->end(); it != e; ++it) { rewriter.clone(*it, map); rewriter.eraseOp(&*it); } rewriter.setInsertionPointToEnd(initBlock); rewriter.create(loc, ValueRange({resultRankMinusOne, oneIndex}), condBlock); rewriter.setInsertionPointToStart(condBlock); Value indexArg = condBlock->getArgument(0); Value strideArg = condBlock->getArgument(1); Value zeroIndex = createIndexConstant(rewriter, loc, 0); Value pred = rewriter.create( loc, LLVM::LLVMType::getInt1Ty(rewriter.getContext()), LLVM::ICmpPredicate::sge, indexArg, zeroIndex); Block *bodyBlock = rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); rewriter.setInsertionPointToStart(bodyBlock); // Copy size from shape to descriptor. LLVM::LLVMType llvmIndexPtrType = indexType.getPointerTo(); Value sizeLoadGep = rewriter.create( loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); Value size = rewriter.create(loc, sizeLoadGep); UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(), targetSizesBase, indexArg, size); // Write stride value and compute next one. UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(), targetStridesBase, indexArg, strideArg); Value nextStride = rewriter.create(loc, strideArg, size); // Decrement loop counter and branch back. Value decrement = rewriter.create(loc, indexArg, oneIndex); rewriter.create(loc, ValueRange({decrement, nextStride}), condBlock); Block *remainder = rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); // Hook up the cond exit to the remainder. rewriter.setInsertionPointToEnd(condBlock); rewriter.create(loc, pred, bodyBlock, llvm::None, remainder, llvm::None); // Reset position to beginning of new remainder block. rewriter.setInsertionPointToStart(remainder); *descriptor = targetDesc; return success(); } }; struct DialectCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(LLVM::DialectCastOp castOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { LLVM::DialectCastOp::Adaptor transformed(operands); if (transformed.in().getType() != typeConverter->convertType(castOp.getType())) { return failure(); } rewriter.replaceOp(castOp, transformed.in()); return success(); } }; // A `dim` is converted to a constant for static sizes and to an access to the // size stored in the memref descriptor for dynamic sizes. struct DimOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(DimOp dimOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Type operandType = dimOp.memrefOrTensor().getType(); if (operandType.isa()) { rewriter.replaceOp(dimOp, {extractSizeOfUnrankedMemRef( operandType, dimOp, operands, rewriter)}); return success(); } if (operandType.isa()) { rewriter.replaceOp(dimOp, {extractSizeOfRankedMemRef( operandType, dimOp, operands, rewriter)}); return success(); } return failure(); } private: Value extractSizeOfUnrankedMemRef(Type operandType, DimOp dimOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { Location loc = dimOp.getLoc(); DimOp::Adaptor transformed(operands); auto unrankedMemRefType = operandType.cast(); auto scalarMemRefType = MemRefType::get({}, unrankedMemRefType.getElementType()); unsigned addressSpace = unrankedMemRefType.getMemorySpace(); // Extract pointer to the underlying ranked descriptor and bitcast it to a // memref descriptor pointer to minimize the number of GEP // operations. UnrankedMemRefDescriptor unrankedDesc(transformed.memrefOrTensor()); Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); Value scalarMemRefDescPtr = rewriter.create( loc, typeConverter->convertType(scalarMemRefType) .cast() .getPointerTo(addressSpace), underlyingRankedDesc); // Get pointer to offset field of memref descriptor. Type indexPtrTy = getTypeConverter()->getIndexType().getPointerTo(addressSpace); Value two = rewriter.create( loc, typeConverter->convertType(rewriter.getI32Type()), rewriter.getI32IntegerAttr(2)); Value offsetPtr = rewriter.create( loc, indexPtrTy, scalarMemRefDescPtr, ValueRange({createIndexConstant(rewriter, loc, 0), two})); // The size value that we have to extract can be obtained using GEPop with // `dimOp.index() + 1` index argument. Value idxPlusOne = rewriter.create( loc, createIndexConstant(rewriter, loc, 1), transformed.index()); Value sizePtr = rewriter.create(loc, indexPtrTy, offsetPtr, ValueRange({idxPlusOne})); return rewriter.create(loc, sizePtr); } Value extractSizeOfRankedMemRef(Type operandType, DimOp dimOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { Location loc = dimOp.getLoc(); DimOp::Adaptor transformed(operands); // Take advantage if index is constant. MemRefType memRefType = operandType.cast(); if (Optional index = dimOp.getConstantIndex()) { int64_t i = index.getValue(); if (memRefType.isDynamicDim(i)) { // extract dynamic size from the memref descriptor. MemRefDescriptor descriptor(transformed.memrefOrTensor()); return descriptor.size(rewriter, loc, i); } // Use constant for static size. int64_t dimSize = memRefType.getDimSize(i); return createIndexConstant(rewriter, loc, dimSize); } Value index = dimOp.index(); int64_t rank = memRefType.getRank(); MemRefDescriptor memrefDescriptor(transformed.memrefOrTensor()); return memrefDescriptor.size(rewriter, loc, index, rank); } }; struct RankOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(RankOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type operandType = op.memrefOrTensor().getType(); if (auto unrankedMemRefType = operandType.dyn_cast()) { UnrankedMemRefDescriptor desc(RankOp::Adaptor(operands).memrefOrTensor()); rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); return success(); } if (auto rankedMemRefType = operandType.dyn_cast()) { rewriter.replaceOp( op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); return success(); } return failure(); } }; // Common base for load and store operations on MemRefs. Restricts the match // to supported MemRef types. Provides functionality to emit code accessing a // specific element of the underlying data buffer. template struct LoadStoreOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using ConvertOpToLLVMPattern::isSupportedMemRefType; using Base = LoadStoreOpLowering; LogicalResult match(Derived op) const override { MemRefType type = op.getMemRefType(); return isSupportedMemRefType(type) ? success() : failure(); } }; // Load operation is lowered to obtaining a pointer to the indexed element // and loading it. struct LoadOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { LoadOp::Adaptor transformed(operands); auto type = loadOp.getMemRefType(); Value dataPtr = getStridedElementPtr(loadOp.getLoc(), type, transformed.memref(), transformed.indices(), rewriter); rewriter.replaceOpWithNewOp(loadOp, dataPtr); return success(); } }; // Store operation is lowered to obtaining a pointer to the indexed element, // and storing the given value to it. struct StoreOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(StoreOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto type = op.getMemRefType(); StoreOp::Adaptor transformed(operands); Value dataPtr = getStridedElementPtr(op.getLoc(), type, transformed.memref(), transformed.indices(), rewriter); rewriter.replaceOpWithNewOp(op, transformed.value(), dataPtr); return success(); } }; // The prefetch operation is lowered in a way similar to the load operation // except that the llvm.prefetch operation is used for replacement. struct PrefetchOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(PrefetchOp prefetchOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { PrefetchOp::Adaptor transformed(operands); auto type = prefetchOp.getMemRefType(); auto loc = prefetchOp.getLoc(); Value dataPtr = getStridedElementPtr(loc, type, transformed.memref(), transformed.indices(), rewriter); // Replace with llvm.prefetch. auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32)); auto isWrite = rewriter.create( loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite())); auto localityHint = rewriter.create( loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.localityHint())); auto isData = rewriter.create( loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache())); rewriter.replaceOpWithNewOp(prefetchOp, dataPtr, isWrite, localityHint, isData); return success(); } }; // The lowering of index_cast becomes an integer conversion since index becomes // an integer. If the bit width of the source and target integer types is the // same, just erase the cast. If the target type is wider, sign-extend the // value, otherwise truncate it. struct IndexCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(IndexCastOp indexCastOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { IndexCastOpAdaptor transformed(operands); auto targetType = typeConverter->convertType(indexCastOp.getResult().getType()) .cast(); auto sourceType = transformed.in().getType().cast(); unsigned targetBits = targetType.getIntegerBitWidth(); unsigned sourceBits = sourceType.getIntegerBitWidth(); if (targetBits == sourceBits) rewriter.replaceOp(indexCastOp, transformed.in()); else if (targetBits < sourceBits) rewriter.replaceOpWithNewOp(indexCastOp, targetType, transformed.in()); else rewriter.replaceOpWithNewOp(indexCastOp, targetType, transformed.in()); return success(); } }; // Convert std.cmp predicate into the LLVM dialect CmpPredicate. The two // enums share the numerical values so just cast. template static LLVMPredType convertCmpPredicate(StdPredType pred) { return static_cast(pred); } struct CmpIOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(CmpIOp cmpiOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { CmpIOpAdaptor transformed(operands); rewriter.replaceOpWithNewOp( cmpiOp, typeConverter->convertType(cmpiOp.getResult().getType()), rewriter.getI64IntegerAttr(static_cast( convertCmpPredicate(cmpiOp.getPredicate()))), transformed.lhs(), transformed.rhs()); return success(); } }; struct CmpFOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(CmpFOp cmpfOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { CmpFOpAdaptor transformed(operands); rewriter.replaceOpWithNewOp( cmpfOp, typeConverter->convertType(cmpfOp.getResult().getType()), rewriter.getI64IntegerAttr(static_cast( convertCmpPredicate(cmpfOp.getPredicate()))), transformed.lhs(), transformed.rhs()); return success(); } }; struct SIToFPLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct UIToFPLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct FPExtLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct FPToSILowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct FPToUILowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct FPTruncLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct SignExtendIOpLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct TruncateIOpLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct ZeroExtendIOpLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; // Base class for LLVM IR lowering terminator operations with successors. template struct OneToOneLLVMTerminatorLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Super = OneToOneLLVMTerminatorLowering; LogicalResult matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, operands, op->getSuccessors(), op.getAttrs()); return success(); } }; // Special lowering pattern for `ReturnOps`. Unlike all other operations, // `ReturnOp` interacts with the function signature and must have as many // operands as the function has return values. Because in LLVM IR, functions // can only return 0 or 1 value, we pack multiple values into a structure type. // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if // necessary before returning it struct ReturnOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(ReturnOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); unsigned numArguments = op.getNumOperands(); SmallVector updatedOperands; if (getTypeConverter()->getOptions().useBarePtrCallConv) { // For the bare-ptr calling convention, extract the aligned pointer to // be returned from the memref descriptor. for (auto it : llvm::zip(op->getOperands(), operands)) { Type oldTy = std::get<0>(it).getType(); Value newOperand = std::get<1>(it); if (oldTy.isa()) { MemRefDescriptor memrefDesc(newOperand); newOperand = memrefDesc.alignedPtr(rewriter, loc); } else if (oldTy.isa()) { // Unranked memref is not supported in the bare pointer calling // convention. return failure(); } updatedOperands.push_back(newOperand); } } else { updatedOperands = llvm::to_vector<4>(operands); copyUnrankedDescriptors(rewriter, loc, *getTypeConverter(), op.getOperands().getTypes(), updatedOperands, /*toDynamic=*/true); } // If ReturnOp has 0 or 1 operand, create it and return immediately. if (numArguments == 0) { rewriter.replaceOpWithNewOp(op, TypeRange(), ValueRange(), op.getAttrs()); return success(); } if (numArguments == 1) { rewriter.replaceOpWithNewOp( op, TypeRange(), updatedOperands, op.getAttrs()); return success(); } // Otherwise, we need to pack the arguments into an LLVM struct type before // returning. auto packedType = getTypeConverter()->packFunctionResults( llvm::to_vector<4>(op.getOperandTypes())); Value packed = rewriter.create(loc, packedType); for (unsigned i = 0; i < numArguments; ++i) { packed = rewriter.create( loc, packedType, packed, updatedOperands[i], rewriter.getI64ArrayAttr(i)); } rewriter.replaceOpWithNewOp(op, TypeRange(), packed, op.getAttrs()); return success(); } }; // FIXME: this should be tablegen'ed as well. struct BranchOpLowering : public OneToOneLLVMTerminatorLowering { using Super::Super; }; struct CondBranchOpLowering : public OneToOneLLVMTerminatorLowering { using Super::Super; }; // The Splat operation is lowered to an insertelement + a shufflevector // operation. Splat to only 1-d vector result types are lowered. struct SplatOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(SplatOp splatOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { VectorType resultType = splatOp.getType().dyn_cast(); if (!resultType || resultType.getRank() != 1) return failure(); // First insert it into an undef vector so we can shuffle it. auto vectorType = typeConverter->convertType(splatOp.getType()); Value undef = rewriter.create(splatOp.getLoc(), vectorType); auto zero = rewriter.create( splatOp.getLoc(), typeConverter->convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); auto v = rewriter.create( splatOp.getLoc(), vectorType, undef, splatOp.getOperand(), zero); int64_t width = splatOp.getType().cast().getDimSize(0); SmallVector zeroValues(width, 0); // Shuffle the value across the desired number of elements. ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); rewriter.replaceOpWithNewOp(splatOp, v, undef, zeroAttrs); return success(); } }; // The Splat operation is lowered to an insertelement + a shufflevector // operation. Splat to only 2+-d vector result types are lowered by the // SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. struct SplatNdOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(SplatOp splatOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { SplatOp::Adaptor adaptor(operands); VectorType resultType = splatOp.getType().dyn_cast(); if (!resultType || resultType.getRank() == 1) return failure(); // First insert it into an undef vector so we can shuffle it. auto loc = splatOp.getLoc(); auto vectorTypeInfo = extractNDVectorTypeInfo(resultType, *getTypeConverter()); auto llvmArrayTy = vectorTypeInfo.llvmArrayTy; auto llvmVectorTy = vectorTypeInfo.llvmVectorTy; if (!llvmArrayTy || !llvmVectorTy) return failure(); // Construct returned value. Value desc = rewriter.create(loc, llvmArrayTy); // Construct a 1-D vector with the splatted value that we insert in all the // places within the returned descriptor. Value vdesc = rewriter.create(loc, llvmVectorTy); auto zero = rewriter.create( loc, typeConverter->convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); Value v = rewriter.create(loc, llvmVectorTy, vdesc, adaptor.input(), zero); // Shuffle the value across the desired number of elements. int64_t width = resultType.getDimSize(resultType.getRank() - 1); SmallVector zeroValues(width, 0); ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); v = rewriter.create(loc, v, v, zeroAttrs); // Iterate of linear index, convert to coords space and insert splatted 1-D // vector in each position. nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { desc = rewriter.create(loc, llvmArrayTy, desc, v, position); }); rewriter.replaceOp(splatOp, desc); return success(); } }; /// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr. static SmallVector extractFromI64ArrayAttr(Attribute attr) { return llvm::to_vector<4>( llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { return a.cast().getInt(); })); } /// Conversion pattern that transforms a subview op into: /// 1. An `llvm.mlir.undef` operation to create a memref descriptor /// 2. Updates to the descriptor to introduce the data ptr, offset, size /// and stride. /// The subview op is replaced by the descriptor. struct SubViewOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(SubViewOp subViewOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = subViewOp.getLoc(); auto sourceMemRefType = subViewOp.source().getType().cast(); auto sourceElementTy = typeConverter->convertType(sourceMemRefType.getElementType()) .dyn_cast_or_null(); auto viewMemRefType = subViewOp.getType(); auto inferredType = SubViewOp::inferResultType( subViewOp.getSourceType(), extractFromI64ArrayAttr(subViewOp.static_offsets()), extractFromI64ArrayAttr(subViewOp.static_sizes()), extractFromI64ArrayAttr(subViewOp.static_strides())) .cast(); auto targetElementTy = typeConverter->convertType(viewMemRefType.getElementType()) .dyn_cast(); auto targetDescTy = typeConverter->convertType(viewMemRefType) .dyn_cast_or_null(); if (!sourceElementTy || !targetDescTy) return failure(); // Extract the offset and strides from the type. int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(inferredType, strides, offset); if (failed(successStrides)) return failure(); // Create the descriptor. if (!operands.front().getType().isa()) return failure(); MemRefDescriptor sourceMemRef(operands.front()); auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); // Copy the buffer pointer from the old descriptor to the new one. Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); Value bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(viewMemRefType.getMemorySpace()), extracted); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); // Copy the buffer pointer from the old descriptor to the new one. extracted = sourceMemRef.alignedPtr(rewriter, loc); bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(viewMemRefType.getMemorySpace()), extracted); targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); auto shape = viewMemRefType.getShape(); auto inferredShape = inferredType.getShape(); size_t inferredShapeRank = inferredShape.size(); size_t resultShapeRank = shape.size(); SmallVector mask = computeRankReductionMask(inferredShape, shape).getValue(); // Extract strides needed to compute offset. SmallVector strideValues; strideValues.reserve(inferredShapeRank); for (unsigned i = 0; i < inferredShapeRank; ++i) strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); // Offset. auto llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); if (!ShapedType::isDynamicStrideOrOffset(offset)) { targetMemRef.setConstantOffset(rewriter, loc, offset); } else { Value baseOffset = sourceMemRef.offset(rewriter, loc); for (unsigned i = 0; i < inferredShapeRank; ++i) { Value offset = subViewOp.isDynamicOffset(i) ? operands[subViewOp.getIndexOfDynamicOffset(i)] : rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i))); Value mul = rewriter.create(loc, offset, strideValues[i]); baseOffset = rewriter.create(loc, baseOffset, mul); } targetMemRef.setOffset(rewriter, loc, baseOffset); } // Update sizes and strides. for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; i >= 0 && j >= 0; --i) { if (!mask[i]) continue; Value size = subViewOp.isDynamicSize(i) ? operands[subViewOp.getIndexOfDynamicSize(i)] : rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); targetMemRef.setSize(rewriter, loc, j, size); Value stride; if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { stride = rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); } else { stride = subViewOp.isDynamicStride(i) ? operands[subViewOp.getIndexOfDynamicStride(i)] : rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(subViewOp.getStaticStride(i))); stride = rewriter.create(loc, stride, strideValues[i]); } targetMemRef.setStride(rewriter, loc, j, stride); j--; } rewriter.replaceOp(subViewOp, {targetMemRef}); return success(); } }; /// Conversion pattern that transforms a transpose op into: /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. /// 2. A load of the ViewDescriptor from the pointer allocated in 1. /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size /// and stride. Size and stride are permutations of the original values. /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. /// The transpose op is replaced by the alloca'ed pointer. class TransposeOpLowering : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(TransposeOp transposeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = transposeOp.getLoc(); TransposeOpAdaptor adaptor(operands); MemRefDescriptor viewMemRef(adaptor.in()); // No permutation, early exit. if (transposeOp.permutation().isIdentity()) return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); auto targetMemRef = MemRefDescriptor::undef( rewriter, loc, typeConverter->convertType(transposeOp.getShapedType())); // Copy the base and aligned pointers from the old descriptor to the new // one. targetMemRef.setAllocatedPtr(rewriter, loc, viewMemRef.allocatedPtr(rewriter, loc)); targetMemRef.setAlignedPtr(rewriter, loc, viewMemRef.alignedPtr(rewriter, loc)); // Copy the offset pointer from the old descriptor to the new one. targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); // Iterate over the dimensions and apply size/stride permutation. for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) { int sourcePos = en.index(); int targetPos = en.value().cast().getPosition(); targetMemRef.setSize(rewriter, loc, targetPos, viewMemRef.size(rewriter, loc, sourcePos)); targetMemRef.setStride(rewriter, loc, targetPos, viewMemRef.stride(rewriter, loc, sourcePos)); } rewriter.replaceOp(transposeOp, {targetMemRef}); return success(); } }; /// Conversion pattern that transforms an op into: /// 1. An `llvm.mlir.undef` operation to create a memref descriptor /// 2. Updates to the descriptor to introduce the data ptr, offset, size /// and stride. /// The view op is replaced by the descriptor. struct ViewOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; // Build and return the value for the idx^th shape dimension, either by // returning the constant shape dimension or counting the proper dynamic size. Value getSize(ConversionPatternRewriter &rewriter, Location loc, ArrayRef shape, ValueRange dynamicSizes, unsigned idx) const { assert(idx < shape.size()); if (!ShapedType::isDynamic(shape[idx])) return createIndexConstant(rewriter, loc, shape[idx]); // Count the number of dynamic dims in range [0, idx] unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) { return ShapedType::isDynamic(v); }); return dynamicSizes[nDynamic]; } // Build and return the idx^th stride, either by returning the constant stride // or by computing the dynamic stride from the current `runningStride` and // `nextSize`. The caller should keep a running stride and update it with the // result returned by this function. Value getStride(ConversionPatternRewriter &rewriter, Location loc, ArrayRef strides, Value nextSize, Value runningStride, unsigned idx) const { assert(idx < strides.size()); if (strides[idx] != MemRefType::getDynamicStrideOrOffset()) return createIndexConstant(rewriter, loc, strides[idx]); if (nextSize) return runningStride ? rewriter.create(loc, runningStride, nextSize) : nextSize; assert(!runningStride); return createIndexConstant(rewriter, loc, 1); } LogicalResult matchAndRewrite(ViewOp viewOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = viewOp.getLoc(); ViewOpAdaptor adaptor(operands); auto viewMemRefType = viewOp.getType(); auto targetElementTy = typeConverter->convertType(viewMemRefType.getElementType()) .dyn_cast(); auto targetDescTy = typeConverter->convertType(viewMemRefType).dyn_cast(); if (!targetDescTy) return viewOp.emitWarning("Target descriptor type not converted to LLVM"), failure(); int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); if (failed(successStrides)) return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); assert(offset == 0 && "expected offset to be 0"); // Create the descriptor. MemRefDescriptor sourceMemRef(adaptor.source()); auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); // Field 1: Copy the allocated pointer, used for malloc/free. Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); auto srcMemRefType = viewOp.source().getType().cast(); Value bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(srcMemRefType.getMemorySpace()), allocatedPtr); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); // Field 2: Copy the actual aligned pointer to payload. Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); alignedPtr = rewriter.create(loc, alignedPtr.getType(), alignedPtr, adaptor.byte_shift()); bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(srcMemRefType.getMemorySpace()), alignedPtr); targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); // Field 3: The offset in the resulting type must be 0. This is because of // the type change: an offset on srcType* may not be expressible as an // offset on dstType*. targetMemRef.setOffset(rewriter, loc, createIndexConstant(rewriter, loc, offset)); // Early exit for 0-D corner case. if (viewMemRefType.getRank() == 0) return rewriter.replaceOp(viewOp, {targetMemRef}), success(); // Fields 4 and 5: Update sizes and strides. if (strides.back() != 1) return viewOp.emitWarning("cannot cast to non-contiguous shape"), failure(); Value stride = nullptr, nextSize = nullptr; for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { // Update size. Value size = getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i); targetMemRef.setSize(rewriter, loc, i, size); // Update stride. stride = getStride(rewriter, loc, strides, nextSize, stride, i); targetMemRef.setStride(rewriter, loc, i, stride); nextSize = size; } rewriter.replaceOp(viewOp, {targetMemRef}); return success(); } }; struct AssumeAlignmentOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(AssumeAlignmentOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { AssumeAlignmentOp::Adaptor transformed(operands); Value memref = transformed.memref(); unsigned alignment = op.alignment(); auto loc = op.getLoc(); MemRefDescriptor memRefDescriptor(memref); Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that // the asserted memref.alignedPtr isn't used anywhere else, as the real // users like load/store/views always re-extract memref.alignedPtr as they // get lowered. // // This relies on LLVM's CSE optimization (potentially after SROA), since // after CSE all memref.alignedPtr instances get de-duplicated into the same // pointer SSA value. auto intPtrType = getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0); Value mask = createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1); Value ptrValue = rewriter.create(loc, intPtrType, ptr); rewriter.create( loc, rewriter.create( loc, LLVM::ICmpPredicate::eq, rewriter.create(loc, ptrValue, mask), zero)); rewriter.eraseOp(op); return success(); } }; } // namespace /// Try to match the kind of a std.atomic_rmw to determine whether to use a /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. static Optional matchSimpleAtomicOp(AtomicRMWOp atomicOp) { switch (atomicOp.kind()) { case AtomicRMWKind::addf: return LLVM::AtomicBinOp::fadd; case AtomicRMWKind::addi: return LLVM::AtomicBinOp::add; case AtomicRMWKind::assign: return LLVM::AtomicBinOp::xchg; case AtomicRMWKind::maxs: return LLVM::AtomicBinOp::max; case AtomicRMWKind::maxu: return LLVM::AtomicBinOp::umax; case AtomicRMWKind::mins: return LLVM::AtomicBinOp::min; case AtomicRMWKind::minu: return LLVM::AtomicBinOp::umin; default: return llvm::None; } llvm_unreachable("Invalid AtomicRMWKind"); } namespace { struct AtomicRMWOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(AtomicRMWOp atomicOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (failed(match(atomicOp))) return failure(); auto maybeKind = matchSimpleAtomicOp(atomicOp); if (!maybeKind) return failure(); AtomicRMWOp::Adaptor adaptor(operands); auto resultType = adaptor.value().getType(); auto memRefType = atomicOp.getMemRefType(); auto dataPtr = getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.memref(), adaptor.indices(), rewriter); rewriter.replaceOpWithNewOp( atomicOp, resultType, *maybeKind, dataPtr, adaptor.value(), LLVM::AtomicOrdering::acq_rel); return success(); } }; /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be /// retried until it succeeds in atomically storing a new value into memory. /// /// +---------------------------------+ /// | | /// | | /// | br loop(%loaded) | /// +---------------------------------+ /// | /// -------| | /// | v v /// | +--------------------------------+ /// | | loop(%loaded): | /// | | | /// | | %pair = cmpxchg | /// | | %ok = %pair[0] | /// | | %new = %pair[1] | /// | | cond_br %ok, end, loop(%new) | /// | +--------------------------------+ /// | | | /// |----------- | /// v /// +--------------------------------+ /// | end: | /// | | /// +--------------------------------+ /// struct GenericAtomicRMWOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(GenericAtomicRMWOp atomicOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = atomicOp.getLoc(); GenericAtomicRMWOp::Adaptor adaptor(operands); LLVM::LLVMType valueType = typeConverter->convertType(atomicOp.getResult().getType()) .cast(); // Split the block into initial, loop, and ending parts. auto *initBlock = rewriter.getInsertionBlock(); auto *loopBlock = rewriter.createBlock(initBlock->getParent(), std::next(Region::iterator(initBlock)), valueType); auto *endBlock = rewriter.createBlock( loopBlock->getParent(), std::next(Region::iterator(loopBlock))); // Operations range to be moved to `endBlock`. auto opsToMoveStart = atomicOp->getIterator(); auto opsToMoveEnd = initBlock->back().getIterator(); // Compute the loaded value and branch to the loop block. rewriter.setInsertionPointToEnd(initBlock); auto memRefType = atomicOp.memref().getType().cast(); auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter); Value init = rewriter.create(loc, dataPtr); rewriter.create(loc, init, loopBlock); // Prepare the body of the loop block. rewriter.setInsertionPointToStart(loopBlock); // Clone the GenericAtomicRMWOp region and extract the result. auto loopArgument = loopBlock->getArgument(0); BlockAndValueMapping mapping; mapping.map(atomicOp.getCurrentValue(), loopArgument); Block &entryBlock = atomicOp.body().front(); for (auto &nestedOp : entryBlock.without_terminator()) { Operation *clone = rewriter.clone(nestedOp, mapping); mapping.map(nestedOp.getResults(), clone->getResults()); } Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); // Prepare the epilog of the loop block. // Append the cmpxchg op to the end of the loop block. auto successOrdering = LLVM::AtomicOrdering::acq_rel; auto failureOrdering = LLVM::AtomicOrdering::monotonic; auto boolType = LLVM::LLVMType::getInt1Ty(rewriter.getContext()); auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType); auto cmpxchg = rewriter.create( loc, pairType, dataPtr, loopArgument, result, successOrdering, failureOrdering); // Extract the %new_loaded and %ok values from the pair. Value newLoaded = rewriter.create( loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); Value ok = rewriter.create( loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); // Conditionally branch to the end or back to the loop depending on %ok. rewriter.create(loc, ok, endBlock, ArrayRef(), loopBlock, newLoaded); rewriter.setInsertionPointToEnd(endBlock); moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart), std::next(opsToMoveEnd), rewriter); // The 'result' of the atomic_rmw op is the newly loaded value. rewriter.replaceOp(atomicOp, {newLoaded}); return success(); } private: // Clones a segment of ops [start, end) and erases the original. void moveOpsRange(ValueRange oldResult, ValueRange newResult, Block::iterator start, Block::iterator end, ConversionPatternRewriter &rewriter) const { BlockAndValueMapping mapping; mapping.map(oldResult, newResult); SmallVector opsToErase; for (auto it = start; it != end; ++it) { rewriter.clone(*it, mapping); opsToErase.push_back(&*it); } for (auto *it : opsToErase) rewriter.eraseOp(it); } }; } // namespace /// Collect a set of patterns to convert from the Standard dialect to LLVM. void mlir::populateStdToLLVMNonMemoryConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // FIXME: this should be tablegen'ed // clang-format off patterns.insert< AbsFOpLowering, AddCFOpLowering, AddFOpLowering, AddIOpLowering, AllocaOpLowering, AndOpLowering, AssertOpLowering, AtomicRMWOpLowering, BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CeilFOpLowering, CmpFOpLowering, CmpIOpLowering, CondBranchOpLowering, CopySignOpLowering, CosOpLowering, ConstantOpLowering, CreateComplexOpLowering, DialectCastOpLowering, DivFOpLowering, ExpOpLowering, Exp2OpLowering, FloorFOpLowering, GenericAtomicRMWOpLowering, LogOpLowering, Log10OpLowering, Log2OpLowering, FPExtLowering, FPToSILowering, FPToUILowering, FPTruncLowering, ImOpLowering, IndexCastOpLowering, MulFOpLowering, MulIOpLowering, NegFOpLowering, OrOpLowering, PrefetchOpLowering, ReOpLowering, RemFOpLowering, ReturnOpLowering, RsqrtOpLowering, SIToFPLowering, SelectOpLowering, ShiftLeftOpLowering, SignExtendIOpLowering, SignedDivIOpLowering, SignedRemIOpLowering, SignedShiftRightOpLowering, SinOpLowering, SplatOpLowering, SplatNdOpLowering, SqrtOpLowering, SubCFOpLowering, SubFOpLowering, SubIOpLowering, TruncateIOpLowering, UIToFPLowering, UnsignedDivIOpLowering, UnsignedRemIOpLowering, UnsignedShiftRightOpLowering, XOrOpLowering, ZeroExtendIOpLowering>(converter); // clang-format on } void mlir::populateStdToLLVMMemoryConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // clang-format off patterns.insert< AssumeAlignmentOpLowering, DeallocOpLowering, DimOpLowering, GlobalMemrefOpLowering, GetGlobalMemrefOpLowering, LoadOpLowering, MemRefCastOpLowering, MemRefReinterpretCastOpLowering, MemRefReshapeOpLowering, RankOpLowering, StoreOpLowering, SubViewOpLowering, TransposeOpLowering, ViewOpLowering>(converter); // clang-format on if (converter.getOptions().useAlignedAlloc) patterns.insert(converter); else patterns.insert(converter); } void mlir::populateStdToLLVMFuncOpConversionPattern( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { if (converter.getOptions().useBarePtrCallConv) patterns.insert(converter); else patterns.insert(converter); } void mlir::populateStdToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { populateStdToLLVMFuncOpConversionPattern(converter, patterns); populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); populateStdToLLVMMemoryConversionPatterns(converter, patterns); } /// Convert a non-empty list of types to be returned from a function into a /// supported LLVM IR type. In particular, if more than one value is returned, /// create an LLVM IR structure type with elements that correspond to each of /// the MLIR types converted with `convertType`. Type LLVMTypeConverter::packFunctionResults(ArrayRef types) { assert(!types.empty() && "expected non-empty list of type"); if (types.size() == 1) return convertCallingConventionType(types.front()); SmallVector resultTypes; resultTypes.reserve(types.size()); for (auto t : types) { auto converted = convertCallingConventionType(t).dyn_cast_or_null(); if (!converted) return {}; resultTypes.push_back(converted); } return LLVM::LLVMType::getStructTy(&getContext(), resultTypes); } Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) { auto *context = builder.getContext(); auto int64Ty = LLVM::LLVMType::getInt64Ty(builder.getContext()); auto indexType = IndexType::get(context); // Alloca with proper alignment. We do not expect optimizations of this // alloca op and so we omit allocating at the entry block. auto ptrType = operand.getType().cast().getPointerTo(); Value one = builder.create(loc, int64Ty, IntegerAttr::get(indexType, 1)); Value allocated = builder.create(loc, ptrType, one, /*alignment=*/0); // Store into the alloca'ed descriptor. builder.create(loc, operand, allocated); return allocated; } SmallVector LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands, ValueRange operands, OpBuilder &builder) { SmallVector promotedOperands; promotedOperands.reserve(operands.size()); for (auto it : llvm::zip(opOperands, operands)) { auto operand = std::get<0>(it); auto llvmOperand = std::get<1>(it); if (options.useBarePtrCallConv) { // For the bare-ptr calling convention, we only have to extract the // aligned pointer of a memref. if (auto memrefType = operand.getType().dyn_cast()) { MemRefDescriptor desc(llvmOperand); llvmOperand = desc.alignedPtr(builder, loc); } else if (operand.getType().isa()) { llvm_unreachable("Unranked memrefs are not supported"); } } else { if (operand.getType().isa()) { UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand, promotedOperands); continue; } if (auto memrefType = operand.getType().dyn_cast()) { MemRefDescriptor::unpack(builder, loc, llvmOperand, operand.getType().cast(), promotedOperands); continue; } } promotedOperands.push_back(llvmOperand); } return promotedOperands; } namespace { /// A pass converting MLIR operations into the LLVM IR dialect. struct LLVMLoweringPass : public ConvertStandardToLLVMBase { LLVMLoweringPass() = default; LLVMLoweringPass(bool useBarePtrCallConv, bool emitCWrappers, unsigned indexBitwidth, bool useAlignedAlloc, const llvm::DataLayout &dataLayout) { this->useBarePtrCallConv = useBarePtrCallConv; this->emitCWrappers = emitCWrappers; this->indexBitwidth = indexBitwidth; this->useAlignedAlloc = useAlignedAlloc; this->dataLayout = dataLayout.getStringRepresentation(); } /// Run the dialect converter on the module. void runOnOperation() override { if (useBarePtrCallConv && emitCWrappers) { getOperation().emitError() << "incompatible conversion options: bare-pointer calling convention " "and C wrapper emission"; signalPassFailure(); return; } if (failed(LLVM::LLVMDialect::verifyDataLayoutString( this->dataLayout, [this](const Twine &message) { getOperation().emitError() << message.str(); }))) { signalPassFailure(); return; } ModuleOp m = getOperation(); LowerToLLVMOptions options = {useBarePtrCallConv, emitCWrappers, indexBitwidth, useAlignedAlloc, llvm::DataLayout(this->dataLayout)}; LLVMTypeConverter typeConverter(&getContext(), options); OwningRewritePatternList patterns; populateStdToLLVMConversionPatterns(typeConverter, patterns); LLVMConversionTarget target(getContext()); if (failed(applyPartialConversion(m, target, std::move(patterns)))) signalPassFailure(); m.setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), StringAttr::get(this->dataLayout, m.getContext())); } }; } // end namespace mlir::LLVMConversionTarget::LLVMConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { this->addLegalDialect(); this->addIllegalOp(); this->addIllegalOp(); } std::unique_ptr> mlir::createLowerToLLVMPass(const LowerToLLVMOptions &options) { return std::make_unique( options.useBarePtrCallConv, options.emitCWrappers, options.indexBitwidth, options.useAlignedAlloc, options.dataLayout); }