//===- TargetAndABI.cpp - SPIR-V target and ABI utilities -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/TargetAndABI.h" #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/FunctionSupport.h" #include "mlir/IR/Operation.h" #include "mlir/IR/SymbolTable.h" using namespace mlir; //===----------------------------------------------------------------------===// // TargetEnv //===----------------------------------------------------------------------===// spirv::TargetEnv::TargetEnv(spirv::TargetEnvAttr targetAttr) : targetAttr(targetAttr) { for (spirv::Extension ext : targetAttr.getExtensions()) givenExtensions.insert(ext); // Add extensions implied by the current version. for (spirv::Extension ext : spirv::getImpliedExtensions(targetAttr.getVersion())) givenExtensions.insert(ext); for (spirv::Capability cap : targetAttr.getCapabilities()) { givenCapabilities.insert(cap); // Add capabilities implied by the current capability. for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap)) givenCapabilities.insert(c); } } spirv::Version spirv::TargetEnv::getVersion() const { return targetAttr.getVersion(); } bool spirv::TargetEnv::allows(spirv::Capability capability) const { return givenCapabilities.count(capability); } Optional spirv::TargetEnv::allows(ArrayRef caps) const { const auto *chosen = llvm::find_if(caps, [this](spirv::Capability cap) { return givenCapabilities.count(cap); }); if (chosen != caps.end()) return *chosen; return llvm::None; } bool spirv::TargetEnv::allows(spirv::Extension extension) const { return givenExtensions.count(extension); } Optional spirv::TargetEnv::allows(ArrayRef exts) const { const auto *chosen = llvm::find_if(exts, [this](spirv::Extension ext) { return givenExtensions.count(ext); }); if (chosen != exts.end()) return *chosen; return llvm::None; } spirv::Vendor spirv::TargetEnv::getVendorID() const { return targetAttr.getVendorID(); } spirv::DeviceType spirv::TargetEnv::getDeviceType() const { return targetAttr.getDeviceType(); } uint32_t spirv::TargetEnv::getDeviceID() const { return targetAttr.getDeviceID(); } spirv::ResourceLimitsAttr spirv::TargetEnv::getResourceLimits() const { return targetAttr.getResourceLimits(); } MLIRContext *spirv::TargetEnv::getContext() const { return targetAttr.getContext(); } //===----------------------------------------------------------------------===// // Utility functions //===----------------------------------------------------------------------===// StringRef spirv::getInterfaceVarABIAttrName() { return "spv.interface_var_abi"; } spirv::InterfaceVarABIAttr spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, Optional storageClass, MLIRContext *context) { return spirv::InterfaceVarABIAttr::get(descriptorSet, binding, storageClass, context); } bool spirv::needsInterfaceVarABIAttrs(spirv::TargetEnvAttr targetAttr) { for (spirv::Capability cap : targetAttr.getCapabilities()) { if (cap == spirv::Capability::Kernel) return false; if (cap == spirv::Capability::Shader) return true; } return false; } StringRef spirv::getEntryPointABIAttrName() { return "spv.entry_point_abi"; } spirv::EntryPointABIAttr spirv::getEntryPointABIAttr(ArrayRef localSize, MLIRContext *context) { assert(localSize.size() == 3); return spirv::EntryPointABIAttr::get( DenseElementsAttr::get( VectorType::get(3, IntegerType::get(32, context)), localSize) .cast(), context); } spirv::EntryPointABIAttr spirv::lookupEntryPointABI(Operation *op) { while (op && !op->hasTrait()) op = op->getParentOp(); if (!op) return {}; if (auto attr = op->getAttrOfType( spirv::getEntryPointABIAttrName())) return attr; return {}; } DenseIntElementsAttr spirv::lookupLocalWorkGroupSize(Operation *op) { if (auto entryPoint = spirv::lookupEntryPointABI(op)) return entryPoint.local_size(); return {}; } spirv::ResourceLimitsAttr spirv::getDefaultResourceLimits(MLIRContext *context) { // All the fields have default values. Here we just provide a nicer way to // construct a default resource limit attribute. return spirv::ResourceLimitsAttr ::get( /*max_compute_shared_memory_size=*/nullptr, /*max_compute_workgroup_invocations=*/nullptr, /*max_compute_workgroup_size=*/nullptr, /*subgroup_size=*/nullptr, /*cooperative_matrix_properties_nv=*/nullptr, context); } StringRef spirv::getTargetEnvAttrName() { return "spv.target_env"; } spirv::TargetEnvAttr spirv::getDefaultTargetEnv(MLIRContext *context) { auto triple = spirv::VerCapExtAttr::get(spirv::Version::V_1_0, {spirv::Capability::Shader}, ArrayRef(), context); return spirv::TargetEnvAttr::get(triple, spirv::Vendor::Unknown, spirv::DeviceType::Unknown, spirv::TargetEnvAttr::kUnknownDeviceID, spirv::getDefaultResourceLimits(context)); } spirv::TargetEnvAttr spirv::lookupTargetEnv(Operation *op) { while (op) { op = SymbolTable::getNearestSymbolTable(op); if (!op) break; if (auto attr = op->getAttrOfType( spirv::getTargetEnvAttrName())) return attr; op = op->getParentOp(); } return {}; } spirv::TargetEnvAttr spirv::lookupTargetEnvOrDefault(Operation *op) { if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) return attr; return getDefaultTargetEnv(op->getContext()); } spirv::AddressingModel spirv::getAddressingModel(spirv::TargetEnvAttr targetAttr) { for (spirv::Capability cap : targetAttr.getCapabilities()) { // TODO: Physical64 is hard-coded here, but some information should come // from TargetEnvAttr to selected between Physical32 and Physical64. if (cap == Capability::Kernel) return spirv::AddressingModel::Physical64; } // Logical addressing doesn't need any capabilities so return it as default. return spirv::AddressingModel::Logical; } FailureOr spirv::getExecutionModel(spirv::TargetEnvAttr targetAttr) { for (spirv::Capability cap : targetAttr.getCapabilities()) { if (cap == spirv::Capability::Kernel) return spirv::ExecutionModel::Kernel; if (cap == spirv::Capability::Shader) return spirv::ExecutionModel::GLCompute; } return failure(); } FailureOr spirv::getMemoryModel(spirv::TargetEnvAttr targetAttr) { for (spirv::Capability cap : targetAttr.getCapabilities()) { if (cap == spirv::Capability::Addresses) return spirv::MemoryModel::OpenCL; if (cap == spirv::Capability::Shader) return spirv::MemoryModel::GLSL450; } return failure(); }