1// RUN: mlir-opt %s -scf-bufferize | FileCheck %s
2
3// CHECK-LABEL:   func @if(
4// CHECK-SAME:             %[[PRED:.*]]: i1,
5// CHECK-SAME:             %[[TRUE_TENSOR:.*]]: tensor<?xf32>,
6// CHECK-SAME:             %[[FALSE_TENSOR:.*]]: tensor<?xf32>) -> tensor<?xf32> {
7// CHECK:           %[[RESULT_MEMREF:.*]] = scf.if %[[PRED]] -> (memref<?xf32>) {
8// CHECK:             %[[TRUE_MEMREF:.*]] = tensor_to_memref %[[TRUE_TENSOR]] : memref<?xf32>
9// CHECK:             scf.yield %[[TRUE_MEMREF]] : memref<?xf32>
10// CHECK:           } else {
11// CHECK:             %[[FALSE_MEMREF:.*]] = tensor_to_memref %[[FALSE_TENSOR]] : memref<?xf32>
12// CHECK:             scf.yield %[[FALSE_MEMREF]] : memref<?xf32>
13// CHECK:           }
14// CHECK:           %[[RESULT_TENSOR:.*]] = tensor_load %[[RESULT_MEMREF:.*]] : memref<?xf32>
15// CHECK:           return %[[RESULT_TENSOR]] : tensor<?xf32>
16// CHECK:         }
17func @if(%pred: i1, %true_val: tensor<?xf32>, %false_val: tensor<?xf32>) -> tensor<?xf32> {
18  %0 = scf.if %pred -> (tensor<?xf32>) {
19    scf.yield %true_val : tensor<?xf32>
20  } else {
21    scf.yield %false_val : tensor<?xf32>
22  }
23  return %0 : tensor<?xf32>
24}
25
26// CHECK-LABEL:   func @for(
27// CHECK-SAME:              %[[TENSOR:.*]]: tensor<f32>,
28// CHECK-SAME:              %[[LB:.*]]: index, %[[UB:.*]]: index,
29// CHECK-SAME:              %[[STEP:.*]]: index) -> tensor<f32> {
30// CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<f32>
31// CHECK:           %[[RESULT_MEMREF:.*]] = scf.for %[[VAL_6:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ITER:.*]] = %[[MEMREF]]) -> (memref<f32>) {
32// CHECK:             %[[TENSOR_ITER:.*]] = tensor_load %[[ITER]] : memref<f32>
33// CHECK:             %[[MEMREF_YIELDED:.*]] = tensor_to_memref %[[TENSOR_ITER]] : memref<f32>
34// CHECK:             scf.yield %[[MEMREF_YIELDED]] : memref<f32>
35// CHECK:           }
36// CHECK:           %[[VAL_8:.*]] = tensor_load %[[VAL_9:.*]] : memref<f32>
37// CHECK:           return %[[VAL_8]] : tensor<f32>
38// CHECK:         }
39func @for(%arg0: tensor<f32>, %lb: index, %ub: index, %step: index) -> tensor<f32> {
40  %ret = scf.for %iv = %lb to %ub step %step iter_args(%iter = %arg0) -> tensor<f32> {
41    scf.yield %iter : tensor<f32>
42  }
43  return %ret : tensor<f32>
44}
45
46// Check whether this converts at all.
47//
48// It would previously fail altogether.
49// CHECK-LABEL:   func @if_correct_recursive_legalization_behavior
50// CHECK: "test.munge_tensor"
51func @if_correct_recursive_legalization_behavior(%pred: i1, %tensor: tensor<f32>) -> tensor<f32> {
52  %0 = scf.if %pred -> (tensor<f32>) {
53    %1 = "test.munge_tensor"(%tensor) : (tensor<f32>) -> (tensor<f32>)
54    scf.yield %1: tensor<f32>
55  } else {
56    %1 = "test.munge_tensor"(%tensor) : (tensor<f32>) -> (tensor<f32>)
57    scf.yield %1 : tensor<f32>
58  }
59  return %0 : tensor<f32>
60}
61
62// CHECK-LABEL:   func @for_correct_recursive_legalization_behavior(
63// CHECK-SAME:                                                      %[[TENSOR:.*]]: tensor<f32>,
64// CHECK-SAME:                                                      %[[INDEX:.*]]: index) -> tensor<f32> {
65// CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<f32>
66// CHECK:           %[[RESULT:.*]] = scf.for %[[IV:.*]] = %[[INDEX]] to %[[INDEX]] step %[[INDEX]] iter_args(%[[MEMREF_ITER:.*]] = %[[MEMREF]]) -> (memref<f32>) {
67// CHECK:             %[[TENSOR_ITER:.*]] = tensor_load %[[MEMREF_ITER]] : memref<f32>
68// CHECK:             %[[TENSOR_MUNGED:.*]] = "test.munge_tensor"(%[[TENSOR_ITER]]) : (tensor<f32>) -> tensor<f32>
69// CHECK:             %[[MEMREF_MUNGED:.*]] = tensor_to_memref %[[TENSOR_MUNGED]] : memref<f32>
70// CHECK:             scf.yield %[[MEMREF_MUNGED]] : memref<f32>
71// CHECK:           }
72// CHECK:           %[[TENSOR:.*]] = tensor_load %[[RESULT:.*]] : memref<f32>
73// CHECK:           return %[[TENSOR]] : tensor<f32>
74// CHECK:         }
75func @for_correct_recursive_legalization_behavior(%arg0: tensor<f32>, %index: index) -> tensor<f32> {
76  %ret = scf.for %iv = %index to %index step %index iter_args(%iter = %arg0) -> tensor<f32> {
77    %0 = "test.munge_tensor"(%iter) : (tensor<f32>) -> (tensor<f32>)
78    scf.yield %0 : tensor<f32>
79  }
80  return %ret : tensor<f32>
81}
82