1// RUN: mlir-hlo-opt -lhlo-fuse-linalg %s -split-input-file | FileCheck %s --dump-input=always 2// RUN: mlir-hlo-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -split-input-file | FileCheck %s -check-prefix=TILED 3// RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP 4 5#map0 = affine_map<(d0, d1) -> (d0, d1)> 6#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], 7 iterator_types = ["parallel", "parallel"]} 8func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, 9 %summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) { 10 %temp_result = alloc() : memref<6x6xf32> 11 linalg.generic #pointwise_2d_trait 12 ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>) 13 outs(%temp_result : memref<6x6xf32>) { 14 ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32): 15 %out = addf %summand_1_in, %summand_2_in : f32 16 linalg.yield %out : f32 17 } 18 linalg.generic #pointwise_2d_trait 19 ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>) 20 outs(%result : memref<6x6xf32>) { 21 ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32): 22 %out = mulf %temp_result_in, %multiplier_in : f32 23 linalg.yield %out : f32 24 } 25 dealloc %temp_result : memref<6x6xf32> 26 return 27} 28// CHECK-LABEL: func @fusion 29// CHECK: %[[C1:.*]] = constant 1 30// CHECK-NOT: linalg.generic 31// CHECK: scf.for {{.*}} step %[[C1]] 32// CHECK: scf.for {{.*}} step %[[C1]] 33// CHECK-NOT: scf.for 34// CHECK: linalg.generic 35// CHECK: addf 36// CHECK: linalg.generic 37// CHECK: mulf 38 39// TILED-LABEL: func @fusion 40// TILED-DAG: %[[C2:.*]] = constant 2 41// TILED-DAG: %[[C3:.*]] = constant 3 42// TILED-NOT: linalg.generic 43// TILED: scf.for {{.*}} step %[[C2]] 44// TILED: scf.for {{.*}} step %[[C3]] 45// TILED-NOT: scf.for 46// TILED: linalg.generic 47// TILED: addf 48// TILED: linalg.generic 49// TILED: mulf 50 51// PLOOP-LABEL: func @fusion 52// PLOOP-NOT: linalg.generic 53// PLOOP: scf.parallel 54// PLOOP-NOT: scf.parallel 55// PLOOP: linalg.generic 56// PLOOP: addf 57// PLOOP: linalg.generic 58// PLOOP: mulf 59 60// ----- 61 62func @fusion_of_three(%arg0: memref<100x10xf32>, 63 %arg1: memref<100xf32>, 64 %arg2: memref<100x10xf32>) { 65 %0 = alloc() : memref<100x10xf32> 66 linalg.generic { 67 indexing_maps = [affine_map<(d0, d1) -> (d0)>, 68 affine_map<(d0, d1) -> (d0, d1)>], 69 iterator_types = ["parallel", "parallel"]} 70 ins(%arg1 : memref<100xf32>) 71 outs(%0 : memref<100x10xf32>) { 72 ^bb0(%arg3: f32, %arg4: f32): // no predecessors 73 linalg.yield %arg3 : f32 74 } 75 %1 = alloc() : memref<100x10xf32> 76 linalg.generic { 77 indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, 78 affine_map<(d0, d1) -> (d0, d1)>, 79 affine_map<(d0, d1) -> (d0, d1)>], 80 iterator_types = ["parallel", "parallel"]} 81 ins(%arg0, %0 : memref<100x10xf32>, memref<100x10xf32>) 82 outs(%1 : memref<100x10xf32>) { 83 ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors 84 %2 = subf %arg3, %arg4 : f32 85 linalg.yield %2 : f32 86 } 87 dealloc %0 : memref<100x10xf32> 88 linalg.generic { 89 indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, 90 affine_map<(d0, d1) -> (d0, d1)>], 91 iterator_types = ["parallel", "parallel"]} 92 ins(%1 : memref<100x10xf32>) 93 outs(%arg2 : memref<100x10xf32>) { 94 ^bb0(%arg3: f32, %arg4: f32): // no predecessors 95 %2 = math.exp %arg3 : f32 96 linalg.yield %2 : f32 97 } 98 dealloc %1 : memref<100x10xf32> 99 return 100} 101// CHECK-LABEL: func @fusion 102// CHECK: %[[C1:.*]] = constant 1 103// CHECK-NOT: linalg.generic 104// CHECK: scf.for {{.*}} step %[[C1]] 105// CHECK: scf.for {{.*}} step %[[C1]] 106// CHECK-NOT: scf.for 107// CHECK: linalg.generic 108// CHECK: linalg.generic 109// CHECK: subf 110// CHECK: linalg.generic 111// CHECK: exp 112 113// TILED-LABEL: func @fusion_of_three 114// TILED-DAG: %[[C2:.*]] = constant 2 115// TILED-DAG: %[[C3:.*]] = constant 3 116// TILED-NOT: linalg.generic 117// TILED: scf.for {{.*}} step %[[C2]] 118// TILED: scf.for {{.*}} step %[[C3]] 119// TILED-NOT: scf.for 120// TILED: linalg.generic 121// TILED: linalg.generic 122// TILED: subf 123// TILED: linalg.generic 124// TILED: exp 125 126// PLOOP-LABEL: func @fusion_of_three 127// PLOOP-NOT: linalg.generic 128// PLOOP: scf.parallel 129// PLOOP-NOT: scf.parallel 130// PLOOP: linalg.generic 131// PLOOP: linalg.generic 132// PLOOP: subf 133// PLOOP: linalg.generic 134// PLOOP: exp 135 136// ----- 137 138#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 139#pointwise_4d_trait = {indexing_maps = [#map0, #map0, #map0], 140 iterator_types = ["parallel", "parallel", "parallel", 141 "parallel"]} 142func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>, 143 %summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) { 144 %temp_result = alloc() : memref<6x6x6x6xf32> 145 linalg.generic #pointwise_4d_trait 146 ins(%summand_1, %summand_2 : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>) 147 outs(%temp_result : memref<6x6x6x6xf32>) { 148 ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32): 149 %out = addf %summand_1_in, %summand_2_in : f32 150 linalg.yield %out : f32 151 } 152 linalg.generic #pointwise_4d_trait 153 ins(%temp_result, %multiplier : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>) 154 outs(%result : memref<6x6x6x6xf32>) { 155 ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32): 156 %out = mulf %temp_result_in, %multiplier_in : f32 157 linalg.yield %out : f32 158 } 159 dealloc %temp_result : memref<6x6x6x6xf32> 160 return 161} 162// CHECK-LABEL: func @fusion_4d 163// CHECK: %[[C1:.*]] = constant 1 164// CHECK-NOT: linalg.generic 165// CHECK: scf.for {{.*}} step %[[C1]] 166// CHECK: scf.for {{.*}} step %[[C1]] 167// CHECK: scf.for {{.*}} step %[[C1]] 168// CHECK: scf.for {{.*}} step %[[C1]] 169// CHECK-NOT: scf.for 170// CHECK: linalg.generic 171// CHECK: addf 172// CHECK: linalg.generic 173// CHECK: mulf 174 175// TILED-LABEL: func @fusion_4d 176// TILED-DAG: %[[C2:.*]] = constant 2 177// TILED-DAG: %[[C3:.*]] = constant 3 178// TILED-NOT: linalg.generic 179// TILED: scf.for {{.*}} step %[[C2]] 180// TILED: scf.for {{.*}} step %[[C3]] 181// TILED-NOT: scf.for 182// TILED: linalg.generic 183// TILED: addf 184// TILED: linalg.generic 185// TILED: mulf 186 187// PLOOP-LABEL: func @fusion_4d 188// PLOOP-NOT: linalg.generic 189// PLOOP: scf.parallel 190// PLOOP-NOT: scf.parallel 191// PLOOP: linalg.generic 192// PLOOP: addf 193// PLOOP: linalg.generic 194// PLOOP: mulf 195 196// ----- 197 198#map0 = affine_map<(d0, d1) -> (d0, d1)> 199#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], 200 iterator_types = ["parallel", "parallel"]} 201func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, 202 %summand_2: memref<6x6xf32>) -> memref<6x6xf32> { 203 %temp_result = alloc() : memref<6x6xf32> 204 linalg.generic #pointwise_2d_trait 205 ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>) 206 outs(%temp_result : memref<6x6xf32>) { 207 ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32): 208 %out = addf %summand_1_in, %summand_2_in : f32 209 linalg.yield %out : f32 210 } 211 %result = alloc() : memref<6x6xf32> 212 linalg.generic #pointwise_2d_trait 213 ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>) 214 outs(%result : memref<6x6xf32>) { 215 ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32): 216 %out = mulf %temp_result_in, %multiplier_in : f32 217 linalg.yield %out : f32 218 } 219 dealloc %temp_result : memref<6x6xf32> 220 return %result : memref<6x6xf32> 221} 222 223// CHECK-LABEL: func @fusion 224// CHECK: %[[C1:.*]] = constant 1 225// CHECK-NOT: linalg.generic 226// CHECK: scf.for {{.*}} step %[[C1]] 227// CHECK: scf.for {{.*}} step %[[C1]] 228// CHECK-NOT: scf.for 229// CHECK: linalg.generic 230// CHECK: addf 231// CHECK: linalg.generic 232// CHECK: mulf 233 234// TILED-LABEL: func @fusion 235// TILED-DAG: %[[C2:.*]] = constant 2 236// TILED-DAG: %[[C3:.*]] = constant 3 237// TILED-NOT: linalg.generic 238// TILED: scf.for {{.*}} step %[[C2]] 239// TILED: scf.for {{.*}} step %[[C3]] 240// TILED-NOT: scf.for 241// TILED: linalg.generic 242// TILED: addf 243// TILED: linalg.generic 244// TILED: mulf 245 246// PLOOP-LABEL: func @fusion 247// PLOOP-NOT: linalg.generic 248// PLOOP: scf.parallel 249// PLOOP-NOT: scf.parallel 250// PLOOP: linalg.generic 251// PLOOP: addf 252// PLOOP: linalg.generic 253// PLOOP: mulf 254 255// ----- 256 257func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index) 258 -> memref<*xf32> { 259 %c1 = constant 1 : index 260 %c0 = constant 0 : index 261 %1 = alloc(%arg2) : memref<?xf32> 262 linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, 263 affine_map<(d0) -> (d0)>], 264 iterator_types = ["parallel"]} 265 ins(%arg0 : memref<?xf32>) outs(%1 : memref<?xf32>) { 266 ^bb0(%arg3: f32, %arg4: f32): // no predecessors 267 %13 = absf %arg3 : f32 268 linalg.yield %13 : f32 269 } 270 %2 = memref_reshape %1(%arg1) 271 : (memref<?xf32>, memref<?xindex>) -> memref<*xf32> 272 return %2 : memref<*xf32> 273} 274 275// CHECK-LABEL: func @view_result 276// CHECK: %[[C1:.*]] = constant 1 277// CHECK-NOT: linalg.generic 278// CHECK: scf.for {{.*}} step %[[C1]] 279// CHECK-NOT: scf.for 280// CHECK: linalg.generic 281// CHECK: absf 282// CHECK: memref_reshape 283 284// TILED-LABEL: func @view_result 285// TILED-DAG: %[[C2:.*]] = constant 2 286// TILED-NOT: linalg.generic 287// TILED: scf.for {{.*}} step %[[C2]] 288// TILED-NOT: scf.for 289// TILED: linalg.generic 290// TILED: absf 291// TILED: memref_reshape 292 293 294// PLOOP-LABEL: func @view_result 295// PLOOP-NOT: linalg.generic 296// PLOOP: scf.parallel 297// PLOOP-NOT: scf.parallel 298// PLOOP: linalg.generic 299// PLOOP: absf 300// PLOOP: memref_reshape 301 302 303 304// ----- 305 306// Confirm that tiling information is passed through RegionBranchOpInterfaces. 307// This test also uses memref_reshape, just to have a value to return through 308// the if statement. 309func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index) 310 -> memref<*xf32> { 311 %c1 = constant 1 : index 312 %c0 = constant 0 : index 313 %1 = alloc(%arg2) : memref<?xf32> 314 linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, 315 affine_map<(d0) -> (d0)>], 316 iterator_types = ["parallel"]} 317 ins(%arg0 : memref<?xf32>) outs(%1 : memref<?xf32>) { 318 ^bb0(%arg3: f32, %arg4: f32): // no predecessors 319 %13 = absf %arg3 : f32 320 linalg.yield %13 : f32 321 } 322 %true = constant 1 : i1 323 %3 = scf.if %true -> memref<*xf32> { 324 %2 = memref_reshape %1(%arg1) 325 : (memref<?xf32>, memref<?xindex>) -> memref<*xf32> 326 scf.yield %2 : memref<*xf32> 327 } else { 328 %2 = memref_reshape %1(%arg1) 329 : (memref<?xf32>, memref<?xindex>) -> memref<*xf32> 330 scf.yield %2 : memref<*xf32> 331 } 332 return %3 : memref<*xf32> 333} 334 335// CHECK-LABEL: func @branching_result 336// CHECK: %[[C1:.*]] = constant 1 337// CHECK-NOT: linalg.generic 338// CHECK: scf.for {{.*}} step %[[C1]] 339// CHECK-NOT: scf.for 340// CHECK: linalg.generic 341// CHECK: absf 342// CHECK: scf.if 343// CHECK: memref_reshape 344// CHECK: scf.yield 345// CHECK: else 346// CHECK: memref_reshape 347// CHECK: scf.yield 348 349// TILED-LABEL: func @branching_result 350// TILED-DAG: %[[C2:.*]] = constant 2 351// TILED-NOT: linalg.generic 352// TILED: scf.for {{.*}} step %[[C2]] 353// TILED-NOT: scf.for 354// TILED: linalg.generic 355// TILED: absf 356// TILED: scf.if 357// TILED: memref_reshape 358// TILED: scf.yield 359// TILED: else 360// TILED: memref_reshape 361// TILED: scf.yield 362 363// PLOOP-LABEL: func @branching_result 364// PLOOP-NOT: linalg.generic 365// PLOOP: scf.parallel 366// PLOOP-NOT: scf.parallel 367// PLOOP: linalg.generic 368// PLOOP: absf 369// PLOOP: scf.if 370// PLOOP: memref_reshape 371// PLOOP: scf.yield 372// PLOOP: else 373// PLOOP: memref_reshape 374// PLOOP: scf.yield 375 376// ----- 377 378// Confirm that tiling information is passed through tensor_load, tensor.cast 379// and memref_to_tensor operations. 380func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>) 381 -> memref<?xf32> { 382 %c1 = constant 1 : index 383 %1 = alloc() : memref<32xf32> 384 linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, 385 affine_map<(d0) -> (d0)>], 386 iterator_types = ["parallel"]} 387 ins(%arg0 : memref<32xf32>) outs(%1 : memref<32xf32>) { 388 ^bb0(%arg3: f32, %arg4: f32): // no predecessors 389 %13 = absf %arg3 : f32 390 linalg.yield %13 : f32 391 } 392 %2 = tensor_load %1 : memref<32xf32> 393 %3 = tensor.cast %2 : tensor<32xf32> to tensor<?xf32> 394 %4 = tensor_to_memref %3 : memref<?xf32> 395 return %4 : memref<?xf32> 396} 397 398// CHECK-LABEL: func @tensor_ops 399// CHECK: %[[C1:.*]] = constant 1 400// CHECK-NOT: linalg.generic 401// CHECK: scf.for {{.*}} step %[[C1]] 402// CHECK-NOT: scf.for 403// CHECK: linalg.generic 404// CHECK: absf 405// CHECK: tensor_load 406// CHECK: tensor.cast 407// CHECK: tensor_to_memref 408 409// TILED-LABEL: func @tensor_ops 410// TILED-DAG: %[[C2:.*]] = constant 2 411// TILED-NOT: linalg.generic 412// TILED: scf.for {{.*}} step %[[C2]] 413// TILED-NOT: scf.for 414// TILED: linalg.generic 415// TILED: absf 416// TILED: tensor_load 417// TILED: tensor.cast 418// TILED: tensor_to_memref 419 420 421// PLOOP-LABEL: func @tensor_ops 422// PLOOP-NOT: linalg.generic 423// PLOOP: scf.parallel 424// PLOOP-NOT: scf.parallel 425// PLOOP: linalg.generic 426// PLOOP: absf 427// PLOOP: tensor_load 428// PLOOP: tensor.cast 429// PLOOP: tensor_to_memref 430