//===- SPIRVLowering.cpp - SPIR-V lowering utilities ----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements utilities used to lower to SPIR-V dialect. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/SPIRVLowering.h" #include "mlir/Dialect/SPIRV/LayoutUtils.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Debug.h" #include #define DEBUG_TYPE "mlir-spirv-lowering" using namespace mlir; //===----------------------------------------------------------------------===// // Utility functions //===----------------------------------------------------------------------===// /// Checks that `candidates` extension requirements are possible to be satisfied /// with the given `targetEnv`. /// /// `candidates` is a vector of vector for extension requirements following /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) /// convention. template static LogicalResult checkExtensionRequirements( LabelT label, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::ExtensionArrayRefVector &candidates) { for (const auto &ors : candidates) { if (targetEnv.allows(ors)) continue; SmallVector extStrings; for (spirv::Extension ext : ors) extStrings.push_back(spirv::stringifyExtension(ext)); LLVM_DEBUG(llvm::dbgs() << label << " illegal: requires at least one extension in [" << llvm::join(extStrings, ", ") << "] but none allowed in target environment\n"); return failure(); } return success(); } /// Checks that `candidates`capability requirements are possible to be satisfied /// with the given `isAllowedFn`. /// /// `candidates` is a vector of vector for capability requirements following /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D)) /// convention. template static LogicalResult checkCapabilityRequirements( LabelT label, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::CapabilityArrayRefVector &candidates) { for (const auto &ors : candidates) { if (targetEnv.allows(ors)) continue; SmallVector capStrings; for (spirv::Capability cap : ors) capStrings.push_back(spirv::stringifyCapability(cap)); LLVM_DEBUG(llvm::dbgs() << label << " illegal: requires at least one capability in [" << llvm::join(capStrings, ", ") << "] but none allowed in target environment\n"); return failure(); } return success(); } //===----------------------------------------------------------------------===// // Type Conversion //===----------------------------------------------------------------------===// Type SPIRVTypeConverter::getIndexType(MLIRContext *context) { // Convert to 32-bit integers for now. Might need a way to control this in // future. // TODO: It is probably better to make it 64-bit integers. To // this some support is needed in SPIR-V dialect for Conversion // instructions. The Vulkan spec requires the builtins like // GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be // SExtended to 64-bit for index computations. return IntegerType::get(32, context); } /// Mapping between SPIR-V storage classes to memref memory spaces. /// /// Note: memref does not have a defined semantics for each memory space; it /// depends on the context where it is used. There are no particular reasons /// behind the number assignments; we try to follow NVVM conventions and largely /// give common storage classes a smaller number. The hope is use symbolic /// memory space representation eventually after memref supports it. // TODO: swap Generic and StorageBuffer assignment to be more akin // to NVVM. #define STORAGE_SPACE_MAP_LIST(MAP_FN) \ MAP_FN(spirv::StorageClass::Generic, 1) \ MAP_FN(spirv::StorageClass::StorageBuffer, 0) \ MAP_FN(spirv::StorageClass::Workgroup, 3) \ MAP_FN(spirv::StorageClass::Uniform, 4) \ MAP_FN(spirv::StorageClass::Private, 5) \ MAP_FN(spirv::StorageClass::Function, 6) \ MAP_FN(spirv::StorageClass::PushConstant, 7) \ MAP_FN(spirv::StorageClass::UniformConstant, 8) \ MAP_FN(spirv::StorageClass::Input, 9) \ MAP_FN(spirv::StorageClass::Output, 10) \ MAP_FN(spirv::StorageClass::CrossWorkgroup, 11) \ MAP_FN(spirv::StorageClass::AtomicCounter, 12) \ MAP_FN(spirv::StorageClass::Image, 13) \ MAP_FN(spirv::StorageClass::CallableDataNV, 14) \ MAP_FN(spirv::StorageClass::IncomingCallableDataNV, 15) \ MAP_FN(spirv::StorageClass::RayPayloadNV, 16) \ MAP_FN(spirv::StorageClass::HitAttributeNV, 17) \ MAP_FN(spirv::StorageClass::IncomingRayPayloadNV, 18) \ MAP_FN(spirv::StorageClass::ShaderRecordBufferNV, 19) \ MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20) unsigned SPIRVTypeConverter::getMemorySpaceForStorageClass(spirv::StorageClass storage) { #define STORAGE_SPACE_MAP_FN(storage, space) \ case storage: \ return space; switch (storage) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) } #undef STORAGE_SPACE_MAP_FN llvm_unreachable("unhandled storage class!"); } Optional SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) { #define STORAGE_SPACE_MAP_FN(storage, space) \ case space: \ return storage; switch (space) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) default: return llvm::None; } #undef STORAGE_SPACE_MAP_FN } #undef STORAGE_SPACE_MAP_LIST // TODO: This is a utility function that should probably be // exposed by the SPIR-V dialect. Keeping it local till the use case arises. static Optional getTypeNumBytes(Type t) { if (t.isa()) { auto bitWidth = t.getIntOrFloatBitWidth(); // According to the SPIR-V spec: // "There is no physical size or bit pattern defined for values with boolean // type. If they are stored (in conjunction with OpVariable), they can only // be used with logical addressing operations, not physical, and only with // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup, // Private, Function, Input, and Output." if (bitWidth == 1) { return llvm::None; } return bitWidth / 8; } if (auto vecType = t.dyn_cast()) { auto elementSize = getTypeNumBytes(vecType.getElementType()); if (!elementSize) return llvm::None; return vecType.getNumElements() * *elementSize; } if (auto memRefType = t.dyn_cast()) { // TODO: Layout should also be controlled by the ABI attributes. For now // using the layout from MemRef. int64_t offset; SmallVector strides; if (!memRefType.hasStaticShape() || failed(getStridesAndOffset(memRefType, strides, offset))) { return llvm::None; } // To get the size of the memref object in memory, the total size is the // max(stride * dimension-size) computed for all dimensions times the size // of the element. auto elementSize = getTypeNumBytes(memRefType.getElementType()); if (!elementSize) { return llvm::None; } if (memRefType.getRank() == 0) { return elementSize; } auto dims = memRefType.getShape(); if (llvm::is_contained(dims, ShapedType::kDynamicSize) || offset == MemRefType::getDynamicStrideOrOffset() || llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) { return llvm::None; } int64_t memrefSize = -1; for (auto shape : enumerate(dims)) { memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]); } return (offset + memrefSize) * elementSize.getValue(); } else if (auto tensorType = t.dyn_cast()) { if (!tensorType.hasStaticShape()) { return llvm::None; } auto elementSize = getTypeNumBytes(tensorType.getElementType()); if (!elementSize) { return llvm::None; } int64_t size = elementSize.getValue(); for (auto shape : tensorType.getShape()) { size *= shape; } return size; } // TODO: Add size computation for other types. return llvm::None; } Optional SPIRVTypeConverter::getConvertedTypeNumBytes(Type t) { return getTypeNumBytes(t); } /// Converts a scalar `type` to a suitable type under the given `targetEnv`. static Optional convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type, Optional storageClass = {}) { // Get extension and capability requirements for the given type. SmallVector, 1> extensions; SmallVector, 2> capabilities; type.getExtensions(extensions, storageClass); type.getCapabilities(capabilities, storageClass); // If all requirements are met, then we can accept this type as-is. if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && succeeded(checkExtensionRequirements(type, targetEnv, extensions))) return type; // Otherwise we need to adjust the type, which really means adjusting the // bitwidth given this is a scalar type. // TODO: We are unconditionally converting the bitwidth here, // this might be okay for non-interface types (i.e., types used in // Private/Function storage classes), but not for interface types (i.e., // types used in StorageBuffer/Uniform/PushConstant/etc. storage classes). // This is because the later actually affects the ABI contract with the // runtime. So we may want to expose a control on SPIRVTypeConverter to fail // conversion if we cannot change there. if (auto floatType = type.dyn_cast()) { LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); return Builder(targetEnv.getContext()).getF32Type(); } auto intType = type.cast(); LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); return IntegerType::get(/*width=*/32, intType.getSignedness(), targetEnv.getContext()); } /// Converts a vector `type` to a suitable type under the given `targetEnv`. static Optional convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type, Optional storageClass = {}) { if (!spirv::CompositeType::isValid(type)) { // TODO: One-element vector types can be translated into scalar // types. Vector types with more than four elements can be translated into // array types. LLVM_DEBUG(llvm::dbgs() << type << " illegal: 1- and > 4-element unimplemented\n"); return llvm::None; } // Get extension and capability requirements for the given type. SmallVector, 1> extensions; SmallVector, 2> capabilities; type.cast().getExtensions(extensions, storageClass); type.cast().getCapabilities(capabilities, storageClass); // If all requirements are met, then we can accept this type as-is. if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && succeeded(checkExtensionRequirements(type, targetEnv, extensions))) return type; auto elementType = convertScalarType( targetEnv, type.getElementType().cast(), storageClass); if (elementType) return VectorType::get(type.getShape(), *elementType); return llvm::None; } /// Converts a tensor `type` to a suitable type under the given `targetEnv`. /// /// Note that this is mainly for lowering constant tensors.In SPIR-V one can /// create composite constants with OpConstantComposite to embed relative large /// constant values and use OpCompositeExtract and OpCompositeInsert to /// manipulate, like what we do for vectors. static Optional convertTensorType(const spirv::TargetEnv &targetEnv, TensorType type) { // TODO: Handle dynamic shapes. if (!type.hasStaticShape()) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: dynamic shape unimplemented\n"); return llvm::None; } auto scalarType = type.getElementType().dyn_cast(); if (!scalarType) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot convert non-scalar element type\n"); return llvm::None; } Optional scalarSize = getTypeNumBytes(scalarType); Optional tensorSize = getTypeNumBytes(type); if (!scalarSize || !tensorSize) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot deduce element count\n"); return llvm::None; } auto arrayElemCount = *tensorSize / *scalarSize; auto arrayElemType = convertScalarType(targetEnv, scalarType); if (!arrayElemType) return llvm::None; Optional arrayElemSize = getTypeNumBytes(*arrayElemType); if (!arrayElemSize) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot deduce converted element size\n"); return llvm::None; } return spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize); } static Optional convertMemrefType(const spirv::TargetEnv &targetEnv, MemRefType type) { Optional storageClass = SPIRVTypeConverter::getStorageClassForMemorySpace(type.getMemorySpace()); if (!storageClass) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot convert memory space\n"); return llvm::None; } Optional arrayElemType; Type elementType = type.getElementType(); if (auto vecType = elementType.dyn_cast()) { arrayElemType = convertVectorType(targetEnv, vecType, storageClass); } else if (auto scalarType = elementType.dyn_cast()) { arrayElemType = convertScalarType(targetEnv, scalarType, storageClass); } else { LLVM_DEBUG( llvm::dbgs() << type << " unhandled: can only convert scalar or vector element type\n"); return llvm::None; } if (!arrayElemType) return llvm::None; Optional elementSize = getTypeNumBytes(elementType); if (!elementSize) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot deduce element size\n"); return llvm::None; } if (!type.hasStaticShape()) { auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *elementSize); // Wrap in a struct to satisfy Vulkan interface requirements. auto structType = spirv::StructType::get(arrayType, 0); return spirv::PointerType::get(structType, *storageClass); } Optional memrefSize = getTypeNumBytes(type); if (!memrefSize) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot deduce element count\n"); return llvm::None; } auto arrayElemCount = *memrefSize / *elementSize; Optional arrayElemSize = getTypeNumBytes(*arrayElemType); if (!arrayElemSize) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot deduce converted element size\n"); return llvm::None; } auto arrayType = spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize); // Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with // workgroup storage class do not need the struct to be laid out explicitly. auto structType = *storageClass == spirv::StorageClass::Workgroup ? spirv::StructType::get(arrayType) : spirv::StructType::get(arrayType, 0); return spirv::PointerType::get(structType, *storageClass); } SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr) : targetEnv(targetAttr) { // Add conversions. The order matters here: later ones will be tried earlier. // All other cases failed. Then we cannot convert this type. addConversion([](Type type) { return llvm::None; }); // Allow all SPIR-V dialect specific types. This assumes all builtin types // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType) // were tried before. // // TODO: this assumes that the SPIR-V types are valid to use in // the given target environment, which should be the case if the whole // pipeline is driven by the same target environment. Still, we probably still // want to validate and convert to be safe. addConversion([](spirv::SPIRVType type) { return type; }); addConversion([](IndexType indexType) { return SPIRVTypeConverter::getIndexType(indexType.getContext()); }); addConversion([this](IntegerType intType) -> Optional { if (auto scalarType = intType.dyn_cast()) return convertScalarType(targetEnv, scalarType); return llvm::None; }); addConversion([this](FloatType floatType) -> Optional { if (auto scalarType = floatType.dyn_cast()) return convertScalarType(targetEnv, scalarType); return llvm::None; }); addConversion([this](VectorType vectorType) { return convertVectorType(targetEnv, vectorType); }); addConversion([this](TensorType tensorType) { return convertTensorType(targetEnv, tensorType); }); addConversion([this](MemRefType memRefType) { return convertMemrefType(targetEnv, memRefType); }); } //===----------------------------------------------------------------------===// // FuncOp Conversion Patterns //===----------------------------------------------------------------------===// namespace { /// A pattern for rewriting function signature to convert arguments of functions /// to be of valid SPIR-V types. class FuncOpConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; LogicalResult matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { auto fnType = funcOp.getType(); // TODO: support converting functions with one result. if (fnType.getNumResults()) return failure(); TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); for (auto argType : enumerate(funcOp.getType().getInputs())) { auto convertedType = typeConverter.convertType(argType.value()); if (!convertedType) return failure(); signatureConverter.addInputs(argType.index(), convertedType); } // Create the converted spv.func op. auto newFuncOp = rewriter.create( funcOp.getLoc(), funcOp.getName(), rewriter.getFunctionType(signatureConverter.getConvertedTypes(), llvm::None)); // Copy over all attributes other than the function name and type. for (const auto &namedAttr : funcOp.getAttrs()) { if (namedAttr.first != impl::getTypeAttrName() && namedAttr.first != SymbolTable::getSymbolAttrName()) newFuncOp.setAttr(namedAttr.first, namedAttr.second); } rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, &signatureConverter))) return failure(); rewriter.eraseOp(funcOp); return success(); } void mlir::populateBuiltinFuncToSPIRVPatterns( MLIRContext *context, SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns) { patterns.insert(context, typeConverter); } //===----------------------------------------------------------------------===// // Builtin Variables //===----------------------------------------------------------------------===// static spirv::GlobalVariableOp getBuiltinVariable(Block &body, spirv::BuiltIn builtin) { // Look through all global variables in the given `body` block and check if // there is a spv.globalVariable that has the same `builtin` attribute. for (auto varOp : body.getOps()) { if (auto builtinAttr = varOp->getAttrOfType( spirv::SPIRVDialect::getAttributeName( spirv::Decoration::BuiltIn))) { auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue()); if (varBuiltIn && varBuiltIn.getValue() == builtin) { return varOp; } } } return nullptr; } /// Gets name of global variable for a builtin. static std::string getBuiltinVarName(spirv::BuiltIn builtin) { return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__"; } /// Gets or inserts a global variable for a builtin within `body` block. static spirv::GlobalVariableOp getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, OpBuilder &builder) { if (auto varOp = getBuiltinVariable(body, builtin)) return varOp; OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(&body); spirv::GlobalVariableOp newVarOp; switch (builtin) { case spirv::BuiltIn::NumWorkgroups: case spirv::BuiltIn::WorkgroupSize: case spirv::BuiltIn::WorkgroupId: case spirv::BuiltIn::LocalInvocationId: case spirv::BuiltIn::GlobalInvocationId: { auto ptrType = spirv::PointerType::get( VectorType::get({3}, builder.getIntegerType(32)), spirv::StorageClass::Input); std::string name = getBuiltinVarName(builtin); newVarOp = builder.create(loc, ptrType, name, builtin); break; } case spirv::BuiltIn::SubgroupId: case spirv::BuiltIn::NumSubgroups: case spirv::BuiltIn::SubgroupSize: { auto ptrType = spirv::PointerType::get(builder.getIntegerType(32), spirv::StorageClass::Input); std::string name = getBuiltinVarName(builtin); newVarOp = builder.create(loc, ptrType, name, builtin); break; } default: emitError(loc, "unimplemented builtin variable generation for ") << stringifyBuiltIn(builtin); } return newVarOp; } Value mlir::spirv::getBuiltinVariableValue(Operation *op, spirv::BuiltIn builtin, OpBuilder &builder) { Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); if (!parent) { op->emitError("expected operation to be within a module-like op"); return nullptr; } spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable( *parent->getRegion(0).begin(), op->getLoc(), builtin, builder); Value ptr = builder.create(op->getLoc(), varOp); return builder.create(op->getLoc(), ptr); } //===----------------------------------------------------------------------===// // Index calculation //===----------------------------------------------------------------------===// spirv::AccessChainOp mlir::spirv::getElementPtr( SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder) { // Get base and offset of the MemRefType and verify they are static. int64_t offset; SmallVector strides; if (failed(getStridesAndOffset(baseType, strides, offset)) || llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) || offset == MemRefType::getDynamicStrideOrOffset()) { return nullptr; } auto indexType = typeConverter.getIndexType(builder.getContext()); SmallVector linearizedIndices; // Add a '0' at the start to index into the struct. auto zero = spirv::ConstantOp::getZero(indexType, loc, builder); linearizedIndices.push_back(zero); if (baseType.getRank() == 0) { linearizedIndices.push_back(zero); } else { // TODO: Instead of this logic, use affine.apply and add patterns for // lowering affine.apply to standard ops. These will get lowered to SPIR-V // ops by the DialectConversion framework. Value ptrLoc = builder.create( loc, indexType, IntegerAttr::get(indexType, offset)); assert(indices.size() == strides.size() && "must provide indices for all dimensions"); for (auto index : llvm::enumerate(indices)) { Value strideVal = builder.create( loc, indexType, IntegerAttr::get(indexType, strides[index.index()])); Value update = builder.create(loc, strideVal, index.value()); ptrLoc = builder.create(loc, ptrLoc, update); } linearizedIndices.push_back(ptrLoc); } return builder.create(loc, basePtr, linearizedIndices); } //===----------------------------------------------------------------------===// // Set ABI attributes for lowering entry functions. //===----------------------------------------------------------------------===// LogicalResult mlir::spirv::setABIAttrs(spirv::FuncOp funcOp, spirv::EntryPointABIAttr entryPointInfo, ArrayRef argABIInfo) { // Set the attributes for argument and the function. StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName(); for (auto argIndex : llvm::seq(0, argABIInfo.size())) { funcOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]); } funcOp.setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo); return success(); } //===----------------------------------------------------------------------===// // SPIR-V ConversionTarget //===----------------------------------------------------------------------===// std::unique_ptr spirv::SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) { std::unique_ptr target( // std::make_unique does not work here because the constructor is private. new SPIRVConversionTarget(targetAttr)); SPIRVConversionTarget *targetPtr = target.get(); target->addDynamicallyLegalDialect( // We need to capture the raw pointer here because it is stable: // target will be destroyed once this function is returned. [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); }); return target; } spirv::SPIRVConversionTarget::SPIRVConversionTarget( spirv::TargetEnvAttr targetAttr) : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {} bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) { // Make sure this op is available at the given version. Ops not implementing // QueryMinVersionInterface/QueryMaxVersionInterface are available to all // SPIR-V versions. if (auto minVersion = dyn_cast(op)) if (minVersion.getMinVersion() > this->targetEnv.getVersion()) { LLVM_DEBUG(llvm::dbgs() << op->getName() << " illegal: requiring min version " << spirv::stringifyVersion(minVersion.getMinVersion()) << "\n"); return false; } if (auto maxVersion = dyn_cast(op)) if (maxVersion.getMaxVersion() < this->targetEnv.getVersion()) { LLVM_DEBUG(llvm::dbgs() << op->getName() << " illegal: requiring max version " << spirv::stringifyVersion(maxVersion.getMaxVersion()) << "\n"); return false; } // Make sure this op's required extensions are allowed to use. Ops not // implementing QueryExtensionInterface do not require extensions to be // available. if (auto extensions = dyn_cast(op)) if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, extensions.getExtensions()))) return false; // Make sure this op's required extensions are allowed to use. Ops not // implementing QueryCapabilityInterface do not require capabilities to be // available. if (auto capabilities = dyn_cast(op)) if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, capabilities.getCapabilities()))) return false; SmallVector valueTypes; valueTypes.append(op->operand_type_begin(), op->operand_type_end()); valueTypes.append(op->result_type_begin(), op->result_type_end()); // Special treatment for global variables, whose type requirements are // conveyed by type attributes. if (auto globalVar = dyn_cast(op)) valueTypes.push_back(globalVar.type()); // Make sure the op's operands/results use types that are allowed by the // target environment. SmallVector, 4> typeExtensions; SmallVector, 8> typeCapabilities; for (Type valueType : valueTypes) { typeExtensions.clear(); valueType.cast().getExtensions(typeExtensions); if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, typeExtensions))) return false; typeCapabilities.clear(); valueType.cast().getCapabilities(typeCapabilities); if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, typeCapabilities))) return false; } return true; }