1// RUN: mlir-opt -allow-unregistered-dialect --convert-gpu-to-nvvm --split-input-file %s | FileCheck --check-prefix=NVVM %s
2// RUN: mlir-opt -allow-unregistered-dialect --convert-gpu-to-rocdl --split-input-file %s | FileCheck --check-prefix=ROCDL %s
3
4gpu.module @kernel {
5  // NVVM-LABEL:  llvm.func @private
6  gpu.func @private(%arg0: f32) private(%arg1: memref<4xf32, 5>) {
7    // Allocate private memory inside the function.
8    // NVVM: %[[size:.*]] = llvm.mlir.constant(4 : i64) : !llvm.i64
9    // NVVM: %[[raw:.*]] = llvm.alloca %[[size]] x !llvm.float : (!llvm.i64) -> !llvm.ptr<float>
10
11    // ROCDL: %[[size:.*]] = llvm.mlir.constant(4 : i64) : !llvm.i64
12    // ROCDL: %[[raw:.*]] = llvm.alloca %[[size]] x !llvm.float : (!llvm.i64) -> !llvm.ptr<float, 5>
13
14    // Populate the memref descriptor.
15    // NVVM: %[[descr1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<float>, ptr<float>, i64, array<1 x i64>, array<1 x i64>)>
16    // NVVM: %[[descr2:.*]] = llvm.insertvalue %[[raw]], %[[descr1]][0]
17    // NVVM: %[[descr3:.*]] = llvm.insertvalue %[[raw]], %[[descr2]][1]
18    // NVVM: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
19    // NVVM: %[[descr4:.*]] = llvm.insertvalue %[[c0]], %[[descr3]][2]
20    // NVVM: %[[c4:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64
21    // NVVM: %[[descr5:.*]] = llvm.insertvalue %[[c4]], %[[descr4]][3, 0]
22    // NVVM: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
23    // NVVM: %[[descr6:.*]] = llvm.insertvalue %[[c1]], %[[descr5]][4, 0]
24
25    // ROCDL: %[[descr1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<float, 5>, ptr<float, 5>, i64, array<1 x i64>, array<1 x i64>)>
26    // ROCDL: %[[descr2:.*]] = llvm.insertvalue %[[raw]], %[[descr1]][0]
27    // ROCDL: %[[descr3:.*]] = llvm.insertvalue %[[raw]], %[[descr2]][1]
28    // ROCDL: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
29    // ROCDL: %[[descr4:.*]] = llvm.insertvalue %[[c0]], %[[descr3]][2]
30    // ROCDL: %[[c4:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64
31    // ROCDL: %[[descr5:.*]] = llvm.insertvalue %[[c4]], %[[descr4]][3, 0]
32    // ROCDL: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
33    // ROCDL: %[[descr6:.*]] = llvm.insertvalue %[[c1]], %[[descr5]][4, 0]
34
35    // "Store" lowering should work just as any other memref, only check that
36    // we emit some core instructions.
37    // NVVM: llvm.extractvalue %[[descr6:.*]]
38    // NVVM: llvm.getelementptr
39    // NVVM: llvm.store
40
41    // ROCDL: llvm.extractvalue %[[descr6:.*]]
42    // ROCDL: llvm.getelementptr
43    // ROCDL: llvm.store
44    %c0 = constant 0 : index
45    store %arg0, %arg1[%c0] : memref<4xf32, 5>
46
47    "terminator"() : () -> ()
48  }
49}
50
51// -----
52
53gpu.module @kernel {
54  // Workgroup buffers are allocated as globals.
55  // NVVM: llvm.mlir.global internal @[[$buffer:.*]]()
56  // NVVM-SAME:  addr_space = 3
57  // NVVM-SAME:  !llvm.array<4 x float>
58
59  // ROCDL: llvm.mlir.global internal @[[$buffer:.*]]()
60  // ROCDL-SAME:  addr_space = 3
61  // ROCDL-SAME:  !llvm.array<4 x float>
62
63  // NVVM-LABEL: llvm.func @workgroup
64  // NVVM-SAME: {
65
66  // ROCDL-LABEL: llvm.func @workgroup
67  // ROCDL-SAME: {
68  gpu.func @workgroup(%arg0: f32) workgroup(%arg1: memref<4xf32, 3>) {
69    // Get the address of the first element in the global array.
70    // NVVM: %[[c0:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
71    // NVVM: %[[addr:.*]] = llvm.mlir.addressof @[[$buffer]] : !llvm.ptr<array<4 x float>, 3>
72    // NVVM: %[[raw:.*]] = llvm.getelementptr %[[addr]][%[[c0]], %[[c0]]]
73    // NVVM-SAME: !llvm.ptr<float, 3>
74
75    // ROCDL: %[[c0:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
76    // ROCDL: %[[addr:.*]] = llvm.mlir.addressof @[[$buffer]] : !llvm.ptr<array<4 x float>, 3>
77    // ROCDL: %[[raw:.*]] = llvm.getelementptr %[[addr]][%[[c0]], %[[c0]]]
78    // ROCDL-SAME: !llvm.ptr<float, 3>
79
80    // Populate the memref descriptor.
81    // NVVM: %[[descr1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<float, 3>, ptr<float, 3>, i64, array<1 x i64>, array<1 x i64>)>
82    // NVVM: %[[descr2:.*]] = llvm.insertvalue %[[raw]], %[[descr1]][0]
83    // NVVM: %[[descr3:.*]] = llvm.insertvalue %[[raw]], %[[descr2]][1]
84    // NVVM: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
85    // NVVM: %[[descr4:.*]] = llvm.insertvalue %[[c0]], %[[descr3]][2]
86    // NVVM: %[[c4:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64
87    // NVVM: %[[descr5:.*]] = llvm.insertvalue %[[c4]], %[[descr4]][3, 0]
88    // NVVM: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
89    // NVVM: %[[descr6:.*]] = llvm.insertvalue %[[c1]], %[[descr5]][4, 0]
90
91    // ROCDL: %[[descr1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<float, 3>, ptr<float, 3>, i64, array<1 x i64>, array<1 x i64>)>
92    // ROCDL: %[[descr2:.*]] = llvm.insertvalue %[[raw]], %[[descr1]][0]
93    // ROCDL: %[[descr3:.*]] = llvm.insertvalue %[[raw]], %[[descr2]][1]
94    // ROCDL: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
95    // ROCDL: %[[descr4:.*]] = llvm.insertvalue %[[c0]], %[[descr3]][2]
96    // ROCDL: %[[c4:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64
97    // ROCDL: %[[descr5:.*]] = llvm.insertvalue %[[c4]], %[[descr4]][3, 0]
98    // ROCDL: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
99    // ROCDL: %[[descr6:.*]] = llvm.insertvalue %[[c1]], %[[descr5]][4, 0]
100
101    // "Store" lowering should work just as any other memref, only check that
102    // we emit some core instructions.
103    // NVVM: llvm.extractvalue %[[descr6:.*]]
104    // NVVM: llvm.getelementptr
105    // NVVM: llvm.store
106
107    // ROCDL: llvm.extractvalue %[[descr6:.*]]
108    // ROCDL: llvm.getelementptr
109    // ROCDL: llvm.store
110    %c0 = constant 0 : index
111    store %arg0, %arg1[%c0] : memref<4xf32, 3>
112
113    "terminator"() : () -> ()
114  }
115}
116
117// -----
118
119gpu.module @kernel {
120  // Check that the total size was computed correctly.
121  // NVVM: llvm.mlir.global internal @[[$buffer:.*]]()
122  // NVVM-SAME:  addr_space = 3
123  // NVVM-SAME:  !llvm.array<48 x float>
124
125  // ROCDL: llvm.mlir.global internal @[[$buffer:.*]]()
126  // ROCDL-SAME:  addr_space = 3
127  // ROCDL-SAME:  !llvm.array<48 x float>
128
129  // NVVM-LABEL: llvm.func @workgroup3d
130  // ROCDL-LABEL: llvm.func @workgroup3d
131  gpu.func @workgroup3d(%arg0: f32) workgroup(%arg1: memref<4x2x6xf32, 3>) {
132    // Get the address of the first element in the global array.
133    // NVVM: %[[c0:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
134    // NVVM: %[[addr:.*]] = llvm.mlir.addressof @[[$buffer]] : !llvm.ptr<array<48 x float>, 3>
135    // NVVM: %[[raw:.*]] = llvm.getelementptr %[[addr]][%[[c0]], %[[c0]]]
136    // NVVM-SAME: !llvm.ptr<float, 3>
137
138    // ROCDL: %[[c0:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
139    // ROCDL: %[[addr:.*]] = llvm.mlir.addressof @[[$buffer]] : !llvm.ptr<array<48 x float>, 3>
140    // ROCDL: %[[raw:.*]] = llvm.getelementptr %[[addr]][%[[c0]], %[[c0]]]
141    // ROCDL-SAME: !llvm.ptr<float, 3>
142
143    // Populate the memref descriptor.
144    // NVVM: %[[descr1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<float, 3>, ptr<float, 3>, i64, array<3 x i64>, array<3 x i64>)>
145    // NVVM: %[[descr2:.*]] = llvm.insertvalue %[[raw]], %[[descr1]][0]
146    // NVVM: %[[descr3:.*]] = llvm.insertvalue %[[raw]], %[[descr2]][1]
147    // NVVM: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
148    // NVVM: %[[descr4:.*]] = llvm.insertvalue %[[c0]], %[[descr3]][2]
149    // NVVM: %[[c4:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64
150    // NVVM: %[[descr5:.*]] = llvm.insertvalue %[[c4]], %[[descr4]][3, 0]
151    // NVVM: %[[c12:.*]] = llvm.mlir.constant(12 : index) : !llvm.i64
152    // NVVM: %[[descr6:.*]] = llvm.insertvalue %[[c12]], %[[descr5]][4, 0]
153    // NVVM: %[[c2:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
154    // NVVM: %[[descr7:.*]] = llvm.insertvalue %[[c2]], %[[descr6]][3, 1]
155    // NVVM: %[[c6:.*]] = llvm.mlir.constant(6 : index) : !llvm.i64
156    // NVVM: %[[descr8:.*]] = llvm.insertvalue %[[c6]], %[[descr7]][4, 1]
157    // NVVM: %[[c6:.*]] = llvm.mlir.constant(6 : index) : !llvm.i64
158    // NVVM: %[[descr9:.*]] = llvm.insertvalue %[[c6]], %[[descr8]][3, 2]
159    // NVVM: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
160    // NVVM: %[[descr10:.*]] = llvm.insertvalue %[[c1]], %[[descr9]][4, 2]
161
162    // ROCDL: %[[descr1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<float, 3>, ptr<float, 3>, i64, array<3 x i64>, array<3 x i64>)>
163    // ROCDL: %[[descr2:.*]] = llvm.insertvalue %[[raw]], %[[descr1]][0]
164    // ROCDL: %[[descr3:.*]] = llvm.insertvalue %[[raw]], %[[descr2]][1]
165    // ROCDL: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
166    // ROCDL: %[[descr4:.*]] = llvm.insertvalue %[[c0]], %[[descr3]][2]
167    // ROCDL: %[[c4:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64
168    // ROCDL: %[[descr5:.*]] = llvm.insertvalue %[[c4]], %[[descr4]][3, 0]
169    // ROCDL: %[[c12:.*]] = llvm.mlir.constant(12 : index) : !llvm.i64
170    // ROCDL: %[[descr6:.*]] = llvm.insertvalue %[[c12]], %[[descr5]][4, 0]
171    // ROCDL: %[[c2:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
172    // ROCDL: %[[descr7:.*]] = llvm.insertvalue %[[c2]], %[[descr6]][3, 1]
173    // ROCDL: %[[c6:.*]] = llvm.mlir.constant(6 : index) : !llvm.i64
174    // ROCDL: %[[descr8:.*]] = llvm.insertvalue %[[c6]], %[[descr7]][4, 1]
175    // ROCDL: %[[c6:.*]] = llvm.mlir.constant(6 : index) : !llvm.i64
176    // ROCDL: %[[descr9:.*]] = llvm.insertvalue %[[c6]], %[[descr8]][3, 2]
177    // ROCDL: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
178    // ROCDL: %[[descr10:.*]] = llvm.insertvalue %[[c1]], %[[descr9]][4, 2]
179
180    %c0 = constant 0 : index
181    store %arg0, %arg1[%c0,%c0,%c0] : memref<4x2x6xf32, 3>
182    "terminator"() : () -> ()
183  }
184}
185
186// -----
187
188gpu.module @kernel {
189  // Check that several buffers are defined.
190  // NVVM: llvm.mlir.global internal @[[$buffer1:.*]]()
191  // NVVM-SAME:  !llvm.array<1 x float>
192  // NVVM: llvm.mlir.global internal @[[$buffer2:.*]]()
193  // NVVM-SAME:  !llvm.array<2 x float>
194
195  // ROCDL: llvm.mlir.global internal @[[$buffer1:.*]]()
196  // ROCDL-SAME:  !llvm.array<1 x float>
197  // ROCDL: llvm.mlir.global internal @[[$buffer2:.*]]()
198  // ROCDL-SAME:  !llvm.array<2 x float>
199
200  // NVVM-LABEL: llvm.func @multiple
201  // ROCDL-LABEL: llvm.func @multiple
202  gpu.func @multiple(%arg0: f32)
203      workgroup(%arg1: memref<1xf32, 3>, %arg2: memref<2xf32, 3>)
204      private(%arg3: memref<3xf32, 5>, %arg4: memref<4xf32, 5>) {
205
206    // Workgroup buffers.
207    // NVVM: llvm.mlir.addressof @[[$buffer1]]
208    // NVVM: llvm.mlir.addressof @[[$buffer2]]
209
210    // ROCDL: llvm.mlir.addressof @[[$buffer1]]
211    // ROCDL: llvm.mlir.addressof @[[$buffer2]]
212
213    // Private buffers.
214    // NVVM: %[[c3:.*]] = llvm.mlir.constant(3 : i64)
215    // NVVM: llvm.alloca %[[c3]] x !llvm.float : (!llvm.i64) -> !llvm.ptr<float>
216    // NVVM: %[[c4:.*]] = llvm.mlir.constant(4 : i64)
217    // NVVM: llvm.alloca %[[c4]] x !llvm.float : (!llvm.i64) -> !llvm.ptr<float>
218
219    // ROCDL: %[[c3:.*]] = llvm.mlir.constant(3 : i64)
220    // ROCDL: llvm.alloca %[[c3]] x !llvm.float : (!llvm.i64) -> !llvm.ptr<float, 5>
221    // ROCDL: %[[c4:.*]] = llvm.mlir.constant(4 : i64)
222    // ROCDL: llvm.alloca %[[c4]] x !llvm.float : (!llvm.i64) -> !llvm.ptr<float, 5>
223
224    %c0 = constant 0 : index
225    store %arg0, %arg1[%c0] : memref<1xf32, 3>
226    store %arg0, %arg2[%c0] : memref<2xf32, 3>
227    store %arg0, %arg3[%c0] : memref<3xf32, 5>
228    store %arg0, %arg4[%c0] : memref<4xf32, 5>
229    "terminator"() : () -> ()
230  }
231}
232