1 // Copyright 2021 The SwiftShader Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //    http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "Device.hpp"
16 #include "Driver.hpp"
17 
18 #include "gmock/gmock.h"
19 #include "gtest/gtest.h"
20 
21 #include "spirv-tools/libspirv.hpp"
22 
23 #include <cstring>
24 #include <sstream>
25 
26 namespace {
alignUp(size_t val,size_t alignment)27 size_t alignUp(size_t val, size_t alignment)
28 {
29 	return alignment * ((val + alignment - 1) / alignment);
30 }
31 }  // anonymous namespace
32 
33 struct ComputeParams
34 {
35 	size_t numElements;
36 	int localSizeX;
37 	int localSizeY;
38 	int localSizeZ;
39 
operator <<(std::ostream & os,const ComputeParams & params)40 	friend std::ostream &operator<<(std::ostream &os, const ComputeParams &params)
41 	{
42 		return os << "ComputeParams{"
43 		          << "numElements: " << params.numElements << ", "
44 		          << "localSizeX: " << params.localSizeX << ", "
45 		          << "localSizeY: " << params.localSizeY << ", "
46 		          << "localSizeZ: " << params.localSizeZ << "}";
47 	}
48 };
49 
50 class ComputeTest : public testing::TestWithParam<ComputeParams>
51 {
52 protected:
53 	static Driver driver;
54 
SetUpTestSuite()55 	static void SetUpTestSuite()
56 	{
57 		ASSERT_TRUE(driver.loadSwiftShader());
58 	}
59 
TearDownTestSuite()60 	static void TearDownTestSuite()
61 	{
62 		driver.unload();
63 	}
64 };
65 
66 Driver ComputeTest::driver;
67 
compileSpirv(const char * assembly)68 std::vector<uint32_t> compileSpirv(const char *assembly)
69 {
70 	spvtools::SpirvTools core(SPV_ENV_VULKAN_1_0);
71 
72 	core.SetMessageConsumer([](spv_message_level_t, const char *, const spv_position_t &p, const char *m) {
73 		FAIL() << p.line << ":" << p.column << ": " << m;
74 	});
75 
76 	std::vector<uint32_t> spirv;
77 	EXPECT_TRUE(core.Assemble(assembly, &spirv));
78 	EXPECT_TRUE(core.Validate(spirv));
79 
80 	// Warn if the disassembly does not match the source assembly.
81 	// We do this as debugging tests in the debugger is often made much harder
82 	// if the SSA names (%X) in the debugger do not match the source.
83 	std::string disassembled;
84 	core.Disassemble(spirv, &disassembled, SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
85 	if(disassembled != assembly)
86 	{
87 		printf("-- WARNING: Disassembly does not match assembly: ---\n\n");
88 
89 		auto splitLines = [](const std::string &str) -> std::vector<std::string> {
90 			std::stringstream ss(str);
91 			std::vector<std::string> out;
92 			std::string line;
93 			while(std::getline(ss, line, '\n')) { out.push_back(line); }
94 			return out;
95 		};
96 
97 		auto srcLines = splitLines(std::string(assembly));
98 		auto disLines = splitLines(disassembled);
99 
100 		for(size_t line = 0; line < srcLines.size() && line < disLines.size(); line++)
101 		{
102 			auto srcLine = (line < srcLines.size()) ? srcLines[line] : "<missing>";
103 			auto disLine = (line < disLines.size()) ? disLines[line] : "<missing>";
104 			if(srcLine != disLine)
105 			{
106 				printf("%zu: '%s' != '%s'\n", line, srcLine.c_str(), disLine.c_str());
107 			}
108 		}
109 		printf("\n\n---\nExpected:\n\n%s", disassembled.c_str());
110 	}
111 
112 	return spirv;
113 }
114 
115 #define VK_ASSERT(x) ASSERT_EQ(x, VK_SUCCESS)
116 
117 // Base class for compute tests that read from an input buffer and write to an
118 // output buffer of same length.
119 class SwiftShaderVulkanBufferToBufferComputeTest : public ComputeTest
120 {
121 public:
122 	void test(const std::string &shader,
123 	          std::function<uint32_t(uint32_t idx)> input,
124 	          std::function<uint32_t(uint32_t idx)> expected);
125 };
126 
test(const std::string & shader,std::function<uint32_t (uint32_t idx)> input,std::function<uint32_t (uint32_t idx)> expected)127 void SwiftShaderVulkanBufferToBufferComputeTest::test(
128     const std::string &shader,
129     std::function<uint32_t(uint32_t idx)> input,
130     std::function<uint32_t(uint32_t idx)> expected)
131 {
132 	auto code = compileSpirv(shader.c_str());
133 
134 	const VkInstanceCreateInfo createInfo = {
135 		VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO,  // sType
136 		nullptr,                                 // pNext
137 		0,                                       // flags
138 		nullptr,                                 // pApplicationInfo
139 		0,                                       // enabledLayerCount
140 		nullptr,                                 // ppEnabledLayerNames
141 		0,                                       // enabledExtensionCount
142 		nullptr,                                 // ppEnabledExtensionNames
143 	};
144 
145 	VkInstance instance = VK_NULL_HANDLE;
146 	VK_ASSERT(driver.vkCreateInstance(&createInfo, nullptr, &instance));
147 
148 	ASSERT_TRUE(driver.resolve(instance));
149 
150 	std::unique_ptr<Device> device;
151 	VK_ASSERT(Device::CreateComputeDevice(&driver, instance, device));
152 	ASSERT_TRUE(device->IsValid());
153 
154 	// struct Buffers
155 	// {
156 	//     uint32_t pad0[63];
157 	//     uint32_t magic0;
158 	//     uint32_t in[NUM_ELEMENTS]; // Aligned to 0x100
159 	//     uint32_t magic1;
160 	//     uint32_t pad1[N];
161 	//     uint32_t magic2;
162 	//     uint32_t out[NUM_ELEMENTS]; // Aligned to 0x100
163 	//     uint32_t magic3;
164 	// };
165 	static constexpr uint32_t magic0 = 0x01234567;
166 	static constexpr uint32_t magic1 = 0x89abcdef;
167 	static constexpr uint32_t magic2 = 0xfedcba99;
168 	static constexpr uint32_t magic3 = 0x87654321;
169 	size_t numElements = GetParam().numElements;
170 	size_t alignElements = 0x100 / sizeof(uint32_t);
171 	size_t magic0Offset = alignElements - 1;
172 	size_t inOffset = 1 + magic0Offset;
173 	size_t magic1Offset = numElements + inOffset;
174 	size_t magic2Offset = alignUp(magic1Offset + 1, alignElements) - 1;
175 	size_t outOffset = 1 + magic2Offset;
176 	size_t magic3Offset = numElements + outOffset;
177 	size_t buffersTotalElements = alignUp(1 + magic3Offset, alignElements);
178 	size_t buffersSize = sizeof(uint32_t) * buffersTotalElements;
179 
180 	VkDeviceMemory memory;
181 	VK_ASSERT(device->AllocateMemory(buffersSize,
182 	                                 VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT,
183 	                                 &memory));
184 
185 	uint32_t *buffers;
186 	VK_ASSERT(device->MapMemory(memory, 0, buffersSize, 0, (void **)&buffers));
187 
188 	buffers[magic0Offset] = magic0;
189 	buffers[magic1Offset] = magic1;
190 	buffers[magic2Offset] = magic2;
191 	buffers[magic3Offset] = magic3;
192 
193 	for(size_t i = 0; i < numElements; i++)
194 	{
195 		buffers[inOffset + i] = input((uint32_t)i);
196 	}
197 
198 	device->UnmapMemory(memory);
199 	buffers = nullptr;
200 
201 	VkBuffer bufferIn;
202 	VK_ASSERT(device->CreateStorageBuffer(memory,
203 	                                      sizeof(uint32_t) * numElements,
204 	                                      sizeof(uint32_t) * inOffset,
205 	                                      &bufferIn));
206 
207 	VkBuffer bufferOut;
208 	VK_ASSERT(device->CreateStorageBuffer(memory,
209 	                                      sizeof(uint32_t) * numElements,
210 	                                      sizeof(uint32_t) * outOffset,
211 	                                      &bufferOut));
212 
213 	VkShaderModule shaderModule;
214 	VK_ASSERT(device->CreateShaderModule(code, &shaderModule));
215 
216 	std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings = {
217 		{
218 		    0,                                  // binding
219 		    VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,  // descriptorType
220 		    1,                                  // descriptorCount
221 		    VK_SHADER_STAGE_COMPUTE_BIT,        // stageFlags
222 		    0,                                  // pImmutableSamplers
223 		},
224 		{
225 		    1,                                  // binding
226 		    VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,  // descriptorType
227 		    1,                                  // descriptorCount
228 		    VK_SHADER_STAGE_COMPUTE_BIT,        // stageFlags
229 		    0,                                  // pImmutableSamplers
230 		}
231 	};
232 
233 	VkDescriptorSetLayout descriptorSetLayout;
234 	VK_ASSERT(device->CreateDescriptorSetLayout(descriptorSetLayoutBindings, &descriptorSetLayout));
235 
236 	VkPipelineLayout pipelineLayout;
237 	VK_ASSERT(device->CreatePipelineLayout(descriptorSetLayout, &pipelineLayout));
238 
239 	VkPipeline pipeline;
240 	VK_ASSERT(device->CreateComputePipeline(shaderModule, pipelineLayout, &pipeline));
241 
242 	VkDescriptorPool descriptorPool;
243 	VK_ASSERT(device->CreateStorageBufferDescriptorPool(2, &descriptorPool));
244 
245 	VkDescriptorSet descriptorSet;
246 	VK_ASSERT(device->AllocateDescriptorSet(descriptorPool, descriptorSetLayout, &descriptorSet));
247 
248 	std::vector<VkDescriptorBufferInfo> descriptorBufferInfos = {
249 		{
250 		    bufferIn,       // buffer
251 		    0,              // offset
252 		    VK_WHOLE_SIZE,  // range
253 		},
254 		{
255 		    bufferOut,      // buffer
256 		    0,              // offset
257 		    VK_WHOLE_SIZE,  // range
258 		}
259 	};
260 	device->UpdateStorageBufferDescriptorSets(descriptorSet, descriptorBufferInfos);
261 
262 	VkCommandPool commandPool;
263 	VK_ASSERT(device->CreateCommandPool(&commandPool));
264 
265 	VkCommandBuffer commandBuffer;
266 	VK_ASSERT(device->AllocateCommandBuffer(commandPool, &commandBuffer));
267 
268 	VK_ASSERT(device->BeginCommandBuffer(VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT, commandBuffer));
269 
270 	driver.vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
271 
272 	driver.vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipelineLayout, 0, 1, &descriptorSet,
273 	                               0, nullptr);
274 
275 	driver.vkCmdDispatch(commandBuffer, (uint32_t)(numElements / GetParam().localSizeX), 1, 1);
276 
277 	VK_ASSERT(driver.vkEndCommandBuffer(commandBuffer));
278 
279 	VK_ASSERT(device->QueueSubmitAndWait(commandBuffer));
280 
281 	VK_ASSERT(device->MapMemory(memory, 0, buffersSize, 0, (void **)&buffers));
282 
283 	for(size_t i = 0; i < numElements; ++i)
284 	{
285 		auto got = buffers[i + outOffset];
286 		EXPECT_EQ(expected((uint32_t)i), got) << "Unexpected output at " << i;
287 	}
288 
289 	// Check for writes outside of bounds.
290 	EXPECT_EQ(buffers[magic0Offset], magic0);
291 	EXPECT_EQ(buffers[magic1Offset], magic1);
292 	EXPECT_EQ(buffers[magic2Offset], magic2);
293 	EXPECT_EQ(buffers[magic3Offset], magic3);
294 
295 	device->UnmapMemory(memory);
296 	buffers = nullptr;
297 
298 	device->FreeCommandBuffer(commandPool, commandBuffer);
299 	device->FreeMemory(memory);
300 	device->DestroyPipeline(pipeline);
301 	device->DestroyCommandPool(commandPool);
302 	device->DestroyPipelineLayout(pipelineLayout);
303 	device->DestroyDescriptorSetLayout(descriptorSetLayout);
304 	device->DestroyDescriptorPool(descriptorPool);
305 	device->DestroyBuffer(bufferIn);
306 	device->DestroyBuffer(bufferOut);
307 	device->DestroyShaderModule(shaderModule);
308 	device.reset(nullptr);
309 	driver.vkDestroyInstance(instance, nullptr);
310 }
311 
312 INSTANTIATE_TEST_SUITE_P(ComputeParams, SwiftShaderVulkanBufferToBufferComputeTest, testing::Values(ComputeParams{ 512, 1, 1, 1 }, ComputeParams{ 512, 2, 1, 1 }, ComputeParams{ 512, 4, 1, 1 }, ComputeParams{ 512, 8, 1, 1 }, ComputeParams{ 512, 16, 1, 1 }, ComputeParams{ 512, 32, 1, 1 },
313 
314                                                                                                     // Non-multiple of SIMD-lane.
315                                                                                                     ComputeParams{ 3, 1, 1, 1 }, ComputeParams{ 2, 1, 1, 1 }));
316 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,Memcpy)317 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, Memcpy)
318 {
319 	std::stringstream src;
320 	// #version 450
321 	// layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
322 	// layout(binding = 0, std430) buffer InBuffer
323 	// {
324 	//     int Data[];
325 	// } In;
326 	// layout(binding = 1, std430) buffer OutBuffer
327 	// {
328 	//     int Data[];
329 	// } Out;
330 	// void main()
331 	// {
332 	//     Out.Data[gl_GlobalInvocationID.x] = In.Data[gl_GlobalInvocationID.x];
333 	// }
334 	// clang-format off
335     src <<
336         "OpCapability Shader\n"
337         "OpMemoryModel Logical GLSL450\n"
338         "OpEntryPoint GLCompute %1 \"main\" %2\n"
339         "OpExecutionMode %1 LocalSize " <<
340         GetParam().localSizeX << " " <<
341         GetParam().localSizeY << " " <<
342         GetParam().localSizeZ << "\n" <<
343         "OpDecorate %3 ArrayStride 4\n"
344         "OpMemberDecorate %4 0 Offset 0\n"
345         "OpDecorate %4 BufferBlock\n"
346         "OpDecorate %5 DescriptorSet 0\n"
347         "OpDecorate %5 Binding 1\n"
348         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
349         "OpDecorate %6 DescriptorSet 0\n"
350         "OpDecorate %6 Binding 0\n"
351         "%7 = OpTypeVoid\n"
352         "%8 = OpTypeFunction %7\n"             // void()
353         "%9 = OpTypeInt 32 1\n"                // int32
354         "%10 = OpTypeInt 32 0\n"                // uint32
355         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
356         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
357         "%11 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
358         "%5 = OpVariable %11 Uniform\n"        // struct{ int32[] }* in
359         "%12 = OpConstant %9 0\n"               // int32(0)
360         "%13 = OpConstant %10 0\n"              // uint32(0)
361         "%14 = OpTypeVector %10 3\n"            // vec3<int32>
362         "%15 = OpTypePointer Input %14\n"       // vec3<int32>*
363         "%2 = OpVariable %15 Input\n"          // gl_GlobalInvocationId
364         "%16 = OpTypePointer Input %10\n"       // uint32*
365         "%6 = OpVariable %11 Uniform\n"        // struct{ int32[] }* out
366         "%17 = OpTypePointer Uniform %9\n"      // int32*
367         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
368         "%18 = OpLabel\n"
369         "%19 = OpAccessChain %16 %2 %13\n"      // &gl_GlobalInvocationId.x
370         "%20 = OpLoad %10 %19\n"                // gl_GlobalInvocationId.x
371         "%21 = OpAccessChain %17 %6 %12 %20\n"  // &in.arr[gl_GlobalInvocationId.x]
372         "%22 = OpLoad %9 %21\n"                 // in.arr[gl_GlobalInvocationId.x]
373         "%23 = OpAccessChain %17 %5 %12 %20\n"  // &out.arr[gl_GlobalInvocationId.x]
374         "OpStore %23 %22\n"               // out.arr[gl_GlobalInvocationId.x] = in[gl_GlobalInvocationId.x]
375         "OpReturn\n"
376         "OpFunctionEnd\n";
377 	// clang-format on
378 
379 	test(
380 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i; });
381 }
382 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,GlobalInvocationId)383 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, GlobalInvocationId)
384 {
385 	std::stringstream src;
386 	// clang-format off
387     src <<
388         "OpCapability Shader\n"
389         "OpMemoryModel Logical GLSL450\n"
390         "OpEntryPoint GLCompute %1 \"main\" %2\n"
391         "OpExecutionMode %1 LocalSize " <<
392         GetParam().localSizeX << " " <<
393         GetParam().localSizeY << " " <<
394         GetParam().localSizeZ << "\n" <<
395         "OpDecorate %3 ArrayStride 4\n"
396         "OpMemberDecorate %4 0 Offset 0\n"
397         "OpDecorate %4 BufferBlock\n"
398         "OpDecorate %5 DescriptorSet 0\n"
399         "OpDecorate %5 Binding 1\n"
400         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
401         "OpDecorate %6 DescriptorSet 0\n"
402         "OpDecorate %6 Binding 0\n"
403         "%7 = OpTypeVoid\n"
404         "%8 = OpTypeFunction %7\n"             // void()
405         "%9 = OpTypeInt 32 1\n"                // int32
406         "%10 = OpTypeInt 32 0\n"                // uint32
407         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
408         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
409         "%11 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
410         "%5 = OpVariable %11 Uniform\n"        // struct{ int32[] }* in
411         "%12 = OpConstant %9 0\n"               // int32(0)
412         "%13 = OpConstant %9 1\n"               // int32(1)
413         "%14 = OpConstant %10 0\n"              // uint32(0)
414         "%15 = OpConstant %10 1\n"              // uint32(1)
415         "%16 = OpConstant %10 2\n"              // uint32(2)
416         "%17 = OpTypeVector %10 3\n"            // vec3<int32>
417         "%18 = OpTypePointer Input %17\n"       // vec3<int32>*
418         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
419         "%19 = OpTypePointer Input %10\n"       // uint32*
420         "%6 = OpVariable %11 Uniform\n"        // struct{ int32[] }* out
421         "%20 = OpTypePointer Uniform %9\n"      // int32*
422         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
423         "%21 = OpLabel\n"
424         "%22 = OpAccessChain %19 %2 %14\n"      // &gl_GlobalInvocationId.x
425         "%23 = OpAccessChain %19 %2 %15\n"      // &gl_GlobalInvocationId.y
426         "%24 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.z
427         "%25 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
428         "%26 = OpLoad %10 %23\n"                // gl_GlobalInvocationId.y
429         "%27 = OpLoad %10 %24\n"                // gl_GlobalInvocationId.z
430         "%28 = OpAccessChain %20 %6 %12 %25\n"  // &in.arr[gl_GlobalInvocationId.x]
431         "%29 = OpLoad %9 %28\n"                 // out.arr[gl_GlobalInvocationId.x]
432         "%30 = OpIAdd %9 %29 %26\n"             // in[gl_GlobalInvocationId.x] + gl_GlobalInvocationId.y
433         "%31 = OpIAdd %9 %30 %27\n"             // in[gl_GlobalInvocationId.x] + gl_GlobalInvocationId.y + gl_GlobalInvocationId.z
434         "%32 = OpAccessChain %20 %5 %12 %25\n"  // &out.arr[gl_GlobalInvocationId.x]
435         "OpStore %32 %31\n"               // out.arr[gl_GlobalInvocationId.x] = in[gl_GlobalInvocationId.x] + gl_GlobalInvocationId.y + gl_GlobalInvocationId.z
436         "OpReturn\n"
437         "OpFunctionEnd\n";
438 	// clang-format on
439 
440 	// gl_GlobalInvocationId.y and gl_GlobalInvocationId.z should both be zero.
441 	test(
442 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i; });
443 }
444 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,BranchSimple)445 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchSimple)
446 {
447 	std::stringstream src;
448 	// clang-format off
449     src <<
450         "OpCapability Shader\n"
451         "OpMemoryModel Logical GLSL450\n"
452         "OpEntryPoint GLCompute %1 \"main\" %2\n"
453         "OpExecutionMode %1 LocalSize " <<
454         GetParam().localSizeX << " " <<
455         GetParam().localSizeY << " " <<
456         GetParam().localSizeZ << "\n" <<
457         "OpDecorate %3 ArrayStride 4\n"
458         "OpMemberDecorate %4 0 Offset 0\n"
459         "OpDecorate %4 BufferBlock\n"
460         "OpDecorate %5 DescriptorSet 0\n"
461         "OpDecorate %5 Binding 1\n"
462         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
463         "OpDecorate %6 DescriptorSet 0\n"
464         "OpDecorate %6 Binding 0\n"
465         "%7 = OpTypeVoid\n"
466         "%8 = OpTypeFunction %7\n"             // void()
467         "%9 = OpTypeInt 32 1\n"                // int32
468         "%10 = OpTypeInt 32 0\n"                // uint32
469         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
470         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
471         "%11 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
472         "%5 = OpVariable %11 Uniform\n"        // struct{ int32[] }* in
473         "%12 = OpConstant %9 0\n"               // int32(0)
474         "%13 = OpConstant %10 0\n"              // uint32(0)
475         "%14 = OpTypeVector %10 3\n"            // vec3<int32>
476         "%15 = OpTypePointer Input %14\n"       // vec3<int32>*
477         "%2 = OpVariable %15 Input\n"          // gl_GlobalInvocationId
478         "%16 = OpTypePointer Input %10\n"       // uint32*
479         "%6 = OpVariable %11 Uniform\n"        // struct{ int32[] }* out
480         "%17 = OpTypePointer Uniform %9\n"      // int32*
481         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
482         "%18 = OpLabel\n"
483         "%19 = OpAccessChain %16 %2 %13\n"      // &gl_GlobalInvocationId.x
484         "%20 = OpLoad %10 %19\n"                // gl_GlobalInvocationId.x
485         "%21 = OpAccessChain %17 %6 %12 %20\n"  // &in.arr[gl_GlobalInvocationId.x]
486         "%22 = OpLoad %9 %21\n"                 // in.arr[gl_GlobalInvocationId.x]
487         "%23 = OpAccessChain %17 %5 %12 %20\n"  // &out.arr[gl_GlobalInvocationId.x]
488                                                 // Start of branch logic
489                                                 // %22 = in value
490         "OpBranch %24\n"
491         "%24 = OpLabel\n"
492         "OpBranch %25\n"
493         "%25 = OpLabel\n"
494         "OpBranch %26\n"
495         "%26 = OpLabel\n"
496         // %22 = out value
497         // End of branch logic
498         "OpStore %23 %22\n"
499         "OpReturn\n"
500         "OpFunctionEnd\n";
501 	// clang-format on
502 
503 	test(
504 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i; });
505 }
506 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,BranchDeclareSSA)507 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchDeclareSSA)
508 {
509 	std::stringstream src;
510 	// clang-format off
511     src <<
512         "OpCapability Shader\n"
513         "OpMemoryModel Logical GLSL450\n"
514         "OpEntryPoint GLCompute %1 \"main\" %2\n"
515         "OpExecutionMode %1 LocalSize " <<
516         GetParam().localSizeX << " " <<
517         GetParam().localSizeY << " " <<
518         GetParam().localSizeZ << "\n" <<
519         "OpDecorate %3 ArrayStride 4\n"
520         "OpMemberDecorate %4 0 Offset 0\n"
521         "OpDecorate %4 BufferBlock\n"
522         "OpDecorate %5 DescriptorSet 0\n"
523         "OpDecorate %5 Binding 1\n"
524         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
525         "OpDecorate %6 DescriptorSet 0\n"
526         "OpDecorate %6 Binding 0\n"
527         "%7 = OpTypeVoid\n"
528         "%8 = OpTypeFunction %7\n"             // void()
529         "%9 = OpTypeInt 32 1\n"                // int32
530         "%10 = OpTypeInt 32 0\n"                // uint32
531         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
532         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
533         "%11 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
534         "%5 = OpVariable %11 Uniform\n"        // struct{ int32[] }* in
535         "%12 = OpConstant %9 0\n"               // int32(0)
536         "%13 = OpConstant %10 0\n"              // uint32(0)
537         "%14 = OpTypeVector %10 3\n"            // vec3<int32>
538         "%15 = OpTypePointer Input %14\n"       // vec3<int32>*
539         "%2 = OpVariable %15 Input\n"          // gl_GlobalInvocationId
540         "%16 = OpTypePointer Input %10\n"       // uint32*
541         "%6 = OpVariable %11 Uniform\n"        // struct{ int32[] }* out
542         "%17 = OpTypePointer Uniform %9\n"      // int32*
543         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
544         "%18 = OpLabel\n"
545         "%19 = OpAccessChain %16 %2 %13\n"      // &gl_GlobalInvocationId.x
546         "%20 = OpLoad %10 %19\n"                // gl_GlobalInvocationId.x
547         "%21 = OpAccessChain %17 %6 %12 %20\n"  // &in.arr[gl_GlobalInvocationId.x]
548         "%22 = OpLoad %9 %21\n"                 // in.arr[gl_GlobalInvocationId.x]
549         "%23 = OpAccessChain %17 %5 %12 %20\n"  // &out.arr[gl_GlobalInvocationId.x]
550                                                 // Start of branch logic
551                                                 // %22 = in value
552         "OpBranch %24\n"
553         "%24 = OpLabel\n"
554         "%25 = OpIAdd %9 %22 %22\n"             // %25 = in*2
555         "OpBranch %26\n"
556         "%26 = OpLabel\n"
557         "OpBranch %27\n"
558         "%27 = OpLabel\n"
559         // %25 = out value
560         // End of branch logic
561         "OpStore %23 %25\n"               // use SSA value from previous block
562         "OpReturn\n"
563         "OpFunctionEnd\n";
564 	// clang-format on
565 
566 	test(
567 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i * 2; });
568 }
569 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,BranchConditionalSimple)570 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchConditionalSimple)
571 {
572 	std::stringstream src;
573 	// clang-format off
574     src <<
575         "OpCapability Shader\n"
576         "OpMemoryModel Logical GLSL450\n"
577         "OpEntryPoint GLCompute %1 \"main\" %2\n"
578         "OpExecutionMode %1 LocalSize " <<
579         GetParam().localSizeX << " " <<
580         GetParam().localSizeY << " " <<
581         GetParam().localSizeZ << "\n" <<
582         "OpDecorate %3 ArrayStride 4\n"
583         "OpMemberDecorate %4 0 Offset 0\n"
584         "OpDecorate %4 BufferBlock\n"
585         "OpDecorate %5 DescriptorSet 0\n"
586         "OpDecorate %5 Binding 1\n"
587         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
588         "OpDecorate %6 DescriptorSet 0\n"
589         "OpDecorate %6 Binding 0\n"
590         "%7 = OpTypeVoid\n"
591         "%8 = OpTypeFunction %7\n"             // void()
592         "%9 = OpTypeInt 32 1\n"                // int32
593         "%10 = OpTypeInt 32 0\n"                // uint32
594         "%11 = OpTypeBool\n"
595         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
596         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
597         "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
598         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
599         "%13 = OpConstant %9 0\n"               // int32(0)
600         "%14 = OpConstant %9 2\n"               // int32(2)
601         "%15 = OpConstant %10 0\n"              // uint32(0)
602         "%16 = OpTypeVector %10 3\n"            // vec4<int32>
603         "%17 = OpTypePointer Input %16\n"       // vec4<int32>*
604         "%2 = OpVariable %17 Input\n"          // gl_GlobalInvocationId
605         "%18 = OpTypePointer Input %10\n"       // uint32*
606         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
607         "%19 = OpTypePointer Uniform %9\n"      // int32*
608         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
609         "%20 = OpLabel\n"
610         "%21 = OpAccessChain %18 %2 %15\n"      // &gl_GlobalInvocationId.x
611         "%22 = OpLoad %10 %21\n"                // gl_GlobalInvocationId.x
612         "%23 = OpAccessChain %19 %6 %13 %22\n"  // &in.arr[gl_GlobalInvocationId.x]
613         "%24 = OpLoad %9 %23\n"                 // in.arr[gl_GlobalInvocationId.x]
614         "%25 = OpAccessChain %19 %5 %13 %22\n"  // &out.arr[gl_GlobalInvocationId.x]
615                                                 // Start of branch logic
616                                                 // %24 = in value
617         "%26 = OpSMod %9 %24 %14\n"             // in % 2
618         "%27 = OpIEqual %11 %26 %13\n"          // (in % 2) == 0
619         "OpSelectionMerge %28 None\n"
620         "OpBranchConditional %27 %28 %28\n" // Both go to %28
621         "%28 = OpLabel\n"
622         // %26 = out value
623         // End of branch logic
624         "OpStore %25 %26\n"               // use SSA value from previous block
625         "OpReturn\n"
626         "OpFunctionEnd\n";
627 	// clang-format on
628 
629 	test(
630 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i % 2; });
631 }
632 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,BranchConditionalTwoEmptyBlocks)633 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchConditionalTwoEmptyBlocks)
634 {
635 	std::stringstream src;
636 	// clang-format off
637     src <<
638         "OpCapability Shader\n"
639         "OpMemoryModel Logical GLSL450\n"
640         "OpEntryPoint GLCompute %1 \"main\" %2\n"
641         "OpExecutionMode %1 LocalSize " <<
642         GetParam().localSizeX << " " <<
643         GetParam().localSizeY << " " <<
644         GetParam().localSizeZ << "\n" <<
645         "OpDecorate %3 ArrayStride 4\n"
646         "OpMemberDecorate %4 0 Offset 0\n"
647         "OpDecorate %4 BufferBlock\n"
648         "OpDecorate %5 DescriptorSet 0\n"
649         "OpDecorate %5 Binding 1\n"
650         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
651         "OpDecorate %6 DescriptorSet 0\n"
652         "OpDecorate %6 Binding 0\n"
653         "%7 = OpTypeVoid\n"
654         "%8 = OpTypeFunction %7\n"             // void()
655         "%9 = OpTypeInt 32 1\n"                // int32
656         "%10 = OpTypeInt 32 0\n"                // uint32
657         "%11 = OpTypeBool\n"
658         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
659         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
660         "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
661         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
662         "%13 = OpConstant %9 0\n"               // int32(0)
663         "%14 = OpConstant %9 2\n"               // int32(2)
664         "%15 = OpConstant %10 0\n"              // uint32(0)
665         "%16 = OpTypeVector %10 3\n"            // vec4<int32>
666         "%17 = OpTypePointer Input %16\n"       // vec4<int32>*
667         "%2 = OpVariable %17 Input\n"          // gl_GlobalInvocationId
668         "%18 = OpTypePointer Input %10\n"       // uint32*
669         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
670         "%19 = OpTypePointer Uniform %9\n"      // int32*
671         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
672         "%20 = OpLabel\n"
673         "%21 = OpAccessChain %18 %2 %15\n"      // &gl_GlobalInvocationId.x
674         "%22 = OpLoad %10 %21\n"                // gl_GlobalInvocationId.x
675         "%23 = OpAccessChain %19 %6 %13 %22\n"  // &in.arr[gl_GlobalInvocationId.x]
676         "%24 = OpLoad %9 %23\n"                 // in.arr[gl_GlobalInvocationId.x]
677         "%25 = OpAccessChain %19 %5 %13 %22\n"  // &out.arr[gl_GlobalInvocationId.x]
678                                                 // Start of branch logic
679                                                 // %24 = in value
680         "%26 = OpSMod %9 %24 %14\n"             // in % 2
681         "%27 = OpIEqual %11 %26 %13\n"          // (in % 2) == 0
682         "OpSelectionMerge %28 None\n"
683         "OpBranchConditional %27 %29 %30\n"
684         "%29 = OpLabel\n"                       // (in % 2) == 0
685         "OpBranch %28\n"
686         "%30 = OpLabel\n"                       // (in % 2) != 0
687         "OpBranch %28\n"
688         "%28 = OpLabel\n"
689         // %26 = out value
690         // End of branch logic
691         "OpStore %25 %26\n"               // use SSA value from previous block
692         "OpReturn\n"
693         "OpFunctionEnd\n";
694 	// clang-format on
695 
696 	test(
697 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i % 2; });
698 }
699 
700 // TODO: Test for parallel assignment
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,BranchConditionalStore)701 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchConditionalStore)
702 {
703 	std::stringstream src;
704 	// clang-format off
705     src <<
706         "OpCapability Shader\n"
707         "OpMemoryModel Logical GLSL450\n"
708         "OpEntryPoint GLCompute %1 \"main\" %2\n"
709         "OpExecutionMode %1 LocalSize " <<
710         GetParam().localSizeX << " " <<
711         GetParam().localSizeY << " " <<
712         GetParam().localSizeZ << "\n" <<
713         "OpDecorate %3 ArrayStride 4\n"
714         "OpMemberDecorate %4 0 Offset 0\n"
715         "OpDecorate %4 BufferBlock\n"
716         "OpDecorate %5 DescriptorSet 0\n"
717         "OpDecorate %5 Binding 1\n"
718         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
719         "OpDecorate %6 DescriptorSet 0\n"
720         "OpDecorate %6 Binding 0\n"
721         "%7 = OpTypeVoid\n"
722         "%8 = OpTypeFunction %7\n"             // void()
723         "%9 = OpTypeInt 32 1\n"                // int32
724         "%10 = OpTypeInt 32 0\n"                // uint32
725         "%11 = OpTypeBool\n"
726         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
727         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
728         "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
729         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
730         "%13 = OpConstant %9 0\n"               // int32(0)
731         "%14 = OpConstant %9 1\n"               // int32(1)
732         "%15 = OpConstant %9 2\n"               // int32(2)
733         "%16 = OpConstant %10 0\n"              // uint32(0)
734         "%17 = OpTypeVector %10 3\n"            // vec4<int32>
735         "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
736         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
737         "%19 = OpTypePointer Input %10\n"       // uint32*
738         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
739         "%20 = OpTypePointer Uniform %9\n"      // int32*
740         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
741         "%21 = OpLabel\n"
742         "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
743         "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
744         "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
745         "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
746         "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
747                                                 // Start of branch logic
748                                                 // %25 = in value
749         "%27 = OpSMod %9 %25 %15\n"             // in % 2
750         "%28 = OpIEqual %11 %27 %13\n"          // (in % 2) == 0
751         "OpSelectionMerge %29 None\n"
752         "OpBranchConditional %28 %30 %31\n"
753         "%30 = OpLabel\n"                       // (in % 2) == 0
754         "OpStore %26 %14\n"               // write 1
755         "OpBranch %29\n"
756         "%31 = OpLabel\n"                       // (in % 2) != 0
757         "OpStore %26 %15\n"               // write 2
758         "OpBranch %29\n"
759         "%29 = OpLabel\n"
760         // End of branch logic
761         "OpReturn\n"
762         "OpFunctionEnd\n";
763 	// clang-format on
764 
765 	test(
766 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 0 ? 1 : 2; });
767 }
768 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,BranchConditionalReturnTrue)769 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchConditionalReturnTrue)
770 {
771 	std::stringstream src;
772 	// clang-format off
773     src <<
774         "OpCapability Shader\n"
775         "OpMemoryModel Logical GLSL450\n"
776         "OpEntryPoint GLCompute %1 \"main\" %2\n"
777         "OpExecutionMode %1 LocalSize " <<
778         GetParam().localSizeX << " " <<
779         GetParam().localSizeY << " " <<
780         GetParam().localSizeZ << "\n" <<
781         "OpDecorate %3 ArrayStride 4\n"
782         "OpMemberDecorate %4 0 Offset 0\n"
783         "OpDecorate %4 BufferBlock\n"
784         "OpDecorate %5 DescriptorSet 0\n"
785         "OpDecorate %5 Binding 1\n"
786         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
787         "OpDecorate %6 DescriptorSet 0\n"
788         "OpDecorate %6 Binding 0\n"
789         "%7 = OpTypeVoid\n"
790         "%8 = OpTypeFunction %7\n"             // void()
791         "%9 = OpTypeInt 32 1\n"                // int32
792         "%10 = OpTypeInt 32 0\n"                // uint32
793         "%11 = OpTypeBool\n"
794         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
795         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
796         "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
797         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
798         "%13 = OpConstant %9 0\n"               // int32(0)
799         "%14 = OpConstant %9 1\n"               // int32(1)
800         "%15 = OpConstant %9 2\n"               // int32(2)
801         "%16 = OpConstant %10 0\n"              // uint32(0)
802         "%17 = OpTypeVector %10 3\n"            // vec4<int32>
803         "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
804         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
805         "%19 = OpTypePointer Input %10\n"       // uint32*
806         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
807         "%20 = OpTypePointer Uniform %9\n"      // int32*
808         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
809         "%21 = OpLabel\n"
810         "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
811         "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
812         "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
813         "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
814         "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
815                                                 // Start of branch logic
816                                                 // %25 = in value
817         "%27 = OpSMod %9 %25 %15\n"             // in % 2
818         "%28 = OpIEqual %11 %27 %13\n"          // (in % 2) == 0
819         "OpSelectionMerge %29 None\n"
820         "OpBranchConditional %28 %30 %29\n"
821         "%30 = OpLabel\n"                       // (in % 2) == 0
822         "OpReturn\n"
823         "%29 = OpLabel\n"                       // merge
824         "OpStore %26 %15\n"               // write 2
825                                           // End of branch logic
826         "OpReturn\n"
827         "OpFunctionEnd\n";
828 	// clang-format on
829 
830 	test(
831 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 0 ? 0 : 2; });
832 }
833 
834 // TODO: Test for parallel assignment
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,BranchConditionalPhi)835 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchConditionalPhi)
836 {
837 	std::stringstream src;
838 	// clang-format off
839     src <<
840         "OpCapability Shader\n"
841         "OpMemoryModel Logical GLSL450\n"
842         "OpEntryPoint GLCompute %1 \"main\" %2\n"
843         "OpExecutionMode %1 LocalSize " <<
844         GetParam().localSizeX << " " <<
845         GetParam().localSizeY << " " <<
846         GetParam().localSizeZ << "\n" <<
847         "OpDecorate %3 ArrayStride 4\n"
848         "OpMemberDecorate %4 0 Offset 0\n"
849         "OpDecorate %4 BufferBlock\n"
850         "OpDecorate %5 DescriptorSet 0\n"
851         "OpDecorate %5 Binding 1\n"
852         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
853         "OpDecorate %6 DescriptorSet 0\n"
854         "OpDecorate %6 Binding 0\n"
855         "%7 = OpTypeVoid\n"
856         "%8 = OpTypeFunction %7\n"             // void()
857         "%9 = OpTypeInt 32 1\n"                // int32
858         "%10 = OpTypeInt 32 0\n"                // uint32
859         "%11 = OpTypeBool\n"
860         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
861         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
862         "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
863         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
864         "%13 = OpConstant %9 0\n"               // int32(0)
865         "%14 = OpConstant %9 1\n"               // int32(1)
866         "%15 = OpConstant %9 2\n"               // int32(2)
867         "%16 = OpConstant %10 0\n"              // uint32(0)
868         "%17 = OpTypeVector %10 3\n"            // vec4<int32>
869         "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
870         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
871         "%19 = OpTypePointer Input %10\n"       // uint32*
872         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
873         "%20 = OpTypePointer Uniform %9\n"      // int32*
874         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
875         "%21 = OpLabel\n"
876         "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
877         "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
878         "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
879         "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
880         "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
881                                                 // Start of branch logic
882                                                 // %25 = in value
883         "%27 = OpSMod %9 %25 %15\n"             // in % 2
884         "%28 = OpIEqual %11 %27 %13\n"          // (in % 2) == 0
885         "OpSelectionMerge %29 None\n"
886         "OpBranchConditional %28 %30 %31\n"
887         "%30 = OpLabel\n"                       // (in % 2) == 0
888         "OpBranch %29\n"
889         "%31 = OpLabel\n"                       // (in % 2) != 0
890         "OpBranch %29\n"
891         "%29 = OpLabel\n"
892         "%32 = OpPhi %9 %14 %30 %15 %31\n"      // (in % 2) == 0 ? 1 : 2
893                                                 // End of branch logic
894         "OpStore %26 %32\n"
895         "OpReturn\n"
896         "OpFunctionEnd\n";
897 	// clang-format on
898 
899 	test(
900 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 0 ? 1 : 2; });
901 }
902 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,SwitchEmptyCases)903 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchEmptyCases)
904 {
905 	std::stringstream src;
906 	// clang-format off
907     src <<
908         "OpCapability Shader\n"
909         "OpMemoryModel Logical GLSL450\n"
910         "OpEntryPoint GLCompute %1 \"main\" %2\n"
911         "OpExecutionMode %1 LocalSize " <<
912         GetParam().localSizeX << " " <<
913         GetParam().localSizeY << " " <<
914         GetParam().localSizeZ << "\n" <<
915         "OpDecorate %3 ArrayStride 4\n"
916         "OpMemberDecorate %4 0 Offset 0\n"
917         "OpDecorate %4 BufferBlock\n"
918         "OpDecorate %5 DescriptorSet 0\n"
919         "OpDecorate %5 Binding 1\n"
920         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
921         "OpDecorate %6 DescriptorSet 0\n"
922         "OpDecorate %6 Binding 0\n"
923         "%7 = OpTypeVoid\n"
924         "%8 = OpTypeFunction %7\n"             // void()
925         "%9 = OpTypeInt 32 1\n"                // int32
926         "%10 = OpTypeInt 32 0\n"                // uint32
927         "%11 = OpTypeBool\n"
928         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
929         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
930         "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
931         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
932         "%13 = OpConstant %9 0\n"               // int32(0)
933         "%14 = OpConstant %9 2\n"               // int32(2)
934         "%15 = OpConstant %10 0\n"              // uint32(0)
935         "%16 = OpTypeVector %10 3\n"            // vec4<int32>
936         "%17 = OpTypePointer Input %16\n"       // vec4<int32>*
937         "%2 = OpVariable %17 Input\n"          // gl_GlobalInvocationId
938         "%18 = OpTypePointer Input %10\n"       // uint32*
939         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
940         "%19 = OpTypePointer Uniform %9\n"      // int32*
941         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
942         "%20 = OpLabel\n"
943         "%21 = OpAccessChain %18 %2 %15\n"      // &gl_GlobalInvocationId.x
944         "%22 = OpLoad %10 %21\n"                // gl_GlobalInvocationId.x
945         "%23 = OpAccessChain %19 %6 %13 %22\n"  // &in.arr[gl_GlobalInvocationId.x]
946         "%24 = OpLoad %9 %23\n"                 // in.arr[gl_GlobalInvocationId.x]
947         "%25 = OpAccessChain %19 %5 %13 %22\n"  // &out.arr[gl_GlobalInvocationId.x]
948                                                 // Start of branch logic
949                                                 // %24 = in value
950         "%26 = OpSMod %9 %24 %14\n"             // in % 2
951         "OpSelectionMerge %27 None\n"
952         "OpSwitch %26 %27 0 %28 1 %29\n"
953         "%28 = OpLabel\n"                       // (in % 2) == 0
954         "OpBranch %27\n"
955         "%29 = OpLabel\n"                       // (in % 2) == 1
956         "OpBranch %27\n"
957         "%27 = OpLabel\n"
958         // %26 = out value
959         // End of branch logic
960         "OpStore %25 %26\n"               // use SSA value from previous block
961         "OpReturn\n"
962         "OpFunctionEnd\n";
963 	// clang-format on
964 
965 	test(
966 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i % 2; });
967 }
968 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,SwitchStore)969 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchStore)
970 {
971 	std::stringstream src;
972 	// clang-format off
973     src <<
974         "OpCapability Shader\n"
975         "OpMemoryModel Logical GLSL450\n"
976         "OpEntryPoint GLCompute %1 \"main\" %2\n"
977         "OpExecutionMode %1 LocalSize " <<
978         GetParam().localSizeX << " " <<
979         GetParam().localSizeY << " " <<
980         GetParam().localSizeZ << "\n" <<
981         "OpDecorate %3 ArrayStride 4\n"
982         "OpMemberDecorate %4 0 Offset 0\n"
983         "OpDecorate %4 BufferBlock\n"
984         "OpDecorate %5 DescriptorSet 0\n"
985         "OpDecorate %5 Binding 1\n"
986         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
987         "OpDecorate %6 DescriptorSet 0\n"
988         "OpDecorate %6 Binding 0\n"
989         "%7 = OpTypeVoid\n"
990         "%8 = OpTypeFunction %7\n"             // void()
991         "%9 = OpTypeInt 32 1\n"                // int32
992         "%10 = OpTypeInt 32 0\n"                // uint32
993         "%11 = OpTypeBool\n"
994         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
995         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
996         "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
997         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
998         "%13 = OpConstant %9 0\n"               // int32(0)
999         "%14 = OpConstant %9 1\n"               // int32(1)
1000         "%15 = OpConstant %9 2\n"               // int32(2)
1001         "%16 = OpConstant %10 0\n"              // uint32(0)
1002         "%17 = OpTypeVector %10 3\n"            // vec4<int32>
1003         "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
1004         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
1005         "%19 = OpTypePointer Input %10\n"       // uint32*
1006         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
1007         "%20 = OpTypePointer Uniform %9\n"      // int32*
1008         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
1009         "%21 = OpLabel\n"
1010         "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
1011         "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
1012         "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
1013         "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
1014         "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
1015                                                 // Start of branch logic
1016                                                 // %25 = in value
1017         "%27 = OpSMod %9 %25 %15\n"             // in % 2
1018         "OpSelectionMerge %28 None\n"
1019         "OpSwitch %27 %28 0 %29 1 %30\n"
1020         "%29 = OpLabel\n"                       // (in % 2) == 0
1021         "OpStore %26 %15\n"               // write 2
1022         "OpBranch %28\n"
1023         "%30 = OpLabel\n"                       // (in % 2) == 1
1024         "OpStore %26 %14\n"               // write 1
1025         "OpBranch %28\n"
1026         "%28 = OpLabel\n"
1027         // End of branch logic
1028         "OpReturn\n"
1029         "OpFunctionEnd\n";
1030 	// clang-format on
1031 
1032 	test(
1033 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 0 ? 2 : 1; });
1034 }
1035 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,SwitchCaseReturn)1036 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchCaseReturn)
1037 {
1038 	std::stringstream src;
1039 	// clang-format off
1040     src <<
1041         "OpCapability Shader\n"
1042         "OpMemoryModel Logical GLSL450\n"
1043         "OpEntryPoint GLCompute %1 \"main\" %2\n"
1044         "OpExecutionMode %1 LocalSize " <<
1045         GetParam().localSizeX << " " <<
1046         GetParam().localSizeY << " " <<
1047         GetParam().localSizeZ << "\n" <<
1048         "OpDecorate %3 ArrayStride 4\n"
1049         "OpMemberDecorate %4 0 Offset 0\n"
1050         "OpDecorate %4 BufferBlock\n"
1051         "OpDecorate %5 DescriptorSet 0\n"
1052         "OpDecorate %5 Binding 1\n"
1053         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
1054         "OpDecorate %6 DescriptorSet 0\n"
1055         "OpDecorate %6 Binding 0\n"
1056         "%7 = OpTypeVoid\n"
1057         "%8 = OpTypeFunction %7\n"             // void()
1058         "%9 = OpTypeInt 32 1\n"                // int32
1059         "%10 = OpTypeInt 32 0\n"                // uint32
1060         "%11 = OpTypeBool\n"
1061         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
1062         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
1063         "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
1064         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
1065         "%13 = OpConstant %9 0\n"               // int32(0)
1066         "%14 = OpConstant %9 1\n"               // int32(1)
1067         "%15 = OpConstant %9 2\n"               // int32(2)
1068         "%16 = OpConstant %10 0\n"              // uint32(0)
1069         "%17 = OpTypeVector %10 3\n"            // vec4<int32>
1070         "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
1071         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
1072         "%19 = OpTypePointer Input %10\n"       // uint32*
1073         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
1074         "%20 = OpTypePointer Uniform %9\n"      // int32*
1075         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
1076         "%21 = OpLabel\n"
1077         "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
1078         "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
1079         "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
1080         "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
1081         "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
1082                                                 // Start of branch logic
1083                                                 // %25 = in value
1084         "%27 = OpSMod %9 %25 %15\n"             // in % 2
1085         "OpSelectionMerge %28 None\n"
1086         "OpSwitch %27 %28 0 %29 1 %30\n"
1087         "%29 = OpLabel\n"                       // (in % 2) == 0
1088         "OpBranch %28\n"
1089         "%30 = OpLabel\n"                       // (in % 2) == 1
1090         "OpReturn\n"
1091         "%28 = OpLabel\n"
1092         "OpStore %26 %14\n"               // write 1
1093                                           // End of branch logic
1094         "OpReturn\n"
1095         "OpFunctionEnd\n";
1096 	// clang-format on
1097 
1098 	test(
1099 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 1 ? 0 : 1; });
1100 }
1101 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,SwitchDefaultReturn)1102 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchDefaultReturn)
1103 {
1104 	std::stringstream src;
1105 	// clang-format off
1106     src <<
1107         "OpCapability Shader\n"
1108         "OpMemoryModel Logical GLSL450\n"
1109         "OpEntryPoint GLCompute %1 \"main\" %2\n"
1110         "OpExecutionMode %1 LocalSize " <<
1111         GetParam().localSizeX << " " <<
1112         GetParam().localSizeY << " " <<
1113         GetParam().localSizeZ << "\n" <<
1114         "OpDecorate %3 ArrayStride 4\n"
1115         "OpMemberDecorate %4 0 Offset 0\n"
1116         "OpDecorate %4 BufferBlock\n"
1117         "OpDecorate %5 DescriptorSet 0\n"
1118         "OpDecorate %5 Binding 1\n"
1119         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
1120         "OpDecorate %6 DescriptorSet 0\n"
1121         "OpDecorate %6 Binding 0\n"
1122         "%7 = OpTypeVoid\n"
1123         "%8 = OpTypeFunction %7\n"             // void()
1124         "%9 = OpTypeInt 32 1\n"                // int32
1125         "%10 = OpTypeInt 32 0\n"                // uint32
1126         "%11 = OpTypeBool\n"
1127         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
1128         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
1129         "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
1130         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
1131         "%13 = OpConstant %9 0\n"               // int32(0)
1132         "%14 = OpConstant %9 1\n"               // int32(1)
1133         "%15 = OpConstant %9 2\n"               // int32(2)
1134         "%16 = OpConstant %10 0\n"              // uint32(0)
1135         "%17 = OpTypeVector %10 3\n"            // vec4<int32>
1136         "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
1137         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
1138         "%19 = OpTypePointer Input %10\n"       // uint32*
1139         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
1140         "%20 = OpTypePointer Uniform %9\n"      // int32*
1141         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
1142         "%21 = OpLabel\n"
1143         "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
1144         "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
1145         "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
1146         "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
1147         "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
1148                                                 // Start of branch logic
1149                                                 // %25 = in value
1150         "%27 = OpSMod %9 %25 %15\n"             // in % 2
1151         "OpSelectionMerge %28 None\n"
1152         "OpSwitch %27 %29 1 %30\n"
1153         "%30 = OpLabel\n"                       // (in % 2) == 1
1154         "OpBranch %28\n"
1155         "%29 = OpLabel\n"                       // (in % 2) != 1
1156         "OpReturn\n"
1157         "%28 = OpLabel\n"                       // merge
1158         "OpStore %26 %14\n"               // write 1
1159                                           // End of branch logic
1160         "OpReturn\n"
1161         "OpFunctionEnd\n";
1162 	// clang-format on
1163 
1164 	test(
1165 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 1 ? 1 : 0; });
1166 }
1167 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,SwitchCaseFallthrough)1168 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchCaseFallthrough)
1169 {
1170 	std::stringstream src;
1171 	// clang-format off
1172     src <<
1173         "OpCapability Shader\n"
1174         "OpMemoryModel Logical GLSL450\n"
1175         "OpEntryPoint GLCompute %1 \"main\" %2\n"
1176         "OpExecutionMode %1 LocalSize " <<
1177         GetParam().localSizeX << " " <<
1178         GetParam().localSizeY << " " <<
1179         GetParam().localSizeZ << "\n" <<
1180         "OpDecorate %3 ArrayStride 4\n"
1181         "OpMemberDecorate %4 0 Offset 0\n"
1182         "OpDecorate %4 BufferBlock\n"
1183         "OpDecorate %5 DescriptorSet 0\n"
1184         "OpDecorate %5 Binding 1\n"
1185         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
1186         "OpDecorate %6 DescriptorSet 0\n"
1187         "OpDecorate %6 Binding 0\n"
1188         "%7 = OpTypeVoid\n"
1189         "%8 = OpTypeFunction %7\n"             // void()
1190         "%9 = OpTypeInt 32 1\n"                // int32
1191         "%10 = OpTypeInt 32 0\n"                // uint32
1192         "%11 = OpTypeBool\n"
1193         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
1194         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
1195         "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
1196         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
1197         "%13 = OpConstant %9 0\n"               // int32(0)
1198         "%14 = OpConstant %9 1\n"               // int32(1)
1199         "%15 = OpConstant %9 2\n"               // int32(2)
1200         "%16 = OpConstant %10 0\n"              // uint32(0)
1201         "%17 = OpTypeVector %10 3\n"            // vec4<int32>
1202         "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
1203         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
1204         "%19 = OpTypePointer Input %10\n"       // uint32*
1205         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
1206         "%20 = OpTypePointer Uniform %9\n"      // int32*
1207         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
1208         "%21 = OpLabel\n"
1209         "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
1210         "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
1211         "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
1212         "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
1213         "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
1214                                                 // Start of branch logic
1215                                                 // %25 = in value
1216         "%27 = OpSMod %9 %25 %15\n"             // in % 2
1217         "OpSelectionMerge %28 None\n"
1218         "OpSwitch %27 %29 0 %30 1 %31\n"
1219         "%30 = OpLabel\n"                       // (in % 2) == 0
1220         "%32 = OpIAdd %9 %27 %14\n"             // generate an intermediate
1221         "OpStore %26 %32\n"               // write a value (overwritten later)
1222         "OpBranch %31\n"                  // fallthrough
1223         "%31 = OpLabel\n"                       // (in % 2) == 1
1224         "OpStore %26 %15\n"               // write 2
1225         "OpBranch %28\n"
1226         "%29 = OpLabel\n"                       // unreachable
1227         "OpUnreachable\n"
1228         "%28 = OpLabel\n"                       // merge
1229                                                 // End of branch logic
1230         "OpReturn\n"
1231         "OpFunctionEnd\n";
1232 	// clang-format on
1233 
1234 	test(
1235 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return 2; });
1236 }
1237 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,SwitchDefaultFallthrough)1238 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchDefaultFallthrough)
1239 {
1240 	std::stringstream src;
1241 	// clang-format off
1242     src <<
1243         "OpCapability Shader\n"
1244         "OpMemoryModel Logical GLSL450\n"
1245         "OpEntryPoint GLCompute %1 \"main\" %2\n"
1246         "OpExecutionMode %1 LocalSize " <<
1247         GetParam().localSizeX << " " <<
1248         GetParam().localSizeY << " " <<
1249         GetParam().localSizeZ << "\n" <<
1250         "OpDecorate %3 ArrayStride 4\n"
1251         "OpMemberDecorate %4 0 Offset 0\n"
1252         "OpDecorate %4 BufferBlock\n"
1253         "OpDecorate %5 DescriptorSet 0\n"
1254         "OpDecorate %5 Binding 1\n"
1255         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
1256         "OpDecorate %6 DescriptorSet 0\n"
1257         "OpDecorate %6 Binding 0\n"
1258         "%7 = OpTypeVoid\n"
1259         "%8 = OpTypeFunction %7\n"             // void()
1260         "%9 = OpTypeInt 32 1\n"                // int32
1261         "%10 = OpTypeInt 32 0\n"                // uint32
1262         "%11 = OpTypeBool\n"
1263         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
1264         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
1265         "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
1266         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
1267         "%13 = OpConstant %9 0\n"               // int32(0)
1268         "%14 = OpConstant %9 1\n"               // int32(1)
1269         "%15 = OpConstant %9 2\n"               // int32(2)
1270         "%16 = OpConstant %10 0\n"              // uint32(0)
1271         "%17 = OpTypeVector %10 3\n"            // vec4<int32>
1272         "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
1273         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
1274         "%19 = OpTypePointer Input %10\n"       // uint32*
1275         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
1276         "%20 = OpTypePointer Uniform %9\n"      // int32*
1277         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
1278         "%21 = OpLabel\n"
1279         "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
1280         "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
1281         "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
1282         "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
1283         "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
1284                                                 // Start of branch logic
1285                                                 // %25 = in value
1286         "%27 = OpSMod %9 %25 %15\n"             // in % 2
1287         "OpSelectionMerge %28 None\n"
1288         "OpSwitch %27 %29 0 %30 1 %31\n"
1289         "%30 = OpLabel\n"                       // (in % 2) == 0
1290         "%32 = OpIAdd %9 %27 %14\n"             // generate an intermediate
1291         "OpStore %26 %32\n"               // write a value (overwritten later)
1292         "OpBranch %29\n"                  // fallthrough
1293         "%29 = OpLabel\n"                       // default
1294         "%33 = OpIAdd %9 %27 %14\n"             // generate an intermediate
1295         "OpStore %26 %33\n"               // write a value (overwritten later)
1296         "OpBranch %31\n"                  // fallthrough
1297         "%31 = OpLabel\n"                       // (in % 2) == 1
1298         "OpStore %26 %15\n"               // write 2
1299         "OpBranch %28\n"
1300         "%28 = OpLabel\n"                       // merge
1301                                                 // End of branch logic
1302         "OpReturn\n"
1303         "OpFunctionEnd\n";
1304 	// clang-format on
1305 
1306 	test(
1307 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return 2; });
1308 }
1309 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,SwitchPhi)1310 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, SwitchPhi)
1311 {
1312 	std::stringstream src;
1313 	// clang-format off
1314     src <<
1315         "OpCapability Shader\n"
1316         "OpMemoryModel Logical GLSL450\n"
1317         "OpEntryPoint GLCompute %1 \"main\" %2\n"
1318         "OpExecutionMode %1 LocalSize " <<
1319         GetParam().localSizeX << " " <<
1320         GetParam().localSizeY << " " <<
1321         GetParam().localSizeZ << "\n" <<
1322         "OpDecorate %3 ArrayStride 4\n"
1323         "OpMemberDecorate %4 0 Offset 0\n"
1324         "OpDecorate %4 BufferBlock\n"
1325         "OpDecorate %5 DescriptorSet 0\n"
1326         "OpDecorate %5 Binding 1\n"
1327         "OpDecorate %2 BuiltIn GlobalInvocationId\n"
1328         "OpDecorate %6 DescriptorSet 0\n"
1329         "OpDecorate %6 Binding 0\n"
1330         "%7 = OpTypeVoid\n"
1331         "%8 = OpTypeFunction %7\n"             // void()
1332         "%9 = OpTypeInt 32 1\n"                // int32
1333         "%10 = OpTypeInt 32 0\n"                // uint32
1334         "%11 = OpTypeBool\n"
1335         "%3 = OpTypeRuntimeArray %9\n"         // int32[]
1336         "%4 = OpTypeStruct %3\n"               // struct{ int32[] }
1337         "%12 = OpTypePointer Uniform %4\n"      // struct{ int32[] }*
1338         "%5 = OpVariable %12 Uniform\n"        // struct{ int32[] }* in
1339         "%13 = OpConstant %9 0\n"               // int32(0)
1340         "%14 = OpConstant %9 1\n"               // int32(1)
1341         "%15 = OpConstant %9 2\n"               // int32(2)
1342         "%16 = OpConstant %10 0\n"              // uint32(0)
1343         "%17 = OpTypeVector %10 3\n"            // vec4<int32>
1344         "%18 = OpTypePointer Input %17\n"       // vec4<int32>*
1345         "%2 = OpVariable %18 Input\n"          // gl_GlobalInvocationId
1346         "%19 = OpTypePointer Input %10\n"       // uint32*
1347         "%6 = OpVariable %12 Uniform\n"        // struct{ int32[] }* out
1348         "%20 = OpTypePointer Uniform %9\n"      // int32*
1349         "%1 = OpFunction %7 None %8\n"         // -- Function begin --
1350         "%21 = OpLabel\n"
1351         "%22 = OpAccessChain %19 %2 %16\n"      // &gl_GlobalInvocationId.x
1352         "%23 = OpLoad %10 %22\n"                // gl_GlobalInvocationId.x
1353         "%24 = OpAccessChain %20 %6 %13 %23\n"  // &in.arr[gl_GlobalInvocationId.x]
1354         "%25 = OpLoad %9 %24\n"                 // in.arr[gl_GlobalInvocationId.x]
1355         "%26 = OpAccessChain %20 %5 %13 %23\n"  // &out.arr[gl_GlobalInvocationId.x]
1356                                                 // Start of branch logic
1357                                                 // %25 = in value
1358         "%27 = OpSMod %9 %25 %15\n"             // in % 2
1359         "OpSelectionMerge %28 None\n"
1360         "OpSwitch %27 %29 1 %30\n"
1361         "%30 = OpLabel\n"                       // (in % 2) == 1
1362         "OpBranch %28\n"
1363         "%29 = OpLabel\n"                       // (in % 2) != 1
1364         "OpBranch %28\n"
1365         "%28 = OpLabel\n"                       // merge
1366         "%31 = OpPhi %9 %14 %30 %15 %29\n"      // (in % 2) == 1 ? 1 : 2
1367         "OpStore %26 %31\n"
1368         // End of branch logic
1369         "OpReturn\n"
1370         "OpFunctionEnd\n";
1371 	// clang-format on
1372 
1373 	test(
1374 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return (i % 2) == 1 ? 1 : 2; });
1375 }
1376 
TEST_P(SwiftShaderVulkanBufferToBufferComputeTest,LoopDivergentMergePhi)1377 TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, LoopDivergentMergePhi)
1378 {
1379 	// #version 450
1380 	// layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
1381 	// layout(binding = 0, std430) buffer InBuffer
1382 	// {
1383 	//     int Data[];
1384 	// } In;
1385 	// layout(binding = 1, std430) buffer OutBuffer
1386 	// {
1387 	//     int Data[];
1388 	// } Out;
1389 	// void main()
1390 	// {
1391 	//     int phi = 0;
1392 	//     uint lane = gl_GlobalInvocationID.x % 4;
1393 	//     for (uint i = 0; i < 4; i++)
1394 	//     {
1395 	//         if (lane == i)
1396 	//         {
1397 	//             phi = In.Data[gl_GlobalInvocationID.x];
1398 	//             break;
1399 	//         }
1400 	//     }
1401 	//     Out.Data[gl_GlobalInvocationID.x] = phi;
1402 	// }
1403 	std::stringstream src;
1404 	// clang-format off
1405     src <<
1406         "OpCapability Shader\n"
1407         "%1 = OpExtInstImport \"GLSL.std.450\"\n"
1408         "OpMemoryModel Logical GLSL450\n"
1409         "OpEntryPoint GLCompute %2 \"main\" %3\n"
1410         "OpExecutionMode %2 LocalSize " <<
1411         GetParam().localSizeX << " " <<
1412         GetParam().localSizeY << " " <<
1413         GetParam().localSizeZ << "\n" <<
1414         "OpDecorate %3 BuiltIn GlobalInvocationId\n"
1415         "OpDecorate %4 ArrayStride 4\n"
1416         "OpMemberDecorate %5 0 Offset 0\n"
1417         "OpDecorate %5 BufferBlock\n"
1418         "OpDecorate %6 DescriptorSet 0\n"
1419         "OpDecorate %6 Binding 0\n"
1420         "OpDecorate %7 ArrayStride 4\n"
1421         "OpMemberDecorate %8 0 Offset 0\n"
1422         "OpDecorate %8 BufferBlock\n"
1423         "OpDecorate %9 DescriptorSet 0\n"
1424         "OpDecorate %9 Binding 1\n"
1425         "%10 = OpTypeVoid\n"
1426         "%11 = OpTypeFunction %10\n"
1427         "%12 = OpTypeInt 32 1\n"
1428         "%13 = OpConstant %12 0\n"
1429         "%14 = OpTypeInt 32 0\n"
1430         "%15 = OpTypeVector %14 3\n"
1431         "%16 = OpTypePointer Input %15\n"
1432         "%3 = OpVariable %16 Input\n"
1433         "%17 = OpConstant %14 0\n"
1434         "%18 = OpTypePointer Input %14\n"
1435         "%19 = OpConstant %14 4\n"
1436         "%20 = OpTypeBool\n"
1437         "%4 = OpTypeRuntimeArray %12\n"
1438         "%5 = OpTypeStruct %4\n"
1439         "%21 = OpTypePointer Uniform %5\n"
1440         "%6 = OpVariable %21 Uniform\n"
1441         "%22 = OpTypePointer Uniform %12\n"
1442         "%23 = OpConstant %12 1\n"
1443         "%7 = OpTypeRuntimeArray %12\n"
1444         "%8 = OpTypeStruct %7\n"
1445         "%24 = OpTypePointer Uniform %8\n"
1446         "%9 = OpVariable %24 Uniform\n"
1447         "%2 = OpFunction %10 None %11\n"
1448         "%25 = OpLabel\n"
1449         "%26 = OpAccessChain %18 %3 %17\n"
1450         "%27 = OpLoad %14 %26\n"
1451         "%28 = OpUMod %14 %27 %19\n"
1452         "OpBranch %29\n"
1453         "%29 = OpLabel\n"
1454         "%30 = OpPhi %14 %17 %25 %31 %32\n"
1455         "%33 = OpULessThan %20 %30 %19\n"
1456         "OpLoopMerge %34 %32 None\n"
1457         "OpBranchConditional %33 %35 %34\n"
1458         "%35 = OpLabel\n"
1459         "%36 = OpIEqual %20 %28 %30\n"
1460         "OpSelectionMerge %37 None\n"
1461         "OpBranchConditional %36 %38 %37\n"
1462         "%38 = OpLabel\n"
1463         "%39 = OpAccessChain %22 %6 %13 %27\n"
1464         "%40 = OpLoad %12 %39\n"
1465         "OpBranch %34\n"
1466         "%37 = OpLabel\n"
1467         "OpBranch %32\n"
1468         "%32 = OpLabel\n"
1469         "%31 = OpIAdd %14 %30 %23\n"
1470         "OpBranch %29\n"
1471         "%34 = OpLabel\n"
1472         "%41 = OpPhi %12 %13 %29 %40 %38\n" // %40: phi
1473         "%42 = OpAccessChain %22 %9 %13 %27\n"
1474         "OpStore %42 %41\n"
1475         "OpReturn\n"
1476         "OpFunctionEnd\n";
1477 	// clang-format on
1478 
1479 	test(
1480 	    src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i; });
1481 }
1482