1 /*------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2017 The Khronos Group Inc.
6 * Copyright (c) 2017 Codeplay Software Ltd.
7 *
8 * Licensed under the Apache License, Version 2.0 (the "License");
9 * you may not use this file except in compliance with the License.
10 * You may obtain a copy of the License at
11 *
12 * http://www.apache.org/licenses/LICENSE-2.0
13 *
14 * Unless required by applicable law or agreed to in writing, software
15 * distributed under the License is distributed on an "AS IS" BASIS,
16 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 * See the License for the specific language governing permissions and
18 * limitations under the License.
19 *
20 */ /*!
21 * \file
22 * \brief Subgroups Tests
23 */ /*--------------------------------------------------------------------*/
24
25 #include "vktSubgroupsQuadTests.hpp"
26 #include "vktSubgroupsTestsUtils.hpp"
27
28 #include <string>
29 #include <vector>
30
31 using namespace tcu;
32 using namespace std;
33 using namespace vk;
34 using namespace vkt;
35
36 namespace
37 {
38 enum OpType
39 {
40 OPTYPE_QUAD_BROADCAST = 0,
41 OPTYPE_QUAD_SWAP_HORIZONTAL,
42 OPTYPE_QUAD_SWAP_VERTICAL,
43 OPTYPE_QUAD_SWAP_DIAGONAL,
44 OPTYPE_LAST
45 };
46
checkVertexPipelineStages(std::vector<const void * > datas,deUint32 width,deUint32)47 static bool checkVertexPipelineStages(std::vector<const void*> datas,
48 deUint32 width, deUint32)
49 {
50 return vkt::subgroups::check(datas, width, 1);
51 }
52
checkCompute(std::vector<const void * > datas,const deUint32 numWorkgroups[3],const deUint32 localSize[3],deUint32)53 static bool checkCompute(std::vector<const void*> datas,
54 const deUint32 numWorkgroups[3], const deUint32 localSize[3],
55 deUint32)
56 {
57 return vkt::subgroups::checkCompute(datas, numWorkgroups, localSize, 1);
58 }
59
getOpTypeName(int opType)60 std::string getOpTypeName(int opType)
61 {
62 switch (opType)
63 {
64 default:
65 DE_FATAL("Unsupported op type");
66 return "";
67 case OPTYPE_QUAD_BROADCAST:
68 return "subgroupQuadBroadcast";
69 case OPTYPE_QUAD_SWAP_HORIZONTAL:
70 return "subgroupQuadSwapHorizontal";
71 case OPTYPE_QUAD_SWAP_VERTICAL:
72 return "subgroupQuadSwapVertical";
73 case OPTYPE_QUAD_SWAP_DIAGONAL:
74 return "subgroupQuadSwapDiagonal";
75 }
76 }
77
78 struct CaseDefinition
79 {
80 int opType;
81 VkShaderStageFlags shaderStage;
82 VkFormat format;
83 int direction;
84 };
85
initFrameBufferPrograms(SourceCollections & programCollection,CaseDefinition caseDef)86 void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
87 {
88 const vk::ShaderBuildOptions buildOptions (programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
89 std::string swapTable[OPTYPE_LAST];
90
91 subgroups::setFragmentShaderFrameBuffer(programCollection);
92
93 if (VK_SHADER_STAGE_VERTEX_BIT != caseDef.shaderStage)
94 subgroups::setVertexShaderFrameBuffer(programCollection);
95
96 swapTable[OPTYPE_QUAD_BROADCAST] = "";
97 swapTable[OPTYPE_QUAD_SWAP_HORIZONTAL] = " const uint swapTable[4] = {1, 0, 3, 2};\n";
98 swapTable[OPTYPE_QUAD_SWAP_VERTICAL] = " const uint swapTable[4] = {2, 3, 0, 1};\n";
99 swapTable[OPTYPE_QUAD_SWAP_DIAGONAL] = " const uint swapTable[4] = {3, 2, 1, 0};\n";
100
101 if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
102 {
103 std::ostringstream vertexSrc;
104 vertexSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
105 << "#extension GL_KHR_shader_subgroup_quad: enable\n"
106 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
107 << "layout(location = 0) in highp vec4 in_position;\n"
108 << "layout(location = 0) out float result;\n"
109 << "layout(set = 0, binding = 0) uniform Buffer1\n"
110 << "{\n"
111 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
112 << "};\n"
113 << "\n"
114 << "void main (void)\n"
115 << "{\n"
116 << " uvec4 mask = subgroupBallot(true);\n"
117 << swapTable[caseDef.opType];
118
119 if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
120 {
121 vertexSrc << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
122 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
123 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
124 }
125 else
126 {
127 vertexSrc << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
128 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
129 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
130 }
131
132 vertexSrc << " if (subgroupBallotBitExtract(mask, otherID))\n"
133 << " {\n"
134 << " result = (op == data[otherID]) ? 1.0f : 0.0f;\n"
135 << " }\n"
136 << " else\n"
137 << " {\n"
138 << " result = 1.0f;\n" // Invocation we read from was inactive, so we can't verify results!
139 << " }\n"
140 << " gl_Position = in_position;\n"
141 << " gl_PointSize = 1.0f;\n"
142 << "}\n";
143 programCollection.glslSources.add("vert")
144 << glu::VertexSource(vertexSrc.str()) << buildOptions;
145 }
146 else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
147 {
148 std::ostringstream geometry;
149
150 geometry << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
151 << "#extension GL_KHR_shader_subgroup_quad: enable\n"
152 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
153 << "layout(points) in;\n"
154 << "layout(points, max_vertices = 1) out;\n"
155 << "layout(location = 0) out float out_color;\n"
156
157 << "layout(set = 0, binding = 0) uniform Buffer1\n"
158 << "{\n"
159 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
160 << "};\n"
161 << "\n"
162 << "void main (void)\n"
163 << "{\n"
164 << " uvec4 mask = subgroupBallot(true);\n"
165 << swapTable[caseDef.opType];
166
167 if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
168 {
169 geometry << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
170 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
171 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
172 }
173 else
174 {
175 geometry << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
176 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
177 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
178 }
179
180 geometry << " if (subgroupBallotBitExtract(mask, otherID))\n"
181 << " {\n"
182 << " out_color = (op == data[otherID]) ? 1.0 : 0.0;\n"
183 << " }\n"
184 << " else\n"
185 << " {\n"
186 << " out_color = 1.0;\n" // Invocation we read from was inactive, so we can't verify results!
187 << " }\n"
188 << " gl_Position = gl_in[0].gl_Position;\n"
189 << " EmitVertex();\n"
190 << " EndPrimitive();\n"
191 << "}\n";
192
193 programCollection.glslSources.add("geometry")
194 << glu::GeometrySource(geometry.str()) << buildOptions;
195 }
196 else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
197 {
198 std::ostringstream controlSource;
199
200 controlSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
201 << "#extension GL_KHR_shader_subgroup_quad: enable\n"
202 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
203 << "layout(vertices = 2) out;\n"
204 << "layout(location = 0) out float out_color[];\n"
205 << "layout(set = 0, binding = 0) uniform Buffer1\n"
206 << "{\n"
207 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
208 << "};\n"
209 << "\n"
210 << "void main (void)\n"
211 << "{\n"
212 << " if (gl_InvocationID == 0)\n"
213 <<" {\n"
214 << " gl_TessLevelOuter[0] = 1.0f;\n"
215 << " gl_TessLevelOuter[1] = 1.0f;\n"
216 << " }\n"
217 << " uvec4 mask = subgroupBallot(true);\n"
218 << swapTable[caseDef.opType];
219
220 if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
221 {
222 controlSource << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
223 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
224 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
225 }
226 else
227 {
228 controlSource << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
229 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
230 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
231 }
232
233 controlSource << " if (subgroupBallotBitExtract(mask, otherID))\n"
234 << " {\n"
235 << " out_color[gl_InvocationID] = (op == data[otherID]) ? 1.0 : 0.0;\n"
236 << " }\n"
237 << " else\n"
238 << " {\n"
239 << " out_color[gl_InvocationID] = 1.0; \n"// Invocation we read from was inactive, so we can't verify results!
240 << " }\n"
241 << " gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
242 << "}\n";
243
244 programCollection.glslSources.add("tesc")
245 << glu::TessellationControlSource(controlSource.str()) << buildOptions;
246 subgroups::setTesEvalShaderFrameBuffer(programCollection);
247 }
248 else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
249 {
250 ostringstream evaluationSource;
251 evaluationSource << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
252 << "#extension GL_KHR_shader_subgroup_quad: enable\n"
253 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
254 << "layout(isolines, equal_spacing, ccw ) in;\n"
255 << "layout(location = 0) out float out_color;\n"
256 << "layout(set = 0, binding = 0) uniform Buffer1\n"
257 << "{\n"
258 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
259 << "};\n"
260 << "\n"
261 << "void main (void)\n"
262 << "{\n"
263 << " uvec4 mask = subgroupBallot(true);\n"
264 << swapTable[caseDef.opType];
265
266 if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
267 {
268 evaluationSource << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
269 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
270 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
271 }
272 else
273 {
274 evaluationSource << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
275 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
276 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
277 }
278
279 evaluationSource << " if (subgroupBallotBitExtract(mask, otherID))\n"
280 << " {\n"
281 << " out_color = (op == data[otherID]) ? 1.0 : 0.0;\n"
282 << " }\n"
283 << " else\n"
284 << " {\n"
285 << " out_color = 1.0;\n" // Invocation we read from was inactive, so we can't verify results!
286 << " }\n"
287 << " gl_Position = mix(gl_in[0].gl_Position, gl_in[1].gl_Position, gl_TessCoord.x);\n"
288 << "}\n";
289
290 subgroups::setTesCtrlShaderFrameBuffer(programCollection);
291 programCollection.glslSources.add("tese")
292 << glu::TessellationEvaluationSource(evaluationSource.str()) << buildOptions;
293 }
294 else
295 {
296 DE_FATAL("Unsupported shader stage");
297 }
298 }
299
initPrograms(SourceCollections & programCollection,CaseDefinition caseDef)300 void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
301 {
302 std::string swapTable[OPTYPE_LAST];
303 swapTable[OPTYPE_QUAD_BROADCAST] = "";
304 swapTable[OPTYPE_QUAD_SWAP_HORIZONTAL] = " const uint swapTable[4] = {1, 0, 3, 2};\n";
305 swapTable[OPTYPE_QUAD_SWAP_VERTICAL] = " const uint swapTable[4] = {2, 3, 0, 1};\n";
306 swapTable[OPTYPE_QUAD_SWAP_DIAGONAL] = " const uint swapTable[4] = {3, 2, 1, 0};\n";
307
308 if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
309 {
310 std::ostringstream src;
311
312 src << "#version 450\n"
313 << "#extension GL_KHR_shader_subgroup_quad: enable\n"
314 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
315 << "layout (local_size_x_id = 0, local_size_y_id = 1, "
316 "local_size_z_id = 2) in;\n"
317 << "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
318 << "{\n"
319 << " uint result[];\n"
320 << "};\n"
321 << "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
322 << "{\n"
323 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
324 << "};\n"
325 << "\n"
326 << "void main (void)\n"
327 << "{\n"
328 << " uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
329 << " highp uint offset = globalSize.x * ((globalSize.y * "
330 "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
331 "gl_GlobalInvocationID.x;\n"
332 << " uvec4 mask = subgroupBallot(true);\n"
333 << swapTable[caseDef.opType];
334
335
336 if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
337 {
338 src << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
339 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
340 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
341 }
342 else
343 {
344 src << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
345 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
346 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
347 }
348
349 src << " if (subgroupBallotBitExtract(mask, otherID))\n"
350 << " {\n"
351 << " result[offset] = (op == data[otherID]) ? 1 : 0;\n"
352 << " }\n"
353 << " else\n"
354 << " {\n"
355 << " result[offset] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
356 << " }\n"
357 << "}\n";
358
359 programCollection.glslSources.add("comp")
360 << glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
361 }
362 else
363 {
364 std::ostringstream src;
365 if (OPTYPE_QUAD_BROADCAST == caseDef.opType)
366 {
367 src << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
368 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID], " << caseDef.direction << ");\n"
369 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + " << caseDef.direction << ";\n";
370 }
371 else
372 {
373 src << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " op = "
374 << getOpTypeName(caseDef.opType) << "(data[gl_SubgroupInvocationID]);\n"
375 << " uint otherID = (gl_SubgroupInvocationID & ~0x3) + swapTable[gl_SubgroupInvocationID & 0x3];\n";
376 }
377 const string sourceType = src.str();
378
379 {
380 const string vertex =
381 "#version 450\n"
382 "#extension GL_KHR_shader_subgroup_quad: enable\n"
383 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
384 "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
385 "{\n"
386 " uint result[];\n"
387 "};\n"
388 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
389 "{\n"
390 " " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
391 "};\n"
392 "\n"
393 "void main (void)\n"
394 "{\n"
395 " uvec4 mask = subgroupBallot(true);\n"
396 + swapTable[caseDef.opType]
397 + sourceType +
398 " if (subgroupBallotBitExtract(mask, otherID))\n"
399 " {\n"
400 " result[gl_VertexIndex] = (op == data[otherID]) ? 1 : 0;\n"
401 " }\n"
402 " else\n"
403 " {\n"
404 " result[gl_VertexIndex] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
405 " }\n"
406 " float pixelSize = 2.0f/1024.0f;\n"
407 " float pixelPosition = pixelSize/2.0f - 1.0f;\n"
408 " gl_Position = vec4(float(gl_VertexIndex) * pixelSize + pixelPosition, 0.0f, 0.0f, 1.0f);\n"
409 " gl_PointSize = 1.0f;\n"
410 "}\n";
411 programCollection.glslSources.add("vert")
412 << glu::VertexSource(vertex) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
413 }
414
415 {
416 const string tesc =
417 "#version 450\n"
418 "#extension GL_KHR_shader_subgroup_quad: enable\n"
419 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
420 "layout(vertices=1) out;\n"
421 "layout(set = 0, binding = 1, std430) buffer Buffer1\n"
422 "{\n"
423 " uint result[];\n"
424 "};\n"
425 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
426 "{\n"
427 " " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
428 "};\n"
429 "\n"
430 "void main (void)\n"
431 "{\n"
432 " uvec4 mask = subgroupBallot(true);\n"
433 + swapTable[caseDef.opType]
434 + sourceType +
435 " if (subgroupBallotBitExtract(mask, otherID))\n"
436 " {\n"
437 " result[gl_PrimitiveID] = (op == data[otherID]) ? 1 : 0;\n"
438 " }\n"
439 " else\n"
440 " {\n"
441 " result[gl_PrimitiveID] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
442 " }\n"
443 " if (gl_InvocationID == 0)\n"
444 " {\n"
445 " gl_TessLevelOuter[0] = 1.0f;\n"
446 " gl_TessLevelOuter[1] = 1.0f;\n"
447 " }\n"
448 " gl_out[gl_InvocationID].gl_Position = gl_in[gl_InvocationID].gl_Position;\n"
449 "}\n";
450 programCollection.glslSources.add("tesc")
451 << glu::TessellationControlSource(tesc) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
452 }
453
454 {
455 const string tese =
456 "#version 450\n"
457 "#extension GL_KHR_shader_subgroup_quad: enable\n"
458 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
459 "layout(isolines) in;\n"
460 "layout(set = 0, binding = 2, std430) buffer Buffer1\n"
461 "{\n"
462 " uint result[];\n"
463 "};\n"
464 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
465 "{\n"
466 " " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
467 "};\n"
468 "\n"
469 "void main (void)\n"
470 "{\n"
471 " uvec4 mask = subgroupBallot(true);\n"
472 + swapTable[caseDef.opType]
473 + sourceType +
474 " if (subgroupBallotBitExtract(mask, otherID))\n"
475 " {\n"
476 " result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = (op == data[otherID]) ? 1 : 0;\n"
477 " }\n"
478 " else\n"
479 " {\n"
480 " result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
481 " }\n"
482 " float pixelSize = 2.0f/1024.0f;\n"
483 " gl_Position = gl_in[0].gl_Position + gl_TessCoord.x * pixelSize / 2.0f;\n"
484 "}\n";
485 programCollection.glslSources.add("tese")
486 << glu::TessellationEvaluationSource(tese) << vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
487 }
488
489 {
490 const string geometry =
491 "#version 450\n"
492 "#extension GL_KHR_shader_subgroup_quad: enable\n"
493 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
494 "layout(${TOPOLOGY}) in;\n"
495 "layout(points, max_vertices = 1) out;\n"
496 "layout(set = 0, binding = 3, std430) buffer Buffer1\n"
497 "{\n"
498 " uint result[];\n"
499 "};\n"
500 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
501 "{\n"
502 " " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
503 "};\n"
504 "\n"
505 "void main (void)\n"
506 "{\n"
507 " uvec4 mask = subgroupBallot(true);\n"
508 + swapTable[caseDef.opType]
509 + sourceType +
510 " if (subgroupBallotBitExtract(mask, otherID))\n"
511 " {\n"
512 " result[gl_PrimitiveIDIn] = (op == data[otherID]) ? 1 : 0;\n"
513 " }\n"
514 " else\n"
515 " {\n"
516 " result[gl_PrimitiveIDIn] = 1; // Invocation we read from was inactive, so we can't verify results!\n"
517 " }\n"
518 " gl_Position = gl_in[0].gl_Position;\n"
519 " EmitVertex();\n"
520 " EndPrimitive();\n"
521 "}\n";
522 subgroups::addGeometryShadersFromTemplate(geometry, vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u),
523 programCollection.glslSources);
524 }
525
526 {
527 const string fragment =
528 "#version 450\n"
529 "#extension GL_KHR_shader_subgroup_quad: enable\n"
530 "#extension GL_KHR_shader_subgroup_ballot: enable\n"
531 "layout(location = 0) out uint result;\n"
532 "layout(set = 0, binding = 4, std430) readonly buffer Buffer2\n"
533 "{\n"
534 " " + subgroups::getFormatNameForGLSL(caseDef.format) + " data[];\n"
535 "};\n"
536 "void main (void)\n"
537 "{\n"
538 " uvec4 mask = subgroupBallot(true);\n"
539 + swapTable[caseDef.opType]
540 + sourceType +
541 " if (subgroupBallotBitExtract(mask, otherID))\n"
542 " {\n"
543 " result = (op == data[otherID]) ? 1 : 0;\n"
544 " }\n"
545 " else\n"
546 " {\n"
547 " result = 1; // Invocation we read from was inactive, so we can't verify results!\n"
548 " }\n"
549 "}\n";
550 programCollection.glslSources.add("fragment")
551 << glu::FragmentSource(fragment)<< vk::ShaderBuildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
552 }
553 subgroups::addNoSubgroupShader(programCollection);
554 }
555 }
556
supportedCheck(Context & context,CaseDefinition caseDef)557 void supportedCheck (Context& context, CaseDefinition caseDef)
558 {
559 if (!subgroups::isSubgroupSupported(context))
560 TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
561
562 if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_QUAD_BIT))
563 TCU_THROW(NotSupportedError, "Device does not support subgroup quad operations");
564
565
566 if (subgroups::isDoubleFormat(caseDef.format) &&
567 !subgroups::isDoubleSupportedForDevice(context))
568 {
569 TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
570 }
571 }
572
noSSBOtest(Context & context,const CaseDefinition caseDef)573 tcu::TestStatus noSSBOtest (Context& context, const CaseDefinition caseDef)
574 {
575 if (!subgroups::areSubgroupOperationsSupportedForStage(
576 context, caseDef.shaderStage))
577 {
578 if (subgroups::areSubgroupOperationsRequiredForStage(
579 caseDef.shaderStage))
580 {
581 return tcu::TestStatus::fail(
582 "Shader stage " +
583 subgroups::getShaderStageName(caseDef.shaderStage) +
584 " is required to support subgroup operations!");
585 }
586 else
587 {
588 TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
589 }
590 }
591
592 subgroups::SSBOData inputData;
593 inputData.format = caseDef.format;
594 inputData.layout = subgroups::SSBOData::LayoutStd140;
595 inputData.numElements = subgroups::maxSupportedSubgroupSize();
596 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;;
597
598 if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
599 return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
600 else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
601 return subgroups::makeGeometryFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
602 else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
603 return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT);
604 else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
605 return subgroups::makeTessellationEvaluationFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT);
606 else
607 TCU_THROW(InternalError, "Unhandled shader stage");
608 }
609
610
test(Context & context,const CaseDefinition caseDef)611 tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
612 {
613 if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
614 {
615 if (!subgroups::areSubgroupOperationsSupportedForStage(context, caseDef.shaderStage))
616 {
617 return tcu::TestStatus::fail(
618 "Shader stage " +
619 subgroups::getShaderStageName(caseDef.shaderStage) +
620 " is required to support subgroup operations!");
621 }
622 subgroups::SSBOData inputData;
623 inputData.format = caseDef.format;
624 inputData.layout = subgroups::SSBOData::LayoutStd430;
625 inputData.numElements = subgroups::maxSupportedSubgroupSize();
626 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
627
628 return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkCompute);
629 }
630 else
631 {
632 VkPhysicalDeviceSubgroupProperties subgroupProperties;
633 subgroupProperties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES;
634 subgroupProperties.pNext = DE_NULL;
635
636 VkPhysicalDeviceProperties2 properties;
637 properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
638 properties.pNext = &subgroupProperties;
639
640 context.getInstanceInterface().getPhysicalDeviceProperties2(context.getPhysicalDevice(), &properties);
641
642 VkShaderStageFlagBits stages = (VkShaderStageFlagBits)(caseDef.shaderStage & subgroupProperties.supportedStages);
643
644 if (VK_SHADER_STAGE_FRAGMENT_BIT != stages && !subgroups::isVertexSSBOSupportedForDevice(context))
645 {
646 if ( (stages & VK_SHADER_STAGE_FRAGMENT_BIT) == 0)
647 TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
648 else
649 stages = VK_SHADER_STAGE_FRAGMENT_BIT;
650 }
651
652 if ((VkShaderStageFlagBits)0u == stages)
653 TCU_THROW(NotSupportedError, "Subgroup operations are not supported for any graphic shader");
654
655 subgroups::SSBOData inputData;
656 inputData.format = caseDef.format;
657 inputData.layout = subgroups::SSBOData::LayoutStd430;
658 inputData.numElements = subgroups::maxSupportedSubgroupSize();
659 inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
660 inputData.binding = 4u;
661 inputData.stages = stages;
662
663 return subgroups::allStages(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages, stages);
664 }
665 }
666 }
667
668 namespace vkt
669 {
670 namespace subgroups
671 {
createSubgroupsQuadTests(tcu::TestContext & testCtx)672 tcu::TestCaseGroup* createSubgroupsQuadTests(tcu::TestContext& testCtx)
673 {
674 de::MovePtr<tcu::TestCaseGroup> graphicGroup(new tcu::TestCaseGroup(
675 testCtx, "graphics", "Subgroup arithmetic category tests: graphics"));
676 de::MovePtr<tcu::TestCaseGroup> computeGroup(new tcu::TestCaseGroup(
677 testCtx, "compute", "Subgroup arithmetic category tests: compute"));
678 de::MovePtr<tcu::TestCaseGroup> framebufferGroup(new tcu::TestCaseGroup(
679 testCtx, "framebuffer", "Subgroup arithmetic category tests: framebuffer"));
680
681 const VkFormat formats[] =
682 {
683 VK_FORMAT_R32_SINT, VK_FORMAT_R32G32_SINT, VK_FORMAT_R32G32B32_SINT,
684 VK_FORMAT_R32G32B32A32_SINT, VK_FORMAT_R32_UINT, VK_FORMAT_R32G32_UINT,
685 VK_FORMAT_R32G32B32_UINT, VK_FORMAT_R32G32B32A32_UINT,
686 VK_FORMAT_R32_SFLOAT, VK_FORMAT_R32G32_SFLOAT,
687 VK_FORMAT_R32G32B32_SFLOAT, VK_FORMAT_R32G32B32A32_SFLOAT,
688 VK_FORMAT_R64_SFLOAT, VK_FORMAT_R64G64_SFLOAT,
689 VK_FORMAT_R64G64B64_SFLOAT, VK_FORMAT_R64G64B64A64_SFLOAT,
690 VK_FORMAT_R8_USCALED, VK_FORMAT_R8G8_USCALED,
691 VK_FORMAT_R8G8B8_USCALED, VK_FORMAT_R8G8B8A8_USCALED,
692 };
693
694 const VkShaderStageFlags stages[] =
695 {
696 VK_SHADER_STAGE_VERTEX_BIT,
697 VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
698 VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
699 VK_SHADER_STAGE_GEOMETRY_BIT,
700 };
701
702 for (int direction = 0; direction < 4; ++direction)
703 {
704 for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
705 {
706 const VkFormat format = formats[formatIndex];
707
708 for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
709 {
710 const std::string op = de::toLower(getOpTypeName(opTypeIndex));
711 std::ostringstream name;
712 name << de::toLower(op);
713
714 if (OPTYPE_QUAD_BROADCAST == opTypeIndex)
715 {
716 name << "_" << direction;
717 }
718 else
719 {
720 if (0 != direction)
721 {
722 // We don't need direction for swap operations.
723 continue;
724 }
725 }
726
727 name << "_" << subgroups::getFormatNameForGLSL(format);
728
729 {
730 const CaseDefinition caseDef = {opTypeIndex, VK_SHADER_STAGE_COMPUTE_BIT, format, direction};
731 addFunctionCaseWithPrograms(computeGroup.get(), name.str(), "", supportedCheck, initPrograms, test, caseDef);
732 }
733
734 {
735 const CaseDefinition caseDef =
736 {
737 opTypeIndex,
738 VK_SHADER_STAGE_ALL_GRAPHICS,
739 format,
740 direction
741 };
742 addFunctionCaseWithPrograms(graphicGroup.get(), name.str(), "", supportedCheck, initPrograms, test, caseDef);
743 }
744 for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
745 {
746 const CaseDefinition caseDef = {opTypeIndex, stages[stageIndex], format, direction};
747 addFunctionCaseWithPrograms(framebufferGroup.get(), name.str()+"_"+ getShaderStageName(caseDef.shaderStage), "",
748 supportedCheck, initFrameBufferPrograms, noSSBOtest, caseDef);
749 }
750
751 }
752 }
753 }
754
755 de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
756 testCtx, "quad", "Subgroup quad category tests"));
757
758 group->addChild(graphicGroup.release());
759 group->addChild(computeGroup.release());
760 group->addChild(framebufferGroup.release());
761
762 return group.release();
763 }
764 } // subgroups
765 } // vkt
766