1// RUN: mlir-opt %s -scf-bufferize | FileCheck %s 2 3// CHECK-LABEL: func @if( 4// CHECK-SAME: %[[PRED:.*]]: i1, 5// CHECK-SAME: %[[TRUE_TENSOR:.*]]: tensor<?xf32>, 6// CHECK-SAME: %[[FALSE_TENSOR:.*]]: tensor<?xf32>) -> tensor<?xf32> { 7// CHECK: %[[RESULT_MEMREF:.*]] = scf.if %[[PRED]] -> (memref<?xf32>) { 8// CHECK: %[[TRUE_MEMREF:.*]] = tensor_to_memref %[[TRUE_TENSOR]] : memref<?xf32> 9// CHECK: scf.yield %[[TRUE_MEMREF]] : memref<?xf32> 10// CHECK: } else { 11// CHECK: %[[FALSE_MEMREF:.*]] = tensor_to_memref %[[FALSE_TENSOR]] : memref<?xf32> 12// CHECK: scf.yield %[[FALSE_MEMREF]] : memref<?xf32> 13// CHECK: } 14// CHECK: %[[RESULT_TENSOR:.*]] = tensor_load %[[RESULT_MEMREF:.*]] : memref<?xf32> 15// CHECK: return %[[RESULT_TENSOR]] : tensor<?xf32> 16// CHECK: } 17func @if(%pred: i1, %true_val: tensor<?xf32>, %false_val: tensor<?xf32>) -> tensor<?xf32> { 18 %0 = scf.if %pred -> (tensor<?xf32>) { 19 scf.yield %true_val : tensor<?xf32> 20 } else { 21 scf.yield %false_val : tensor<?xf32> 22 } 23 return %0 : tensor<?xf32> 24} 25 26// CHECK-LABEL: func @for( 27// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>, 28// CHECK-SAME: %[[LB:.*]]: index, %[[UB:.*]]: index, 29// CHECK-SAME: %[[STEP:.*]]: index) -> tensor<f32> { 30// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<f32> 31// CHECK: %[[RESULT_MEMREF:.*]] = scf.for %[[VAL_6:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ITER:.*]] = %[[MEMREF]]) -> (memref<f32>) { 32// CHECK: %[[TENSOR_ITER:.*]] = tensor_load %[[ITER]] : memref<f32> 33// CHECK: %[[MEMREF_YIELDED:.*]] = tensor_to_memref %[[TENSOR_ITER]] : memref<f32> 34// CHECK: scf.yield %[[MEMREF_YIELDED]] : memref<f32> 35// CHECK: } 36// CHECK: %[[VAL_8:.*]] = tensor_load %[[VAL_9:.*]] : memref<f32> 37// CHECK: return %[[VAL_8]] : tensor<f32> 38// CHECK: } 39func @for(%arg0: tensor<f32>, %lb: index, %ub: index, %step: index) -> tensor<f32> { 40 %ret = scf.for %iv = %lb to %ub step %step iter_args(%iter = %arg0) -> tensor<f32> { 41 scf.yield %iter : tensor<f32> 42 } 43 return %ret : tensor<f32> 44} 45 46// Check whether this converts at all. 47// 48// It would previously fail altogether. 49// CHECK-LABEL: func @if_correct_recursive_legalization_behavior 50// CHECK: "test.munge_tensor" 51func @if_correct_recursive_legalization_behavior(%pred: i1, %tensor: tensor<f32>) -> tensor<f32> { 52 %0 = scf.if %pred -> (tensor<f32>) { 53 %1 = "test.munge_tensor"(%tensor) : (tensor<f32>) -> (tensor<f32>) 54 scf.yield %1: tensor<f32> 55 } else { 56 %1 = "test.munge_tensor"(%tensor) : (tensor<f32>) -> (tensor<f32>) 57 scf.yield %1 : tensor<f32> 58 } 59 return %0 : tensor<f32> 60} 61 62// CHECK-LABEL: func @for_correct_recursive_legalization_behavior( 63// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>, 64// CHECK-SAME: %[[INDEX:.*]]: index) -> tensor<f32> { 65// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<f32> 66// CHECK: %[[RESULT:.*]] = scf.for %[[IV:.*]] = %[[INDEX]] to %[[INDEX]] step %[[INDEX]] iter_args(%[[MEMREF_ITER:.*]] = %[[MEMREF]]) -> (memref<f32>) { 67// CHECK: %[[TENSOR_ITER:.*]] = tensor_load %[[MEMREF_ITER]] : memref<f32> 68// CHECK: %[[TENSOR_MUNGED:.*]] = "test.munge_tensor"(%[[TENSOR_ITER]]) : (tensor<f32>) -> tensor<f32> 69// CHECK: %[[MEMREF_MUNGED:.*]] = tensor_to_memref %[[TENSOR_MUNGED]] : memref<f32> 70// CHECK: scf.yield %[[MEMREF_MUNGED]] : memref<f32> 71// CHECK: } 72// CHECK: %[[TENSOR:.*]] = tensor_load %[[RESULT:.*]] : memref<f32> 73// CHECK: return %[[TENSOR]] : tensor<f32> 74// CHECK: } 75func @for_correct_recursive_legalization_behavior(%arg0: tensor<f32>, %index: index) -> tensor<f32> { 76 %ret = scf.for %iv = %index to %index step %index iter_args(%iter = %arg0) -> tensor<f32> { 77 %0 = "test.munge_tensor"(%iter) : (tensor<f32>) -> (tensor<f32>) 78 scf.yield %0 : tensor<f32> 79 } 80 return %ret : tensor<f32> 81} 82