1// RUN: mlir-opt -pass-pipeline="func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),canonicalize,cse" -split-input-file %s | FileCheck %s 2 3module { 4 func @three_op_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, 5 %arg2: memref<?xf32>, %arg3 : memref<?x?xf32>) { 6 %cst = constant 0.000000e+00 : f32 7 %c0 = constant 0 : index 8 %c1 = constant 1 : index 9 %d0 = dim %arg0, %c0 : memref<?x?xf32> 10 %d1 = dim %arg1, %c1 : memref<?x?xf32> 11 %0 = alloc(%d0, %d1) : memref<?x?xf32> 12 linalg.fill(%0, %cst) : memref<?x?xf32>, f32 13 linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>) 14 outs(%0 : memref<?x?xf32>) 15 linalg.generic 16 {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, 17 affine_map<(d0, d1) -> (d1)>, 18 affine_map<(d0, d1) -> (d0, d1)>], 19 iterator_types = ["parallel", "parallel"]} 20 ins(%0, %arg2 : memref<?x?xf32>, memref<?xf32>) 21 outs(%arg3 : memref<?x?xf32>) { 22 ^bb0(%arg4 : f32, %arg5 : f32, %arg6 : f32) : 23 %5 = addf %arg4, %arg5 : f32 24 linalg.yield %5 : f32 25 } 26 return 27 } 28} 29 30// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> 31// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> 32// CHECK: func @three_op_fusion 33// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32> 34// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32> 35// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?xf32> 36// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32> 37// CHECK: %[[TEMP:.+]] = alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32> 38// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = {{.*}} { 39// CHECK-DAG: %[[SV_TEMP:.+]] = subview %[[TEMP]][%[[IV0]], %[[IV1]]] 40// CHECK-DAG: %[[SV_ARG2:.+]] = subview %[[ARG2]][%[[IV1]]] 41// CHECK-DAG: %[[SV_ARG3:.+]] = subview %[[ARG3]][%[[IV0]], %[[IV1]]] 42// CHECK-DAG: %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0] 43// CHECK-DAG: %[[SV_ARG1:.+]] = subview %[[ARG1]][0, %[[IV1]]] 44// CHECK: linalg.fill(%[[SV_TEMP]], %{{.+}}) 45// CHECK: linalg.matmul 46// CHECK-SAME: ins(%[[SV_ARG0]], %[[SV_ARG1]] 47// CHECK-SAME: : memref<?x?xf32, #[[MAP2]]>, memref<?x?xf32, #[[MAP2]]>) 48// CHECK-SAME: outs(%[[SV_TEMP]] : memref<?x?xf32, #[[MAP2]]>) 49// CHECK: linalg.generic 50// CHECK-SAME: ins(%[[SV_TEMP]], %[[SV_ARG2]] 51// CHECK-SAME: : memref<?x?xf32, #[[MAP2]]>, memref<?xf32, #[[MAP3]]>) 52// CHECK-SAME: outs(%[[SV_ARG3]] : memref<?x?xf32, #[[MAP2]]>) 53// CHECK: scf.yield 54// CHECK: } 55 56// ----- 57 58module { 59 func @sequence_of_matmul(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, 60 %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>, 61 %arg4: memref<?x?xf32>) { 62 %cst = constant 0.000000e+00 : f32 63 %c0 = constant 0 : index 64 %c1 = constant 1 : index 65 %m = dim %arg0, %c0 : memref<?x?xf32> 66 %n1 = dim %arg1, %c1 : memref<?x?xf32> 67 %n2 = dim %arg2, %c1 : memref<?x?xf32> 68 %n3 = dim %arg3, %c1 : memref<?x?xf32> 69 %0 = alloc(%m, %n1) : memref<?x?xf32> 70 %1 = alloc(%m, %n2) : memref<?x?xf32> 71 linalg.fill(%0, %cst) : memref<?x?xf32>, f32 72 linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>) 73 outs(%0 : memref<?x?xf32>) 74 linalg.fill(%1, %cst) : memref<?x?xf32>, f32 75 linalg.matmul ins(%0, %arg2 : memref<?x?xf32>, memref<?x?xf32>) 76 outs(%1 : memref<?x?xf32>) 77 linalg.fill(%arg4, %cst) : memref<?x?xf32>, f32 78 linalg.matmul ins(%1, %arg3 : memref<?x?xf32>, memref<?x?xf32>) 79 outs(%arg4 : memref<?x?xf32>) 80 return 81 } 82} 83 84// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> 85// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> 86// CHECK: func @sequence_of_matmul 87// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32> 88// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32> 89// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32> 90// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32> 91// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: memref<?x?xf32> 92// CHECK-DAG: %[[C0:.+]] = constant 0 : index 93// CHECK-DAG: %[[C1:.+]] = constant 1 : index 94// CHECK-DAG: %[[C16:.+]] = constant 16 : index 95// CHECK-DAG: %[[M:.+]] = dim %[[ARG0]], %[[C0]] 96// CHECK-DAG: %[[N1:.+]] = dim %[[ARG1]], %[[C1]] 97// CHECK-DAG: %[[N2:.+]] = dim %[[ARG2]], %[[C1]] 98// CHECK: %[[ALLOC1:.+]] = alloc(%[[M]], %[[N1]]) 99// CHECK: %[[ALLOC2:.+]] = alloc(%[[M]], %[[N2]]) 100// CHECK: scf.parallel (%[[IV0:.+]]) = (%[[C0]]) to (%[[M]]) 101// CHECK-SAME: step (%[[C16]]) { 102// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] 103// CHECK: %[[SV_ALLOC2:.+]] = subview %[[ALLOC2]][%[[IV0]], 0] 104// CHECK-SAME: [%[[TILE_M]], %[[N2]]] 105// CHECK: %[[M_2:.+]] = dim %[[ARG4]], %[[C0]] 106// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]] 107// CHECK: %[[N3:.+]] = dim %[[ARG4]], %[[C1]] 108// CHECK: %[[SV_ARG4:.+]] = subview %[[ARG4]][%[[IV0]], 0] 109// CHECK-SAME: [%[[TILE_M_2]], %[[N3]]] 110// CHECK: %[[SV_ARG4_2:.+]] = subview %[[ARG4]][%[[IV0]], 0] 111// CHECK-SAME: [%[[TILE_M]], %[[N3]]] 112// CHECK: %[[SV_ALLOC1:.+]] = subview %[[ALLOC1]][%[[IV0]], 0] 113// CHECK-SAME: [%[[TILE_M]], %[[N1]]] 114// CHECK: %[[SV_ARG2:.+]] = subview %[[ARG2]][0, 0] [%[[N1]], %[[N2]]] 115// CHECK: %[[N0:.+]] = dim %[[ARG0]], %[[C1]] 116// CHECK: %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0] 117// CHECK-SAME: [%[[TILE_M:.+]], %[[N0]]] 118// CHECK: %[[SV_ARG1:.+]] = subview %[[ARG1]][0, 0] [%[[N0]], %[[N1]]] 119// CHECK: linalg.fill(%[[SV_ALLOC1]], %{{.+}}) 120// CHECK: linalg.matmul ins(%[[SV_ARG0]], %[[SV_ARG1]] 121// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>) 122// CHECK-SAME: outs(%[[SV_ALLOC1]] : memref<?x?xf32, #[[MAP1]]>) 123// CHECK: linalg.fill(%[[SV_ALLOC2]], %{{.+}}) 124// CHECK: linalg.matmul ins(%[[SV_ALLOC1]], %[[SV_ARG2]] 125// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>) 126// CHECK-SAME: outs(%[[SV_ALLOC2]] : memref<?x?xf32, #[[MAP1]]>) 127// CHECK: linalg.fill(%[[SV_ARG4_2]], %{{.+}}) 128// CHECK: linalg.matmul ins(%[[SV_ALLOC2]], %[[ARG3]] 129// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>) 130// CHECK-SAME: outs(%[[SV_ARG4]] : memref<?x?xf32, #[[MAP1]]>) 131// CHECK: scf.yield 132// CHECK: } 133 134