1 /*-------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2020 The Khronos Group Inc.
6 * Copyright (c) 2020 Valve Corporation.
7 *
8 * Licensed under the Apache License, Version 2.0 (the "License");
9 * you may not use this file except in compliance with the License.
10 * You may obtain a copy of the License at
11 *
12 * http://www.apache.org/licenses/LICENSE-2.0
13 *
14 * Unless required by applicable law or agreed to in writing, software
15 * distributed under the License is distributed on an "AS IS" BASIS,
16 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 * See the License for the specific language governing permissions and
18 * limitations under the License.
19 *
20 *//*!
21 * \file
22 * \brief Ray Tracing Data Spill tests
23 *//*--------------------------------------------------------------------*/
24 #include "vktRayTracingDataSpillTests.hpp"
25 #include "vktTestCase.hpp"
26
27 #include "vkRayTracingUtil.hpp"
28 #include "vkObjUtil.hpp"
29 #include "vkBufferWithMemory.hpp"
30 #include "vkImageWithMemory.hpp"
31 #include "vkBuilderUtil.hpp"
32 #include "vkCmdUtil.hpp"
33 #include "vkTypeUtil.hpp"
34 #include "vkBarrierUtil.hpp"
35
36 #include "tcuStringTemplate.hpp"
37 #include "tcuFloat.hpp"
38
39 #include "deUniquePtr.hpp"
40 #include "deSTLUtil.hpp"
41
42 #include <sstream>
43 #include <string>
44 #include <map>
45 #include <vector>
46 #include <array>
47 #include <utility>
48
49 using namespace vk;
50
51 namespace vkt
52 {
53 namespace RayTracing
54 {
55
56 namespace
57 {
58
59 // The type of shader call that will be used.
60 enum class CallType
61 {
62 TRACE_RAY = 0,
63 EXECUTE_CALLABLE,
64 REPORT_INTERSECTION,
65 };
66
67 // The type of data that will be checked.
68 enum class DataType
69 {
70 // These can be made an array or vector.
71 INT32 = 0,
72 UINT32,
73 INT64,
74 UINT64,
75 INT16,
76 UINT16,
77 INT8,
78 UINT8,
79 FLOAT32,
80 FLOAT64,
81 FLOAT16,
82
83 // These are standalone, so the vector type should be scalar.
84 STRUCT,
85 IMAGE,
86 SAMPLER,
87 SAMPLED_IMAGE,
88 PTR_IMAGE,
89 PTR_SAMPLER,
90 PTR_SAMPLED_IMAGE,
91 PTR_TEXEL,
92 OP_NULL,
93 OP_UNDEF,
94 };
95
96 // The type of vector in use.
97 enum class VectorType
98 {
99 SCALAR = 1,
100 V2 = 2,
101 V3 = 3,
102 V4 = 4,
103 A5 = 5,
104 };
105
106 struct InputStruct
107 {
108 deUint32 uintPart;
109 float floatPart;
110 };
111
112 constexpr auto kImageFormat = VK_FORMAT_R32_UINT;
113 const auto kImageExtent = makeExtent3D(1u, 1u, 1u);
114
115 // For samplers.
116 const VkImageUsageFlags kSampledImageUsage = (VK_IMAGE_USAGE_TRANSFER_DST_BIT | VK_IMAGE_USAGE_SAMPLED_BIT);
117 constexpr size_t kNumImages = 4u;
118 constexpr size_t kNumSamplers = 4u;
119 constexpr size_t kNumCombined = 2u;
120 constexpr size_t kNumAloneImages = kNumImages - kNumCombined;
121 constexpr size_t kNumAloneSamplers = kNumSamplers - kNumCombined;
122
123 // For storage images.
124 const VkImageUsageFlags kStorageImageUsage = (VK_IMAGE_USAGE_TRANSFER_DST_BIT | VK_IMAGE_USAGE_STORAGE_BIT);
125
126 // For the pipeline interface tests.
127 constexpr size_t kNumStorageValues = 6u;
128 constexpr deUint32 kShaderRecordSize = sizeof(tcu::UVec4);
129
130 // Get the effective vector length in memory.
getEffectiveVectorLength(VectorType vectorType)131 size_t getEffectiveVectorLength (VectorType vectorType)
132 {
133 return ((vectorType == VectorType::V3) ? static_cast<size_t>(4) : static_cast<size_t>(vectorType));
134 }
135
136 // Get the corresponding element size.
getElementSize(DataType dataType,VectorType vectorType)137 VkDeviceSize getElementSize(DataType dataType, VectorType vectorType)
138 {
139 const size_t length = getEffectiveVectorLength(vectorType);
140 size_t dataSize = 0u;
141
142 switch (dataType)
143 {
144 case DataType::INT32: dataSize = sizeof(deInt32); break;
145 case DataType::UINT32: dataSize = sizeof(deUint32); break;
146 case DataType::INT64: dataSize = sizeof(deInt64); break;
147 case DataType::UINT64: dataSize = sizeof(deUint64); break;
148 case DataType::INT16: dataSize = sizeof(deInt16); break;
149 case DataType::UINT16: dataSize = sizeof(deUint16); break;
150 case DataType::INT8: dataSize = sizeof(deInt8); break;
151 case DataType::UINT8: dataSize = sizeof(deUint8); break;
152 case DataType::FLOAT32: dataSize = sizeof(tcu::Float32); break;
153 case DataType::FLOAT64: dataSize = sizeof(tcu::Float64); break;
154 case DataType::FLOAT16: dataSize = sizeof(tcu::Float16); break;
155 case DataType::STRUCT: dataSize = sizeof(InputStruct); break;
156 case DataType::IMAGE: // fallthrough.
157 case DataType::SAMPLER: // fallthrough.
158 case DataType::SAMPLED_IMAGE: // fallthrough.
159 case DataType::PTR_IMAGE: // fallthrough.
160 case DataType::PTR_SAMPLER: // fallthrough.
161 case DataType::PTR_SAMPLED_IMAGE: // fallthrough.
162 dataSize = sizeof(tcu::Float32); break;
163 case DataType::PTR_TEXEL: dataSize = sizeof(deInt32); break;
164 case DataType::OP_NULL: // fallthrough.
165 case DataType::OP_UNDEF: // fallthrough.
166 dataSize = sizeof(deUint32); break;
167 default: DE_ASSERT(false); break;
168 }
169
170 return static_cast<VkDeviceSize>(dataSize * length);
171 }
172
173 // Proper stage for generating default geometry.
getShaderStageForGeometry(CallType type_)174 VkShaderStageFlagBits getShaderStageForGeometry (CallType type_)
175 {
176 VkShaderStageFlagBits bits = VK_SHADER_STAGE_FLAG_BITS_MAX_ENUM;
177
178 switch (type_)
179 {
180 case CallType::TRACE_RAY: bits = VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR; break;
181 case CallType::EXECUTE_CALLABLE: bits = VK_SHADER_STAGE_CALLABLE_BIT_KHR; break;
182 case CallType::REPORT_INTERSECTION: bits = VK_SHADER_STAGE_INTERSECTION_BIT_KHR; break;
183 default: DE_ASSERT(false); break;
184 }
185
186 DE_ASSERT(bits != VK_SHADER_STAGE_FLAG_BITS_MAX_ENUM);
187 return bits;
188 }
189
getShaderStages(CallType type_)190 VkShaderStageFlags getShaderStages (CallType type_)
191 {
192 VkShaderStageFlags flags = VK_SHADER_STAGE_RAYGEN_BIT_KHR;
193
194 switch (type_)
195 {
196 case CallType::EXECUTE_CALLABLE:
197 flags |= VK_SHADER_STAGE_CALLABLE_BIT_KHR;
198 break;
199 case CallType::TRACE_RAY:
200 flags |= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
201 break;
202 case CallType::REPORT_INTERSECTION:
203 flags |= VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
204 flags |= VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
205 break;
206 default:
207 DE_ASSERT(false);
208 break;
209 }
210
211 return flags;
212 }
213
214 // Some test types need additional descriptors with samplers, images and combined image samplers.
samplersNeeded(DataType dataType)215 bool samplersNeeded (DataType dataType)
216 {
217 bool needed = false;
218
219 switch (dataType)
220 {
221 case DataType::IMAGE:
222 case DataType::SAMPLER:
223 case DataType::SAMPLED_IMAGE:
224 case DataType::PTR_IMAGE:
225 case DataType::PTR_SAMPLER:
226 case DataType::PTR_SAMPLED_IMAGE:
227 needed = true;
228 break;
229 default:
230 break;
231 }
232
233 return needed;
234 }
235
236 // Some test types need an additional descriptor with a storage image.
storageImageNeeded(DataType dataType)237 bool storageImageNeeded (DataType dataType)
238 {
239 return (dataType == DataType::PTR_TEXEL);
240 }
241
242 // Returns two strings:
243 // .first is an optional GLSL additional type declaration (for structs, basically).
244 // .second is the value declaration inside the input block.
getGLSLInputValDecl(DataType dataType,VectorType vectorType)245 std::pair<std::string, std::string> getGLSLInputValDecl (DataType dataType, VectorType vectorType)
246 {
247 using TypePair = std::pair<DataType, VectorType>;
248 using TypeMap = std::map<TypePair, std::string>;
249
250 const std::string varName = "val";
251 const auto dataTypeIdx = static_cast<int>(dataType);
252
253 if (dataTypeIdx >= static_cast<int>(DataType::INT32) && dataTypeIdx <= static_cast<int>(DataType::FLOAT16))
254 {
255 // Note: A5 uses the same type as the scalar version. The array suffix will be added below.
256 const TypeMap map =
257 {
258 std::make_pair(std::make_pair(DataType::INT32, VectorType::SCALAR), "int32_t"),
259 std::make_pair(std::make_pair(DataType::INT32, VectorType::V2), "i32vec2"),
260 std::make_pair(std::make_pair(DataType::INT32, VectorType::V3), "i32vec3"),
261 std::make_pair(std::make_pair(DataType::INT32, VectorType::V4), "i32vec4"),
262 std::make_pair(std::make_pair(DataType::INT32, VectorType::A5), "int32_t"),
263 std::make_pair(std::make_pair(DataType::UINT32, VectorType::SCALAR), "uint32_t"),
264 std::make_pair(std::make_pair(DataType::UINT32, VectorType::V2), "u32vec2"),
265 std::make_pair(std::make_pair(DataType::UINT32, VectorType::V3), "u32vec3"),
266 std::make_pair(std::make_pair(DataType::UINT32, VectorType::V4), "u32vec4"),
267 std::make_pair(std::make_pair(DataType::UINT32, VectorType::A5), "uint32_t"),
268 std::make_pair(std::make_pair(DataType::INT64, VectorType::SCALAR), "int64_t"),
269 std::make_pair(std::make_pair(DataType::INT64, VectorType::V2), "i64vec2"),
270 std::make_pair(std::make_pair(DataType::INT64, VectorType::V3), "i64vec3"),
271 std::make_pair(std::make_pair(DataType::INT64, VectorType::V4), "i64vec4"),
272 std::make_pair(std::make_pair(DataType::INT64, VectorType::A5), "int64_t"),
273 std::make_pair(std::make_pair(DataType::UINT64, VectorType::SCALAR), "uint64_t"),
274 std::make_pair(std::make_pair(DataType::UINT64, VectorType::V2), "u64vec2"),
275 std::make_pair(std::make_pair(DataType::UINT64, VectorType::V3), "u64vec3"),
276 std::make_pair(std::make_pair(DataType::UINT64, VectorType::V4), "u64vec4"),
277 std::make_pair(std::make_pair(DataType::UINT64, VectorType::A5), "uint64_t"),
278 std::make_pair(std::make_pair(DataType::INT16, VectorType::SCALAR), "int16_t"),
279 std::make_pair(std::make_pair(DataType::INT16, VectorType::V2), "i16vec2"),
280 std::make_pair(std::make_pair(DataType::INT16, VectorType::V3), "i16vec3"),
281 std::make_pair(std::make_pair(DataType::INT16, VectorType::V4), "i16vec4"),
282 std::make_pair(std::make_pair(DataType::INT16, VectorType::A5), "int16_t"),
283 std::make_pair(std::make_pair(DataType::UINT16, VectorType::SCALAR), "uint16_t"),
284 std::make_pair(std::make_pair(DataType::UINT16, VectorType::V2), "u16vec2"),
285 std::make_pair(std::make_pair(DataType::UINT16, VectorType::V3), "u16vec3"),
286 std::make_pair(std::make_pair(DataType::UINT16, VectorType::V4), "u16vec4"),
287 std::make_pair(std::make_pair(DataType::UINT16, VectorType::A5), "uint16_t"),
288 std::make_pair(std::make_pair(DataType::INT8, VectorType::SCALAR), "int8_t"),
289 std::make_pair(std::make_pair(DataType::INT8, VectorType::V2), "i8vec2"),
290 std::make_pair(std::make_pair(DataType::INT8, VectorType::V3), "i8vec3"),
291 std::make_pair(std::make_pair(DataType::INT8, VectorType::V4), "i8vec4"),
292 std::make_pair(std::make_pair(DataType::INT8, VectorType::A5), "int8_t"),
293 std::make_pair(std::make_pair(DataType::UINT8, VectorType::SCALAR), "uint8_t"),
294 std::make_pair(std::make_pair(DataType::UINT8, VectorType::V2), "u8vec2"),
295 std::make_pair(std::make_pair(DataType::UINT8, VectorType::V3), "u8vec3"),
296 std::make_pair(std::make_pair(DataType::UINT8, VectorType::V4), "u8vec4"),
297 std::make_pair(std::make_pair(DataType::UINT8, VectorType::A5), "uint8_t"),
298 std::make_pair(std::make_pair(DataType::FLOAT32, VectorType::SCALAR), "float32_t"),
299 std::make_pair(std::make_pair(DataType::FLOAT32, VectorType::V2), "f32vec2"),
300 std::make_pair(std::make_pair(DataType::FLOAT32, VectorType::V3), "f32vec3"),
301 std::make_pair(std::make_pair(DataType::FLOAT32, VectorType::V4), "f32vec4"),
302 std::make_pair(std::make_pair(DataType::FLOAT32, VectorType::A5), "float32_t"),
303 std::make_pair(std::make_pair(DataType::FLOAT64, VectorType::SCALAR), "float64_t"),
304 std::make_pair(std::make_pair(DataType::FLOAT64, VectorType::V2), "f64vec2"),
305 std::make_pair(std::make_pair(DataType::FLOAT64, VectorType::V3), "f64vec3"),
306 std::make_pair(std::make_pair(DataType::FLOAT64, VectorType::V4), "f64vec4"),
307 std::make_pair(std::make_pair(DataType::FLOAT64, VectorType::A5), "float64_t"),
308 std::make_pair(std::make_pair(DataType::FLOAT16, VectorType::SCALAR), "float16_t"),
309 std::make_pair(std::make_pair(DataType::FLOAT16, VectorType::V2), "f16vec2"),
310 std::make_pair(std::make_pair(DataType::FLOAT16, VectorType::V3), "f16vec3"),
311 std::make_pair(std::make_pair(DataType::FLOAT16, VectorType::V4), "f16vec4"),
312 std::make_pair(std::make_pair(DataType::FLOAT16, VectorType::A5), "float16_t"),
313 };
314
315 const auto key = std::make_pair(dataType, vectorType);
316 const auto found = map.find(key);
317
318 DE_ASSERT(found != end(map));
319
320 const auto baseType = found->second;
321 const std::string decl = baseType + " " + varName + ((vectorType == VectorType::A5) ? "[5]" : "") + ";";
322
323 return std::make_pair(std::string(), decl);
324 }
325 else if (dataType == DataType::STRUCT)
326 {
327 return std::make_pair(std::string("struct InputStruct { uint val1; float val2; };\n"), std::string("InputStruct val;"));
328 }
329 else if (samplersNeeded(dataType))
330 {
331 return std::make_pair(std::string(), std::string("float val;"));
332 }
333 else if (storageImageNeeded(dataType))
334 {
335 return std::make_pair(std::string(), std::string("int val;"));
336 }
337 else if (dataType == DataType::OP_NULL || dataType == DataType::OP_UNDEF)
338 {
339 return std::make_pair(std::string(), std::string("uint val;"));
340 }
341
342 // Unreachable.
343 DE_ASSERT(false);
344 return std::make_pair(std::string(), std::string());
345 }
346
347 class DataSpillTestCase : public vkt::TestCase
348 {
349 public:
350 struct TestParams
351 {
352 CallType callType;
353 DataType dataType;
354 VectorType vectorType;
355 };
356
357 DataSpillTestCase (tcu::TestContext& testCtx, const std::string& name, const std::string& description, const TestParams& testParams);
~DataSpillTestCase(void)358 virtual ~DataSpillTestCase (void) {}
359
360 virtual void initPrograms (vk::SourceCollections& programCollection) const;
361 virtual TestInstance* createInstance (Context& context) const;
362 virtual void checkSupport (Context& context) const;
363
364 private:
365 TestParams m_params;
366 };
367
368 class DataSpillTestInstance : public vkt::TestInstance
369 {
370 public:
371 using TestParams = DataSpillTestCase::TestParams;
372
373 DataSpillTestInstance (Context& context, const TestParams& testParams);
~DataSpillTestInstance(void)374 virtual ~DataSpillTestInstance (void) {}
375
376 virtual tcu::TestStatus iterate (void);
377
378 private:
379 TestParams m_params;
380 };
381
382
DataSpillTestCase(tcu::TestContext & testCtx,const std::string & name,const std::string & description,const TestParams & testParams)383 DataSpillTestCase::DataSpillTestCase (tcu::TestContext& testCtx, const std::string& name, const std::string& description, const TestParams& testParams)
384 : vkt::TestCase (testCtx, name, description)
385 , m_params (testParams)
386 {
387 switch (m_params.dataType)
388 {
389 case DataType::STRUCT:
390 case DataType::IMAGE:
391 case DataType::SAMPLER:
392 case DataType::SAMPLED_IMAGE:
393 case DataType::PTR_IMAGE:
394 case DataType::PTR_SAMPLER:
395 case DataType::PTR_SAMPLED_IMAGE:
396 case DataType::PTR_TEXEL:
397 case DataType::OP_NULL:
398 case DataType::OP_UNDEF:
399 DE_ASSERT(m_params.vectorType == VectorType::SCALAR);
400 break;
401 default:
402 break;
403 }
404
405 // The code assumes at most one of these is needed.
406 DE_ASSERT(!(samplersNeeded(m_params.dataType) && storageImageNeeded(m_params.dataType)));
407 }
408
createInstance(Context & context) const409 TestInstance* DataSpillTestCase::createInstance (Context& context) const
410 {
411 return new DataSpillTestInstance(context, m_params);
412 }
413
DataSpillTestInstance(Context & context,const TestParams & testParams)414 DataSpillTestInstance::DataSpillTestInstance (Context& context, const TestParams& testParams)
415 : vkt::TestInstance (context)
416 , m_params (testParams)
417 {
418 }
419
420 // General checks for all tests.
commonCheckSupport(Context & context)421 void commonCheckSupport (Context& context)
422 {
423 context.requireDeviceFunctionality("VK_KHR_acceleration_structure");
424 context.requireDeviceFunctionality("VK_KHR_ray_tracing_pipeline");
425
426 const auto& rtFeatures = context.getRayTracingPipelineFeatures();
427 if (!rtFeatures.rayTracingPipeline)
428 TCU_THROW(NotSupportedError, "Ray Tracing pipelines not supported");
429
430 const auto& asFeatures = context.getAccelerationStructureFeatures();
431 if (!asFeatures.accelerationStructure)
432 TCU_FAIL("VK_KHR_acceleration_structure supported without accelerationStructure support");
433
434 }
435
checkSupport(Context & context) const436 void DataSpillTestCase::checkSupport (Context& context) const
437 {
438 // General checks first.
439 commonCheckSupport(context);
440
441 const auto& features = context.getDeviceFeatures();
442 const auto& featuresStorage16 = context.get16BitStorageFeatures();
443 const auto& featuresF16I8 = context.getShaderFloat16Int8Features();
444 const auto& featuresStorage8 = context.get8BitStorageFeatures();
445
446 if (m_params.dataType == DataType::INT64 || m_params.dataType == DataType::UINT64)
447 {
448 if (!features.shaderInt64)
449 TCU_THROW(NotSupportedError, "64-bit integers not supported");
450 }
451 else if (m_params.dataType == DataType::INT16 || m_params.dataType == DataType::UINT16)
452 {
453 context.requireDeviceFunctionality("VK_KHR_16bit_storage");
454
455 if (!features.shaderInt16)
456 TCU_THROW(NotSupportedError, "16-bit integers not supported");
457
458 if (!featuresStorage16.storageBuffer16BitAccess)
459 TCU_THROW(NotSupportedError, "16-bit storage buffer access not supported");
460 }
461 else if (m_params.dataType == DataType::INT8 || m_params.dataType == DataType::UINT8)
462 {
463 context.requireDeviceFunctionality("VK_KHR_shader_float16_int8");
464 context.requireDeviceFunctionality("VK_KHR_8bit_storage");
465
466 if (!featuresF16I8.shaderInt8)
467 TCU_THROW(NotSupportedError, "8-bit integers not supported");
468
469 if (!featuresStorage8.storageBuffer8BitAccess)
470 TCU_THROW(NotSupportedError, "8-bit storage buffer access not supported");
471 }
472 else if (m_params.dataType == DataType::FLOAT64)
473 {
474 if (!features.shaderFloat64)
475 TCU_THROW(NotSupportedError, "64-bit floats not supported");
476 }
477 else if (m_params.dataType == DataType::FLOAT16)
478 {
479 context.requireDeviceFunctionality("VK_KHR_shader_float16_int8");
480 context.requireDeviceFunctionality("VK_KHR_16bit_storage");
481
482 if (!featuresF16I8.shaderFloat16)
483 TCU_THROW(NotSupportedError, "16-bit floats not supported");
484
485 if (!featuresStorage16.storageBuffer16BitAccess)
486 TCU_THROW(NotSupportedError, "16-bit storage buffer access not supported");
487 }
488 else if (samplersNeeded(m_params.dataType))
489 {
490 context.requireDeviceFunctionality("VK_EXT_descriptor_indexing");
491 const auto indexingFeatures = context.getDescriptorIndexingFeatures();
492 if (!indexingFeatures.shaderSampledImageArrayNonUniformIndexing)
493 TCU_THROW(NotSupportedError, "No support for non-uniform sampled image arrays");
494 }
495 }
496
initPrograms(vk::SourceCollections & programCollection) const497 void DataSpillTestCase::initPrograms (vk::SourceCollections& programCollection) const
498 {
499 const vk::ShaderBuildOptions buildOptions (programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
500 const vk::SpirVAsmBuildOptions spvBuildOptions (programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, true);
501
502 std::ostringstream spvTemplateStream;
503
504 // This SPIR-V template will be used to generate shaders for different
505 // stages (raygen, callable, etc). The basic mechanism uses 3 SSBOs: one
506 // used strictly as an input, one to write the check result, and one to
507 // verify the shader call has taken place. The latter two SSBOs contain just
508 // a single uint, but the input SSBO typically contains other type of data
509 // that will be filled from the test instance with predetermined values. The
510 // shader will expect this data to have specific values that can be combined
511 // some way to give an expected result (e.g. by adding the 4 components if
512 // it's a vec4). This result will be used in the shader call to make sure
513 // input values are read *before* the call. After the shader call has taken
514 // place, the shader will attempt to read the input buffer again and verify
515 // the value is still correct and matches the previous one. If the result
516 // matches, it will write a confirmation value in the check buffer. In the
517 // mean time, the callee will write a confirmation value in the callee
518 // buffer to verify the shader call took place.
519 //
520 // Some test variants use samplers, images or sampled images. These need
521 // additional bindings of different types and the interesting value is
522 // typically placed in the image instead of the input buffer, while the
523 // input buffer is used for sampling coordinates instead.
524 //
525 // Some important SPIR-V template variables:
526 //
527 // - INPUT_BUFFER_VALUE_TYPE will contain the type of input buffer data.
528 // - CALC_ZERO_FOR_CALLABLE is expected to contain instructions that will
529 // calculate a value of zero to be used in the shader call instruction.
530 // This value should be derived from the input data.
531 // - CALL_STATEMENTS will contain the shader call instructions.
532 // - CALC_EQUAL_STATEMENT is expected to contain instructions that will
533 // set %equal to true as a %bool if the before- and after- data match.
534 //
535 // - %input_val_ptr contains the pointer to the input value.
536 // - %input_val_before contains the value read before the call.
537 // - %input_val_after contains the value read after the call.
538
539 spvTemplateStream
540 << " OpCapability RayTracingKHR\n"
541 << "${EXTRA_CAPABILITIES}"
542 << " OpExtension \"SPV_KHR_ray_tracing\"\n"
543 << "${EXTRA_EXTENSIONS}"
544 << " OpMemoryModel Logical GLSL450\n"
545 << " OpEntryPoint ${ENTRY_POINT} %main \"main\" %topLevelAS %calleeBuffer %outputBuffer %inputBuffer${MAIN_INTERFACE_EXTRAS}\n"
546 << "${INTERFACE_DECORATIONS}"
547 << " OpMemberDecorate %InputBlock 0 Offset 0\n"
548 << " OpDecorate %InputBlock Block\n"
549 << " OpDecorate %inputBuffer DescriptorSet 0\n"
550 << " OpDecorate %inputBuffer Binding 3\n"
551 << " OpMemberDecorate %OutputBlock 0 Offset 0\n"
552 << " OpDecorate %OutputBlock Block\n"
553 << " OpDecorate %outputBuffer DescriptorSet 0\n"
554 << " OpDecorate %outputBuffer Binding 2\n"
555 << " OpMemberDecorate %CalleeBlock 0 Offset 0\n"
556 << " OpDecorate %CalleeBlock Block\n"
557 << " OpDecorate %calleeBuffer DescriptorSet 0\n"
558 << " OpDecorate %calleeBuffer Binding 1\n"
559 << " OpDecorate %topLevelAS DescriptorSet 0\n"
560 << " OpDecorate %topLevelAS Binding 0\n"
561 << "${EXTRA_BINDINGS}"
562 << " %void = OpTypeVoid\n"
563 << " %void_func = OpTypeFunction %void\n"
564 << " %int = OpTypeInt 32 1\n"
565 << " %uint = OpTypeInt 32 0\n"
566 << " %int_0 = OpConstant %int 0\n"
567 << " %uint_0 = OpConstant %uint 0\n"
568 << " %uint_1 = OpConstant %uint 1\n"
569 << " %uint_2 = OpConstant %uint 2\n"
570 << " %uint_3 = OpConstant %uint 3\n"
571 << " %uint_4 = OpConstant %uint 4\n"
572 << " %uint_5 = OpConstant %uint 5\n"
573 << " %uint_255 = OpConstant %uint 255\n"
574 << " %bool = OpTypeBool\n"
575 << " %float = OpTypeFloat 32\n"
576 << " %float_0 = OpConstant %float 0\n"
577 << " %float_1 = OpConstant %float 1\n"
578 << " %float_9 = OpConstant %float 9\n"
579 << " %float_0_5 = OpConstant %float 0.5\n"
580 << " %float_n1 = OpConstant %float -1\n"
581 << " %v3float = OpTypeVector %float 3\n"
582 << " %origin_const = OpConstantComposite %v3float %float_0_5 %float_0_5 %float_0\n"
583 << " %direction_const = OpConstantComposite %v3float %float_0 %float_0 %float_n1\n"
584 << "${EXTRA_TYPES_AND_CONSTANTS}"
585 << " %data_func_ptr = OpTypePointer Function ${INPUT_BUFFER_VALUE_TYPE}\n"
586 << "${INTERFACE_TYPES_AND_VARIABLES}"
587 << " %InputBlock = OpTypeStruct ${INPUT_BUFFER_VALUE_TYPE}\n"
588 << " %_ptr_StorageBuffer_InputBlock = OpTypePointer StorageBuffer %InputBlock\n"
589 << " %inputBuffer = OpVariable %_ptr_StorageBuffer_InputBlock StorageBuffer\n"
590 << " %data_storagebuffer_ptr = OpTypePointer StorageBuffer ${INPUT_BUFFER_VALUE_TYPE}\n"
591 << " %OutputBlock = OpTypeStruct %uint\n"
592 << "%_ptr_StorageBuffer_OutputBlock = OpTypePointer StorageBuffer %OutputBlock\n"
593 << " %outputBuffer = OpVariable %_ptr_StorageBuffer_OutputBlock StorageBuffer\n"
594 << " %_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint\n"
595 << " %CalleeBlock = OpTypeStruct %uint\n"
596 << "%_ptr_StorageBuffer_CalleeBlock = OpTypePointer StorageBuffer %CalleeBlock\n"
597 << " %calleeBuffer = OpVariable %_ptr_StorageBuffer_CalleeBlock StorageBuffer\n"
598 << " %as_type = OpTypeAccelerationStructureKHR\n"
599 << " %as_uniformconstant_ptr = OpTypePointer UniformConstant %as_type\n"
600 << " %topLevelAS = OpVariable %as_uniformconstant_ptr UniformConstant\n"
601 << "${EXTRA_BINDING_VARIABLES}"
602 << " %main = OpFunction %void None %void_func\n"
603 << " %main_label = OpLabel\n"
604 << "${EXTRA_FUNCTION_VARIABLES}"
605 << " %input_val_ptr = OpAccessChain %data_storagebuffer_ptr %inputBuffer %int_0\n"
606 << " %output_val_ptr = OpAccessChain %_ptr_StorageBuffer_uint %outputBuffer %int_0\n"
607 // Note we use Volatile to load the input buffer value before and after the call statements.
608 << " %input_val_before = OpLoad ${INPUT_BUFFER_VALUE_TYPE} %input_val_ptr Volatile\n"
609 << "${CALC_ZERO_FOR_CALLABLE}"
610 << "${CALL_STATEMENTS}"
611 << " %input_val_after = OpLoad ${INPUT_BUFFER_VALUE_TYPE} %input_val_ptr Volatile\n"
612 << "${CALC_EQUAL_STATEMENT}"
613 << " %output_val = OpSelect %uint %equal %uint_1 %uint_0\n"
614 << " OpStore %output_val_ptr %output_val\n"
615 << " OpReturn\n"
616 << " OpFunctionEnd\n"
617 ;
618
619 const tcu::StringTemplate spvTemplate (spvTemplateStream.str());
620
621 std::map<std::string, std::string> subs;
622 std::string componentTypeName;
623 std::string opEqual;
624 const int numComponents = static_cast<int>(m_params.vectorType);
625 const auto isArray = (numComponents > static_cast<int>(VectorType::V4));
626 const auto numComponentsStr = de::toString(numComponents);
627
628 subs["EXTRA_CAPABILITIES"] = "";
629 subs["EXTRA_EXTENSIONS"] = "";
630 subs["EXTRA_TYPES_AND_CONSTANTS"] = "";
631 subs["EXTRA_FUNCTION_VARIABLES"] = "";
632 subs["EXTRA_BINDINGS"] = "";
633 subs["EXTRA_BINDING_VARIABLES"] = "";
634 subs["EXTRA_FUNCTIONS"] = "";
635
636 // Take into account some of these substitutions will be updated after the if-block.
637
638 if (m_params.dataType == DataType::INT32)
639 {
640 componentTypeName = "int";
641
642 subs["INPUT_BUFFER_VALUE_TYPE"] = "%int";
643 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %int_37 = OpConstant %int 37\n";
644 subs["CALC_ZERO_FOR_CALLABLE"] = " %zero_int = OpISub %int %input_val_before %int_37\n"
645 " %zero_for_callable = OpBitcast %uint %zero_int\n";
646 }
647 else if (m_params.dataType == DataType::UINT32)
648 {
649 componentTypeName = "uint";
650
651 subs["INPUT_BUFFER_VALUE_TYPE"] = "%uint";
652 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %uint_37 = OpConstant %uint 37\n";
653 subs["CALC_ZERO_FOR_CALLABLE"] = " %zero_for_callable = OpISub %uint %input_val_before %uint_37\n";
654 }
655 else if (m_params.dataType == DataType::INT64)
656 {
657 componentTypeName = "long";
658
659 subs["EXTRA_CAPABILITIES"] += " OpCapability Int64\n";
660 subs["INPUT_BUFFER_VALUE_TYPE"] = "%long";
661 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %long = OpTypeInt 64 1\n"
662 " %long_37 = OpConstant %long 37\n";
663 subs["CALC_ZERO_FOR_CALLABLE"] = " %zero_long = OpISub %long %input_val_before %long_37\n"
664 " %zero_for_callable = OpSConvert %uint %zero_long\n";
665 }
666 else if (m_params.dataType == DataType::UINT64)
667 {
668 componentTypeName = "ulong";
669
670 subs["EXTRA_CAPABILITIES"] += " OpCapability Int64\n";
671 subs["INPUT_BUFFER_VALUE_TYPE"] = "%ulong";
672 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %ulong = OpTypeInt 64 0\n"
673 " %ulong_37 = OpConstant %ulong 37\n";
674 subs["CALC_ZERO_FOR_CALLABLE"] = " %zero_ulong = OpISub %ulong %input_val_before %ulong_37\n"
675 " %zero_for_callable = OpUConvert %uint %zero_ulong\n";
676 }
677 else if (m_params.dataType == DataType::INT16)
678 {
679 componentTypeName = "short";
680
681 subs["EXTRA_CAPABILITIES"] += " OpCapability Int16\n"
682 " OpCapability StorageBuffer16BitAccess\n";
683 subs["EXTRA_EXTENSIONS"] += " OpExtension \"SPV_KHR_16bit_storage\"\n";
684 subs["INPUT_BUFFER_VALUE_TYPE"] = "%short";
685 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %short = OpTypeInt 16 1\n"
686 " %short_37 = OpConstant %short 37\n";
687 subs["CALC_ZERO_FOR_CALLABLE"] = " %zero_short = OpISub %short %input_val_before %short_37\n"
688 " %zero_for_callable = OpSConvert %uint %zero_short\n";
689 }
690 else if (m_params.dataType == DataType::UINT16)
691 {
692 componentTypeName = "ushort";
693
694 subs["EXTRA_CAPABILITIES"] += " OpCapability Int16\n"
695 " OpCapability StorageBuffer16BitAccess\n";
696 subs["EXTRA_EXTENSIONS"] += " OpExtension \"SPV_KHR_16bit_storage\"\n";
697 subs["INPUT_BUFFER_VALUE_TYPE"] = "%ushort";
698 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %ushort = OpTypeInt 16 0\n"
699 " %ushort_37 = OpConstant %ushort 37\n";
700 subs["CALC_ZERO_FOR_CALLABLE"] = " %zero_ushort = OpISub %ushort %input_val_before %ushort_37\n"
701 " %zero_for_callable = OpUConvert %uint %zero_ushort\n";
702 }
703 else if (m_params.dataType == DataType::INT8)
704 {
705 componentTypeName = "char";
706
707 subs["EXTRA_CAPABILITIES"] += " OpCapability Int8\n"
708 " OpCapability StorageBuffer8BitAccess\n";
709 subs["EXTRA_EXTENSIONS"] += " OpExtension \"SPV_KHR_8bit_storage\"\n";
710 subs["INPUT_BUFFER_VALUE_TYPE"] = "%char";
711 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %char = OpTypeInt 8 1\n"
712 " %char_37 = OpConstant %char 37\n";
713 subs["CALC_ZERO_FOR_CALLABLE"] = " %zero_char = OpISub %char %input_val_before %char_37\n"
714 " %zero_for_callable = OpSConvert %uint %zero_char\n";
715 }
716 else if (m_params.dataType == DataType::UINT8)
717 {
718 componentTypeName = "uchar";
719
720 subs["EXTRA_CAPABILITIES"] += " OpCapability Int8\n"
721 " OpCapability StorageBuffer8BitAccess\n";
722 subs["EXTRA_EXTENSIONS"] += " OpExtension \"SPV_KHR_8bit_storage\"\n";
723 subs["INPUT_BUFFER_VALUE_TYPE"] = "%uchar";
724 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %uchar = OpTypeInt 8 0\n"
725 " %uchar_37 = OpConstant %uchar 37\n";
726 subs["CALC_ZERO_FOR_CALLABLE"] = " %zero_uchar = OpISub %uchar %input_val_before %uchar_37\n"
727 " %zero_for_callable = OpUConvert %uint %zero_uchar\n";
728 }
729 else if (m_params.dataType == DataType::FLOAT32)
730 {
731 componentTypeName = "float";
732
733 subs["INPUT_BUFFER_VALUE_TYPE"] = "%float";
734 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %float_37 = OpConstant %float 37\n";
735 subs["CALC_ZERO_FOR_CALLABLE"] = " %zero_float = OpFSub %float %input_val_before %float_37\n"
736 " %zero_for_callable = OpConvertFToU %uint %zero_float\n";
737 }
738 else if (m_params.dataType == DataType::FLOAT64)
739 {
740 componentTypeName = "double";
741
742 subs["EXTRA_CAPABILITIES"] += " OpCapability Float64\n";
743 subs["INPUT_BUFFER_VALUE_TYPE"] = "%double";
744 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %double = OpTypeFloat 64\n"
745 " %double_37 = OpConstant %double 37\n";
746 subs["CALC_ZERO_FOR_CALLABLE"] = " %zero_double = OpFSub %double %input_val_before %double_37\n"
747 " %zero_for_callable = OpConvertFToU %uint %zero_double\n";
748 }
749 else if (m_params.dataType == DataType::FLOAT16)
750 {
751 componentTypeName = "half";
752
753 subs["EXTRA_CAPABILITIES"] += " OpCapability Float16\n"
754 " OpCapability StorageBuffer16BitAccess\n";
755 subs["EXTRA_EXTENSIONS"] += " OpExtension \"SPV_KHR_16bit_storage\"\n";
756 subs["INPUT_BUFFER_VALUE_TYPE"] = "%half";
757 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %half = OpTypeFloat 16\n"
758 " %half_37 = OpConstant %half 37\n";
759 subs["CALC_ZERO_FOR_CALLABLE"] = " %zero_half = OpFSub %half %input_val_before %half_37\n"
760 " %zero_for_callable = OpConvertFToU %uint %zero_half\n";
761 }
762 else if (m_params.dataType == DataType::STRUCT)
763 {
764 componentTypeName = "InputStruct";
765
766 subs["INPUT_BUFFER_VALUE_TYPE"] = "%InputStruct";
767 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %InputStruct = OpTypeStruct %uint %float\n"
768 " %float_37 = OpConstant %float 37\n"
769 " %uint_part_ptr_type = OpTypePointer StorageBuffer %uint\n"
770 " %float_part_ptr_type = OpTypePointer StorageBuffer %float\n"
771 " %uint_part_func_ptr_type = OpTypePointer Function %uint\n"
772 " %float_part_func_ptr_type = OpTypePointer Function %float\n"
773 " %input_struct_func_ptr_type = OpTypePointer Function %InputStruct\n"
774 ;
775 subs["INTERFACE_DECORATIONS"] = " OpMemberDecorate %InputStruct 0 Offset 0\n"
776 " OpMemberDecorate %InputStruct 1 Offset 4\n";
777
778 // Sum struct members, then substract constant and convert to uint.
779 subs["CALC_ZERO_FOR_CALLABLE"] = " %uint_part_ptr = OpAccessChain %uint_part_ptr_type %input_val_ptr %uint_0\n"
780 " %float_part_ptr = OpAccessChain %float_part_ptr_type %input_val_ptr %uint_1\n"
781 " %uint_part = OpLoad %uint %uint_part_ptr\n"
782 " %float_part = OpLoad %float %float_part_ptr\n"
783 " %uint_as_float = OpConvertUToF %float %uint_part\n"
784 " %member_sum = OpFAdd %float %float_part %uint_as_float\n"
785 " %zero_float = OpFSub %float %member_sum %float_37\n"
786 " %zero_for_callable = OpConvertFToU %uint %zero_float\n"
787 ;
788 }
789 else if (samplersNeeded(m_params.dataType))
790 {
791 // These tests will use additional bindings as arrays of 2 elements:
792 // - 1 array of samplers.
793 // - 1 array of images.
794 // - 1 array of combined image samplers.
795 // Input values are typically used as texture coordinates (normally zeros)
796 // Pixels will contain the expected values instead of them being in the input buffer.
797
798 subs["INPUT_BUFFER_VALUE_TYPE"] = "%float";
799 subs["EXTRA_CAPABILITIES"] += " OpCapability SampledImageArrayNonUniformIndexing\n";
800 subs["EXTRA_EXTENSIONS"] += " OpExtension \"SPV_EXT_descriptor_indexing\"\n";
801 subs["MAIN_INTERFACE_EXTRAS"] += " %sampledTexture %textureSampler %combinedImageSampler";
802 subs["EXTRA_BINDINGS"] += " OpDecorate %sampledTexture DescriptorSet 0\n"
803 " OpDecorate %sampledTexture Binding 4\n"
804 " OpDecorate %textureSampler DescriptorSet 0\n"
805 " OpDecorate %textureSampler Binding 5\n"
806 " OpDecorate %combinedImageSampler DescriptorSet 0\n"
807 " OpDecorate %combinedImageSampler Binding 6\n";
808 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %uint_37 = OpConstant %uint 37\n"
809 " %v4uint = OpTypeVector %uint 4\n"
810 " %v2float = OpTypeVector %float 2\n"
811 " %image_type = OpTypeImage %uint 2D 0 0 0 1 Unknown\n"
812 " %image_array_type = OpTypeArray %image_type %uint_2\n"
813 " %image_array_type_uniform_ptr = OpTypePointer UniformConstant %image_array_type\n"
814 " %image_type_uniform_ptr = OpTypePointer UniformConstant %image_type\n"
815 " %sampler_type = OpTypeSampler\n"
816 " %sampler_array_type = OpTypeArray %sampler_type %uint_2\n"
817 "%sampler_array_type_uniform_ptr = OpTypePointer UniformConstant %sampler_array_type\n"
818 " %sampler_type_uniform_ptr = OpTypePointer UniformConstant %sampler_type\n"
819 " %sampled_image_type = OpTypeSampledImage %image_type\n"
820 " %sampled_image_array_type = OpTypeArray %sampled_image_type %uint_2\n"
821 "%sampled_image_array_type_uniform_ptr = OpTypePointer UniformConstant %sampled_image_array_type\n"
822 "%sampled_image_type_uniform_ptr = OpTypePointer UniformConstant %sampled_image_type\n"
823 ;
824 subs["EXTRA_BINDING_VARIABLES"] += " %sampledTexture = OpVariable %image_array_type_uniform_ptr UniformConstant\n"
825 " %textureSampler = OpVariable %sampler_array_type_uniform_ptr UniformConstant\n"
826 " %combinedImageSampler = OpVariable %sampled_image_array_type_uniform_ptr UniformConstant\n"
827 ;
828
829 if (m_params.dataType == DataType::IMAGE || m_params.dataType == DataType::SAMPLER)
830 {
831 // Use the first sampler and sample from the first image.
832 subs["CALC_ZERO_FOR_CALLABLE"] += "%image_0_ptr = OpAccessChain %image_type_uniform_ptr %sampledTexture %uint_0\n"
833 "%sampler_0_ptr = OpAccessChain %sampler_type_uniform_ptr %textureSampler %uint_0\n"
834 "%sampler_0 = OpLoad %sampler_type %sampler_0_ptr\n"
835 "%image_0 = OpLoad %image_type %image_0_ptr\n"
836 "%sampled_image_0 = OpSampledImage %sampled_image_type %image_0 %sampler_0\n"
837 "%texture_coords_0 = OpCompositeConstruct %v2float %input_val_before %input_val_before\n"
838 "%pixel_vec_0 = OpImageSampleExplicitLod %v4uint %sampled_image_0 %texture_coords_0 Lod|ZeroExtend %float_0\n"
839 "%pixel_0 = OpCompositeExtract %uint %pixel_vec_0 0\n"
840 "%zero_for_callable = OpISub %uint %pixel_0 %uint_37\n"
841 ;
842 }
843 else if (m_params.dataType == DataType::SAMPLED_IMAGE)
844 {
845 // Use the first combined image sampler.
846 subs["CALC_ZERO_FOR_CALLABLE"] += "%sampled_image_0_ptr = OpAccessChain %sampled_image_type_uniform_ptr %combinedImageSampler %uint_0\n"
847 "%sampled_image_0 = OpLoad %sampled_image_type %sampled_image_0_ptr\n"
848 "%texture_coords_0 = OpCompositeConstruct %v2float %input_val_before %input_val_before\n"
849 "%pixel_vec_0 = OpImageSampleExplicitLod %v4uint %sampled_image_0 %texture_coords_0 Lod|ZeroExtend %float_0\n"
850 "%pixel_0 = OpCompositeExtract %uint %pixel_vec_0 0\n"
851 "%zero_for_callable = OpISub %uint %pixel_0 %uint_37\n"
852 ;
853 }
854 else if (m_params.dataType == DataType::PTR_IMAGE)
855 {
856 // We attempt to create the second pointer before the call.
857 subs["CALC_ZERO_FOR_CALLABLE"] += "%image_0_ptr = OpAccessChain %image_type_uniform_ptr %sampledTexture %uint_0\n"
858 "%image_1_ptr = OpAccessChain %image_type_uniform_ptr %sampledTexture %uint_1\n"
859 "%image_0 = OpLoad %image_type %image_0_ptr\n"
860 "%sampler_0_ptr = OpAccessChain %sampler_type_uniform_ptr %textureSampler %uint_0\n"
861 "%sampler_0 = OpLoad %sampler_type %sampler_0_ptr\n"
862 "%sampled_image_0 = OpSampledImage %sampled_image_type %image_0 %sampler_0\n"
863 "%texture_coords_0 = OpCompositeConstruct %v2float %input_val_before %input_val_before\n"
864 "%pixel_vec_0 = OpImageSampleExplicitLod %v4uint %sampled_image_0 %texture_coords_0 Lod|ZeroExtend %float_0\n"
865 "%pixel_0 = OpCompositeExtract %uint %pixel_vec_0 0\n"
866 "%zero_for_callable = OpISub %uint %pixel_0 %uint_37\n"
867 ;
868 }
869 else if (m_params.dataType == DataType::PTR_SAMPLER)
870 {
871 // We attempt to create the second pointer before the call.
872 subs["CALC_ZERO_FOR_CALLABLE"] += "%sampler_0_ptr = OpAccessChain %sampler_type_uniform_ptr %textureSampler %uint_0\n"
873 "%sampler_1_ptr = OpAccessChain %sampler_type_uniform_ptr %textureSampler %uint_1\n"
874 "%sampler_0 = OpLoad %sampler_type %sampler_0_ptr\n"
875 "%image_0_ptr = OpAccessChain %image_type_uniform_ptr %sampledTexture %uint_0\n"
876 "%image_0 = OpLoad %image_type %image_0_ptr\n"
877 "%sampled_image_0 = OpSampledImage %sampled_image_type %image_0 %sampler_0\n"
878 "%texture_coords_0 = OpCompositeConstruct %v2float %input_val_before %input_val_before\n"
879 "%pixel_vec_0 = OpImageSampleExplicitLod %v4uint %sampled_image_0 %texture_coords_0 Lod|ZeroExtend %float_0\n"
880 "%pixel_0 = OpCompositeExtract %uint %pixel_vec_0 0\n"
881 "%zero_for_callable = OpISub %uint %pixel_0 %uint_37\n"
882 ;
883 }
884 else if (m_params.dataType == DataType::PTR_SAMPLED_IMAGE)
885 {
886 // We attempt to create the second pointer before the call.
887 subs["CALC_ZERO_FOR_CALLABLE"] += "%sampled_image_0_ptr = OpAccessChain %sampled_image_type_uniform_ptr %combinedImageSampler %uint_0\n"
888 "%sampled_image_1_ptr = OpAccessChain %sampled_image_type_uniform_ptr %combinedImageSampler %uint_1\n"
889 "%sampled_image_0 = OpLoad %sampled_image_type %sampled_image_0_ptr\n"
890 "%texture_coords_0 = OpCompositeConstruct %v2float %input_val_before %input_val_before\n"
891 "%pixel_vec_0 = OpImageSampleExplicitLod %v4uint %sampled_image_0 %texture_coords_0 Lod|ZeroExtend %float_0\n"
892 "%pixel_0 = OpCompositeExtract %uint %pixel_vec_0 0\n"
893 "%zero_for_callable = OpISub %uint %pixel_0 %uint_37\n"
894 ;
895 }
896 else
897 {
898 DE_ASSERT(false);
899 }
900 }
901 else if (storageImageNeeded(m_params.dataType))
902 {
903 subs["INPUT_BUFFER_VALUE_TYPE"] = "%int";
904 subs["MAIN_INTERFACE_EXTRAS"] += " %storageImage";
905 subs["EXTRA_BINDINGS"] += " OpDecorate %storageImage DescriptorSet 0\n"
906 " OpDecorate %storageImage Binding 4\n"
907 ;
908 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %uint_37 = OpConstant %uint 37\n"
909 " %v2int = OpTypeVector %int 2\n"
910 " %image_type = OpTypeImage %uint 2D 0 0 0 2 R32ui\n"
911 " %image_type_uniform_ptr = OpTypePointer UniformConstant %image_type\n"
912 " %uint_img_ptr = OpTypePointer Image %uint\n"
913 ;
914 subs["EXTRA_BINDING_VARIABLES"] += " %storageImage = OpVariable %image_type_uniform_ptr UniformConstant\n"
915 ;
916
917 // Load value from the image, expecting it to be 37 and swapping it with 5.
918 subs["CALC_ZERO_FOR_CALLABLE"] += "%coords = OpCompositeConstruct %v2int %input_val_before %input_val_before\n"
919 "%texel_ptr = OpImageTexelPointer %uint_img_ptr %storageImage %coords %uint_0\n"
920 "%texel_value = OpAtomicCompareExchange %uint %texel_ptr %uint_1 %uint_0 %uint_0 %uint_5 %uint_37\n"
921 "%zero_for_callable = OpISub %uint %texel_value %uint_37\n"
922 ;
923 }
924 else if (m_params.dataType == DataType::OP_NULL)
925 {
926 subs["INPUT_BUFFER_VALUE_TYPE"] = "%uint";
927 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %uint_37 = OpConstant %uint 37\n"
928 " %constant_null = OpConstantNull %uint\n"
929 ;
930
931 // Create a local copy of the null constant global object to work with it.
932 subs["CALC_ZERO_FOR_CALLABLE"] += "%constant_null_copy = OpCopyObject %uint %constant_null\n"
933 "%is_37_before = OpIEqual %bool %input_val_before %uint_37\n"
934 "%zero_for_callable = OpSelect %uint %is_37_before %constant_null_copy %uint_5\n"
935 ;
936 }
937 else if (m_params.dataType == DataType::OP_UNDEF)
938 {
939 subs["INPUT_BUFFER_VALUE_TYPE"] = "%uint";
940 subs["EXTRA_TYPES_AND_CONSTANTS"] += " %uint_37 = OpConstant %uint 37\n"
941 ;
942
943 // Extract an undef value and write it to the output buffer to make sure it's used before the call. The value will be overwritten later.
944 subs["CALC_ZERO_FOR_CALLABLE"] += "%undef_var = OpUndef %uint\n"
945 "%undef_val_before = OpCopyObject %uint %undef_var\n"
946 "OpStore %output_val_ptr %undef_val_before Volatile\n"
947 "%zero_for_callable = OpISub %uint %uint_37 %input_val_before\n"
948 ;
949 }
950 else
951 {
952 DE_ASSERT(false);
953 }
954
955 // Comparison statement for data before and after the call.
956 switch (m_params.dataType)
957 {
958 case DataType::INT32:
959 case DataType::UINT32:
960 case DataType::INT64:
961 case DataType::UINT64:
962 case DataType::INT16:
963 case DataType::UINT16:
964 case DataType::INT8:
965 case DataType::UINT8:
966 opEqual = "OpIEqual";
967 break;
968 case DataType::FLOAT32:
969 case DataType::FLOAT64:
970 case DataType::FLOAT16:
971 opEqual = "OpFOrdEqual";
972 break;
973 case DataType::STRUCT:
974 case DataType::IMAGE:
975 case DataType::SAMPLER:
976 case DataType::SAMPLED_IMAGE:
977 case DataType::PTR_IMAGE:
978 case DataType::PTR_SAMPLER:
979 case DataType::PTR_SAMPLED_IMAGE:
980 case DataType::PTR_TEXEL:
981 case DataType::OP_NULL:
982 case DataType::OP_UNDEF:
983 // These needs special code for the comparison.
984 opEqual = "INVALID";
985 break;
986 default:
987 DE_ASSERT(false);
988 break;
989 }
990
991 if (m_params.dataType == DataType::STRUCT)
992 {
993 // We need to store the before and after values in a variable in order to be able to access each member individually without accessing the StorageBuffer again.
994 subs["EXTRA_FUNCTION_VARIABLES"] = " %input_val_func_before = OpVariable %input_struct_func_ptr_type Function\n"
995 " %input_val_func_after = OpVariable %input_struct_func_ptr_type Function\n"
996 ;
997 subs["CALC_EQUAL_STATEMENT"] = " OpStore %input_val_func_before %input_val_before\n"
998 " OpStore %input_val_func_after %input_val_after\n"
999 " %uint_part_func_before_ptr = OpAccessChain %uint_part_func_ptr_type %input_val_func_before %uint_0\n"
1000 " %float_part_func_before_ptr = OpAccessChain %float_part_func_ptr_type %input_val_func_before %uint_1\n"
1001 " %uint_part_func_after_ptr = OpAccessChain %uint_part_func_ptr_type %input_val_func_after %uint_0\n"
1002 " %float_part_func_after_ptr = OpAccessChain %float_part_func_ptr_type %input_val_func_after %uint_1\n"
1003 " %uint_part_before = OpLoad %uint %uint_part_func_before_ptr\n"
1004 " %float_part_before = OpLoad %float %float_part_func_before_ptr\n"
1005 " %uint_part_after = OpLoad %uint %uint_part_func_after_ptr\n"
1006 " %float_part_after = OpLoad %float %float_part_func_after_ptr\n"
1007 " %uint_equal = OpIEqual %bool %uint_part_before %uint_part_after\n"
1008 " %float_equal = OpFOrdEqual %bool %float_part_before %float_part_after\n"
1009 " %equal = OpLogicalAnd %bool %uint_equal %float_equal\n"
1010 ;
1011 }
1012 else if (m_params.dataType == DataType::IMAGE)
1013 {
1014 // Use the same image and the second sampler with different coordinates (actually the same).
1015 subs["CALC_EQUAL_STATEMENT"] += "%sampler_1_ptr = OpAccessChain %sampler_type_uniform_ptr %textureSampler %uint_1\n"
1016 "%sampler_1 = OpLoad %sampler_type %sampler_1_ptr\n"
1017 "%sampled_image_1 = OpSampledImage %sampled_image_type %image_0 %sampler_1\n"
1018 "%texture_coords_1 = OpCompositeConstruct %v2float %input_val_after %input_val_after\n"
1019 "%pixel_vec_1 = OpImageSampleExplicitLod %v4uint %sampled_image_1 %texture_coords_1 Lod|ZeroExtend %float_0\n"
1020 "%pixel_1 = OpCompositeExtract %uint %pixel_vec_1 0\n"
1021 "%equal = OpIEqual %bool %pixel_0 %pixel_1\n"
1022 ;
1023 }
1024 else if (m_params.dataType == DataType::SAMPLER)
1025 {
1026 // Use the same sampler and sample from the second image with different coordinates (but actually the same).
1027 subs["CALC_EQUAL_STATEMENT"] += "%image_1_ptr = OpAccessChain %image_type_uniform_ptr %sampledTexture %uint_1\n"
1028 "%image_1 = OpLoad %image_type %image_1_ptr\n"
1029 "%sampled_image_1 = OpSampledImage %sampled_image_type %image_1 %sampler_0\n"
1030 "%texture_coords_1 = OpCompositeConstruct %v2float %input_val_after %input_val_after\n"
1031 "%pixel_vec_1 = OpImageSampleExplicitLod %v4uint %sampled_image_1 %texture_coords_1 Lod|ZeroExtend %float_0\n"
1032 "%pixel_1 = OpCompositeExtract %uint %pixel_vec_1 0\n"
1033 "%equal = OpIEqual %bool %pixel_0 %pixel_1\n"
1034 ;
1035 }
1036 else if (m_params.dataType == DataType::SAMPLED_IMAGE)
1037 {
1038 // Reuse the same combined image sampler with different coordinates (actually the same).
1039 subs["CALC_EQUAL_STATEMENT"] += "%texture_coords_1 = OpCompositeConstruct %v2float %input_val_after %input_val_after\n"
1040 "%pixel_vec_1 = OpImageSampleExplicitLod %v4uint %sampled_image_0 %texture_coords_1 Lod|ZeroExtend %float_0\n"
1041 "%pixel_1 = OpCompositeExtract %uint %pixel_vec_1 0\n"
1042 "%equal = OpIEqual %bool %pixel_0 %pixel_1\n"
1043 ;
1044 }
1045 else if (m_params.dataType == DataType::PTR_IMAGE)
1046 {
1047 // We attempt to use the second pointer only after the call.
1048 subs["CALC_EQUAL_STATEMENT"] += "%image_1 = OpLoad %image_type %image_1_ptr\n"
1049 "%sampled_image_1 = OpSampledImage %sampled_image_type %image_1 %sampler_0\n"
1050 "%texture_coords_1 = OpCompositeConstruct %v2float %input_val_after %input_val_after\n"
1051 "%pixel_vec_1 = OpImageSampleExplicitLod %v4uint %sampled_image_1 %texture_coords_1 Lod|ZeroExtend %float_0\n"
1052 "%pixel_1 = OpCompositeExtract %uint %pixel_vec_1 0\n"
1053 "%equal = OpIEqual %bool %pixel_0 %pixel_1\n"
1054 ;
1055
1056 }
1057 else if (m_params.dataType == DataType::PTR_SAMPLER)
1058 {
1059 // We attempt to use the second pointer only after the call.
1060 subs["CALC_EQUAL_STATEMENT"] += "%sampler_1 = OpLoad %sampler_type %sampler_1_ptr\n"
1061 "%sampled_image_1 = OpSampledImage %sampled_image_type %image_0 %sampler_1\n"
1062 "%texture_coords_1 = OpCompositeConstruct %v2float %input_val_after %input_val_after\n"
1063 "%pixel_vec_1 = OpImageSampleExplicitLod %v4uint %sampled_image_1 %texture_coords_1 Lod|ZeroExtend %float_0\n"
1064 "%pixel_1 = OpCompositeExtract %uint %pixel_vec_1 0\n"
1065 "%equal = OpIEqual %bool %pixel_0 %pixel_1\n"
1066 ;
1067 }
1068 else if (m_params.dataType == DataType::PTR_SAMPLED_IMAGE)
1069 {
1070 // We attempt to use the second pointer only after the call.
1071 subs["CALC_EQUAL_STATEMENT"] += "%sampled_image_1 = OpLoad %sampled_image_type %sampled_image_1_ptr\n"
1072 "%texture_coords_1 = OpCompositeConstruct %v2float %input_val_after %input_val_after\n"
1073 "%pixel_vec_1 = OpImageSampleExplicitLod %v4uint %sampled_image_1 %texture_coords_1 Lod|ZeroExtend %float_0\n"
1074 "%pixel_1 = OpCompositeExtract %uint %pixel_vec_1 0\n"
1075 "%equal = OpIEqual %bool %pixel_0 %pixel_1\n"
1076 ;
1077 }
1078 else if (m_params.dataType == DataType::PTR_TEXEL)
1079 {
1080 // Check value 5 was stored properly.
1081 subs["CALC_EQUAL_STATEMENT"] += "%stored_val = OpAtomicLoad %uint %texel_ptr %uint_1 %uint_0\n"
1082 "%equal = OpIEqual %bool %stored_val %uint_5\n"
1083 ;
1084 }
1085 else if (m_params.dataType == DataType::OP_NULL)
1086 {
1087 // Reuse the null constant after the call.
1088 subs["CALC_EQUAL_STATEMENT"] += "%is_37_after = OpIEqual %bool %input_val_after %uint_37\n"
1089 "%writeback_val = OpSelect %uint %is_37_after %constant_null_copy %uint_5\n"
1090 "OpStore %input_val_ptr %writeback_val Volatile\n"
1091 "%readback_val = OpLoad %uint %input_val_ptr Volatile\n"
1092 "%equal = OpIEqual %bool %readback_val %uint_0\n"
1093 ;
1094 }
1095 else if (m_params.dataType == DataType::OP_UNDEF)
1096 {
1097 // Extract another undef value and write it to the input buffer. It will not be checked later.
1098 subs["CALC_EQUAL_STATEMENT"] += "%undef_val_after = OpCopyObject %uint %undef_var\n"
1099 "OpStore %input_val_ptr %undef_val_after Volatile\n"
1100 "%equal = OpIEqual %bool %input_val_after %input_val_before\n"
1101 ;
1102 }
1103 else
1104 {
1105 subs["CALC_EQUAL_STATEMENT"] += " %equal = " + opEqual + " %bool %input_val_before %input_val_after\n";
1106 }
1107
1108 // Modifications for vectors and arrays.
1109 if (numComponents > 1)
1110 {
1111 const std::string vectorTypeName = "v" + numComponentsStr + componentTypeName;
1112 const std::string opType = (isArray ? "OpTypeArray" : "OpTypeVector");
1113 const std::string componentCountStr = (isArray ? ("%uint_" + numComponentsStr) : numComponentsStr);
1114
1115 // Some extra types are needed.
1116 if (!(m_params.dataType == DataType::FLOAT32 && m_params.vectorType == VectorType::V3))
1117 {
1118 // Note: v3float is already defined in the shader by default.
1119 subs["EXTRA_TYPES_AND_CONSTANTS"] += "%" + vectorTypeName + " = " + opType + " %" + componentTypeName + " " + componentCountStr + "\n";
1120 }
1121 subs["EXTRA_TYPES_AND_CONSTANTS"] += "%v" + numComponentsStr + "bool = " + opType + " %bool " + componentCountStr + "\n";
1122 subs["EXTRA_TYPES_AND_CONSTANTS"] += "%comp_ptr = OpTypePointer StorageBuffer %" + componentTypeName + "\n";
1123
1124 // The input value in the buffer has a different type.
1125 subs["INPUT_BUFFER_VALUE_TYPE"] = "%" + vectorTypeName;
1126
1127 // Overwrite the way we calculate the zero used in the call.
1128
1129 // Proper operations for adding, substracting and converting components.
1130 std::string opAdd;
1131 std::string opSub;
1132 std::string opConvert;
1133
1134 switch (m_params.dataType)
1135 {
1136 case DataType::INT32:
1137 case DataType::UINT32:
1138 case DataType::INT64:
1139 case DataType::UINT64:
1140 case DataType::INT16:
1141 case DataType::UINT16:
1142 case DataType::INT8:
1143 case DataType::UINT8:
1144 opAdd = "OpIAdd";
1145 opSub = "OpISub";
1146 break;
1147 case DataType::FLOAT32:
1148 case DataType::FLOAT64:
1149 case DataType::FLOAT16:
1150 opAdd = "OpFAdd";
1151 opSub = "OpFSub";
1152 break;
1153 default:
1154 DE_ASSERT(false);
1155 break;
1156 }
1157
1158 switch (m_params.dataType)
1159 {
1160 case DataType::UINT32:
1161 opConvert = "OpCopyObject";
1162 break;
1163 case DataType::INT32:
1164 opConvert = "OpBitcast";
1165 break;
1166 case DataType::INT64:
1167 case DataType::INT16:
1168 case DataType::INT8:
1169 opConvert = "OpSConvert";
1170 break;
1171 case DataType::UINT64:
1172 case DataType::UINT16:
1173 case DataType::UINT8:
1174 opConvert = "OpUConvert";
1175 break;
1176 case DataType::FLOAT32:
1177 case DataType::FLOAT64:
1178 case DataType::FLOAT16:
1179 opConvert = "OpConvertFToU";
1180 break;
1181 default:
1182 DE_ASSERT(false);
1183 break;
1184 }
1185
1186 std::ostringstream zeroForCallable;
1187
1188 // Create pointers to components and load components.
1189 for (int i = 0; i < numComponents; ++i)
1190 {
1191 zeroForCallable
1192 << "%component_ptr_" << i << " = OpAccessChain %comp_ptr %input_val_ptr %uint_" << i << "\n"
1193 << "%component_" << i << " = OpLoad %" << componentTypeName << " %component_ptr_" << i << "\n"
1194 ;
1195 }
1196
1197 // Sum components together in %total_sum.
1198 for (int i = 1; i < numComponents; ++i)
1199 {
1200 const std::string previous = ((i == 1) ? "%component_0" : ("%partial_" + de::toString(i-1)));
1201 const std::string resultName = ((i == (numComponents - 1)) ? "%total_sum" : ("%partial_" + de::toString(i)));
1202 zeroForCallable << resultName << " = " << opAdd << " %" << componentTypeName << " %component_" << i << " " << previous << "\n";
1203 }
1204
1205 // Recalculate the zero.
1206 zeroForCallable
1207 << "%zero_" << componentTypeName << " = " << opSub << " %" << componentTypeName << " %total_sum %" << componentTypeName << "_37\n"
1208 << "%zero_for_callable = " << opConvert << " %uint %zero_" << componentTypeName << "\n"
1209 ;
1210
1211 // Finally replace the zero_for_callable statements with the special version for vectors.
1212 subs["CALC_ZERO_FOR_CALLABLE"] = zeroForCallable.str();
1213
1214 // Rework comparison statements.
1215 if (isArray)
1216 {
1217 // Arrays need to be compared per-component.
1218 std::ostringstream calcEqual;
1219
1220 for (int i = 0; i < numComponents; ++i)
1221 {
1222 calcEqual
1223 << "%component_after_" << i << " = OpLoad %" << componentTypeName << " %component_ptr_" << i << "\n"
1224 << "%equal_" << i << " = " << opEqual << " %bool %component_" << i << " %component_after_" << i << "\n";
1225 ;
1226 if (i > 0)
1227 calcEqual << "%and_" << i << " = OpLogicalAnd %bool %equal_" << (i - 1) << " %equal_" << i << "\n";
1228 if (i == numComponents - 1)
1229 calcEqual << "%equal = OpCopyObject %bool %and_" << i << "\n";
1230 }
1231
1232 subs["CALC_EQUAL_STATEMENT"] = calcEqual.str();
1233 }
1234 else
1235 {
1236 // Vectors can be compared using a bool vector and OpAll.
1237 subs["CALC_EQUAL_STATEMENT"] = " %equal_vector = " + opEqual + " %v" + numComponentsStr + "bool %input_val_before %input_val_after\n";
1238 subs["CALC_EQUAL_STATEMENT"] += " %equal = OpAll %bool %equal_vector\n";
1239 }
1240 }
1241
1242 if (isArray)
1243 {
1244 // Arrays need an ArrayStride decoration.
1245 std::ostringstream interfaceDecorations;
1246 interfaceDecorations << "OpDecorate %v" << numComponentsStr << componentTypeName << " ArrayStride " << getElementSize(m_params.dataType, VectorType::SCALAR) << "\n";
1247 subs["INTERFACE_DECORATIONS"] = interfaceDecorations.str();
1248 }
1249
1250 const auto inputBlockDecls = getGLSLInputValDecl(m_params.dataType, m_params.vectorType);
1251
1252 std::ostringstream glslBindings;
1253 glslBindings
1254 << inputBlockDecls.first // Additional data types needed.
1255 << "layout(set = 0, binding = 0) uniform accelerationStructureEXT topLevelAS;\n"
1256 << "layout(set = 0, binding = 1) buffer CalleeBlock { uint val; } calleeBuffer;\n"
1257 << "layout(set = 0, binding = 2) buffer OutputBlock { uint val; } outputBuffer;\n"
1258 << "layout(set = 0, binding = 3) buffer InputBlock { " << inputBlockDecls.second << " } inputBuffer;\n"
1259 ;
1260
1261 if (samplersNeeded(m_params.dataType))
1262 {
1263 glslBindings
1264 << "layout(set = 0, binding = 4) uniform utexture2D sampledTexture[2];\n"
1265 << "layout(set = 0, binding = 5) uniform sampler textureSampler[2];\n"
1266 << "layout(set = 0, binding = 6) uniform usampler2D combinedImageSampler[2];\n"
1267 ;
1268 }
1269 else if (storageImageNeeded(m_params.dataType))
1270 {
1271 glslBindings
1272 << "layout(set = 0, binding = 4, r32ui) uniform uimage2D storageImage;\n"
1273 ;
1274 }
1275
1276 const auto glslBindingsStr = glslBindings.str();
1277 const auto glslHeaderStr = "#version 460 core\n"
1278 "#extension GL_EXT_ray_tracing : require\n"
1279 "#extension GL_EXT_shader_explicit_arithmetic_types : require\n";
1280
1281
1282 if (m_params.callType == CallType::TRACE_RAY)
1283 {
1284 subs["ENTRY_POINT"] = "RayGenerationKHR";
1285 subs["MAIN_INTERFACE_EXTRAS"] += " %hitValue";
1286 subs["INTERFACE_DECORATIONS"] += " OpDecorate %hitValue Location 0\n";
1287 subs["INTERFACE_TYPES_AND_VARIABLES"] = " %payload_ptr = OpTypePointer RayPayloadKHR %v3float\n"
1288 " %hitValue = OpVariable %payload_ptr RayPayloadKHR\n";
1289 subs["CALL_STATEMENTS"] = " %as_value = OpLoad %as_type %topLevelAS\n"
1290 " OpTraceRayKHR %as_value %uint_0 %uint_255 %zero_for_callable %zero_for_callable %zero_for_callable %origin_const %float_0 %direction_const %float_9 %hitValue\n";
1291
1292 const auto rgen = spvTemplate.specialize(subs);
1293 programCollection.spirvAsmSources.add("rgen") << rgen << spvBuildOptions;
1294
1295 std::stringstream chit;
1296 chit
1297 << glslHeaderStr
1298 << "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
1299 << "hitAttributeEXT vec3 attribs;\n"
1300 << glslBindingsStr
1301 << "void main()\n"
1302 << "{\n"
1303 << " calleeBuffer.val = 1u;\n"
1304 << "}\n";
1305 ;
1306 programCollection.glslSources.add("chit") << glu::ClosestHitSource(updateRayTracingGLSL(chit.str())) << buildOptions;
1307 }
1308 else if (m_params.callType == CallType::EXECUTE_CALLABLE)
1309 {
1310 subs["ENTRY_POINT"] = "RayGenerationKHR";
1311 subs["MAIN_INTERFACE_EXTRAS"] += " %callableData";
1312 subs["INTERFACE_DECORATIONS"] += " OpDecorate %callableData Location 0\n";
1313 subs["INTERFACE_TYPES_AND_VARIABLES"] = " %callable_data_ptr = OpTypePointer CallableDataKHR %float\n"
1314 " %callableData = OpVariable %callable_data_ptr CallableDataKHR\n";
1315 subs["CALL_STATEMENTS"] = " OpExecuteCallableKHR %zero_for_callable %callableData\n";
1316
1317 const auto rgen = spvTemplate.specialize(subs);
1318 programCollection.spirvAsmSources.add("rgen") << rgen << spvBuildOptions;
1319
1320 std::ostringstream call;
1321 call
1322 << glslHeaderStr
1323 << "layout(location = 0) callableDataInEXT float callableData;\n"
1324 << glslBindingsStr
1325 << "void main()\n"
1326 << "{\n"
1327 << " calleeBuffer.val = 1u;\n"
1328 << "}\n"
1329 ;
1330
1331 programCollection.glslSources.add("call") << glu::CallableSource(updateRayTracingGLSL(call.str())) << buildOptions;
1332 }
1333 else if (m_params.callType == CallType::REPORT_INTERSECTION)
1334 {
1335 subs["ENTRY_POINT"] = "IntersectionKHR";
1336 subs["MAIN_INTERFACE_EXTRAS"] += " %attribs";
1337 subs["INTERFACE_DECORATIONS"] += "";
1338 subs["INTERFACE_TYPES_AND_VARIABLES"] = " %hit_attribute_ptr = OpTypePointer HitAttributeKHR %v3float\n"
1339 " %attribs = OpVariable %hit_attribute_ptr HitAttributeKHR\n";
1340 subs["CALL_STATEMENTS"] = " %intersection_ret = OpReportIntersectionKHR %bool %float_1 %zero_for_callable\n";
1341
1342 const auto rint = spvTemplate.specialize(subs);
1343 programCollection.spirvAsmSources.add("rint") << rint << spvBuildOptions;
1344
1345 std::ostringstream rgen;
1346 rgen
1347 << glslHeaderStr
1348 << "layout(location = 0) rayPayloadEXT vec3 hitValue;\n"
1349 << glslBindingsStr
1350 << "void main()\n"
1351 << "{\n"
1352 << " traceRayEXT(topLevelAS, 0u, 0xFFu, 0, 0, 0, vec3(0.5, 0.5, 0.0), 0.0, vec3(0.0, 0.0, -1.0), 9.0, 0);\n"
1353 << "}\n"
1354 ;
1355 programCollection.glslSources.add("rgen") << glu::RaygenSource(updateRayTracingGLSL(rgen.str())) << buildOptions;
1356
1357 std::stringstream ahit;
1358 ahit
1359 << glslHeaderStr
1360 << "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
1361 << "hitAttributeEXT vec3 attribs;\n"
1362 << glslBindingsStr
1363 << "void main()\n"
1364 << "{\n"
1365 << " calleeBuffer.val = 1u;\n"
1366 << "}\n";
1367 ;
1368 programCollection.glslSources.add("ahit") << glu::AnyHitSource(updateRayTracingGLSL(ahit.str())) << buildOptions;
1369 }
1370 else
1371 {
1372 DE_ASSERT(false);
1373 }
1374 }
1375
1376 using v2i32 = tcu::Vector<deInt32, 2>;
1377 using v3i32 = tcu::Vector<deInt32, 3>;
1378 using v4i32 = tcu::Vector<deInt32, 4>;
1379 using a5i32 = std::array<deInt32, 5>;
1380
1381 using v2u32 = tcu::Vector<deUint32, 2>;
1382 using v3u32 = tcu::Vector<deUint32, 3>;
1383 using v4u32 = tcu::Vector<deUint32, 4>;
1384 using a5u32 = std::array<deUint32, 5>;
1385
1386 using v2i64 = tcu::Vector<deInt64, 2>;
1387 using v3i64 = tcu::Vector<deInt64, 3>;
1388 using v4i64 = tcu::Vector<deInt64, 4>;
1389 using a5i64 = std::array<deInt64, 5>;
1390
1391 using v2u64 = tcu::Vector<deUint64, 2>;
1392 using v3u64 = tcu::Vector<deUint64, 3>;
1393 using v4u64 = tcu::Vector<deUint64, 4>;
1394 using a5u64 = std::array<deUint64, 5>;
1395
1396 using v2i16 = tcu::Vector<deInt16, 2>;
1397 using v3i16 = tcu::Vector<deInt16, 3>;
1398 using v4i16 = tcu::Vector<deInt16, 4>;
1399 using a5i16 = std::array<deInt16, 5>;
1400
1401 using v2u16 = tcu::Vector<deUint16, 2>;
1402 using v3u16 = tcu::Vector<deUint16, 3>;
1403 using v4u16 = tcu::Vector<deUint16, 4>;
1404 using a5u16 = std::array<deUint16, 5>;
1405
1406 using v2i8 = tcu::Vector<deInt8, 2>;
1407 using v3i8 = tcu::Vector<deInt8, 3>;
1408 using v4i8 = tcu::Vector<deInt8, 4>;
1409 using a5i8 = std::array<deInt8, 5>;
1410
1411 using v2u8 = tcu::Vector<deUint8, 2>;
1412 using v3u8 = tcu::Vector<deUint8, 3>;
1413 using v4u8 = tcu::Vector<deUint8, 4>;
1414 using a5u8 = std::array<deUint8, 5>;
1415
1416 using v2f32 = tcu::Vector<tcu::Float32, 2>;
1417 using v3f32 = tcu::Vector<tcu::Float32, 3>;
1418 using v4f32 = tcu::Vector<tcu::Float32, 4>;
1419 using a5f32 = std::array<tcu::Float32, 5>;
1420
1421 using v2f64 = tcu::Vector<tcu::Float64, 2>;
1422 using v3f64 = tcu::Vector<tcu::Float64, 3>;
1423 using v4f64 = tcu::Vector<tcu::Float64, 4>;
1424 using a5f64 = std::array<tcu::Float64, 5>;
1425
1426 using v2f16 = tcu::Vector<tcu::Float16, 2>;
1427 using v3f16 = tcu::Vector<tcu::Float16, 3>;
1428 using v4f16 = tcu::Vector<tcu::Float16, 4>;
1429 using a5f16 = std::array<tcu::Float16, 5>;
1430
1431 // Scalar types get filled with value 37, matching the value that will be substracted in the shader.
1432 #define GEN_SCALAR_FILL(DATA_TYPE) \
1433 do { \
1434 const auto inputBufferValue = static_cast<DATA_TYPE>(37.0); \
1435 deMemcpy(bufferPtr, &inputBufferValue, sizeof(inputBufferValue)); \
1436 } while (0)
1437
1438 // Vector types get filled with values that add up to 37, matching the value that will be substracted in the shader.
1439 #define GEN_V2_FILL(DATA_TYPE) \
1440 do { \
1441 DATA_TYPE inputBufferValue; \
1442 inputBufferValue.x() = static_cast<DATA_TYPE::Element>(21.0); \
1443 inputBufferValue.y() = static_cast<DATA_TYPE::Element>(16.0); \
1444 deMemcpy(bufferPtr, &inputBufferValue, sizeof(inputBufferValue)); \
1445 } while (0)
1446
1447 #define GEN_V3_FILL(DATA_TYPE) \
1448 do { \
1449 DATA_TYPE inputBufferValue; \
1450 inputBufferValue.x() = static_cast<DATA_TYPE::Element>(11.0); \
1451 inputBufferValue.y() = static_cast<DATA_TYPE::Element>(19.0); \
1452 inputBufferValue.z() = static_cast<DATA_TYPE::Element>(7.0); \
1453 deMemcpy(bufferPtr, &inputBufferValue, sizeof(inputBufferValue)); \
1454 } while (0)
1455
1456 #define GEN_V4_FILL(DATA_TYPE) \
1457 do { \
1458 DATA_TYPE inputBufferValue; \
1459 inputBufferValue.x() = static_cast<DATA_TYPE::Element>(9.0); \
1460 inputBufferValue.y() = static_cast<DATA_TYPE::Element>(11.0); \
1461 inputBufferValue.z() = static_cast<DATA_TYPE::Element>(3.0); \
1462 inputBufferValue.w() = static_cast<DATA_TYPE::Element>(14.0); \
1463 deMemcpy(bufferPtr, &inputBufferValue, sizeof(inputBufferValue)); \
1464 } while (0)
1465
1466 #define GEN_A5_FILL(DATA_TYPE) \
1467 do { \
1468 DATA_TYPE inputBufferValue; \
1469 inputBufferValue[0] = static_cast<DATA_TYPE::value_type>(13.0); \
1470 inputBufferValue[1] = static_cast<DATA_TYPE::value_type>(6.0); \
1471 inputBufferValue[2] = static_cast<DATA_TYPE::value_type>(2.0); \
1472 inputBufferValue[3] = static_cast<DATA_TYPE::value_type>(5.0); \
1473 inputBufferValue[4] = static_cast<DATA_TYPE::value_type>(11.0); \
1474 deMemcpy(bufferPtr, inputBufferValue.data(), de::dataSize(inputBufferValue)); \
1475 } while (0)
1476
fillInputBuffer(DataType dataType,VectorType vectorType,void * bufferPtr)1477 void fillInputBuffer (DataType dataType, VectorType vectorType, void* bufferPtr)
1478 {
1479 if (vectorType == VectorType::SCALAR)
1480 {
1481 if (dataType == DataType::INT32) GEN_SCALAR_FILL(deInt32);
1482 else if (dataType == DataType::UINT32) GEN_SCALAR_FILL(deUint32);
1483 else if (dataType == DataType::INT64) GEN_SCALAR_FILL(deInt64);
1484 else if (dataType == DataType::UINT64) GEN_SCALAR_FILL(deUint64);
1485 else if (dataType == DataType::INT16) GEN_SCALAR_FILL(deInt16);
1486 else if (dataType == DataType::UINT16) GEN_SCALAR_FILL(deUint16);
1487 else if (dataType == DataType::INT8) GEN_SCALAR_FILL(deInt8);
1488 else if (dataType == DataType::UINT8) GEN_SCALAR_FILL(deUint8);
1489 else if (dataType == DataType::FLOAT32) GEN_SCALAR_FILL(tcu::Float32);
1490 else if (dataType == DataType::FLOAT64) GEN_SCALAR_FILL(tcu::Float64);
1491 else if (dataType == DataType::FLOAT16) GEN_SCALAR_FILL(tcu::Float16);
1492 else if (dataType == DataType::STRUCT)
1493 {
1494 InputStruct data = { 12u, 25.0f };
1495 deMemcpy(bufferPtr, &data, sizeof(data));
1496 }
1497 else if (dataType == DataType::OP_NULL) GEN_SCALAR_FILL(deUint32);
1498 else if (dataType == DataType::OP_UNDEF) GEN_SCALAR_FILL(deUint32);
1499 else
1500 {
1501 DE_ASSERT(false);
1502 }
1503 }
1504 else if (vectorType == VectorType::V2)
1505 {
1506 if (dataType == DataType::INT32) GEN_V2_FILL(v2i32);
1507 else if (dataType == DataType::UINT32) GEN_V2_FILL(v2u32);
1508 else if (dataType == DataType::INT64) GEN_V2_FILL(v2i64);
1509 else if (dataType == DataType::UINT64) GEN_V2_FILL(v2u64);
1510 else if (dataType == DataType::INT16) GEN_V2_FILL(v2i16);
1511 else if (dataType == DataType::UINT16) GEN_V2_FILL(v2u16);
1512 else if (dataType == DataType::INT8) GEN_V2_FILL(v2i8);
1513 else if (dataType == DataType::UINT8) GEN_V2_FILL(v2u8);
1514 else if (dataType == DataType::FLOAT32) GEN_V2_FILL(v2f32);
1515 else if (dataType == DataType::FLOAT64) GEN_V2_FILL(v2f64);
1516 else if (dataType == DataType::FLOAT16) GEN_V2_FILL(v2f16);
1517 else
1518 {
1519 DE_ASSERT(false);
1520 }
1521 }
1522 else if (vectorType == VectorType::V3)
1523 {
1524 if (dataType == DataType::INT32) GEN_V3_FILL(v3i32);
1525 else if (dataType == DataType::UINT32) GEN_V3_FILL(v3u32);
1526 else if (dataType == DataType::INT64) GEN_V3_FILL(v3i64);
1527 else if (dataType == DataType::UINT64) GEN_V3_FILL(v3u64);
1528 else if (dataType == DataType::INT16) GEN_V3_FILL(v3i16);
1529 else if (dataType == DataType::UINT16) GEN_V3_FILL(v3u16);
1530 else if (dataType == DataType::INT8) GEN_V3_FILL(v3i8);
1531 else if (dataType == DataType::UINT8) GEN_V3_FILL(v3u8);
1532 else if (dataType == DataType::FLOAT32) GEN_V3_FILL(v3f32);
1533 else if (dataType == DataType::FLOAT64) GEN_V3_FILL(v3f64);
1534 else if (dataType == DataType::FLOAT16) GEN_V3_FILL(v3f16);
1535 else
1536 {
1537 DE_ASSERT(false);
1538 }
1539 }
1540 else if (vectorType == VectorType::V4)
1541 {
1542 if (dataType == DataType::INT32) GEN_V4_FILL(v4i32);
1543 else if (dataType == DataType::UINT32) GEN_V4_FILL(v4u32);
1544 else if (dataType == DataType::INT64) GEN_V4_FILL(v4i64);
1545 else if (dataType == DataType::UINT64) GEN_V4_FILL(v4u64);
1546 else if (dataType == DataType::INT16) GEN_V4_FILL(v4i16);
1547 else if (dataType == DataType::UINT16) GEN_V4_FILL(v4u16);
1548 else if (dataType == DataType::INT8) GEN_V4_FILL(v4i8);
1549 else if (dataType == DataType::UINT8) GEN_V4_FILL(v4u8);
1550 else if (dataType == DataType::FLOAT32) GEN_V4_FILL(v4f32);
1551 else if (dataType == DataType::FLOAT64) GEN_V4_FILL(v4f64);
1552 else if (dataType == DataType::FLOAT16) GEN_V4_FILL(v4f16);
1553 else
1554 {
1555 DE_ASSERT(false);
1556 }
1557 }
1558 else if (vectorType == VectorType::A5)
1559 {
1560 if (dataType == DataType::INT32) GEN_A5_FILL(a5i32);
1561 else if (dataType == DataType::UINT32) GEN_A5_FILL(a5u32);
1562 else if (dataType == DataType::INT64) GEN_A5_FILL(a5i64);
1563 else if (dataType == DataType::UINT64) GEN_A5_FILL(a5u64);
1564 else if (dataType == DataType::INT16) GEN_A5_FILL(a5i16);
1565 else if (dataType == DataType::UINT16) GEN_A5_FILL(a5u16);
1566 else if (dataType == DataType::INT8) GEN_A5_FILL(a5i8);
1567 else if (dataType == DataType::UINT8) GEN_A5_FILL(a5u8);
1568 else if (dataType == DataType::FLOAT32) GEN_A5_FILL(a5f32);
1569 else if (dataType == DataType::FLOAT64) GEN_A5_FILL(a5f64);
1570 else if (dataType == DataType::FLOAT16) GEN_A5_FILL(a5f16);
1571 else
1572 {
1573 DE_ASSERT(false);
1574 }
1575 }
1576 else
1577 {
1578 DE_ASSERT(false);
1579 }
1580 }
1581
iterate(void)1582 tcu::TestStatus DataSpillTestInstance::iterate (void)
1583 {
1584 const auto& vki = m_context.getInstanceInterface();
1585 const auto physicalDevice = m_context.getPhysicalDevice();
1586 const auto& vkd = m_context.getDeviceInterface();
1587 const auto device = m_context.getDevice();
1588 const auto queue = m_context.getUniversalQueue();
1589 const auto familyIndex = m_context.getUniversalQueueFamilyIndex();
1590 auto& alloc = m_context.getDefaultAllocator();
1591 const auto shaderStages = getShaderStages(m_params.callType);
1592
1593 // Command buffer.
1594 const auto cmdPool = makeCommandPool(vkd, device, familyIndex);
1595 const auto cmdBufferPtr = allocateCommandBuffer(vkd, device, cmdPool.get(), VK_COMMAND_BUFFER_LEVEL_PRIMARY);
1596 const auto cmdBuffer = cmdBufferPtr.get();
1597
1598 beginCommandBuffer(vkd, cmdBuffer);
1599
1600 // Callee, input and output buffers.
1601 const auto calleeBufferSize = getElementSize(DataType::UINT32, VectorType::SCALAR);
1602 const auto outputBufferSize = getElementSize(DataType::UINT32, VectorType::SCALAR);
1603 const auto inputBufferSize = getElementSize(m_params.dataType, m_params.vectorType);
1604
1605 const auto calleeBufferInfo = makeBufferCreateInfo(calleeBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
1606 const auto outputBufferInfo = makeBufferCreateInfo(outputBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
1607 const auto inputBufferInfo = makeBufferCreateInfo(inputBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
1608
1609 BufferWithMemory calleeBuffer (vkd, device, alloc, calleeBufferInfo, MemoryRequirement::HostVisible);
1610 BufferWithMemory outputBuffer (vkd, device, alloc, outputBufferInfo, MemoryRequirement::HostVisible);
1611 BufferWithMemory inputBuffer (vkd, device, alloc, inputBufferInfo, MemoryRequirement::HostVisible);
1612
1613 // Fill buffers with values.
1614 auto& calleeBufferAlloc = calleeBuffer.getAllocation();
1615 auto* calleeBufferPtr = calleeBufferAlloc.getHostPtr();
1616 auto& outputBufferAlloc = outputBuffer.getAllocation();
1617 auto* outputBufferPtr = outputBufferAlloc.getHostPtr();
1618 auto& inputBufferAlloc = inputBuffer.getAllocation();
1619 auto* inputBufferPtr = inputBufferAlloc.getHostPtr();
1620
1621 deMemset(calleeBufferPtr, 0, static_cast<size_t>(calleeBufferSize));
1622 deMemset(outputBufferPtr, 0, static_cast<size_t>(outputBufferSize));
1623
1624 if (samplersNeeded(m_params.dataType) || storageImageNeeded(m_params.dataType))
1625 {
1626 // The input buffer for these cases will be filled with zeros (sampling coordinates), and the input textures will contain the interesting input value.
1627 deMemset(inputBufferPtr, 0, static_cast<size_t>(inputBufferSize));
1628 }
1629 else
1630 {
1631 // We want to fill the input buffer with values that will be consistently used in the shader to obtain a result of zero.
1632 fillInputBuffer(m_params.dataType, m_params.vectorType, inputBufferPtr);
1633 }
1634
1635 flushAlloc(vkd, device, calleeBufferAlloc);
1636 flushAlloc(vkd, device, outputBufferAlloc);
1637 flushAlloc(vkd, device, inputBufferAlloc);
1638
1639 // Acceleration structures.
1640 de::MovePtr<BottomLevelAccelerationStructure> bottomLevelAccelerationStructure;
1641 de::MovePtr<TopLevelAccelerationStructure> topLevelAccelerationStructure;
1642
1643 bottomLevelAccelerationStructure = makeBottomLevelAccelerationStructure();
1644 bottomLevelAccelerationStructure->setDefaultGeometryData(getShaderStageForGeometry(m_params.callType), VK_GEOMETRY_NO_DUPLICATE_ANY_HIT_INVOCATION_BIT_KHR);
1645 bottomLevelAccelerationStructure->createAndBuild(vkd, device, cmdBuffer, alloc);
1646
1647 topLevelAccelerationStructure = makeTopLevelAccelerationStructure();
1648 topLevelAccelerationStructure->setInstanceCount(1);
1649 topLevelAccelerationStructure->addInstance(de::SharedPtr<BottomLevelAccelerationStructure>(bottomLevelAccelerationStructure.release()));
1650 topLevelAccelerationStructure->createAndBuild(vkd, device, cmdBuffer, alloc);
1651
1652 // Get some ray tracing properties.
1653 deUint32 shaderGroupHandleSize = 0u;
1654 deUint32 shaderGroupBaseAlignment = 1u;
1655 {
1656 const auto rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
1657 shaderGroupHandleSize = rayTracingPropertiesKHR->getShaderGroupHandleSize();
1658 shaderGroupBaseAlignment = rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
1659 }
1660
1661 // Textures and samplers if needed.
1662 de::MovePtr<BufferWithMemory> textureData;
1663 std::vector<de::MovePtr<ImageWithMemory>> textures;
1664 std::vector<Move<VkImageView>> textureViews;
1665 std::vector<Move<VkSampler>> samplers;
1666
1667 if (samplersNeeded(m_params.dataType) || storageImageNeeded(m_params.dataType))
1668 {
1669 // Create texture data with the expected contents.
1670 {
1671 const auto textureDataSize = static_cast<VkDeviceSize>(sizeof(deUint32));
1672 const auto textureDataCreateInfo = makeBufferCreateInfo(textureDataSize, VK_BUFFER_USAGE_TRANSFER_SRC_BIT);
1673
1674 textureData = de::MovePtr<BufferWithMemory>(new BufferWithMemory(vkd, device, alloc, textureDataCreateInfo, MemoryRequirement::HostVisible));
1675 auto& textureDataAlloc = textureData->getAllocation();
1676 auto* textureDataPtr = textureDataAlloc.getHostPtr();
1677
1678 fillInputBuffer(DataType::UINT32, VectorType::SCALAR, textureDataPtr);
1679 flushAlloc(vkd, device, textureDataAlloc);
1680 }
1681
1682 // Images will be created like this with different usages.
1683 VkImageCreateInfo imageCreateInfo =
1684 {
1685 VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // VkStructureType sType;
1686 nullptr, // const void* pNext;
1687 0u, // VkImageCreateFlags flags;
1688 VK_IMAGE_TYPE_2D, // VkImageType imageType;
1689 kImageFormat, // VkFormat format;
1690 kImageExtent, // VkExtent3D extent;
1691 1u, // deUint32 mipLevels;
1692 1u, // deUint32 arrayLayers;
1693 VK_SAMPLE_COUNT_1_BIT, // VkSampleCountFlagBits samples;
1694 VK_IMAGE_TILING_OPTIMAL, // VkImageTiling tiling;
1695 kSampledImageUsage, // VkImageUsageFlags usage;
1696 VK_SHARING_MODE_EXCLUSIVE, // VkSharingMode sharingMode;
1697 0u, // deUint32 queueFamilyIndexCount;
1698 nullptr, // const deUint32* pQueueFamilyIndices;
1699 VK_IMAGE_LAYOUT_UNDEFINED, // VkImageLayout initialLayout;
1700 };
1701
1702 const auto imageSubresourceRange = makeImageSubresourceRange(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 1u, 0u, 1u);
1703 const auto imageSubresourceLayers = makeImageSubresourceLayers(VK_IMAGE_ASPECT_COLOR_BIT, 0u, 0u, 1u);
1704
1705 if (samplersNeeded(m_params.dataType))
1706 {
1707 // All samplers will be created like this.
1708 const VkSamplerCreateInfo samplerCreateInfo =
1709 {
1710 VK_STRUCTURE_TYPE_SAMPLER_CREATE_INFO, // VkStructureType sType;
1711 nullptr, // const void* pNext;
1712 0u, // VkSamplerCreateFlags flags;
1713 VK_FILTER_NEAREST, // VkFilter magFilter;
1714 VK_FILTER_NEAREST, // VkFilter minFilter;
1715 VK_SAMPLER_MIPMAP_MODE_NEAREST, // VkSamplerMipmapMode mipmapMode;
1716 VK_SAMPLER_ADDRESS_MODE_REPEAT, // VkSamplerAddressMode addressModeU;
1717 VK_SAMPLER_ADDRESS_MODE_REPEAT, // VkSamplerAddressMode addressModeV;
1718 VK_SAMPLER_ADDRESS_MODE_REPEAT, // VkSamplerAddressMode addressModeW;
1719 0.0, // float mipLodBias;
1720 VK_FALSE, // VkBool32 anisotropyEnable;
1721 1.0f, // float maxAnisotropy;
1722 VK_FALSE, // VkBool32 compareEnable;
1723 VK_COMPARE_OP_ALWAYS, // VkCompareOp compareOp;
1724 0.0f, // float minLod;
1725 1.0f, // float maxLod;
1726 VK_BORDER_COLOR_INT_OPAQUE_BLACK, // VkBorderColor borderColor;
1727 VK_FALSE, // VkBool32 unnormalizedCoordinates;
1728 };
1729
1730 // Create textures and samplers.
1731 for (size_t i = 0; i < kNumImages; ++i)
1732 {
1733 textures.emplace_back(new ImageWithMemory(vkd, device, alloc, imageCreateInfo, MemoryRequirement::Any));
1734 textureViews.emplace_back(makeImageView(vkd, device, textures.back()->get(), VK_IMAGE_VIEW_TYPE_2D, kImageFormat, imageSubresourceRange));
1735 }
1736
1737 for (size_t i = 0; i < kNumSamplers; ++i)
1738 samplers.emplace_back(createSampler(vkd, device, &samplerCreateInfo));
1739
1740 // Make sure texture data is available in the transfer stage.
1741 const auto textureDataBarrier = makeMemoryBarrier(VK_ACCESS_HOST_WRITE_BIT, VK_ACCESS_TRANSFER_READ_BIT);
1742 vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_HOST_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT, 0u, 1u, &textureDataBarrier, 0u, nullptr, 0u, nullptr);
1743
1744 const auto bufferImageCopy = makeBufferImageCopy(kImageExtent, imageSubresourceLayers);
1745
1746 // Fill textures with data and prepare them for the ray tracing pipeline stages.
1747 for (size_t i = 0; i < kNumImages; ++i)
1748 {
1749 const auto texturePreCopyBarrier = makeImageMemoryBarrier(0u, VK_ACCESS_TRANSFER_WRITE_BIT, VK_IMAGE_LAYOUT_UNDEFINED, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, textures[i]->get(), imageSubresourceRange);
1750 const auto texturePostCopyBarrier = makeImageMemoryBarrier(VK_ACCESS_TRANSFER_WRITE_BIT, VK_ACCESS_SHADER_READ_BIT, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL, textures[i]->get(), imageSubresourceRange);
1751
1752 vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT, 0u, 0u, nullptr, 0u, nullptr, 1u, &texturePreCopyBarrier);
1753 vkd.cmdCopyBufferToImage(cmdBuffer, textureData->get(), textures[i]->get(), VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, 1u, &bufferImageCopy);
1754 vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR, 0u, 0u, nullptr, 0u, nullptr, 1u, &texturePostCopyBarrier);
1755 }
1756 }
1757 else if (storageImageNeeded(m_params.dataType))
1758 {
1759 // Image will be used for storage.
1760 imageCreateInfo.usage = kStorageImageUsage;
1761
1762 textures.emplace_back(new ImageWithMemory(vkd, device, alloc, imageCreateInfo, MemoryRequirement::Any));
1763 textureViews.emplace_back(makeImageView(vkd, device, textures.back()->get(), VK_IMAGE_VIEW_TYPE_2D, kImageFormat, imageSubresourceRange));
1764
1765 // Make sure texture data is available in the transfer stage.
1766 const auto textureDataBarrier = makeMemoryBarrier(VK_ACCESS_HOST_WRITE_BIT, VK_ACCESS_TRANSFER_READ_BIT);
1767 vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_HOST_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT, 0u, 1u, &textureDataBarrier, 0u, nullptr, 0u, nullptr);
1768
1769 const auto bufferImageCopy = makeBufferImageCopy(kImageExtent, imageSubresourceLayers);
1770 const auto texturePreCopyBarrier = makeImageMemoryBarrier(0u, VK_ACCESS_TRANSFER_WRITE_BIT, VK_IMAGE_LAYOUT_UNDEFINED, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, textures.back()->get(), imageSubresourceRange);
1771 const auto texturePostCopyBarrier = makeImageMemoryBarrier(VK_ACCESS_TRANSFER_WRITE_BIT, (VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT), VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, VK_IMAGE_LAYOUT_GENERAL, textures.back()->get(), imageSubresourceRange);
1772
1773 // Fill texture with data and prepare them for the ray tracing pipeline stages.
1774 vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT, 0u, 0u, nullptr, 0u, nullptr, 1u, &texturePreCopyBarrier);
1775 vkd.cmdCopyBufferToImage(cmdBuffer, textureData->get(), textures.back()->get(), VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, 1u, &bufferImageCopy);
1776 vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_TRANSFER_BIT, VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR, 0u, 0u, nullptr, 0u, nullptr, 1u, &texturePostCopyBarrier);
1777 }
1778 else
1779 {
1780 DE_ASSERT(false);
1781 }
1782 }
1783
1784 // Descriptor set layout.
1785 DescriptorSetLayoutBuilder dslBuilder;
1786 dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, 1u, shaderStages, nullptr);
1787 dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1u, shaderStages, nullptr); // Callee buffer.
1788 dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1u, shaderStages, nullptr); // Output buffer.
1789 dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1u, shaderStages, nullptr); // Input buffer.
1790 if (samplersNeeded(m_params.dataType))
1791 {
1792 dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE, 2u, shaderStages, nullptr);
1793 dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_SAMPLER, 2u, shaderStages, nullptr);
1794 dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, 2u, shaderStages, nullptr);
1795 }
1796 else if (storageImageNeeded(m_params.dataType))
1797 {
1798 dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, 1u, shaderStages, nullptr);
1799 }
1800 const auto descriptorSetLayout = dslBuilder.build(vkd, device);
1801
1802 // Pipeline layout.
1803 const auto pipelineLayout = makePipelineLayout(vkd, device, descriptorSetLayout.get());
1804
1805 // Descriptor pool and set.
1806 DescriptorPoolBuilder poolBuilder;
1807 poolBuilder.addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR);
1808 poolBuilder.addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 3u);
1809 if (samplersNeeded(m_params.dataType))
1810 {
1811 poolBuilder.addType(VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE, 2u);
1812 poolBuilder.addType(VK_DESCRIPTOR_TYPE_SAMPLER, 2u);
1813 poolBuilder.addType(VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, 2u);
1814 }
1815 else if (storageImageNeeded(m_params.dataType))
1816 {
1817 poolBuilder.addType(VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, 1u);
1818 }
1819 const auto descriptorPool = poolBuilder.build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
1820 const auto descriptorSet = makeDescriptorSet(vkd, device, descriptorPool.get(), descriptorSetLayout.get());
1821
1822 // Update descriptor set.
1823 {
1824 const VkWriteDescriptorSetAccelerationStructureKHR writeASInfo =
1825 {
1826 VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR,
1827 nullptr,
1828 1u,
1829 topLevelAccelerationStructure.get()->getPtr(),
1830 };
1831
1832 DescriptorSetUpdateBuilder updateBuilder;
1833
1834 const auto ds = descriptorSet.get();
1835
1836 const auto calleeBufferDescriptorInfo = makeDescriptorBufferInfo(calleeBuffer.get(), 0ull, VK_WHOLE_SIZE);
1837 const auto outputBufferDescriptorInfo = makeDescriptorBufferInfo(outputBuffer.get(), 0ull, VK_WHOLE_SIZE);
1838 const auto inputBufferDescriptorInfo = makeDescriptorBufferInfo(inputBuffer.get(), 0ull, VK_WHOLE_SIZE);
1839
1840 updateBuilder.writeSingle(ds, DescriptorSetUpdateBuilder::Location::binding(0u), VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &writeASInfo);
1841 updateBuilder.writeSingle(ds, DescriptorSetUpdateBuilder::Location::binding(1u), VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &calleeBufferDescriptorInfo);
1842 updateBuilder.writeSingle(ds, DescriptorSetUpdateBuilder::Location::binding(2u), VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &outputBufferDescriptorInfo);
1843 updateBuilder.writeSingle(ds, DescriptorSetUpdateBuilder::Location::binding(3u), VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &inputBufferDescriptorInfo);
1844
1845 if (samplersNeeded(m_params.dataType))
1846 {
1847 // Update textures, samplers and combined image samplers.
1848 std::vector<VkDescriptorImageInfo> textureDescInfos;
1849 std::vector<VkDescriptorImageInfo> textureSamplerInfos;
1850 std::vector<VkDescriptorImageInfo> combinedSamplerInfos;
1851
1852 for (size_t i = 0; i < kNumAloneImages; ++i)
1853 textureDescInfos.push_back(makeDescriptorImageInfo(DE_NULL, textureViews[i].get(), VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL));
1854 for (size_t i = 0; i < kNumAloneSamplers; ++i)
1855 textureSamplerInfos.push_back(makeDescriptorImageInfo(samplers[i].get(), DE_NULL, VK_IMAGE_LAYOUT_UNDEFINED));
1856
1857 for (size_t i = 0; i < kNumCombined; ++i)
1858 combinedSamplerInfos.push_back(makeDescriptorImageInfo(samplers[i + kNumAloneSamplers].get(), textureViews[i + kNumAloneImages].get(), VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL));
1859
1860 updateBuilder.writeArray(ds, DescriptorSetUpdateBuilder::Location::binding(4u), VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE, kNumAloneImages, textureDescInfos.data());
1861 updateBuilder.writeArray(ds, DescriptorSetUpdateBuilder::Location::binding(5u), VK_DESCRIPTOR_TYPE_SAMPLER, kNumAloneSamplers, textureSamplerInfos.data());
1862 updateBuilder.writeArray(ds, DescriptorSetUpdateBuilder::Location::binding(6u), VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, kNumCombined, combinedSamplerInfos.data());
1863 }
1864 else if (storageImageNeeded(m_params.dataType))
1865 {
1866 const auto storageImageDescriptorInfo = makeDescriptorImageInfo(DE_NULL, textureViews.back().get(), VK_IMAGE_LAYOUT_GENERAL);
1867 updateBuilder.writeSingle(ds, DescriptorSetUpdateBuilder::Location::binding(4u), VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, &storageImageDescriptorInfo);
1868 }
1869
1870 updateBuilder.update(vkd, device);
1871 }
1872
1873 // Create raytracing pipeline and shader binding tables.
1874 Move<VkPipeline> pipeline;
1875
1876 de::MovePtr<BufferWithMemory> raygenShaderBindingTable;
1877 de::MovePtr<BufferWithMemory> missShaderBindingTable;
1878 de::MovePtr<BufferWithMemory> hitShaderBindingTable;
1879 de::MovePtr<BufferWithMemory> callableShaderBindingTable;
1880
1881 VkStridedDeviceAddressRegionKHR raygenShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
1882 VkStridedDeviceAddressRegionKHR missShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
1883 VkStridedDeviceAddressRegionKHR hitShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
1884 VkStridedDeviceAddressRegionKHR callableShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
1885
1886 {
1887 const auto rayTracingPipeline = de::newMovePtr<RayTracingPipeline>();
1888 const auto callType = m_params.callType;
1889
1890 // Every case uses a ray generation shader.
1891 rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR, createShaderModule(vkd, device, m_context.getBinaryCollection().get("rgen"), 0), 0);
1892
1893 if (callType == CallType::TRACE_RAY)
1894 {
1895 rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR, createShaderModule(vkd, device, m_context.getBinaryCollection().get("chit"), 0), 1);
1896 }
1897 else if (callType == CallType::EXECUTE_CALLABLE)
1898 {
1899 rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR, createShaderModule(vkd, device, m_context.getBinaryCollection().get("call"), 0), 1);
1900 }
1901 else if (callType == CallType::REPORT_INTERSECTION)
1902 {
1903 rayTracingPipeline->addShader(VK_SHADER_STAGE_INTERSECTION_BIT_KHR, createShaderModule(vkd, device, m_context.getBinaryCollection().get("rint"), 0), 1);
1904 rayTracingPipeline->addShader(VK_SHADER_STAGE_ANY_HIT_BIT_KHR, createShaderModule(vkd, device, m_context.getBinaryCollection().get("ahit"), 0), 1);
1905 }
1906 else
1907 {
1908 DE_ASSERT(false);
1909 }
1910
1911 pipeline = rayTracingPipeline->createPipeline(vkd, device, pipelineLayout.get());
1912
1913 raygenShaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
1914 raygenShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
1915
1916 if (callType == CallType::EXECUTE_CALLABLE)
1917 {
1918 callableShaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
1919 callableShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
1920 }
1921 else if (callType == CallType::TRACE_RAY || callType == CallType::REPORT_INTERSECTION)
1922 {
1923 hitShaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
1924 hitShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
1925 }
1926 else
1927 {
1928 DE_ASSERT(false);
1929 }
1930 }
1931
1932 // Use ray tracing pipeline.
1933 vkd.cmdBindPipeline(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipeline.get());
1934 vkd.cmdBindDescriptorSets(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipelineLayout.get(), 0u, 1u, &descriptorSet.get(), 0u, nullptr);
1935 vkd.cmdTraceRaysKHR(cmdBuffer, &raygenShaderBindingTableRegion, &missShaderBindingTableRegion, &hitShaderBindingTableRegion, &callableShaderBindingTableRegion, 1u, 1u, 1u);
1936
1937 // Synchronize output and callee buffers.
1938 const auto memBarrier = makeMemoryBarrier(VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_HOST_READ_BIT);
1939 vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR, VK_PIPELINE_STAGE_HOST_BIT, 0u, 1u, &memBarrier, 0u, nullptr, 0u, nullptr);
1940
1941 endCommandBuffer(vkd, cmdBuffer);
1942 submitCommandsAndWait(vkd, device, queue, cmdBuffer);
1943
1944 // Verify output and callee buffers.
1945 invalidateAlloc(vkd, device, outputBufferAlloc);
1946 invalidateAlloc(vkd, device, calleeBufferAlloc);
1947
1948 std::map<std::string, void*> bufferPtrs;
1949 bufferPtrs["output"] = outputBufferPtr;
1950 bufferPtrs["callee"] = calleeBufferPtr;
1951
1952 for (const auto& ptr : bufferPtrs)
1953 {
1954 const auto& bufferName = ptr.first;
1955 const auto& bufferPtr = ptr.second;
1956
1957 deUint32 outputVal;
1958 deMemcpy(&outputVal, bufferPtr, sizeof(outputVal));
1959
1960 if (outputVal != 1u)
1961 return tcu::TestStatus::fail("Unexpected value found in " + bufferName + " buffer: " + de::toString(outputVal));
1962 }
1963
1964 return tcu::TestStatus::pass("Pass");
1965 }
1966
1967 enum class InterfaceType
1968 {
1969 RAY_PAYLOAD = 0,
1970 CALLABLE_DATA,
1971 HIT_ATTRIBUTES,
1972 SHADER_RECORD_BUFFER_RGEN,
1973 SHADER_RECORD_BUFFER_CALL,
1974 SHADER_RECORD_BUFFER_MISS,
1975 SHADER_RECORD_BUFFER_HIT,
1976 };
1977
1978 // Separate class to ease testing pipeline interface variables.
1979 class DataSpillPipelineInterfaceTestCase : public vkt::TestCase
1980 {
1981 public:
1982 struct TestParams
1983 {
1984 InterfaceType interfaceType;
1985 };
1986
1987 DataSpillPipelineInterfaceTestCase (tcu::TestContext& testCtx, const std::string& name, const std::string& description, const TestParams& testParams);
~DataSpillPipelineInterfaceTestCase(void)1988 virtual ~DataSpillPipelineInterfaceTestCase (void) {}
1989
1990 virtual void initPrograms (vk::SourceCollections& programCollection) const;
1991 virtual TestInstance* createInstance (Context& context) const;
1992 virtual void checkSupport (Context& context) const;
1993
1994 private:
1995 TestParams m_params;
1996 };
1997
1998 class DataSpillPipelineInterfaceTestInstance : public vkt::TestInstance
1999 {
2000 public:
2001 using TestParams = DataSpillPipelineInterfaceTestCase::TestParams;
2002
2003 DataSpillPipelineInterfaceTestInstance (Context& context, const TestParams& testParams);
~DataSpillPipelineInterfaceTestInstance(void)2004 ~DataSpillPipelineInterfaceTestInstance (void) {}
2005
2006 tcu::TestStatus iterate (void);
2007
2008 private:
2009 TestParams m_params;
2010 };
2011
DataSpillPipelineInterfaceTestCase(tcu::TestContext & testCtx,const std::string & name,const std::string & description,const TestParams & testParams)2012 DataSpillPipelineInterfaceTestCase::DataSpillPipelineInterfaceTestCase (tcu::TestContext& testCtx, const std::string& name, const std::string& description, const TestParams& testParams)
2013 : vkt::TestCase (testCtx, name, description)
2014 , m_params (testParams)
2015 {
2016 }
2017
createInstance(Context & context) const2018 TestInstance* DataSpillPipelineInterfaceTestCase::createInstance (Context& context) const
2019 {
2020 return new DataSpillPipelineInterfaceTestInstance (context, m_params);
2021 }
2022
DataSpillPipelineInterfaceTestInstance(Context & context,const TestParams & testParams)2023 DataSpillPipelineInterfaceTestInstance::DataSpillPipelineInterfaceTestInstance (Context& context, const TestParams& testParams)
2024 : vkt::TestInstance (context)
2025 , m_params (testParams)
2026 {
2027 }
2028
checkSupport(Context & context) const2029 void DataSpillPipelineInterfaceTestCase::checkSupport (Context& context) const
2030 {
2031 commonCheckSupport(context);
2032 }
2033
initPrograms(vk::SourceCollections & programCollection) const2034 void DataSpillPipelineInterfaceTestCase::initPrograms (vk::SourceCollections& programCollection) const
2035 {
2036 const vk::ShaderBuildOptions buildOptions (programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_4, 0u, true);
2037
2038 const std::string glslHeader =
2039 "#version 460 core\n"
2040 "#extension GL_EXT_ray_tracing : require\n"
2041 ;
2042
2043 const std::string glslBindings =
2044 "layout(set = 0, binding = 0) uniform accelerationStructureEXT topLevelAS;\n"
2045 "layout(set = 0, binding = 1) buffer StorageBlock { uint val[" + std::to_string(kNumStorageValues) + "]; } storageBuffer;\n"
2046 ;
2047
2048 if (m_params.interfaceType == InterfaceType::RAY_PAYLOAD)
2049 {
2050 // The closest hit shader will store 100 in the second array position.
2051 // The ray gen shader will store 103 in the first array position using the hitValue after the traceRayExt() call.
2052
2053 std::ostringstream rgen;
2054 rgen
2055 << glslHeader
2056 << "layout(location = 0) rayPayloadEXT vec3 hitValue;\n"
2057 << glslBindings
2058 << "void main()\n"
2059 << "{\n"
2060 << " hitValue = vec3(10.0, 30.0, 60.0);\n"
2061 << " traceRayEXT(topLevelAS, 0u, 0xFFu, 0, 0, 0, vec3(0.5, 0.5, 0.0), 0.0, vec3(0.0, 0.0, -1.0), 9.0, 0);\n"
2062 << " storageBuffer.val[0] = uint(hitValue.x + hitValue.y + hitValue.z);\n"
2063 << "}\n"
2064 ;
2065 programCollection.glslSources.add("rgen") << glu::RaygenSource(updateRayTracingGLSL(rgen.str())) << buildOptions;
2066
2067 std::stringstream chit;
2068 chit
2069 << glslHeader
2070 << "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
2071 << "hitAttributeEXT vec3 attribs;\n"
2072 << glslBindings
2073 << "void main()\n"
2074 << "{\n"
2075 << " storageBuffer.val[1] = uint(hitValue.x + hitValue.y + hitValue.z);\n"
2076 << " hitValue = vec3(hitValue.x + 1.0, hitValue.y + 1.0, hitValue.z + 1.0);\n"
2077 << "}\n";
2078 ;
2079 programCollection.glslSources.add("chit") << glu::ClosestHitSource(updateRayTracingGLSL(chit.str())) << buildOptions;
2080 }
2081 else if (m_params.interfaceType == InterfaceType::CALLABLE_DATA)
2082 {
2083 // The callable shader shader will store 100 in the second array position.
2084 // The ray gen shader will store 200 in the first array position using the callable data after the executeCallableEXT() call.
2085
2086 std::ostringstream rgen;
2087 rgen
2088 << glslHeader
2089 << "layout(location = 0) callableDataEXT float callableData;\n"
2090 << glslBindings
2091 << "void main()\n"
2092 << "{\n"
2093 << " callableData = 100.0;\n"
2094 << " executeCallableEXT(0, 0);\n"
2095 << " storageBuffer.val[0] = uint(callableData);\n"
2096 << "}\n"
2097 ;
2098 programCollection.glslSources.add("rgen") << glu::RaygenSource(updateRayTracingGLSL(rgen.str())) << buildOptions;
2099
2100 std::ostringstream call;
2101 call
2102 << glslHeader
2103 << "layout(location = 0) callableDataInEXT float callableData;\n"
2104 << glslBindings
2105 << "void main()\n"
2106 << "{\n"
2107 << " storageBuffer.val[1] = uint(callableData);\n"
2108 << " callableData = callableData * 2.0;\n"
2109 << "}\n"
2110 ;
2111
2112 programCollection.glslSources.add("call") << glu::CallableSource(updateRayTracingGLSL(call.str())) << buildOptions;
2113 }
2114 else if (m_params.interfaceType == InterfaceType::HIT_ATTRIBUTES)
2115 {
2116 // The ray gen shader will store value 300 in the first storage buffer position.
2117 // The intersection shader will store value 315 in the second storage buffer position.
2118 // The closes hit shader will store value 330 in the third storage buffer position using the hit attributes.
2119
2120 std::ostringstream rgen;
2121 rgen
2122 << glslHeader
2123 << "layout(location = 0) rayPayloadEXT vec3 hitValue;\n"
2124 << glslBindings
2125 << "void main()\n"
2126 << "{\n"
2127 << " traceRayEXT(topLevelAS, 0u, 0xFFu, 0, 0, 0, vec3(0.5, 0.5, 0.0), 0.0, vec3(0.0, 0.0, -1.0), 9.0, 0);\n"
2128 << " storageBuffer.val[0] = 300u;\n"
2129 << "}\n"
2130 ;
2131 programCollection.glslSources.add("rgen") << glu::RaygenSource(updateRayTracingGLSL(rgen.str())) << buildOptions;
2132
2133 std::stringstream rint;
2134 rint
2135 << glslHeader
2136 << "hitAttributeEXT vec3 attribs;\n"
2137 << glslBindings
2138 << "void main()\n"
2139 << "{\n"
2140 << " attribs = vec3(140.0, 160.0, 30.0);\n"
2141 << " storageBuffer.val[1] = 315u;\n"
2142 << " reportIntersectionEXT(1.0f, 0);\n"
2143 << "}\n"
2144 ;
2145
2146 programCollection.glslSources.add("rint") << glu::IntersectionSource(updateRayTracingGLSL(rint.str())) << buildOptions;
2147
2148 std::stringstream chit;
2149 chit
2150 << glslHeader
2151 << "layout(location = 0) rayPayloadInEXT vec3 hitValue;\n"
2152 << "hitAttributeEXT vec3 attribs;\n"
2153 << glslBindings
2154 << "void main()\n"
2155 << "{\n"
2156 << " storageBuffer.val[2] = uint(attribs.x + attribs.y + attribs.z);\n"
2157 << "}\n";
2158 ;
2159 programCollection.glslSources.add("chit") << glu::ClosestHitSource(updateRayTracingGLSL(chit.str())) << buildOptions;
2160
2161 }
2162 else if (m_params.interfaceType == InterfaceType::SHADER_RECORD_BUFFER_RGEN)
2163 {
2164 // The ray gen shader will have a uvec4 in the shader record buffer with contents 400, 401, 402, 403.
2165 // The shader will call a callable shader indicating a position in that vec4 (0, 1, 2, 3). For example, let's use position 1.
2166 // The callable shader will return the indicated position+1 modulo 4, so it will return 2 in our case.
2167 // *After* returning from the callable shader, the raygen shader will use that reply to access position 2 and write a 402 in the first output buffer position.
2168 // The callable shader will store 450 in the second output buffer position.
2169
2170 std::ostringstream rgen;
2171 rgen
2172 << glslHeader
2173 << "layout(shaderRecordEXT) buffer ShaderRecordStruct {\n"
2174 << " uvec4 info;\n"
2175 << "};\n"
2176 << "layout(location = 0) callableDataEXT uint callableData;\n"
2177 << glslBindings
2178 << "void main()\n"
2179 << "{\n"
2180 << " callableData = 1u;"
2181 << " executeCallableEXT(0, 0);\n"
2182 << " if (callableData == 0u) storageBuffer.val[0] = info.x;\n"
2183 << " else if (callableData == 1u) storageBuffer.val[0] = info.y;\n"
2184 << " else if (callableData == 2u) storageBuffer.val[0] = info.z;\n"
2185 << " else if (callableData == 3u) storageBuffer.val[0] = info.w;\n"
2186 << "}\n"
2187 ;
2188 programCollection.glslSources.add("rgen") << glu::RaygenSource(updateRayTracingGLSL(rgen.str())) << buildOptions;
2189
2190 std::ostringstream call;
2191 call
2192 << glslHeader
2193 << "layout(location = 0) callableDataInEXT uint callableData;\n"
2194 << glslBindings
2195 << "void main()\n"
2196 << "{\n"
2197 << " storageBuffer.val[1] = 450u;\n"
2198 << " callableData = (callableData + 1u) % 4u;\n"
2199 << "}\n"
2200 ;
2201
2202 programCollection.glslSources.add("call") << glu::CallableSource(updateRayTracingGLSL(call.str())) << buildOptions;
2203 }
2204 else if (m_params.interfaceType == InterfaceType::SHADER_RECORD_BUFFER_CALL)
2205 {
2206 // Similar to the previous case, with a twist:
2207 // * rgen passes the vector position.
2208 // * call increases that by one.
2209 // * subcall increases again and does the modulo operation, also writing 450 in the third output buffer value.
2210 // * call is the one accessing the vector at the returned position, writing 403 in this case to the second output buffer value.
2211 // * call passes this value back doubled to rgen, which writes it to the first output buffer value (806).
2212
2213 std::ostringstream rgen;
2214 rgen
2215 << glslHeader
2216 << "layout(location = 0) callableDataEXT uint callableData;\n"
2217 << glslBindings
2218 << "void main()\n"
2219 << "{\n"
2220 << " callableData = 1u;\n"
2221 << " executeCallableEXT(0, 0);\n"
2222 << " storageBuffer.val[0] = callableData;\n"
2223 << "}\n"
2224 ;
2225 programCollection.glslSources.add("rgen") << glu::RaygenSource(updateRayTracingGLSL(rgen.str())) << buildOptions;
2226
2227 std::ostringstream call;
2228 call
2229 << glslHeader
2230 << "layout(shaderRecordEXT) buffer ShaderRecordStruct {\n"
2231 << " uvec4 info;\n"
2232 << "};\n"
2233 << "layout(location = 0) callableDataInEXT uint callableDataIn;\n"
2234 << "layout(location = 1) callableDataEXT uint callableDataOut;\n"
2235 << glslBindings
2236 << "void main()\n"
2237 << "{\n"
2238 << " callableDataOut = callableDataIn + 1u;\n"
2239 << " executeCallableEXT(1, 1);\n"
2240 << " uint outputBufferValue = 777u;\n"
2241 << " if (callableDataOut == 0u) outputBufferValue = info.x;\n"
2242 << " else if (callableDataOut == 1u) outputBufferValue = info.y;\n"
2243 << " else if (callableDataOut == 2u) outputBufferValue = info.z;\n"
2244 << " else if (callableDataOut == 3u) outputBufferValue = info.w;\n"
2245 << " storageBuffer.val[1] = outputBufferValue;\n"
2246 << " callableDataIn = outputBufferValue * 2u;\n"
2247 << "}\n"
2248 ;
2249
2250 programCollection.glslSources.add("call") << glu::CallableSource(updateRayTracingGLSL(call.str())) << buildOptions;
2251
2252 std::ostringstream subcall;
2253 subcall
2254 << glslHeader
2255 << "layout(location = 1) callableDataInEXT uint callableData;\n"
2256 << glslBindings
2257 << "void main()\n"
2258 << "{\n"
2259 << " callableData = (callableData + 1u) % 4u;\n"
2260 << " storageBuffer.val[2] = 450u;\n"
2261 << "}\n"
2262 ;
2263
2264 programCollection.glslSources.add("subcall") << glu::CallableSource(updateRayTracingGLSL(subcall.str())) << buildOptions;
2265 }
2266 else if (m_params.interfaceType == InterfaceType::SHADER_RECORD_BUFFER_MISS || m_params.interfaceType == InterfaceType::SHADER_RECORD_BUFFER_HIT)
2267 {
2268 // Similar to the previous one, but the intermediate call shader has been replaced with a miss or closest hit shader.
2269 // The rgen shader will communicate with the miss/chit shader using the ray payload instead of the callable data.
2270 // Also, the initial position will be 2, so it will wrap around in this case. The numbers will also change.
2271
2272 std::ostringstream rgen;
2273 rgen
2274 << glslHeader
2275 << "layout(location = 0) rayPayloadEXT uint rayPayload;\n"
2276 << glslBindings
2277 << "void main()\n"
2278 << "{\n"
2279 << " rayPayload = 2u;\n"
2280 << " traceRayEXT(topLevelAS, 0u, 0xFFu, 0, 0, 0, vec3(0.5, 0.5, 0.0), 0.0, vec3(0.0, 0.0, -1.0), 9.0, 0);\n"
2281 << " storageBuffer.val[0] = rayPayload;\n"
2282 << "}\n"
2283 ;
2284 programCollection.glslSources.add("rgen") << glu::RaygenSource(updateRayTracingGLSL(rgen.str())) << buildOptions;
2285
2286 std::ostringstream chitOrMiss;
2287 chitOrMiss
2288 << glslHeader
2289 << "layout(shaderRecordEXT) buffer ShaderRecordStruct {\n"
2290 << " uvec4 info;\n"
2291 << "};\n"
2292 << "layout(location = 0) rayPayloadInEXT uint rayPayload;\n"
2293 << "layout(location = 0) callableDataEXT uint callableData;\n"
2294 << glslBindings
2295 << "void main()\n"
2296 << "{\n"
2297 << " callableData = rayPayload + 1u;\n"
2298 << " executeCallableEXT(0, 0);\n"
2299 << " uint outputBufferValue = 777u;\n"
2300 << " if (callableData == 0u) outputBufferValue = info.x;\n"
2301 << " else if (callableData == 1u) outputBufferValue = info.y;\n"
2302 << " else if (callableData == 2u) outputBufferValue = info.z;\n"
2303 << " else if (callableData == 3u) outputBufferValue = info.w;\n"
2304 << " storageBuffer.val[1] = outputBufferValue;\n"
2305 << " rayPayload = outputBufferValue * 3u;\n"
2306 << "}\n"
2307 ;
2308
2309 if (m_params.interfaceType == InterfaceType::SHADER_RECORD_BUFFER_MISS)
2310 programCollection.glslSources.add("miss") << glu::MissSource(updateRayTracingGLSL(chitOrMiss.str())) << buildOptions;
2311 else if (m_params.interfaceType == InterfaceType::SHADER_RECORD_BUFFER_HIT)
2312 programCollection.glslSources.add("chit") << glu::ClosestHitSource(updateRayTracingGLSL(chitOrMiss.str())) << buildOptions;
2313 else
2314 DE_ASSERT(false);
2315
2316 std::ostringstream call;
2317 call
2318 << glslHeader
2319 << "layout(location = 0) callableDataInEXT uint callableData;\n"
2320 << glslBindings
2321 << "void main()\n"
2322 << "{\n"
2323 << " storageBuffer.val[2] = 490u;\n"
2324 << " callableData = (callableData + 1u) % 4u;\n"
2325 << "}\n"
2326 ;
2327
2328 programCollection.glslSources.add("call") << glu::CallableSource(updateRayTracingGLSL(call.str())) << buildOptions;
2329 }
2330 else
2331 {
2332 DE_ASSERT(false);
2333 }
2334 }
2335
getShaderStages(InterfaceType type_)2336 VkShaderStageFlags getShaderStages (InterfaceType type_)
2337 {
2338 VkShaderStageFlags flags = VK_SHADER_STAGE_RAYGEN_BIT_KHR;
2339
2340 switch (type_)
2341 {
2342 case InterfaceType::HIT_ATTRIBUTES:
2343 flags |= VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
2344 // fallthrough.
2345 case InterfaceType::RAY_PAYLOAD:
2346 flags |= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
2347 break;
2348 case InterfaceType::CALLABLE_DATA:
2349 case InterfaceType::SHADER_RECORD_BUFFER_RGEN:
2350 case InterfaceType::SHADER_RECORD_BUFFER_CALL:
2351 flags |= VK_SHADER_STAGE_CALLABLE_BIT_KHR;
2352 break;
2353 case InterfaceType::SHADER_RECORD_BUFFER_MISS:
2354 flags |= VK_SHADER_STAGE_CALLABLE_BIT_KHR;
2355 flags |= VK_SHADER_STAGE_MISS_BIT_KHR;
2356 break;
2357 case InterfaceType::SHADER_RECORD_BUFFER_HIT:
2358 flags |= VK_SHADER_STAGE_CALLABLE_BIT_KHR;
2359 flags |= VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
2360 break;
2361 default:
2362 DE_ASSERT(false);
2363 break;
2364 }
2365
2366 return flags;
2367 }
2368
2369 // Proper stage for generating default geometry.
getShaderStageForGeometry(InterfaceType type_)2370 VkShaderStageFlagBits getShaderStageForGeometry (InterfaceType type_)
2371 {
2372 VkShaderStageFlagBits bits = VK_SHADER_STAGE_FLAG_BITS_MAX_ENUM;
2373
2374 switch (type_)
2375 {
2376 case InterfaceType::HIT_ATTRIBUTES: bits = VK_SHADER_STAGE_INTERSECTION_BIT_KHR; break;
2377 case InterfaceType::RAY_PAYLOAD: bits = VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR; break;
2378 case InterfaceType::CALLABLE_DATA: bits = VK_SHADER_STAGE_CALLABLE_BIT_KHR; break;
2379 case InterfaceType::SHADER_RECORD_BUFFER_RGEN: bits = VK_SHADER_STAGE_CALLABLE_BIT_KHR; break;
2380 case InterfaceType::SHADER_RECORD_BUFFER_CALL: bits = VK_SHADER_STAGE_CALLABLE_BIT_KHR; break;
2381 case InterfaceType::SHADER_RECORD_BUFFER_MISS: bits = VK_SHADER_STAGE_MISS_BIT_KHR; break;
2382 case InterfaceType::SHADER_RECORD_BUFFER_HIT: bits = VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR; break;
2383 default: DE_ASSERT(false); break;
2384 }
2385
2386 DE_ASSERT(bits != VK_SHADER_STAGE_FLAG_BITS_MAX_ENUM);
2387 return bits;
2388 }
2389
createSBTWithShaderRecord(const DeviceInterface & vkd,VkDevice device,vk::Allocator & alloc,VkPipeline pipeline,RayTracingPipeline * rayTracingPipeline,deUint32 shaderGroupHandleSize,deUint32 shaderGroupBaseAlignment,deUint32 firstGroup,deUint32 groupCount,de::MovePtr<BufferWithMemory> & shaderBindingTable,VkStridedDeviceAddressRegionKHR & shaderBindingTableRegion)2390 void createSBTWithShaderRecord (const DeviceInterface& vkd, VkDevice device, vk::Allocator &alloc,
2391 VkPipeline pipeline, RayTracingPipeline* rayTracingPipeline,
2392 deUint32 shaderGroupHandleSize, deUint32 shaderGroupBaseAlignment,
2393 deUint32 firstGroup, deUint32 groupCount,
2394 de::MovePtr<BufferWithMemory>& shaderBindingTable,
2395 VkStridedDeviceAddressRegionKHR& shaderBindingTableRegion)
2396 {
2397 const auto alignedSize = de::roundUp(shaderGroupHandleSize + kShaderRecordSize, shaderGroupHandleSize);
2398 shaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline, alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, firstGroup, groupCount, 0u, 0u, MemoryRequirement::Any, 0u, 0u, kShaderRecordSize);
2399 shaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, shaderBindingTable->get(), 0), alignedSize, groupCount * alignedSize);
2400
2401 // Fill shader record buffer data.
2402 // Note we will only fill the first shader record after the handle.
2403 const tcu::UVec4 shaderRecordData (400u, 401u, 402u, 403u);
2404 auto& sbtAlloc = shaderBindingTable->getAllocation();
2405 auto* dataPtr = reinterpret_cast<deUint8*>(sbtAlloc.getHostPtr()) + shaderGroupHandleSize;
2406
2407 DE_STATIC_ASSERT(sizeof(shaderRecordData) == static_cast<size_t>(kShaderRecordSize));
2408 deMemcpy(dataPtr, &shaderRecordData, sizeof(shaderRecordData));
2409 }
2410
iterate(void)2411 tcu::TestStatus DataSpillPipelineInterfaceTestInstance::iterate (void)
2412 {
2413 const auto& vki = m_context.getInstanceInterface();
2414 const auto physicalDevice = m_context.getPhysicalDevice();
2415 const auto& vkd = m_context.getDeviceInterface();
2416 const auto device = m_context.getDevice();
2417 const auto queue = m_context.getUniversalQueue();
2418 const auto familyIndex = m_context.getUniversalQueueFamilyIndex();
2419 auto& alloc = m_context.getDefaultAllocator();
2420 const auto shaderStages = getShaderStages(m_params.interfaceType);
2421
2422 // Command buffer.
2423 const auto cmdPool = makeCommandPool(vkd, device, familyIndex);
2424 const auto cmdBufferPtr = allocateCommandBuffer(vkd, device, cmdPool.get(), VK_COMMAND_BUFFER_LEVEL_PRIMARY);
2425 const auto cmdBuffer = cmdBufferPtr.get();
2426
2427 beginCommandBuffer(vkd, cmdBuffer);
2428
2429 // Storage buffer.
2430 std::array<deUint32, kNumStorageValues> storageBufferData;
2431 const auto storageBufferSize = de::dataSize(storageBufferData);
2432 const auto storagebufferInfo = makeBufferCreateInfo(storageBufferSize, VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
2433 BufferWithMemory storageBuffer (vkd, device, alloc, storagebufferInfo, MemoryRequirement::HostVisible);
2434
2435 // Zero-out buffer.
2436 auto& storageBufferAlloc = storageBuffer.getAllocation();
2437 auto* storageBufferPtr = storageBufferAlloc.getHostPtr();
2438 deMemset(storageBufferPtr, 0, storageBufferSize);
2439 flushAlloc(vkd, device, storageBufferAlloc);
2440
2441 // Acceleration structures.
2442 de::MovePtr<BottomLevelAccelerationStructure> bottomLevelAccelerationStructure;
2443 de::MovePtr<TopLevelAccelerationStructure> topLevelAccelerationStructure;
2444
2445 bottomLevelAccelerationStructure = makeBottomLevelAccelerationStructure();
2446 bottomLevelAccelerationStructure->setDefaultGeometryData(getShaderStageForGeometry(m_params.interfaceType), VK_GEOMETRY_NO_DUPLICATE_ANY_HIT_INVOCATION_BIT_KHR);
2447 bottomLevelAccelerationStructure->createAndBuild(vkd, device, cmdBuffer, alloc);
2448
2449 topLevelAccelerationStructure = makeTopLevelAccelerationStructure();
2450 topLevelAccelerationStructure->setInstanceCount(1);
2451 topLevelAccelerationStructure->addInstance(de::SharedPtr<BottomLevelAccelerationStructure>(bottomLevelAccelerationStructure.release()));
2452 topLevelAccelerationStructure->createAndBuild(vkd, device, cmdBuffer, alloc);
2453
2454 // Get some ray tracing properties.
2455 deUint32 shaderGroupHandleSize = 0u;
2456 deUint32 shaderGroupBaseAlignment = 1u;
2457 {
2458 const auto rayTracingPropertiesKHR = makeRayTracingProperties(vki, physicalDevice);
2459 shaderGroupHandleSize = rayTracingPropertiesKHR->getShaderGroupHandleSize();
2460 shaderGroupBaseAlignment = rayTracingPropertiesKHR->getShaderGroupBaseAlignment();
2461 }
2462
2463 // Descriptor set layout.
2464 DescriptorSetLayoutBuilder dslBuilder;
2465 dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, 1u, shaderStages, nullptr);
2466 dslBuilder.addBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1u, shaderStages, nullptr); // Callee buffer.
2467 const auto descriptorSetLayout = dslBuilder.build(vkd, device);
2468
2469 // Pipeline layout.
2470 const auto pipelineLayout = makePipelineLayout(vkd, device, descriptorSetLayout.get());
2471
2472 // Descriptor pool and set.
2473 DescriptorPoolBuilder poolBuilder;
2474 poolBuilder.addType(VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR);
2475 poolBuilder.addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
2476 const auto descriptorPool = poolBuilder.build(vkd, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
2477 const auto descriptorSet = makeDescriptorSet(vkd, device, descriptorPool.get(), descriptorSetLayout.get());
2478
2479 // Update descriptor set.
2480 {
2481 const VkWriteDescriptorSetAccelerationStructureKHR writeASInfo =
2482 {
2483 VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET_ACCELERATION_STRUCTURE_KHR,
2484 nullptr,
2485 1u,
2486 topLevelAccelerationStructure.get()->getPtr(),
2487 };
2488
2489 const auto ds = descriptorSet.get();
2490 const auto storageBufferDescriptorInfo = makeDescriptorBufferInfo(storageBuffer.get(), 0ull, VK_WHOLE_SIZE);
2491
2492 DescriptorSetUpdateBuilder updateBuilder;
2493 updateBuilder.writeSingle(ds, DescriptorSetUpdateBuilder::Location::binding(0u), VK_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE_KHR, &writeASInfo);
2494 updateBuilder.writeSingle(ds, DescriptorSetUpdateBuilder::Location::binding(1u), VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &storageBufferDescriptorInfo);
2495 updateBuilder.update(vkd, device);
2496 }
2497
2498 // Create raytracing pipeline and shader binding tables.
2499 const auto interfaceType = m_params.interfaceType;
2500 Move<VkPipeline> pipeline;
2501
2502 de::MovePtr<BufferWithMemory> raygenShaderBindingTable;
2503 de::MovePtr<BufferWithMemory> missShaderBindingTable;
2504 de::MovePtr<BufferWithMemory> hitShaderBindingTable;
2505 de::MovePtr<BufferWithMemory> callableShaderBindingTable;
2506
2507 VkStridedDeviceAddressRegionKHR raygenShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
2508 VkStridedDeviceAddressRegionKHR missShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
2509 VkStridedDeviceAddressRegionKHR hitShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
2510 VkStridedDeviceAddressRegionKHR callableShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(DE_NULL, 0, 0);
2511
2512 {
2513 const auto rayTracingPipeline = de::newMovePtr<RayTracingPipeline>();
2514
2515 // Every case uses a ray generation shader.
2516 rayTracingPipeline->addShader(VK_SHADER_STAGE_RAYGEN_BIT_KHR, createShaderModule(vkd, device, m_context.getBinaryCollection().get("rgen"), 0), 0);
2517
2518 if (interfaceType == InterfaceType::RAY_PAYLOAD)
2519 {
2520 rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR, createShaderModule(vkd, device, m_context.getBinaryCollection().get("chit"), 0), 1);
2521 }
2522 else if (interfaceType == InterfaceType::CALLABLE_DATA || interfaceType == InterfaceType::SHADER_RECORD_BUFFER_RGEN)
2523 {
2524 rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR, createShaderModule(vkd, device, m_context.getBinaryCollection().get("call"), 0), 1);
2525 }
2526 else if (interfaceType == InterfaceType::HIT_ATTRIBUTES)
2527 {
2528 rayTracingPipeline->addShader(VK_SHADER_STAGE_INTERSECTION_BIT_KHR, createShaderModule(vkd, device, m_context.getBinaryCollection().get("rint"), 0), 1);
2529 rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR, createShaderModule(vkd, device, m_context.getBinaryCollection().get("chit"), 0), 1);
2530 }
2531 else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_CALL)
2532 {
2533 rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR, createShaderModule(vkd, device, m_context.getBinaryCollection().get("call"), 0), 1);
2534 rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR, createShaderModule(vkd, device, m_context.getBinaryCollection().get("subcall"), 0), 2);
2535 }
2536 else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_MISS)
2537 {
2538 rayTracingPipeline->addShader(VK_SHADER_STAGE_MISS_BIT_KHR, createShaderModule(vkd, device, m_context.getBinaryCollection().get("miss"), 0), 1);
2539 rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR, createShaderModule(vkd, device, m_context.getBinaryCollection().get("call"), 0), 2);
2540 }
2541 else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_HIT)
2542 {
2543 rayTracingPipeline->addShader(VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR, createShaderModule(vkd, device, m_context.getBinaryCollection().get("chit"), 0), 1);
2544 rayTracingPipeline->addShader(VK_SHADER_STAGE_CALLABLE_BIT_KHR, createShaderModule(vkd, device, m_context.getBinaryCollection().get("call"), 0), 2);
2545 }
2546 else
2547 {
2548 DE_ASSERT(false);
2549 }
2550
2551 pipeline = rayTracingPipeline->createPipeline(vkd, device, pipelineLayout.get());
2552
2553 if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_RGEN)
2554 {
2555 createSBTWithShaderRecord (vkd, device, alloc, pipeline.get(), rayTracingPipeline.get(), shaderGroupHandleSize, shaderGroupBaseAlignment,
2556 0u, 1u, raygenShaderBindingTable, raygenShaderBindingTableRegion);
2557 }
2558 else
2559 {
2560 raygenShaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 0, 1);
2561 raygenShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, raygenShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
2562 }
2563
2564
2565 if (interfaceType == InterfaceType::CALLABLE_DATA || interfaceType == InterfaceType::SHADER_RECORD_BUFFER_RGEN)
2566 {
2567 callableShaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
2568 callableShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
2569 }
2570 else if (interfaceType == InterfaceType::RAY_PAYLOAD || interfaceType == InterfaceType::HIT_ATTRIBUTES)
2571 {
2572 hitShaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 1, 1);
2573 hitShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, hitShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
2574 }
2575 else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_CALL)
2576 {
2577 createSBTWithShaderRecord (vkd, device, alloc, pipeline.get(), rayTracingPipeline.get(), shaderGroupHandleSize, shaderGroupBaseAlignment,
2578 1u, 2u, callableShaderBindingTable, callableShaderBindingTableRegion);
2579 }
2580 else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_MISS)
2581 {
2582 createSBTWithShaderRecord (vkd, device, alloc, pipeline.get(), rayTracingPipeline.get(), shaderGroupHandleSize, shaderGroupBaseAlignment,
2583 1u, 1u, missShaderBindingTable, missShaderBindingTableRegion);
2584
2585 // Callable shader table.
2586 callableShaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
2587 callableShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
2588 }
2589 else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_HIT)
2590 {
2591 createSBTWithShaderRecord (vkd, device, alloc, pipeline.get(), rayTracingPipeline.get(), shaderGroupHandleSize, shaderGroupBaseAlignment,
2592 1u, 1u, hitShaderBindingTable, hitShaderBindingTableRegion);
2593
2594 // Callable shader table.
2595 callableShaderBindingTable = rayTracingPipeline->createShaderBindingTable(vkd, device, pipeline.get(), alloc, shaderGroupHandleSize, shaderGroupBaseAlignment, 2, 1);
2596 callableShaderBindingTableRegion = makeStridedDeviceAddressRegionKHR(getBufferDeviceAddress(vkd, device, callableShaderBindingTable->get(), 0), shaderGroupHandleSize, shaderGroupHandleSize);
2597 }
2598 else
2599 {
2600 DE_ASSERT(false);
2601 }
2602 }
2603
2604 // Use ray tracing pipeline.
2605 vkd.cmdBindPipeline(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipeline.get());
2606 vkd.cmdBindDescriptorSets(cmdBuffer, VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, pipelineLayout.get(), 0u, 1u, &descriptorSet.get(), 0u, nullptr);
2607 vkd.cmdTraceRaysKHR(cmdBuffer, &raygenShaderBindingTableRegion, &missShaderBindingTableRegion, &hitShaderBindingTableRegion, &callableShaderBindingTableRegion, 1u, 1u, 1u);
2608
2609 // Synchronize output and callee buffers.
2610 const auto memBarrier = makeMemoryBarrier(VK_ACCESS_SHADER_WRITE_BIT, VK_ACCESS_HOST_READ_BIT);
2611 vkd.cmdPipelineBarrier(cmdBuffer, VK_PIPELINE_STAGE_RAY_TRACING_SHADER_BIT_KHR, VK_PIPELINE_STAGE_HOST_BIT, 0u, 1u, &memBarrier, 0u, nullptr, 0u, nullptr);
2612
2613 endCommandBuffer(vkd, cmdBuffer);
2614 submitCommandsAndWait(vkd, device, queue, cmdBuffer);
2615
2616 // Verify storage buffer.
2617 invalidateAlloc(vkd, device, storageBufferAlloc);
2618 deMemcpy(storageBufferData.data(), storageBufferPtr, storageBufferSize);
2619
2620 // These values must match what the shaders store.
2621 std::vector<deUint32> expectedData;
2622 if (interfaceType == InterfaceType::RAY_PAYLOAD)
2623 {
2624 expectedData.push_back(103u);
2625 expectedData.push_back(100u);
2626 }
2627 else if (interfaceType == InterfaceType::CALLABLE_DATA)
2628 {
2629 expectedData.push_back(200u);
2630 expectedData.push_back(100u);
2631 }
2632 else if (interfaceType == InterfaceType::HIT_ATTRIBUTES)
2633 {
2634 expectedData.push_back(300u);
2635 expectedData.push_back(315u);
2636 expectedData.push_back(330u);
2637 }
2638 else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_RGEN)
2639 {
2640 expectedData.push_back(402u);
2641 expectedData.push_back(450u);
2642 }
2643 else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_CALL)
2644 {
2645 expectedData.push_back(806u);
2646 expectedData.push_back(403u);
2647 expectedData.push_back(450u);
2648 }
2649 else if (interfaceType == InterfaceType::SHADER_RECORD_BUFFER_MISS || interfaceType == InterfaceType::SHADER_RECORD_BUFFER_HIT)
2650 {
2651 expectedData.push_back(1200u);
2652 expectedData.push_back( 400u);
2653 expectedData.push_back( 490u);
2654 }
2655 else
2656 {
2657 DE_ASSERT(false);
2658 }
2659
2660 size_t pos;
2661 for (pos = 0u; pos < expectedData.size(); ++pos)
2662 {
2663 const auto& stored = storageBufferData.at(pos);
2664 const auto& expected = expectedData.at(pos);
2665 if (stored != expected)
2666 {
2667 std::ostringstream msg;
2668 msg << "Unexpected output value found at position " << pos << " (expected " << expected << " but got " << stored << ")";
2669 return tcu::TestStatus::fail(msg.str());
2670 }
2671 }
2672
2673 // Expect zeros in unused positions, as filled on the host.
2674 for (; pos < storageBufferData.size(); ++pos)
2675 {
2676 const auto& stored = storageBufferData.at(pos);
2677 if (stored != 0u)
2678 {
2679 std::ostringstream msg;
2680 msg << "Unexpected output value found at position " << pos << " (expected 0 but got " << stored << ")";
2681 return tcu::TestStatus::fail(msg.str());
2682 }
2683 }
2684
2685 return tcu::TestStatus::pass("Pass");
2686 }
2687
2688 } // anonymous namespace
2689
createDataSpillTests(tcu::TestContext & testCtx)2690 tcu::TestCaseGroup* createDataSpillTests(tcu::TestContext& testCtx)
2691 {
2692 de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, "data_spill", "Ray tracing tests for data spilling and unspilling around shader calls"));
2693
2694 struct
2695 {
2696 CallType callType;
2697 const char* name;
2698 } callTypes[] =
2699 {
2700 { CallType::EXECUTE_CALLABLE, "execute_callable" },
2701 { CallType::TRACE_RAY, "trace_ray" },
2702 { CallType::REPORT_INTERSECTION, "report_intersection" },
2703 };
2704
2705 struct
2706 {
2707 DataType dataType;
2708 const char* name;
2709 } dataTypes[] =
2710 {
2711 { DataType::INT32, "int32" },
2712 { DataType::UINT32, "uint32" },
2713 { DataType::INT64, "int64" },
2714 { DataType::UINT64, "uint64" },
2715 { DataType::INT16, "int16" },
2716 { DataType::UINT16, "uint16" },
2717 { DataType::INT8, "int8" },
2718 { DataType::UINT8, "uint8" },
2719 { DataType::FLOAT32, "float32" },
2720 { DataType::FLOAT64, "float64" },
2721 { DataType::FLOAT16, "float16" },
2722 { DataType::STRUCT, "struct" },
2723 { DataType::SAMPLER, "sampler" },
2724 { DataType::IMAGE, "image" },
2725 { DataType::SAMPLED_IMAGE, "combined" },
2726 { DataType::PTR_IMAGE, "ptr_image" },
2727 { DataType::PTR_SAMPLER, "ptr_sampler" },
2728 { DataType::PTR_SAMPLED_IMAGE, "ptr_combined" },
2729 { DataType::PTR_TEXEL, "ptr_texel" },
2730 { DataType::OP_NULL, "op_null" },
2731 { DataType::OP_UNDEF, "op_undef" },
2732 };
2733
2734 struct
2735 {
2736 VectorType vectorType;
2737 const char* prefix;
2738 } vectorTypes[] =
2739 {
2740 { VectorType::SCALAR, "" },
2741 { VectorType::V2, "v2" },
2742 { VectorType::V3, "v3" },
2743 { VectorType::V4, "v4" },
2744 { VectorType::A5, "a5" },
2745 };
2746
2747 for (int callTypeIdx = 0; callTypeIdx < DE_LENGTH_OF_ARRAY(callTypes); ++callTypeIdx)
2748 {
2749 const auto& entryCallTypes = callTypes[callTypeIdx];
2750
2751 de::MovePtr<tcu::TestCaseGroup> callTypeGroup(new tcu::TestCaseGroup(testCtx, entryCallTypes.name, ""));
2752 for (int dataTypeIdx = 0; dataTypeIdx < DE_LENGTH_OF_ARRAY(dataTypes); ++dataTypeIdx)
2753 {
2754 const auto& entryDataTypes = dataTypes[dataTypeIdx];
2755
2756 for (int vectorTypeIdx = 0; vectorTypeIdx < DE_LENGTH_OF_ARRAY(vectorTypes); ++vectorTypeIdx)
2757 {
2758 const auto& entryVectorTypes = vectorTypes[vectorTypeIdx];
2759
2760 if ((samplersNeeded(entryDataTypes.dataType)
2761 || storageImageNeeded(entryDataTypes.dataType)
2762 || entryDataTypes.dataType == DataType::STRUCT
2763 || entryDataTypes.dataType == DataType::OP_NULL
2764 || entryDataTypes.dataType == DataType::OP_UNDEF)
2765 && entryVectorTypes.vectorType != VectorType::SCALAR)
2766 {
2767 continue;
2768 }
2769
2770 DataSpillTestCase::TestParams params;
2771 params.callType = entryCallTypes.callType;
2772 params.dataType = entryDataTypes.dataType;
2773 params.vectorType = entryVectorTypes.vectorType;
2774
2775 const auto testName = std::string(entryVectorTypes.prefix) + entryDataTypes.name;
2776
2777 callTypeGroup->addChild(new DataSpillTestCase(testCtx, testName, "", params));
2778 }
2779 }
2780
2781 group->addChild(callTypeGroup.release());
2782 }
2783
2784 // Pipeline interface tests.
2785 de::MovePtr<tcu::TestCaseGroup> pipelineInterfaceGroup(new tcu::TestCaseGroup(testCtx, "pipeline_interface", "Test data spilling and unspilling of pipeline interface variables"));
2786
2787 struct
2788 {
2789 InterfaceType interfaceType;
2790 const char* name;
2791 } interfaceTypes[] =
2792 {
2793 { InterfaceType::RAY_PAYLOAD, "ray_payload" },
2794 { InterfaceType::CALLABLE_DATA, "callable_data" },
2795 { InterfaceType::HIT_ATTRIBUTES, "hit_attributes" },
2796 { InterfaceType::SHADER_RECORD_BUFFER_RGEN, "shader_record_buffer_rgen" },
2797 { InterfaceType::SHADER_RECORD_BUFFER_CALL, "shader_record_buffer_call" },
2798 { InterfaceType::SHADER_RECORD_BUFFER_MISS, "shader_record_buffer_miss" },
2799 { InterfaceType::SHADER_RECORD_BUFFER_HIT, "shader_record_buffer_hit" },
2800 };
2801
2802 for (int idx = 0; idx < DE_LENGTH_OF_ARRAY(interfaceTypes); ++idx)
2803 {
2804 const auto& entry = interfaceTypes[idx];
2805 DataSpillPipelineInterfaceTestCase::TestParams params;
2806
2807 params.interfaceType = entry.interfaceType;
2808
2809 pipelineInterfaceGroup->addChild(new DataSpillPipelineInterfaceTestCase(testCtx, entry.name, "", params));
2810 }
2811
2812 group->addChild(pipelineInterfaceGroup.release());
2813
2814 return group.release();
2815 }
2816
2817 } // RayTracing
2818 } // vkt
2819
2820