1// RUN: mlir-opt -split-input-file %s | FileCheck %s
2// | mlir-opt | FileCheck %s
3
4// TODO: Re-enable LLVM lowering test after IndexedGenericOp is lowered.
5//
6// Test that we can lower all the way to LLVM without crashing, don't check results here.
7// DISABLED: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1
8
9func @range(%arg0: index, %arg1: index, %arg2: index) {
10  %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
11  return
12}
13// CHECK-LABEL: func @range(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
14//  CHECK-NEXT:  linalg.range %{{.*}} : %{{.*}} : %{{.*}} : !linalg.range
15
16// -----
17
18// CHECK-DAG: #[[$strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
19
20func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
21  %c0 = constant 0 : index
22  %0 = muli %arg0, %arg0 : index
23  %1 = alloc (%0) : memref<?xi8>
24  %2 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
25  %3 = view %1[%c0][%arg0, %arg0] : memref<?xi8> to memref<?x?xf32>
26  %4 = linalg.slice %3[%2, %2] :
27    memref<?x?xf32>,
28    !linalg.range,
29    !linalg.range,
30    memref<?x?xf32>
31  %5 = linalg.slice %3[%2, %arg2] : memref<?x?xf32>,
32                                    !linalg.range,
33                                    index,
34                                    memref<?xf32, offset: ?, strides: [1]>
35  %6 = linalg.slice %3[%arg2, %2] : memref<?x?xf32>,
36                                    index,
37                                    !linalg.range,
38                                    memref<?xf32, offset: ?, strides: [1]>
39  %7 = linalg.slice %3[%arg2, %arg3] : memref<?x?xf32>,
40                                       index,
41                                       index,
42                                       memref<f32>
43  %8 = view %1[%c0][%arg0, %arg0] : memref<?xi8> to memref<?x?xvector<4x4xf32>>
44  dealloc %1 : memref<?xi8>
45  return
46}
47// CHECK-LABEL: func @views
48//  CHECK:  muli %{{.*}}, %{{.*}} : index
49//  CHECK-NEXT:  alloc(%{{.*}}) : memref<?xi8>
50//  CHECK-NEXT:  range
51//  CHECK-NEXT:  std.view %{{.*}}[%{{.*}}][%{{.*}}] :
52//  CHECK-SAME:     memref<?xi8> to memref<?x?xf32>
53//  CHECK-NEXT:  linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] :
54//  CHECK-SAME:     memref<?x?xf32>,
55//  CHECK-SAME:     !linalg.range,
56//  CHECK-SAME:     !linalg.range,
57//  CHECK-SAME:     memref<?x?xf32>
58//  CHECK-NEXT:  linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] :
59//  CHECK-SAME:     memref<?x?xf32>,
60//  CHECK-SAME:     !linalg.range,
61//  CHECK-SAME:     index,
62//  CHECK-SAME:     memref<?xf32, #[[$strided1D]]>
63//  CHECK-NEXT:  linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] :
64//  CHECK-SAME:     memref<?x?xf32>,
65//  CHECK-SAME:     index,
66//  CHECK-SAME:     !linalg.range,
67//  CHECK-SAME:     memref<?xf32, #[[$strided1D]]>
68//  CHECK-NEXT:  linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] :
69//  CHECK-SAME:     memref<?x?xf32>,
70//  CHECK-SAME:     index,
71//  CHECK-SAME:     index,
72//  CHECK-SAME:     memref<f32>
73//  CHECK-NEXT:  view %{{.*}}[%{{.*}}][%{{.*}}] :
74//  CHECK-SAME:     memref<?xi8> to memref<?x?xvector<4x4xf32>>
75//  CHECK-NEXT:  dealloc %{{.*}} : memref<?xi8>
76
77// -----
78
79// CHECK-DAG: #[[$strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
80// CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
81
82func @ops(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
83          %arg1: memref<?xf32, offset: ?, strides: [1]>,
84          %arg2: memref<?xf32, offset: ?, strides: [1]>,
85          %arg3: memref<f32>) {
86  linalg.matmul ins(%arg0, %arg0 : memref<?x?xf32, offset: ?, strides: [?, 1]>,
87                                   memref<?x?xf32, offset: ?, strides: [?, 1]>)
88               outs(%arg0 : memref<?x?xf32, offset: ?, strides: [?, 1]>)
89  linalg.matvec ins(%arg0, %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
90                                  memref<?xf32, offset: ?, strides: [1]>)
91               outs(%arg2: memref<?xf32, offset: ?, strides: [1]>)
92  linalg.dot ins(%arg1, %arg2: memref<?xf32, offset: ?, strides: [1]>,
93                               memref<?xf32, offset: ?, strides: [1]>)
94            outs(%arg3: memref<f32>)
95  return
96}
97// CHECK-LABEL: func @ops(%
98// CHECK: linalg.matmul
99// CHECK-SAME:   ins(%{{.*}}, %{{.*}} : memref<?x?xf32, #[[$strided2D]]>,
100// CHECK-SAME:                          memref<?x?xf32, #[[$strided2D]]>)
101// CHECK-SAME:  outs(%{{.*}} : memref<?x?xf32, #[[$strided2D]]>)
102// CHECK: linalg.matvec
103// CHECK-SAME:   ins(%{{.*}}, %{{.*}}: memref<?x?xf32, #[[$strided2D]]>,
104// CHECK-SAME:                         memref<?xf32, #[[$strided1D]]>)
105// CHECK-SAME:  outs(%{{.*}}: memref<?xf32, #[[$strided1D]]>)
106// CHECK: linalg.dot
107// CHECK-SAME:   ins(%{{.*}}, %{{.*}}: memref<?xf32, #[[$strided1D]]>,
108// CHECK-SAME:                         memref<?xf32, #[[$strided1D]]>)
109// CHECK-SAME:  outs(%{{.*}}: memref<f32>)
110
111// -----
112
113// CHECK-DAG: #[[$strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
114
115func @fill_view(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: f32) {
116  linalg.fill(%arg0, %arg1) : memref<?xf32, offset: ?, strides: [1]>, f32
117  return
118}
119// CHECK-LABEL: func @fill_view(
120//       CHECK:  %{{.*}}: memref<?xf32, #[[$strided1D]]>, %{{.*}}: f32) {
121//       CHECK:   linalg.fill(%{{.*}}, %{{.*}}) : memref<?xf32, #[[$strided1D]]>, f32
122
123// -----
124
125// CHECK-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
126// CHECK-DAG: #[[$strided3DT:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)>
127
128func @transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
129  %0 = transpose %arg0 (i, j, k) -> (k, j, i) : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)>>
130  return
131}
132// CHECK-LABEL: func @transpose
133//       CHECK:   transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) :
134//  CHECK-SAME:      memref<?x?x?xf32, #[[$strided3D]]> to memref<?x?x?xf32, #[[$strided3DT]]>
135
136// -----
137
138// CHECK-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
139
140func @fill_view3(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg1: f32) {
141  linalg.fill(%arg0, %arg1) : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, f32
142  return
143}
144// CHECK-LABEL: func @fill_view3(
145//       CHECK:  %{{.*}}: memref<?x?x?xf32, #[[$strided3D]]>, %{{.*}}: f32) {
146//       CHECK:   linalg.fill(%{{.*}}, %{{.*}}) : memref<?x?x?xf32, #[[$strided3D]]>, f32
147
148// -----
149
150// CHECK-DAG: #[[$strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
151
152func @copy_view(%arg0: memref<?xf32, offset: ?, strides: [1]>,
153                %arg1: memref<?xf32, offset: ?, strides: [1]>) {
154  linalg.copy(%arg0, %arg1) : memref<?xf32, offset: ?, strides: [1]>,
155                              memref<?xf32, offset: ?, strides: [1]>
156  return
157}
158// CHECK-LABEL: func @copy_view(
159//       CHECK:   linalg.copy(%{{.*}}, %{{.*}}) :
160//  CHECK-SAME:     memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>
161
162// -----
163
164// CHECK-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
165// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
166// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1, d0)>
167
168func @copy_view3(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
169                 %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
170  linalg.copy(%arg0, %arg1) {inputPermutation = affine_map<(i, j, k) -> (i, k, j)>,
171                             outputPermutation = affine_map<(i, j, k) -> (k, j, i)>} :
172    memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
173  return
174}
175// CHECK-LABEL: func @copy_view3(
176//       CHECK:  %{{.*}}: memref<?x?x?xf32, #[[$strided3D]]>, %{{.*}}: memref<?x?x?xf32, #[[$strided3D]]>) {
177//       CHECK:   linalg.copy(%{{.*}}, %{{.*}}) {
178//  CHECK-SAME:     inputPermutation = #[[$map0]],
179//  CHECK-SAME:     outputPermutation = #[[$map1]]} :
180//  CHECK-SAME:     memref<?x?x?xf32, #[[$strided3D]]>,
181//  CHECK-SAME:     memref<?x?x?xf32, #[[$strided3D]]>
182
183// -----
184
185// CHECK-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
186
187func @conv_view3(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
188                 %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
189                 %arg2: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
190  linalg.conv(%arg0, %arg1, %arg2) : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
191                                     memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
192                                     memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
193  return
194}
195// CHECK-LABEL: func @conv_view3(
196//       CHECK:   linalg.conv(%{{.*}}, %{{.*}}, %{{.*}}) :
197//  CHECK-SAME:     memref<?x?x?xf32, #[[$strided3D]]>,
198//  CHECK-SAME:     memref<?x?x?xf32, #[[$strided3D]]>,
199//  CHECK-SAME:     memref<?x?x?xf32, #[[$strided3D]]>
200
201// -----
202
203// CHECK-DAG: #[[$strided6D:.*]] = affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5 + d5)>
204
205func @conv_view6(%arg0: memref<?x?x?x?x?x?xf32, offset: ?, strides: [?, ?, ?, ?, ?, 1]>,
206                 %arg1: memref<?x?x?x?x?x?xf32, offset: ?, strides: [?, ?, ?, ?, ?, 1]>,
207                 %arg2: memref<?x?x?x?x?x?xf32, offset: ?, strides: [?, ?, ?, ?, ?, 1]>) {
208  linalg.conv(%arg0, %arg1, %arg2) {dilations = [4, 4, 5, 5], strides = [2, 2, 3, 3]} :
209    memref<?x?x?x?x?x?xf32, offset: ?, strides: [?, ?, ?, ?, ?, 1]>,
210    memref<?x?x?x?x?x?xf32, offset: ?, strides: [?, ?, ?, ?, ?, 1]>,
211    memref<?x?x?x?x?x?xf32, offset: ?, strides: [?, ?, ?, ?, ?, 1]>
212  return
213}
214// CHECK-LABEL: func @conv_view6(
215//       CHECK:   linalg.conv(%{{.*}}, %{{.*}}, %{{.*}}) {
216//  CHECK-SAME:     dilations = [4, 4, 5, 5], strides = [2, 2, 3, 3]} :
217//  CHECK-SAME:     memref<?x?x?x?x?x?xf32, #[[$strided6D]]>,
218//  CHECK-SAME:     memref<?x?x?x?x?x?xf32, #[[$strided6D]]>,
219//  CHECK-SAME:     memref<?x?x?x?x?x?xf32, #[[$strided6D]]>
220
221// -----
222
223func @conv_padding(%arg0: memref<?x?x?x?xf32>,
224                   %arg1: memref<?x?x?x?xf32>,
225                   %arg2: memref<?x?x?x?xf32>) {
226  linalg.conv(%arg0, %arg1, %arg2) {dilations = [1, 1],
227                                    padding = dense<[[0, 1], [1, 1]]> : tensor<2x2xi64>,
228                                    strides = [1, 1]} :
229    memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
230  return
231}
232
233// CHECK-LABEL: func @conv_padding(
234//       CHECK:   linalg.conv(%{{.*}}, %{{.*}}, %{{.*}}) {
235//  CHECK-SAME:     dilations = [1, 1],
236//  CHECK-SAME:     padding = dense<[
237//  CHECK-SAME:                      [0, 1], [1, 1]]> : tensor<2x2xi64>,
238//  CHECK-SAME:     strides = [1, 1]} :
239//  CHECK-SAME:     memref<?x?x?x?xf32>,
240//  CHECK-SAME:     memref<?x?x?x?xf32>,
241//  CHECK-SAME:     memref<?x?x?x?xf32>
242
243// -----
244
245func @pooling_max(%arg0: memref<?x?x?xf32>,
246                  %arg1: memref<?x?x?xi32>,
247                  %arg2: memref<?x?x?xf32>) {
248  linalg.pooling_max(%arg0, %arg1, %arg2) {strides = [2, 1, 2]}:
249    memref<?x?x?xf32>, memref<?x?x?xi32>, memref<?x?x?xf32>
250  return
251}
252// CHECK-LABEL: func @pooling_max
253//       CHECK:   linalg.pooling_max(%{{.*}}, %{{.*}}, %{{.*}})
254//  CHECK-SAME:   {strides = [2, 1, 2]}
255//  CHECK-SAME:   memref<?x?x?xf32>, memref<?x?x?xi32>, memref<?x?x?xf32>
256
257// -----
258
259func @pooling_min(%arg0: memref<?x?x?xf32>,
260                  %arg1: memref<?x?x?xi32>,
261                  %arg2: memref<?x?x?xf32>) {
262  linalg.pooling_min(%arg0, %arg1, %arg2) {strides = [2, 1, 2]}:
263    memref<?x?x?xf32>, memref<?x?x?xi32>, memref<?x?x?xf32>
264  return
265}
266// CHECK-LABEL: func @pooling_min
267//       CHECK:   linalg.pooling_min(%{{.*}}, %{{.*}}, %{{.*}})
268//  CHECK-SAME:   {strides = [2, 1, 2]}
269//  CHECK-SAME:   memref<?x?x?xf32>, memref<?x?x?xi32>, memref<?x?x?xf32>
270
271// -----
272
273func @pooling_sum(%arg0: memref<?x?x?xf32>,
274                  %arg1: memref<?x?x?xi32>,
275                  %arg2: memref<?x?x?xf32>) {
276  linalg.pooling_sum(%arg0, %arg1, %arg2) {strides = [2, 1, 2]}:
277    memref<?x?x?xf32>, memref<?x?x?xi32>, memref<?x?x?xf32>
278  return
279}
280// CHECK-LABEL: func @pooling_sum
281//       CHECK:   linalg.pooling_sum(%{{.*}}, %{{.*}}, %{{.*}})
282//  CHECK-SAME:   {strides = [2, 1, 2]}
283//  CHECK-SAME:   memref<?x?x?xf32>, memref<?x?x?xi32>, memref<?x?x?xf32>
284
285// -----
286
287// CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
288// CHECK-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
289
290#accesses = [
291  affine_map<(i, j, k) -> (j, i)>,
292  affine_map<(i, j, k) -> (i, k, i + j)>
293]
294
295#trait = {
296  indexing_maps = #accesses,
297  iterator_types = ["parallel", "parallel", "parallel"],
298  library_call = "some_external_function_name_1"
299}
300
301func @generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
302              %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
303  linalg.generic #trait
304      ins(%arg0 : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>)
305      outs(%arg1 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
306      attrs = {foo = 1} {
307    ^bb(%0: vector<3x4xi4>, %1: f32) :
308      %f0 = constant 0.0 : f32
309      linalg.yield %f0 : f32
310  }
311  return
312}
313// CHECK-LABEL: func @generic
314//       CHECK:   linalg.generic {
315//  CHECK-SAME:     indexing_maps = [#{{[0-9a-z]*}}, #{{[0-9a-z]*}}],
316//  CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel"],
317//  CHECK-SAME:     library_call = "some_external_function_name_1"}
318//  CHECK-SAME:     ins({{.*}} : memref<?x?xvector<3x4xi4>, #[[$strided2D]]>)
319//  CHECK-SAME:     outs({{.*}} : memref<?x?x?xf32, #[[$strided3D]]>)
320//  CHECK-SAME:     {foo = 1 : i64}
321
322func @generic_with_tensor_input(%arg0: tensor<?x?xvector<3x4xi4>>,
323                                %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
324  linalg.generic #trait
325      ins(%arg0 : tensor<?x?xvector<3x4xi4>>)
326      outs(%arg1 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
327      attrs = {foo = 1} {
328    ^bb(%0: vector<3x4xi4>, %1: f32) :
329      %f0 = constant 0.0 : f32
330      linalg.yield %f0 : f32
331  }
332  return
333}
334// CHECK-LABEL: func @generic_with_tensor_input
335//       CHECK:   linalg.generic {
336//  CHECK-SAME:     indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
337//  CHECK-SAME:     library_call = "some_external_function_name_1"}
338//  CHECK-SAME:     ins({{.*}} : tensor<?x?xvector<3x4xi4>>)
339//  CHECK-SAME:     outs({{.*}} : memref<?x?x?xf32, #[[$strided3D]]>)
340//  CHECK-SAME:     {foo = 1 : i64}
341
342// -----
343
344#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
345func @generic_without_inputs(%arg0 : memref<?x?x?xf32>) {
346  linalg.generic  {indexing_maps = [#map0],
347                   iterator_types = ["parallel", "parallel", "parallel"]}
348                  outs(%arg0 : memref<?x?x?xf32>) {
349   ^bb0(%arg3: f32):  // no predecessors
350      %cst = constant 0.000000e+00 : f32
351      linalg.yield %cst : f32
352    }
353  return
354}
355
356// CHECK-LABEL: func @generic_without_inputs
357//       CHECK:   linalg.generic
358//   CHECK-NOT:     ins
359
360// -----
361
362#accesses = [
363  affine_map<(i, j, k) -> (j, i)>,
364  affine_map<(i, j, k) -> (i, k, i + j)>,
365  affine_map<(i, j, k) -> (i, k, i + j)>
366]
367
368#trait2 = {
369  indexing_maps = #accesses,
370  iterator_types = ["parallel", "parallel", "parallel"],
371  library_call = "some_external_function_name_1"
372}
373
374func @generic_with_tensor_input_and_output(
375    %arg0: tensor<?x?xvector<3x4xi4>>, %arg1: tensor<?x?x?xf32>)
376    -> (tensor<?x?x?xf32>) {
377  %0 = linalg.generic #trait2
378      ins(%arg0, %arg1 : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
379      attrs = {foo = 1} {
380    ^bb(%0: vector<3x4xi4>, %1: f32) :
381      %f0 = constant 0.0 : f32
382      linalg.yield %f0 : f32
383  } -> tensor<?x?x?xf32>
384  return %0 : tensor<?x?x?xf32>
385}
386// CHECK-LABEL: func @generic_with_tensor_input_and_output
387//       CHECK:   linalg.generic {
388//  CHECK-SAME:     indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
389//  CHECK-SAME:     library_call = "some_external_function_name_1"}
390//  CHECK-SAME:     ins({{.*}} : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
391//  CHECK-SAME:     {foo = 1 : i64}
392//       CHECK:     -> tensor<?x?x?xf32>
393//       CHECK:   return {{.*}} : tensor<?x?x?xf32>
394
395// -----
396
397#accesses = [
398  affine_map<(i, j, k) -> (j, i)>,
399  affine_map<(i, j, k) -> (i, k, i + j)>,
400  affine_map<(i, j, k) -> (i, k, i + j)>
401]
402
403#trait2 = {
404  indexing_maps = #accesses,
405  iterator_types = ["parallel", "parallel", "parallel"],
406  library_call = "some_external_function_name_1"
407}
408
409func @indexed_generic_with_tensor_input_and_output(
410    %arg0: tensor<?x?xvector<3x4xi4>>, %arg1: tensor<?x?x?xf32>)
411    -> (tensor<?x?x?xf32>) {
412  %0 = linalg.indexed_generic #trait2
413      ins(%arg0, %arg1 : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
414      attrs = {foo = 1} {
415    ^bb(%i: index, %j: index, %k: index, %0: vector<3x4xi4>, %1: f32) :
416      %f0 = constant 0.0 : f32
417      linalg.yield %f0 : f32
418  } -> tensor<?x?x?xf32>
419  return %0 : tensor<?x?x?xf32>
420}
421// CHECK-LABEL: func @indexed_generic_with_tensor_input_and_output
422//       CHECK:   linalg.indexed_generic {
423//  CHECK-SAME:     indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
424//  CHECK-SAME:     library_call = "some_external_function_name_1"}
425//  CHECK-SAME:     ins({{.*}} : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
426//  CHECK-SAME:     {foo = 1 : i64}
427//       CHECK:     -> tensor<?x?x?xf32>
428//       CHECK:   return {{.*}} : tensor<?x?x?xf32>
429
430// -----
431
432#broadcast_access = [
433  affine_map<(i, j) -> ()>,
434  affine_map<(i, j) -> (i, j)>
435]
436
437#trait_broadcast = {
438  indexing_maps = #broadcast_access,
439  iterator_types = ["parallel", "parallel"],
440  library_call = "some_broadcast_external_fn"
441}
442
443func @generic_op_zero_rank(%arg0: tensor<f32>) -> (tensor<3x4xf32>)
444{
445  %0 = linalg.generic #trait_broadcast
446      ins(%arg0 : tensor<f32>) {
447    ^bb(%a: f32) :
448      linalg.yield %a : f32
449  } -> tensor<3x4xf32>
450  return %0 : tensor<3x4xf32>
451}
452
453func @indexed_generic_op_zero_rank(%arg0: tensor<f32>) -> (tensor<3x4xf32>)
454{
455  %0 = linalg.indexed_generic #trait_broadcast
456      ins(%arg0 : tensor<f32>) {
457    ^bb(%i: index, %j: index, %a: f32) :
458      linalg.yield %a : f32
459  } -> tensor<3x4xf32>
460  return %0 : tensor<3x4xf32>
461}
462
463// -----
464
465// CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
466// CHECK-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
467
468#accesses = [
469  affine_map<(i, j, k) -> (j, i)>,
470  affine_map<(i, j, k) -> (i, k, i + j)>
471]
472
473#trait3 = {
474  indexing_maps = #accesses,
475  iterator_types = ["parallel", "parallel", "parallel"],
476  library_call = "some_external_function_name_2"
477}
478
479func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
480                     %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
481  linalg.generic #trait3
482      ins(%arg0 : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>)
483      outs(%arg1 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
484      attrs = {foo = 1} {
485    ^bb(%a: vector<3x4xi4>, %b: f32) :
486      linalg.yield %b : f32
487  }
488  return
489}
490// CHECK-LABEL: func @generic_region
491//       CHECK:   linalg.generic {
492//  CHECK-SAME:     indexing_maps = [#{{[0-9a-z]*}}, #{{[0-9a-z]*}}],
493//  CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel"],
494//  CHECK-SAME:     library_call = "some_external_function_name_2"
495//  CHECK-SAME:     ins({{.*}} : memref<?x?xvector<3x4xi4>, #[[$strided2D]]>)
496//  CHECK-SAME:     outs({{.*}} : memref<?x?x?xf32, #[[$strided3D]]>)
497//  CHECK-SAME:     attrs = {foo = 1 : i64} {
498//       CHECK:  ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32):
499//       CHECK:    linalg.yield %{{.*}} : f32
500
501func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
502                      %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
503  linalg.indexed_generic #trait3
504      ins(%arg0 : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>)
505      outs(%arg1 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
506      attrs = {foo = 1} {
507    ^bb(%i: index, %j: index, %k: index, %a: vector<3x4xi4>, %b: f32) :
508      linalg.yield %b : f32
509  }
510  return
511}
512// CHECK-LABEL: func @indexed_generic
513//       CHECK:   linalg.indexed_generic {
514//  CHECK-SAME:     indexing_maps = [#{{[0-9a-z]*}}, #{{[0-9a-z]*}}],
515//  CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel"],
516//  CHECK-SAME:     library_call = "some_external_function_name_2"
517//  CHECK-SAME:      ins({{.*}} : memref<?x?xvector<3x4xi4>, #[[$strided2D]]>)
518//  CHECK-SAME:     outs({{.*}} : memref<?x?x?xf32, #[[$strided3D]]>)
519//  CHECK-SAME:     {foo = 1 : i64}
520//       CHECK:    ^{{.*}}(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: vector<3x4xi4>, %{{.*}}: f32):
521//       CHECK:      linalg.yield %{{.*}} : f32
522//       CHECK:    }
523
524// -----
525
526// CHECK-DAG: #[[$reshapeD01:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
527// CHECK-DAG: #[[$reshapeD2:.*]] = affine_map<(d0, d1, d2) -> (d2)>
528// CHECK-DAG: #[[$reshapeD0:.*]] = affine_map<(d0, d1, d2) -> (d0)>
529// CHECK-DAG: #[[$reshapeD12:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
530// CHECK-DAG: #[[$reshapeD012:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
531// CHECK-DAG: #[[$reshape5D01:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
532// CHECK-DAG: #[[$reshape5D2:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d2)>
533// CHECK-DAG: #[[$reshape5D34:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
534
535func @reshape_static(%arg0: memref<3x4x5xf32>, %arg1: tensor<3x4x5xf32>, %arg2: tensor<3x?x5xf32>) {
536  // Reshapes that collapse and expand back a contiguous buffer.
537  %0 = linalg.reshape %arg0 [affine_map<(i, j, k) -> (i, j)>,
538                             affine_map<(i, j, k) -> (k)>] :
539    memref<3x4x5xf32> into memref<12x5xf32>
540  %r0 = linalg.reshape %0 [affine_map<(i, j, k) -> (i, j)>,
541                           affine_map<(i, j, k) -> (k)>] :
542    memref<12x5xf32> into memref<3x4x5xf32>
543  %1 = linalg.reshape %arg0 [affine_map<(i, j, k) -> (i)>,
544                             affine_map<(i, j, k) -> (j, k)>] :
545    memref<3x4x5xf32> into memref<3x20xf32>
546  %r1 = linalg.reshape %1 [affine_map<(i, j, k) -> (i)>,
547                           affine_map<(i, j, k) -> (j, k)>] :
548    memref<3x20xf32> into memref<3x4x5xf32>
549  %2 = linalg.reshape %arg0 [affine_map<(i, j, k) -> (i, j, k)>] :
550    memref<3x4x5xf32> into memref<60xf32>
551  %r2 = linalg.reshape %2 [affine_map<(i, j, k) -> (i, j, k)>] :
552    memref<60xf32> into memref<3x4x5xf32>
553  // Reshapes that expand and collapse back a contiguous buffer with some 1's.
554  %3 = linalg.reshape %arg0 [affine_map<(i, j, k, l, m) -> (i, j)>,
555                             affine_map<(i, j, k, l, m) -> (k)>,
556                             affine_map<(i, j, k, l, m) -> (l, m)>] :
557    memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
558  %r3 = linalg.reshape %3 [affine_map<(i, j, k, l, m) -> (i, j)>,
559                           affine_map<(i, j, k, l, m) -> (k)>,
560                           affine_map<(i, j, k, l, m) -> (l, m)>] :
561    memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
562  // Reshapes on tensors.
563  %t0 = linalg.tensor_reshape %arg1 [affine_map<(i, j, k, l, m) -> (i, j)>,
564                                     affine_map<(i, j, k, l, m) -> (k)>,
565                                     affine_map<(i, j, k, l, m) -> (l, m)>] :
566    tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
567  %rt0 = linalg.tensor_reshape %t0 [affine_map<(i, j, k, l, m) -> (i, j)>,
568                                   affine_map<(i, j, k, l, m) -> (k)>,
569                                   affine_map<(i, j, k, l, m) -> (l, m)>] :
570    tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
571  %t1 = linalg.tensor_reshape %arg2 [affine_map<(i, j, k, l, m) -> (i, j)>,
572                                     affine_map<(i, j, k, l, m) -> (k)>,
573                                     affine_map<(i, j, k, l, m) -> (l, m)>] :
574    tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
575  %rt1 = linalg.tensor_reshape %t1 [affine_map<(i, j, k, l, m) -> (i)>,
576                                    affine_map<(i, j, k, l, m) -> (j, k)>,
577                                    affine_map<(i, j, k, l, m) -> (l, m)>] :
578    tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
579  return
580}
581// CHECK-LABEL: func @reshape_static
582//       CHECK:   linalg.reshape {{.*}} [#[[$reshapeD01]], #[[$reshapeD2]]]
583//  CHECK-SAME:     memref<3x4x5xf32> into memref<12x5xf32>
584//       CHECK:   linalg.reshape {{.*}} [#[[$reshapeD01]], #[[$reshapeD2]]]
585//  CHECK-SAME:     memref<12x5xf32> into memref<3x4x5xf32>
586//       CHECK:   linalg.reshape {{.*}} [#[[$reshapeD0]], #[[$reshapeD12]]]
587//  CHECK-SAME:     memref<3x4x5xf32> into memref<3x20xf32>
588//       CHECK:   linalg.reshape {{.*}} [#[[$reshapeD0]], #[[$reshapeD12]]]
589//  CHECK-SAME:     memref<3x20xf32> into memref<3x4x5xf32>
590//       CHECK:   linalg.reshape {{.*}} [#[[$reshapeD012]]]
591//  CHECK-SAME:     memref<3x4x5xf32> into memref<60xf32>
592//       CHECK:   linalg.reshape {{.*}} [#[[$reshapeD012]]]
593//  CHECK-SAME:     memref<60xf32> into memref<3x4x5xf32>
594//       CHECK:   linalg.reshape {{.*}} [#[[$reshape5D01]], #[[$reshape5D2]], #[[$reshape5D34]]]
595//  CHECK-SAME:     memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
596//       CHECK:   linalg.reshape {{.*}} [#[[$reshape5D01]], #[[$reshape5D2]], #[[$reshape5D34]]]
597//  CHECK-SAME:     memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
598//
599//       CHECK:   linalg.tensor_reshape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
600//       CHECK:   linalg.tensor_reshape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
601//       CHECK:   linalg.tensor_reshape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
602//       CHECK:   linalg.tensor_reshape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
603
604// -----
605
606// CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
607// CHECK-DAG: #[[$strided2DOFF0:.*]] = affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>
608// CHECK-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>
609// CHECK-DAG: #[[$strided3DOFF0:.*]] = affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s0 + d1 * s1 + d2)>
610
611func @reshape_dynamic(%arg0: memref<?x?x?xf32>,
612                      %arg1: memref<?x?x?xf32, offset : 0, strides : [?, ?, 1]>,
613                      %arg2: memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]>) {
614  %0 = linalg.reshape %arg0 [affine_map<(i, j, k) -> (i, j)>,
615                             affine_map<(i, j, k) -> (k)>] :
616    memref<?x?x?xf32> into memref<?x?xf32>
617  %r0 = linalg.reshape %0 [affine_map<(i, j, k) -> (i, j)>,
618                           affine_map<(i, j, k) -> (k)>] :
619    memref<?x?xf32> into memref<?x?x?xf32>
620  %1 = linalg.reshape %arg1 [affine_map<(i, j, k) -> (i, j)>,
621                             affine_map<(i, j, k) -> (k)>] :
622    memref<?x?x?xf32, offset : 0, strides : [?, ?, 1]> into
623    memref<?x?xf32, offset : 0, strides : [?, 1]>
624  %r1 = linalg.reshape %1 [affine_map<(i, j, k) -> (i, j)>,
625                           affine_map<(i, j, k) -> (k)>] :
626    memref<?x?xf32, offset : 0, strides : [?, 1]> into
627    memref<?x?x?xf32, offset : 0, strides : [?, ?, 1]>
628  %2 = linalg.reshape %arg2 [affine_map<(i, j, k) -> (i, j)>,
629                             affine_map<(i, j, k) -> (k)>] :
630    memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]> into
631    memref<?x?xf32, offset : ?, strides : [?, 1]>
632  %r2 = linalg.reshape %2 [affine_map<(i, j, k) -> (i, j)>,
633                           affine_map<(i, j, k) -> (k)>] :
634    memref<?x?xf32, offset : ?, strides : [?, 1]> into
635    memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]>
636  return
637}
638
639// CHECK-DAG: #[[$reshapeD01:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
640// CHECK-DAG: #[[$reshapeD2:.*]] = affine_map<(d0, d1, d2) -> (d2)>
641
642// CHECK-LABEL: func @reshape
643//       CHECK:   linalg.reshape {{.*}} [#[[$reshapeD01]], #[[$reshapeD2]]]
644//  CHECK-SAME:     memref<?x?x?xf32> into memref<?x?xf32>
645//       CHECK:   linalg.reshape {{.*}} [#[[$reshapeD01]], #[[$reshapeD2]]]
646//  CHECK-SAME:     memref<?x?xf32> into memref<?x?x?xf32>
647//       CHECK:   linalg.reshape {{.*}} [#[[$reshapeD01]], #[[$reshapeD2]]]
648//  CHECK-SAME:     memref<?x?x?xf32, #[[$strided3DOFF0]]> into memref<?x?xf32, #[[$strided2DOFF0]]>
649//       CHECK:   linalg.reshape {{.*}} [#[[$reshapeD01]], #[[$reshapeD2]]]
650//  CHECK-SAME:     memref<?x?xf32, #[[$strided2DOFF0]]> into memref<?x?x?xf32, #[[$strided3DOFF0]]>
651//       CHECK:   linalg.reshape {{.*}} [#[[$reshapeD01]], #[[$reshapeD2]]]
652//  CHECK-SAME:     memref<?x?x?xf32, #[[$strided3D]]> into memref<?x?xf32, #[[$strided2D]]>
653//       CHECK:   linalg.reshape {{.*}} [#[[$reshapeD01]], #[[$reshapeD2]]]
654//  CHECK-SAME:     memref<?x?xf32, #[[$strided2D]]> into memref<?x?x?xf32, #[[$strided3D]]>
655
656func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x?xf32>,
657                %ta3: tensor<?x?x?xf32>, %tb3: tensor<?x?x?xf32>, %tc3: tensor<?x?x?xf32>)
658  -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
659{
660  linalg.batch_matmul ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?x?xf32>)
661                     outs(%c3: memref<?x?x?xf32>)
662  linalg.batch_matmul ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
663                     outs(%c3: memref<?x?x?xf32>)
664  %res1 = linalg.batch_matmul ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
665                     init(%tc3: tensor<?x?x?xf32>)
666                  -> tensor<?x?x?xf32>
667  %res2 = linalg.batch_matmul ins(%ta3, %b3: tensor<?x?x?xf32>, memref<?x?x?xf32>)
668                     init(%tc3: tensor<?x?x?xf32>)
669                  -> tensor<?x?x?xf32>
670  return %res1, %res2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
671}
672// CHECK-LABEL: func @named_ops
673//       CHECK:   linalg.batch_matmul
674//       CHECK:   linalg.batch_matmul
675//       CHECK:   linalg.batch_matmul
676//       CHECK:   linalg.batch_matmul
677
678// -----
679
680func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor<f32>) -> (tensor<f32>, tensor<1x1xf32>)
681{
682  %0 = linalg.tensor_reshape %arg0 [] : tensor<1x1xf32> into tensor<f32>
683  %1 = linalg.tensor_reshape %0 [] : tensor<f32> into tensor<1x1xf32>
684  return %0, %1 : tensor<f32>, tensor<1x1xf32>
685}
686// CHECK-LABEL: func @tensor_reshape_zero_dim
687//       CHECK:   linalg.tensor_reshape %{{.*}} [] : tensor<1x1xf32> into tensor<f32>
688//       CHECK:   linalg.tensor_reshape %{{.*}} [] : tensor<f32> into tensor<1x1xf32>
689
690// -----
691
692func @memref_reshape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref<f32>) -> (memref<f32>, memref<1x1xf32>)
693{
694  %0 = linalg.reshape %arg0 [] : memref<1x1xf32> into memref<f32>
695  %1 = linalg.reshape %0 [] : memref<f32> into memref<1x1xf32>
696  return %0, %1 : memref<f32>, memref<1x1xf32>
697}
698// CHECK-LABEL: func @memref_reshape_zero_dim
699//       CHECK:   linalg.reshape %{{.*}} [] : memref<1x1xf32> into memref<f32>
700//       CHECK:   linalg.reshape %{{.*}} [] : memref<f32> into memref<1x1xf32>
701