1// RUN: mlir-opt -pass-pipeline="func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),canonicalize,cse" -split-input-file %s | FileCheck %s
2
3module {
4  func @three_op_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
5                        %arg2: memref<?xf32>, %arg3 : memref<?x?xf32>) {
6    %cst = constant 0.000000e+00 : f32
7    %c0 = constant 0 : index
8    %c1 = constant 1 : index
9    %d0 = dim %arg0, %c0 : memref<?x?xf32>
10    %d1 = dim %arg1, %c1 : memref<?x?xf32>
11    %0 = alloc(%d0, %d1) : memref<?x?xf32>
12    linalg.fill(%0, %cst) : memref<?x?xf32>, f32
13    linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
14      outs(%0 : memref<?x?xf32>)
15    linalg.generic
16      {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
17                        affine_map<(d0, d1) -> (d1)>,
18                        affine_map<(d0, d1) -> (d0, d1)>],
19       iterator_types = ["parallel", "parallel"]}
20      ins(%0, %arg2 : memref<?x?xf32>, memref<?xf32>)
21      outs(%arg3 : memref<?x?xf32>) {
22      ^bb0(%arg4 : f32, %arg5 : f32, %arg6 : f32) :
23        %5 = addf %arg4, %arg5 : f32
24        linalg.yield %5 : f32
25      }
26    return
27  }
28}
29
30//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
31//   CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
32//       CHECK: func @three_op_fusion
33//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
34//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
35//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: memref<?xf32>
36//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
37//       CHECK:   %[[TEMP:.+]] = alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32>
38//       CHECK:   scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = {{.*}} {
39//   CHECK-DAG:     %[[SV_TEMP:.+]] = subview %[[TEMP]][%[[IV0]], %[[IV1]]]
40//   CHECK-DAG:     %[[SV_ARG2:.+]] = subview %[[ARG2]][%[[IV1]]]
41//   CHECK-DAG:     %[[SV_ARG3:.+]] = subview %[[ARG3]][%[[IV0]], %[[IV1]]]
42//   CHECK-DAG:     %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0]
43//   CHECK-DAG:     %[[SV_ARG1:.+]] = subview %[[ARG1]][0, %[[IV1]]]
44//       CHECK:     linalg.fill(%[[SV_TEMP]], %{{.+}})
45//       CHECK:     linalg.matmul
46//  CHECK-SAME:       ins(%[[SV_ARG0]], %[[SV_ARG1]]
47//  CHECK-SAME:         : memref<?x?xf32, #[[MAP2]]>, memref<?x?xf32, #[[MAP2]]>)
48//  CHECK-SAME:       outs(%[[SV_TEMP]] : memref<?x?xf32, #[[MAP2]]>)
49//       CHECK:     linalg.generic
50//  CHECK-SAME:       ins(%[[SV_TEMP]], %[[SV_ARG2]]
51//  CHECK-SAME:         : memref<?x?xf32, #[[MAP2]]>, memref<?xf32, #[[MAP3]]>)
52//  CHECK-SAME:       outs(%[[SV_ARG3]] : memref<?x?xf32, #[[MAP2]]>)
53//       CHECK:     scf.yield
54//       CHECK:   }
55
56// -----
57
58module {
59  func @sequence_of_matmul(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
60                           %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>,
61			   %arg4: memref<?x?xf32>) {
62    %cst = constant 0.000000e+00 : f32
63    %c0 = constant 0 : index
64    %c1 = constant 1 : index
65    %m = dim %arg0, %c0 : memref<?x?xf32>
66    %n1 = dim %arg1, %c1 : memref<?x?xf32>
67    %n2 = dim %arg2, %c1 : memref<?x?xf32>
68    %n3 = dim %arg3, %c1 : memref<?x?xf32>
69    %0 = alloc(%m, %n1) : memref<?x?xf32>
70    %1 = alloc(%m, %n2) : memref<?x?xf32>
71    linalg.fill(%0, %cst) : memref<?x?xf32>, f32
72    linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
73      outs(%0 : memref<?x?xf32>)
74    linalg.fill(%1, %cst) : memref<?x?xf32>, f32
75    linalg.matmul ins(%0, %arg2 : memref<?x?xf32>, memref<?x?xf32>)
76      outs(%1 : memref<?x?xf32>)
77    linalg.fill(%arg4, %cst) : memref<?x?xf32>, f32
78    linalg.matmul ins(%1, %arg3 : memref<?x?xf32>, memref<?x?xf32>)
79      outs(%arg4 : memref<?x?xf32>)
80    return
81  }
82}
83
84//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
85//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
86//       CHECK: func @sequence_of_matmul
87//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
88//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
89//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
90//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
91//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: memref<?x?xf32>
92//   CHECK-DAG:   %[[C0:.+]] = constant 0 : index
93//   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
94//   CHECK-DAG:   %[[C16:.+]] = constant 16 : index
95//   CHECK-DAG:   %[[M:.+]] = dim %[[ARG0]], %[[C0]]
96//   CHECK-DAG:   %[[N1:.+]] = dim %[[ARG1]], %[[C1]]
97//   CHECK-DAG:   %[[N2:.+]] = dim %[[ARG2]], %[[C1]]
98//       CHECK:   %[[ALLOC1:.+]] = alloc(%[[M]], %[[N1]])
99//       CHECK:   %[[ALLOC2:.+]] = alloc(%[[M]], %[[N2]])
100//       CHECK:   scf.parallel (%[[IV0:.+]]) = (%[[C0]]) to (%[[M]])
101//  CHECK-SAME:     step (%[[C16]]) {
102//       CHECK:     %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
103//       CHECK:     %[[SV_ALLOC2:.+]] = subview %[[ALLOC2]][%[[IV0]], 0]
104//  CHECK-SAME:       [%[[TILE_M]], %[[N2]]]
105//       CHECK:     %[[M_2:.+]] = dim %[[ARG4]], %[[C0]]
106//       CHECK:     %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]]
107//       CHECK:     %[[N3:.+]] = dim %[[ARG4]], %[[C1]]
108//       CHECK:     %[[SV_ARG4:.+]] = subview %[[ARG4]][%[[IV0]], 0]
109//  CHECK-SAME:       [%[[TILE_M_2]], %[[N3]]]
110//       CHECK:     %[[SV_ARG4_2:.+]] = subview %[[ARG4]][%[[IV0]], 0]
111//  CHECK-SAME:       [%[[TILE_M]], %[[N3]]]
112//       CHECK:     %[[SV_ALLOC1:.+]] = subview %[[ALLOC1]][%[[IV0]], 0]
113//  CHECK-SAME:       [%[[TILE_M]], %[[N1]]]
114//       CHECK:     %[[SV_ARG2:.+]] = subview %[[ARG2]][0, 0] [%[[N1]], %[[N2]]]
115//       CHECK:     %[[N0:.+]] = dim %[[ARG0]], %[[C1]]
116//       CHECK:     %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0]
117//  CHECK-SAME:       [%[[TILE_M:.+]], %[[N0]]]
118//       CHECK:     %[[SV_ARG1:.+]] = subview %[[ARG1]][0, 0] [%[[N0]], %[[N1]]]
119//       CHECK:     linalg.fill(%[[SV_ALLOC1]], %{{.+}})
120//       CHECK:     linalg.matmul ins(%[[SV_ARG0]], %[[SV_ARG1]]
121//  CHECK-SAME:        : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
122//  CHECK-SAME:        outs(%[[SV_ALLOC1]] : memref<?x?xf32, #[[MAP1]]>)
123//       CHECK:     linalg.fill(%[[SV_ALLOC2]], %{{.+}})
124//       CHECK:     linalg.matmul ins(%[[SV_ALLOC1]], %[[SV_ARG2]]
125//  CHECK-SAME:        : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
126//  CHECK-SAME:        outs(%[[SV_ALLOC2]] : memref<?x?xf32, #[[MAP1]]>)
127//       CHECK:     linalg.fill(%[[SV_ARG4_2]], %{{.+}})
128//       CHECK:     linalg.matmul ins(%[[SV_ALLOC2]], %[[ARG3]]
129//  CHECK-SAME:        : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
130//  CHECK-SAME:        outs(%[[SV_ARG4]] : memref<?x?xf32, #[[MAP1]]>)
131//       CHECK:     scf.yield
132//       CHECK:   }
133
134