1// RUN: mlir-opt -normalize-memrefs -allow-unregistered-dialect %s | FileCheck %s
2
3// This file tests whether the memref type having non-trivial map layouts
4// are normalized to trivial (identity) layouts.
5
6// CHECK-LABEL: func @permute()
7func @permute() {
8  %A = alloc() : memref<64x256xf32, affine_map<(d0, d1) -> (d1, d0)>>
9  affine.for %i = 0 to 64 {
10    affine.for %j = 0 to 256 {
11      %1 = affine.load %A[%i, %j] : memref<64x256xf32, affine_map<(d0, d1) -> (d1, d0)>>
12      "prevent.dce"(%1) : (f32) -> ()
13    }
14  }
15  dealloc %A : memref<64x256xf32, affine_map<(d0, d1) -> (d1, d0)>>
16  return
17}
18// The old memref alloc should disappear.
19// CHECK-NOT:  memref<64x256xf32>
20// CHECK:      [[MEM:%[0-9]+]] = alloc() : memref<256x64xf32>
21// CHECK-NEXT: affine.for %[[I:arg[0-9]+]] = 0 to 64 {
22// CHECK-NEXT:   affine.for %[[J:arg[0-9]+]] = 0 to 256 {
23// CHECK-NEXT:     affine.load [[MEM]][%[[J]], %[[I]]] : memref<256x64xf32>
24// CHECK-NEXT:     "prevent.dce"
25// CHECK-NEXT:   }
26// CHECK-NEXT: }
27// CHECK-NEXT: dealloc [[MEM]]
28// CHECK-NEXT: return
29
30// CHECK-LABEL: func @shift
31func @shift(%idx : index) {
32  // CHECK-NEXT: alloc() : memref<65xf32>
33  %A = alloc() : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
34  // CHECK-NEXT: affine.load %{{.*}}[symbol(%arg0) + 1] : memref<65xf32>
35  affine.load %A[%idx] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
36  affine.for %i = 0 to 64 {
37    %1 = affine.load %A[%i] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
38    "prevent.dce"(%1) : (f32) -> ()
39    // CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}} + 1] : memref<65xf32>
40  }
41  return
42}
43
44// CHECK-LABEL: func @high_dim_permute()
45func @high_dim_permute() {
46  // CHECK-NOT: memref<64x128x256xf32,
47  %A = alloc() : memref<64x128x256xf32, affine_map<(d0, d1, d2) -> (d2, d0, d1)>>
48  // CHECK: %[[I:arg[0-9]+]]
49  affine.for %i = 0 to 64 {
50    // CHECK: %[[J:arg[0-9]+]]
51    affine.for %j = 0 to 128 {
52      // CHECK: %[[K:arg[0-9]+]]
53      affine.for %k = 0 to 256 {
54        %1 = affine.load %A[%i, %j, %k] : memref<64x128x256xf32, affine_map<(d0, d1, d2) -> (d2, d0, d1)>>
55        // CHECK: %{{.*}} = affine.load %{{.*}}[%[[K]], %[[I]], %[[J]]] : memref<256x64x128xf32>
56        "prevent.dce"(%1) : (f32) -> ()
57      }
58    }
59  }
60  return
61}
62
63// CHECK-LABEL: func @invalid_map
64func @invalid_map() {
65  %A = alloc() : memref<64x128xf32, affine_map<(d0, d1) -> (d0, -d1 - 10)>>
66  // CHECK: %{{.*}} = alloc() : memref<64x128xf32,
67  return
68}
69
70// A tiled layout.
71// CHECK-LABEL: func @data_tiling
72func @data_tiling(%idx : index) {
73  // CHECK: alloc() : memref<8x32x8x16xf32>
74  %A = alloc() : memref<64x512xf32, affine_map<(d0, d1) -> (d0 floordiv 8, d1 floordiv 16, d0 mod 8, d1 mod 16)>>
75  // CHECK: affine.load %{{.*}}[symbol(%arg0) floordiv 8, symbol(%arg0) floordiv 16, symbol(%arg0) mod 8, symbol(%arg0) mod 16]
76  %1 = affine.load %A[%idx, %idx] : memref<64x512xf32, affine_map<(d0, d1) -> (d0 floordiv 8, d1 floordiv 16, d0 mod 8, d1 mod 16)>>
77  "prevent.dce"(%1) : (f32) -> ()
78  return
79}
80
81// Strides 2 and 4 along respective dimensions.
82// CHECK-LABEL: func @strided
83func @strided() {
84  %A = alloc() : memref<64x128xf32, affine_map<(d0, d1) -> (2*d0, 4*d1)>>
85  // CHECK: affine.for %[[IV0:.*]] =
86  affine.for %i = 0 to 64 {
87    // CHECK: affine.for %[[IV1:.*]] =
88    affine.for %j = 0 to 128 {
89      // CHECK: affine.load %{{.*}}[%[[IV0]] * 2, %[[IV1]] * 4] : memref<127x509xf32>
90      %1 = affine.load %A[%i, %j] : memref<64x128xf32, affine_map<(d0, d1) -> (2*d0, 4*d1)>>
91      "prevent.dce"(%1) : (f32) -> ()
92    }
93  }
94  return
95}
96
97// Strided, but the strides are in the linearized space.
98// CHECK-LABEL: func @strided_cumulative
99func @strided_cumulative() {
100  %A = alloc() : memref<2x5xf32, affine_map<(d0, d1) -> (3*d0 + 17*d1)>>
101  // CHECK: affine.for %[[IV0:.*]] =
102  affine.for %i = 0 to 2 {
103    // CHECK: affine.for %[[IV1:.*]] =
104    affine.for %j = 0 to 5 {
105      // CHECK: affine.load %{{.*}}[%[[IV0]] * 3 + %[[IV1]] * 17] : memref<72xf32>
106      %1 = affine.load %A[%i, %j]  : memref<2x5xf32, affine_map<(d0, d1) -> (3*d0 + 17*d1)>>
107      "prevent.dce"(%1) : (f32) -> ()
108    }
109  }
110  return
111}
112
113// Symbolic operand for alloc, although unused. Tests replaceAllMemRefUsesWith
114// when the index remap has symbols.
115// CHECK-LABEL: func @symbolic_operands
116func @symbolic_operands(%s : index) {
117  // CHECK: alloc() : memref<100xf32>
118  %A = alloc()[%s] : memref<10x10xf32, affine_map<(d0,d1)[s0] -> (10*d0 + d1)>>
119  affine.for %i = 0 to 10 {
120    affine.for %j = 0 to 10 {
121      // CHECK: affine.load %{{.*}}[%{{.*}} * 10 + %{{.*}}] : memref<100xf32>
122      %1 = affine.load %A[%i, %j] : memref<10x10xf32, affine_map<(d0,d1)[s0] -> (10*d0 + d1)>>
123      "prevent.dce"(%1) : (f32) -> ()
124    }
125  }
126  return
127}
128
129// Semi-affine maps, normalization not implemented yet.
130// CHECK-LABEL: func @semi_affine_layout_map
131func @semi_affine_layout_map(%s0: index, %s1: index) {
132  %A = alloc()[%s0, %s1] : memref<256x1024xf32, affine_map<(d0, d1)[s0, s1] -> (d0*s0 + d1*s1)>>
133  affine.for %i = 0 to 256 {
134    affine.for %j = 0 to 1024 {
135      // CHECK: memref<256x1024xf32, #map{{[0-9]+}}>
136      affine.load %A[%i, %j] : memref<256x1024xf32, affine_map<(d0, d1)[s0, s1] -> (d0*s0 + d1*s1)>>
137    }
138  }
139  return
140}
141
142// CHECK-LABEL: func @alignment
143func @alignment() {
144  %A = alloc() {alignment = 32 : i64}: memref<64x128x256xf32, affine_map<(d0, d1, d2) -> (d2, d0, d1)>>
145  // CHECK-NEXT: alloc() {alignment = 32 : i64} : memref<256x64x128xf32>
146  return
147}
148
149#tile = affine_map < (i)->(i floordiv 4, i mod 4) >
150
151// Following test cases check the inter-procedural memref normalization.
152
153// Test case 1: Check normalization for multiple memrefs in a function argument list.
154// CHECK-LABEL: func @multiple_argument_type
155// CHECK-SAME:  (%[[A:arg[0-9]+]]: memref<4x4xf64>, %[[B:arg[0-9]+]]: f64, %[[C:arg[0-9]+]]: memref<2x4xf64>, %[[D:arg[0-9]+]]: memref<24xf64>) -> f64
156func @multiple_argument_type(%A: memref<16xf64, #tile>, %B: f64, %C: memref<8xf64, #tile>, %D: memref<24xf64>) -> f64 {
157  %a = affine.load %A[0] : memref<16xf64, #tile>
158  %p = mulf %a, %a : f64
159  affine.store %p, %A[10] : memref<16xf64, #tile>
160  call @single_argument_type(%C): (memref<8xf64, #tile>) -> ()
161  return %B : f64
162}
163
164// CHECK: %[[a:[0-9]+]] = affine.load %[[A]][0, 0] : memref<4x4xf64>
165// CHECK: %[[p:[0-9]+]] = mulf %[[a]], %[[a]] : f64
166// CHECK: affine.store %[[p]], %[[A]][2, 2] : memref<4x4xf64>
167// CHECK: call @single_argument_type(%[[C]]) : (memref<2x4xf64>) -> ()
168// CHECK: return %[[B]] : f64
169
170// Test case 2: Check normalization for single memref argument in a function.
171// CHECK-LABEL: func @single_argument_type
172// CHECK-SAME: (%[[C:arg[0-9]+]]: memref<2x4xf64>)
173func @single_argument_type(%C : memref<8xf64, #tile>) {
174  %a = alloc(): memref<8xf64, #tile>
175  %b = alloc(): memref<16xf64, #tile>
176  %d = constant 23.0 : f64
177  %e = alloc(): memref<24xf64>
178  call @single_argument_type(%a): (memref<8xf64, #tile>) -> ()
179  call @single_argument_type(%C): (memref<8xf64, #tile>) -> ()
180  call @multiple_argument_type(%b, %d, %a, %e): (memref<16xf64, #tile>, f64, memref<8xf64, #tile>, memref<24xf64>) -> f64
181  return
182}
183
184// CHECK: %[[a:[0-9]+]] = alloc() : memref<2x4xf64>
185// CHECK: %[[b:[0-9]+]] = alloc() : memref<4x4xf64>
186// CHECK: %cst = constant 2.300000e+01 : f64
187// CHECK: %[[e:[0-9]+]] = alloc() : memref<24xf64>
188// CHECK: call @single_argument_type(%[[a]]) : (memref<2x4xf64>) -> ()
189// CHECK: call @single_argument_type(%[[C]]) : (memref<2x4xf64>) -> ()
190// CHECK: call @multiple_argument_type(%[[b]], %cst, %[[a]], %[[e]]) : (memref<4x4xf64>, f64, memref<2x4xf64>, memref<24xf64>) -> f64
191
192// Test case 3: Check function returning any other type except memref.
193// CHECK-LABEL: func @non_memref_ret
194// CHECK-SAME: (%[[C:arg[0-9]+]]: memref<2x4xf64>) -> i1
195func @non_memref_ret(%A: memref<8xf64, #tile>) -> i1 {
196  %d = constant 1 : i1
197  return %d : i1
198}
199
200// Test cases here onwards deal with normalization of memref in function signature, caller site.
201
202// Test case 4: Check successful memref normalization in case of inter/intra-recursive calls.
203// CHECK-LABEL: func @ret_multiple_argument_type
204// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<4x4xf64>, %[[B:arg[0-9]+]]: f64, %[[C:arg[0-9]+]]: memref<2x4xf64>) -> (memref<2x4xf64>, f64)
205func @ret_multiple_argument_type(%A: memref<16xf64, #tile>, %B: f64, %C: memref<8xf64, #tile>) -> (memref<8xf64, #tile>, f64) {
206  %a = affine.load %A[0] : memref<16xf64, #tile>
207  %p = mulf %a, %a : f64
208  %cond = constant 1 : i1
209  cond_br %cond, ^bb1, ^bb2
210  ^bb1:
211    %res1, %res2 = call @ret_single_argument_type(%C) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>)
212    return %res2, %p: memref<8xf64, #tile>, f64
213  ^bb2:
214    return %C, %p: memref<8xf64, #tile>, f64
215}
216
217// CHECK:   %[[a:[0-9]+]] = affine.load %[[A]][0, 0] : memref<4x4xf64>
218// CHECK:   %[[p:[0-9]+]] = mulf %[[a]], %[[a]] : f64
219// CHECK:   %true = constant true
220// CHECK:   cond_br %true, ^bb1, ^bb2
221// CHECK: ^bb1:  // pred: ^bb0
222// CHECK:   %[[res:[0-9]+]]:2 = call @ret_single_argument_type(%[[C]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
223// CHECK:   return %[[res]]#1, %[[p]] : memref<2x4xf64>, f64
224// CHECK: ^bb2:  // pred: ^bb0
225// CHECK:   return %{{.*}}, %{{.*}} : memref<2x4xf64>, f64
226
227// CHECK-LABEL: func @ret_single_argument_type
228// CHECK-SAME: (%[[C:arg[0-9]+]]: memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
229func @ret_single_argument_type(%C: memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>){
230  %a = alloc() : memref<8xf64, #tile>
231  %b = alloc() : memref<16xf64, #tile>
232  %d = constant 23.0 : f64
233  call @ret_single_argument_type(%a) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>)
234  call @ret_single_argument_type(%C) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>)
235  %res1, %res2 = call @ret_multiple_argument_type(%b, %d, %a) : (memref<16xf64, #tile>, f64, memref<8xf64, #tile>) -> (memref<8xf64, #tile>, f64)
236  %res3, %res4 = call @ret_single_argument_type(%res1) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>)
237  return %b, %a: memref<16xf64, #tile>, memref<8xf64, #tile>
238}
239
240// CHECK: %[[a:[0-9]+]] = alloc() : memref<2x4xf64>
241// CHECK: %[[b:[0-9]+]] = alloc() : memref<4x4xf64>
242// CHECK: %cst = constant 2.300000e+01 : f64
243// CHECK: %[[resA:[0-9]+]]:2 = call @ret_single_argument_type(%[[a]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
244// CHECK: %[[resB:[0-9]+]]:2 = call @ret_single_argument_type(%[[C]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
245// CHECK: %[[resC:[0-9]+]]:2 = call @ret_multiple_argument_type(%[[b]], %cst, %[[a]]) : (memref<4x4xf64>, f64, memref<2x4xf64>) -> (memref<2x4xf64>, f64)
246// CHECK: %[[resD:[0-9]+]]:2 = call @ret_single_argument_type(%[[resC]]#0) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>)
247// CHECK: return %{{.*}}, %{{.*}} : memref<4x4xf64>, memref<2x4xf64>
248
249// Test case set #5: To check normalization in a chain of interconnected functions.
250// CHECK-LABEL: func @func_A
251// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<2x4xf64>)
252func @func_A(%A: memref<8xf64, #tile>) {
253  call @func_B(%A) : (memref<8xf64, #tile>) -> ()
254  return
255}
256// CHECK: call @func_B(%[[A]]) : (memref<2x4xf64>) -> ()
257
258// CHECK-LABEL: func @func_B
259// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<2x4xf64>)
260func @func_B(%A: memref<8xf64, #tile>) {
261  call @func_C(%A) : (memref<8xf64, #tile>) -> ()
262  return
263}
264// CHECK: call @func_C(%[[A]]) : (memref<2x4xf64>) -> ()
265
266// CHECK-LABEL: func @func_C
267// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<2x4xf64>)
268func @func_C(%A: memref<8xf64, #tile>) {
269  return
270}
271
272// Test case set #6: Checking if no normalization takes place in a scenario: A -> B -> C and B has an unsupported type.
273// CHECK-LABEL: func @some_func_A
274// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>)
275func @some_func_A(%A: memref<8xf64, #tile>) {
276  call @some_func_B(%A) : (memref<8xf64, #tile>) -> ()
277  return
278}
279// CHECK: call @some_func_B(%[[A]]) : (memref<8xf64, #map{{[0-9]+}}>) -> ()
280
281// CHECK-LABEL: func @some_func_B
282// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>)
283func @some_func_B(%A: memref<8xf64, #tile>) {
284  "test.test"(%A) : (memref<8xf64, #tile>) -> ()
285  call @some_func_C(%A) : (memref<8xf64, #tile>) -> ()
286  return
287}
288// CHECK: call @some_func_C(%[[A]]) : (memref<8xf64, #map{{[0-9]+}}>) -> ()
289
290// CHECK-LABEL: func @some_func_C
291// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>)
292func @some_func_C(%A: memref<8xf64, #tile>) {
293  return
294}
295
296// Test case set #7: Check normalization in case of external functions.
297// CHECK-LABEL: func private @external_func_A
298// CHECK-SAME: (memref<4x4xf64>)
299func private @external_func_A(memref<16xf64, #tile>) -> ()
300
301// CHECK-LABEL: func private @external_func_B
302// CHECK-SAME: (memref<4x4xf64>, f64) -> memref<2x4xf64>
303func private @external_func_B(memref<16xf64, #tile>, f64) -> (memref<8xf64, #tile>)
304
305// CHECK-LABEL: func @simply_call_external()
306func @simply_call_external() {
307  %a = alloc() : memref<16xf64, #tile>
308  call @external_func_A(%a) : (memref<16xf64, #tile>) -> ()
309  return
310}
311// CHECK: %[[a:[0-9]+]] = alloc() : memref<4x4xf64>
312// CHECK: call @external_func_A(%[[a]]) : (memref<4x4xf64>) -> ()
313
314// CHECK-LABEL: func @use_value_of_external
315// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<4x4xf64>, %[[B:arg[0-9]+]]: f64) -> memref<2x4xf64>
316func @use_value_of_external(%A: memref<16xf64, #tile>, %B: f64) -> (memref<8xf64, #tile>) {
317  %res = call @external_func_B(%A, %B) : (memref<16xf64, #tile>, f64) -> (memref<8xf64, #tile>)
318  return %res : memref<8xf64, #tile>
319}
320// CHECK: %[[res:[0-9]+]] = call @external_func_B(%[[A]], %[[B]]) : (memref<4x4xf64>, f64) -> memref<2x4xf64>
321// CHECK: return %{{.*}} : memref<2x4xf64>
322