1// RUN: mlir-opt -normalize-memrefs -allow-unregistered-dialect %s | FileCheck %s 2 3// This file tests whether the memref type having non-trivial map layouts 4// are normalized to trivial (identity) layouts. 5 6// CHECK-LABEL: func @permute() 7func @permute() { 8 %A = alloc() : memref<64x256xf32, affine_map<(d0, d1) -> (d1, d0)>> 9 affine.for %i = 0 to 64 { 10 affine.for %j = 0 to 256 { 11 %1 = affine.load %A[%i, %j] : memref<64x256xf32, affine_map<(d0, d1) -> (d1, d0)>> 12 "prevent.dce"(%1) : (f32) -> () 13 } 14 } 15 dealloc %A : memref<64x256xf32, affine_map<(d0, d1) -> (d1, d0)>> 16 return 17} 18// The old memref alloc should disappear. 19// CHECK-NOT: memref<64x256xf32> 20// CHECK: [[MEM:%[0-9]+]] = alloc() : memref<256x64xf32> 21// CHECK-NEXT: affine.for %[[I:arg[0-9]+]] = 0 to 64 { 22// CHECK-NEXT: affine.for %[[J:arg[0-9]+]] = 0 to 256 { 23// CHECK-NEXT: affine.load [[MEM]][%[[J]], %[[I]]] : memref<256x64xf32> 24// CHECK-NEXT: "prevent.dce" 25// CHECK-NEXT: } 26// CHECK-NEXT: } 27// CHECK-NEXT: dealloc [[MEM]] 28// CHECK-NEXT: return 29 30// CHECK-LABEL: func @shift 31func @shift(%idx : index) { 32 // CHECK-NEXT: alloc() : memref<65xf32> 33 %A = alloc() : memref<64xf32, affine_map<(d0) -> (d0 + 1)>> 34 // CHECK-NEXT: affine.load %{{.*}}[symbol(%arg0) + 1] : memref<65xf32> 35 affine.load %A[%idx] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>> 36 affine.for %i = 0 to 64 { 37 %1 = affine.load %A[%i] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>> 38 "prevent.dce"(%1) : (f32) -> () 39 // CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}} + 1] : memref<65xf32> 40 } 41 return 42} 43 44// CHECK-LABEL: func @high_dim_permute() 45func @high_dim_permute() { 46 // CHECK-NOT: memref<64x128x256xf32, 47 %A = alloc() : memref<64x128x256xf32, affine_map<(d0, d1, d2) -> (d2, d0, d1)>> 48 // CHECK: %[[I:arg[0-9]+]] 49 affine.for %i = 0 to 64 { 50 // CHECK: %[[J:arg[0-9]+]] 51 affine.for %j = 0 to 128 { 52 // CHECK: %[[K:arg[0-9]+]] 53 affine.for %k = 0 to 256 { 54 %1 = affine.load %A[%i, %j, %k] : memref<64x128x256xf32, affine_map<(d0, d1, d2) -> (d2, d0, d1)>> 55 // CHECK: %{{.*}} = affine.load %{{.*}}[%[[K]], %[[I]], %[[J]]] : memref<256x64x128xf32> 56 "prevent.dce"(%1) : (f32) -> () 57 } 58 } 59 } 60 return 61} 62 63// CHECK-LABEL: func @invalid_map 64func @invalid_map() { 65 %A = alloc() : memref<64x128xf32, affine_map<(d0, d1) -> (d0, -d1 - 10)>> 66 // CHECK: %{{.*}} = alloc() : memref<64x128xf32, 67 return 68} 69 70// A tiled layout. 71// CHECK-LABEL: func @data_tiling 72func @data_tiling(%idx : index) { 73 // CHECK: alloc() : memref<8x32x8x16xf32> 74 %A = alloc() : memref<64x512xf32, affine_map<(d0, d1) -> (d0 floordiv 8, d1 floordiv 16, d0 mod 8, d1 mod 16)>> 75 // CHECK: affine.load %{{.*}}[symbol(%arg0) floordiv 8, symbol(%arg0) floordiv 16, symbol(%arg0) mod 8, symbol(%arg0) mod 16] 76 %1 = affine.load %A[%idx, %idx] : memref<64x512xf32, affine_map<(d0, d1) -> (d0 floordiv 8, d1 floordiv 16, d0 mod 8, d1 mod 16)>> 77 "prevent.dce"(%1) : (f32) -> () 78 return 79} 80 81// Strides 2 and 4 along respective dimensions. 82// CHECK-LABEL: func @strided 83func @strided() { 84 %A = alloc() : memref<64x128xf32, affine_map<(d0, d1) -> (2*d0, 4*d1)>> 85 // CHECK: affine.for %[[IV0:.*]] = 86 affine.for %i = 0 to 64 { 87 // CHECK: affine.for %[[IV1:.*]] = 88 affine.for %j = 0 to 128 { 89 // CHECK: affine.load %{{.*}}[%[[IV0]] * 2, %[[IV1]] * 4] : memref<127x509xf32> 90 %1 = affine.load %A[%i, %j] : memref<64x128xf32, affine_map<(d0, d1) -> (2*d0, 4*d1)>> 91 "prevent.dce"(%1) : (f32) -> () 92 } 93 } 94 return 95} 96 97// Strided, but the strides are in the linearized space. 98// CHECK-LABEL: func @strided_cumulative 99func @strided_cumulative() { 100 %A = alloc() : memref<2x5xf32, affine_map<(d0, d1) -> (3*d0 + 17*d1)>> 101 // CHECK: affine.for %[[IV0:.*]] = 102 affine.for %i = 0 to 2 { 103 // CHECK: affine.for %[[IV1:.*]] = 104 affine.for %j = 0 to 5 { 105 // CHECK: affine.load %{{.*}}[%[[IV0]] * 3 + %[[IV1]] * 17] : memref<72xf32> 106 %1 = affine.load %A[%i, %j] : memref<2x5xf32, affine_map<(d0, d1) -> (3*d0 + 17*d1)>> 107 "prevent.dce"(%1) : (f32) -> () 108 } 109 } 110 return 111} 112 113// Symbolic operand for alloc, although unused. Tests replaceAllMemRefUsesWith 114// when the index remap has symbols. 115// CHECK-LABEL: func @symbolic_operands 116func @symbolic_operands(%s : index) { 117 // CHECK: alloc() : memref<100xf32> 118 %A = alloc()[%s] : memref<10x10xf32, affine_map<(d0,d1)[s0] -> (10*d0 + d1)>> 119 affine.for %i = 0 to 10 { 120 affine.for %j = 0 to 10 { 121 // CHECK: affine.load %{{.*}}[%{{.*}} * 10 + %{{.*}}] : memref<100xf32> 122 %1 = affine.load %A[%i, %j] : memref<10x10xf32, affine_map<(d0,d1)[s0] -> (10*d0 + d1)>> 123 "prevent.dce"(%1) : (f32) -> () 124 } 125 } 126 return 127} 128 129// Semi-affine maps, normalization not implemented yet. 130// CHECK-LABEL: func @semi_affine_layout_map 131func @semi_affine_layout_map(%s0: index, %s1: index) { 132 %A = alloc()[%s0, %s1] : memref<256x1024xf32, affine_map<(d0, d1)[s0, s1] -> (d0*s0 + d1*s1)>> 133 affine.for %i = 0 to 256 { 134 affine.for %j = 0 to 1024 { 135 // CHECK: memref<256x1024xf32, #map{{[0-9]+}}> 136 affine.load %A[%i, %j] : memref<256x1024xf32, affine_map<(d0, d1)[s0, s1] -> (d0*s0 + d1*s1)>> 137 } 138 } 139 return 140} 141 142// CHECK-LABEL: func @alignment 143func @alignment() { 144 %A = alloc() {alignment = 32 : i64}: memref<64x128x256xf32, affine_map<(d0, d1, d2) -> (d2, d0, d1)>> 145 // CHECK-NEXT: alloc() {alignment = 32 : i64} : memref<256x64x128xf32> 146 return 147} 148 149#tile = affine_map < (i)->(i floordiv 4, i mod 4) > 150 151// Following test cases check the inter-procedural memref normalization. 152 153// Test case 1: Check normalization for multiple memrefs in a function argument list. 154// CHECK-LABEL: func @multiple_argument_type 155// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<4x4xf64>, %[[B:arg[0-9]+]]: f64, %[[C:arg[0-9]+]]: memref<2x4xf64>, %[[D:arg[0-9]+]]: memref<24xf64>) -> f64 156func @multiple_argument_type(%A: memref<16xf64, #tile>, %B: f64, %C: memref<8xf64, #tile>, %D: memref<24xf64>) -> f64 { 157 %a = affine.load %A[0] : memref<16xf64, #tile> 158 %p = mulf %a, %a : f64 159 affine.store %p, %A[10] : memref<16xf64, #tile> 160 call @single_argument_type(%C): (memref<8xf64, #tile>) -> () 161 return %B : f64 162} 163 164// CHECK: %[[a:[0-9]+]] = affine.load %[[A]][0, 0] : memref<4x4xf64> 165// CHECK: %[[p:[0-9]+]] = mulf %[[a]], %[[a]] : f64 166// CHECK: affine.store %[[p]], %[[A]][2, 2] : memref<4x4xf64> 167// CHECK: call @single_argument_type(%[[C]]) : (memref<2x4xf64>) -> () 168// CHECK: return %[[B]] : f64 169 170// Test case 2: Check normalization for single memref argument in a function. 171// CHECK-LABEL: func @single_argument_type 172// CHECK-SAME: (%[[C:arg[0-9]+]]: memref<2x4xf64>) 173func @single_argument_type(%C : memref<8xf64, #tile>) { 174 %a = alloc(): memref<8xf64, #tile> 175 %b = alloc(): memref<16xf64, #tile> 176 %d = constant 23.0 : f64 177 %e = alloc(): memref<24xf64> 178 call @single_argument_type(%a): (memref<8xf64, #tile>) -> () 179 call @single_argument_type(%C): (memref<8xf64, #tile>) -> () 180 call @multiple_argument_type(%b, %d, %a, %e): (memref<16xf64, #tile>, f64, memref<8xf64, #tile>, memref<24xf64>) -> f64 181 return 182} 183 184// CHECK: %[[a:[0-9]+]] = alloc() : memref<2x4xf64> 185// CHECK: %[[b:[0-9]+]] = alloc() : memref<4x4xf64> 186// CHECK: %cst = constant 2.300000e+01 : f64 187// CHECK: %[[e:[0-9]+]] = alloc() : memref<24xf64> 188// CHECK: call @single_argument_type(%[[a]]) : (memref<2x4xf64>) -> () 189// CHECK: call @single_argument_type(%[[C]]) : (memref<2x4xf64>) -> () 190// CHECK: call @multiple_argument_type(%[[b]], %cst, %[[a]], %[[e]]) : (memref<4x4xf64>, f64, memref<2x4xf64>, memref<24xf64>) -> f64 191 192// Test case 3: Check function returning any other type except memref. 193// CHECK-LABEL: func @non_memref_ret 194// CHECK-SAME: (%[[C:arg[0-9]+]]: memref<2x4xf64>) -> i1 195func @non_memref_ret(%A: memref<8xf64, #tile>) -> i1 { 196 %d = constant 1 : i1 197 return %d : i1 198} 199 200// Test cases here onwards deal with normalization of memref in function signature, caller site. 201 202// Test case 4: Check successful memref normalization in case of inter/intra-recursive calls. 203// CHECK-LABEL: func @ret_multiple_argument_type 204// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<4x4xf64>, %[[B:arg[0-9]+]]: f64, %[[C:arg[0-9]+]]: memref<2x4xf64>) -> (memref<2x4xf64>, f64) 205func @ret_multiple_argument_type(%A: memref<16xf64, #tile>, %B: f64, %C: memref<8xf64, #tile>) -> (memref<8xf64, #tile>, f64) { 206 %a = affine.load %A[0] : memref<16xf64, #tile> 207 %p = mulf %a, %a : f64 208 %cond = constant 1 : i1 209 cond_br %cond, ^bb1, ^bb2 210 ^bb1: 211 %res1, %res2 = call @ret_single_argument_type(%C) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>) 212 return %res2, %p: memref<8xf64, #tile>, f64 213 ^bb2: 214 return %C, %p: memref<8xf64, #tile>, f64 215} 216 217// CHECK: %[[a:[0-9]+]] = affine.load %[[A]][0, 0] : memref<4x4xf64> 218// CHECK: %[[p:[0-9]+]] = mulf %[[a]], %[[a]] : f64 219// CHECK: %true = constant true 220// CHECK: cond_br %true, ^bb1, ^bb2 221// CHECK: ^bb1: // pred: ^bb0 222// CHECK: %[[res:[0-9]+]]:2 = call @ret_single_argument_type(%[[C]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>) 223// CHECK: return %[[res]]#1, %[[p]] : memref<2x4xf64>, f64 224// CHECK: ^bb2: // pred: ^bb0 225// CHECK: return %{{.*}}, %{{.*}} : memref<2x4xf64>, f64 226 227// CHECK-LABEL: func @ret_single_argument_type 228// CHECK-SAME: (%[[C:arg[0-9]+]]: memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>) 229func @ret_single_argument_type(%C: memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>){ 230 %a = alloc() : memref<8xf64, #tile> 231 %b = alloc() : memref<16xf64, #tile> 232 %d = constant 23.0 : f64 233 call @ret_single_argument_type(%a) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>) 234 call @ret_single_argument_type(%C) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>) 235 %res1, %res2 = call @ret_multiple_argument_type(%b, %d, %a) : (memref<16xf64, #tile>, f64, memref<8xf64, #tile>) -> (memref<8xf64, #tile>, f64) 236 %res3, %res4 = call @ret_single_argument_type(%res1) : (memref<8xf64, #tile>) -> (memref<16xf64, #tile>, memref<8xf64, #tile>) 237 return %b, %a: memref<16xf64, #tile>, memref<8xf64, #tile> 238} 239 240// CHECK: %[[a:[0-9]+]] = alloc() : memref<2x4xf64> 241// CHECK: %[[b:[0-9]+]] = alloc() : memref<4x4xf64> 242// CHECK: %cst = constant 2.300000e+01 : f64 243// CHECK: %[[resA:[0-9]+]]:2 = call @ret_single_argument_type(%[[a]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>) 244// CHECK: %[[resB:[0-9]+]]:2 = call @ret_single_argument_type(%[[C]]) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>) 245// CHECK: %[[resC:[0-9]+]]:2 = call @ret_multiple_argument_type(%[[b]], %cst, %[[a]]) : (memref<4x4xf64>, f64, memref<2x4xf64>) -> (memref<2x4xf64>, f64) 246// CHECK: %[[resD:[0-9]+]]:2 = call @ret_single_argument_type(%[[resC]]#0) : (memref<2x4xf64>) -> (memref<4x4xf64>, memref<2x4xf64>) 247// CHECK: return %{{.*}}, %{{.*}} : memref<4x4xf64>, memref<2x4xf64> 248 249// Test case set #5: To check normalization in a chain of interconnected functions. 250// CHECK-LABEL: func @func_A 251// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<2x4xf64>) 252func @func_A(%A: memref<8xf64, #tile>) { 253 call @func_B(%A) : (memref<8xf64, #tile>) -> () 254 return 255} 256// CHECK: call @func_B(%[[A]]) : (memref<2x4xf64>) -> () 257 258// CHECK-LABEL: func @func_B 259// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<2x4xf64>) 260func @func_B(%A: memref<8xf64, #tile>) { 261 call @func_C(%A) : (memref<8xf64, #tile>) -> () 262 return 263} 264// CHECK: call @func_C(%[[A]]) : (memref<2x4xf64>) -> () 265 266// CHECK-LABEL: func @func_C 267// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<2x4xf64>) 268func @func_C(%A: memref<8xf64, #tile>) { 269 return 270} 271 272// Test case set #6: Checking if no normalization takes place in a scenario: A -> B -> C and B has an unsupported type. 273// CHECK-LABEL: func @some_func_A 274// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>) 275func @some_func_A(%A: memref<8xf64, #tile>) { 276 call @some_func_B(%A) : (memref<8xf64, #tile>) -> () 277 return 278} 279// CHECK: call @some_func_B(%[[A]]) : (memref<8xf64, #map{{[0-9]+}}>) -> () 280 281// CHECK-LABEL: func @some_func_B 282// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>) 283func @some_func_B(%A: memref<8xf64, #tile>) { 284 "test.test"(%A) : (memref<8xf64, #tile>) -> () 285 call @some_func_C(%A) : (memref<8xf64, #tile>) -> () 286 return 287} 288// CHECK: call @some_func_C(%[[A]]) : (memref<8xf64, #map{{[0-9]+}}>) -> () 289 290// CHECK-LABEL: func @some_func_C 291// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<8xf64, #map{{[0-9]+}}>) 292func @some_func_C(%A: memref<8xf64, #tile>) { 293 return 294} 295 296// Test case set #7: Check normalization in case of external functions. 297// CHECK-LABEL: func private @external_func_A 298// CHECK-SAME: (memref<4x4xf64>) 299func private @external_func_A(memref<16xf64, #tile>) -> () 300 301// CHECK-LABEL: func private @external_func_B 302// CHECK-SAME: (memref<4x4xf64>, f64) -> memref<2x4xf64> 303func private @external_func_B(memref<16xf64, #tile>, f64) -> (memref<8xf64, #tile>) 304 305// CHECK-LABEL: func @simply_call_external() 306func @simply_call_external() { 307 %a = alloc() : memref<16xf64, #tile> 308 call @external_func_A(%a) : (memref<16xf64, #tile>) -> () 309 return 310} 311// CHECK: %[[a:[0-9]+]] = alloc() : memref<4x4xf64> 312// CHECK: call @external_func_A(%[[a]]) : (memref<4x4xf64>) -> () 313 314// CHECK-LABEL: func @use_value_of_external 315// CHECK-SAME: (%[[A:arg[0-9]+]]: memref<4x4xf64>, %[[B:arg[0-9]+]]: f64) -> memref<2x4xf64> 316func @use_value_of_external(%A: memref<16xf64, #tile>, %B: f64) -> (memref<8xf64, #tile>) { 317 %res = call @external_func_B(%A, %B) : (memref<16xf64, #tile>, f64) -> (memref<8xf64, #tile>) 318 return %res : memref<8xf64, #tile> 319} 320// CHECK: %[[res:[0-9]+]] = call @external_func_B(%[[A]], %[[B]]) : (memref<4x4xf64>, f64) -> memref<2x4xf64> 321// CHECK: return %{{.*}} : memref<2x4xf64> 322