/*------------------------------------------------------------------------ * Vulkan Conformance Tests * ------------------------ * * Copyright (c) 2017 The Khronos Group Inc. * Copyright (c) 2017 Codeplay Software Ltd. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ /*! * \file * \brief Subgroups Tests */ /*--------------------------------------------------------------------*/ #include "vktSubgroupsClusteredTests.hpp" #include "vktSubgroupsTestsUtils.hpp" #include #include using namespace tcu; using namespace std; using namespace vk; using namespace vkt; namespace { enum OpType { OPTYPE_CLUSTERED_ADD = 0, OPTYPE_CLUSTERED_MUL, OPTYPE_CLUSTERED_MIN, OPTYPE_CLUSTERED_MAX, OPTYPE_CLUSTERED_AND, OPTYPE_CLUSTERED_OR, OPTYPE_CLUSTERED_XOR, OPTYPE_CLUSTERED_LAST }; static bool checkVertexPipelineStages(std::vector datas, deUint32 width, deUint32) { return vkt::subgroups::check(datas, width, 1); } static bool checkCompute(std::vector datas, const deUint32 numWorkgroups[3], const deUint32 localSize[3], deUint32) { return vkt::subgroups::checkCompute(datas, numWorkgroups, localSize, 1); } std::string getOpTypeName(int opType) { switch (opType) { default: DE_FATAL("Unsupported op type"); return ""; case OPTYPE_CLUSTERED_ADD: return "subgroupClusteredAdd"; case OPTYPE_CLUSTERED_MUL: return "subgroupClusteredMul"; case OPTYPE_CLUSTERED_MIN: return "subgroupClusteredMin"; case OPTYPE_CLUSTERED_MAX: return "subgroupClusteredMax"; case OPTYPE_CLUSTERED_AND: return "subgroupClusteredAnd"; case OPTYPE_CLUSTERED_OR: return "subgroupClusteredOr"; case OPTYPE_CLUSTERED_XOR: return "subgroupClusteredXor"; } } std::string getOpTypeOperation(int opType, vk::VkFormat format, std::string lhs, std::string rhs) { switch (opType) { default: DE_FATAL("Unsupported op type"); return ""; case OPTYPE_CLUSTERED_ADD: return lhs + " + " + rhs; case OPTYPE_CLUSTERED_MUL: return lhs + " * " + rhs; case OPTYPE_CLUSTERED_MIN: switch (format) { default: return "min(" + lhs + ", " + rhs + ")"; case VK_FORMAT_R32_SFLOAT: case VK_FORMAT_R64_SFLOAT: return "(isnan(" + lhs + ") ? " + rhs + " : (isnan(" + rhs + ") ? " + lhs + " : min(" + lhs + ", " + rhs + ")))"; case VK_FORMAT_R32G32_SFLOAT: case VK_FORMAT_R32G32B32_SFLOAT: case VK_FORMAT_R32G32B32A32_SFLOAT: case VK_FORMAT_R64G64_SFLOAT: case VK_FORMAT_R64G64B64_SFLOAT: case VK_FORMAT_R64G64B64A64_SFLOAT: return "mix(mix(min(" + lhs + ", " + rhs + "), " + lhs + ", isnan(" + rhs + ")), " + rhs + ", isnan(" + lhs + "))"; } case OPTYPE_CLUSTERED_MAX: switch (format) { default: return "max(" + lhs + ", " + rhs + ")"; case VK_FORMAT_R32_SFLOAT: case VK_FORMAT_R64_SFLOAT: return "(isnan(" + lhs + ") ? " + rhs + " : (isnan(" + rhs + ") ? " + lhs + " : max(" + lhs + ", " + rhs + ")))"; case VK_FORMAT_R32G32_SFLOAT: case VK_FORMAT_R32G32B32_SFLOAT: case VK_FORMAT_R32G32B32A32_SFLOAT: case VK_FORMAT_R64G64_SFLOAT: case VK_FORMAT_R64G64B64_SFLOAT: case VK_FORMAT_R64G64B64A64_SFLOAT: return "mix(mix(max(" + lhs + ", " + rhs + "), " + lhs + ", isnan(" + rhs + ")), " + rhs + ", isnan(" + lhs + "))"; } case OPTYPE_CLUSTERED_AND: switch (format) { default: return lhs + " & " + rhs; case VK_FORMAT_R8_USCALED: return lhs + " && " + rhs; case VK_FORMAT_R8G8_USCALED: return "bvec2(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y)"; case VK_FORMAT_R8G8B8_USCALED: return "bvec3(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y, " + lhs + ".z && " + rhs + ".z)"; case VK_FORMAT_R8G8B8A8_USCALED: return "bvec4(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y, " + lhs + ".z && " + rhs + ".z, " + lhs + ".w && " + rhs + ".w)"; } case OPTYPE_CLUSTERED_OR: switch (format) { default: return lhs + " | " + rhs; case VK_FORMAT_R8_USCALED: return lhs + " || " + rhs; case VK_FORMAT_R8G8_USCALED: return "bvec2(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y)"; case VK_FORMAT_R8G8B8_USCALED: return "bvec3(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y, " + lhs + ".z || " + rhs + ".z)"; case VK_FORMAT_R8G8B8A8_USCALED: return "bvec4(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y, " + lhs + ".z || " + rhs + ".z, " + lhs + ".w || " + rhs + ".w)"; } case OPTYPE_CLUSTERED_XOR: switch (format) { default: return lhs + " ^ " + rhs; case VK_FORMAT_R8_USCALED: return lhs + " ^^ " + rhs; case VK_FORMAT_R8G8_USCALED: return "bvec2(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y)"; case VK_FORMAT_R8G8B8_USCALED: return "bvec3(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y, " + lhs + ".z ^^ " + rhs + ".z)"; case VK_FORMAT_R8G8B8A8_USCALED: return "bvec4(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y, " + lhs + ".z ^^ " + rhs + ".z, " + lhs + ".w ^^ " + rhs + ".w)"; } } } std::string getIdentity(int opType, vk::VkFormat format) { bool isFloat = false; bool isInt = false; bool isUnsigned = false; switch (format) { default: DE_FATAL("Unhandled format!"); break; case VK_FORMAT_R32_SINT: case VK_FORMAT_R32G32_SINT: case VK_FORMAT_R32G32B32_SINT: case VK_FORMAT_R32G32B32A32_SINT: isInt = true; break; case VK_FORMAT_R32_UINT: case VK_FORMAT_R32G32_UINT: case VK_FORMAT_R32G32B32_UINT: case VK_FORMAT_R32G32B32A32_UINT: isUnsigned = true; break; case VK_FORMAT_R32_SFLOAT: case VK_FORMAT_R32G32_SFLOAT: case VK_FORMAT_R32G32B32_SFLOAT: case VK_FORMAT_R32G32B32A32_SFLOAT: case VK_FORMAT_R64_SFLOAT: case VK_FORMAT_R64G64_SFLOAT: case VK_FORMAT_R64G64B64_SFLOAT: case VK_FORMAT_R64G64B64A64_SFLOAT: isFloat = true; break; case VK_FORMAT_R8_USCALED: case VK_FORMAT_R8G8_USCALED: case VK_FORMAT_R8G8B8_USCALED: case VK_FORMAT_R8G8B8A8_USCALED: break; // bool types are not anything } switch (opType) { default: DE_FATAL("Unsupported op type"); return ""; case OPTYPE_CLUSTERED_ADD: return subgroups::getFormatNameForGLSL(format) + "(0)"; case OPTYPE_CLUSTERED_MUL: return subgroups::getFormatNameForGLSL(format) + "(1)"; case OPTYPE_CLUSTERED_MIN: if (isFloat) { return subgroups::getFormatNameForGLSL(format) + "(intBitsToFloat(0x7f800000))"; } else if (isInt) { return subgroups::getFormatNameForGLSL(format) + "(0x7fffffff)"; } else if (isUnsigned) { return subgroups::getFormatNameForGLSL(format) + "(0xffffffffu)"; } else { DE_FATAL("Unhandled case"); return ""; } case OPTYPE_CLUSTERED_MAX: if (isFloat) { return subgroups::getFormatNameForGLSL(format) + "(intBitsToFloat(0xff800000))"; } else if (isInt) { return subgroups::getFormatNameForGLSL(format) + "(0x80000000)"; } else if (isUnsigned) { return subgroups::getFormatNameForGLSL(format) + "(0)"; } else { DE_FATAL("Unhandled case"); return ""; } case OPTYPE_CLUSTERED_AND: return subgroups::getFormatNameForGLSL(format) + "(~0)"; case OPTYPE_CLUSTERED_OR: return subgroups::getFormatNameForGLSL(format) + "(0)"; case OPTYPE_CLUSTERED_XOR: return subgroups::getFormatNameForGLSL(format) + "(0)"; } } std::string getCompare(int opType, vk::VkFormat format, std::string lhs, std::string rhs) { std::string formatName = subgroups::getFormatNameForGLSL(format); switch (format) { default: return "all(equal(" + lhs + ", " + rhs + "))"; case VK_FORMAT_R8_USCALED: case VK_FORMAT_R32_UINT: case VK_FORMAT_R32_SINT: return "(" + lhs + " == " + rhs + ")"; case VK_FORMAT_R32_SFLOAT: case VK_FORMAT_R64_SFLOAT: switch (opType) { default: return "(abs(" + lhs + " - " + rhs + ") < 0.00001)"; case OPTYPE_CLUSTERED_MIN: case OPTYPE_CLUSTERED_MAX: return "(" + lhs + " == " + rhs + ")"; } case VK_FORMAT_R32G32_SFLOAT: case VK_FORMAT_R32G32B32_SFLOAT: case VK_FORMAT_R32G32B32A32_SFLOAT: case VK_FORMAT_R64G64_SFLOAT: case VK_FORMAT_R64G64B64_SFLOAT: case VK_FORMAT_R64G64B64A64_SFLOAT: switch (opType) { default: return "all(lessThan(abs(" + lhs + " - " + rhs + "), " + formatName + "(0.00001)))"; case OPTYPE_CLUSTERED_MIN: case OPTYPE_CLUSTERED_MAX: return "all(equal(" + lhs + ", " + rhs + "))"; } } } struct CaseDefinition { int opType; VkShaderStageFlags shaderStage; VkFormat format; }; std::string getBodySource(CaseDefinition caseDef) { std::ostringstream bdy; bdy << " bool tempResult = true;\n"; for (deUint32 i = 1; i <= subgroups::maxSupportedSubgroupSize(); i *= 2) { bdy << " {\n" << " const uint clusterSize = " << i << ";\n" << " if (clusterSize <= gl_SubgroupSize)\n" << " {\n" << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = " << getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID], clusterSize);\n" << " for (uint clusterOffset = 0; clusterOffset < gl_SubgroupSize; clusterOffset += clusterSize)\n" << " {\n" << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " ref = " << getIdentity(caseDef.opType, caseDef.format) << ";\n" << " for (uint index = clusterOffset; index < (clusterOffset + clusterSize); index++)\n" << " {\n" << " if (subgroupBallotBitExtract(mask, index))\n" << " {\n" << " ref = " << getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") << ";\n" << " }\n" << " }\n" << " if ((clusterOffset <= gl_SubgroupInvocationID) && (gl_SubgroupInvocationID < (clusterOffset + clusterSize)))\n" << " {\n" << " if (!" << getCompare(caseDef.opType, caseDef.format, "ref", "op") << ")\n" << " {\n" << " tempResult = false;\n" << " }\n" << " }\n" << " }\n" << " }\n" << " }\n"; } return bdy.str(); } void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef) { const vk::ShaderBuildOptions buildOptions (programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u); subgroups::setFragmentShaderFrameBuffer(programCollection); if (VK_SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage) subgroups::setVertexShaderFrameBuffer(programCollection); std::string bdy = getBodySource(caseDef); if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage) { std::ostringstream vertexSrc; vertexSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450 )<< "\n" << "#extension GL_KHR_shader_subgroup_clustered: enable\n" << "#extension GL_KHR_shader_subgroup_ballot: enable\n" << "layout(location = 0) in highp vec4 in_position;\n" << "layout(location = 0) out float out_color;\n" << "layout(set = 0, binding = 0) uniform Buffer1\n" << "{\n" << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n" << "};\n" << "\n" << "void main (void)\n" << "{\n" << " uvec4 mask = subgroupBallot(true);\n" << bdy << " out_color = float(tempResult ? 1 : 0);\n" << " gl_Position = in_position;\n" << " gl_PointSize = 1.0f;\n" << "}\n"; programCollection.glslSources.add("vert") << glu::VertexSource(vertexSrc.str()) < graphicGroup(new tcu::TestCaseGroup( testCtx, "graphics", "Subgroup clustered category tests: graphics")); de::MovePtr computeGroup(new tcu::TestCaseGroup( testCtx, "compute", "Subgroup clustered category tests: compute")); de::MovePtr framebufferGroup(new tcu::TestCaseGroup( testCtx, "framebuffer", "Subgroup clustered category tests: framebuffer")); const VkShaderStageFlags stages[] = { VK_SHADER_STAGE_VERTEX_BIT, VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT, VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT, VK_SHADER_STAGE_GEOMETRY_BIT }; const VkFormat formats[] = { VK_FORMAT_R32_SINT, VK_FORMAT_R32G32_SINT, VK_FORMAT_R32G32B32_SINT, VK_FORMAT_R32G32B32A32_SINT, VK_FORMAT_R32_UINT, VK_FORMAT_R32G32_UINT, VK_FORMAT_R32G32B32_UINT, VK_FORMAT_R32G32B32A32_UINT, VK_FORMAT_R32_SFLOAT, VK_FORMAT_R32G32_SFLOAT, VK_FORMAT_R32G32B32_SFLOAT, VK_FORMAT_R32G32B32A32_SFLOAT, VK_FORMAT_R64_SFLOAT, VK_FORMAT_R64G64_SFLOAT, VK_FORMAT_R64G64B64_SFLOAT, VK_FORMAT_R64G64B64A64_SFLOAT, VK_FORMAT_R8_USCALED, VK_FORMAT_R8G8_USCALED, VK_FORMAT_R8G8B8_USCALED, VK_FORMAT_R8G8B8A8_USCALED, }; for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex) { const VkFormat format = formats[formatIndex]; for (int opTypeIndex = 0; opTypeIndex < OPTYPE_CLUSTERED_LAST; ++opTypeIndex) { bool isBool = false; bool isFloat = false; switch (format) { default: break; case VK_FORMAT_R32_SFLOAT: case VK_FORMAT_R32G32_SFLOAT: case VK_FORMAT_R32G32B32_SFLOAT: case VK_FORMAT_R32G32B32A32_SFLOAT: case VK_FORMAT_R64_SFLOAT: case VK_FORMAT_R64G64_SFLOAT: case VK_FORMAT_R64G64B64_SFLOAT: case VK_FORMAT_R64G64B64A64_SFLOAT: isFloat = true; break; case VK_FORMAT_R8_USCALED: case VK_FORMAT_R8G8_USCALED: case VK_FORMAT_R8G8B8_USCALED: case VK_FORMAT_R8G8B8A8_USCALED: isBool = true; break; } bool isBitwiseOp = false; switch (opTypeIndex) { default: break; case OPTYPE_CLUSTERED_AND: case OPTYPE_CLUSTERED_OR: case OPTYPE_CLUSTERED_XOR: isBitwiseOp = true; break; } if (isFloat && isBitwiseOp) { // Skip float with bitwise category. continue; } if (isBool && !isBitwiseOp) { // Skip bool when its not the bitwise category. continue; } const std::string name = de::toLower(getOpTypeName(opTypeIndex)) +"_" + subgroups::getFormatNameForGLSL(format); { const CaseDefinition caseDef = {opTypeIndex, VK_SHADER_STAGE_COMPUTE_BIT, format}; addFunctionCaseWithPrograms(computeGroup.get(), name, "", supportedCheck, initPrograms, test, caseDef); } { const CaseDefinition caseDef = {opTypeIndex, VK_SHADER_STAGE_ALL_GRAPHICS, format}; addFunctionCaseWithPrograms(graphicGroup.get(), name, "", supportedCheck, initPrograms, test, caseDef); } for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex) { const CaseDefinition caseDef = {opTypeIndex, stages[stageIndex], format}; addFunctionCaseWithPrograms(framebufferGroup.get(), name +"_" + getShaderStageName(caseDef.shaderStage), "", supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef); } } } de::MovePtr group(new tcu::TestCaseGroup( testCtx, "clustered", "Subgroup clustered category tests")); group->addChild(graphicGroup.release()); group->addChild(computeGroup.release()); group->addChild(framebufferGroup.release()); return group.release(); } } // subgroups } // vkt