1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2017 The Khronos Group Inc.
6  * Copyright (c) 2017 Codeplay Software Ltd.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  */ /*!
21  * \file
22  * \brief Subgroups Tests
23  */ /*--------------------------------------------------------------------*/
24 
25 #include "vktSubgroupsBallotBroadcastTests.hpp"
26 #include "vktSubgroupsTestsUtils.hpp"
27 
28 #include <string>
29 #include <vector>
30 
31 using namespace tcu;
32 using namespace std;
33 using namespace vk;
34 using namespace vkt;
35 
36 namespace
37 {
38 enum OpType
39 {
40 	OPTYPE_BROADCAST = 0,
41 	OPTYPE_BROADCAST_FIRST,
42 	OPTYPE_LAST
43 };
44 
checkVertexPipelineStages(std::vector<const void * > datas,deUint32 width,deUint32)45 static bool checkVertexPipelineStages(std::vector<const void*> datas,
46 									  deUint32 width, deUint32)
47 {
48 	return vkt::subgroups::check(datas, width, 3);
49 }
50 
checkCompute(std::vector<const void * > datas,const deUint32 numWorkgroups[3],const deUint32 localSize[3],deUint32)51 static bool checkCompute(std::vector<const void*> datas,
52 						 const deUint32 numWorkgroups[3], const deUint32 localSize[3],
53 						 deUint32)
54 {
55 	return vkt::subgroups::checkCompute(datas, numWorkgroups, localSize, 3);
56 }
57 
getOpTypeName(int opType)58 std::string getOpTypeName(int opType)
59 {
60 	switch (opType)
61 	{
62 		default:
63 			DE_FATAL("Unsupported op type");
64 			return "";
65 		case OPTYPE_BROADCAST:
66 			return "subgroupBroadcast";
67 		case OPTYPE_BROADCAST_FIRST:
68 			return "subgroupBroadcastFirst";
69 	}
70 }
71 
72 struct CaseDefinition
73 {
74 	int					opType;
75 	VkShaderStageFlags	shaderStage;
76 	VkFormat			format;
77 };
78 
getBodySource(CaseDefinition caseDef)79 std::string getBodySource(CaseDefinition caseDef)
80 {
81 	std::ostringstream bdy;
82 
83 	bdy << "  uvec4 mask = subgroupBallot(true);\n";
84 	bdy << "  uint tempResult = 0;\n";
85 
86 	if (OPTYPE_BROADCAST == caseDef.opType)
87 	{
88 		bdy	<< "  tempResult = 0x3;\n";
89 		for (int i = 0; i < (int)subgroups::maxSupportedSubgroupSize(); i++)
90 		{
91 			bdy << "  {\n"
92 			<< "    const uint id = "<< i << ";\n"
93 			<< "    " << subgroups::getFormatNameForGLSL(caseDef.format)
94 			<< " op = subgroupBroadcast(data1[gl_SubgroupInvocationID], id);\n"
95 			<< "    if ((id < gl_SubgroupSize) && subgroupBallotBitExtract(mask, id))\n"
96 			<< "    {\n"
97 			<< "      if (op != data1[id])\n"
98 			<< "      {\n"
99 			<< "        tempResult = 0;\n"
100 			<< "      }\n"
101 			<< "    }\n"
102 			<< "  }\n";
103 		}
104 	}
105 	else
106 	{
107 		bdy	<< "  uint firstActive = 0;\n"
108 			<< "  for (uint i = 0; i < gl_SubgroupSize; i++)\n"
109 			<< "  {\n"
110 			<< "    if (subgroupBallotBitExtract(mask, i))\n"
111 			<< "    {\n"
112 			<< "      firstActive = i;\n"
113 			<< "      break;\n"
114 			<< "    }\n"
115 			<< "  }\n"
116 			<< "  tempResult |= (subgroupBroadcastFirst(data1[gl_SubgroupInvocationID]) == data1[firstActive]) ? 0x1 : 0;\n"
117 			<< "  // make the firstActive invocation inactive now\n"
118 			<< "  if (firstActive == gl_SubgroupInvocationID)\n"
119 			<< "  {\n"
120 			<< "    for (uint i = 0; i < gl_SubgroupSize; i++)\n"
121 			<< "    {\n"
122 			<< "      if (subgroupBallotBitExtract(mask, i))\n"
123 			<< "      {\n"
124 			<< "        firstActive = i;\n"
125 			<< "        break;\n"
126 			<< "      }\n"
127 			<< "    }\n"
128 			<< "    tempResult |= (subgroupBroadcastFirst(data1[gl_SubgroupInvocationID]) == data1[firstActive]) ? 0x2 : 0;\n"
129 			<< "  }\n"
130 			<< "  else\n"
131 			<< "  {\n"
132 			<< "    // the firstActive invocation didn't partake in the second result so set it to true\n"
133 			<< "    tempResult |= 0x2;\n"
134 			<< "  }\n";
135 	}
136    return bdy.str();
137 }
138 
initFrameBufferPrograms(SourceCollections & programCollection,CaseDefinition caseDef)139 void initFrameBufferPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
140 {
141 	const vk::ShaderBuildOptions	buildOptions	(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
142 
143 	subgroups::setFragmentShaderFrameBuffer(programCollection);
144 
145 	if (VK_SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
146 		subgroups::setVertexShaderFrameBuffer(programCollection);
147 
148 	std::string bdyStr = getBodySource(caseDef);
149 
150 	if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
151 	{
152 		std::ostringstream				vertex;
153 		vertex << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
154 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
155 			<< "layout(location = 0) in highp vec4 in_position;\n"
156 			<< "layout(location = 0) out float out_color;\n"
157 			<< "layout(set = 0, binding = 0) uniform  Buffer1\n"
158 			<< "{\n"
159 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[" << subgroups::maxSupportedSubgroupSize() << "];\n"
160 			<< "};\n"
161 			<< "\n"
162 			<< "void main (void)\n"
163 			<< "{\n"
164 			<< bdyStr
165 			<< "  out_color = float(tempResult);\n"
166 			<< "  gl_Position = in_position;\n"
167 			<< "  gl_PointSize = 1.0f;\n"
168 			<< "}\n";
169 		programCollection.glslSources.add("vert")
170 			<< glu::VertexSource(vertex.str()) << buildOptions;
171 	}
172 	else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
173 	{
174 		std::ostringstream geometry;
175 
176 		geometry << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
177 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
178 			<< "layout(points) in;\n"
179 			<< "layout(points, max_vertices = 1) out;\n"
180 			<< "layout(location = 0) out float out_color;\n"
181 			<< "layout(set = 0, binding = 0) uniform Buffer1\n"
182 			<< "{\n"
183 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[" <<subgroups::maxSupportedSubgroupSize() << "];\n"
184 			<< "};\n"
185 			<< "\n"
186 			<< "void main (void)\n"
187 			<< "{\n"
188 			<< bdyStr
189 			<< "  out_color = float(tempResult);\n"
190 			<< "  gl_Position = gl_in[0].gl_Position;\n"
191 			<< "  EmitVertex();\n"
192 			<< "  EndPrimitive();\n"
193 			<< "}\n";
194 
195 		programCollection.glslSources.add("geometry")
196 			<< glu::GeometrySource(geometry.str()) << buildOptions;
197 	}
198 	else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
199 	{
200 		std::ostringstream controlSource;
201 
202 		controlSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
203 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
204 			<< "layout(vertices = 2) out;\n"
205 			<< "layout(location = 0) out float out_color[];\n"
206 			<< "layout(set = 0, binding = 0) uniform Buffer2\n"
207 			<< "{\n"
208 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[" <<subgroups::maxSupportedSubgroupSize() << "];\n"
209 			<< "};\n"
210 			<< "\n"
211 			<< "void main (void)\n"
212 			<< "{\n"
213 			<< "  if (gl_InvocationID == 0)\n"
214 			<< "  {\n"
215 			<< "    gl_TessLevelOuter[0] = 1.0f;\n"
216 			<< "    gl_TessLevelOuter[1] = 1.0f;\n"
217 			<< "  }\n"
218 			<< bdyStr
219 			<< "  out_color[gl_InvocationID ] = float(tempResult);\n"
220 			<< "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
221 			<< "}\n";
222 
223 		programCollection.glslSources.add("tesc")
224 			<< glu::TessellationControlSource(controlSource.str()) << buildOptions;
225 		subgroups::setTesEvalShaderFrameBuffer(programCollection);
226 	}
227 	else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
228 	{
229 		std::ostringstream evaluationSource;
230 		evaluationSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
231 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
232 			<< "layout(isolines, equal_spacing, ccw ) in;\n"
233 			<< "layout(location = 0) out float out_color;\n"
234 			<< "layout(set = 0, binding = 0) uniform Buffer1\n"
235 			<< "{\n"
236 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[" <<subgroups::maxSupportedSubgroupSize() << "];\n"
237 			<< "};\n"
238 			<< "\n"
239 			<< "void main (void)\n"
240 			<< "{\n"
241 			<< bdyStr
242 			<< "  out_color  = float(tempResult);\n"
243 			<< "  gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
244 			<< "}\n";
245 
246 		subgroups::setTesCtrlShaderFrameBuffer(programCollection);
247 		programCollection.glslSources.add("tese")
248 			<< glu::TessellationEvaluationSource(evaluationSource.str()) << buildOptions;
249 	}
250 	else
251 	{
252 		DE_FATAL("Unsupported shader stage");
253 	}
254 }
255 
initPrograms(SourceCollections & programCollection,CaseDefinition caseDef)256 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
257 {
258 	std::string bdyStr = getBodySource(caseDef);
259 
260 	if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
261 	{
262 		std::ostringstream src;
263 
264 		src << "#version 450\n"
265 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
266 			<< "layout (local_size_x_id = 0, local_size_y_id = 1, "
267 			"local_size_z_id = 2) in;\n"
268 			<< "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
269 			<< "{\n"
270 			<< "  uint result[];\n"
271 			<< "};\n"
272 			<< "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
273 			<< "{\n"
274 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data1[];\n"
275 			<< "};\n"
276 			<< "\n"
277 			<< "void main (void)\n"
278 			<< "{\n"
279 			<< "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
280 			<< "  highp uint offset = globalSize.x * ((globalSize.y * "
281 			"gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
282 			"gl_GlobalInvocationID.x;\n"
283 			<< bdyStr
284 			<< "  result[offset] = tempResult;\n"
285 			<< "}\n";
286 
287 		programCollection.glslSources.add("comp")
288 				<< glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
289 	}
290 	else
291 	{
292 		const string vertex =
293 			"#version 450\n"
294 			"#extension GL_KHR_shader_subgroup_ballot: enable\n"
295 			"layout(set = 0, binding = 0, std430) buffer Buffer1\n"
296 			"{\n"
297 			"  uint result[];\n"
298 			"};\n"
299 			"layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
300 			"{\n"
301 			"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data1[];\n"
302 			"};\n"
303 			"\n"
304 			"void main (void)\n"
305 			"{\n"
306 			+ bdyStr +
307 			"  result[gl_VertexIndex] = tempResult;\n"
308 			"  float pixelSize = 2.0f/1024.0f;\n"
309 			"  float pixelPosition = pixelSize/2.0f - 1.0f;\n"
310 			"  gl_Position = vec4(float(gl_VertexIndex) * pixelSize + pixelPosition, 0.0f, 0.0f, 1.0f);\n"
311 			"  gl_PointSize = 1.0f;\n"
312 			"}\n";
313 
314 		const string tesc =
315 			"#version 450\n"
316 			"#extension GL_KHR_shader_subgroup_ballot: enable\n"
317 			"layout(vertices=1) out;\n"
318 			"layout(set = 0, binding = 1, std430) buffer Buffer1\n"
319 			"{\n"
320 			"  uint result[];\n"
321 			"};\n"
322 			"layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
323 			"{\n"
324 			"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data1[];\n"
325 			"};\n"
326 			"\n"
327 			"void main (void)\n"
328 			"{\n"
329 			+ bdyStr +
330 			"  result[gl_PrimitiveID] = tempResult;\n"
331 			"  if (gl_InvocationID == 0)\n"
332 			"  {\n"
333 			"    gl_TessLevelOuter[0] = 1.0f;\n"
334 			"    gl_TessLevelOuter[1] = 1.0f;\n"
335 			"  }\n"
336 			"  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
337 			"}\n";
338 
339 		const string tese =
340 			"#version 450\n"
341 			"#extension GL_KHR_shader_subgroup_ballot: enable\n"
342 			"layout(isolines) in;\n"
343 			"layout(set = 0, binding = 2, std430) buffer Buffer1\n"
344 			"{\n"
345 			"  uint result[];\n"
346 			"};\n"
347 			"layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
348 			"{\n"
349 			"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data1[];\n"
350 			"};\n"
351 			"\n"
352 			"void main (void)\n"
353 			"{\n"
354 			+ bdyStr +
355 			"  result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = tempResult;\n"
356 			"  float pixelSize = 2.0f/1024.0f;\n"
357 			"  gl_Position = gl_in[0].gl_Position + gl_TessCoord.x * pixelSize / 2.0f;\n"
358 			"}\n";
359 
360 		const string geometry =
361 			"#version 450\n"
362 			"#extension GL_KHR_shader_subgroup_ballot: enable\n"
363 			"layout(${TOPOLOGY}) in;\n"
364 			"layout(points, max_vertices = 1) out;\n"
365 			"layout(set = 0, binding = 3, std430) buffer Buffer1\n"
366 			"{\n"
367 			"  uint result[];\n"
368 			"};\n"
369 			"layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
370 			"{\n"
371 			"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data1[];\n"
372 			"};\n"
373 			"\n"
374 			"void main (void)\n"
375 			"{\n"
376 			+ bdyStr +
377 			"  result[gl_PrimitiveIDIn] = tempResult;\n"
378 			"  gl_Position = gl_in[0].gl_Position;\n"
379 			"  EmitVertex();\n"
380 			"  EndPrimitive();\n"
381 			"}\n";
382 
383 		const string fragment =
384 			"#version 450\n"
385 			"#extension GL_KHR_shader_subgroup_ballot: enable\n"
386 			"layout(location = 0) out uint result;\n"
387 			"layout(set = 0, binding = 4, std430) readonly buffer Buffer1\n"
388 			"{\n"
389 			"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data1[];\n"
390 			"};\n"
391 			"void main (void)\n"
392 			"{\n"
393 			+ bdyStr +
394 			"  result = tempResult;\n"
395 			"}\n";
396 
397 		subgroups::addNoSubgroupShader(programCollection);
398 
399 		programCollection.glslSources.add("vert")
400 				<< glu::VertexSource(vertex) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
401 		programCollection.glslSources.add("tesc")
402 				<< glu::TessellationControlSource(tesc) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
403 		programCollection.glslSources.add("tese")
404 				<< glu::TessellationEvaluationSource(tese) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
405 		subgroups::addGeometryShadersFromTemplate(geometry, vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u),
406 												  programCollection.glslSources);
407 		programCollection.glslSources.add("fragment")
408 				<< glu::FragmentSource(fragment)<< vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
409 	}
410 }
411 
supportedCheck(Context & context,CaseDefinition caseDef)412 void supportedCheck (Context& context, CaseDefinition caseDef)
413 {
414 	if (!subgroups::isSubgroupSupported(context))
415 		TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
416 
417 	if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_BALLOT_BIT))
418 	{
419 		TCU_THROW(NotSupportedError, "Device does not support subgroup ballot operations");
420 	}
421 
422 	if (subgroups::isDoubleFormat(caseDef.format) &&
423 		!subgroups::isDoubleSupportedForDevice(context))
424 	{
425 		TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
426 	}
427 }
428 
noSSBOtest(Context & context,const CaseDefinition caseDef)429 tcu::TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
430 {
431 	if (!subgroups::areSubgroupOperationsSupportedForStage(
432 			context, caseDef.shaderStage))
433 	{
434 		if (subgroups::areSubgroupOperationsRequiredForStage(caseDef.shaderStage))
435 		{
436 			return tcu::TestStatus::fail(
437 					   "Shader stage " +
438 					   subgroups::getShaderStageName(caseDef.shaderStage) +
439 					   " is required to support subgroup operations!");
440 		}
441 		else
442 		{
443 			TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
444 		}
445 	}
446 
447 	subgroups::SSBOData inputData[1];
448 	inputData[0].format = caseDef.format;
449 	inputData[0].layout = subgroups::SSBOData::LayoutStd140;
450 	inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
451 	inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
452 
453 	if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
454 		return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, inputData, 1, checkVertexPipelineStages);
455 	else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
456 		return subgroups::makeGeometryFrameBufferTest(context, VK_FORMAT_R32_UINT, inputData, 1, checkVertexPipelineStages);
457 	else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
458 		return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, inputData, 1, checkVertexPipelineStages, VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT);
459 	else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
460 		return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, inputData, 1, checkVertexPipelineStages, VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT);
461 	else
462 		TCU_THROW(InternalError, "Unhandled shader stage");
463 }
464 
465 
test(Context & context,const CaseDefinition caseDef)466 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
467 {
468 	if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
469 	{
470 		if (!subgroups::areSubgroupOperationsSupportedForStage(context, caseDef.shaderStage))
471 		{
472 			if (subgroups::areSubgroupOperationsRequiredForStage(caseDef.shaderStage))
473 			{
474 				return tcu::TestStatus::fail(
475 						   "Shader stage " +
476 						   subgroups::getShaderStageName(caseDef.shaderStage) +
477 						   " is required to support subgroup operations!");
478 			}
479 			else
480 			{
481 				TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
482 			}
483 		}
484 		subgroups::SSBOData inputData[1];
485 		inputData[0].format = caseDef.format;
486 		inputData[0].layout = subgroups::SSBOData::LayoutStd430;
487 		inputData[0].numElements = subgroups::maxSupportedSubgroupSize();
488 		inputData[0].initializeType = subgroups::SSBOData::InitializeNonZero;
489 
490 		return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, inputData, 1, checkCompute);
491 	}
492 	else
493 	{
494 		VkPhysicalDeviceSubgroupProperties subgroupProperties;
495 		subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
496 		subgroupProperties.pNext = DE_NULL;
497 
498 		VkPhysicalDeviceProperties2 properties;
499 		properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
500 		properties.pNext = &subgroupProperties;
501 
502 		context.getInstanceInterface().getPhysicalDeviceProperties2(context.getPhysicalDevice(), &properties);
503 
504 		VkShaderStageFlagBits stages = (VkShaderStageFlagBits)(caseDef.shaderStage  & subgroupProperties.supportedStages);
505 
506 		if ( VK_SHADER_STAGE_FRAGMENT_BIT != stages && !subgroups::isVertexSSBOSupportedForDevice(context))
507 		{
508 			if ( (stages & VK_SHADER_STAGE_FRAGMENT_BIT) == 0)
509 				TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
510 			else
511 				stages = VK_SHADER_STAGE_FRAGMENT_BIT;
512 		}
513 
514 		if ((VkShaderStageFlagBits)0u == stages)
515 			TCU_THROW(NotSupportedError, "Subgroup operations are not supported for any graphic shader");
516 
517 		subgroups::SSBOData inputData;
518 		inputData.format			= caseDef.format;
519 		inputData.layout			= subgroups::SSBOData::LayoutStd430;
520 		inputData.numElements		= subgroups::maxSupportedSubgroupSize();
521 		inputData.initializeType	= subgroups::SSBOData::InitializeNonZero;
522 		inputData.binding			= 4u;
523 		inputData.stages			= stages;
524 
525 		return subgroups::allStages(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, stages);
526 	}
527 }
528 }
529 
530 namespace vkt
531 {
532 namespace subgroups
533 {
createSubgroupsBallotBroadcastTests(tcu::TestContext & testCtx)534 tcu::TestCaseGroup* createSubgroupsBallotBroadcastTests(tcu::TestContext& testCtx)
535 {
536 	de::MovePtr<tcu::TestCaseGroup> graphicGroup(new tcu::TestCaseGroup(
537 		testCtx, "graphics", "Subgroup ballot broadcast category tests: graphics"));
538 	de::MovePtr<tcu::TestCaseGroup> computeGroup(new tcu::TestCaseGroup(
539 		testCtx, "compute", "Subgroup ballot broadcast category tests: compute"));
540 	de::MovePtr<tcu::TestCaseGroup> framebufferGroup(new tcu::TestCaseGroup(
541 		testCtx, "framebuffer", "Subgroup ballot broadcast category tests: framebuffer"));
542 
543 	const VkShaderStageFlags stages[] =
544 	{
545 		VK_SHADER_STAGE_VERTEX_BIT,
546 		VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
547 		VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
548 		VK_SHADER_STAGE_GEOMETRY_BIT,
549 	};
550 
551 	const VkFormat formats[] =
552 	{
553 		VK_FORMAT_R32_SINT, VK_FORMAT_R32G32_SINT, VK_FORMAT_R32G32B32_SINT,
554 		VK_FORMAT_R32G32B32A32_SINT, VK_FORMAT_R32_UINT, VK_FORMAT_R32G32_UINT,
555 		VK_FORMAT_R32G32B32_UINT, VK_FORMAT_R32G32B32A32_UINT,
556 		VK_FORMAT_R32_SFLOAT, VK_FORMAT_R32G32_SFLOAT,
557 		VK_FORMAT_R32G32B32_SFLOAT, VK_FORMAT_R32G32B32A32_SFLOAT,
558 		VK_FORMAT_R64_SFLOAT, VK_FORMAT_R64G64_SFLOAT,
559 		VK_FORMAT_R64G64B64_SFLOAT, VK_FORMAT_R64G64B64A64_SFLOAT,
560 		VK_FORMAT_R8_USCALED, VK_FORMAT_R8G8_USCALED,
561 		VK_FORMAT_R8G8B8_USCALED, VK_FORMAT_R8G8B8A8_USCALED,
562 	};
563 
564 	for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
565 	{
566 		const VkFormat format = formats[formatIndex];
567 
568 		for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
569 		{
570 			const std::string op = de::toLower(getOpTypeName(opTypeIndex));
571 			const std::string name = op + "_" + subgroups::getFormatNameForGLSL(format);
572 
573 			{
574 				CaseDefinition caseDef = {opTypeIndex, VK_SHADER_STAGE_COMPUTE_BIT, format};
575 				addFunctionCaseWithPrograms(computeGroup.get(), name, "", supportedCheck, initPrograms, test, caseDef);
576 			}
577 
578 			{
579 				const CaseDefinition caseDef = {opTypeIndex, VK_SHADER_STAGE_ALL_GRAPHICS, format};
580 				addFunctionCaseWithPrograms(graphicGroup.get(), name, "", supportedCheck, initPrograms, test, caseDef);
581 			}
582 
583 			for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
584 			{
585 				const CaseDefinition caseDef = {opTypeIndex, stages[stageIndex], format};
586 				addFunctionCaseWithPrograms(framebufferGroup.get(), name + getShaderStageName(caseDef.shaderStage), "",
587 							supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
588 			}
589 		}
590 	}
591 
592 	de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
593 		testCtx, "ballot_broadcast", "Subgroup ballot broadcast category tests"));
594 
595 	group->addChild(graphicGroup.release());
596 	group->addChild(computeGroup.release());
597 	group->addChild(framebufferGroup.release());
598 	return group.release();
599 }
600 
601 } // subgroups
602 } // vkt
603