1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2017 The Khronos Group Inc.
6  * Copyright (c) 2017 Codeplay Software Ltd.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  */ /*!
21  * \file
22  * \brief Subgroups Tests
23  */ /*--------------------------------------------------------------------*/
24 
25 #include "vktSubgroupsBallotOtherTests.hpp"
26 #include "vktSubgroupsTestsUtils.hpp"
27 
28 #include <string>
29 #include <vector>
30 
31 using namespace tcu;
32 using namespace std;
33 using namespace vk;
34 using namespace vkt;
35 
36 namespace
37 {
38 enum OpType
39 {
40 	OPTYPE_INVERSE_BALLOT = 0,
41 	OPTYPE_BALLOT_BIT_EXTRACT,
42 	OPTYPE_BALLOT_BIT_COUNT,
43 	OPTYPE_BALLOT_INCLUSIVE_BIT_COUNT,
44 	OPTYPE_BALLOT_EXCLUSIVE_BIT_COUNT,
45 	OPTYPE_BALLOT_FIND_LSB,
46 	OPTYPE_BALLOT_FIND_MSB,
47 	OPTYPE_LAST
48 };
49 
checkVertexPipelineStages(std::vector<const void * > datas,deUint32 width,deUint32)50 static bool checkVertexPipelineStages(std::vector<const void*> datas,
51 									  deUint32 width, deUint32)
52 {
53 	return vkt::subgroups::check(datas, width, 0xf);
54 }
55 
checkCompute(std::vector<const void * > datas,const deUint32 numWorkgroups[3],const deUint32 localSize[3],deUint32)56 static bool checkCompute(std::vector<const void*> datas,
57 						 const deUint32 numWorkgroups[3], const deUint32 localSize[3],
58 						 deUint32)
59 {
60 	return vkt::subgroups::checkCompute(datas, numWorkgroups, localSize, 0xf);
61 }
62 
getOpTypeName(int opType)63 std::string getOpTypeName(int opType)
64 {
65 	switch (opType)
66 	{
67 		default:
68 			DE_FATAL("Unsupported op type");
69 			return "";
70 		case OPTYPE_INVERSE_BALLOT:
71 			return "subgroupInverseBallot";
72 		case OPTYPE_BALLOT_BIT_EXTRACT:
73 			return "subgroupBallotBitExtract";
74 		case OPTYPE_BALLOT_BIT_COUNT:
75 			return "subgroupBallotBitCount";
76 		case OPTYPE_BALLOT_INCLUSIVE_BIT_COUNT:
77 			return "subgroupBallotInclusiveBitCount";
78 		case OPTYPE_BALLOT_EXCLUSIVE_BIT_COUNT:
79 			return "subgroupBallotExclusiveBitCount";
80 		case OPTYPE_BALLOT_FIND_LSB:
81 			return "subgroupBallotFindLSB";
82 		case OPTYPE_BALLOT_FIND_MSB:
83 			return "subgroupBallotFindMSB";
84 	}
85 }
86 
87 struct CaseDefinition
88 {
89 	int					opType;
90 	VkShaderStageFlags	shaderStage;
91 };
92 
getBodySource(CaseDefinition caseDef)93 std::string getBodySource(CaseDefinition caseDef)
94 {
95 	std::ostringstream bdy;
96 
97 	bdy << "  uvec4 allOnes = uvec4(0xFFFFFFFF);\n"
98 		<< "  uvec4 allZeros = uvec4(0);\n"
99 		<< "  uint tempResult = 0;\n"
100 		<< "#define MAKE_HIGH_BALLOT_RESULT(i) uvec4("
101 		<< "i >= 32 ? 0 : (0xFFFFFFFF << i), "
102 		<< "i >= 64 ? 0 : (0xFFFFFFFF << ((i < 32) ? 0 : (i - 32))), "
103 		<< "i >= 96 ? 0 : (0xFFFFFFFF << ((i < 64) ? 0 : (i - 64))), "
104 		<< " 0xFFFFFFFF << ((i < 96) ? 0 : (i - 96)))\n"
105 		<< "#define MAKE_SINGLE_BIT_BALLOT_RESULT(i) uvec4("
106 		<< "i >= 32 ? 0 : 0x1 << i, "
107 		<< "i < 32 || i >= 64 ? 0 : 0x1 << (i - 32), "
108 		<< "i < 64 || i >= 96 ? 0 : 0x1 << (i - 64), "
109 		<< "i < 96 ? 0 : 0x1 << (i - 96))\n";
110 
111 	switch (caseDef.opType)
112 	{
113 		default:
114 			DE_FATAL("Unknown op type!");
115 			break;
116 		case OPTYPE_INVERSE_BALLOT:
117 			bdy << "  tempResult |= subgroupInverseBallot(allOnes) ? 0x1 : 0;\n"
118 				<< "  tempResult |= subgroupInverseBallot(allZeros) ? 0 : 0x2;\n"
119 				<< "  tempResult |= subgroupInverseBallot(subgroupBallot(true)) ? 0x4 : 0;\n"
120 				<< "  tempResult |= 0x8;\n";
121 			break;
122 		case OPTYPE_BALLOT_BIT_EXTRACT:
123 			bdy << "  tempResult |= subgroupBallotBitExtract(allOnes, gl_SubgroupInvocationID) ? 0x1 : 0;\n"
124 				<< "  tempResult |= subgroupBallotBitExtract(allZeros, gl_SubgroupInvocationID) ? 0 : 0x2;\n"
125 				<< "  tempResult |= subgroupBallotBitExtract(subgroupBallot(true), gl_SubgroupInvocationID) ? 0x4 : 0;\n"
126 				<< "  tempResult |= 0x8;\n"
127 				<< "  for (uint i = 0; i < gl_SubgroupSize; i++)\n"
128 				<< "  {\n"
129 				<< "    if (!subgroupBallotBitExtract(allOnes, gl_SubgroupInvocationID))\n"
130 				<< "    {\n"
131 				<< "      tempResult &= ~0x8;\n"
132 				<< "    }\n"
133 				<< "  }\n";
134 			break;
135 		case OPTYPE_BALLOT_BIT_COUNT:
136 			bdy << "  tempResult |= gl_SubgroupSize == subgroupBallotBitCount(allOnes) ? 0x1 : 0;\n"
137 				<< "  tempResult |= 0 == subgroupBallotBitCount(allZeros) ? 0x2 : 0;\n"
138 				<< "  tempResult |= 0 < subgroupBallotBitCount(subgroupBallot(true)) ? 0x4 : 0;\n"
139 				<< "  tempResult |= 0 == subgroupBallotBitCount(MAKE_HIGH_BALLOT_RESULT(gl_SubgroupSize)) ? 0x8 : 0;\n";
140 			break;
141 		case OPTYPE_BALLOT_INCLUSIVE_BIT_COUNT:
142 			bdy << "  uint inclusiveOffset = gl_SubgroupInvocationID + 1;\n"
143 				<< "  tempResult |= inclusiveOffset == subgroupBallotInclusiveBitCount(allOnes) ? 0x1 : 0;\n"
144 				<< "  tempResult |= 0 == subgroupBallotInclusiveBitCount(allZeros) ? 0x2 : 0;\n"
145 				<< "  tempResult |= 0 < subgroupBallotInclusiveBitCount(subgroupBallot(true)) ? 0x4 : 0;\n"
146 				<< "  tempResult |= 0x8;\n"
147 				<< "  uvec4 inclusiveUndef = MAKE_HIGH_BALLOT_RESULT(inclusiveOffset);\n"
148 				<< "  bool undefTerritory = false;\n"
149 				<< "  for (uint i = 0; i <= 128; i++)\n"
150 				<< "  {\n"
151 				<< "    uvec4 iUndef = MAKE_HIGH_BALLOT_RESULT(i);\n"
152 				<< "    if (iUndef == inclusiveUndef)"
153 				<< "    {\n"
154 				<< "      undefTerritory = true;\n"
155 				<< "    }\n"
156 				<< "    uint inclusiveBitCount = subgroupBallotInclusiveBitCount(iUndef);\n"
157 				<< "    if (undefTerritory && (0 != inclusiveBitCount))\n"
158 				<< "    {\n"
159 				<< "      tempResult &= ~0x8;\n"
160 				<< "    }\n"
161 				<< "    else if (!undefTerritory && (0 == inclusiveBitCount))\n"
162 				<< "    {\n"
163 				<< "      tempResult &= ~0x8;\n"
164 				<< "    }\n"
165 				<< "  }\n";
166 			break;
167 		case OPTYPE_BALLOT_EXCLUSIVE_BIT_COUNT:
168 			bdy << "  uint exclusiveOffset = gl_SubgroupInvocationID;\n"
169 				<< "  tempResult |= exclusiveOffset == subgroupBallotExclusiveBitCount(allOnes) ? 0x1 : 0;\n"
170 				<< "  tempResult |= 0 == subgroupBallotExclusiveBitCount(allZeros) ? 0x2 : 0;\n"
171 				<< "  tempResult |= 0x4;\n"
172 				<< "  tempResult |= 0x8;\n"
173 				<< "  uvec4 exclusiveUndef = MAKE_HIGH_BALLOT_RESULT(exclusiveOffset);\n"
174 				<< "  bool undefTerritory = false;\n"
175 				<< "  for (uint i = 0; i <= 128; i++)\n"
176 				<< "  {\n"
177 				<< "    uvec4 iUndef = MAKE_HIGH_BALLOT_RESULT(i);\n"
178 				<< "    if (iUndef == exclusiveUndef)"
179 				<< "    {\n"
180 				<< "      undefTerritory = true;\n"
181 				<< "    }\n"
182 				<< "    uint exclusiveBitCount = subgroupBallotExclusiveBitCount(iUndef);\n"
183 				<< "    if (undefTerritory && (0 != exclusiveBitCount))\n"
184 				<< "    {\n"
185 				<< "      tempResult &= ~0x4;\n"
186 				<< "    }\n"
187 				<< "    else if (!undefTerritory && (0 == exclusiveBitCount))\n"
188 				<< "    {\n"
189 				<< "      tempResult &= ~0x8;\n"
190 				<< "    }\n"
191 				<< "  }\n";
192 			break;
193 		case OPTYPE_BALLOT_FIND_LSB:
194 			bdy << "  tempResult |= 0 == subgroupBallotFindLSB(allOnes) ? 0x1 : 0;\n"
195 				<< "  if (subgroupElect())\n"
196 				<< "  {\n"
197 				<< "    tempResult |= 0x2;\n"
198 				<< "  }\n"
199 				<< "  else\n"
200 				<< "  {\n"
201 				<< "    tempResult |= 0 < subgroupBallotFindLSB(subgroupBallot(true)) ? 0x2 : 0;\n"
202 				<< "  }\n"
203 				<< "  tempResult |= gl_SubgroupSize > subgroupBallotFindLSB(subgroupBallot(true)) ? 0x4 : 0;\n"
204 				<< "  tempResult |= 0x8;\n"
205 				<< "  for (uint i = 0; i < gl_SubgroupSize; i++)\n"
206 				<< "  {\n"
207 				<< "    if (i != subgroupBallotFindLSB(MAKE_HIGH_BALLOT_RESULT(i)))\n"
208 				<< "    {\n"
209 				<< "      tempResult &= ~0x8;\n"
210 				<< "    }\n"
211 				<< "  }\n";
212 			break;
213 		case OPTYPE_BALLOT_FIND_MSB:
214 			bdy << "  tempResult |= (gl_SubgroupSize - 1) == subgroupBallotFindMSB(allOnes) ? 0x1 : 0;\n"
215 				<< "  if (subgroupElect())\n"
216 				<< "  {\n"
217 				<< "    tempResult |= 0x2;\n"
218 				<< "  }\n"
219 				<< "  else\n"
220 				<< "  {\n"
221 				<< "    tempResult |= 0 < subgroupBallotFindMSB(subgroupBallot(true)) ? 0x2 : 0;\n"
222 				<< "  }\n"
223 				<< "  tempResult |= gl_SubgroupSize > subgroupBallotFindMSB(subgroupBallot(true)) ? 0x4 : 0;\n"
224 				<< "  tempResult |= 0x8;\n"
225 				<< "  for (uint i = 0; i < gl_SubgroupSize; i++)\n"
226 				<< "  {\n"
227 				<< "    if (i != subgroupBallotFindMSB(MAKE_SINGLE_BIT_BALLOT_RESULT(i)))\n"
228 				<< "    {\n"
229 				<< "      tempResult &= ~0x8;\n"
230 				<< "    }\n"
231 				<< "  }\n";
232 			break;
233 	}
234    return bdy.str();
235 }
236 
initFrameBufferPrograms(SourceCollections & programCollection,CaseDefinition caseDef)237 void initFrameBufferPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
238 {
239 	const vk::ShaderBuildOptions	buildOptions	(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
240 
241 	subgroups::setFragmentShaderFrameBuffer(programCollection);
242 
243 	if (VK_SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
244 		subgroups::setVertexShaderFrameBuffer(programCollection);
245 
246 	std::string bdyStr = getBodySource(caseDef);
247 
248 	if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
249 	{
250 		std::ostringstream				vertex;
251 		vertex << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
252 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
253 			<< "layout(location = 0) in highp vec4 in_position;\n"
254 			<< "layout(location = 0) out float out_color;\n"
255 			<< "\n"
256 			<< "void main (void)\n"
257 			<< "{\n"
258 			<< bdyStr
259 			<< "  out_color = float(tempResult);\n"
260 			<< "  gl_Position = in_position;\n"
261 			<< "  gl_PointSize = 1.0f;\n"
262 			<< "}\n";
263 		programCollection.glslSources.add("vert")
264 			<< glu::VertexSource(vertex.str()) << buildOptions;
265 	}
266 	else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
267 	{
268 		std::ostringstream geometry;
269 
270 		geometry << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
271 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
272 			<< "layout(points) in;\n"
273 			<< "layout(points, max_vertices = 1) out;\n"
274 			<< "layout(location = 0) out float out_color;\n"
275 			<< "void main (void)\n"
276 			<< "{\n"
277 			<< bdyStr
278 			<< "  out_color = float(tempResult);\n"
279 			<< "  gl_Position = gl_in[0].gl_Position;\n"
280 			<< "  EmitVertex();\n"
281 			<< "  EndPrimitive();\n"
282 			<< "}\n";
283 
284 		programCollection.glslSources.add("geometry")
285 			<< glu::GeometrySource(geometry.str()) << buildOptions;
286 	}
287 	else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
288 	{
289 		std::ostringstream controlSource;
290 
291 		controlSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
292 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
293 			<< "layout(vertices = 2) out;\n"
294 			<< "layout(location = 0) out float out_color[];\n"
295 			<< "\n"
296 			<< "void main (void)\n"
297 			<< "{\n"
298 			<< "  if (gl_InvocationID == 0)\n"
299 			<< "  {\n"
300 			<< "    gl_TessLevelOuter[0] = 1.0f;\n"
301 			<< "    gl_TessLevelOuter[1] = 1.0f;\n"
302 			<< "  }\n"
303 			<< bdyStr
304 			<< "  out_color[gl_InvocationID ] = float(tempResult);\n"
305 			<< "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
306 			<< "}\n";
307 
308 		programCollection.glslSources.add("tesc")
309 			<< glu::TessellationControlSource(controlSource.str()) << buildOptions;
310 		subgroups::setTesEvalShaderFrameBuffer(programCollection);
311 	}
312 	else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
313 	{
314 		std::ostringstream evaluationSource;
315 		evaluationSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
316 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
317 			<< "layout(isolines, equal_spacing, ccw ) in;\n"
318 			<< "layout(location = 0) out float out_color;\n"
319 			<< "void main (void)\n"
320 			<< "{\n"
321 			<< bdyStr
322 			<< "  out_color  = float(tempResult);\n"
323 			<< "  gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
324 			<< "}\n";
325 
326 		subgroups::setTesCtrlShaderFrameBuffer(programCollection);
327 		programCollection.glslSources.add("tese")
328 			<< glu::TessellationEvaluationSource(evaluationSource.str()) << buildOptions;
329 	}
330 	else
331 	{
332 		DE_FATAL("Unsupported shader stage");
333 	}
334 }
335 
initPrograms(SourceCollections & programCollection,CaseDefinition caseDef)336 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
337 {
338 	std::string bdyStr = getBodySource(caseDef);
339 
340 	if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
341 	{
342 		std::ostringstream src;
343 
344 		src << "#version 450\n"
345 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
346 			<< "layout (local_size_x_id = 0, local_size_y_id = 1, "
347 			"local_size_z_id = 2) in;\n"
348 			<< "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
349 			<< "{\n"
350 			<< "  uint result[];\n"
351 			<< "};\n"
352 			<< "\n"
353 			<< "void main (void)\n"
354 			<< "{\n"
355 			<< "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
356 			<< "  highp uint offset = globalSize.x * ((globalSize.y * "
357 			"gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
358 			"gl_GlobalInvocationID.x;\n"
359 			<< bdyStr
360 			<< "  result[offset] = tempResult;\n"
361 			<< "}\n";
362 
363 		programCollection.glslSources.add("comp")
364 				<< glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
365 	}
366 	else
367 	{
368 		const string vertex =
369 			"#version 450\n"
370 			"#extension GL_KHR_shader_subgroup_ballot: enable\n"
371 			"layout(set = 0, binding = 0, std430) buffer Buffer1\n"
372 			"{\n"
373 			"  uint result[];\n"
374 			"};\n"
375 			"\n"
376 			"void main (void)\n"
377 			"{\n"
378 			+ bdyStr +
379 			"  result[gl_VertexIndex] = tempResult;\n"
380 			"  float pixelSize = 2.0f/1024.0f;\n"
381 			"  float pixelPosition = pixelSize/2.0f - 1.0f;\n"
382 			"  gl_Position = vec4(float(gl_VertexIndex) * pixelSize + pixelPosition, 0.0f, 0.0f, 1.0f);\n"
383 			"  gl_PointSize = 1.0f;\n"
384 			"}\n";
385 
386 		const string tesc =
387 			"#version 450\n"
388 			"#extension GL_KHR_shader_subgroup_ballot: enable\n"
389 			"layout(vertices=1) out;\n"
390 			"layout(set = 0, binding = 1, std430) buffer Buffer1\n"
391 			"{\n"
392 			"  uint result[];\n"
393 			"};\n"
394 			"\n"
395 			"void main (void)\n"
396 			"{\n"
397 			+ bdyStr +
398 			"  result[gl_PrimitiveID] = tempResult;\n"
399 			"  if (gl_InvocationID == 0)\n"
400 			"  {\n"
401 			"    gl_TessLevelOuter[0] = 1.0f;\n"
402 			"    gl_TessLevelOuter[1] = 1.0f;\n"
403 			"  }\n"
404 			"  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
405 			"}\n";
406 
407 		const string tese =
408 			"#version 450\n"
409 			"#extension GL_KHR_shader_subgroup_ballot: enable\n"
410 			"layout(isolines) in;\n"
411 			"layout(set = 0, binding = 2, std430) buffer Buffer1\n"
412 			"{\n"
413 			"  uint result[];\n"
414 			"};\n"
415 			"\n"
416 			"void main (void)\n"
417 			"{\n"
418 			+ bdyStr +
419 			"  result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = tempResult;\n"
420 			"  float pixelSize = 2.0f/1024.0f;\n"
421 			"  gl_Position = gl_in[0].gl_Position + gl_TessCoord.x * pixelSize / 2.0f;\n"
422 			"}\n";
423 
424 		const string geometry =
425 			"#version 450\n"
426 			"#extension GL_KHR_shader_subgroup_ballot: enable\n"
427 			"layout(${TOPOLOGY}) in;\n"
428 			"layout(points, max_vertices = 1) out;\n"
429 			"layout(set = 0, binding = 3, std430) buffer Buffer1\n"
430 			"{\n"
431 			"  uint result[];\n"
432 			"};\n"
433 			"\n"
434 			"void main (void)\n"
435 			"{\n"
436 			+ bdyStr +
437 			"  result[gl_PrimitiveIDIn] = tempResult;\n"
438 			"  gl_Position = gl_in[0].gl_Position;\n"
439 			"  EmitVertex();\n"
440 			"  EndPrimitive();\n"
441 			"}\n";
442 
443 		const string fragment =
444 			"#version 450\n"
445 			"#extension GL_KHR_shader_subgroup_ballot: enable\n"
446 			"layout(location = 0) out uint result;\n"
447 			"void main (void)\n"
448 			"{\n"
449 			+ bdyStr +
450 			"  result = tempResult;\n"
451 			"}\n";
452 
453 		subgroups::addNoSubgroupShader(programCollection);
454 
455 		programCollection.glslSources.add("vert")
456 				<< glu::VertexSource(vertex) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
457 		programCollection.glslSources.add("tesc")
458 				<< glu::TessellationControlSource(tesc) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
459 		programCollection.glslSources.add("tese")
460 				<< glu::TessellationEvaluationSource(tese) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
461 		subgroups::addGeometryShadersFromTemplate(geometry, vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u),
462 												  programCollection.glslSources);
463 		programCollection.glslSources.add("fragment")
464 				<< glu::FragmentSource(fragment)<< vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
465 	}
466 }
467 
supportedCheck(Context & context,CaseDefinition caseDef)468 void supportedCheck (Context& context, CaseDefinition caseDef)
469 {
470 	DE_UNREF(caseDef);
471 	if (!subgroups::isSubgroupSupported(context))
472 		TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
473 
474 	if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_BALLOT_BIT))
475 	{
476 		TCU_THROW(NotSupportedError, "Device does not support subgroup ballot operations");
477 	}
478 }
479 
noSSBOtest(Context & context,const CaseDefinition caseDef)480 tcu::TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
481 {
482 	if (!subgroups::areSubgroupOperationsSupportedForStage(
483 			context, caseDef.shaderStage))
484 	{
485 		if (subgroups::areSubgroupOperationsRequiredForStage(caseDef.shaderStage))
486 		{
487 			return tcu::TestStatus::fail(
488 					   "Shader stage " +
489 					   subgroups::getShaderStageName(caseDef.shaderStage) +
490 					   " is required to support subgroup operations!");
491 		}
492 		else
493 		{
494 			TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
495 		}
496 	}
497 
498 	if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
499 		return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, DE_NULL, 0, checkVertexPipelineStages);
500 	else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
501 		return subgroups::makeGeometryFrameBufferTest(context, VK_FORMAT_R32_UINT, DE_NULL, 0, checkVertexPipelineStages);
502 	else if ((VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT | VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT) & caseDef.shaderStage)
503 		return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, DE_NULL, 0, checkVertexPipelineStages);
504 	else
505 		TCU_THROW(InternalError, "Unhandled shader stage");
506 }
507 
test(Context & context,const CaseDefinition caseDef)508 tcu::TestStatus test (Context& context, const CaseDefinition caseDef)
509 {
510 	if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
511 	{
512 		if (!subgroups::areSubgroupOperationsSupportedForStage(context, caseDef.shaderStage))
513 		{
514 			return tcu::TestStatus::fail(
515 					   "Shader stage " +
516 				subgroups::getShaderStageName(caseDef.shaderStage) +
517 				" is required to support subgroup operations!");
518 		}
519 		return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, DE_NULL, 0, checkCompute);
520 	}
521 	else
522 	{
523 		VkPhysicalDeviceSubgroupProperties subgroupProperties;
524 		subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
525 		subgroupProperties.pNext = DE_NULL;
526 
527 		VkPhysicalDeviceProperties2 properties;
528 		properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
529 		properties.pNext = &subgroupProperties;
530 
531 		context.getInstanceInterface().getPhysicalDeviceProperties2(context.getPhysicalDevice(), &properties);
532 
533 		VkShaderStageFlagBits stages = (VkShaderStageFlagBits)(caseDef.shaderStage  & subgroupProperties.supportedStages);
534 
535 		if ( VK_SHADER_STAGE_FRAGMENT_BIT != stages && !subgroups::isVertexSSBOSupportedForDevice(context))
536 		{
537 			if ( (stages & VK_SHADER_STAGE_FRAGMENT_BIT) == 0)
538 				TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
539 			else
540 				stages = VK_SHADER_STAGE_FRAGMENT_BIT;
541 		}
542 
543 		if ((VkShaderStageFlagBits)0u == stages)
544 			TCU_THROW(NotSupportedError, "Subgroup operations are not supported for any graphic shader");
545 
546 		return subgroups::allStages(context, VK_FORMAT_R32_UINT, DE_NULL, 0, checkVertexPipelineStages, stages);
547 	}
548 	return tcu::TestStatus::pass("OK");
549 }
550 }
551 
552 namespace vkt
553 {
554 namespace subgroups
555 {
createSubgroupsBallotOtherTests(tcu::TestContext & testCtx)556 tcu::TestCaseGroup* createSubgroupsBallotOtherTests(tcu::TestContext& testCtx)
557 {
558 	de::MovePtr<tcu::TestCaseGroup> graphicGroup(new tcu::TestCaseGroup(
559 		testCtx, "graphics", "Subgroup ballot other category tests: graphics"));
560 	de::MovePtr<tcu::TestCaseGroup> computeGroup(new tcu::TestCaseGroup(
561 		testCtx, "compute", "Subgroup ballot other category tests: compute"));
562 	de::MovePtr<tcu::TestCaseGroup> framebufferGroup(new tcu::TestCaseGroup(
563 		testCtx, "framebuffer", "Subgroup ballot other category tests: framebuffer"));
564 
565 	const VkShaderStageFlags stages[] =
566 	{
567 		VK_SHADER_STAGE_VERTEX_BIT,
568 		VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
569 		VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
570 		VK_SHADER_STAGE_GEOMETRY_BIT,
571 	};
572 
573 	for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
574 	{
575 		const string	op		= de::toLower(getOpTypeName(opTypeIndex));
576 		{
577 			const CaseDefinition caseDef = {opTypeIndex, VK_SHADER_STAGE_COMPUTE_BIT};
578 			addFunctionCaseWithPrograms(computeGroup.get(), op, "", supportedCheck, initPrograms, test, caseDef);
579 		}
580 
581 		{
582 			const CaseDefinition caseDef = {opTypeIndex, VK_SHADER_STAGE_ALL_GRAPHICS};
583 			addFunctionCaseWithPrograms(graphicGroup.get(), op, "", supportedCheck, initPrograms, test, caseDef);
584 		}
585 
586 		for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
587 		{
588 			const CaseDefinition caseDef = {opTypeIndex, stages[stageIndex]};
589 			addFunctionCaseWithPrograms(framebufferGroup.get(), op + "_" + getShaderStageName(caseDef.shaderStage), "", supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
590 		}
591 	}
592 
593 	de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
594 		testCtx, "ballot_other", "Subgroup ballot other category tests"));
595 
596 	group->addChild(graphicGroup.release());
597 	group->addChild(computeGroup.release());
598 	group->addChild(framebufferGroup.release());
599 
600 	return group.release();
601 }
602 
603 } // subgroups
604 } // vkt
605