1// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='func(parallel-loop-fusion)' -split-input-file | FileCheck %s
2
3func @fuse_empty_loops() {
4  %c2 = constant 2 : index
5  %c0 = constant 0 : index
6  %c1 = constant 1 : index
7  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
8    scf.yield
9  }
10  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
11    scf.yield
12  }
13  return
14}
15// CHECK-LABEL: func @fuse_empty_loops
16// CHECK:        [[C2:%.*]] = constant 2 : index
17// CHECK:        [[C0:%.*]] = constant 0 : index
18// CHECK:        [[C1:%.*]] = constant 1 : index
19// CHECK:        scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
20// CHECK-SAME:       to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
21// CHECK:          scf.yield
22// CHECK:        }
23// CHECK-NOT:    scf.parallel
24
25// -----
26
27func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
28                    %C: memref<2x2xf32>, %result: memref<2x2xf32>) {
29  %c2 = constant 2 : index
30  %c0 = constant 0 : index
31  %c1 = constant 1 : index
32  %sum = alloc()  : memref<2x2xf32>
33  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
34    %B_elem = load %B[%i, %j] : memref<2x2xf32>
35    %C_elem = load %C[%i, %j] : memref<2x2xf32>
36    %sum_elem = addf %B_elem, %C_elem : f32
37    store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
38    scf.yield
39  }
40  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
41    %sum_elem = load %sum[%i, %j] : memref<2x2xf32>
42    %A_elem = load %A[%i, %j] : memref<2x2xf32>
43    %product_elem = mulf %sum_elem, %A_elem : f32
44    store %product_elem, %result[%i, %j] : memref<2x2xf32>
45    scf.yield
46  }
47  dealloc %sum : memref<2x2xf32>
48  return
49}
50// CHECK-LABEL: func @fuse_two
51// CHECK-SAME:   ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}},
52// CHECK-SAME:    [[RESULT:%.*]]: {{.*}}) {
53// CHECK:      [[C2:%.*]] = constant 2 : index
54// CHECK:      [[C0:%.*]] = constant 0 : index
55// CHECK:      [[C1:%.*]] = constant 1 : index
56// CHECK:      [[SUM:%.*]] = alloc()
57// CHECK:      scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
58// CHECK-SAME:     to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
59// CHECK:        [[B_ELEM:%.*]] = load [[B]]{{\[}}[[I]], [[J]]]
60// CHECK:        [[C_ELEM:%.*]] = load [[C]]{{\[}}[[I]], [[J]]]
61// CHECK:        [[SUM_ELEM:%.*]] = addf [[B_ELEM]], [[C_ELEM]]
62// CHECK:        store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
63// CHECK:        [[SUM_ELEM_:%.*]] = load [[SUM]]{{\[}}[[I]], [[J]]]
64// CHECK:        [[A_ELEM:%.*]] = load [[A]]{{\[}}[[I]], [[J]]]
65// CHECK:        [[PRODUCT_ELEM:%.*]] = mulf [[SUM_ELEM_]], [[A_ELEM]]
66// CHECK:        store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
67// CHECK:        scf.yield
68// CHECK:      }
69// CHECK:      dealloc [[SUM]]
70
71// -----
72
73func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>,
74                      %result: memref<100x10xf32>) {
75  %c100 = constant 100 : index
76  %c10 = constant 10 : index
77  %c0 = constant 0 : index
78  %c1 = constant 1 : index
79  %broadcast_rhs = alloc() : memref<100x10xf32>
80  %diff = alloc() : memref<100x10xf32>
81  scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
82    %rhs_elem = load %rhs[%i] : memref<100xf32>
83    store %rhs_elem, %broadcast_rhs[%i, %j] : memref<100x10xf32>
84    scf.yield
85  }
86  scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
87    %lhs_elem = load %lhs[%i, %j] : memref<100x10xf32>
88    %broadcast_rhs_elem = load %broadcast_rhs[%i, %j] : memref<100x10xf32>
89    %diff_elem = subf %lhs_elem, %broadcast_rhs_elem : f32
90    store %diff_elem, %diff[%i, %j] : memref<100x10xf32>
91    scf.yield
92  }
93  scf.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) {
94    %diff_elem = load %diff[%i, %j] : memref<100x10xf32>
95    %exp_elem = exp %diff_elem : f32
96    store %exp_elem, %result[%i, %j] : memref<100x10xf32>
97    scf.yield
98  }
99  dealloc %broadcast_rhs : memref<100x10xf32>
100  dealloc %diff : memref<100x10xf32>
101  return
102}
103// CHECK-LABEL: func @fuse_three
104// CHECK-SAME: ([[LHS:%.*]]: memref<100x10xf32>, [[RHS:%.*]]: memref<100xf32>,
105// CHECK-SAME:  [[RESULT:%.*]]: memref<100x10xf32>) {
106// CHECK:      [[C100:%.*]] = constant 100 : index
107// CHECK:      [[C10:%.*]] = constant 10 : index
108// CHECK:      [[C0:%.*]] = constant 0 : index
109// CHECK:      [[C1:%.*]] = constant 1 : index
110// CHECK:      [[BROADCAST_RHS:%.*]] = alloc()
111// CHECK:      [[DIFF:%.*]] = alloc()
112// CHECK:      scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
113// CHECK-SAME:     to ([[C100]], [[C10]]) step ([[C1]], [[C1]]) {
114// CHECK:        [[RHS_ELEM:%.*]] = load [[RHS]]{{\[}}[[I]]]
115// CHECK:        store [[RHS_ELEM]], [[BROADCAST_RHS]]{{\[}}[[I]], [[J]]]
116// CHECK:        [[LHS_ELEM:%.*]] = load [[LHS]]{{\[}}[[I]], [[J]]]
117// CHECK:        [[BROADCAST_RHS_ELEM:%.*]] = load [[BROADCAST_RHS]]
118// CHECK:        [[DIFF_ELEM:%.*]] = subf [[LHS_ELEM]], [[BROADCAST_RHS_ELEM]]
119// CHECK:        store [[DIFF_ELEM]], [[DIFF]]{{\[}}[[I]], [[J]]]
120// CHECK:        [[DIFF_ELEM_:%.*]] = load [[DIFF]]{{\[}}[[I]], [[J]]]
121// CHECK:        [[EXP_ELEM:%.*]] = exp [[DIFF_ELEM_]]
122// CHECK:        store [[EXP_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
123// CHECK:        scf.yield
124// CHECK:      }
125// CHECK:      dealloc [[BROADCAST_RHS]]
126// CHECK:      dealloc [[DIFF]]
127
128// -----
129
130func @do_not_fuse_nested_ploop1() {
131  %c2 = constant 2 : index
132  %c0 = constant 0 : index
133  %c1 = constant 1 : index
134  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
135    scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
136      scf.yield
137    }
138    scf.yield
139  }
140  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
141    scf.yield
142  }
143  return
144}
145// CHECK-LABEL: func @do_not_fuse_nested_ploop1
146// CHECK:        scf.parallel
147// CHECK:          scf.parallel
148// CHECK:        scf.parallel
149
150// -----
151
152func @do_not_fuse_nested_ploop2() {
153  %c2 = constant 2 : index
154  %c0 = constant 0 : index
155  %c1 = constant 1 : index
156  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
157    scf.yield
158  }
159  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
160    scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
161      scf.yield
162    }
163    scf.yield
164  }
165  return
166}
167// CHECK-LABEL: func @do_not_fuse_nested_ploop2
168// CHECK:        scf.parallel
169// CHECK:        scf.parallel
170// CHECK:          scf.parallel
171
172// -----
173
174func @do_not_fuse_loops_unmatching_num_loops() {
175  %c2 = constant 2 : index
176  %c0 = constant 0 : index
177  %c1 = constant 1 : index
178  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
179    scf.yield
180  }
181  scf.parallel (%i) = (%c0) to (%c2) step (%c1) {
182    scf.yield
183  }
184  return
185}
186// CHECK-LABEL: func @do_not_fuse_loops_unmatching_num_loops
187// CHECK:        scf.parallel
188// CHECK:        scf.parallel
189
190// -----
191
192func @do_not_fuse_loops_with_side_effecting_ops_in_between() {
193  %c2 = constant 2 : index
194  %c0 = constant 0 : index
195  %c1 = constant 1 : index
196  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
197    scf.yield
198  }
199  %buffer  = alloc() : memref<2x2xf32>
200  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
201    scf.yield
202  }
203  return
204}
205// CHECK-LABEL: func @do_not_fuse_loops_with_side_effecting_ops_in_between
206// CHECK:        scf.parallel
207// CHECK:        scf.parallel
208
209// -----
210
211func @do_not_fuse_loops_unmatching_iteration_space() {
212  %c0 = constant 0 : index
213  %c1 = constant 1 : index
214  %c2 = constant 2 : index
215  %c4 = constant 4 : index
216  scf.parallel (%i, %j) = (%c0, %c0) to (%c4, %c4) step (%c2, %c2) {
217    scf.yield
218  }
219  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
220    scf.yield
221  }
222  return
223}
224// CHECK-LABEL: func @do_not_fuse_loops_unmatching_iteration_space
225// CHECK:        scf.parallel
226// CHECK:        scf.parallel
227
228// -----
229
230func @do_not_fuse_unmatching_write_read_patterns(
231    %A: memref<2x2xf32>, %B: memref<2x2xf32>,
232    %C: memref<2x2xf32>, %result: memref<2x2xf32>) {
233  %c2 = constant 2 : index
234  %c0 = constant 0 : index
235  %c1 = constant 1 : index
236  %common_buf = alloc() : memref<2x2xf32>
237  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
238    %B_elem = load %B[%i, %j] : memref<2x2xf32>
239    %C_elem = load %C[%i, %j] : memref<2x2xf32>
240    %sum_elem = addf %B_elem, %C_elem : f32
241    store %sum_elem, %common_buf[%i, %j] : memref<2x2xf32>
242    scf.yield
243  }
244  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
245    %k = addi %i, %c1 : index
246    %sum_elem = load %common_buf[%k, %j] : memref<2x2xf32>
247    %A_elem = load %A[%i, %j] : memref<2x2xf32>
248    %product_elem = mulf %sum_elem, %A_elem : f32
249    store %product_elem, %result[%i, %j] : memref<2x2xf32>
250    scf.yield
251  }
252  dealloc %common_buf : memref<2x2xf32>
253  return
254}
255// CHECK-LABEL: func @do_not_fuse_unmatching_write_read_patterns
256// CHECK:        scf.parallel
257// CHECK:        scf.parallel
258
259// -----
260
261func @do_not_fuse_unmatching_read_write_patterns(
262    %A: memref<2x2xf32>, %B: memref<2x2xf32>, %common_buf: memref<2x2xf32>) {
263  %c2 = constant 2 : index
264  %c0 = constant 0 : index
265  %c1 = constant 1 : index
266  %sum = alloc() : memref<2x2xf32>
267  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
268    %B_elem = load %B[%i, %j] : memref<2x2xf32>
269    %C_elem = load %common_buf[%i, %j] : memref<2x2xf32>
270    %sum_elem = addf %B_elem, %C_elem : f32
271    store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
272    scf.yield
273  }
274  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
275    %k = addi %i, %c1 : index
276    %sum_elem = load %sum[%k, %j] : memref<2x2xf32>
277    %A_elem = load %A[%i, %j] : memref<2x2xf32>
278    %product_elem = mulf %sum_elem, %A_elem : f32
279    store %product_elem, %common_buf[%j, %i] : memref<2x2xf32>
280    scf.yield
281  }
282  dealloc %sum : memref<2x2xf32>
283  return
284}
285// CHECK-LABEL: func @do_not_fuse_unmatching_read_write_patterns
286// CHECK:        scf.parallel
287// CHECK:        scf.parallel
288
289// -----
290
291func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() {
292  %c2 = constant 2 : index
293  %c0 = constant 0 : index
294  %c1 = constant 1 : index
295  %buffer  = alloc() : memref<2x2xf32>
296  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
297    scf.yield
298  }
299  scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
300    %A = subview %buffer[%c0, %c0][%c2, %c2][%c1, %c1]
301      : memref<2x2xf32> to memref<?x?xf32, offset: ?, strides:[?, ?]>
302    %A_elem = load %A[%i, %j] : memref<?x?xf32, offset: ?, strides:[?, ?]>
303    scf.yield
304  }
305  return
306}
307// CHECK-LABEL: func @do_not_fuse_loops_with_memref_defined_in_loop_bodies
308// CHECK:        scf.parallel
309// CHECK:        scf.parallel
310
311// -----
312
313func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
314                    %C: memref<2x2xf32>, %result: memref<2x2xf32>) {
315  %c2 = constant 2 : index
316  %c0 = constant 0 : index
317  %c1 = constant 1 : index
318  %sum = alloc()  : memref<2x2xf32>
319  scf.parallel (%k) = (%c0) to (%c2) step (%c1) {
320    scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
321      %B_elem = load %B[%i, %j] : memref<2x2xf32>
322      %C_elem = load %C[%i, %j] : memref<2x2xf32>
323      %sum_elem = addf %B_elem, %C_elem : f32
324      store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
325      scf.yield
326    }
327    scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
328      %sum_elem = load %sum[%i, %j] : memref<2x2xf32>
329      %A_elem = load %A[%i, %j] : memref<2x2xf32>
330      %product_elem = mulf %sum_elem, %A_elem : f32
331      store %product_elem, %result[%i, %j] : memref<2x2xf32>
332      scf.yield
333    }
334  }
335  dealloc %sum : memref<2x2xf32>
336  return
337}
338// CHECK-LABEL: func @nested_fuse
339// CHECK-SAME:   ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}},
340// CHECK-SAME:    [[RESULT:%.*]]: {{.*}}) {
341// CHECK:      [[C2:%.*]] = constant 2 : index
342// CHECK:      [[C0:%.*]] = constant 0 : index
343// CHECK:      [[C1:%.*]] = constant 1 : index
344// CHECK:      [[SUM:%.*]] = alloc()
345// CHECK:      scf.parallel
346// CHECK:        scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
347// CHECK-SAME:       to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
348// CHECK:          [[B_ELEM:%.*]] = load [[B]]{{\[}}[[I]], [[J]]]
349// CHECK:          [[C_ELEM:%.*]] = load [[C]]{{\[}}[[I]], [[J]]]
350// CHECK:          [[SUM_ELEM:%.*]] = addf [[B_ELEM]], [[C_ELEM]]
351// CHECK:          store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
352// CHECK:          [[SUM_ELEM_:%.*]] = load [[SUM]]{{\[}}[[I]], [[J]]]
353// CHECK:          [[A_ELEM:%.*]] = load [[A]]{{\[}}[[I]], [[J]]]
354// CHECK:          [[PRODUCT_ELEM:%.*]] = mulf [[SUM_ELEM_]], [[A_ELEM]]
355// CHECK:          store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
356// CHECK:          scf.yield
357// CHECK:        }
358// CHECK:      }
359// CHECK:      dealloc [[SUM]]
360