1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2017 The Khronos Group Inc.
6  * Copyright (c) 2017 Codeplay Software Ltd.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  */ /*!
21  * \file
22  * \brief Subgroups Tests
23  */ /*--------------------------------------------------------------------*/
24 
25 #include "vktSubgroupsQuadTests.hpp"
26 #include "vktSubgroupsTestsUtils.hpp"
27 
28 #include <string>
29 #include <vector>
30 
31 using namespace tcu;
32 using namespace std;
33 using namespace vk;
34 using namespace vkt;
35 
36 namespace
37 {
38 enum OpType
39 {
40 	OPTYPE_QUAD_BROADCAST = 0,
41 	OPTYPE_QUAD_SWAP_HORIZONTAL,
42 	OPTYPE_QUAD_SWAP_VERTICAL,
43 	OPTYPE_QUAD_SWAP_DIAGONAL,
44 	OPTYPE_LAST
45 };
46 
checkVertexPipelineStages(std::vector<const void * > datas,deUint32 width,deUint32)47 static bool checkVertexPipelineStages(std::vector<const void*> datas,
48 									  deUint32 width, deUint32)
49 {
50 	return vkt::subgroups::check(datas, width, 1);
51 }
52 
checkCompute(std::vector<const void * > datas,const deUint32 numWorkgroups[3],const deUint32 localSize[3],deUint32)53 static bool checkCompute(std::vector<const void*> datas,
54 						 const deUint32 numWorkgroups[3], const deUint32 localSize[3],
55 						 deUint32)
56 {
57 	return vkt::subgroups::checkCompute(datas, numWorkgroups, localSize, 1);
58 }
59 
getOpTypeName(int opType)60 std::string getOpTypeName(int opType)
61 {
62 	switch (opType)
63 	{
64 		default:
65 			DE_FATAL("Unsupported op type");
66 			return "";
67 		case OPTYPE_QUAD_BROADCAST:
68 			return "subgroupQuadBroadcast";
69 		case OPTYPE_QUAD_SWAP_HORIZONTAL:
70 			return "subgroupQuadSwapHorizontal";
71 		case OPTYPE_QUAD_SWAP_VERTICAL:
72 			return "subgroupQuadSwapVertical";
73 		case OPTYPE_QUAD_SWAP_DIAGONAL:
74 			return "subgroupQuadSwapDiagonal";
75 	}
76 }
77 
78 struct CaseDefinition
79 {
80 	int					opType;
81 	VkShaderStageFlags	shaderStage;
82 	VkFormat			format;
83 	int					direction;
84 };
85 
initFrameBufferPrograms(SourceCollections & programCollection,CaseDefinition caseDef)86 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
87 {
88 	const vk::ShaderBuildOptions	buildOptions	(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
89 	std::string			swapTable[OPTYPE_LAST];
90 
91 	subgroups::setFragmentShaderFrameBuffer(programCollection);
92 
93 	if (VK_SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
94 		subgroups::setVertexShaderFrameBuffer(programCollection);
95 
96 	swapTable[OPTYPE_QUAD_BROADCAST] = "";
97 	swapTable[OPTYPE_QUAD_SWAP_HORIZONTAL] = "  const uint swapTable[4] = {1, 0, 3, 2};\n";
98 	swapTable[OPTYPE_QUAD_SWAP_VERTICAL] = "  const uint swapTable[4] = {2, 3, 0, 1};\n";
99 	swapTable[OPTYPE_QUAD_SWAP_DIAGONAL] = "  const uint swapTable[4] = {3, 2, 1, 0};\n";
100 
101 	if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
102 	{
103 		std::ostringstream	vertexSrc;
104 		vertexSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
105 			<< "#extension GL_KHR_shader_subgroup_quad: enable\n"
106 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
107 			<< "layout(location = 0) in highp vec4 in_position;\n"
108 			<< "layout(location = 0) out float result;\n"
109 			<< "layout(set = 0, binding = 0) uniform Buffer1\n"
110 			<< "{\n"
111 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
112 			<< "};\n"
113 			<< "\n"
114 			<< "void main (void)\n"
115 			<< "{\n"
116 			<< "  uvec4 mask = subgroupBallot(true);\n"
117 			<< swapTable[caseDef.opType];
118 
119 		if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
120 		{
121 			vertexSrc << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
122 				<< getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
123 				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
124 		}
125 		else
126 		{
127 			vertexSrc << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
128 				<< getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
129 				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
130 		}
131 
132 		vertexSrc << "  if (subgroupBallotBitExtract(mask, otherID))\n"
133 			<< "  {\n"
134 			<< "    result = (op == data[otherID]) ? 1.0f : 0.0f;\n"
135 			<< "  }\n"
136 			<< "  else\n"
137 			<< "  {\n"
138 			<< "    result = 1.0f;\n" // Invocation we read from was inactive, so we can't verify results!
139 			<< "  }\n"
140 			<< "  gl_Position = in_position;\n"
141 			<< "  gl_PointSize = 1.0f;\n"
142 			<< "}\n";
143 		programCollection.glslSources.add("vert")
144 			<< glu::VertexSource(vertexSrc.str()) << buildOptions;
145 	}
146 	else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
147 	{
148 		std::ostringstream geometry;
149 
150 		geometry << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
151 			<< "#extension GL_KHR_shader_subgroup_quad: enable\n"
152 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
153 			<< "layout(points) in;\n"
154 			<< "layout(points, max_vertices = 1) out;\n"
155 			<< "layout(location = 0) out float out_color;\n"
156 
157 			<< "layout(set = 0, binding = 0) uniform Buffer1\n"
158 			<< "{\n"
159 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
160 			<< "};\n"
161 			<< "\n"
162 			<< "void main (void)\n"
163 			<< "{\n"
164 			<< "  uvec4 mask = subgroupBallot(true);\n"
165 			<< swapTable[caseDef.opType];
166 
167 		if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
168 		{
169 			geometry << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
170 				<< getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
171 				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
172 		}
173 		else
174 		{
175 			geometry << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
176 				<< getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
177 				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
178 		}
179 
180 		geometry << "  if (subgroupBallotBitExtract(mask, otherID))\n"
181 			<< "  {\n"
182 			<< "    out_color = (op == data[otherID]) ? 1.0 : 0.0;\n"
183 			<< "  }\n"
184 			<< "  else\n"
185 			<< "  {\n"
186 			<< "    out_color = 1.0;\n" // Invocation we read from was inactive, so we can't verify results!
187 			<< "  }\n"
188 			<< "  gl_Position = gl_in[0].gl_Position;\n"
189 			<< "  EmitVertex();\n"
190 			<< "  EndPrimitive();\n"
191 			<< "}\n";
192 
193 		programCollection.glslSources.add("geometry")
194 			<< glu::GeometrySource(geometry.str()) << buildOptions;
195 	}
196 	else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
197 	{
198 		std::ostringstream controlSource;
199 
200 		controlSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
201 			<< "#extension GL_KHR_shader_subgroup_quad: enable\n"
202 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
203 			<< "layout(vertices = 2) out;\n"
204 			<< "layout(location = 0) out float out_color[];\n"
205 			<< "layout(set = 0, binding = 0) uniform Buffer1\n"
206 			<< "{\n"
207 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
208 			<< "};\n"
209 			<< "\n"
210 			<< "void main (void)\n"
211 			<< "{\n"
212 			<< "  if (gl_InvocationID == 0)\n"
213 			<<"  {\n"
214 			<< "    gl_TessLevelOuter[0] = 1.0f;\n"
215 			<< "    gl_TessLevelOuter[1] = 1.0f;\n"
216 			<< "  }\n"
217 			<< "  uvec4 mask = subgroupBallot(true);\n"
218 			<< swapTable[caseDef.opType];
219 
220 		if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
221 		{
222 			controlSource << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
223 				<< getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
224 				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
225 		}
226 		else
227 		{
228 			controlSource << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
229 				<< getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
230 				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
231 		}
232 
233 		controlSource << "  if (subgroupBallotBitExtract(mask, otherID))\n"
234 			<< "  {\n"
235 			<< "    out_color[gl_InvocationID] = (op == data[otherID]) ? 1.0 : 0.0;\n"
236 			<< "  }\n"
237 			<< "  else\n"
238 			<< "  {\n"
239 			<< "    out_color[gl_InvocationID] = 1.0; \n"// Invocation we read from was inactive, so we can't verify results!
240 			<< "  }\n"
241 			<< "  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
242 			<< "}\n";
243 
244 		programCollection.glslSources.add("tesc")
245 			<< glu::TessellationControlSource(controlSource.str()) << buildOptions;
246 		subgroups::setTesEvalShaderFrameBuffer(programCollection);
247 	}
248 	else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
249 	{
250 		ostringstream evaluationSource;
251 		evaluationSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
252 			<< "#extension GL_KHR_shader_subgroup_quad: enable\n"
253 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
254 			<< "layout(isolines, equal_spacing, ccw ) in;\n"
255 			<< "layout(location = 0) out float out_color;\n"
256 			<< "layout(set = 0, binding = 0) uniform Buffer1\n"
257 			<< "{\n"
258 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
259 			<< "};\n"
260 			<< "\n"
261 			<< "void main (void)\n"
262 			<< "{\n"
263 			<< "  uvec4 mask = subgroupBallot(true);\n"
264 			<< swapTable[caseDef.opType];
265 
266 		if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
267 		{
268 			evaluationSource << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
269 				<< getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
270 				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
271 		}
272 		else
273 		{
274 			evaluationSource << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
275 				<< getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
276 				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
277 		}
278 
279 		evaluationSource << "  if (subgroupBallotBitExtract(mask, otherID))\n"
280 			<< "  {\n"
281 			<< "    out_color = (op == data[otherID]) ? 1.0 : 0.0;\n"
282 			<< "  }\n"
283 			<< "  else\n"
284 			<< "  {\n"
285 			<< "    out_color = 1.0;\n" // Invocation we read from was inactive, so we can't verify results!
286 			<< "  }\n"
287 			<< "  gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
288 			<< "}\n";
289 
290 		subgroups::setTesCtrlShaderFrameBuffer(programCollection);
291 		programCollection.glslSources.add("tese")
292 				<< glu::TessellationEvaluationSource(evaluationSource.str()) << buildOptions;
293 	}
294 	else
295 	{
296 		DE_FATAL("Unsupported shader stage");
297 	}
298 }
299 
initPrograms(SourceCollections & programCollection,CaseDefinition caseDef)300 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
301 {
302 	std::string swapTable[OPTYPE_LAST];
303 	swapTable[OPTYPE_QUAD_BROADCAST] = "";
304 	swapTable[OPTYPE_QUAD_SWAP_HORIZONTAL] = "  const uint swapTable[4] = {1, 0, 3, 2};\n";
305 	swapTable[OPTYPE_QUAD_SWAP_VERTICAL] = "  const uint swapTable[4] = {2, 3, 0, 1};\n";
306 	swapTable[OPTYPE_QUAD_SWAP_DIAGONAL] = "  const uint swapTable[4] = {3, 2, 1, 0};\n";
307 
308 	if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
309 	{
310 		std::ostringstream src;
311 
312 		src << "#version 450\n"
313 			<< "#extension GL_KHR_shader_subgroup_quad: enable\n"
314 			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
315 			<< "layout (local_size_x_id = 0, local_size_y_id = 1, "
316 			"local_size_z_id = 2) in;\n"
317 			<< "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
318 			<< "{\n"
319 			<< "  uint result[];\n"
320 			<< "};\n"
321 			<< "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
322 			<< "{\n"
323 			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
324 			<< "};\n"
325 			<< "\n"
326 			<< "void main (void)\n"
327 			<< "{\n"
328 			<< "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
329 			<< "  highp uint offset = globalSize.x * ((globalSize.y * "
330 			"gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
331 			"gl_GlobalInvocationID.x;\n"
332 			<< "  uvec4 mask = subgroupBallot(true);\n"
333 			<< swapTable[caseDef.opType];
334 
335 
336 		if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
337 		{
338 			src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
339 				<< getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
340 				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
341 		}
342 		else
343 		{
344 			src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
345 				<< getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
346 				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
347 		}
348 
349 		src << "  if (subgroupBallotBitExtract(mask, otherID))\n"
350 			<< "  {\n"
351 			<< "    result[offset] = (op == data[otherID]) ? 1 : 0;\n"
352 			<< "  }\n"
353 			<< "  else\n"
354 			<< "  {\n"
355 			<< "    result[offset] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
356 			<< "  }\n"
357 			<< "}\n";
358 
359 		programCollection.glslSources.add("comp")
360 				<< glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
361 	}
362 	else
363 	{
364 		std::ostringstream src;
365 		if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
366 		{
367 			src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
368 				<< getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
369 				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
370 		}
371 		else
372 		{
373 			src << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
374 				<< getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
375 				<< "  uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
376 		}
377 		const string sourceType = src.str();
378 
379 		{
380 			const string vertex =
381 				"#version 450\n"
382 				"#extension GL_KHR_shader_subgroup_quad: enable\n"
383 				"#extension GL_KHR_shader_subgroup_ballot: enable\n"
384 				"layout(set = 0, binding = 0, std430) buffer Buffer1\n"
385 				"{\n"
386 				"  uint result[];\n"
387 				"};\n"
388 				"layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
389 				"{\n"
390 				"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
391 				"};\n"
392 				"\n"
393 				"void main (void)\n"
394 				"{\n"
395 				"  uvec4 mask = subgroupBallot(true);\n"
396 				+ swapTable[caseDef.opType]
397 				+ sourceType +
398 				"  if (subgroupBallotBitExtract(mask, otherID))\n"
399 				"  {\n"
400 				"    result[gl_VertexIndex] = (op == data[otherID]) ? 1 : 0;\n"
401 				"  }\n"
402 				"  else\n"
403 				"  {\n"
404 				"    result[gl_VertexIndex] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
405 				"  }\n"
406 				"  float pixelSize = 2.0f/1024.0f;\n"
407 				"  float pixelPosition = pixelSize/2.0f - 1.0f;\n"
408 				"  gl_Position = vec4(float(gl_VertexIndex) * pixelSize + pixelPosition, 0.0f, 0.0f, 1.0f);\n"
409 				"  gl_PointSize = 1.0f;\n"
410 				"}\n";
411 			programCollection.glslSources.add("vert")
412 				<< glu::VertexSource(vertex) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
413 		}
414 
415 		{
416 			const string tesc =
417 				"#version 450\n"
418 				"#extension GL_KHR_shader_subgroup_quad: enable\n"
419 				"#extension GL_KHR_shader_subgroup_ballot: enable\n"
420 				"layout(vertices=1) out;\n"
421 				"layout(set = 0, binding = 1, std430) buffer Buffer1\n"
422 				"{\n"
423 				"  uint result[];\n"
424 				"};\n"
425 				"layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
426 				"{\n"
427 				"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
428 				"};\n"
429 				"\n"
430 				"void main (void)\n"
431 				"{\n"
432 				"  uvec4 mask = subgroupBallot(true);\n"
433 				+ swapTable[caseDef.opType]
434 				+ sourceType +
435 				"  if (subgroupBallotBitExtract(mask, otherID))\n"
436 				"  {\n"
437 				"    result[gl_PrimitiveID] = (op == data[otherID]) ? 1 : 0;\n"
438 				"  }\n"
439 				"  else\n"
440 				"  {\n"
441 				"    result[gl_PrimitiveID] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
442 				"  }\n"
443 				"  if (gl_InvocationID == 0)\n"
444 				"  {\n"
445 				"    gl_TessLevelOuter[0] = 1.0f;\n"
446 				"    gl_TessLevelOuter[1] = 1.0f;\n"
447 				"  }\n"
448 				"  gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
449 				"}\n";
450 			programCollection.glslSources.add("tesc")
451 					<< glu::TessellationControlSource(tesc) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
452 		}
453 
454 		{
455 			const string tese =
456 				"#version 450\n"
457 				"#extension GL_KHR_shader_subgroup_quad: enable\n"
458 				"#extension GL_KHR_shader_subgroup_ballot: enable\n"
459 				"layout(isolines) in;\n"
460 				"layout(set = 0, binding = 2, std430)  buffer Buffer1\n"
461 				"{\n"
462 				"  uint result[];\n"
463 				"};\n"
464 				"layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
465 				"{\n"
466 				"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
467 				"};\n"
468 				"\n"
469 				"void main (void)\n"
470 				"{\n"
471 				"  uvec4 mask = subgroupBallot(true);\n"
472 				+ swapTable[caseDef.opType]
473 				+ sourceType +
474 				"  if (subgroupBallotBitExtract(mask, otherID))\n"
475 				"  {\n"
476 				"    result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = (op == data[otherID]) ? 1 : 0;\n"
477 				"  }\n"
478 				"  else\n"
479 				"  {\n"
480 				"    result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
481 				"  }\n"
482 				"  float pixelSize = 2.0f/1024.0f;\n"
483 				"  gl_Position = gl_in[0].gl_Position + gl_TessCoord.x * pixelSize / 2.0f;\n"
484 				"}\n";
485 			programCollection.glslSources.add("tese")
486 					<< glu::TessellationEvaluationSource(tese) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
487 		}
488 
489 		{
490 			const string geometry =
491 				"#version 450\n"
492 				"#extension GL_KHR_shader_subgroup_quad: enable\n"
493 				"#extension GL_KHR_shader_subgroup_ballot: enable\n"
494 				"layout(${TOPOLOGY}) in;\n"
495 				"layout(points, max_vertices = 1) out;\n"
496 				"layout(set = 0, binding = 3, std430) buffer Buffer1\n"
497 				"{\n"
498 				"  uint result[];\n"
499 				"};\n"
500 				"layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
501 				"{\n"
502 				"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
503 				"};\n"
504 				"\n"
505 				"void main (void)\n"
506 				"{\n"
507 				"  uvec4 mask = subgroupBallot(true);\n"
508 				+ swapTable[caseDef.opType]
509 				+ sourceType +
510 				"  if (subgroupBallotBitExtract(mask, otherID))\n"
511 				"  {\n"
512 				"    result[gl_PrimitiveIDIn] = (op == data[otherID]) ? 1 : 0;\n"
513 				"  }\n"
514 				"  else\n"
515 				"  {\n"
516 				"    result[gl_PrimitiveIDIn] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
517 				"  }\n"
518 				"  gl_Position = gl_in[0].gl_Position;\n"
519 				"  EmitVertex();\n"
520 				"  EndPrimitive();\n"
521 				"}\n";
522 			subgroups::addGeometryShadersFromTemplate(geometry, vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u),
523 													  programCollection.glslSources);
524 		}
525 
526 		{
527 			const string fragment =
528 				"#version 450\n"
529 				"#extension GL_KHR_shader_subgroup_quad: enable\n"
530 				"#extension GL_KHR_shader_subgroup_ballot: enable\n"
531 				"layout(location = 0) out uint result;\n"
532 				"layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
533 				"{\n"
534 				"  " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
535 				"};\n"
536 				"void main (void)\n"
537 				"{\n"
538 				"  uvec4 mask = subgroupBallot(true);\n"
539 				+ swapTable[caseDef.opType]
540 				+ sourceType +
541 				"  if (subgroupBallotBitExtract(mask, otherID))\n"
542 				"  {\n"
543 				"    result = (op == data[otherID]) ? 1 : 0;\n"
544 				"  }\n"
545 				"  else\n"
546 				"  {\n"
547 				"    result = 1; // Invocation we read from was inactive, so we can't verify results!\n"
548 				"  }\n"
549 				"}\n";
550 			programCollection.glslSources.add("fragment")
551 				<< glu::FragmentSource(fragment)<< vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
552 		}
553 		subgroups::addNoSubgroupShader(programCollection);
554 	}
555 }
556 
supportedCheck(Context & context,CaseDefinition caseDef)557 void supportedCheck (Context& context, CaseDefinition caseDef)
558 {
559 	if (!subgroups::isSubgroupSupported(context))
560 		TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
561 
562 	if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_QUAD_BIT))
563 		TCU_THROW(NotSupportedError, "Device does not support subgroup quad operations");
564 
565 
566 	if (subgroups::isDoubleFormat(caseDef.format) &&
567 			!subgroups::isDoubleSupportedForDevice(context))
568 	{
569 		TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
570 	}
571 }
572 
noSSBOtest(Context & context,const CaseDefinition caseDef)573 tcu::TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
574 {
575 	if (!subgroups::areSubgroupOperationsSupportedForStage(
576 				context, caseDef.shaderStage))
577 	{
578 		if (subgroups::areSubgroupOperationsRequiredForStage(
579 					caseDef.shaderStage))
580 		{
581 			return tcu::TestStatus::fail(
582 					   "Shader stage " +
583 					   subgroups::getShaderStageName(caseDef.shaderStage) +
584 					   " is required to support subgroup operations!");
585 		}
586 		else
587 		{
588 			TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
589 		}
590 	}
591 
592 	subgroups::SSBOData inputData;
593 	inputData.format = caseDef.format;
594 	inputData.layout = subgroups::SSBOData::LayoutStd140;
595 	inputData.numElements = subgroups::maxSupportedSubgroupSize();
596 	inputData.initializeType = subgroups::SSBOData::InitializeNonZero;;
597 
598 	if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
599 		return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
600 	else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
601 		return subgroups::makeGeometryFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
602 	else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
603 		return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT);
604 	else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
605 		return subgroups::makeTessellationEvaluationFrameBufferTest(context,  VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT);
606 	else
607 		TCU_THROW(InternalError, "Unhandled shader stage");
608 }
609 
610 
test(Context & context,const CaseDefinition caseDef)611 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
612 {
613 	if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
614 	{
615 		if (!subgroups::areSubgroupOperationsSupportedForStage(context, caseDef.shaderStage))
616 		{
617 			return tcu::TestStatus::fail(
618 					   "Shader stage " +
619 					   subgroups::getShaderStageName(caseDef.shaderStage) +
620 					   " is required to support subgroup operations!");
621 		}
622 		subgroups::SSBOData inputData;
623 		inputData.format = caseDef.format;
624 		inputData.layout = subgroups::SSBOData::LayoutStd430;
625 		inputData.numElements = subgroups::maxSupportedSubgroupSize();
626 		inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
627 
628 		return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkCompute);
629 	}
630 	else
631 	{
632 		VkPhysicalDeviceSubgroupProperties subgroupProperties;
633 		subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
634 		subgroupProperties.pNext = DE_NULL;
635 
636 		VkPhysicalDeviceProperties2 properties;
637 		properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
638 		properties.pNext = &subgroupProperties;
639 
640 		context.getInstanceInterface().getPhysicalDeviceProperties2(context.getPhysicalDevice(), &properties);
641 
642 		VkShaderStageFlagBits stages = (VkShaderStageFlagBits)(caseDef.shaderStage  & subgroupProperties.supportedStages);
643 
644 		if (VK_SHADER_STAGE_FRAGMENT_BIT != stages && !subgroups::isVertexSSBOSupportedForDevice(context))
645 		{
646 			if ( (stages & VK_SHADER_STAGE_FRAGMENT_BIT) == 0)
647 				TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
648 			else
649 				stages = VK_SHADER_STAGE_FRAGMENT_BIT;
650 		}
651 
652 		if ((VkShaderStageFlagBits)0u == stages)
653 			TCU_THROW(NotSupportedError, "Subgroup operations are not supported for any graphic shader");
654 
655 		subgroups::SSBOData inputData;
656 		inputData.format			= caseDef.format;
657 		inputData.layout			= subgroups::SSBOData::LayoutStd430;
658 		inputData.numElements		= subgroups::maxSupportedSubgroupSize();
659 		inputData.initializeType	= subgroups::SSBOData::InitializeNonZero;
660 		inputData.binding			= 4u;
661 		inputData.stages			= stages;
662 
663 		return subgroups::allStages(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, stages);
664 	}
665 }
666 }
667 
668 namespace vkt
669 {
670 namespace subgroups
671 {
createSubgroupsQuadTests(tcu::TestContext & testCtx)672 tcu::TestCaseGroup* createSubgroupsQuadTests(tcu::TestContext& testCtx)
673 {
674 	de::MovePtr<tcu::TestCaseGroup> graphicGroup(new tcu::TestCaseGroup(
675 		testCtx, "graphics", "Subgroup arithmetic category tests: graphics"));
676 	de::MovePtr<tcu::TestCaseGroup> computeGroup(new tcu::TestCaseGroup(
677 		testCtx, "compute", "Subgroup arithmetic category tests: compute"));
678 	de::MovePtr<tcu::TestCaseGroup> framebufferGroup(new tcu::TestCaseGroup(
679 		testCtx, "framebuffer", "Subgroup arithmetic category tests: framebuffer"));
680 
681 	const VkFormat formats[] =
682 	{
683 		VK_FORMAT_R32_SINT, VK_FORMAT_R32G32_SINT, VK_FORMAT_R32G32B32_SINT,
684 		VK_FORMAT_R32G32B32A32_SINT, VK_FORMAT_R32_UINT, VK_FORMAT_R32G32_UINT,
685 		VK_FORMAT_R32G32B32_UINT, VK_FORMAT_R32G32B32A32_UINT,
686 		VK_FORMAT_R32_SFLOAT, VK_FORMAT_R32G32_SFLOAT,
687 		VK_FORMAT_R32G32B32_SFLOAT, VK_FORMAT_R32G32B32A32_SFLOAT,
688 		VK_FORMAT_R64_SFLOAT, VK_FORMAT_R64G64_SFLOAT,
689 		VK_FORMAT_R64G64B64_SFLOAT, VK_FORMAT_R64G64B64A64_SFLOAT,
690 		VK_FORMAT_R8_USCALED, VK_FORMAT_R8G8_USCALED,
691 		VK_FORMAT_R8G8B8_USCALED, VK_FORMAT_R8G8B8A8_USCALED,
692 	};
693 
694 	const VkShaderStageFlags stages[] =
695 	{
696 		VK_SHADER_STAGE_VERTEX_BIT,
697 		VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
698 		VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
699 		VK_SHADER_STAGE_GEOMETRY_BIT,
700 	};
701 
702 	for (int direction = 0; direction < 4; ++direction)
703 	{
704 		for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
705 		{
706 			const VkFormat format = formats[formatIndex];
707 
708 			for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
709 			{
710 				const std::string op = de::toLower(getOpTypeName(opTypeIndex));
711 				std::ostringstream name;
712 				name << de::toLower(op);
713 
714 				if (OPTYPE_QUAD_BROADCAST == opTypeIndex)
715 				{
716 					name << "_" << direction;
717 				}
718 				else
719 				{
720 					if (0 != direction)
721 					{
722 						// We don't need direction for swap operations.
723 						continue;
724 					}
725 				}
726 
727 				name << "_" << subgroups::getFormatNameForGLSL(format);
728 
729 				{
730 					const CaseDefinition caseDef = {opTypeIndex, VK_SHADER_STAGE_COMPUTE_BIT, format, direction};
731 					addFunctionCaseWithPrograms(computeGroup.get(), name.str(), "", supportedCheck, initPrograms, test, caseDef);
732 				}
733 
734 				{
735 					const CaseDefinition caseDef =
736 					{
737 						opTypeIndex,
738 						VK_SHADER_STAGE_ALL_GRAPHICS,
739 						format,
740 						direction
741 					};
742 					addFunctionCaseWithPrograms(graphicGroup.get(), name.str(), "", supportedCheck, initPrograms, test, caseDef);
743 				}
744 				for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
745 				{
746 					const CaseDefinition caseDef = {opTypeIndex, stages[stageIndex], format, direction};
747 					addFunctionCaseWithPrograms(framebufferGroup.get(), name.str()+"_"+ getShaderStageName(caseDef.shaderStage), "",
748 												supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
749 				}
750 
751 			}
752 		}
753 	}
754 
755 	de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
756 		testCtx, "quad", "Subgroup quad category tests"));
757 
758 	group->addChild(graphicGroup.release());
759 	group->addChild(computeGroup.release());
760 	group->addChild(framebufferGroup.release());
761 
762 	return group.release();
763 }
764 } // subgroups
765 } // vkt
766