1// RUN: mlir-hlo-opt -lhlo-fuse-linalg %s -split-input-file | FileCheck %s --dump-input=always
2// RUN: mlir-hlo-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -split-input-file | FileCheck %s -check-prefix=TILED
3// RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP
4
5#map0 = affine_map<(d0, d1) -> (d0, d1)>
6#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0],
7                       iterator_types = ["parallel", "parallel"]}
8func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
9             %summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
10  %temp_result = alloc() : memref<6x6xf32>
11  linalg.generic #pointwise_2d_trait
12    ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>)
13   outs(%temp_result : memref<6x6xf32>) {
14  ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
15    %out = addf %summand_1_in, %summand_2_in : f32
16    linalg.yield %out : f32
17  }
18  linalg.generic #pointwise_2d_trait
19    ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>)
20   outs(%result : memref<6x6xf32>) {
21  ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
22    %out = mulf %temp_result_in, %multiplier_in : f32
23    linalg.yield %out : f32
24  }
25  dealloc %temp_result : memref<6x6xf32>
26  return
27}
28// CHECK-LABEL: func @fusion
29//       CHECK:  %[[C1:.*]] = constant 1
30//   CHECK-NOT:  linalg.generic
31//       CHECK:  scf.for {{.*}} step %[[C1]]
32//       CHECK:    scf.for {{.*}} step %[[C1]]
33//   CHECK-NOT:  scf.for
34//       CHECK:      linalg.generic
35//       CHECK:        addf
36//       CHECK:      linalg.generic
37//       CHECK:        mulf
38
39// TILED-LABEL: func @fusion
40//   TILED-DAG:  %[[C2:.*]] = constant 2
41//   TILED-DAG:  %[[C3:.*]] = constant 3
42//   TILED-NOT:  linalg.generic
43//       TILED:  scf.for {{.*}} step %[[C2]]
44//       TILED:    scf.for {{.*}} step %[[C3]]
45//   TILED-NOT:  scf.for
46//       TILED:      linalg.generic
47//       TILED:        addf
48//       TILED:      linalg.generic
49//       TILED:        mulf
50
51// PLOOP-LABEL: func @fusion
52//   PLOOP-NOT:  linalg.generic
53//       PLOOP:  scf.parallel
54//   PLOOP-NOT:  scf.parallel
55//       PLOOP:      linalg.generic
56//       PLOOP:        addf
57//       PLOOP:      linalg.generic
58//       PLOOP:        mulf
59
60// -----
61
62func @fusion_of_three(%arg0: memref<100x10xf32>,
63                      %arg1: memref<100xf32>,
64                      %arg2: memref<100x10xf32>) {
65 %0 = alloc() : memref<100x10xf32>
66 linalg.generic {
67   indexing_maps = [affine_map<(d0, d1) -> (d0)>,
68                    affine_map<(d0, d1) -> (d0, d1)>],
69   iterator_types = ["parallel", "parallel"]}
70     ins(%arg1 : memref<100xf32>)
71    outs(%0 : memref<100x10xf32>) {
72   ^bb0(%arg3: f32, %arg4: f32): // no predecessors
73     linalg.yield %arg3 : f32
74   }
75 %1 = alloc() : memref<100x10xf32>
76 linalg.generic {
77   indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
78                    affine_map<(d0, d1) -> (d0, d1)>,
79                    affine_map<(d0, d1) -> (d0, d1)>],
80   iterator_types = ["parallel", "parallel"]}
81    ins(%arg0, %0 : memref<100x10xf32>, memref<100x10xf32>)
82   outs(%1 : memref<100x10xf32>) {
83     ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
84       %2 = subf %arg3, %arg4 : f32
85       linalg.yield %2 : f32
86     }
87 dealloc %0 : memref<100x10xf32>
88 linalg.generic {
89   indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
90                    affine_map<(d0, d1) -> (d0, d1)>],
91   iterator_types = ["parallel", "parallel"]}
92     ins(%1 : memref<100x10xf32>)
93    outs(%arg2 : memref<100x10xf32>) {
94     ^bb0(%arg3: f32, %arg4: f32): // no predecessors
95       %2 = math.exp %arg3 : f32
96       linalg.yield %2 : f32
97     }
98 dealloc %1 : memref<100x10xf32>
99 return
100}
101// CHECK-LABEL: func @fusion
102//       CHECK:  %[[C1:.*]] = constant 1
103//   CHECK-NOT:  linalg.generic
104//       CHECK:  scf.for {{.*}} step %[[C1]]
105//       CHECK:    scf.for {{.*}} step %[[C1]]
106//   CHECK-NOT:  scf.for
107//       CHECK:      linalg.generic
108//       CHECK:      linalg.generic
109//       CHECK:        subf
110//       CHECK:      linalg.generic
111//       CHECK:        exp
112
113// TILED-LABEL: func @fusion_of_three
114//   TILED-DAG:   %[[C2:.*]] = constant 2
115//   TILED-DAG:   %[[C3:.*]] = constant 3
116//   TILED-NOT:   linalg.generic
117//       TILED:   scf.for {{.*}} step %[[C2]]
118//       TILED:     scf.for {{.*}} step %[[C3]]
119//   TILED-NOT:   scf.for
120//       TILED:       linalg.generic
121//       TILED:       linalg.generic
122//       TILED:         subf
123//       TILED:       linalg.generic
124//       TILED:         exp
125
126// PLOOP-LABEL: func @fusion_of_three
127//   PLOOP-NOT:   linalg.generic
128//       PLOOP:   scf.parallel
129//   PLOOP-NOT:   scf.parallel
130//       PLOOP:       linalg.generic
131//       PLOOP:       linalg.generic
132//       PLOOP:         subf
133//       PLOOP:       linalg.generic
134//       PLOOP:         exp
135
136// -----
137
138#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
139#pointwise_4d_trait = {indexing_maps = [#map0, #map0, #map0],
140                       iterator_types = ["parallel", "parallel", "parallel",
141                                         "parallel"]}
142func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
143             %summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
144  %temp_result = alloc() : memref<6x6x6x6xf32>
145  linalg.generic #pointwise_4d_trait
146    ins(%summand_1, %summand_2 : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>)
147   outs(%temp_result : memref<6x6x6x6xf32>) {
148  ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
149    %out = addf %summand_1_in, %summand_2_in : f32
150    linalg.yield %out : f32
151  }
152  linalg.generic #pointwise_4d_trait
153    ins(%temp_result, %multiplier : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>)
154   outs(%result : memref<6x6x6x6xf32>) {
155  ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
156    %out = mulf %temp_result_in, %multiplier_in : f32
157    linalg.yield %out : f32
158  }
159  dealloc %temp_result : memref<6x6x6x6xf32>
160  return
161}
162// CHECK-LABEL: func @fusion_4d
163//       CHECK:  %[[C1:.*]] = constant 1
164//   CHECK-NOT:  linalg.generic
165//       CHECK:  scf.for {{.*}} step %[[C1]]
166//       CHECK:    scf.for {{.*}} step %[[C1]]
167//       CHECK:      scf.for {{.*}} step %[[C1]]
168//       CHECK:        scf.for {{.*}} step %[[C1]]
169//   CHECK-NOT:  scf.for
170//       CHECK:      linalg.generic
171//       CHECK:        addf
172//       CHECK:      linalg.generic
173//       CHECK:        mulf
174
175// TILED-LABEL: func @fusion_4d
176//   TILED-DAG:  %[[C2:.*]] = constant 2
177//   TILED-DAG:  %[[C3:.*]] = constant 3
178//   TILED-NOT:  linalg.generic
179//       TILED:  scf.for {{.*}} step %[[C2]]
180//       TILED:    scf.for {{.*}} step %[[C3]]
181//   TILED-NOT:  scf.for
182//       TILED:      linalg.generic
183//       TILED:        addf
184//       TILED:      linalg.generic
185//       TILED:        mulf
186
187// PLOOP-LABEL: func @fusion_4d
188//   PLOOP-NOT:  linalg.generic
189//       PLOOP:  scf.parallel
190//   PLOOP-NOT:  scf.parallel
191//       PLOOP:      linalg.generic
192//       PLOOP:        addf
193//       PLOOP:      linalg.generic
194//       PLOOP:        mulf
195
196// -----
197
198#map0 = affine_map<(d0, d1) -> (d0, d1)>
199#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0],
200                       iterator_types = ["parallel", "parallel"]}
201func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
202             %summand_2: memref<6x6xf32>) -> memref<6x6xf32> {
203  %temp_result = alloc() : memref<6x6xf32>
204  linalg.generic #pointwise_2d_trait
205    ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>)
206   outs(%temp_result : memref<6x6xf32>) {
207  ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
208    %out = addf %summand_1_in, %summand_2_in : f32
209    linalg.yield %out : f32
210  }
211  %result = alloc() : memref<6x6xf32>
212  linalg.generic #pointwise_2d_trait
213    ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>)
214   outs(%result : memref<6x6xf32>) {
215  ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
216    %out = mulf %temp_result_in, %multiplier_in : f32
217    linalg.yield %out : f32
218  }
219  dealloc %temp_result : memref<6x6xf32>
220  return %result : memref<6x6xf32>
221}
222
223// CHECK-LABEL: func @fusion
224//       CHECK:  %[[C1:.*]] = constant 1
225//   CHECK-NOT:  linalg.generic
226//       CHECK:  scf.for {{.*}} step %[[C1]]
227//       CHECK:    scf.for {{.*}} step %[[C1]]
228//   CHECK-NOT:  scf.for
229//       CHECK:      linalg.generic
230//       CHECK:        addf
231//       CHECK:      linalg.generic
232//       CHECK:        mulf
233
234// TILED-LABEL: func @fusion
235//   TILED-DAG:  %[[C2:.*]] = constant 2
236//   TILED-DAG:  %[[C3:.*]] = constant 3
237//   TILED-NOT:  linalg.generic
238//       TILED:  scf.for {{.*}} step %[[C2]]
239//       TILED:    scf.for {{.*}} step %[[C3]]
240//   TILED-NOT:  scf.for
241//       TILED:      linalg.generic
242//       TILED:        addf
243//       TILED:      linalg.generic
244//       TILED:        mulf
245
246// PLOOP-LABEL: func @fusion
247//   PLOOP-NOT:  linalg.generic
248//       PLOOP:  scf.parallel
249//   PLOOP-NOT:  scf.parallel
250//       PLOOP:      linalg.generic
251//       PLOOP:        addf
252//       PLOOP:      linalg.generic
253//       PLOOP:        mulf
254
255// -----
256
257func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
258    -> memref<*xf32> {
259  %c1 = constant 1 : index
260  %c0 = constant 0 : index
261  %1 = alloc(%arg2) : memref<?xf32>
262  linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
263                                   affine_map<(d0) -> (d0)>],
264                  iterator_types = ["parallel"]}
265      ins(%arg0 : memref<?xf32>) outs(%1 : memref<?xf32>) {
266  ^bb0(%arg3: f32, %arg4: f32):  // no predecessors
267    %13 = absf %arg3 : f32
268    linalg.yield %13 : f32
269  }
270  %2 = memref_reshape %1(%arg1)
271      : (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
272  return %2 : memref<*xf32>
273}
274
275// CHECK-LABEL: func @view_result
276//       CHECK:  %[[C1:.*]] = constant 1
277//   CHECK-NOT:  linalg.generic
278//       CHECK:  scf.for {{.*}} step %[[C1]]
279//   CHECK-NOT:  scf.for
280//       CHECK:      linalg.generic
281//       CHECK:        absf
282//       CHECK:  memref_reshape
283
284// TILED-LABEL: func @view_result
285//   TILED-DAG:  %[[C2:.*]] = constant 2
286//   TILED-NOT:  linalg.generic
287//       TILED:  scf.for {{.*}} step %[[C2]]
288//   TILED-NOT:  scf.for
289//       TILED:      linalg.generic
290//       TILED:        absf
291//       TILED:  memref_reshape
292
293
294// PLOOP-LABEL: func @view_result
295//   PLOOP-NOT:  linalg.generic
296//       PLOOP:  scf.parallel
297//   PLOOP-NOT:  scf.parallel
298//       PLOOP:      linalg.generic
299//       PLOOP:        absf
300//       PLOOP:  memref_reshape
301
302
303
304// -----
305
306// Confirm that tiling information is passed through RegionBranchOpInterfaces.
307// This test also uses memref_reshape, just to have a value to return through
308// the if statement.
309func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
310    -> memref<*xf32> {
311  %c1 = constant 1 : index
312  %c0 = constant 0 : index
313  %1 = alloc(%arg2) : memref<?xf32>
314  linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
315                                   affine_map<(d0) -> (d0)>],
316                  iterator_types = ["parallel"]}
317      ins(%arg0 : memref<?xf32>) outs(%1 : memref<?xf32>) {
318  ^bb0(%arg3: f32, %arg4: f32):  // no predecessors
319    %13 = absf %arg3 : f32
320    linalg.yield %13 : f32
321  }
322  %true = constant 1 : i1
323  %3 = scf.if %true -> memref<*xf32> {
324    %2 = memref_reshape %1(%arg1)
325        : (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
326    scf.yield %2 : memref<*xf32>
327  } else {
328    %2 = memref_reshape %1(%arg1)
329        : (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
330    scf.yield %2 : memref<*xf32>
331  }
332  return %3 : memref<*xf32>
333}
334
335// CHECK-LABEL: func @branching_result
336//       CHECK:  %[[C1:.*]] = constant 1
337//   CHECK-NOT:  linalg.generic
338//       CHECK:  scf.for {{.*}} step %[[C1]]
339//   CHECK-NOT:  scf.for
340//       CHECK:      linalg.generic
341//       CHECK:        absf
342//       CHECK:  scf.if
343//       CHECK:    memref_reshape
344//       CHECK:    scf.yield
345//       CHECK:  else
346//       CHECK:    memref_reshape
347//       CHECK:    scf.yield
348
349// TILED-LABEL: func @branching_result
350//   TILED-DAG:  %[[C2:.*]] = constant 2
351//   TILED-NOT:  linalg.generic
352//       TILED:  scf.for {{.*}} step %[[C2]]
353//   TILED-NOT:  scf.for
354//       TILED:      linalg.generic
355//       TILED:        absf
356//       TILED:  scf.if
357//       TILED:    memref_reshape
358//       TILED:    scf.yield
359//       TILED:  else
360//       TILED:    memref_reshape
361//       TILED:    scf.yield
362
363// PLOOP-LABEL: func @branching_result
364//   PLOOP-NOT:  linalg.generic
365//       PLOOP:  scf.parallel
366//   PLOOP-NOT:  scf.parallel
367//       PLOOP:      linalg.generic
368//       PLOOP:        absf
369//       PLOOP:  scf.if
370//       PLOOP:    memref_reshape
371//       PLOOP:    scf.yield
372//       PLOOP:  else
373//       PLOOP:    memref_reshape
374//       PLOOP:    scf.yield
375
376// -----
377
378// Confirm that tiling information is passed through tensor_load, tensor.cast
379// and memref_to_tensor  operations.
380func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
381    -> memref<?xf32> {
382  %c1 = constant 1 : index
383  %1 = alloc() : memref<32xf32>
384  linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
385                                   affine_map<(d0) -> (d0)>],
386                  iterator_types = ["parallel"]}
387      ins(%arg0 : memref<32xf32>) outs(%1 : memref<32xf32>) {
388  ^bb0(%arg3: f32, %arg4: f32):  // no predecessors
389    %13 = absf %arg3 : f32
390    linalg.yield %13 : f32
391  }
392  %2 = tensor_load %1 : memref<32xf32>
393  %3 = tensor.cast %2 : tensor<32xf32> to tensor<?xf32>
394  %4 = tensor_to_memref %3 : memref<?xf32>
395  return %4 : memref<?xf32>
396}
397
398// CHECK-LABEL: func @tensor_ops
399//       CHECK:  %[[C1:.*]] = constant 1
400//   CHECK-NOT:  linalg.generic
401//       CHECK:  scf.for {{.*}} step %[[C1]]
402//   CHECK-NOT:  scf.for
403//       CHECK:      linalg.generic
404//       CHECK:        absf
405//       CHECK:  tensor_load
406//       CHECK:  tensor.cast
407//       CHECK:  tensor_to_memref
408
409// TILED-LABEL: func @tensor_ops
410//   TILED-DAG:  %[[C2:.*]] = constant 2
411//   TILED-NOT:  linalg.generic
412//       TILED:  scf.for {{.*}} step %[[C2]]
413//   TILED-NOT:  scf.for
414//       TILED:      linalg.generic
415//       TILED:        absf
416//       TILED:  tensor_load
417//       TILED:  tensor.cast
418//       TILED:  tensor_to_memref
419
420
421// PLOOP-LABEL: func @tensor_ops
422//   PLOOP-NOT:  linalg.generic
423//       PLOOP:  scf.parallel
424//   PLOOP-NOT:  scf.parallel
425//       PLOOP:      linalg.generic
426//       PLOOP:        absf
427//       PLOOP:  tensor_load
428//       PLOOP:  tensor.cast
429//       PLOOP:  tensor_to_memref
430