1// RUN: mlir-opt -split-input-file -convert-linalg-to-spirv -canonicalize -verify-diagnostics %s -o - | FileCheck %s
2
3//===----------------------------------------------------------------------===//
4// Single workgroup reduction
5//===----------------------------------------------------------------------===//
6
7#single_workgroup_reduction_trait = {
8  iterator_types = ["reduction"],
9  indexing_maps = [
10    affine_map<(i) -> (i)>,
11    affine_map<(i) -> (0)>
12  ]
13}
14
15module attributes {
16  spv.target_env = #spv.target_env<
17    #spv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, {}>
18} {
19
20// CHECK:      spv.globalVariable
21// CHECK-SAME: built_in("LocalInvocationId")
22
23// CHECK:      @single_workgroup_reduction
24// CHECK-SAME: (%[[INPUT:.+]]: !spv.ptr{{.+}}, %[[OUTPUT:.+]]: !spv.ptr{{.+}})
25
26// CHECK:        %[[ZERO:.+]] = spv.constant 0 : i32
27// CHECK:        %[[ID:.+]] = spv.Load "Input" %{{.+}} : vector<3xi32>
28// CHECK:        %[[X:.+]] = spv.CompositeExtract %[[ID]][0 : i32]
29
30// CHECK:        %[[INPTR:.+]] = spv.AccessChain %[[INPUT]][%[[ZERO]], %[[X]]]
31// CHECK:        %[[VAL:.+]] = spv.Load "StorageBuffer" %[[INPTR]] : i32
32// CHECK:        %[[ADD:.+]] = spv.GroupNonUniformIAdd "Subgroup" "Reduce" %[[VAL]] : i32
33
34// CHECK:        %[[OUTPTR:.+]] = spv.AccessChain %[[OUTPUT]][%[[ZERO]], %[[ZERO]]]
35// CHECK:        %[[ELECT:.+]] = spv.GroupNonUniformElect "Subgroup" : i1
36
37// CHECK:        spv.selection {
38// CHECK:          spv.BranchConditional %[[ELECT]], ^bb1, ^bb2
39// CHECK:        ^bb1:
40// CHECK:          spv.AtomicIAdd "Device" "AcquireRelease" %[[OUTPTR]], %[[ADD]]
41// CHECK:          spv.Branch ^bb2
42// CHECK:        ^bb2:
43// CHECK:          spv.mlir.merge
44// CHECK:        }
45// CHECK:        spv.Return
46
47func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) attributes {
48  spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}
49} {
50  linalg.generic #single_workgroup_reduction_trait
51      ins(%input : memref<16xi32>)
52     outs(%output : memref<1xi32>) {
53    ^bb(%in: i32, %out: i32):
54      %sum = addi %in, %out : i32
55      linalg.yield %sum : i32
56  }
57  spv.Return
58}
59}
60
61// -----
62
63// Missing shader entry point ABI
64
65#single_workgroup_reduction_trait = {
66  iterator_types = ["reduction"],
67  indexing_maps = [
68    affine_map<(i) -> (i)>,
69    affine_map<(i) -> (0)>
70  ]
71}
72
73module attributes {
74  spv.target_env = #spv.target_env<
75    #spv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, {}>
76} {
77func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) {
78  // expected-error @+1 {{failed to legalize operation 'linalg.generic'}}
79  linalg.generic #single_workgroup_reduction_trait
80      ins(%input : memref<16xi32>)
81     outs(%output : memref<1xi32>) {
82    ^bb(%in: i32, %out: i32):
83      %sum = addi %in, %out : i32
84      linalg.yield %sum : i32
85  }
86  return
87}
88}
89
90// -----
91
92// Mismatch between shader entry point ABI and input memref shape
93
94#single_workgroup_reduction_trait = {
95  iterator_types = ["reduction"],
96  indexing_maps = [
97    affine_map<(i) -> (i)>,
98    affine_map<(i) -> (0)>
99  ]
100}
101
102module attributes {
103  spv.target_env = #spv.target_env<
104    #spv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, {}>
105} {
106func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) attributes {
107  spv.entry_point_abi = {local_size = dense<[32, 1, 1]>: vector<3xi32>}
108} {
109  // expected-error @+1 {{failed to legalize operation 'linalg.generic'}}
110  linalg.generic #single_workgroup_reduction_trait
111      ins(%input : memref<16xi32>)
112     outs(%output : memref<1xi32>) {
113    ^bb(%in: i32, %out: i32):
114      %sum = addi %in, %out : i32
115      linalg.yield %sum : i32
116  }
117  spv.Return
118}
119}
120
121// -----
122
123// Unsupported multi-dimension input memref
124
125#single_workgroup_reduction_trait = {
126  iterator_types = ["parallel", "reduction"],
127  indexing_maps = [
128    affine_map<(i, j) -> (i, j)>,
129    affine_map<(i, j) -> (i)>
130  ]
131}
132
133module attributes {
134  spv.target_env = #spv.target_env<
135    #spv.vce<v1.3, [Shader, GroupNonUniformArithmetic], []>, {}>
136} {
137func @single_workgroup_reduction(%input: memref<16x8xi32>, %output: memref<16xi32>) attributes {
138  spv.entry_point_abi = {local_size = dense<[16, 8, 1]>: vector<3xi32>}
139} {
140  // expected-error @+1 {{failed to legalize operation 'linalg.generic'}}
141  linalg.generic #single_workgroup_reduction_trait
142      ins(%input : memref<16x8xi32>)
143     outs(%output : memref<16xi32>) {
144    ^bb(%in: i32, %out: i32):
145      %sum = addi %in, %out : i32
146      linalg.yield %sum : i32
147  }
148  spv.Return
149}
150}
151