1// RUN: mlir-opt -split-input-file -convert-linalg-to-spirv -canonicalize -verify-diagnostics %s -o - | FileCheck %s 2 3//===----------------------------------------------------------------------===// 4// Single workgroup reduction 5//===----------------------------------------------------------------------===// 6 7#single_workgroup_reduction_trait = { 8 iterator_types = ["reduction"], 9 indexing_maps = [ 10 affine_map<(i) -> (i)>, 11 affine_map<(i) -> (0)> 12 ] 13} 14 15module attributes { 16 spv.target_env = #spv.target_env< 17 #spv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, {}> 18} { 19 20// CHECK: spv.globalVariable 21// CHECK-SAME: built_in("LocalInvocationId") 22 23// CHECK: @single_workgroup_reduction 24// CHECK-SAME: (%[[INPUT:.+]]: !spv.ptr{{.+}}, %[[OUTPUT:.+]]: !spv.ptr{{.+}}) 25 26// CHECK: %[[ZERO:.+]] = spv.constant 0 : i32 27// CHECK: %[[ID:.+]] = spv.Load "Input" %{{.+}} : vector<3xi32> 28// CHECK: %[[X:.+]] = spv.CompositeExtract %[[ID]][0 : i32] 29 30// CHECK: %[[INPTR:.+]] = spv.AccessChain %[[INPUT]][%[[ZERO]], %[[X]]] 31// CHECK: %[[VAL:.+]] = spv.Load "StorageBuffer" %[[INPTR]] : i32 32// CHECK: %[[ADD:.+]] = spv.GroupNonUniformIAdd "Subgroup" "Reduce" %[[VAL]] : i32 33 34// CHECK: %[[OUTPTR:.+]] = spv.AccessChain %[[OUTPUT]][%[[ZERO]], %[[ZERO]]] 35// CHECK: %[[ELECT:.+]] = spv.GroupNonUniformElect "Subgroup" : i1 36 37// CHECK: spv.selection { 38// CHECK: spv.BranchConditional %[[ELECT]], ^bb1, ^bb2 39// CHECK: ^bb1: 40// CHECK: spv.AtomicIAdd "Device" "AcquireRelease" %[[OUTPTR]], %[[ADD]] 41// CHECK: spv.Branch ^bb2 42// CHECK: ^bb2: 43// CHECK: spv.mlir.merge 44// CHECK: } 45// CHECK: spv.Return 46 47func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) attributes { 48 spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>} 49} { 50 linalg.generic #single_workgroup_reduction_trait 51 ins(%input : memref<16xi32>) 52 outs(%output : memref<1xi32>) { 53 ^bb(%in: i32, %out: i32): 54 %sum = addi %in, %out : i32 55 linalg.yield %sum : i32 56 } 57 spv.Return 58} 59} 60 61// ----- 62 63// Missing shader entry point ABI 64 65#single_workgroup_reduction_trait = { 66 iterator_types = ["reduction"], 67 indexing_maps = [ 68 affine_map<(i) -> (i)>, 69 affine_map<(i) -> (0)> 70 ] 71} 72 73module attributes { 74 spv.target_env = #spv.target_env< 75 #spv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, {}> 76} { 77func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) { 78 // expected-error @+1 {{failed to legalize operation 'linalg.generic'}} 79 linalg.generic #single_workgroup_reduction_trait 80 ins(%input : memref<16xi32>) 81 outs(%output : memref<1xi32>) { 82 ^bb(%in: i32, %out: i32): 83 %sum = addi %in, %out : i32 84 linalg.yield %sum : i32 85 } 86 return 87} 88} 89 90// ----- 91 92// Mismatch between shader entry point ABI and input memref shape 93 94#single_workgroup_reduction_trait = { 95 iterator_types = ["reduction"], 96 indexing_maps = [ 97 affine_map<(i) -> (i)>, 98 affine_map<(i) -> (0)> 99 ] 100} 101 102module attributes { 103 spv.target_env = #spv.target_env< 104 #spv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, {}> 105} { 106func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) attributes { 107 spv.entry_point_abi = {local_size = dense<[32, 1, 1]>: vector<3xi32>} 108} { 109 // expected-error @+1 {{failed to legalize operation 'linalg.generic'}} 110 linalg.generic #single_workgroup_reduction_trait 111 ins(%input : memref<16xi32>) 112 outs(%output : memref<1xi32>) { 113 ^bb(%in: i32, %out: i32): 114 %sum = addi %in, %out : i32 115 linalg.yield %sum : i32 116 } 117 spv.Return 118} 119} 120 121// ----- 122 123// Unsupported multi-dimension input memref 124 125#single_workgroup_reduction_trait = { 126 iterator_types = ["parallel", "reduction"], 127 indexing_maps = [ 128 affine_map<(i, j) -> (i, j)>, 129 affine_map<(i, j) -> (i)> 130 ] 131} 132 133module attributes { 134 spv.target_env = #spv.target_env< 135 #spv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, {}> 136} { 137func @single_workgroup_reduction(%input: memref<16x8xi32>, %output: memref<16xi32>) attributes { 138 spv.entry_point_abi = {local_size = dense<[16, 8, 1]>: vector<3xi32>} 139} { 140 // expected-error @+1 {{failed to legalize operation 'linalg.generic'}} 141 linalg.generic #single_workgroup_reduction_trait 142 ins(%input : memref<16x8xi32>) 143 outs(%output : memref<16xi32>) { 144 ^bb(%in: i32, %out: i32): 145 %sum = addi %in, %out : i32 146 linalg.yield %sum : i32 147 } 148 spv.Return 149} 150} 151