1// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='func(parallel-loop-fusion)' -split-input-file | FileCheck %s 2 3func @fuse_empty_loops() { 4 %c2 = constant 2 : index 5 %c0 = constant 0 : index 6 %c1 = constant 1 : index 7 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 8 scf.yield 9 } 10 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 11 scf.yield 12 } 13 return 14} 15// CHECK-LABEL: func @fuse_empty_loops 16// CHECK: [[C2:%.*]] = constant 2 : index 17// CHECK: [[C0:%.*]] = constant 0 : index 18// CHECK: [[C1:%.*]] = constant 1 : index 19// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) 20// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { 21// CHECK: scf.yield 22// CHECK: } 23// CHECK-NOT: scf.parallel 24 25// ----- 26 27func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>, 28 %C: memref<2x2xf32>, %result: memref<2x2xf32>) { 29 %c2 = constant 2 : index 30 %c0 = constant 0 : index 31 %c1 = constant 1 : index 32 %sum = alloc() : memref<2x2xf32> 33 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 34 %B_elem = load %B[%i, %j] : memref<2x2xf32> 35 %C_elem = load %C[%i, %j] : memref<2x2xf32> 36 %sum_elem = addf %B_elem, %C_elem : f32 37 store %sum_elem, %sum[%i, %j] : memref<2x2xf32> 38 scf.yield 39 } 40 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 41 %sum_elem = load %sum[%i, %j] : memref<2x2xf32> 42 %A_elem = load %A[%i, %j] : memref<2x2xf32> 43 %product_elem = mulf %sum_elem, %A_elem : f32 44 store %product_elem, %result[%i, %j] : memref<2x2xf32> 45 scf.yield 46 } 47 dealloc %sum : memref<2x2xf32> 48 return 49} 50// CHECK-LABEL: func @fuse_two 51// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}}, 52// CHECK-SAME: [[RESULT:%.*]]: {{.*}}) { 53// CHECK: [[C2:%.*]] = constant 2 : index 54// CHECK: [[C0:%.*]] = constant 0 : index 55// CHECK: [[C1:%.*]] = constant 1 : index 56// CHECK: [[SUM:%.*]] = alloc() 57// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) 58// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { 59// CHECK: [[B_ELEM:%.*]] = load [[B]]{{\[}}[[I]], [[J]]] 60// CHECK: [[C_ELEM:%.*]] = load [[C]]{{\[}}[[I]], [[J]]] 61// CHECK: [[SUM_ELEM:%.*]] = addf [[B_ELEM]], [[C_ELEM]] 62// CHECK: store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]] 63// CHECK: [[SUM_ELEM_:%.*]] = load [[SUM]]{{\[}}[[I]], [[J]]] 64// CHECK: [[A_ELEM:%.*]] = load [[A]]{{\[}}[[I]], [[J]]] 65// CHECK: [[PRODUCT_ELEM:%.*]] = mulf [[SUM_ELEM_]], [[A_ELEM]] 66// CHECK: store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]] 67// CHECK: scf.yield 68// CHECK: } 69// CHECK: dealloc [[SUM]] 70 71// ----- 72 73func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>, 74 %result: memref<100x10xf32>) { 75 %c100 = constant 100 : index 76 %c10 = constant 10 : index 77 %c0 = constant 0 : index 78 %c1 = constant 1 : index 79 %broadcast_rhs = alloc() : memref<100x10xf32> 80 %diff = alloc() : memref<100x10xf32> 81 scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) { 82 %rhs_elem = load %rhs[%i] : memref<100xf32> 83 store %rhs_elem, %broadcast_rhs[%i, %j] : memref<100x10xf32> 84 scf.yield 85 } 86 scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) { 87 %lhs_elem = load %lhs[%i, %j] : memref<100x10xf32> 88 %broadcast_rhs_elem = load %broadcast_rhs[%i, %j] : memref<100x10xf32> 89 %diff_elem = subf %lhs_elem, %broadcast_rhs_elem : f32 90 store %diff_elem, %diff[%i, %j] : memref<100x10xf32> 91 scf.yield 92 } 93 scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) { 94 %diff_elem = load %diff[%i, %j] : memref<100x10xf32> 95 %exp_elem = exp %diff_elem : f32 96 store %exp_elem, %result[%i, %j] : memref<100x10xf32> 97 scf.yield 98 } 99 dealloc %broadcast_rhs : memref<100x10xf32> 100 dealloc %diff : memref<100x10xf32> 101 return 102} 103// CHECK-LABEL: func @fuse_three 104// CHECK-SAME: ([[LHS:%.*]]: memref<100x10xf32>, [[RHS:%.*]]: memref<100xf32>, 105// CHECK-SAME: [[RESULT:%.*]]: memref<100x10xf32>) { 106// CHECK: [[C100:%.*]] = constant 100 : index 107// CHECK: [[C10:%.*]] = constant 10 : index 108// CHECK: [[C0:%.*]] = constant 0 : index 109// CHECK: [[C1:%.*]] = constant 1 : index 110// CHECK: [[BROADCAST_RHS:%.*]] = alloc() 111// CHECK: [[DIFF:%.*]] = alloc() 112// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) 113// CHECK-SAME: to ([[C100]], [[C10]]) step ([[C1]], [[C1]]) { 114// CHECK: [[RHS_ELEM:%.*]] = load [[RHS]]{{\[}}[[I]]] 115// CHECK: store [[RHS_ELEM]], [[BROADCAST_RHS]]{{\[}}[[I]], [[J]]] 116// CHECK: [[LHS_ELEM:%.*]] = load [[LHS]]{{\[}}[[I]], [[J]]] 117// CHECK: [[BROADCAST_RHS_ELEM:%.*]] = load [[BROADCAST_RHS]] 118// CHECK: [[DIFF_ELEM:%.*]] = subf [[LHS_ELEM]], [[BROADCAST_RHS_ELEM]] 119// CHECK: store [[DIFF_ELEM]], [[DIFF]]{{\[}}[[I]], [[J]]] 120// CHECK: [[DIFF_ELEM_:%.*]] = load [[DIFF]]{{\[}}[[I]], [[J]]] 121// CHECK: [[EXP_ELEM:%.*]] = exp [[DIFF_ELEM_]] 122// CHECK: store [[EXP_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]] 123// CHECK: scf.yield 124// CHECK: } 125// CHECK: dealloc [[BROADCAST_RHS]] 126// CHECK: dealloc [[DIFF]] 127 128// ----- 129 130func @do_not_fuse_nested_ploop1() { 131 %c2 = constant 2 : index 132 %c0 = constant 0 : index 133 %c1 = constant 1 : index 134 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 135 scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 136 scf.yield 137 } 138 scf.yield 139 } 140 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 141 scf.yield 142 } 143 return 144} 145// CHECK-LABEL: func @do_not_fuse_nested_ploop1 146// CHECK: scf.parallel 147// CHECK: scf.parallel 148// CHECK: scf.parallel 149 150// ----- 151 152func @do_not_fuse_nested_ploop2() { 153 %c2 = constant 2 : index 154 %c0 = constant 0 : index 155 %c1 = constant 1 : index 156 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 157 scf.yield 158 } 159 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 160 scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 161 scf.yield 162 } 163 scf.yield 164 } 165 return 166} 167// CHECK-LABEL: func @do_not_fuse_nested_ploop2 168// CHECK: scf.parallel 169// CHECK: scf.parallel 170// CHECK: scf.parallel 171 172// ----- 173 174func @do_not_fuse_loops_unmatching_num_loops() { 175 %c2 = constant 2 : index 176 %c0 = constant 0 : index 177 %c1 = constant 1 : index 178 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 179 scf.yield 180 } 181 scf.parallel (%i) = (%c0) to (%c2) step (%c1) { 182 scf.yield 183 } 184 return 185} 186// CHECK-LABEL: func @do_not_fuse_loops_unmatching_num_loops 187// CHECK: scf.parallel 188// CHECK: scf.parallel 189 190// ----- 191 192func @do_not_fuse_loops_with_side_effecting_ops_in_between() { 193 %c2 = constant 2 : index 194 %c0 = constant 0 : index 195 %c1 = constant 1 : index 196 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 197 scf.yield 198 } 199 %buffer = alloc() : memref<2x2xf32> 200 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 201 scf.yield 202 } 203 return 204} 205// CHECK-LABEL: func @do_not_fuse_loops_with_side_effecting_ops_in_between 206// CHECK: scf.parallel 207// CHECK: scf.parallel 208 209// ----- 210 211func @do_not_fuse_loops_unmatching_iteration_space() { 212 %c0 = constant 0 : index 213 %c1 = constant 1 : index 214 %c2 = constant 2 : index 215 %c4 = constant 4 : index 216 scf.parallel (%i, %j) = (%c0, %c0) to (%c4, %c4) step (%c2, %c2) { 217 scf.yield 218 } 219 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 220 scf.yield 221 } 222 return 223} 224// CHECK-LABEL: func @do_not_fuse_loops_unmatching_iteration_space 225// CHECK: scf.parallel 226// CHECK: scf.parallel 227 228// ----- 229 230func @do_not_fuse_unmatching_write_read_patterns( 231 %A: memref<2x2xf32>, %B: memref<2x2xf32>, 232 %C: memref<2x2xf32>, %result: memref<2x2xf32>) { 233 %c2 = constant 2 : index 234 %c0 = constant 0 : index 235 %c1 = constant 1 : index 236 %common_buf = alloc() : memref<2x2xf32> 237 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 238 %B_elem = load %B[%i, %j] : memref<2x2xf32> 239 %C_elem = load %C[%i, %j] : memref<2x2xf32> 240 %sum_elem = addf %B_elem, %C_elem : f32 241 store %sum_elem, %common_buf[%i, %j] : memref<2x2xf32> 242 scf.yield 243 } 244 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 245 %k = addi %i, %c1 : index 246 %sum_elem = load %common_buf[%k, %j] : memref<2x2xf32> 247 %A_elem = load %A[%i, %j] : memref<2x2xf32> 248 %product_elem = mulf %sum_elem, %A_elem : f32 249 store %product_elem, %result[%i, %j] : memref<2x2xf32> 250 scf.yield 251 } 252 dealloc %common_buf : memref<2x2xf32> 253 return 254} 255// CHECK-LABEL: func @do_not_fuse_unmatching_write_read_patterns 256// CHECK: scf.parallel 257// CHECK: scf.parallel 258 259// ----- 260 261func @do_not_fuse_unmatching_read_write_patterns( 262 %A: memref<2x2xf32>, %B: memref<2x2xf32>, %common_buf: memref<2x2xf32>) { 263 %c2 = constant 2 : index 264 %c0 = constant 0 : index 265 %c1 = constant 1 : index 266 %sum = alloc() : memref<2x2xf32> 267 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 268 %B_elem = load %B[%i, %j] : memref<2x2xf32> 269 %C_elem = load %common_buf[%i, %j] : memref<2x2xf32> 270 %sum_elem = addf %B_elem, %C_elem : f32 271 store %sum_elem, %sum[%i, %j] : memref<2x2xf32> 272 scf.yield 273 } 274 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 275 %k = addi %i, %c1 : index 276 %sum_elem = load %sum[%k, %j] : memref<2x2xf32> 277 %A_elem = load %A[%i, %j] : memref<2x2xf32> 278 %product_elem = mulf %sum_elem, %A_elem : f32 279 store %product_elem, %common_buf[%j, %i] : memref<2x2xf32> 280 scf.yield 281 } 282 dealloc %sum : memref<2x2xf32> 283 return 284} 285// CHECK-LABEL: func @do_not_fuse_unmatching_read_write_patterns 286// CHECK: scf.parallel 287// CHECK: scf.parallel 288 289// ----- 290 291func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() { 292 %c2 = constant 2 : index 293 %c0 = constant 0 : index 294 %c1 = constant 1 : index 295 %buffer = alloc() : memref<2x2xf32> 296 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 297 scf.yield 298 } 299 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 300 %A = subview %buffer[%c0, %c0][%c2, %c2][%c1, %c1] 301 : memref<2x2xf32> to memref<?x?xf32, offset: ?, strides:[?, ?]> 302 %A_elem = load %A[%i, %j] : memref<?x?xf32, offset: ?, strides:[?, ?]> 303 scf.yield 304 } 305 return 306} 307// CHECK-LABEL: func @do_not_fuse_loops_with_memref_defined_in_loop_bodies 308// CHECK: scf.parallel 309// CHECK: scf.parallel 310 311// ----- 312 313func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>, 314 %C: memref<2x2xf32>, %result: memref<2x2xf32>) { 315 %c2 = constant 2 : index 316 %c0 = constant 0 : index 317 %c1 = constant 1 : index 318 %sum = alloc() : memref<2x2xf32> 319 scf.parallel (%k) = (%c0) to (%c2) step (%c1) { 320 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 321 %B_elem = load %B[%i, %j] : memref<2x2xf32> 322 %C_elem = load %C[%i, %j] : memref<2x2xf32> 323 %sum_elem = addf %B_elem, %C_elem : f32 324 store %sum_elem, %sum[%i, %j] : memref<2x2xf32> 325 scf.yield 326 } 327 scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { 328 %sum_elem = load %sum[%i, %j] : memref<2x2xf32> 329 %A_elem = load %A[%i, %j] : memref<2x2xf32> 330 %product_elem = mulf %sum_elem, %A_elem : f32 331 store %product_elem, %result[%i, %j] : memref<2x2xf32> 332 scf.yield 333 } 334 } 335 dealloc %sum : memref<2x2xf32> 336 return 337} 338// CHECK-LABEL: func @nested_fuse 339// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}}, 340// CHECK-SAME: [[RESULT:%.*]]: {{.*}}) { 341// CHECK: [[C2:%.*]] = constant 2 : index 342// CHECK: [[C0:%.*]] = constant 0 : index 343// CHECK: [[C1:%.*]] = constant 1 : index 344// CHECK: [[SUM:%.*]] = alloc() 345// CHECK: scf.parallel 346// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) 347// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { 348// CHECK: [[B_ELEM:%.*]] = load [[B]]{{\[}}[[I]], [[J]]] 349// CHECK: [[C_ELEM:%.*]] = load [[C]]{{\[}}[[I]], [[J]]] 350// CHECK: [[SUM_ELEM:%.*]] = addf [[B_ELEM]], [[C_ELEM]] 351// CHECK: store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]] 352// CHECK: [[SUM_ELEM_:%.*]] = load [[SUM]]{{\[}}[[I]], [[J]]] 353// CHECK: [[A_ELEM:%.*]] = load [[A]]{{\[}}[[I]], [[J]]] 354// CHECK: [[PRODUCT_ELEM:%.*]] = mulf [[SUM_ELEM_]], [[A_ELEM]] 355// CHECK: store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]] 356// CHECK: scf.yield 357// CHECK: } 358// CHECK: } 359// CHECK: dealloc [[SUM]] 360