1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2017 The Khronos Group Inc.
6  * Copyright (c) 2017 Codeplay Software Ltd.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  */ /*!
21  * \file
22  * \brief Subgroups Tests
23  */ /*--------------------------------------------------------------------*/
24 
25 #include "vktSubgroupsVoteTests.hpp"
26 #include "vktSubgroupsTestsUtils.hpp"
27 
28 #include <string>
29 #include <vector>
30 
31 using namespace tcu;
32 using namespace std;
33 using namespace vk;
34 using namespace vkt;
35 
36 namespace
37 {
38 enum OpType
39 {
40 	OPTYPE_ALL = 0,
41 	OPTYPE_ANY,
42 	OPTYPE_ALLEQUAL,
43 	OPTYPE_LAST
44 };
45 
checkVertexPipelineStages(std::vector<const void * > datas,deUint32 width,deUint32)46 static bool checkVertexPipelineStages(std::vector<const void*> datas,
47 									  deUint32 width, deUint32)
48 {
49 	return vkt::subgroups::check(datas, width, 0x1F);
50 }
51 
checkFragmentPipelineStages(std::vector<const void * > datas,deUint32 width,deUint32 height,deUint32)52 static bool checkFragmentPipelineStages(std::vector<const void*> datas,
53 									  deUint32 width, deUint32 height, deUint32)
54 {
55 	const deUint32* data =
56 		reinterpret_cast<const deUint32*>(datas[0]);
57 	for (deUint32 x = 0u; x < width; ++x)
58 	{
59 		for (deUint32 y = 0u; y < height; ++y)
60 		{
61 			const deUint32 ndx = (x * height + y);
62 			deUint32 val = data[ndx] & 0x1F;
63 
64 			if (data[ndx] & 0x40) //Helper fragment shader invocation was executed
65 			{
66 				if(val != 0x1F)
67 					return false;
68 			}
69 			else //Helper fragment shader invocation was not executed yet
70 			{
71 				if (val != 0x1E)
72 					return false;
73 			}
74 		}
75 	}
76 	return true;
77 }
78 
checkCompute(std::vector<const void * > datas,const deUint32 numWorkgroups[3],const deUint32 localSize[3],deUint32)79 static bool checkCompute(std::vector<const void*> datas,
80 						 const deUint32 numWorkgroups[3], const deUint32 localSize[3],
81 						 deUint32)
82 {
83 	return vkt::subgroups::checkCompute(datas, numWorkgroups, localSize, 0x1F);
84 }
85 
getOpTypeName(int opType)86 std::string getOpTypeName(int opType)
87 {
88 	switch (opType)
89 	{
90 		default:
91 			DE_FATAL("Unsupported op type");
92 			return "";
93 		case OPTYPE_ALL:
94 			return "subgroupAll";
95 		case OPTYPE_ANY:
96 			return "subgroupAny";
97 		case OPTYPE_ALLEQUAL:
98 			return "subgroupAllEqual";
99 	}
100 }
101 
102 struct CaseDefinition
103 {
104 	int					opType;
105 	VkShaderStageFlags	shaderStage;
106 	VkFormat			format;
107 };
108 
initFrameBufferPrograms(SourceCollections & programCollection,CaseDefinition caseDef)109 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
110 {
111 	const vk::ShaderBuildOptions buildOptions	(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
112 	const bool formatIsBoolean =
113 		VK_FORMAT_R8_USCALED == caseDef.format || VK_FORMAT_R8G8_USCALED == caseDef.format || VK_FORMAT_R8G8B8_USCALED == caseDef.format || VK_FORMAT_R8G8B8A8_USCALED == caseDef.format;
114 
115 	if (VK_SHADER_STAGE_FRAGMENT_BIT != caseDef.shaderStage)
116 		subgroups::setFragmentShaderFrameBuffer(programCollection);
117 
118 	if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
119 	{
120 		const string vertex	= "#version 450\n"
121 			"void main (void)\n"
122 			"{\n"
123 			"  vec2 uv = vec2(float(gl_VertexIndex & 1), float((gl_VertexIndex >> 1) & 1));\n"
124 			"  gl_Position = vec4(uv * 4.0f -2.0f, 0.0f, 1.0f);\n"
125 			"  gl_PointSize = 1.0f;\n"
126 			"}\n";
127 		programCollection.glslSources.add("vert") << glu::VertexSource(vertex) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
128 	}
129 	else if (VK_SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
130 		subgroups::setVertexShaderFrameBuffer(programCollection);
131 
132 	const string source =
133 		(OPTYPE_ALL == caseDef.opType) ?
134 			"  result = " + getOpTypeName(caseDef.opType) +
135 			"(true) ? 0x1 : 0;\n"
136 			"  result |= " + getOpTypeName(caseDef.opType) +
137 			"(false) ? 0 : 0x1A;\n"
138 			"  result |= 0x4;\n"
139 		: (OPTYPE_ANY == caseDef.opType) ?
140 				"  result = " + getOpTypeName(caseDef.opType) +
141 				"(true) ? 0x1 : 0;\n"
142 				"  result |= " + getOpTypeName(caseDef.opType) +
143 				"(false) ? 0 : 0x1A;\n"
144 				"  result |= 0x4;\n"
145 		: (OPTYPE_ALLEQUAL == caseDef.opType) ?
146 				"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " valueEqual = " + subgroups::getFormatNameForGLSL(caseDef.format) + "(1.25 * float(data[gl_SubgroupInvocationID]) + 5.0);\n" +
147 				"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " valueNoEqual = " + subgroups::getFormatNameForGLSL(caseDef.format) + (formatIsBoolean ? "(subgroupElect())\n;" : "(12.0 * float(data[gl_SubgroupInvocationID]) + gl_SubgroupInvocationID);\n") +
148 				"  result = " + getOpTypeName(caseDef.opType) + "("
149 				+ subgroups::getFormatNameForGLSL(caseDef.format) + "(1)) ? 0x1 : 0;\n"
150 				"  result |= " + getOpTypeName(caseDef.opType) +
151 				"(gl_SubgroupInvocationID) ? 0 : 0x2;\n"
152 				"  result |= " + getOpTypeName(caseDef.opType) +
153 				"(data[0]) ? 0x4 : 0;\n"
154 				"  result |= " + getOpTypeName(caseDef.opType) +
155 				"(valueEqual) ? 0x8 : 0x0;\n"
156 				"  result |= " + getOpTypeName(caseDef.opType) +
157 				"(valueNoEqual) ? 0x0 : 0x10;\n"
158 				"  if (subgroupElect()) result |= 0x2 | 0x10;\n"
159 		: "";
160 
161 	if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
162 	{
163 		std::ostringstream vertexSrc;
164 		vertexSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
165 			<< "#extension GL_KHR_shader_subgroup_vote: enable\n"
166 			<< "layout(location = 0) out vec4 out_color;\n"
167 			<< "layout(location = 0) in highp vec4 in_position;\n"
168 			<< "layout(set = 0, binding = 0) uniform Buffer1\n"
169 			<< "{\n"
170 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
171 			<< "};\n"
172 			<< "\n"
173 			<< "void main (void)\n"
174 			<< "{\n"
175 			<< "  uint result;\n"
176 			<< source
177 			<< "  out_color.r = float(result);\n"
178 			<< "  gl_Position = in_position;\n"
179 			<< "  gl_PointSize = 1.0f;\n"
180 			<< "}\n";
181 
182 		programCollection.glslSources.add("vert") << glu::VertexSource(vertexSrc.str()) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
183 	}
184 	else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
185 	{
186 		std::ostringstream geometry;
187 
188 		geometry << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
189 			<< "#extension GL_KHR_shader_subgroup_vote: enable\n"
190 			<< "layout(points) in;\n"
191 			<< "layout(points, max_vertices = 1) out;\n"
192 			<< "layout(location = 0) out float out_color;\n"
193 			<< "layout(set = 0, binding = 0) uniform Buffer1\n"
194 			<< "{\n"
195 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
196 			<< "};\n"
197 			<< "\n"
198 			<< "void main (void)\n"
199 			<< "{\n"
200 			<< "  uint result;\n"
201 			<< source
202 			<< "  out_color = float(result);\n"
203 			<< "  gl_Position = gl_in[0].gl_Position;\n"
204 			<< "  EmitVertex();\n"
205 			<< "  EndPrimitive();\n"
206 			<< "}\n";
207 
208 		programCollection.glslSources.add("geometry")
209 			<< glu::GeometrySource(geometry.str()) << buildOptions;
210 	}
211 	else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
212 	{
213 		std::ostringstream controlSource;
214 		controlSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
215 			<< "#extension GL_KHR_shader_subgroup_vote: enable\n"
216 			<< "layout(vertices = 2) out;\n"
217 			<< "layout(location = 0) out float out_color[];\n"
218 			<< "layout(set = 0, binding = 0) uniform Buffer1\n"
219 			<< "{\n"
220 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
221 			<< "};\n"
222 			<< "\n"
223 			<< "void main (void)\n"
224 			<< "{\n"
225 			<< "  uint result;\n"
226 			<< "  if (gl_InvocationID == 0)\n"
227 			<<"  {\n"
228 			<< "    gl_TessLevelOuter[0] = 1.0f;\n"
229 			<< "    gl_TessLevelOuter[1] = 1.0f;\n"
230 			<< "  }\n"
231 			<< source
232 			<< "  out_color[gl_InvocationID] = float(result);"
233 			<< "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
234 			<< "}\n";
235 
236 		programCollection.glslSources.add("tesc")
237 			<< glu::TessellationControlSource(controlSource.str()) << buildOptions;
238 		subgroups::setTesEvalShaderFrameBuffer(programCollection);
239 	}
240 	else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
241 	{
242 		std::ostringstream evaluationSource;
243 		evaluationSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
244 			<< "#extension GL_KHR_shader_subgroup_vote: enable\n"
245 			<< "#extension GL_EXT_tessellation_shader : require\n"
246 			<< "layout(isolines, equal_spacing, ccw ) in;\n"
247 			<< "layout(location = 0) out float out_color;\n"
248 			<< "layout(set = 0, binding = 0) uniform Buffer1\n"
249 			<< "{\n"
250 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
251 			<< "};\n"
252 			<< "\n"
253 			<< "void main (void)\n"
254 			<< "{\n"
255 			<< "  uint result;\n"
256 			<< "  highp uint offset = gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5);\n"
257 			<< source
258 			<< "  out_color = float(result);\n"
259 			<< "  gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
260 			<< "}\n";
261 
262 		subgroups::setTesCtrlShaderFrameBuffer(programCollection);
263 		programCollection.glslSources.add("tese")
264 				<< glu::TessellationEvaluationSource(evaluationSource.str()) << buildOptions;
265 	}
266 	else if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
267 	{
268 		const string sourceFragment =
269 		(OPTYPE_ALL == caseDef.opType) ?
270 			"  result |= " + getOpTypeName(caseDef.opType) +
271 			"(!gl_HelperInvocation) ? 0x0 : 0x1;\n"
272 			"  result |= " + getOpTypeName(caseDef.opType) +
273 			"(false) ? 0 : 0x1A;\n"
274 			"  result |= 0x4;\n"
275 		: (OPTYPE_ANY == caseDef.opType) ?
276 				"  result |= " + getOpTypeName(caseDef.opType) +
277 				"(gl_HelperInvocation) ? 0x1 : 0x0;\n"
278 				"  result |= " + getOpTypeName(caseDef.opType) +
279 				"(false) ? 0 : 0x1A;\n"
280 				"  result |= 0x4;\n"
281 		: (OPTYPE_ALLEQUAL == caseDef.opType) ?
282 				"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " valueEqual = " + subgroups::getFormatNameForGLSL(caseDef.format) + "(1.25 * float(data[gl_SubgroupInvocationID]) + 5.0);\n" +
283 				"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " valueNoEqual = " + subgroups::getFormatNameForGLSL(caseDef.format) + (formatIsBoolean ? "(subgroupElect());\n" : "(12.0 * float(data[gl_SubgroupInvocationID]) + int(gl_FragCoord.x*gl_SubgroupInvocationID));\n") +
284 				"  result |= " + getOpTypeName(caseDef.opType) + "("
285 				+ subgroups::getFormatNameForGLSL(caseDef.format) + "(1)) ? 0x10 : 0;\n"
286 				"  result |= " + getOpTypeName(caseDef.opType) +
287 				"(gl_SubgroupInvocationID) ? 0 : 0x2;\n"
288 				"  result |= " + getOpTypeName(caseDef.opType) +
289 				"(data[0]) ? 0x4 : 0;\n"
290 				"  result |= " + getOpTypeName(caseDef.opType) +
291 				"(valueEqual) ? 0x8 : 0x0;\n"
292 				"  result |= " + getOpTypeName(caseDef.opType) +
293 				"(gl_HelperInvocation) ? 0x0 : 0x1;\n"
294 				"  if (subgroupElect()) result |= 0x2 | 0x10;\n"
295 		: "";
296 
297 		std::ostringstream fragmentSource;
298 		fragmentSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
299 		<< "#extension GL_KHR_shader_subgroup_vote: enable\n"
300 		<< "layout(location = 0) out uint out_color;\n"
301 		<< "layout(set = 0, binding = 0) uniform Buffer1\n"
302 		<< "{\n"
303 		<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
304 		<< "};\n"
305 		<< ""
306 		<< "void main()\n"
307 		<< "{\n"
308 		<< "  uint result = 0u;\n"
309 		<< "  if (dFdx(gl_SubgroupInvocationID * gl_FragCoord.x * gl_FragCoord.y) - dFdy(gl_SubgroupInvocationID * gl_FragCoord.x * gl_FragCoord.y) > 0.0f)\n"
310 		<< "  {\n"
311 		<< "    result |= 0x20;\n" // to be sure that compiler doesn't remove dFdx and dFdy executions
312 		<< "  }\n"
313 		<< "  bool helper = subgroupAny(gl_HelperInvocation);\n"
314 		<< "  if (helper)\n"
315 		<< "  {\n"
316 		<< "    result |= 0x40;\n"
317 		<< "  }\n"
318 		<< sourceFragment
319 		<< "  out_color = result;\n"
320 		<< "}\n";
321 
322 		programCollection.glslSources.add("fragment")
323 			<< glu::FragmentSource(fragmentSource.str())<< vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
324 	}
325 	else
326 	{
327 		DE_FATAL("Unsupported shader stage");
328 	}
329 }
330 
initPrograms(SourceCollections & programCollection,CaseDefinition caseDef)331 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
332 {
333 	const bool formatIsBoolean =
334 		VK_FORMAT_R8_USCALED == caseDef.format || VK_FORMAT_R8G8_USCALED == caseDef.format || VK_FORMAT_R8G8B8_USCALED == caseDef.format || VK_FORMAT_R8G8B8A8_USCALED == caseDef.format;
335 	if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
336 	{
337 		std::ostringstream src;
338 
339 		src << "#version 450\n"
340 			<< "#extension GL_KHR_shader_subgroup_vote: enable\n"
341 			<< "layout (local_size_x_id = 0, local_size_y_id = 1, "
342 			"local_size_z_id = 2) in;\n"
343 			<< "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
344 			<< "{\n"
345 			<< "  uint result[];\n"
346 			<< "};\n"
347 			<< "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
348 			<< "{\n"
349 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
350 			<< "};\n"
351 			<< "\n"
352 			<< "void main (void)\n"
353 			<< "{\n"
354 			<< "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
355 			<< "  highp uint offset = globalSize.x * ((globalSize.y * "
356 			"gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
357 			"gl_GlobalInvocationID.x;\n";
358 		if (OPTYPE_ALL == caseDef.opType)
359 		{
360 			src << "  result[offset] = " << getOpTypeName(caseDef.opType)
361 				<< "(true) ? 0x1 : 0;\n"
362 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
363 				<< "(false) ? 0 : 0x1A;\n"
364 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
365 				<< "(data[gl_SubgroupInvocationID] > 0) ? 0x4 : 0;\n";
366 		}
367 		else if (OPTYPE_ANY == caseDef.opType)
368 		{
369 			src << "  result[offset] = " << getOpTypeName(caseDef.opType)
370 				<< "(true) ? 0x1 : 0;\n"
371 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
372 				<< "(false) ? 0 : 0x1A;\n"
373 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
374 				<< "(data[gl_SubgroupInvocationID] == data[0]) ? 0x4 : 0;\n";
375 		}
376 
377 		else if (OPTYPE_ALLEQUAL == caseDef.opType)
378 		{
379 			src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) <<" valueEqual = " << subgroups::getFormatNameForGLSL(caseDef.format) << "(1.25 * float(data[gl_SubgroupInvocationID]) + 5.0);\n"
380 				<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) <<" valueNoEqual = " << subgroups::getFormatNameForGLSL(caseDef.format) << (formatIsBoolean ? "(subgroupElect());\n" : "(12.0 * float(data[gl_SubgroupInvocationID]) + offset);\n")
381 				<<"  result[offset] = " << getOpTypeName(caseDef.opType) << "("
382 				<< subgroups::getFormatNameForGLSL(caseDef.format) << "(1)) ? 0x1 : 0x0;\n"
383 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
384 				<< "(gl_SubgroupInvocationID) ? 0x0 : 0x2;\n"
385 				<< "  result[offset] |= " << getOpTypeName(caseDef.opType)
386 				<< "(data[0]) ? 0x4 : 0x0;\n"
387 				<< "  result[offset] |= "<< getOpTypeName(caseDef.opType)
388 				<< "(valueEqual) ? 0x8 : 0x0;\n"
389 				<< "  result[offset] |= "<< getOpTypeName(caseDef.opType)
390 				<< "(valueNoEqual) ? 0x0 : 0x10;\n"
391 				<< "  if (subgroupElect()) result[offset] |= 0x2 | 0x10;\n";
392 		}
393 
394 		src << "}\n";
395 
396 		programCollection.glslSources.add("comp")
397 				<< glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
398 	}
399 	else
400 	{
401 		const string source =
402 		(OPTYPE_ALL == caseDef.opType) ?
403 			"  result[offset] = " + getOpTypeName(caseDef.opType) +
404 			"(true) ? 0x1 : 0;\n"
405 			"  result[offset] |= " + getOpTypeName(caseDef.opType) +
406 			"(false) ? 0 : 0x1A;\n"
407 			"  result[offset] |= 0x4;\n"
408 		: (OPTYPE_ANY == caseDef.opType) ?
409 				"  result[offset] = " + getOpTypeName(caseDef.opType) +
410 				"(true) ? 0x1 : 0;\n"
411 				"  result[offset] |= " + getOpTypeName(caseDef.opType) +
412 				"(false) ? 0 : 0x1A;\n"
413 				"  result[offset] |= 0x4;\n"
414 		: (OPTYPE_ALLEQUAL == caseDef.opType) ?
415 				"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " valueEqual = " + subgroups::getFormatNameForGLSL(caseDef.format) + "(1.25 * float(data[gl_SubgroupInvocationID]) + 5.0);\n" +
416 				"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " valueNoEqual = " + subgroups::getFormatNameForGLSL(caseDef.format) + (formatIsBoolean ? "(subgroupElect());\n" : "(12.0 * float(data[gl_SubgroupInvocationID]) + gl_SubgroupInvocationID);\n") +
417 				"  result[offset] = " + getOpTypeName(caseDef.opType) + "("
418 				+ subgroups::getFormatNameForGLSL(caseDef.format) + "(1)) ? 0x1 : 0;\n"
419 				"  result[offset] |= " + getOpTypeName(caseDef.opType) +
420 				"(gl_SubgroupInvocationID) ? 0 : 0x2;\n"
421 				"  result[offset] |= " + getOpTypeName(caseDef.opType) +
422 				"(data[0]) ? 0x4 : 0;\n"
423 				"  result[offset] |= " + getOpTypeName(caseDef.opType) +
424 				"(valueEqual) ? 0x8 : 0x0;\n"
425 				"  result[offset] |= " + getOpTypeName(caseDef.opType) +
426 				"(valueNoEqual) ? 0x0 : 0x10;\n"
427 				"  if (subgroupElect()) result[offset] |= 0x2 | 0x10;\n"
428 		: "";
429 
430 		const string formatString = subgroups::getFormatNameForGLSL(caseDef.format);
431 
432 		{
433 			const string vertex =
434 				"#version 450\n"
435 				"#extension GL_KHR_shader_subgroup_vote: enable\n"
436 				"layout(set = 0, binding = 0, std430) buffer Buffer1\n"
437 				"{\n"
438 				"  uint result[];\n"
439 				"};\n"
440 				"layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
441 				"{\n"
442 				"  " + formatString + " data[];\n"
443 				"};\n"
444 				"\n"
445 				"void main (void)\n"
446 				"{\n"
447 				"  highp uint offset = gl_VertexIndex;\n"
448 				+ source +
449 				"  float pixelSize = 2.0f/1024.0f;\n"
450 				"  float pixelPosition = pixelSize/2.0f - 1.0f;\n"
451 				"  gl_Position = vec4(float(gl_VertexIndex) * pixelSize + pixelPosition, 0.0f, 0.0f, 1.0f);\n"
452 				"  gl_PointSize = 1.0f;\n"
453 				"}\n";
454 			programCollection.glslSources.add("vert")
455 				<< glu::VertexSource(vertex) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
456 		}
457 
458 		{
459 			const string tesc =
460 				"#version 450\n"
461 				"#extension GL_KHR_shader_subgroup_vote: enable\n"
462 				"layout(vertices=1) out;\n"
463 				"layout(set = 0, binding = 1, std430) buffer Buffer1\n"
464 				"{\n"
465 				"  uint result[];\n"
466 				"};\n"
467 				"layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
468 				"{\n"
469 				"  " + formatString + " data[];\n"
470 				"};\n"
471 				"\n"
472 				"void main (void)\n"
473 				"{\n"
474 				"  highp uint offset = gl_PrimitiveID;\n"
475 				+ source +
476 				"  if (gl_InvocationID == 0)\n"
477 				"  {\n"
478 				"    gl_TessLevelOuter[0] = 1.0f;\n"
479 				"    gl_TessLevelOuter[1] = 1.0f;\n"
480 				"  }\n"
481 				"  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
482 				"}\n";
483 
484 			programCollection.glslSources.add("tesc")
485 					<< glu::TessellationControlSource(tesc) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
486 		}
487 
488 		{
489 			const string tese =
490 				"#version 450\n"
491 				"#extension GL_KHR_shader_subgroup_vote: enable\n"
492 				"layout(isolines) in;\n"
493 				"layout(set = 0, binding = 2, std430) buffer Buffer1\n"
494 				"{\n"
495 				"  uint result[];\n"
496 				"};\n"
497 				"layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
498 				"{\n"
499 				"  " + formatString + " data[];\n"
500 				"};\n"
501 				"\n"
502 				"void main (void)\n"
503 				"{\n"
504 				"  highp uint offset = gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5);\n"
505 				+ source +
506 				"  float pixelSize = 2.0f/1024.0f;\n"
507 				"  gl_Position = gl_in[0].gl_Position + gl_TessCoord.x * pixelSize / 2.0f;\n"
508 				"}\n";
509 
510 			programCollection.glslSources.add("tese")
511 					<< glu::TessellationEvaluationSource(tese) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
512 		}
513 
514 		{
515 			const string geometry =
516 				"#version 450\n"
517 				"#extension GL_KHR_shader_subgroup_vote: enable\n"
518 				"layout(${TOPOLOGY}) in;\n"
519 				"layout(points, max_vertices = 1) out;\n"
520 				"layout(set = 0, binding = 3, std430) buffer Buffer1\n"
521 				"{\n"
522 				"  uint result[];\n"
523 				"};\n"
524 				"layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
525 				"{\n"
526 				"  " + formatString + " data[];\n"
527 				"};\n"
528 				"\n"
529 				"void main (void)\n"
530 				"{\n"
531 				"  highp uint offset = gl_PrimitiveIDIn;\n"
532 				+ source +
533 				"  gl_Position = gl_in[0].gl_Position;\n"
534 				"  EmitVertex();\n"
535 				"  EndPrimitive();\n"
536 				"}\n";
537 
538 			subgroups::addGeometryShadersFromTemplate(geometry, vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u),
539 													  programCollection.glslSources);
540 		}
541 
542 		{
543 			const string sourceFragment =
544 			(OPTYPE_ALL == caseDef.opType) ?
545 				"  result = " + getOpTypeName(caseDef.opType) +
546 				"(true) ? 0x1 : 0;\n"
547 				"  result |= " + getOpTypeName(caseDef.opType) +
548 				"(false) ? 0 : 0x1A;\n"
549 				"  result |= 0x4;\n"
550 			: (OPTYPE_ANY == caseDef.opType) ?
551 					"  result = " + getOpTypeName(caseDef.opType) +
552 					"(true) ? 0x1 : 0;\n"
553 					"  result |= " + getOpTypeName(caseDef.opType) +
554 					"(false) ? 0 : 0x1A;\n"
555 					"  result |= 0x4;\n"
556 			: (OPTYPE_ALLEQUAL == caseDef.opType) ?
557 					"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " valueEqual = " + subgroups::getFormatNameForGLSL(caseDef.format) + "(1.25 * float(data[gl_SubgroupInvocationID]) + 5.0);\n" +
558 					"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " valueNoEqual = " + subgroups::getFormatNameForGLSL(caseDef.format) + (formatIsBoolean ? "(subgroupElect());\n" : "(12.0 * float(data[gl_SubgroupInvocationID]) + int(gl_FragCoord.x*gl_SubgroupInvocationID));\n") +
559 					"  result = " + getOpTypeName(caseDef.opType) + "("
560 					+ subgroups::getFormatNameForGLSL(caseDef.format) + "(1)) ? 0x1 : 0;\n"
561 					"  result |= " + getOpTypeName(caseDef.opType) +
562 					"(gl_SubgroupInvocationID) ? 0 : 0x2;\n"
563 					"  result |= " + getOpTypeName(caseDef.opType) +
564 					"(data[0]) ? 0x4 : 0;\n"
565 					"  result |= " + getOpTypeName(caseDef.opType) +
566 					"(valueEqual) ? 0x8 : 0x0;\n"
567 					"  result |= " + getOpTypeName(caseDef.opType) +
568 					"(valueNoEqual) ? 0x0 : 0x10;\n"
569 					"  if (subgroupElect()) result |= 0x2 | 0x10;\n"
570 			: "";
571 			const string fragment =
572 				"#version 450\n"
573 				"#extension GL_KHR_shader_subgroup_vote: enable\n"
574 				"layout(location = 0) out uint result;\n"
575 				"layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
576 				"{\n"
577 				"  " + formatString + " data[];\n"
578 				"};\n"
579 				"void main (void)\n"
580 				"{\n"
581 				+ sourceFragment +
582 				"}\n";
583 
584 			programCollection.glslSources.add("fragment")
585 				<< glu::FragmentSource(fragment)<< vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
586 		}
587 
588 		subgroups::addNoSubgroupShader(programCollection);
589 	}
590 }
591 
supportedCheck(Context & context,CaseDefinition caseDef)592 void supportedCheck (Context& context, CaseDefinition caseDef)
593 {
594 	if (!subgroups::isSubgroupSupported(context))
595 		TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
596 
597 	if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_VOTE_BIT))
598 	{
599 		TCU_THROW(NotSupportedError, "Device does not support subgroup vote operations");
600 	}
601 
602 	if (subgroups::isDoubleFormat(caseDef.format) &&
603 			!subgroups::isDoubleSupportedForDevice(context))
604 	{
605 		TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
606 	}
607 }
608 
noSSBOtest(Context & context,const CaseDefinition caseDef)609 tcu::TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
610 {
611 	if (!subgroups::areSubgroupOperationsSupportedForStage(
612 				context, caseDef.shaderStage))
613 	{
614 		if (subgroups::areSubgroupOperationsRequiredForStage(
615 					caseDef.shaderStage))
616 		{
617 			return tcu::TestStatus::fail(
618 					   "Shader stage " +
619 					   subgroups::getShaderStageName(caseDef.shaderStage) +
620 					   " is required to support subgroup operations!");
621 		}
622 		else
623 		{
624 			TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
625 		}
626 	}
627 
628 	subgroups::SSBOData inputData;
629 	inputData.format = caseDef.format;
630 	inputData.layout = subgroups::SSBOData::LayoutStd140;
631 	inputData.numElements = subgroups::maxSupportedSubgroupSize();
632 	inputData.initializeType = OPTYPE_ALLEQUAL == caseDef.opType ? subgroups::SSBOData::InitializeZero : subgroups::SSBOData::InitializeNonZero;
633 
634 	if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
635 		return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
636 	else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
637 		return subgroups::makeGeometryFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
638 	else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
639 		return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT);
640 	else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
641 		return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT);
642 	else if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
643 		return subgroups::makeFragmentFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkFragmentPipelineStages);
644 	else
645 		TCU_THROW(InternalError, "Unhandled shader stage");
646 }
647 
648 
test(Context & context,const CaseDefinition caseDef)649 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
650 {
651 	if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
652 	{
653 		if (!subgroups::areSubgroupOperationsSupportedForStage(context, caseDef.shaderStage))
654 		{
655 			return tcu::TestStatus::fail(
656 					   "Shader stage " +
657 					   subgroups::getShaderStageName(caseDef.shaderStage) +
658 					   " is required to support subgroup operations!");
659 		}
660 
661 		subgroups::SSBOData inputData;
662 		inputData.format = caseDef.format;
663 		inputData.layout = subgroups::SSBOData::LayoutStd430;
664 		inputData.numElements = subgroups::maxSupportedSubgroupSize();
665 		inputData.initializeType = OPTYPE_ALLEQUAL == caseDef.opType ? subgroups::SSBOData::InitializeZero : subgroups::SSBOData::InitializeNonZero;
666 
667 		return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData,
668 										  1, checkCompute);
669 	}
670 	else
671 	{
672 		VkPhysicalDeviceSubgroupProperties subgroupProperties;
673 		subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
674 		subgroupProperties.pNext = DE_NULL;
675 
676 		VkPhysicalDeviceProperties2 properties;
677 		properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
678 		properties.pNext = &subgroupProperties;
679 
680 		context.getInstanceInterface().getPhysicalDeviceProperties2(context.getPhysicalDevice(), &properties);
681 
682 		VkShaderStageFlagBits stages = (VkShaderStageFlagBits)(caseDef.shaderStage  & subgroupProperties.supportedStages);
683 
684 		if (VK_SHADER_STAGE_FRAGMENT_BIT != stages && !subgroups::isVertexSSBOSupportedForDevice(context))
685 		{
686 			if ( (stages & VK_SHADER_STAGE_FRAGMENT_BIT) == 0)
687 				TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
688 			else
689 				stages = VK_SHADER_STAGE_FRAGMENT_BIT;
690 		}
691 
692 		if ((VkShaderStageFlagBits)0u == stages)
693 			TCU_THROW(NotSupportedError, "Subgroup operations are not supported for any graphic shader");
694 
695 		subgroups::SSBOData inputData;
696 		inputData.format			= caseDef.format;
697 		inputData.layout			= subgroups::SSBOData::LayoutStd430;
698 		inputData.numElements		= subgroups::maxSupportedSubgroupSize();
699 		inputData.initializeType	= OPTYPE_ALLEQUAL == caseDef.opType ? subgroups::SSBOData::InitializeZero : subgroups::SSBOData::InitializeNonZero;
700 		inputData.binding			= 4u;
701 		inputData.stages			= stages;
702 
703 		return subgroups::allStages(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, stages);
704 	}
705 }
706 }
707 
708 namespace vkt
709 {
710 namespace subgroups
711 {
createSubgroupsVoteTests(tcu::TestContext & testCtx)712 tcu::TestCaseGroup* createSubgroupsVoteTests(tcu::TestContext& testCtx)
713 {
714 	de::MovePtr<tcu::TestCaseGroup> graphicGroup(new tcu::TestCaseGroup(
715 		testCtx, "graphics", "Subgroup arithmetic category tests: graphics"));
716 	de::MovePtr<tcu::TestCaseGroup> computeGroup(new tcu::TestCaseGroup(
717 		testCtx, "compute", "Subgroup arithmetic category tests: compute"));
718 	de::MovePtr<tcu::TestCaseGroup> framebufferGroup(new tcu::TestCaseGroup(
719 		testCtx, "framebuffer", "Subgroup arithmetic category tests: framebuffer"));
720 
721 	de::MovePtr<tcu::TestCaseGroup> fragHelperGroup(new tcu::TestCaseGroup(
722 		testCtx, "frag_helper", "Subgroup arithmetic category tests: fragment helper invocation"));
723 
724 	const VkShaderStageFlags stages[] =
725 	{
726 		VK_SHADER_STAGE_VERTEX_BIT,
727 		VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
728 		VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
729 		VK_SHADER_STAGE_GEOMETRY_BIT,
730 	};
731 
732 	const VkFormat formats[] =
733 	{
734 		VK_FORMAT_R32_SINT, VK_FORMAT_R32G32_SINT, VK_FORMAT_R32G32B32_SINT,
735 		VK_FORMAT_R32G32B32A32_SINT, VK_FORMAT_R32_UINT, VK_FORMAT_R32G32_UINT,
736 		VK_FORMAT_R32G32B32_UINT, VK_FORMAT_R32G32B32A32_UINT,
737 		VK_FORMAT_R32_SFLOAT, VK_FORMAT_R32G32_SFLOAT,
738 		VK_FORMAT_R32G32B32_SFLOAT, VK_FORMAT_R32G32B32A32_SFLOAT,
739 		VK_FORMAT_R64_SFLOAT, VK_FORMAT_R64G64_SFLOAT,
740 		VK_FORMAT_R64G64B64_SFLOAT, VK_FORMAT_R64G64B64A64_SFLOAT,
741 		VK_FORMAT_R8_USCALED, VK_FORMAT_R8G8_USCALED,
742 		VK_FORMAT_R8G8B8_USCALED, VK_FORMAT_R8G8B8A8_USCALED,
743 	};
744 
745 	for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
746 	{
747 		const VkFormat format = formats[formatIndex];
748 
749 		for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
750 		{
751 			// Skip the typed tests for all but subgroupAllEqual()
752 			if ((VK_FORMAT_R32_UINT != format) && (OPTYPE_ALLEQUAL != opTypeIndex))
753 			{
754 				continue;
755 			}
756 
757 			const std::string op = de::toLower(getOpTypeName(opTypeIndex));
758 
759 			{
760 				const CaseDefinition caseDef = {opTypeIndex, VK_SHADER_STAGE_COMPUTE_BIT, format};
761 				addFunctionCaseWithPrograms(computeGroup.get(),
762 											op + "_" + subgroups::getFormatNameForGLSL(format),
763 											"", supportedCheck, initPrograms, test, caseDef);
764 			}
765 
766 			{
767 				const CaseDefinition caseDef = {opTypeIndex, VK_SHADER_STAGE_ALL_GRAPHICS, format};
768 				addFunctionCaseWithPrograms(graphicGroup.get(),
769 											op + "_" + subgroups::getFormatNameForGLSL(format),
770 											"", supportedCheck, initPrograms, test, caseDef);
771 			}
772 
773 			for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
774 			{
775 				const CaseDefinition caseDef = {opTypeIndex, stages[stageIndex], format};
776 				addFunctionCaseWithPrograms(framebufferGroup.get(),
777 							op + "_" +
778 							subgroups::getFormatNameForGLSL(format)
779 							+ "_" + getShaderStageName(caseDef.shaderStage), "",
780 							supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
781 			}
782 
783 			const CaseDefinition caseDef = {opTypeIndex, VK_SHADER_STAGE_FRAGMENT_BIT, format};
784 			addFunctionCaseWithPrograms(fragHelperGroup.get(),
785 						op + "_" +
786 						subgroups::getFormatNameForGLSL(format)
787 						+ "_" + getShaderStageName(caseDef.shaderStage), "",
788 						supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
789 		}
790 	}
791 
792 	de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
793 		testCtx, "vote", "Subgroup vote category tests"));
794 
795 	group->addChild(graphicGroup.release());
796 	group->addChild(computeGroup.release());
797 	group->addChild(framebufferGroup.release());
798 	group->addChild(fragHelperGroup.release());
799 
800 	return group.release();
801 }
802 
803 } // subgroups
804 } // vkt
805