1// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=false" %s | FILECHECK_OPTS="" FileCheck %s 2// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" -verify-diagnostics %s | FileCheck %s --check-prefix CHLO --dump-input-filter=all 3// This test runs twice: 4// 1. Through FILECHECK_OPTS="" FileCheck with chlo legalization disabled since verifying 5// that the chlo ops emit produces more useful tests. 6// 2. With chlo legalization enabled, verifying diagnostics to pick up any 7// issues with the full lowering (can catch some broadcasting corner 8// cases which emit with a warning). 9 10//===----------------------------------------------------------------------===// 11// BatchNorm op legalizations. 12//===----------------------------------------------------------------------===// 13 14// CHECK-LABEL: fusedBatchNorm_notraining 15func @fusedBatchNorm_notraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 16 // CHECK: "mhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> 17 %0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 18 return %0#0 : tensor<8x8x8x8xf32> 19} 20 21// CHECK-LABEL: fusedBatchNorm_training 22func @fusedBatchNorm_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 23 // TODO(riverriddle) Support training. 24 // CHECK: "tf.FusedBatchNorm" 25 %0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 26 return %0#0 : tensor<8x8x8x8xf32> 27} 28 29// fusedBatchNormV2 is almost identical to fusedBatchNormV3 (and uses the same 30// code), so only do a couple of basic checks. 31 32// CHECK-LABEL: fusedBatchNormV2_noTraining 33func @fusedBatchNormV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 34 // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> 35 %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 36 return %0#0 : tensor<8x8x8x8xf32> 37} 38 39// CHECK-LABEL: fusedBatchNormV2_training 40func @fusedBatchNormV2_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 41 // CHECK: %[[RESULT0:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>> 42 %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 43 // CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> 44 // CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 45 // CHECK: %[[VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 46 // CHECK: mhlo.constant 47 // CHECK: chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> 48 return %0#0 : tensor<8x8x8x8xf32> 49} 50 51// CHECK-LABEL: fusedBatchNormV3_noTraining 52func @fusedBatchNormV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 53 // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> 54 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 55 return %0#0 : tensor<8x8x8x8xf32> 56} 57 58// CHECK-LABEL: fusedBatchNormV3_noTraining_mixedPrecision 59// CHECK-SAME: ([[X:%.*]]: tensor<8x8x8x8xbf16>, [[SCALE:%.*]]: tensor<8xf32>, [[OFFSET:%.*]]: tensor<8xf32>, [[MEAN:%.*]]: tensor<8xf32>, [[VARIANCE:%.*]]: tensor<8xf32>) 60func @fusedBatchNormV3_noTraining_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) { 61 // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[X]]) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> 62 // CHECK: [[Y:%.*]] = "mhlo.batch_norm_inference"([[CONVERT_X]], [[SCALE]], [[OFFSET]], [[MEAN]], [[VARIANCE]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} 63 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) 64 // CHECK: [[Y_CONVERT:%.*]] = "mhlo.convert"([[Y]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> 65 // CHECK: [[DUMMY:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<0xf32> 66 // CHECK: [[DUMMY_CAST:%.*]] = tensor.cast [[DUMMY]] : tensor<0xf32> to tensor<*xf32> 67 // CHECK: return [[Y_CONVERT]], [[MEAN]], [[VARIANCE]], [[MEAN]], [[VARIANCE]], [[DUMMY_CAST]] 68 return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32> 69} 70 71// CHECK-LABEL: fusedBatchNormV3_training 72func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 73 // CHECK: %[[RESULT0:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>> 74 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 75 // CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> 76 // CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 77 // CHECK: %[[VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 78 // CHECK: mhlo.constant 79 // CHECK: chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> 80 return %0#0 : tensor<8x8x8x8xf32> 81} 82 83// CHECK-LABEL: func @fusedBatchNormV3_training_batchVariance 84func @fusedBatchNormV3_training_batchVariance(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> tensor<8xf32> { 85 // CHECK: %[[RESULT0:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>> 86 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 87 // CHECK: %[[VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 88 // CHECK: return %[[VAR]] 89 return %0#4 : tensor<8xf32> 90} 91 92// CHECK-LABEL: fusedBatchNormV3_training_exponentialAvgFactor 93func @fusedBatchNormV3_training_exponentialAvgFactor(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { 94 // CHECK: %[[RESULT0:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>> 95 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 0.8 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 96 // CHECK-DAG: %[[BATCH_MEAN:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} 97 // CHECK-DAG: %[[BATCH_VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} 98 99 // CHECK: %[[FACTOR:.*]] = mhlo.constant dense<1.00195694> 100 // CHECK: %[[CORRECTED_VAR:.*]] = chlo.broadcast_multiply %[[BATCH_VAR]], %[[FACTOR]] 101 102 // CHECK-DAG: %[[ALPHA:.*]] = mhlo.constant dense<0.199999988> 103 // CHECK-DAG: %[[BETA:.*]] = mhlo.constant dense<8.000000e-01> 104 105 // CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = chlo.broadcast_multiply %[[ALPHA]], %arg3 106 // CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = chlo.broadcast_multiply %[[BETA]], %[[BATCH_MEAN]] 107 // CHECK: %[[NEW_BATCH_MEAN:.*]] = chlo.broadcast_add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]] 108 109 // CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = chlo.broadcast_multiply %[[ALPHA]], %arg4 110 // CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = chlo.broadcast_multiply %[[BETA]], %[[CORRECTED_VAR]] 111 // CHECK: %[[NEW_BATCH_VAR:.*]] = chlo.broadcast_add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]] 112 113 // CHECK: return %[[NEW_BATCH_MEAN]], %[[NEW_BATCH_VAR]], %[[BATCH_MEAN]], %[[BATCH_VAR]] 114 return %0#1, %0#2, %0#3, %0#4 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32> 115} 116 117// CHECK-LABEL: fusedBatchNormV3_training_mixedPrecision 118func @fusedBatchNormV3_training_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { 119 // CHECK: "mhlo.convert"(%arg0) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> 120 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 121 // CHECK: "mhlo.convert"({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> 122 return %0#0 : tensor<8x8x8x8xbf16> 123} 124 125// CHECK-LABEL: fusedBatchNormV3_NCHW 126func @fusedBatchNormV3_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 127 // CHECK: "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>> 128 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 129 return %0#0 : tensor<8x8x8x8xf32> 130} 131 132// CHECK-LABEL: fusedBatchNormV3_noTraining_dynamic_supported 133func @fusedBatchNormV3_noTraining_dynamic_supported(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: tensor<?xf32>, %arg4: tensor<?xf32>) -> (tensor<?x?x?x?xf32>) { 134 // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?x?x?x?xf32> 135 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = false} : (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) 136 return %0#0 : tensor<?x?x?x?xf32> 137} 138 139// CHECK-LABEL: fusedBatchNormV3_training_dynamic_unsupported1 140func @fusedBatchNormV3_training_dynamic_unsupported1(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: tensor<?xf32>, %arg4: tensor<?xf32>) -> (tensor<?x?x?x?xf32>) { 141 // CHECK: tf.FusedBatchNormV3 142 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) 143 return %0#0 : tensor<?x?x?x?xf32> 144} 145 146// CHECK-LABEL: fusedBatchNormV3_training_dynamic_unsupported2 147func @fusedBatchNormV3_training_dynamic_unsupported2(%arg0: tensor<?x6x?x?xf32>, %arg1: tensor<6xf32>, %arg2: tensor<6xf32>, %arg3: tensor<6xf32>, %arg4: tensor<6xf32>) -> (tensor<?x6x?x?xf32>) { 148 // CHECK: tf.FusedBatchNormV3 149 %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<?x6x?x?xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>) -> (tensor<?x6x?x?xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>) 150 return %0#0 : tensor<?x6x?x?xf32> 151} 152 153// CHECK-LABEL: fusedBatchNormGrad_noTraining 154func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 155 // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 156 // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 157 // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32> 158 159 // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> 160 // CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> 161 162 // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> 163 // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> 164 // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> 165 // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> 166 // CHECK-NEXT: %[[cmul:.*]] = "mhlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 167 // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 168 // CHECK-NEXT: %[[red1:.*]] = "mhlo.reduce"(%[[cmul]], %[[init]]) ( { 169 // CHECK-NEXT: ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>): // no predecessors 170 // CHECK-NEXT: %[[reduced:.*]] = mhlo.add %arg5, %arg6 : tensor<f32> 171 // CHECK-NEXT: "mhlo.return"(%[[reduced]]) : (tensor<f32>) -> () 172 // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> 173 // CHECK-NEXT: %[[scr2:.*]] = "mhlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> 174 175 // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> 176 // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> 177 // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> 178 // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> 179 180 // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> 181 // CHECK-NEXT: %[[cgrad:.*]] = "mhlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 182 // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 183 // CHECK-NEXT: %[[red2:.*]] = "mhlo.reduce"(%[[cgrad]], %[[init2]]) ( { 184 // CHECK-NEXT: ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>): // no predecessors 185 // CHECK-NEXT: %[[reduced1:.*]] = mhlo.add %arg5, %arg6 : tensor<f32> 186 // CHECK-NEXT: "mhlo.return"(%[[reduced1]]) : (tensor<f32>) -> () 187 // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> 188 // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32> 189 190 // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 191 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> 192 193 %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 194 return %0#0 : tensor<8x8x8x8xf32> 195} 196 197// CHECK-LABEL: fusedBatchNormGrad_Training 198func @fusedBatchNormGrad_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 199 // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 200 // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 201 // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>> 202 // CHECK-NEXT: %[[tact:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> 203 // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 204 // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 205 // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 206 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> 207 208 %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 209 return %0#0 : tensor<8x8x8x8xf32> 210} 211 212// CHECK-LABEL: fusedBatchNormGradV2_noTraining 213func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 214 // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 215 // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 216 // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32> 217 218 // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> 219 // CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> 220 221 // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> 222 // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> 223 // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> 224 // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> 225 // CHECK-NEXT: %[[cmul:.*]] = "mhlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 226 // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 227 // CHECK-NEXT: %[[red1:.*]] = "mhlo.reduce"(%[[cmul]], %[[init]]) ( { 228 // CHECK-NEXT: ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>): // no predecessors 229 // CHECK-NEXT: %[[reduced:.*]] = mhlo.add %arg5, %arg6 : tensor<f32> 230 // CHECK-NEXT: "mhlo.return"(%[[reduced]]) : (tensor<f32>) -> () 231 // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> 232 // CHECK-NEXT: %[[scr2:.*]] = "mhlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> 233 234 // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> 235 // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> 236 // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> 237 238 // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> 239 240 // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> 241 // CHECK-NEXT: %[[cgrad:.*]] = "mhlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 242 // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 243 // CHECK-NEXT: %[[red2:.*]] = "mhlo.reduce"(%[[cgrad]], %[[init2]]) ( { 244 // CHECK-NEXT: ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>): // no predecessors 245 // CHECK-NEXT: %[[reduced1:.*]] = mhlo.add %arg5, %arg6 : tensor<f32> 246 // CHECK-NEXT: "mhlo.return"(%[[reduced1]]) : (tensor<f32>) -> () 247 // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> 248 // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32> 249 250 // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 251 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> 252 253 %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 254 return %0#0 : tensor<8x8x8x8xf32> 255} 256 257// CHECK-LABEL: fusedBatchNormGradV2_Training 258func @fusedBatchNormGradV2_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 259 // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 260 // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 261 // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>> 262 // CHECK-NEXT: %[[tact:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> 263 // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 264 // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 265 // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 266 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> 267 268 %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 269 return %0#0 : tensor<8x8x8x8xf32> 270} 271 272// CHECK-LABEL: fusedBatchNormGradV2_noTraining_mixed_precision 273func @fusedBatchNormGradV2_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { 274 // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 275 // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> 276 277 // CHECK: %[[x_backprop:.*]] = "mhlo.convert"({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> 278 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> 279 280 %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 281 return %0#0 : tensor<8x8x8x8xbf16> 282} 283 284// CHECK-LABEL: fusedBatchNormGradV2_Training_mixed_precision 285func @fusedBatchNormGradV2_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { 286 // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 287 // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> 288 // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>> 289 // CHECK-NEXT: %[[tact:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> 290 // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 291 // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 292 // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> 293 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> 294 295 %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 296 return %0#0 : tensor<8x8x8x8xbf16> 297} 298 299// CHECK-LABEL: fusedBatchNormGradV3_noTraining 300func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 301 // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 302 // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 303 // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32> 304 305 // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> 306 // CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> 307 308 // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> 309 // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> 310 // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> 311 // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> 312 // CHECK-NEXT: %[[cmul:.*]] = "mhlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 313 // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 314 // CHECK-NEXT: %[[red1:.*]] = "mhlo.reduce"(%[[cmul]], %[[init]]) ( { 315 // CHECK-NEXT: ^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>): // no predecessors 316 // CHECK-NEXT: %[[reduced:.*]] = mhlo.add %arg6, %arg7 : tensor<f32> 317 // CHECK-NEXT: "mhlo.return"(%[[reduced]]) : (tensor<f32>) -> () 318 // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> 319 // CHECK-NEXT: %[[scr2:.*]] = "mhlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> 320 321 // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> 322 // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> 323 // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> 324 325 // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> 326 327 // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64> 328 // CHECK-NEXT: %[[cgrad:.*]] = "mhlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 329 // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 330 // CHECK-NEXT: %[[red2:.*]] = "mhlo.reduce"(%[[cgrad]], %[[init2]]) ( { 331 // CHECK-NEXT: ^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>): // no predecessors 332 // CHECK-NEXT: %[[reduced1:.*]] = mhlo.add %arg6, %arg7 : tensor<f32> 333 // CHECK-NEXT: "mhlo.return"(%[[reduced1]]) : (tensor<f32>) -> () 334 // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> 335 // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32> 336 337 // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 338 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> 339 340 %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 341 return %0#0 : tensor<8x8x8x8xf32> 342} 343 344// CHECK-LABEL: fusedBatchNormGradV3_Training 345func @fusedBatchNormGradV3_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32>) { 346 // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 347 // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 348 // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>> 349 // CHECK-NEXT: %[[tact:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> 350 // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 351 // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 352 // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 353 // CHECK: return %[[x_backprop]] 354 // CHECK-SAME: tensor<8x8x8x8xf32> 355 356 %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<0xf32>, tensor<*xf32>) 357 return %0#0, %0#3, %0#4 : tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32> 358} 359 360// CHECK-LABEL: fusedBatchNormGradV3_noTraining_mixed_precision 361func @fusedBatchNormGradV3_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { 362 // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 363 // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> 364 365 // CHECK: %[[x_backprop:.*]] = "mhlo.convert"({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> 366 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> 367 368 %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 369 return %0#0 : tensor<8x8x8x8xbf16> 370} 371 372// CHECK-LABEL: fusedBatchNormGradV3_Training_mixed_precision 373func @fusedBatchNormGradV3_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { 374 // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 375 // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> 376 // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>> 377 // CHECK-NEXT: %[[tact:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> 378 // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 379 // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> 380 // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> 381 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16> 382 383 %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 384 return %0#0 : tensor<8x8x8x8xbf16> 385} 386 387// CHECK-LABEL: fusedBatchNormGradV3_noTraining_NCHW 388func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 389 // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 390 // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 391 // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32> 392 393 // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> 394 // CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> 395 396 // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> 397 // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> 398 // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> 399 // CHECK-NEXT: mhlo.constant dense<[0, 2, 3]> : tensor<3xi64> 400 // CHECK-NEXT: %[[cmul:.*]] = "mhlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 401 // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 402 // CHECK-NEXT: %[[red1:.*]] = "mhlo.reduce"(%[[cmul]], %[[init]]) ( { 403 // CHECK-NEXT: ^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>): // no predecessors 404 // CHECK-NEXT: %[[reduced:.*]] = mhlo.add %arg6, %arg7 : tensor<f32> 405 // CHECK-NEXT: "mhlo.return"(%[[reduced]]) : (tensor<f32>) -> () 406 // CHECK-NEXT: }) {dimensions = dense<[0, 2, 3]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> 407 // CHECK-NEXT: %[[scr2:.*]] = "mhlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> 408 409 // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32> 410 // CHECK: %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> 411 // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> 412 413 // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> 414 415 // CHECK-NEXT: mhlo.constant dense<[0, 2, 3]> : tensor<3xi64> 416 // CHECK-NEXT: %[[cgrad:.*]] = "mhlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 417 // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 418 // CHECK-NEXT: %[[red2:.*]] = "mhlo.reduce"(%[[cgrad]], %[[init2]]) ( { 419 // CHECK-NEXT: ^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>): // no predecessors 420 // CHECK-NEXT: %[[reduced1:.*]] = mhlo.add %arg6, %arg7 : tensor<f32> 421 // CHECK-NEXT: "mhlo.return"(%[[reduced1]]) : (tensor<f32>) -> () 422 // CHECK-NEXT: }) {dimensions = dense<[0, 2, 3]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> 423 // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32> 424 425 // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> 426 // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32> 427 428 %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 429 return %0#0 : tensor<8x8x8x8xf32> 430} 431 432// CHECK-LABEL: fusedBatchNormGradV3_Training_NCHW 433func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { 434 // CHECK: %{{.*}} = "mhlo.batch_norm_grad"(%{{.*}}, %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>> 435 %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) 436 return %0#0 : tensor<8x8x8x8xf32> 437} 438 439//===----------------------------------------------------------------------===// 440// Bias op legalizations. 441//===----------------------------------------------------------------------===// 442 443// CHECK-LABEL: func @biasAdd_default 444func @biasAdd_default(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { 445 // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 446 // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] 447 // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) 448 // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} 449 // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] 450 %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> 451 return %0 : tensor<1x32x10x32xi32> 452} 453 454// CHECK-LABEL: func @biasAdd_NHWC 455func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { 456 // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 457 // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] 458 // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) 459 // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} 460 // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] 461 %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> 462 return %0 : tensor<1x32x10x32xi32> 463} 464 465// CHECK-LABEL: func @biasAdd_NCHW 466func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { 467 // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 468 // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] 469 // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) 470 // CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>} 471 // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] 472 %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> 473 return %0 : tensor<1x32x10x32xi32> 474} 475 476// CHECK-LABEL: func @biasAdd_dynamic 477func @biasAdd_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?x?x?xi32> { 478 // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 479 // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]] 480 // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) 481 // CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>} 482 // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]] 483 %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} : (tensor<?x?x?x?xi32>, tensor<?xi32>) -> tensor<?x?x?x?xi32> 484 return %0 : tensor<?x?x?x?xi32> 485} 486 487 488//===----------------------------------------------------------------------===// 489// ClipByValue 490//===----------------------------------------------------------------------===// 491 492// CHECK-LABEL: @clip 493func @clip(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<f32> { 494 // CHECK: [[VAL:%.+]] = "mhlo.clamp"(%arg1, %arg0, %arg2) 495 496 %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32> 497 // CHECK: return [[VAL]] 498 return %0 : tensor<f32> 499} 500 501// CHECK-LABEL: @clip_dynamic 502func @clip_dynamic(%arg0 : tensor<?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>) -> tensor<?xf32> { 503 // CHECK-DAG: [[CLAMP:%.+]] = "mhlo.clamp"(%arg1, %arg0, %arg2) 504 %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> 505 506 // CHECK: return [[CLAMP]] 507 return %0 : tensor<?xf32> 508} 509 510// CHECK-LABEL: @clip_static_broadcast 511func @clip_static_broadcast(%arg0 : tensor<5xf32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<5xf32> { 512 // CHECK-DAG: [[SHP:%.+]] = mhlo.constant dense<5> 513 // CHECK-DAG: [[SHPIDX:%.+]] = tensor.cast [[SHP]] 514 // CHECK-DAG: [[BROADCAST_MIN:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>} 515 // CHECK-DAG: [[BROADCAST_MAX:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>} 516 // CHECK-DAG: [[CLAMP:%.+]] = "mhlo.clamp"([[BROADCAST_MIN]], %arg0, [[BROADCAST_MAX]]) 517 %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<5xf32>, tensor<f32>, tensor<f32>) -> tensor<5xf32> 518 519 // CHECK: return [[CLAMP]] 520 return %0 : tensor<5xf32> 521} 522 523 524// CHECK-LABEL: @clip_dynamic_broadcast 525func @clip_dynamic_broadcast(%arg0 : tensor<?xf32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<?xf32> { 526 // CHECK-DAG: [[SHP:%.+]] = shape.shape_of %arg0 527 // CHECK-DAG: [[EXT:%.+]] = shape.to_extent_tensor [[SHP]] 528 // CHECK-DAG: [[SHPIDX:%.+]] = index_cast %1 529 // CHECK-DAG: [[BROADCAST_MIN:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>} 530 // CHECK-DAG: [[BROADCAST_MAX:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>} 531 // CHECK-DAG: [[CLAMP:%.+]] = "mhlo.clamp"([[BROADCAST_MIN]], %arg0, [[BROADCAST_MAX]]) 532 %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<?xf32>, tensor<f32>, tensor<f32>) -> tensor<?xf32> 533 534 // CHECK: return [[CLAMP]] 535 return %0 : tensor<?xf32> 536} 537 538//===----------------------------------------------------------------------===// 539// DiagPart 540//===----------------------------------------------------------------------===// 541 542// CHECK-LABEL: func @diag_part 543// CHECK-SAME: %[[ARG:.*]]: tensor<4x3x4x3xf32> 544func @diag_part(%arg0: tensor<4x3x4x3xf32>) -> tensor<4x3xf32> { 545 // CHECK: %[[RS:.*]] = "mhlo.reshape"(%[[ARG]]) : (tensor<4x3x4x3xf32>) -> tensor<12x12xf32> 546 // CHECK-DAG: %[[IOTA0:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<12x12xi32> 547 // CHECK-DAG: %[[IOTA1:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<12x12xi32> 548 // CHECK-DAG: %[[COMP:.*]] = "mhlo.compare"(%[[IOTA0]], %[[IOTA1]]) {comparison_direction = "EQ"} : (tensor<12x12xi32>, tensor<12x12xi32>) -> tensor<12x12xi1> 549 // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 550 // CHECK-DAG: %[[ZERO_MAT:.*]] = "mhlo.broadcast"(%[[ZERO]]) {broadcast_sizes = dense<12> : tensor<2xi64>} : (tensor<f32>) -> tensor<12x12xf32> 551 // CHECK-DAG: %[[SEL:.*]] = "mhlo.select"(%[[COMP]], %[[RS]], %[[ZERO_MAT]]) : (tensor<12x12xi1>, tensor<12x12xf32>, tensor<12x12xf32>) -> tensor<12x12xf32> 552 // CHECK-DAG: %[[RED:.*]] = "mhlo.reduce"(%[[SEL]], %[[ZERO]]) 553 // CHECK-DAG: mhlo.add 554 // CHECK-DAG: {dimensions = dense<0> : tensor<1xi64>} : (tensor<12x12xf32>, tensor<f32>) -> tensor<12xf32> 555 // CHECK-DAG: %[[RES:.*]] = "mhlo.reshape"(%[[RED]]) : (tensor<12xf32>) -> tensor<4x3xf32> 556 // CHECK-DAG: return %[[RES]] : tensor<4x3xf32> 557 %0 = "tf.DiagPart"(%arg0) : (tensor<4x3x4x3xf32>) -> tensor<4x3xf32> 558 return %0: tensor<4x3xf32> 559} 560 561//===----------------------------------------------------------------------===// 562// MatrixDiagPart 563//===----------------------------------------------------------------------===// 564 565// CHECK-LABEL: func @matrix_diag_part 566// CHECK-SAME: %[[ARG:.*]]: tensor<7x140x128xi32> 567func @matrix_diag_part(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { 568 // CHECK-DAG: %[[V0:.*]] = mhlo.constant dense<42> : tensor<i32> 569 // CHECK-DAG: %[[V1:.*]] = mhlo.constant dense<[-10, 11]> : tensor<2xi32> 570 // CHECK-DAG: %[[V2:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<1x22x128xi32> 571 // CHECK-DAG: %[[V3:.*]] = "mhlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<1x22x128xi32> 572 // CHECK-DAG: %[[V4:.*]] = mhlo.constant dense<0> : tensor<i32> 573 // CHECK-DAG: %[[V5:.*]] = "mhlo.broadcast"(%[[V4]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32> 574 // CHECK-DAG: %[[V6:.*]] = mhlo.constant dense<false> : tensor<i1> 575 // CHECK-DAG: %[[V7:.*]] = "mhlo.broadcast"(%[[V6]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i1>) -> tensor<1x22x128xi1> 576 // CHECK-DAG: %[[V8:.*]] = mhlo.constant dense<true> : tensor<i1> 577 // CHECK-DAG: %[[V9:.*]] = "mhlo.broadcast"(%[[V8]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i1>) -> tensor<1x22x128xi1> 578 // CHECK-DAG: %[[V10:.*]] = mhlo.constant dense<11> : tensor<i32> 579 // CHECK-DAG: %[[V11:.*]] = "mhlo.broadcast"(%[[V10]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32> 580 // CHECK-DAG: %[[V12:.*]] = mhlo.constant dense<140> : tensor<i32> 581 // CHECK-DAG: %[[V13:.*]] = "mhlo.broadcast"(%[[V12]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32> 582 // CHECK-DAG: %[[V14:.*]] = mhlo.constant dense<128> : tensor<i32> 583 // CHECK-DAG: %[[V15:.*]] = "mhlo.broadcast"(%[[V14]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32> 584 // CHECK-DAG: %[[V16:.*]] = mhlo.constant dense<128> : tensor<i32> 585 // CHECK-DAG: %[[V17:.*]] = "mhlo.broadcast"(%[[V16]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32> 586 // CHECK-DAG: %[[V18:.*]] = mhlo.subtract %[[V11]], %[[V2]] : tensor<1x22x128xi32> 587 // CHECK-DAG: %[[V19:.*]] = "mhlo.negate"(%[[V18]]) : (tensor<1x22x128xi32>) -> tensor<1x22x128xi32> 588 // CHECK-DAG: %[[V20:.*]] = mhlo.minimum %[[V18]], %[[V5]] : tensor<1x22x128xi32> 589 // CHECK-DAG: %[[V21:.*]] = mhlo.add %[[V13]], %[[V20]] : tensor<1x22x128xi32> 590 // CHECK-DAG: %[[V22:.*]] = mhlo.maximum %[[V18]], %[[V5]] : tensor<1x22x128xi32> 591 // CHECK-DAG: %[[V23:.*]] = mhlo.subtract %[[V15]], %[[V22]] : tensor<1x22x128xi32> 592 // CHECK-DAG: %[[V24:.*]] = mhlo.minimum %[[V21]], %[[V23]] : tensor<1x22x128xi32> 593 // CHECK-DAG: %[[V25:.*]] = chlo.broadcast_compare %[[V18]], %[[V5]] {comparison_direction = "GE"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> 594 // CHECK-DAG: %[[V26:.*]] = mhlo.subtract %[[V17]], %[[V24]] : tensor<1x22x128xi32> 595 // CHECK-DAG: %[[V27:.*]] = "mhlo.select"(%[[V25]], %[[V26]], %[[V5]]) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32> 596 // CHECK-DAG: %[[V28:.*]] = mhlo.maximum %[[V18]], %[[V5]] : tensor<1x22x128xi32> 597 // CHECK-DAG: %[[V29:.*]] = mhlo.subtract %[[V28]], %[[V27]] : tensor<1x22x128xi32> 598 // CHECK-DAG: %[[V30:.*]] = mhlo.maximum %[[V19]], %[[V5]] : tensor<1x22x128xi32> 599 // CHECK-DAG: %[[V31:.*]] = mhlo.subtract %[[V30]], %[[V27]] : tensor<1x22x128xi32> 600 // CHECK-DAG: %[[V32:.*]] = mhlo.add %[[V3]], %[[V29]] : tensor<1x22x128xi32> 601 // CHECK-DAG: %[[V33:.*]] = mhlo.add %[[V3]], %[[V31]] : tensor<1x22x128xi32> 602 // CHECK-DAG: %[[V34:.*]] = chlo.broadcast_compare %[[V32]], %[[V5]] {comparison_direction = "GE"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> 603 // CHECK-DAG: %[[V35:.*]] = chlo.broadcast_compare %[[V32]], %[[V15]] {comparison_direction = "LT"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> 604 // CHECK-DAG: %[[V36:.*]] = mhlo.and %[[V34]], %[[V35]] : tensor<1x22x128xi1> 605 // CHECK-DAG: %[[V37:.*]] = chlo.broadcast_compare %[[V33]], %[[V5]] {comparison_direction = "GE"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> 606 // CHECK-DAG: %[[V38:.*]] = chlo.broadcast_compare %[[V33]], %[[V13]] {comparison_direction = "LT"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> 607 // CHECK-DAG: %[[V39:.*]] = mhlo.and %[[V37]], %[[V38]] : tensor<1x22x128xi1> 608 // CHECK-DAG: %[[V40:.*]] = mhlo.and %[[V36]], %[[V39]] : tensor<1x22x128xi1> 609 // CHECK-DAG: %[[V41:.*]] = "mhlo.reshape"(%[[V40]]) : (tensor<1x22x128xi1>) -> tensor<22x128xi1> 610 // CHECK-DAG: %[[V42:.*]] = "mhlo.concatenate"(%[[V33]], %[[V32]]) {dimension = 0 : i64} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<2x22x128xi32> 611 // CHECK-DAG: %[[V43:.*]] = "mhlo.gather"(%[[ARG]], %[[V42]]) {dimension_numbers = {collapsed_slice_dims = dense<[1, 2]> : tensor<2xi64>, index_vector_dim = 0 : i64, offset_dims = dense<0> : tensor<1xi64>, start_index_map = dense<[1, 2]> : tensor<2xi64>}, indices_are_sorted = false, slice_sizes = dense<[7, 1, 1]> : tensor<3xi64>} : (tensor<7x140x128xi32>, tensor<2x22x128xi32>) -> tensor<7x22x128xi32> 612 // CHECK-DAG: %[[V44:.*]] = "mhlo.broadcast"(%[[V41]]) {broadcast_sizes = dense<7> : tensor<1xi64>} : (tensor<22x128xi1>) -> tensor<7x22x128xi1> 613 // CHECK-DAG: %[[V45:.*]] = "mhlo.broadcast"(%[[V0]]) {broadcast_sizes = dense<[7, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<7x22x128xi32> 614 // CHECK: %[[V46:.*]] = "mhlo.select"(%[[V44]], %[[V43]], %[[V45]]) : (tensor<7x22x128xi1>, tensor<7x22x128xi32>, tensor<7x22x128xi32>) -> tensor<7x22x128xi32> 615 // CHECK: return %[[V46]] : tensor<7x22x128xi32> 616 %0 = mhlo.constant dense<42> : tensor<i32> // padding value 617 %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k 618 %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { 619 T = i32, align = "RIGHT_LEFT" 620 } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32> 621 return %2: tensor<7x22x128xi32> 622} 623 624// CHECK-LABEL: func @matrix_diag_part_single_diagonal 625func @matrix_diag_part_single_diagonal(%arg0: tensor<7x140x128xi32>) -> tensor<7x128xi32> { 626 %0 = mhlo.constant dense<42> : tensor<i32> // padding value 627 %1 = mhlo.constant dense<0> : tensor<2xi32> // k 628 %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { 629 T = i32, align = "RIGHT_LEFT" 630 } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x128xi32> 631 // CHECK: %[[result:.*]] = "mhlo.reshape"({{.*}}) : (tensor<7x1x128xi32>) -> tensor<7x128xi32> 632 // CHECK: return %[[result]] : tensor<7x128xi32> 633 return %2: tensor<7x128xi32> 634} 635 636// CHECK-LABEL: func @matrix_diag_part_align_ll 637func @matrix_diag_part_align_ll(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { 638 %0 = mhlo.constant dense<42> : tensor<i32> // padding value 639 %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k 640 %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { 641 T = i32, align = "LEFT_LEFT" 642 } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32> 643 // CHECK: %[[false:.*]] = mhlo.constant dense<false> : tensor<i1> 644 // CHECK: %[[b_false:.*]] = "mhlo.broadcast"(%[[false]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i1>) -> tensor<1x22x128xi1> 645 // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[b_false]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32> 646 return %2: tensor<7x22x128xi32> 647} 648 649// CHECK-LABEL: func @matrix_diag_part_align_lr 650func @matrix_diag_part_align_lr(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { 651 %0 = mhlo.constant dense<42> : tensor<i32> // padding value 652 %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k 653 %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { 654 T = i32, align = "LEFT_RIGHT" 655 } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32> 656 // CHECK: %[[le:.*]] = chlo.broadcast_compare %{{[0-9]*}}, %{{[0-9]*}} {comparison_direction = "LE"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> 657 // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[le]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32> 658 return %2: tensor<7x22x128xi32> 659} 660 661// CHECK-LABEL: func @matrix_diag_part_align_rl 662func @matrix_diag_part_align_rl(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { 663 %0 = mhlo.constant dense<42> : tensor<i32> // padding value 664 %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k 665 %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { 666 T = i32, align = "RIGHT_LEFT" 667 } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32> 668 // CHECK: %[[ge:.*]] = chlo.broadcast_compare %{{[0-9]*}}, %{{[0-9]*}} {comparison_direction = "GE"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> 669 // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[ge]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32> 670 return %2: tensor<7x22x128xi32> 671} 672 673// CHECK-LABEL: func @matrix_diag_part_align_rr 674func @matrix_diag_part_align_rr(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { 675 %0 = mhlo.constant dense<42> : tensor<i32> // padding value 676 %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k 677 %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { 678 T = i32, align = "RIGHT_RIGHT" 679 } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32> 680 // CHECK: %[[true:.*]] = mhlo.constant dense<true> : tensor<i1> 681 // CHECK: %[[b_true:.*]] = "mhlo.broadcast"(%[[true]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i1>) -> tensor<1x22x128xi1> 682 // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[b_true]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32> 683 return %2: tensor<7x22x128xi32> 684} 685 686// CHECK-LABEL: func @matrix_diag_part_align_7d 687// CHECK: (%arg0: tensor<3x5x7x9x11x13x17xf32>) -> tensor<3x5x7x9x11x4x10xf32> 688func @matrix_diag_part_align_7d(%arg0: tensor<3x5x7x9x11x13x17xf32>) -> tensor<3x5x7x9x11x4x10xf32> { 689 %0 = mhlo.constant dense<-1.> : tensor<f32> // padding value 690 %1 = mhlo.constant dense<[-6, -3]> : tensor<2xi32> // k 691 %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { 692 T = f32, align = "LEFT_RIGHT" 693 } : (tensor<3x5x7x9x11x13x17xf32>, tensor<2xi32>, tensor<f32>) -> tensor<3x5x7x9x11x4x10xf32> 694 return %2: tensor<3x5x7x9x11x4x10xf32> 695} 696 697//===----------------------------------------------------------------------===// 698// Erf 699//===----------------------------------------------------------------------===// 700 701// CHECK-LABEL: func @erf 702func @erf(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { 703 // CHECK: chlo.erf %arg0 : tensor<2x3xf32> 704 %0 = "tf.Erf"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> 705 return %0 : tensor<2x3xf32> 706} 707 708//===----------------------------------------------------------------------===// 709// Erfc 710//===----------------------------------------------------------------------===// 711 712// CHECK-LABEL: func @erfc 713func @erfc(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { 714 // CHECK: chlo.erfc %arg0 : tensor<2x3xf32> 715 %0 = "tf.Erfc"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> 716 return %0 : tensor<2x3xf32> 717} 718 719//===----------------------------------------------------------------------===// 720// Einsum. 721//===----------------------------------------------------------------------===// 722 723// CHECK-LABEL: func @einsum 724func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> { 725 // CHECK: mhlo.einsum 726 %0 = "tf.Einsum"(%arg0, %arg1) {equation = "ab,bc->ac"} : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32> 727 return %0: tensor<2x4xf32> 728} 729 730// CHECK-LABEL: func @unary_einsum 731func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { 732 // CHECK: mhlo.unary_einsum 733 %0 = "tf.Einsum"(%arg0) {equation = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32> 734 return %0: tensor<2x2xf32> 735} 736 737//===----------------------------------------------------------------------===// 738// FloorDiv and FloorMod. 739//===----------------------------------------------------------------------===// 740 741// CHECK-LABEL: func @floordiv_broadcast_i32 742func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { 743 // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> 744 // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} 745 // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> 746 // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} 747 // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} 748 // CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} 749 // CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0) 750 // CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1) 751 // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> 752 // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]] 753 // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} 754 // CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]]) 755 // CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1) 756 // CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>} 757 // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) 758 // CHECK: return [[SELECT]] 759 %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> 760 return %0: tensor<2x3xi32> 761} 762 763// CHECK-LABEL: func @floordiv_reverse_broadcast_i32 764func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { 765 // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> 766 // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} 767 // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> 768 // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} 769 // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} 770 // CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} 771 // CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0) 772 // CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1) 773 // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> 774 // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]] 775 // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} 776 // CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]]) 777 // CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1) 778 // CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]] 779 // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) 780 // CHECK: return [[SELECT]] 781 %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> 782 return %0: tensor<2x3xi32> 783} 784 785// CHECK-LABEL: func @floordiv_f32 786func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { 787 // CHECK-NEXT: %[[DIV:.*]] = chlo.broadcast_divide %arg0, %arg0 788 // CHECK-NEXT: %[[FLOOR:.*]] = "mhlo.floor"(%[[DIV]]) 789 // CHECK-NEXT: return %[[FLOOR]] : tensor<2xf32> 790 %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> 791 return %0: tensor<2xf32> 792} 793 794// CHECK-LABEL: func @floordiv_bf16 795func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> { 796 // CHECK-NEXT: mhlo.convert 797 // CHECK-NEXT: mhlo.convert 798 // CHECK-NEXT: chlo.broadcast_divide 799 // CHECK-NEXT: mhlo.floor 800 // CHECK-NEXT: mhlo.convert 801 // CHECK-NEXT: return 802 %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xbf16>, tensor<2xbf16>) -> tensor<2xbf16> 803 return %0: tensor<2xbf16> 804} 805 806// CHECK-LABEL: func @floordiv_f16_broadcast 807func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { 808 // CHECK-NEXT: chlo.broadcast_divide 809 // CHECK-NEXT: mhlo.floor 810 // CHECK-NEXT: return 811 %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> 812 return %0: tensor<2x3xf16> 813} 814 815// CHECK-LABEL: func @floordiv_dynamic 816func @floordiv_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?xi32> { 817 // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> 818 // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} 819 // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> 820 // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} 821 // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} 822 // CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} 823 // CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0) 824 // CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1) 825 // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> 826 // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]] 827 // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} 828 // CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]]) 829 // CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1) 830 // CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>} 831 // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) 832 // CHECK: return [[SELECT]] 833 %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<?xi32>) -> tensor<?x?xi32> 834 return %0: tensor<?x?xi32> 835} 836 837// CHECK-LABEL: func @floordiv_unranked 838func @floordiv_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { 839 // CHECK-NOT: tf.FloorDiv 840 %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> 841 return %0: tensor<*xf32> 842} 843 844// CHECK-LABEL: func @floordiv_int 845func @floordiv_int(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { 846 // CHECK: tf.FloorDiv 847 %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> 848 return %0: tensor<*xi32> 849} 850 851// CHECK-LABEL: func @floormod_broadcast_numerator 852func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { 853 // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} 854 // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> 855 // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} 856 // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> 857 // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"} 858 // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "LT"} 859 // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = "NE"} 860 // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] 861 // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] 862 // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]]) 863 // CHECK-NEXT: return [[SELECT]] 864 %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> 865 return %0: tensor<2x3xi32> 866} 867 868// CHECK-LABEL: func @floormod_broadcast_denominator 869func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { 870 // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} 871 // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> 872 // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"} 873 // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> 874 // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"} 875 // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "LT"} 876 // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} 877 // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] 878 // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>} 879 // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]]) 880 // CHECK-NEXT: return [[SELECT]] 881 %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> 882 return %0: tensor<2x3xi32> 883} 884 885// CHECK-LABEL: func @floormod_dynamic 886func @floormod_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?xi32> { 887 // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} 888 // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> 889 // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"} 890 // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> 891 // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"} 892 // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "LT"} 893 // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} 894 // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] 895 // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>} 896 // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]]) 897 // CHECK-NEXT: return [[SELECT]] 898 %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<?xi32>) -> tensor<?x?xi32> 899 return %0: tensor<?x?xi32> 900} 901 902// CHECK-LABEL: func @floormod_unranked 903func @floormod_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { 904 // CHECK-NOT: tf.FloorMod 905 %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> 906 return %0: tensor<*xi32> 907} 908 909//===----------------------------------------------------------------------===// 910// BroadcastTo. 911//===----------------------------------------------------------------------===// 912 913// CHECK-LABEL: func @broadcast_to 914func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> { 915 %cst = "tf.Const"() { value = dense<16> : tensor<4xi32> } : () -> tensor<4xi32> 916 917 // CHECK: [[CST:%.+]] = mhlo.constant 918 // CHECK: [[CAST:%.+]] = tensor.cast [[CST]] : tensor<4xi32> to tensor<4xi32> 919 // CHECK: "mhlo.dynamic_broadcast_in_dim"(%arg0, [[CAST]]) 920 // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} 921 %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<16xf32>, tensor<4xi32>) -> tensor<16x16x16x16xf32> 922 return %0 : tensor<16x16x16x16xf32> 923} 924 925//===----------------------------------------------------------------------===// 926// Complex op legalizations. 927//===----------------------------------------------------------------------===// 928 929// CHECK-LABEL: func @complex 930func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex<f32>> { 931 // CHECK: chlo.broadcast_complex 932 %1 = "tf.Complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex<f32>> 933 return %1 : tensor<3xcomplex<f32>> 934} 935 936// CHECK-LABEL: func @imag 937func @imag(%arg0: tensor<3xcomplex<f32>>) -> tensor<3xf32> { 938 // CHECK: "mhlo.imag" 939 %1 = "tf.Imag"(%arg0) : (tensor<3xcomplex<f32>>) -> tensor<3xf32> 940 return %1 : tensor<3xf32> 941} 942 943// CHECK-LABEL: func @real 944func @real(%arg0: tensor<3xcomplex<f32>>) -> tensor<3xf32> { 945 // CHECK: "mhlo.real" 946 %1 = "tf.Real"(%arg0) : (tensor<3xcomplex<f32>>) -> tensor<3xf32> 947 return %1 : tensor<3xf32> 948} 949 950//===----------------------------------------------------------------------===// 951// Concat op legalizations. 952//===----------------------------------------------------------------------===// 953 954// CHECK-LABEL: func @concat_v2 955func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { 956 // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> 957 %axis = "tf.Const"() { value = dense<0> : tensor<i64> } : () -> tensor<i64> 958 %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<i64>) -> tensor<6x3xf32> 959 return %1 : tensor<6x3xf32> 960} 961 962// CHECK-LABEL: func @concat_v2_neg_axis 963func @concat_v2_neg_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { 964 // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> 965 966 %axis = "tf.Const"() { value = dense<-2> : tensor<i64> } : () -> tensor<i64> 967 %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<i64>) -> tensor<6x3xf32> 968 return %1 : tensor<6x3xf32> 969} 970 971// CHECK-LABEL: func @concat_v2_1d_axis 972func @concat_v2_1d_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x6xf32> { 973 // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 1 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32> 974 975 %axis = "tf.Const"() { value = dense<[1]> : tensor<1xi64> } : () -> tensor<1xi64> 976 %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<1xi64>) -> tensor<3x6xf32> 977 return %1 : tensor<3x6xf32> 978} 979 980// CHECK-LABEL: func @concat_v2_non_const_axis 981func @concat_v2_non_const_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>, %axis: tensor<i64>) -> tensor<3x6xf32> { 982 // CHECK: "tf.ConcatV2" 983 %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<i64>) -> tensor<3x6xf32> 984 return %1 : tensor<3x6xf32> 985} 986 987// CHECK-LABEL: func @concat_v2_unranked 988func @concat_v2_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { 989 %axis = "tf.Const"() { value = dense<0> : tensor<i64> } : () -> tensor<i64> 990 // CHECK: "tf.ConcatV2" 991 %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<*xf32>, tensor<*xf32>, tensor<i64>) -> tensor<*xf32> 992 return %1 : tensor<*xf32> 993} 994 995//===----------------------------------------------------------------------===// 996// Pad op legalizations. 997//===----------------------------------------------------------------------===// 998 999// CHECK-LABEL: func @padv2_1D 1000func @padv2_1D(%arg0: tensor<3xf32>, %arg1: tensor<f32>) -> tensor<6xf32> { 1001 %padding = "tf.Const"() { value = dense<[[1, 2]]> : tensor<1x2xi64> } : () -> tensor<1x2xi64> 1002 // CHECK: "mhlo.pad"(%arg0, %arg1) { 1003 // CHECK-SAME: edge_padding_high = dense<2> : tensor<1xi64>, 1004 // CHECK-SAME: edge_padding_low = dense<1> : tensor<1xi64>, 1005 // CHECK-SAME: interior_padding = dense<0> : tensor<1xi64> 1006 %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3xf32>, tensor<1x2xi64>, tensor<f32>) -> tensor<6xf32> 1007 return %1 : tensor<6xf32> 1008} 1009 1010// CHECK-LABEL: func @padv2_2D 1011func @padv2_2D(%arg0: tensor<3x2xf32>, %arg1: tensor<f32>) -> tensor<6x9xf32> { 1012 %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi64> } : () -> tensor<2x2xi64> 1013 // CHECK: "mhlo.pad"(%arg0, %arg1) { 1014 // CHECK-SAME: edge_padding_high = dense<[2, 4]> : tensor<2xi64>, 1015 // CHECK-SAME: edge_padding_low = dense<[1, 3]> : tensor<2xi64>, 1016 // CHECK-SAME: interior_padding = dense<0> : tensor<2xi64> 1017 %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi64>, tensor<f32>) -> tensor<6x9xf32> 1018 return %1 : tensor<6x9xf32> 1019} 1020 1021// CHECK-LABEL: func @padv2_i32_paddings 1022func @padv2_i32_paddings(%arg0: tensor<3x2xf32>, %arg1: tensor<f32>) -> tensor<6x9xf32> { 1023 %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi32> } : () -> tensor<2x2xi32> 1024 // CHECK: "mhlo.pad"(%arg0, %arg1) { 1025 // CHECK-SAME: edge_padding_high = dense<[2, 4]> : tensor<2xi64>, 1026 // CHECK-SAME: edge_padding_low = dense<[1, 3]> : tensor<2xi64>, 1027 // CHECK-SAME: interior_padding = dense<0> : tensor<2xi64> 1028 %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi32>, tensor<f32>) -> tensor<6x9xf32> 1029 return %1 : tensor<6x9xf32> 1030} 1031 1032//===----------------------------------------------------------------------===// 1033// Identity op legalizations. 1034//===----------------------------------------------------------------------===// 1035 1036// CHECK-LABEL: func @identity 1037func @identity(%arg0: tensor<1xi32>) -> tensor<1xi32> { 1038 // CHECK-NEXT: return %arg0 : tensor<1xi32> 1039 %0 = "tf.Identity"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> 1040 return %0: tensor<1xi32> 1041} 1042 1043// CHECK-LABEL: func @identityN 1044func @identityN(%arg0: tensor<1xi32>, %arg1: tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>) { 1045 // CHECK-NEXT: return %arg0, %arg1 : tensor<1xi32>, tensor<1xf32> 1046 %0:2 = "tf.IdentityN"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>) 1047 return %0#0, %0#1: tensor<1xi32>, tensor<1xf32> 1048} 1049 1050// CHECK-LABEL: func @stopgradient 1051func @stopgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> { 1052 // CHECK-NEXT: return %arg0 : tensor<1xi32> 1053 %0 = "tf.StopGradient"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> 1054 return %0: tensor<1xi32> 1055} 1056 1057// CHECK-LABEL: func @preventgradient 1058func @preventgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> { 1059 // CHECK-NEXT: return %arg0 : tensor<1xi32> 1060 %0 = "tf.PreventGradient"(%arg0) {message = "fin gradients"} : (tensor<1xi32>) -> tensor<1xi32> 1061 return %0: tensor<1xi32> 1062} 1063 1064// CHECK-LABEL: func @checkNumerics 1065func @checkNumerics(%arg0: tensor<1xf32>) -> tensor<1xf32> { 1066 // CHECK-NEXT: return %arg0 : tensor<1xf32> 1067 %0 = "tf.CheckNumerics"(%arg0) {message = "check numerics"} : (tensor<1xf32>) -> tensor<1xf32> 1068 return %0: tensor<1xf32> 1069} 1070 1071//===----------------------------------------------------------------------===// 1072// InfeedDequeueTuple legalization 1073//===----------------------------------------------------------------------===// 1074 1075// CHECK-LABEL: func @infeed_dequeue_tuple 1076func @infeed_dequeue_tuple() -> (tensor<3xi32>, tensor<4xf32>) { 1077// CHECK: [[TOKEN:%.*]] = "mhlo.create_token"() : () -> !mhlo.token 1078// CHECK: [[INFEED:%.*]] = "mhlo.infeed"([[TOKEN]]) {infeed_config = ""} : (!mhlo.token) -> tuple<tuple<tensor<3xi32>, tensor<4xf32>>, !mhlo.token> 1079// CHECK: [[INFEED_VAL:%.*]] = "mhlo.get_tuple_element"([[INFEED]]) {index = 0 : i32} : (tuple<tuple<tensor<3xi32>, tensor<4xf32>>, !mhlo.token>) -> tuple<tensor<3xi32>, tensor<4xf32>> 1080// CHECK: [[RES_1:%.*]] = "mhlo.get_tuple_element"([[INFEED_VAL]]) {index = 0 : i32} : (tuple<tensor<3xi32>, tensor<4xf32>>) -> tensor<3xi32> 1081// CHECK: [[RES_2:%.*]] = "mhlo.get_tuple_element"([[INFEED_VAL]]) {index = 1 : i32} : (tuple<tensor<3xi32>, tensor<4xf32>>) -> tensor<4xf32> 1082// CHECK: return [[RES_1]], [[RES_2]] 1083 %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3xi32>, tensor<4xf32>) 1084 return %0#0, %0#1 : tensor<3xi32>, tensor<4xf32> 1085} 1086 1087// The following op sharding is used: 1088// Proto debug string: 1089// type: TUPLE 1090// tuple_shardings { 1091// type: MAXIMAL 1092// tile_assignment_dimensions: 1 1093// tile_assignment_devices: 0 1094// } 1095// Serialized string: 1096// "\08\02*\08\08\01\1A\01\01\22\01\00" 1097 1098// CHECK-LABEL: infeed_dequeue_tuple_sharding 1099func @infeed_dequeue_tuple_sharding() -> tensor<8xi32> { 1100 // CHECK: "mhlo.infeed" 1101 // An additional sharding is added at the end to account for token result. 1102 // Proto debug string: 1103 // type: TUPLE 1104 // tuple_shardings { 1105 // type: MAXIMAL 1106 // tile_assignment_dimensions: 1 1107 // tile_assignment_devices: 0 1108 // } 1109 // tuple_shardings { 1110 // type: MAXIMAL 1111 // tile_assignment_dimensions: 1 1112 // tile_assignment_devices: 0 1113 // } 1114 // CHECK-SAME: mhlo.sharding = "\08\02*\08\08\01\1A\01\01\22\01\00*\08\08\01\1A\01\01\22\01\00" 1115 %0 = "tf.InfeedDequeueTuple"() {_XlaSharding = "\08\02*\08\08\01\1A\01\01\22\01\00"} : () -> tensor<8xi32> 1116 return %0 : tensor<8xi32> 1117} 1118 1119//===----------------------------------------------------------------------===// 1120// Nullary op legalizations. 1121//===----------------------------------------------------------------------===// 1122 1123// CHECK-LABEL: @const 1124func @const() -> tensor<2xi32> { 1125 // CHECK: mhlo.constant dense<0> : tensor<2xi32> 1126 %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = dense<0> : tensor<2xi32>} : () -> (tensor<2xi32>) 1127 return %0: tensor<2xi32> 1128} 1129 1130// CHECK-LABEL: @const_dynamic_output 1131func @const_dynamic_output() -> tensor<*xi32> { 1132 // CHECK: [[CONST:%.*]] = mhlo.constant dense<0> : tensor<2xi32> 1133 // CHECK: [[CAST:%.*]] = tensor.cast [[CONST]] : tensor<2xi32> to tensor<*xi32> 1134 %0 = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> (tensor<*xi32>) 1135 // CHECK: return [[CAST]] 1136 return %0: tensor<*xi32> 1137} 1138 1139// CHECK-LABEL: @opaque_const 1140func @opaque_const() -> tensor<!tf.variant<tensor<2xi32>>> { 1141 // CHECK-NOT: mhlo.constant 1142 %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<!tf.variant>} : () -> tensor<!tf.variant<tensor<2xi32>>> 1143 return %0 : tensor<!tf.variant<tensor<2xi32>>> 1144} 1145 1146//===----------------------------------------------------------------------===// 1147// Matmul op legalizations. 1148//===----------------------------------------------------------------------===// 1149 1150// CHECK-LABEL: matmul_notranspose 1151// CHECK-SAME: (%[[A:.*]]: tensor<5x7xf32>, %[[B:.*]]: tensor<7x11xf32>) 1152func @matmul_notranspose(%a: tensor<5x7xf32>, %b: tensor<7x11xf32>) -> tensor<5x11xf32> { 1153 // CHECK: "mhlo.dot"(%[[A]], %[[B]]) 1154 %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor<5x7xf32>, tensor<7x11xf32>) -> tensor<5x11xf32> 1155 1156 return %0 : tensor<5x11xf32> 1157} 1158 1159// CHECK-LABEL: matmul_transpose_b 1160// CHECK-SAME: (%[[A:.*]]: tensor<5x7xf32>, %[[B:.*]]: tensor<11x7xf32>) 1161func @matmul_transpose_b(%a: tensor<5x7xf32>, %b: tensor<11x7xf32>) -> tensor<5x11xf32> { 1162 // CHECK: %[[UPDATED_B:.*]] = "mhlo.transpose"(%[[B]]) {permutation = dense<[1, 0]> : tensor<2xi64>} 1163 // CHECK: "mhlo.dot"(%[[A]], %[[UPDATED_B]]) 1164 %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = true} : (tensor<5x7xf32>, tensor<11x7xf32>) -> tensor<5x11xf32> 1165 1166 return %0 : tensor<5x11xf32> 1167} 1168 1169// CHECK-LABEL: matmul_transpose_both 1170// CHECK-SAME: (%[[A:.*]]: tensor<7x5xf32>, %[[B:.*]]: tensor<11x7xf32>) 1171func @matmul_transpose_both(%a: tensor<7x5xf32>, %b: tensor<11x7xf32>) -> tensor<5x11xf32> { 1172 // CHECK: %[[UPDATED_A:.*]] = "mhlo.transpose"(%[[A]]) {permutation = dense<[1, 0]> : tensor<2xi64>} 1173 // CHECK: %[[UPDATED_B:.*]] = "mhlo.transpose"(%[[B]]) {permutation = dense<[1, 0]> : tensor<2xi64>} 1174 // CHECK: "mhlo.dot"(%[[UPDATED_A]], %[[UPDATED_B]]) 1175 %0 = "tf.MatMul"(%a, %b) {transpose_a = true, transpose_b = true} : (tensor<7x5xf32>, tensor<11x7xf32>) -> tensor<5x11xf32> 1176 1177 return %0 : tensor<5x11xf32> 1178} 1179 1180// Verify that MatMul with ranked inputs are lowered to HLO. 1181// CHECK-LABEL: matmul_ranked 1182func @matmul_ranked(%a: tensor<?x7xf32>, %b: tensor<7x?xf32>) -> tensor<?x?xf32> { 1183 // CHECK: "mhlo.dot" 1184 %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor<?x7xf32>, tensor<7x?xf32>) -> tensor<?x?xf32> 1185 1186 return %0 : tensor<?x?xf32> 1187} 1188 1189// Verify that MatMul with unranked inputs are lowered to HLO. 1190// CHECK-LABEL: matmul_unranked 1191func @matmul_unranked(%a: tensor<*xf32>, %b: tensor<*xf32>) -> tensor<*xf32> { 1192 // CHECK: "mhlo.dot" 1193 %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> 1194 1195 return %0 : tensor<*xf32> 1196} 1197 1198// Verify SparseMatMul is legalized to dot. 1199// CHECK-LABEL: test_sparse_mat_mul 1200func @test_sparse_mat_mul(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> tensor<3x5xf32> { 1201 // CHECK: "mhlo.dot" 1202 %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<3x5xf32> 1203 return %0: tensor<3x5xf32> 1204} 1205 1206// SparseMatMul where one operand needs to be transposed and the other one not. 1207// 1208// CHECK-LABEL: @test_sparse_mat_mul_with_transpose 1209// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32> 1210// CHECK-SAME: %[[ARG1:.*]]: tensor<5x4xf32> 1211// CHECK-SAME: -> tensor<3x5xf32> 1212// CHECK: %[[TRANSPOSE:.*]] = "mhlo.transpose"(%[[ARG1]]) 1213// CHECK-SAME: permutation = dense<[1, 0]> 1214// CHECK-SAME: -> tensor<4x5xf32> 1215// CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[TRANSPOSE]]) 1216// CHECK-SAME: -> tensor<3x5xf32> 1217// CHECK: return %[[RESULT]] 1218func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5x4xf32>) -> tensor<3x5xf32> { 1219 %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = true} : (tensor<3x4xf32>, tensor<5x4xf32>) -> tensor<3x5xf32> 1220 return %0: tensor<3x5xf32> 1221} 1222 1223// SparseMatMul where one operand needs to be casted and the other one not. 1224// 1225// CHECK-LABEL: @test_sparse_mat_mul_with_cast 1226// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32> 1227// CHECK-SAME: %[[ARG1:.*]]: tensor<4x5xbf16> 1228// CHECK-SAME: -> tensor<3x5xf32> 1229// CHECK: %[[CAST:.*]] = "mhlo.convert"(%[[ARG1]]) 1230// CHECK-SAME: -> tensor<4x5xf32> 1231// CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[CAST]]) 1232// CHECK-SAME: -> tensor<3x5xf32> 1233// CHECK: return %[[RESULT]] 1234func @test_sparse_mat_mul_with_cast(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xbf16>) -> tensor<3x5xf32> { 1235 %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xbf16>) -> tensor<3x5xf32> 1236 return %0: tensor<3x5xf32> 1237} 1238 1239//===----------------------------------------------------------------------===// 1240// MatrixBandPart op legalizations. 1241//===----------------------------------------------------------------------===// 1242 1243// CHECK-LABEL: matrix_band_part 1244// CHECK-SAME: (%[[INPUT:.*]]: tensor<64x64xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>) 1245func @matrix_band_part(%arg0: tensor<64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> { 1246 // CHECK-DAG: %[[M:.*]] = mhlo.constant dense<64> : tensor<i64> 1247 // CHECK-DAG: %[[N:.*]] = mhlo.constant dense<64> : tensor<i64> 1248 1249 // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i64> 1250 // CHECK-DAG: %[[A:.*]] = "mhlo.compare"(%[[LOWER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1> 1251 // CHECK-DAG: %[[B:.*]] = "mhlo.select"(%[[A]], %[[M]], %[[LOWER]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64> 1252 1253 // CHECK-DAG: %[[C:.*]] = "mhlo.compare"(%[[UPPER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1> 1254 // CHECK-DAG: %[[D:.*]] = "mhlo.select"(%[[C]], %[[N]], %[[UPPER]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64> 1255 // CHECK-DAG: %[[F:.*]] = "mhlo.negate"(%[[B]]) : (tensor<i64>) -> tensor<i64> 1256 1257 // CHECK-DAG: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<64x64xi64> 1258 // CHECK-DAG: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64x64xi64> 1259 // CHECK-DAG: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<64x64xi64> 1260 // CHECK-DAG: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<i64>, tensor<64x64xi64>) -> tensor<64x64xi1> 1261 1262 // CHECK-DAG: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[D]] {comparison_direction = "LE"} : (tensor<64x64xi64>, tensor<i64>) -> tensor<64x64xi1> 1263 1264 // CHECK-DAG: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<64x64xi1> 1265 1266 // CHECK-DAG: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<64x64xbf16> 1267 1268 // CHECK-DAG: %[[R:.*]] = "mhlo.select"(%[[J]], %[[INPUT]], %[[ZERO2]]) 1269 // CHECK-DAG: return %[[R]] 1270 %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64xbf16> 1271 return %0 : tensor<64x64xbf16> 1272} 1273 1274// CHECK-LABEL: matrix_band_part_2 1275// CHECK-SAME: (%[[INPUT:.*]]: tensor<12x24x48xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>) 1276func @matrix_band_part_2(%arg0: tensor<12x24x48xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<12x24x48xbf16> { 1277 // CHECK-DAG: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<24x48xi64> 1278 // CHECK-DAG: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<24x48xi64> 1279 // CHECK-DAG: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<24x48xi64> 1280 1281 // CHECK-DAG: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<i64>, tensor<24x48xi64>) -> tensor<24x48xi1> 1282 1283 // CHECK-DAG: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[D]] {comparison_direction = "LE"} : (tensor<24x48xi64>, tensor<i64>) -> tensor<24x48xi1> 1284 // CHECK-DAG: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<24x48xi1> 1285 1286 // CHECK-DAG: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<12x24x48xbf16> 1287 1288 // CHECK-DAG: %[[K:.*]] = "mhlo.broadcast_in_dim"(%[[J]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<24x48xi1>) -> tensor<12x24x48xi1> 1289 // CHECK-DAG: %[[R:.*]] = "mhlo.select"(%[[K]], %[[INPUT]], %[[ZERO2]]) 1290 // CHECK-DAG: return %[[R]] 1291 %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<12x24x48xbf16>, tensor<i64>, tensor<i64>) -> tensor<12x24x48xbf16> 1292 return %0 : tensor<12x24x48xbf16> 1293} 1294 1295// CHECK-LABEL: matrix_band_part_3 1296// CHECK-SAME: (%[[INPUT:.*]]: tensor<*xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>) 1297func @matrix_band_part_3(%arg0: tensor<*xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<*xbf16> { 1298 // CHECK: "tf.MatrixBandPart" 1299 %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<*xbf16>, tensor<i64>, tensor<i64>) -> tensor<*xbf16> 1300 return %0 : tensor<*xbf16> 1301} 1302 1303// CHECK-LABEL: matrix_band_part_4 1304// CHECK-SAME: (%[[INPUT:.*]]: tensor<24x48xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>) 1305func @matrix_band_part_4(%arg0: tensor<24x48xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<24x48xbf16> { 1306 // This one should lower. 1307 // CHECK-NOT: "tf.MatrixBandPart" 1308 %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<24x48xbf16>, tensor<i64>, tensor<i64>) -> tensor<24x48xbf16> 1309 return %0 : tensor<24x48xbf16> 1310} 1311 1312//===----------------------------------------------------------------------===// 1313// MaxPool op legalizations. 1314//===----------------------------------------------------------------------===// 1315 1316// CHECK-LABEL: maxpool_valid_padding 1317// CHECK-SAME: %[[ARG:.*]]: tensor 1318func @maxpool_valid_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> { 1319 // CHECK: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor<i32> 1320 // CHECK: "mhlo.reduce_window"(%[[ARG]], %[[INIT]]) 1321 // CHECK: mhlo.maximum 1322 // CHECK: mhlo.return 1323 // CHECK: {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>} 1324 1325 %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> 1326 return %0 : tensor<2x3x5x7xi32> 1327} 1328 1329// CHECK-LABEL: maxpool_same_padding 1330// CHECK-SAME: %[[ARG:.*]]: tensor 1331func @maxpool_same_padding(%arg0: tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32> { 1332 // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> 1333 1334 %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 4, 1]} : (tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32> 1335 return %0 : tensor<2x4x7x7xi32> 1336} 1337 1338// CHECK-LABEL: maxpool_3d_valid_padding 1339// CHECK-SAME: %[[ARG:.*]]: tensor 1340func @maxpool_3d_valid_padding(%arg0: tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32> { 1341 // CHECK: %[[INIT:.*]] = mhlo.constant dense<0xFF800000> : tensor<f32> 1342 // CHECK: "mhlo.reduce_window"(%[[ARG]], %[[INIT]]) 1343 // CHECK: mhlo.maximum 1344 // CHECK: mhlo.return 1345 // CHECK: {window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 4, 4, 1]> : tensor<5xi64>} 1346 1347 %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32> 1348 return %0 : tensor<2x8x3x5x7xf32> 1349} 1350 1351// CHECK-LABEL: maxpool_3d_same_padding 1352// CHECK-SAME: %[[ARG:.*]]: tensor 1353func @maxpool_3d_same_padding(%arg0: tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x7xf32> { 1354 // CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64> 1355 1356 %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 3, 1], padding = "SAME", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x7xf32> 1357 return %0 : tensor<2x8x4x7x7xf32> 1358} 1359 1360// CHECK-LABEL: maxpool_explicit_padding 1361func @maxpool_explicit_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> { 1362 // CHECK: tf.MaxPool 1363 // TODO(b/165938852): need to support explicit padding in max_pool. 1364 1365 %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "EXPLICIT", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> 1366 return %0 : tensor<2x3x5x7xi32> 1367} 1368 1369//===----------------------------------------------------------------------===// 1370// MaxPoolGrad op legalizations. 1371//===----------------------------------------------------------------------===// 1372 1373// CHECK-LABEL: @max_pool_grad_valid 1374// CHECK-SAME: %[[INPUT:.*]]: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x12x12x64xf32> 1375func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_output: tensor<10x12x12x64xf32>, %grad: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> { 1376 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 1377 // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) ( { 1378 // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>): 1379 // CHECK: %[[SELECT_RESULT:.*]] = "mhlo.compare"(%[[VALUE_A]], %[[VALUE_B]]) {comparison_direction = "GE"} : (tensor<f32>, tensor<f32>) -> tensor<i1> 1380 // CHECK: "mhlo.return"(%[[SELECT_RESULT]]) : (tensor<i1>) -> () 1381 // CHECK: }, { 1382 // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>): 1383 // CHECK: %[[SELECT_RESULT:.*]] = mhlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor<f32> 1384 // CHECK: "mhlo.return"(%[[SELECT_RESULT]]) : (tensor<f32>) -> () 1385 // CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) -> tensor<10x24x24x64xf32> 1386 // CHECK: return %[[RESULT]] : tensor<10x24x24x64xf32> 1387 %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) { 1388 data_format = "NHWC", 1389 ksize = [1, 2, 2, 1], 1390 padding = "VALID", 1391 strides = [1, 2, 2, 1] 1392 } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> 1393 return %result : tensor<10x24x24x64xf32> 1394} 1395 1396// CHECK-LABEL: @max_pool_3d_grad_valid 1397// CHECK-SAME: %[[INPUT:.*]]: tensor<10x8x24x24x64xf32>, %arg1: tensor<10x8x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x8x12x12x64xf32> 1398func @max_pool_3d_grad_valid(%orig_input: tensor<10x8x24x24x64xf32>, %orig_output: tensor<10x8x12x12x64xf32>, %grad: tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> { 1399 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 1400 // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) ( { 1401 // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>): 1402 // CHECK: %[[SELECT_RESULT:.*]] = "mhlo.compare"(%[[VALUE_A]], %[[VALUE_B]]) {comparison_direction = "GE"} : (tensor<f32>, tensor<f32>) -> tensor<i1> 1403 // CHECK: "mhlo.return"(%[[SELECT_RESULT]]) : (tensor<i1>) -> () 1404 // CHECK: }, { 1405 // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>): 1406 // CHECK: %[[SELECT_RESULT:.*]] = mhlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor<f32> 1407 // CHECK: "mhlo.return"(%[[SELECT_RESULT]]) : (tensor<f32>) -> () 1408 // CHECK: }) {window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<f32>) -> tensor<10x8x24x24x64xf32> 1409 // CHECK: return %[[RESULT]] : tensor<10x8x24x24x64xf32> 1410 %result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 2, 2, 1]} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> 1411 return %result : tensor<10x8x24x24x64xf32> 1412} 1413 1414// CHECK-LABEL: @max_pool_grad_same 1415func @max_pool_grad_same(%orig_input: tensor<2x13x25x7xf32>, %orig_output: tensor<2x4x7x7xf32>, %grad: tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> { 1416 // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> 1417 %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) { 1418 data_format = "NHWC", 1419 ksize = [1, 2, 3, 1], 1420 padding = "SAME", 1421 strides = [1, 4, 4, 1] 1422 } : (tensor<2x13x25x7xf32>, tensor<2x4x7x7xf32>, tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> 1423 return %result : tensor<2x13x25x7xf32> 1424} 1425 1426// CHECK-LABEL: @max_pool_3d_grad_same 1427func @max_pool_3d_grad_same(%orig_input: tensor<2x8x13x25x7xf32>, %orig_output: tensor<2x8x4x7x7xf32>, %grad: tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32> { 1428 // CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64> 1429 %result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 3, 1], padding = "SAME", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x13x25x7xf32>, tensor<2x8x4x7x7xf32>, tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32> 1430 return %result : tensor<2x8x13x25x7xf32> 1431} 1432 1433//===----------------------------------------------------------------------===// 1434// OneHot op legalizations. 1435//===----------------------------------------------------------------------===// 1436 1437// CHECK-LABEL:one_hot 1438func @one_hot(%indices: tensor<3xi32>, %on_value: tensor<f32>, %off_value: tensor<f32>) -> tensor<3x5xf32> { 1439 // CHECK: %[[IOTA:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x5xi32> 1440 // CHECK: %[[BCAST_ARG0:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<3x5xi32> 1441 // CHECK: %[[COMPARE:.*]] = "mhlo.compare"(%[[BCAST_ARG0]], %[[IOTA]]) {comparison_direction = "EQ"} : (tensor<3x5xi32>, tensor<3x5xi32>) -> tensor<3x5xi1> 1442 // CHECK: %[[ON_VALUE:.*]] = "mhlo.broadcast"(%arg1) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor<f32>) -> tensor<3x5xf32> 1443 // CHECK: %[[OFF_VALUE:.*]] = "mhlo.broadcast"(%arg2) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor<f32>) -> tensor<3x5xf32> 1444 // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[COMPARE]], %[[ON_VALUE]], %[[OFF_VALUE]]) : (tensor<3x5xi1>, tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<3x5xf32> 1445 // CHECK: return %[[RESULT]] : tensor<3x5xf32> 1446 %depth = "tf.Const"() { value = dense<5> : tensor<i32> } : () -> tensor<i32> 1447 %result = "tf.OneHot"(%indices, %depth, %on_value, %off_value) {axis = -1 : i64} : (tensor<3xi32>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<3x5xf32> 1448 return %result : tensor<3x5xf32> 1449} 1450 1451//===----------------------------------------------------------------------===// 1452// tf.OutfeedEnqueueTuple legalization 1453//===----------------------------------------------------------------------===// 1454 1455// CHECK-LABEL: func @outfeed_enqueue_tuple 1456// CHECK-SAME: [[VAL_0:%.*]]: tensor<3xi32>, [[VAL_1:%.*]]: tensor<4xf32>) 1457func @outfeed_enqueue_tuple(%data_1: tensor<3xi32>, %data_2: tensor<4xf32>) -> () { 1458// CHECK: [[TUPLE:%.*]] = "mhlo.tuple"([[VAL_0]], [[VAL_1]]) : (tensor<3xi32>, tensor<4xf32>) -> tuple<tensor<3xi32>, tensor<4xf32>> 1459// CHECK: [[TOKEN:%.*]] = "mhlo.create_token"() : () -> !mhlo.token 1460// CHECK: "mhlo.outfeed"([[TUPLE]], [[TOKEN]]) {outfeed_config = ""} : (tuple<tensor<3xi32>, tensor<4xf32>>, !mhlo.token) -> !mhlo.token 1461 "tf.OutfeedEnqueueTuple"(%data_1, %data_2) : (tensor<3xi32>, tensor<4xf32>) -> () 1462 return 1463} 1464 1465//===----------------------------------------------------------------------===// 1466// Pack op legalizations. 1467//===----------------------------------------------------------------------===// 1468 1469// CHECK-LABEL: func @pack 1470func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { 1471 // CHECK: "mhlo.reshape"({{.*}}) : (tensor<2xi32>) -> tensor<1x2xi32> 1472 // CHECK: "mhlo.reshape"({{.*}}) : (tensor<2xi32>) -> tensor<1x2xi32> 1473 // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32> 1474 1475 %0 = "tf.Pack"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32> 1476 return %0 : tensor<2x2xi32> 1477} 1478 1479//===----------------------------------------------------------------------===// 1480// PartitionedCall op legalization. 1481//===----------------------------------------------------------------------===// 1482 1483// CHECK-LABEL: func @partitioned_call 1484func @partitioned_call(%arg0: tensor<i32>) -> tensor<i32> { 1485 // CHECK: call @pcall_func(%arg0) : (tensor<i32>) -> tensor<i32> 1486 %0 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @pcall_func} : (tensor<i32>) -> (tensor<i32>) 1487 return %0 : tensor<i32> 1488} 1489 1490func @pcall_func(%arg0: tensor<i32>) -> tensor<i32> { 1491 return %arg0 : tensor<i32> 1492} 1493 1494// CHECK-LABEL: func @partitioned_call_multi_input 1495func @partitioned_call_multi_input(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> { 1496 // CHECK: call @pcall_multi_input(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32> 1497 %0 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_input} : (tensor<i32>, tensor<i32>) -> (tensor<i32>) 1498 return %0 : tensor<i32> 1499} 1500 1501func @pcall_multi_input(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> { 1502 return %arg0 : tensor<i32> 1503} 1504 1505// CHECK-LABEL: func @partitioned_call_multi_in_out 1506func @partitioned_call_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) { 1507 // CHECK: call @pcall_multi_in_out(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) 1508 %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) 1509 return %0, %1 : tensor<i32>, tensor<i32> 1510} 1511 1512func @pcall_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) { 1513 return %arg1, %arg0 : tensor<i32>, tensor<i32> 1514} 1515 1516// CHECK-LABEL: func @unhandled_partitioned_call 1517func @unhandled_partitioned_call(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> (tensor<i32>, tensor<i32>) { 1518 // The argument types don't match the parameter types for the 1519 // pcall_multi_in_out function. That's fine for a PartitionedCallOp but not 1520 // for a standard CallOp, so this op can't be lowered. 1521 // CHECK: "tf.PartitionedCall" 1522 %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<i32>, tensor<i32>) 1523 return %0, %1 : tensor<i32>, tensor<i32> 1524} 1525 1526// CHECK-LABEL: func @unhandled_partitioned_call_2 1527func @unhandled_partitioned_call_2(%arg0: tensor<i32>, %arg1: tensor<*xi32>) -> (tensor<i32>, tensor<i32>) { 1528 // CHECK: "tf.PartitionedCall" 1529 %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor<i32>, tensor<*xi32>) -> (tensor<i32>, tensor<i32>) 1530 return %0, %1 : tensor<i32>, tensor<i32> 1531} 1532 1533 1534//===----------------------------------------------------------------------===// 1535// ReverseV2 op legalization. 1536//===----------------------------------------------------------------------===// 1537 1538// CHECK-LABEL: @reverse_func_32 1539func @reverse_func_32(%arg0: tensor<5xi32>) -> tensor<5xi32> { 1540 %axis = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> (tensor<1xi32>) 1541 1542 // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} 1543 %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi32>) -> tensor<5xi32> 1544 1545 // CHECK: return [[VAL]] : tensor<5xi32> 1546 return %reversed : tensor<5xi32> 1547} 1548 1549// CHECK-LABEL: @reverse_func_64 1550func @reverse_func_64(%arg0: tensor<5xi32>) -> tensor<5xi32> { 1551 %axis = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> (tensor<1xi64>) 1552 1553 // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} 1554 %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi64>) -> tensor<5xi32> 1555 1556 // CHECK: return [[VAL]] : tensor<5xi32> 1557 return %reversed : tensor<5xi32> 1558} 1559 1560// CHECK-LABEL: @reverse_func_neg 1561func @reverse_func_neg(%arg0: tensor<5x5xi32>) -> tensor<5x5xi32> { 1562 %axis = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>) 1563 1564 // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) {dimensions = dense<1> : tensor<1xi64>} 1565 %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5x5xi32>, tensor<1xi32>) -> tensor<5x5xi32> 1566 1567 // CHECK: return [[VAL]] : tensor<5x5xi32> 1568 return %reversed : tensor<5x5xi32> 1569} 1570 1571//===----------------------------------------------------------------------===// 1572// StatefulPartitionedCall op legalization. 1573//===----------------------------------------------------------------------===// 1574 1575// CHECK-LABEL: func @stateful_partitioned_call 1576// CHECK-SAME: [[ARG:%.+]]: tensor<i32> 1577func @stateful_partitioned_call(%arg0: tensor<i32>) -> tensor<i32> { 1578 // CHECK: call @stateful_pcall_func([[ARG]]) : (tensor<i32>) -> tensor<i32> 1579 %0 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @stateful_pcall_func} : (tensor<i32>) -> (tensor<i32>) 1580 return %0 : tensor<i32> 1581} 1582 1583func @stateful_pcall_func(%arg0: tensor<i32>) -> tensor<i32> { 1584 return %arg0 : tensor<i32> 1585} 1586 1587// CHECK-LABEL: func @stateful_partitioned_call_multi_in_out 1588// CHECK-SAME: ([[ARG0:%.+]]: tensor<i32>, [[ARG1:%.+]]: tensor<i32>) 1589func @stateful_partitioned_call_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) { 1590 // CHECK: call @stateful_pcall_multi_in_out([[ARG0]], [[ARG1]]) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) 1591 %0, %1 = "tf.StatefulPartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @stateful_pcall_multi_in_out} : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>) 1592 return %0, %1 : tensor<i32>, tensor<i32> 1593} 1594 1595func @stateful_pcall_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) { 1596 return %arg1, %arg0 : tensor<i32>, tensor<i32> 1597} 1598 1599//===----------------------------------------------------------------------===// 1600// Elu op legalizations. 1601//===----------------------------------------------------------------------===// 1602 1603// CHECK-LABEL: func @elu 1604func @elu(%arg0: tensor<1xf32>) -> tensor<1xf32> { 1605 // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 1606 // CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %arg0, %[[ZERO]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "GT"} 1607 // CHECK-DAG: %[[EXP:.*]] = "mhlo.exponential_minus_one"(%arg0) 1608 // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %arg0, %[[EXP]]) 1609 // CHECK: return %[[RESULT]] 1610 %0 = "tf.Elu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> 1611 return %0: tensor<1xf32> 1612} 1613 1614// CHECK-LABEL: func @elu_grad 1615// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<4x8xf32>, %[[FEATURES:.*]]: tensor<?x?xf32>) 1616func @elu_grad(%gradients: tensor<4x8xf32>, %features: tensor<?x?xf32>) -> tensor<4x8xf32> { 1617 // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 1618 // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> 1619 // CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FEATURES]], %[[ZERO]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "GT"} 1620 // CHECK-DAG: %[[ADD1:.*]] = chlo.broadcast_add %[[FEATURES]], %[[ONE]] {broadcast_dimensions = dense<> : tensor<0xi64>} 1621 // CHECK-DAG: %[[MULGRAD:.*]] = "mhlo.multiply"(%[[GRADIENTS]], %[[ADD1]]) 1622 // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %[[GRADIENTS]], %[[MULGRAD]]) 1623 // CHECK: return %[[RESULT]] 1624 %2 = "tf.EluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor<?x?xf32>) -> tensor<4x8xf32> 1625 return %2 : tensor<4x8xf32> 1626} 1627 1628//===----------------------------------------------------------------------===// 1629// Relu op legalizations. 1630//===----------------------------------------------------------------------===// 1631 1632// CHECK-LABEL: func @relu 1633func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { 1634 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32> 1635 // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32> 1636 %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> 1637 return %0: tensor<1xi32> 1638} 1639 1640// CHECK-LABEL: func @relu_unranked 1641func @relu_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> { 1642 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32> 1643 // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32> 1644 %0 = "tf.Relu"(%arg0) : (tensor<?xi32>) -> tensor<?xi32> 1645 return %0: tensor<?xi32> 1646} 1647 1648// CHECK-LABEL: func @relu6 1649func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { 1650 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32> 1651 // CHECK: %[[SIX:.*]] = mhlo.constant dense<6> : tensor<i32> 1652 // CHECK: "mhlo.clamp"(%[[ZERO]], %arg0, %[[SIX]]) : (tensor<i32>, tensor<1xi32>, tensor<i32>) -> tensor<1xi32> 1653 %0 = "tf.Relu6"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> 1654 return %0: tensor<1xi32> 1655} 1656 1657// CHECK-LABEL: func @relu6_unranked 1658func @relu6_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> { 1659 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32> 1660 // CHECK: %[[SIX:.*]] = mhlo.constant dense<6> : tensor<i32> 1661 // CHECK: "mhlo.clamp"(%[[ZERO]], %arg0, %[[SIX]]) : (tensor<i32>, tensor<?xi32>, tensor<i32>) -> tensor<?xi32> 1662 %0 = "tf.Relu6"(%arg0) : (tensor<?xi32>) -> tensor<?xi32> 1663 return %0: tensor<?xi32> 1664} 1665 1666// CHECK-LABEL: func @relu_grad 1667// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<4x8xf32>, %[[FEATURES:.*]]: tensor<?x?xf32>) 1668func @relu_grad(%gradients: tensor<4x8xf32>, %features: tensor<?x?xf32>) -> tensor<4x8xf32> { 1669 // CHECK-DAG: %[[ZERO_SCALAR:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 1670 // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<4x8xf32> 1671 // CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FEATURES]], %[[ZERO_SCALAR]] {comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1> 1672 // CHECK-DAG: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor<?x?xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> 1673 // CHECK-DAG: return %[[RESULT]] : tensor<4x8xf32> 1674 %2 = "tf.ReluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor<?x?xf32>) -> tensor<4x8xf32> 1675 return %2 : tensor<4x8xf32> 1676} 1677 1678//===----------------------------------------------------------------------===// 1679// Select op legalizations. 1680//===----------------------------------------------------------------------===// 1681 1682// CHECK-LABEL: func @selectv2 1683func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { 1684 // CHECK-NEXT: "mhlo.select"(%arg0, %arg1, %arg2) 1685 %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> 1686 return %0: tensor<2xi32> 1687} 1688 1689// CHECK-LABEL: func @selectv2_pred_scalar 1690func @selectv2_pred_scalar(%arg0: tensor<i1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { 1691 // CHECK-NEXT: "mhlo.select"(%arg0, %arg1, %arg2) 1692 %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> 1693 return %0: tensor<2xi32> 1694} 1695 1696// CHECK-LABEL: func @selectv2_broadcast_then 1697func @selectv2_broadcast_then(%arg0: tensor<i1>, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> { 1698 // CHECK: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<8x1xi32>) -> tensor<2x8x8xi32> 1699 // CHECK: "mhlo.select"(%arg0, %[[BROADCAST]], %arg2) 1700 %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32> 1701 return %0: tensor<2x8x8xi32> 1702} 1703 1704// CHECK-LABEL: func @selectv2_broadcast_else 1705func @selectv2_broadcast_else(%arg0: tensor<i1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<8x1xi32>) -> tensor<2x8x8xi32> { 1706 // CHECK: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<8x1xi32>) -> tensor<2x8x8xi32> 1707 // CHECK: "mhlo.select"(%arg0, %arg1, %[[BROADCAST]]) 1708 %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x8x8xi32>, tensor<8x1xi32>) -> tensor<2x8x8xi32> 1709 return %0: tensor<2x8x8xi32> 1710} 1711 1712// CHECK-LABEL: func @selectv2_broadcast_pred 1713func @selectv2_broadcast_pred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> { 1714 // CHECK: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1xi1>) -> tensor<2x8x8xi1> 1715 // CHECK: "mhlo.select"(%[[BROADCAST]], %arg1, %arg2) 1716 %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32> 1717 return %0: tensor<2x8x8xi32> 1718} 1719 1720// CHECK-LABEL: func @selectv2_broadcast_tensor_pred 1721func @selectv2_broadcast_tensor_pred(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> { 1722 // CHECK: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi1>) -> tensor<2x3xi1> 1723 // CHECK: "mhlo.select"(%[[BROADCAST]], %arg1, %arg2) 1724 %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> 1725 return %0: tensor<2x3xf16> 1726} 1727 1728// CHECK-LABEL: func @selectv2_broadcast_all 1729func @selectv2_broadcast_all(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> { 1730 // CHECK-DAG: %[[BROADCAST_0:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x1x1xi1>) -> tensor<8x8x8xi1> 1731 // CHECK-DAG: %[[BROADCAST_1:.*]] = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x8x1xi32>) -> tensor<8x8x8xi32> 1732 // CHECK-DAG: %[[BROADCAST_2:.*]] = "mhlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x1x8xi32>) -> tensor<8x8x8xi32> 1733 // CHECK: "mhlo.select"(%[[BROADCAST_0]], %[[BROADCAST_1]], %[[BROADCAST_2]]) 1734 %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8x1x1xi1>, tensor<1x8x1xi32>, tensor<1x1x8xi32>) -> tensor<8x8x8xi32> 1735 return %0: tensor<8x8x8xi32> 1736} 1737 1738// CHECK-LABEL: func @selectv2_dynamic_ranked 1739func @selectv2_dynamic_ranked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> { 1740 // CHECK: tf.SelectV2 1741 %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32> 1742 return %0: tensor<2x?x8xi32> 1743} 1744 1745// CHECK-LABEL: func @selectv2_unranked 1746func @selectv2_unranked(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<*xi32>) -> tensor<*xi32> { 1747 // CHECK: tf.SelectV2 1748 %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<*xi32>) -> tensor<*xi32> 1749 return %0: tensor<*xi32> 1750} 1751 1752//===----------------------------------------------------------------------===// 1753// Softmax op legalizations. 1754//===----------------------------------------------------------------------===// 1755 1756// CHECK-LABEL: func @simple_softmax 1757// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>) 1758func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { 1759 1760 // Verify reduce op for max computation and its body. 1761 // CHECK-DAG: %[[NEG_INF:.*]] = mhlo.constant dense<0xFF800000> : tensor<f32> 1762 // CHECK-DAG: %[[CASTED_INP:.*]] = "mhlo.convert"(%[[ARG0]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> 1763 // CHECK: %[[MAX:.*]] = "mhlo.reduce"(%[[CASTED_INP]], %[[NEG_INF]]) 1764 // CHECK: mhlo.maximum 1765 // CHECK: "mhlo.return" 1766 // CHECK: {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf32>, tensor<f32>) -> tensor<2xf32> 1767 // CHECK: %[[CASTED_MAX:.*]] = "mhlo.convert"(%[[MAX]]) : (tensor<2xf32>) -> tensor<2xf32> 1768 1769 // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]] 1770 // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] 1771 // CHECK: %[[BCAST_MAX:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[CASTED_MAX]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} 1772 // CHECK: %[[SHIFTED_INP:.*]] = mhlo.subtract %[[ARG0]], %[[BCAST_MAX]] 1773 // CHECK: %[[EXP:.*]] = "mhlo.exponential"(%[[SHIFTED_INP]]) 1774 1775 // Verify reduce op for summation and its body. 1776 // CHECK-DAG: %[[CASTED_EXP:.*]] = "mhlo.convert"(%[[EXP]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> 1777 // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 1778 // CHECK: %[[SUM:.*]] = "mhlo.reduce"(%[[CASTED_EXP]], %[[ZERO]]) 1779 // CHECK: mhlo.add 1780 // CHECK: "mhlo.return" 1781 // CHECK: {dimensions = dense<1> : tensor<1xi64>} 1782 // CHECK: %[[CASTED_SUM:.*]] = "mhlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32> 1783 1784 // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]] 1785 // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] 1786 // CHECK: %[[BCAST_SUM:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[CASTED_SUM]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} 1787 // CHECK: %[[RESULT:.*]] = mhlo.divide %[[EXP]], %[[BCAST_SUM]] 1788 // CHECK: return %[[RESULT]] 1789 1790 %0 = "tf.Softmax"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> 1791 return %0: tensor<2x3xf32> 1792} 1793 1794// Verify intermediate and final shape are correct with dynamic shapes. 1795// CHECK-LABEL: func @dynamic_softmax 1796func @dynamic_softmax(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> { 1797 // CHECK: mhlo.divide {{.*}} : tensor<?x?xf32> 1798 %0 = "tf.Softmax"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32> 1799 return %0: tensor<?x?xf32> 1800} 1801 1802// CHECK-LABEL: bf16_softmax 1803func @bf16_softmax(%arg0: tensor<2x3xbf16>) -> tensor<2x3xbf16> { 1804 // Verify that conversion to f32 and then back to bf16 are introduced. 1805 1806 // CHECK: "mhlo.convert"({{.*}}) : (tensor<2x3xbf16>) -> tensor<2x3xf32> 1807 // CHECK: "mhlo.convert"({{.*}}) : (tensor<2xf32>) -> tensor<2xbf16> 1808 1809 %0 = "tf.Softmax"(%arg0) : (tensor<2x3xbf16>) -> tensor<2x3xbf16> 1810 return %0: tensor<2x3xbf16> 1811} 1812 1813// CHECK-LABEL: rank4_softmax 1814func @rank4_softmax(%arg0: tensor<2x3x4x5xf16>) -> tensor<2x3x4x5xf16> { 1815 // Verify that reduce op dimensions and broadcast dimensions are correct. 1816 1817 // CHECK: "mhlo.reduce" 1818 // CHECK: dimensions = dense<3> 1819 1820 // CHECK: "mhlo.reduce" 1821 // CHECK: dimensions = dense<3> 1822 1823 // CHECK: {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} 1824 // CHECK: mhlo.divide {{.*}} 1825 %0 = "tf.Softmax"(%arg0) : (tensor<2x3x4x5xf16>) -> tensor<2x3x4x5xf16> 1826 return %0: tensor<2x3x4x5xf16> 1827} 1828 1829//===----------------------------------------------------------------------===// 1830// LogSoftmax op legalizations. 1831// This just changes the tail of the regular Softmax legalization 1832//===----------------------------------------------------------------------===// 1833 1834// CHECK-LABEL: func @simple_logsoftmax 1835// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>) 1836func @simple_logsoftmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { 1837 // CHECK: %{{.*}} = "mhlo.reduce"({{.*}}) 1838 // CHECK: %[[SUM:.*]] = "mhlo.reduce"({{.*}}) 1839 // CHECK: %[[CASTED_SUM:.*]] = "mhlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32> 1840 // CHECK: %[[LOG:.*]] = "mhlo.log"(%[[CASTED_SUM]]) : (tensor<2xf32>) -> tensor<2xf32> 1841 // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]] 1842 // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] 1843 // CHECK: %[[BCAST_SUM:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[LOG]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} 1844 // CHECK: %[[RESULT:.*]] = mhlo.subtract {{.*}}, %[[BCAST_SUM]] 1845 // CHECK: return %[[RESULT]] 1846 1847 %0 = "tf.LogSoftmax"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> 1848 return %0: tensor<2x3xf32> 1849} 1850 1851//===----------------------------------------------------------------------===// 1852// Fast Fourier Transform op legalization. 1853//===----------------------------------------------------------------------===// 1854 1855// CHECK-LABEL: func @fft_1D 1856func @fft_1D(%arg0: tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>> { 1857 // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = "FFT"} : (tensor<8xcomplex<f32>> 1858 %0 = "tf.FFT"(%arg0) : (tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>> 1859 return %0 : tensor<8xcomplex<f32>> 1860} 1861 1862// CHECK-LABEL: func @ifft_1D 1863func @ifft_1D(%arg0: tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>> { 1864 // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = "IFFT"} : (tensor<8xcomplex<f32>> 1865 %0 = "tf.IFFT"(%arg0) : (tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>> 1866 return %0 : tensor<8xcomplex<f32>> 1867} 1868 1869// CHECK-LABEL: func @rfft_1D 1870func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<8xcomplex<f32>> { 1871 %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) 1872 // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<8xf32> 1873 %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<8xf32>, tensor<1xi32>) -> tensor<8xcomplex<f32>> 1874 return %0 : tensor<8xcomplex<f32>> 1875} 1876 1877// CHECK-LABEL: func @rfft_1D_padded 1878func @rfft_1D_padded(%arg0: tensor<7xf32>) -> tensor<8xcomplex<f32>> { 1879 %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) 1880 // CHECK: %[[PADDED:.*]] = "mhlo.pad"(%arg0, %2) {edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<7xf32>, tensor<f32>) -> tensor<8xf32> 1881 // CHECK: "mhlo.fft"(%[[PADDED]]) {fft_length = dense<8> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<8xf32> 1882 %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<7xf32>, tensor<1xi32>) -> tensor<8xcomplex<f32>> 1883 return %0 : tensor<8xcomplex<f32>> 1884} 1885 1886// CHECK-LABEL: func @rfft_1D_sliced 1887func @rfft_1D_sliced(%arg0: tensor<2x9xf32>) -> tensor<2x8xcomplex<f32>> { 1888 %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) 1889 // CHECK: %[[SLICED:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[2, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x9xf32>) -> tensor<2x8xf32> 1890 // CHECK: "mhlo.fft"(%[[SLICED]]) {fft_length = dense<8> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<2x8xf32> 1891 %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<2x9xf32>, tensor<1xi32>) -> tensor<2x8xcomplex<f32>> 1892 return %0 : tensor<2x8xcomplex<f32>> 1893} 1894 1895// CHECK-LABEL: func @irfft_1D 1896func @irfft_1D(%arg0: tensor<8xcomplex<f32>>) -> tensor<5xf32> { 1897 %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) 1898 // CHECK: %[[SLICED:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<5> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<8xcomplex<f32>>) -> tensor<5xcomplex<f32>> 1899 // CHECK: "mhlo.fft"(%[[SLICED]]) {fft_length = dense<5> : tensor<1xi64>, fft_type = "IRFFT"} : (tensor<5xcomplex<f32>> 1900 %0 = "tf.IRFFT"(%arg0, %fftlength) : (tensor<8xcomplex<f32>>, tensor<1xi32>) -> tensor<5xf32> 1901 return %0 : tensor<5xf32> 1902} 1903 1904// CHECK-LABEL: fft_1D_dynamic 1905func @fft_1D_dynamic(%arg0: tensor<?xcomplex<f32>>) -> tensor<8xcomplex<f32>> { 1906 // CHECK: "tf.FFT" 1907 %0 = "tf.FFT"(%arg0) : (tensor<?xcomplex<f32>>) -> tensor<8xcomplex<f32>> 1908 return %0 : tensor<8xcomplex<f32>> 1909} 1910 1911// CHECK-LABEL: rfft_1D_dynamic 1912func @rfft_1D_dynamic(%arg0: tensor<?xf32>) -> tensor<8xcomplex<f32>> { 1913 %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) 1914 // CHECK: "tf.RFFT" 1915 %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<?xf32>, tensor<1xi32>) -> tensor<8xcomplex<f32>> 1916 return %0 : tensor<8xcomplex<f32>> 1917} 1918 1919//===----------------------------------------------------------------------===// 1920// Shape op legalization. 1921//===----------------------------------------------------------------------===// 1922 1923// CHECK-LABEL: func @shape_1D 1924func @shape_1D(%arg0: tensor<?xf32>) -> tensor<1xi32> { 1925 // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0 1926 // CHECK: [[TENSOR:%.+]] = shape.to_extent_tensor [[SHAPE]] 1927 // CHECK: [[CAST:%.+]] = index_cast [[TENSOR]] 1928 %0 = "tf.Shape"(%arg0) : (tensor<?xf32>) -> tensor<1xi32> 1929 1930 // CHECK: return [[CAST]] 1931 return %0 : tensor<1xi32> 1932} 1933 1934// CHECK-LABEL: func @shape_2D 1935func @shape_2D(%arg0: tensor<?x?xf32>) -> tensor<2xi32> { 1936 // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0 1937 // CHECK: [[TENSOR:%.+]] = shape.to_extent_tensor [[SHAPE]] 1938 // CHECK: [[CAST:%.+]] = index_cast [[TENSOR]] 1939 %0 = "tf.Shape"(%arg0) : (tensor<?x?xf32>) -> tensor<2xi32> 1940 1941 // CHECK: return [[CAST]] 1942 return %0 : tensor<2xi32> 1943} 1944 1945// CHECK-LABEL: func @shape_rankless 1946func @shape_rankless(%arg0: tensor<*xf32>) -> tensor<?xi32> { 1947 // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0 1948 // CHECK: [[TENSOR:%.+]] = shape.to_extent_tensor [[SHAPE]] 1949 // CHECK: [[CAST:%.+]] = index_cast [[TENSOR]] 1950 %0 = "tf.Shape"(%arg0) : (tensor<*xf32>) -> tensor<?xi32> 1951 1952 // CHECK: return [[CAST]] 1953 return %0 : tensor<?xi32> 1954} 1955 1956//===----------------------------------------------------------------------===// 1957// Transpose op legalization. 1958//===----------------------------------------------------------------------===// 1959 1960// CHECK-LABEL: @transpose_noop 1961func @transpose_noop(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { 1962 %permutation = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> (tensor<2xi64>) 1963 // CHECK: return %arg0 1964 %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<2x3xf32> 1965 return %0 : tensor<2x3xf32> 1966} 1967 1968// CHECK-LABEL: @transpose_2d 1969func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { 1970 %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) 1971 // CHECK: "mhlo.transpose" 1972 %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<3x2xf32> 1973 return %0 : tensor<3x2xf32> 1974} 1975 1976// CHECK-LABEL: @transpose_3d_int32 1977func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { 1978 %permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> (tensor<3xi32>) 1979 // CHECK: "mhlo.transpose" 1980 %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<1x2x3xf32>, tensor<3xi32>) -> tensor<3x2x1xf32> 1981 return %0 : tensor<3x2x1xf32> 1982} 1983 1984// CHECK-LABEL: @transpose_3d 1985func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { 1986 %permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> (tensor<3xi64>) 1987 // CHECK: "mhlo.transpose" 1988 %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<1x2x3xf32>, tensor<3xi64>) -> tensor<3x2x1xf32> 1989 return %0 : tensor<3x2x1xf32> 1990} 1991 1992// CHECK-LABEL: @transpose_dynamic_2d 1993func @transpose_dynamic_2d(%arg0: tensor<?x4xf32>) -> tensor<4x?xf32> { 1994 %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) 1995 // CHECK: "mhlo.transpose" 1996 %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<?x4xf32>, tensor<2xi64>) -> tensor<4x?xf32> 1997 return %0 : tensor<4x?xf32> 1998} 1999 2000// CHECK-LABEL: @transpose_unranked_2d 2001func @transpose_unranked_2d(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2002 %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) 2003 // CHECK: "mhlo.transpose" 2004 %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<*xf32>, tensor<2xi64>) -> tensor<*xf32> 2005 return %0 : tensor<*xf32> 2006} 2007 2008 2009//===----------------------------------------------------------------------===// 2010// Unary op legalizations. 2011//===----------------------------------------------------------------------===// 2012 2013// CHECK-LABEL: @abs 2014func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2015 // CHECK: "mhlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2016 %0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2017 return %0 : tensor<2xf32> 2018} 2019 2020// CHECK-LABEL: func @abs_dynamic 2021func @abs_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2022 // CHECK: "mhlo.abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2023 %0 = "tf.Abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2024 return %0 : tensor<?xf32> 2025} 2026 2027// CHECK-LABEL: func @abs_unranked 2028func @abs_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2029 // CHECK: "mhlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2030 %0 = "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2031 return %0 : tensor<*xf32> 2032} 2033 2034// CHECK-LABEL: @acos 2035// CHLO-LABEL: @acos 2036func @acos(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2037 // CHECK: chlo.acos %arg0 : tensor<2xf32> 2038// CHLO: %[[VAL_1:.*]] = "mhlo.compare"({{.*}}) {comparison_direction = "NE"} 2039// CHLO: %[[VAL_5:.*]] = mhlo.multiply %arg0, %arg0 2040// CHLO: %[[VAL_4:.*]] = mhlo.constant dense<1.000000e+00> 2041// CHLO: %[[VAL_6:.*]] = mhlo.subtract %[[VAL_4]], %[[VAL_5]] 2042// CHLO: %[[VAL_7:.*]] = "mhlo.sqrt"(%[[VAL_6]]) 2043// CHLO: %[[VAL_8:.*]] = mhlo.constant dense<1.000000e+00> 2044// CHLO: %[[VAL_9:.*]] = mhlo.add %[[VAL_8]], %arg0 2045// CHLO: %[[VAL_10:.*]] = mhlo.atan2 %[[VAL_7]], %[[VAL_9]] 2046// CHLO: %[[VAL_3:.*]] = mhlo.constant dense<2.000000e+00> 2047// CHLO: %[[VAL_11:.*]] = mhlo.multiply %[[VAL_3]], %[[VAL_10]] 2048// CHLO: %[[VAL_12:.*]] = mhlo.constant dense<3.14159274> 2049// CHLO: %[[VAL_13:.*]] = "mhlo.select"(%[[VAL_1]], %[[VAL_11]], %[[VAL_12]]) 2050// CHLO: return %[[VAL_13]] : tensor<2xf32> 2051 %0 = "tf.Acos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2052 return %0 : tensor<2xf32> 2053} 2054 2055// CHECK-LABEL: @acos_complex 2056// CHLO-LABEL: @acos_complex 2057func @acos_complex(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> { 2058 // CHLO: tf.Acos 2059 %0 = "tf.Acos"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> 2060 return %0 : tensor<2xcomplex<f32>> 2061} 2062 2063// CHECK-LABEL: @acos_dynamic 2064// CHLO-LABEL: @acos_dynamic 2065func @acos_dynamic(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2066 // CHECK: chlo.acos %arg0 : tensor<*xf32> 2067 // `tf.Acos` is lowered to `chlo.constant_like` operations which can only be 2068 // lowered further on ranked tensors. Unranked CHLO must be transformed to 2069 // ranked code before further lowering. 2070 // CHLO: "tf.Acos" 2071 %0 = "tf.Acos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2072 return %0 : tensor<*xf32> 2073} 2074 2075// CHECK-LABEL: @tan 2076// CHECK-SAME: (%[[ARG:.*]]: tensor<2xf32>) -> tensor<2xf32> 2077// CHLO-LABEL: @tan 2078// CHLO-SAME: (%[[ARG:.*]]: tensor<2xf32>) -> tensor<2xf32> 2079func @tan(%arg : tensor<2xf32>) -> tensor<2xf32> { 2080 // CHECK: chlo.tan %[[ARG]] : tensor<2xf32> 2081 // CHLO: %[[SINE:.*]] = "mhlo.sine"(%[[ARG]]) 2082 // CHLO %[[COSINE:.*]] = "mhlo.cosine"(%[[ARG]]) 2083 // CHLO %[[RESULT:.*]] = "mhlo.divide"(%[[SINE]], %[[COSINE]]) 2084 %result = "tf.Tan"(%arg) : (tensor<2xf32>) -> tensor<2xf32> 2085 return %result : tensor<2xf32> 2086} 2087 2088// CHECK-LABEL: @tan_unranked 2089// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32> 2090// CHLO-LABEL: @tan_unranked 2091// CHLO-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32> 2092func @tan_unranked(%arg : tensor<*xf32>) -> tensor<*xf32> { 2093 // CHECK: chlo.tan %[[ARG]] : tensor<*xf32> 2094 // CHLO: %[[SINE:.*]] = "mhlo.sine"(%[[ARG]]) 2095 // CHLO %[[COSINE:.*]] = "mhlo.cosine"(%[[ARG]]) 2096 // CHLO %[[RESULT:.*]] = "mhlo.divide"(%[[SINE]], %[[COSINE]]) 2097 %result = "tf.Tan"(%arg) : (tensor<*xf32>) -> tensor<*xf32> 2098 return %result : tensor<*xf32> 2099} 2100 2101// CHECK-LABEL: @sinh_complex 2102// CHLO-LABEL: @sinh_complex 2103func @sinh_complex(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> { 2104 // CHLO: tf.Sinh 2105 %0 = "tf.Sinh"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> 2106 return %0 : tensor<2xcomplex<f32>> 2107} 2108 2109// CHECK-LABEL: func @cast_dynamic_i2f 2110func @cast_dynamic_i2f(%arg0: tensor<?xi32>) -> tensor<?xf32> { 2111 // CHECK: "mhlo.convert"(%arg0) : (tensor<?xi32>) -> tensor<?xf32> 2112 %0 = "tf.Cast"(%arg0) : (tensor<?xi32>) -> tensor<?xf32> 2113 return %0 : tensor<?xf32> 2114} 2115 2116// CHECK-LABEL: func @cast_i2f 2117func @cast_i2f(%arg0: tensor<2xi32>) -> tensor<2xf32> { 2118 // CHECK: "mhlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> 2119 %0 = "tf.Cast"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> 2120 return %0 : tensor<2xf32> 2121} 2122 2123// CHECK-LABEL: func @cast_c2f 2124func @cast_c2f(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> { 2125 // CHECK: tf.Cast 2126 %0 = "tf.Cast"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32> 2127 return %0 : tensor<2xf32> 2128} 2129 2130// CHECK-LABEL: @ceil 2131func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2132 // CHECK: "mhlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2133 %0 = "tf.Ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2134 return %0 : tensor<2xf32> 2135} 2136 2137// CHECK-LABEL: func @ceil_dynamic 2138func @ceil_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2139 // CHECK: "mhlo.ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2140 %0 = "tf.Ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2141 return %0 : tensor<?xf32> 2142} 2143 2144// CHECK-LABEL: func @ceil_unranked 2145func @ceil_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2146 // CHECK: "mhlo.ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2147 %0 = "tf.Ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2148 return %0 : tensor<*xf32> 2149} 2150 2151// CHECK-LABEL: @complex_abs 2152func @complex_abs(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> { 2153 // CHECK: "mhlo.abs"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32> 2154 %0 = "tf.ComplexAbs"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32> 2155 return %0 : tensor<2xf32> 2156} 2157 2158// CHECK-LABEL: @cos 2159func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2160 // CHECK: "mhlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2161 %0 = "tf.Cos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2162 return %0 : tensor<2xf32> 2163} 2164 2165// CHECK-LABEL: func @cos_dynamic 2166func @cos_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2167 // CHECK: "mhlo.cosine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2168 %0 = "tf.Cos"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2169 return %0 : tensor<?xf32> 2170} 2171 2172// CHECK-LABEL: func @cos_unranked 2173func @cos_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2174 // CHECK: "mhlo.cosine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2175 %0 = "tf.Cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2176 return %0 : tensor<*xf32> 2177} 2178 2179// CHECK-LABEL: @exp 2180func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2181 // CHECK: "mhlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2182 %0 = "tf.Exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2183 return %0 : tensor<2xf32> 2184} 2185 2186// CHECK-LABEL: @expm1 2187func @expm1(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2188 // CHECK: "mhlo.exponential_minus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2189 %0 = "tf.Expm1"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2190 return %0 : tensor<2xf32> 2191} 2192 2193// CHECK-LABEL: func @exp_dynamic 2194func @exp_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2195 // CHECK: "mhlo.exponential"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2196 %0 = "tf.Exp"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2197 return %0 : tensor<?xf32> 2198} 2199 2200// CHECK-LABEL: func @exp_unranked 2201func @exp_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2202 // CHECK: "mhlo.exponential"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2203 %0 = "tf.Exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2204 return %0 : tensor<*xf32> 2205} 2206 2207// CHECK-LABEL: @floor 2208func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2209 // CHECK: "mhlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2210 %0 = "tf.Floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2211 return %0 : tensor<2xf32> 2212} 2213 2214// CHECK-LABEL: func @floor_dynamic 2215func @floor_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2216 // CHECK: "mhlo.floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2217 %0 = "tf.Floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2218 return %0 : tensor<?xf32> 2219} 2220 2221// CHECK-LABEL: func @floor_unranked 2222func @floor_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2223 // CHECK: "mhlo.floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2224 %0 = "tf.Floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2225 return %0 : tensor<*xf32> 2226} 2227 2228// CHECK-LABEL: func @invert_op_unranked 2229func @invert_op_unranked(%arg0: tensor<*xi32>) -> tensor<*xi32> { 2230 // CHECK: "mhlo.not"(%arg0) : (tensor<*xi32>) -> tensor<*xi32> 2231 %0 = "tf.Invert"(%arg0) : (tensor<*xi32>) -> tensor<*xi32> 2232 return %0 : tensor<*xi32> 2233} 2234 2235// CHECK-LABEL: @is_finite 2236func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> { 2237 // CHECK: "mhlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1> 2238 %0 = "tf.IsFinite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1> 2239 return %0 : tensor<2xi1> 2240} 2241 2242// CHECK-LABEL: func @is_finite_dynamic 2243func @is_finite_dynamic(%arg0: tensor<?xf32>) -> tensor<?xi1> { 2244 // CHECK: "mhlo.is_finite"(%arg0) : (tensor<?xf32>) -> tensor<?xi1> 2245 %0 = "tf.IsFinite"(%arg0) : (tensor<?xf32>) -> tensor<?xi1> 2246 return %0 : tensor<?xi1> 2247} 2248 2249// CHECK-LABEL: func @is_finite_unranked 2250func @is_finite_unranked(%arg0: tensor<*xf32>) -> tensor<*xi1> { 2251 // CHECK: "mhlo.is_finite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1> 2252 %0 = "tf.IsFinite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1> 2253 return %0 : tensor<*xi1> 2254} 2255 2256// CHECK-LABEL: @log 2257func @log(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2258 // CHECK: "mhlo.log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2259 %0 = "tf.Log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2260 return %0 : tensor<2xf32> 2261} 2262 2263// CHECK-LABEL: func @log_dynamic 2264func @log_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2265 // CHECK: "mhlo.log"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2266 %0 = "tf.Log"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2267 return %0 : tensor<?xf32> 2268} 2269 2270// CHECK-LABEL: func @log_unranked 2271func @log_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2272 // CHECK: "mhlo.log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2273 %0 = "tf.Log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2274 return %0 : tensor<*xf32> 2275} 2276 2277// CHECK-LABEL: @log1p 2278func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2279 // CHECK: "mhlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2280 %0 = "tf.Log1p"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2281 return %0 : tensor<2xf32> 2282} 2283 2284// CHECK-LABEL: func @log1p_dynamic 2285func @log1p_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2286 // CHECK: "mhlo.log_plus_one"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2287 %0 = "tf.Log1p"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2288 return %0 : tensor<?xf32> 2289} 2290 2291// CHECK-LABEL: func @log1p_unranked 2292func @log1p_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2293 // CHECK: "mhlo.log_plus_one"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2294 %0 = "tf.Log1p"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2295 return %0 : tensor<*xf32> 2296} 2297 2298// CHECK-LABEL: func @not_op_unranked 2299func @not_op_unranked(%arg0: tensor<*xi1>) -> tensor<*xi1> { 2300 // CHECK: "mhlo.not"(%arg0) : (tensor<*xi1>) -> tensor<*xi1> 2301 %0 = "tf.LogicalNot"(%arg0) : (tensor<*xi1>) -> tensor<*xi1> 2302 return %0 : tensor<*xi1> 2303} 2304 2305// CHECK-LABEL: @neg 2306func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2307 // CHECK: "mhlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2308 %0 = "tf.Neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2309 return %0 : tensor<2xf32> 2310} 2311 2312// CHECK-LABEL: func @neg_dynamic 2313func @neg_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2314 // CHECK: "mhlo.negate"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2315 %0 = "tf.Neg"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2316 return %0 : tensor<?xf32> 2317} 2318 2319// CHECK-LABEL: func @neg_unranked 2320func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2321 // CHECK: "mhlo.negate"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2322 %0 = "tf.Neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2323 return %0 : tensor<*xf32> 2324} 2325 2326// CHECK-LABEL: @sigmoid 2327func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2328 // CHECK-DAG: [[SCALAR:%.+]] = mhlo.constant dense<5.000000e-01> : tensor<f32> 2329 // CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0 : tensor<2xf32> 2330 // CHECK-DAG: [[SHAPE_VAL:%.+]] = shape.to_extent_tensor [[SHAPE]] 2331 // CHECK-DAG: [[HALF:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[SCALAR]], [[SHAPE_VAL]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<2xf32> 2332 // CHECK-DAG: [[R1:%.+]] = mhlo.multiply %arg0, [[HALF]] : tensor<2xf32> 2333 // CHECK-DAG: [[R2:%.+]] = "mhlo.tanh"([[R1]]) : (tensor<2xf32>) -> tensor<2xf32> 2334 // CHECK-DAG: [[R3:%.+]] = mhlo.multiply [[R2]], [[HALF]] : tensor<2xf32> 2335 // CHECK-DAG: [[R4:%.+]] = mhlo.add [[R3]], [[HALF]] : tensor<2xf32> 2336 %0 = "tf.Sigmoid"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2337 return %0 : tensor<2xf32> 2338} 2339 2340// CHECK-LABEL: @sigmoid_complex 2341func @sigmoid_complex(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> { 2342 // CHECK: [[R0:%.+]] = mhlo.constant dense<(5.000000e-01,0.000000e+00)> : tensor<complex<f32>> 2343 // CHECK-NOT: tf.Sigmoid 2344 %0 = "tf.Sigmoid"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> 2345 return %0 : tensor<2xcomplex<f32>> 2346} 2347 2348// CHECK-LABEL: @sigmoid_unranked 2349func @sigmoid_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2350 // CHECK-DAG: [[SCALAR:%.+]] = mhlo.constant dense<5.000000e-01> : tensor<f32> 2351 // CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0 : tensor<*xf32> 2352 // CHECK-DAG: [[SHAPE_VAL:%.+]] = shape.to_extent_tensor [[SHAPE]] 2353 // CHECK-DAG: [[HALF:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[SCALAR]], [[SHAPE_VAL]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<?xindex>) -> tensor<*xf32> 2354 // CHECK-DAG: [[R1:%.+]] = mhlo.multiply %arg0, [[HALF]] : tensor<*xf32> 2355 // CHECK-DAG: [[R2:%.+]] = "mhlo.tanh"([[R1]]) : (tensor<*xf32>) -> tensor<*xf32> 2356 // CHECK-DAG: [[R3:%.+]] = mhlo.multiply [[R2]], [[HALF]] : tensor<*xf32> 2357 // CHECK-DAG: [[R4:%.+]] = mhlo.add [[R3]], [[HALF]] : tensor<*xf32> 2358 %0 = "tf.Sigmoid"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2359 return %0 : tensor<*xf32> 2360} 2361 2362 2363// CHECK-LABEL: @sigmoid_grad 2364func @sigmoid_grad(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { 2365 // CHECK-DAG: [[MUL0:%.+]] = mhlo.multiply %arg1, %arg0 : tensor<2xf32> 2366 // CHECK-DAG: [[ONE:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<2xf32> 2367 // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract [[ONE]], %arg0 : tensor<2xf32> 2368 // CHECK-DAG: [[MUL1:%.+]] = mhlo.multiply [[MUL0]], [[SUB]] : tensor<2xf32> 2369 // CHECK: return [[MUL1]] 2370 %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> 2371 return %0 : tensor<2xf32> 2372} 2373 2374// CHECK-LABEL: @sigmoid_grad_complex 2375func @sigmoid_grad_complex(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> { 2376 // CHECK-DAG: [[MUL0:%.+]] = mhlo.multiply %arg1, %arg0 : tensor<2xcomplex<f32>> 2377 // CHECK-DAG: [[ONE:%.+]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<2xcomplex<f32>> 2378 // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract [[ONE]], %arg0 : tensor<2xcomplex<f32>> 2379 // CHECK-DAG: [[MUL1:%.+]] = mhlo.multiply [[MUL0]], [[SUB]] : tensor<2xcomplex<f32>> 2380 // CHECK: return [[MUL1]] 2381 %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> 2382 return %0 : tensor<2xcomplex<f32>> 2383} 2384 2385// CHECK-LABEL: @sin 2386func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2387 // CHECK: "mhlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2388 %0 = "tf.Sin"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2389 return %0 : tensor<2xf32> 2390} 2391 2392// CHECK-LABEL: func @sin_dynamic 2393func @sin_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2394 // CHECK: "mhlo.sine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2395 %0 = "tf.Sin"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2396 return %0 : tensor<?xf32> 2397} 2398 2399// CHECK-LABEL: func @sin_unranked 2400func @sin_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2401 // CHECK: "mhlo.sine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2402 %0 = "tf.Sin"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2403 return %0 : tensor<*xf32> 2404} 2405 2406// CHECK-LABEL: func @rsqrt 2407func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2408 // CHECK: "mhlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2409 %0 = "tf.Rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2410 return %0 : tensor<2xf32> 2411} 2412 2413// CHECK-LABEL: func @rsqrt_dynamic 2414func @rsqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2415 // CHECK: "mhlo.rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2416 %0 = "tf.Rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2417 return %0 : tensor<?xf32> 2418} 2419 2420// CHECK-LABEL: func @rsqrt_unranked 2421func @rsqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2422 // CHECK: "mhlo.rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2423 %0 = "tf.Rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2424 return %0 : tensor<*xf32> 2425} 2426 2427// CHECK-LABEL: func @sqrt 2428func @sqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2429 // CHECK: "mhlo.sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2430 %0 = "tf.Sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2431 return %0 : tensor<2xf32> 2432} 2433 2434// CHECK-LABEL: func @sqrt_dynamic 2435func @sqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2436 // CHECK: "mhlo.sqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2437 %0 = "tf.Sqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2438 return %0 : tensor<?xf32> 2439} 2440 2441// CHECK-LABEL: func @sqrt_unranked 2442func @sqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2443 // CHECK: "mhlo.sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2444 %0 = "tf.Sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2445 return %0 : tensor<*xf32> 2446} 2447 2448// CHECK-LABEL: func @tanh 2449func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2450 // CHECK: "mhlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2451 %0 = "tf.Tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2452 return %0 : tensor<2xf32> 2453} 2454 2455// CHECK-LABEL: func @tanh_dynamic 2456func @tanh_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2457 // CHECK: "mhlo.tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2458 %0 = "tf.Tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2459 return %0 : tensor<?xf32> 2460} 2461 2462// CHECK-LABEL: func @tanh_unranked 2463func @tanh_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2464 // CHECK: "mhlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2465 %0 = "tf.Tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2466 return %0 : tensor<*xf32> 2467} 2468 2469// CHECK-LABEL: func @bitcast 2470func @bitcast(%arg0: tensor<2xf32>) -> tensor<2xf32> { 2471 // CHECK: "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2472 %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> 2473 return %0 : tensor<2xf32> 2474} 2475 2476// CHECK-LABEL: func @bitcast_dynamic 2477func @bitcast_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { 2478 // CHECK: "mhlo.bitcast_convert"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2479 %0 = "tf.Bitcast"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> 2480 return %0 : tensor<?xf32> 2481} 2482 2483// CHECK-LABEL: func @bitcast_unranked 2484func @bitcast_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { 2485 // CHECK: "mhlo.bitcast_convert"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2486 %0 = "tf.Bitcast"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> 2487 return %0 : tensor<*xf32> 2488} 2489 2490// CHECK-LABEL: func @bitcast_same_widths 2491func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> { 2492 // CHECK: "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32> 2493 %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xi32> 2494 return %0 : tensor<2xi32> 2495} 2496 2497// CHECK-LABEL: func @bitcast_smaller_input_width 2498func @bitcast_smaller_input_width(%arg0: tensor<2xi8>) -> tensor<2xi64> { 2499 // CHECK: "tf.Bitcast"(%arg0) : (tensor<2xi8>) -> tensor<2xi64> 2500 %0 = "tf.Bitcast"(%arg0) : (tensor<2xi8>) -> tensor<2xi64> 2501 return %0 : tensor<2xi64> 2502} 2503 2504// CHECK-LABEL: func @bitcast_smaller_output_width 2505func @bitcast_smaller_output_width(%arg0: tensor<2xf32>) -> tensor<2xf16> { 2506 // CHECK: "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xf16> 2507 %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xf16> 2508 return %0 : tensor<2xf16> 2509} 2510 2511// CHECK-LABEL: reshape 2512func @reshape(%arg0: tensor<2xf32>, %arg1: tensor<2xi32>) -> tensor<2x1xf32> { 2513 // CHECK: "mhlo.reshape" 2514 %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xi32>) -> tensor<2x1xf32> 2515 return %0 : tensor<2x1xf32> 2516} 2517 2518// CHECK-LABEL: reshape_dynamic 2519func @reshape_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<2xi32>) -> tensor<?x?xf32> { 2520 // CHECK: "mhlo.dynamic_reshape" 2521 %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<?xf32>, tensor<2xi32>) -> tensor<?x?xf32> 2522 return %0 : tensor<?x?xf32> 2523} 2524 2525// CHECK-LABEL: reshape_unranked 2526func @reshape_unranked(%arg0: tensor<*xf32>, %arg1: tensor<2xi32>) -> tensor<?x?xf32> { 2527 // CHECK: "tf.Reshape" 2528 %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<*xf32>, tensor<2xi32>) -> tensor<?x?xf32> 2529 return %0 : tensor<?x?xf32> 2530} 2531 2532// CHECK-LABEL: squeeze 2533func @squeeze(%arg0: tensor<1x1x10xf32>) -> tensor<1x10xf32> { 2534 // CHECK: "mhlo.reshape" 2535 %0 = "tf.Squeeze"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32> 2536 return %0 : tensor<1x10xf32> 2537} 2538 2539// CHECK-LABEL: squeeze_dynamic 2540func @squeeze_dynamic(%arg0: tensor<?x10xf32>) -> tensor<*xf32> { 2541 // CHECK: "tf.Squeeze" 2542 %0 = "tf.Squeeze"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32> 2543 return %0 : tensor<*xf32> 2544} 2545 2546// CHECK-LABEL: expand_dims 2547func @expand_dims(%arg0: tensor<2xf32>, %axis: tensor<i32>) -> tensor<1x2xf32> { 2548 // CHECK: "mhlo.reshape" 2549 %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<2xf32>, tensor<i32>) -> tensor<1x2xf32> 2550 return %0 : tensor<1x2xf32> 2551} 2552 2553// CHECK-LABEL: expand_dims_dynamic 2554func @expand_dims_dynamic(%arg0: tensor<?x?xf32>) -> tensor<?x1x?xf32> { 2555 %axis = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> (tensor<i32>) 2556 2557 // CHECK-DAG: [[SHAPEOF:%.+]] = shape.shape_of %arg0 2558 // CHECK-DAG: [[CST0:%.+]] = constant 0 2559 // CHECK-DAG: [[CST1:%.+]] = constant 1 2560 // CHECK-DAG: [[GETEXTENT0:%.+]] = shape.get_extent [[SHAPEOF]], [[CST0]] 2561 // CHECK-DAG: [[CST1_0:%.+]] = constant 1 2562 // CHECK-DAG: [[GETEXTENT1:%.+]] = shape.get_extent [[SHAPEOF]], [[CST1_0]] 2563 // CHECK-DAG: [[FROMEXTENTS:%.+]] = shape.from_extents [[GETEXTENT0]], [[CST1]], [[GETEXTENT1]] 2564 // CHECK-DAG: [[TOEXTENTS:%.+]] = shape.to_extent_tensor [[FROMEXTENTS]] 2565 // CHECK-DAG: [[RESHAPE:%.+]] = "mhlo.dynamic_reshape"(%arg0, [[TOEXTENTS]]) 2566 %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<?x?xf32>, tensor<i32>) -> tensor<?x1x?xf32> 2567 2568 // CHECK: return [[RESHAPE]] 2569 return %0 : tensor<?x1x?xf32> 2570} 2571 2572// CHECK-LABEL: expand_dynamic_dims_rank1_axis 2573func @expand_dynamic_dims_rank1_axis(%arg0: tensor<?x?x4xf32>) -> tensor<?x1x?x4xf32> { 2574 %axis = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 2575 2576 // CHECK-DAG: [[SHAPEOF:%.+]] = shape.shape_of %arg0 2577 // CHECK-DAG: [[CST0:%.+]] = constant 0 2578 // CHECK-DAG: [[CST1:%.+]] = constant 1 2579 // CHECK-DAG: [[GETEXTENT0:%.+]] = shape.get_extent [[SHAPEOF]], [[CST0]] 2580 // CHECK-DAG: [[CST1_0:%.+]] = constant 1 2581 // CHECK-DAG: [[GETEXTENT1:%.+]] = shape.get_extent [[SHAPEOF]], [[CST1_0]] 2582 // CHECK-DAG: [[CST2:%.+]] = constant 2 2583 // CHECK-DAG: [[GETEXTENT2:%.+]] = shape.get_extent [[SHAPEOF]], [[CST2]] 2584 // CHECK-DAG: [[FROMEXTENTS:%.+]] = shape.from_extents [[GETEXTENT0]], [[CST1]], [[GETEXTENT1]], [[GETEXTENT2]] 2585 // CHECK-DAG: [[TOEXTENTS:%.+]] = shape.to_extent_tensor [[FROMEXTENTS]] 2586 // CHECK-DAG: [[RESHAPE:%.+]] = "mhlo.dynamic_reshape"(%arg0, [[TOEXTENTS]]) 2587 %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<?x?x4xf32>, tensor<1xi32>) -> tensor<?x1x?x4xf32> 2588 2589 // CHECK: return [[RESHAPE]] 2590 return %0 : tensor<?x1x?x4xf32> 2591} 2592 2593// CHECK-LABEL: func @sign 2594// CHECK-SAME: [[ARG:%arg.*]]: tensor<1x2x3x4xf32> 2595func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { 2596 // CHECK: [[SIGN:%.*]] = "mhlo.sign"([[ARG]]) 2597 // CHECK: return [[SIGN]] : tensor<1x2x3x4xf32> 2598 %0 = "tf.Sign"(%arg0) : (tensor<1x2x3x4xf32>) -> (tensor<1x2x3x4xf32>) 2599 return %0 : tensor<1x2x3x4xf32> 2600} 2601 2602// CHECK-LABEL: slice_constant_start 2603func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { 2604 // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor<1xi64> 2605 // CHECK: %[[CAST:.*]] = tensor.cast %[[START]] : tensor<1xi64> to tensor<1xi64> 2606 // CHECK: %[[START_I64:.*]] = "mhlo.convert"(%[[CAST]]) : (tensor<1xi64>) -> tensor<1xi64> 2607 // CHECK: %[[SLICED_START:.*]] = "mhlo.slice"(%[[START_I64]]) 2608 // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, 2609 // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, 2610 // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : 2611 // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<1xi64> 2612 // CHECK: %[[RESHAPED_START:.*]] = "mhlo.reshape"(%[[SLICED_START:.*]]) : 2613 // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<i64> 2614 // CHECK: %[[RESULT:.*]] = "mhlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) 2615 // CHECK-DAG-SAME: {slice_sizes = dense<2> : tensor<1xi64>} : 2616 // CHECK-DAG-SAME: (tensor<4xi32>, tensor<i64>) -> tensor<2xi32> 2617 // CHECK: return %[[RESULT]] : tensor<2xi32> 2618 %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) 2619 %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>) 2620 %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi32> 2621 return %0 : tensor<2xi32> 2622} 2623 2624// CHECK-LABEL: slice_i32_consts 2625func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> { 2626 // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor<1xi32> 2627 // CHECK: %[[START_CAST:.*]] = tensor.cast %[[START]] : tensor<1xi32> to tensor<1xi32> 2628 // CHECK: %[[START_I64:.*]] = "mhlo.convert"(%[[START_CAST]]) : (tensor<1xi32>) -> tensor<1xi64> 2629 // CHECK: %[[SLICED_START:.*]] = "mhlo.slice"(%[[START_I64]]) 2630 // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, 2631 // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, 2632 // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi64>) -> tensor<1xi64> 2633 // CHECK: %[[RESHAPED_START:.*]] = "mhlo.reshape"(%[[SLICED_START]]) : (tensor<1xi64>) -> tensor<i64> 2634 // CHECK: "mhlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<2xi32> 2635 %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> (tensor<1xi32>) 2636 %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> (tensor<1xi32>) 2637 %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> 2638 return %0 : tensor<2xi32> 2639} 2640 2641// CHECK-LABEL: slice_constant_start_negative_one_size 2642func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi32> { 2643 // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor<1xi64> 2644 // CHECK: %[[START_CAST:.*]] = tensor.cast %[[START]] : tensor<1xi64> to tensor<1xi64> 2645 // CHECK: %[[START_I64:.*]] = "mhlo.convert"(%[[START_CAST]]) : (tensor<1xi64>) -> tensor<1xi64> 2646 // CHECK: %[[SLICED_START:.*]] = "mhlo.slice"(%[[START_I64]]) 2647 // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, 2648 // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, 2649 // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi64>) -> tensor<1xi64> 2650 // CHECK: %[[RESHAPED_START:.*]] = "mhlo.reshape"(%[[SLICED_START]]) : (tensor<1xi64>) -> tensor<i64> 2651 // CHECK: %[[RESULT:.*]] = "mhlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) {slice_sizes = dense<3> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<3xi32> 2652 // CHECK: return %[[RESULT]] : tensor<3xi32> 2653 %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) 2654 %sizes = "tf.Const"() {value = dense<[-1]> : tensor<1xi64>} : () -> (tensor<1xi64>) 2655 %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi32> 2656 return %0 : tensor<3xi32> 2657} 2658 2659// CHECK-LABEL: slice_constant_start_dynamic_shape 2660func @slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { 2661 // CHECK: %[[START:.*]] = mhlo.constant dense<[1, 0]> : tensor<2xi64> 2662 // CHECK: %[[START_CAST:.*]] = tensor.cast %[[START]] : tensor<2xi64> to tensor<2xi64> 2663 // CHECK: %[[START_I64:.*]] = "mhlo.convert"(%[[START_CAST]]) : (tensor<2xi64>) -> tensor<2xi64> 2664 // CHECK: %[[SLICED_START1:.*]] = "mhlo.slice"(%[[START_I64]]) 2665 // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, 2666 // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, 2667 // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : 2668 // CHECK-DAG-SAME: (tensor<2xi64>) -> tensor<1xi64> 2669 // CHECK: %[[RESHAPED_START1:.*]] = "mhlo.reshape"(%[[SLICED_START1]]) : 2670 // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<i64> 2671 // CHECK: %[[SLICED_START2:.*]] = "mhlo.slice"(%[[START_I64]]) 2672 // CHECK-DAG-SAME: {limit_indices = dense<2> : tensor<1xi64>, 2673 // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>, 2674 // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : 2675 // CHECK-DAG-SAME: (tensor<2xi64>) -> tensor<1xi64> 2676 // CHECK: %[[RESHAPED_START2:.*]] = "mhlo.reshape"(%[[SLICED_START2]]) : 2677 // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<i64> 2678 // CHECK: %[[RESULT:.*]] = "mhlo.dynamic-slice" 2679 // CHECK-DAG-SAME: (%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) 2680 // CHECK-DAG-SAME: {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : 2681 // CHECK-DAG-SAME: (tensor<?x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32> 2682 // CHECK: return %[[RESULT]] : tensor<1x4xi32> 2683 %starts = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) 2684 %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) 2685 %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<?x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> 2686 return %0 : tensor<1x4xi32> 2687} 2688 2689// CHECK-LABEL: slice_variable_start 2690func @slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { 2691 // CHECK: %[[START_I64:.*]] = "mhlo.convert"(%arg1) : (tensor<2xi64>) -> tensor<2xi64> 2692 // CHECK: %[[SLICED_START1:.*]] = "mhlo.slice"(%[[START_I64]]) 2693 // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, 2694 // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, 2695 // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> 2696 // CHECK: %[[RESHAPED_START1:.*]] = "mhlo.reshape"(%[[SLICED_START1]]) : (tensor<1xi64>) -> tensor<i64> 2697 // CHECK: %[[SLICED_START2:.*]] = "mhlo.slice"(%[[START_I64]]) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> 2698 // CHECK: %[[RESHAPED_START2:.*]] = "mhlo.reshape"(%[[SLICED_START2]]) : (tensor<1xi64>) -> tensor<i64> 2699 // CHECK: %[[RESULT:.*]] = "mhlo.dynamic-slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32> 2700 // CHECK: return %[[RESULT]] : tensor<1x4xi32> 2701 %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) 2702 %0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> 2703 return %0 : tensor<1x4xi32> 2704} 2705 2706// CHECK-LABEL: slice_mhlo_sizes 2707func @slice_mhlo_sizes(%arg0: tensor<1x1024x4xf32>, %arg1: tensor<3xi32>) -> tensor<1x512x4xf32> { 2708 // CHECK-NOT: "tf.Slice" 2709 %0 = "mhlo.constant"() {value = dense<[1, 512, 4]> : tensor<3xi32>} : () -> tensor<3xi32> 2710 %1 = "tf.Slice"(%arg0, %arg1, %0) : (tensor<1x1024x4xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x512x4xf32> 2711 return %1 : tensor<1x512x4xf32> 2712} 2713 2714// CHECK-LABEL: slice_variable_start_negative_one_size 2715func @slice_variable_start_negative_one_size(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { 2716 // CHECK: %[[RESULT:.*]] = "tf.Slice" 2717 // CHECK: return %[[RESULT]] : tensor<1x4xi32> 2718 %sizes = "tf.Const"() {value = dense<[1, -1]> : tensor<2xi64>} : () -> (tensor<2xi64>) 2719 %0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> 2720 return %0 : tensor<1x4xi32> 2721} 2722 2723//===----------------------------------------------------------------------===// 2724// StridedSlice op legalizations. 2725//===----------------------------------------------------------------------===// 2726 2727// CHECK-LABEL: simple_strided_slice 2728func @simple_strided_slice(%input: tensor<4x8xf32>) -> tensor<3x2xf32> { 2729 %begin = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2730 %end = "tf.Const"() {value = dense<[3, 7]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2731 %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2732 2733 // CHECK: mhlo.slice 2734 // CHECK-DAG-SAME: start_indices = dense<[0, 1]> 2735 // CHECK-DAG-SAME: limit_indices = dense<[3, 7]> 2736 // CHECK-DAG-SAME: strides = dense<[1, 3]> 2737 // CHECK-SAME: -> tensor<3x2xf32> 2738 2739 %output = "tf.StridedSlice"(%input, %begin, %end, %strides) 2740 : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<3x2xf32> 2741 return %output : tensor<3x2xf32> 2742} 2743 2744// CHECK-LABEL: dynamic_strided_slice 2745func @dynamic_strided_slice(%input: tensor<?x8xf32>) -> tensor<?x2xf32> { 2746 %begin = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2747 %end = "tf.Const"() {value = dense<[3, 7]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2748 %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2749 2750 // CHECK: "tf.StridedSlice" 2751 %output = "tf.StridedSlice"(%input, %begin, %end, %strides) 2752 : (tensor<?x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x2xf32> 2753 return %output : tensor<?x2xf32> 2754} 2755 2756// CHECK-LABEL: strided_slice_negative_indices 2757func @strided_slice_negative_indices(%input: tensor<4x8xf32>) -> tensor<3x2xf32> { 2758 %begin = "tf.Const"() {value = dense<[-1, -2]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2759 %end = "tf.Const"() {value = dense<[-4, -8]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2760 %strides = "tf.Const"() {value = dense<[-1, -3]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2761 2762 // CHECK: "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} 2763 2764 // CHECK: mhlo.slice 2765 // CHECK-DAG-SAME: start_indices = dense<[0, 1]> 2766 // CHECK-DAG-SAME: limit_indices = dense<[3, 7]> 2767 // CHECK-DAG-SAME: strides = dense<[1, 3]> 2768 // CHECK-SAME: -> tensor<3x2xf32> 2769 2770 %output = "tf.StridedSlice"(%input, %begin, %end, %strides) 2771 : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<3x2xf32> 2772 return %output : tensor<3x2xf32> 2773} 2774 2775// CHECK-LABEL: dynamic_strided_slice_negative_indices 2776func @dynamic_strided_slice_negative_indices(%input: tensor<?x8xf32>) -> tensor<?x2xf32> { 2777 %begin = "tf.Const"() {value = dense<[-1, -2]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2778 %end = "tf.Const"() {value = dense<[-4, -8]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2779 %strides = "tf.Const"() {value = dense<[-1, -3]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2780 2781 // CHECK: tf.StridedSlice 2782 %output = "tf.StridedSlice"(%input, %begin, %end, %strides) 2783 : (tensor<?x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x2xf32> 2784 return %output : tensor<?x2xf32> 2785} 2786 2787// CHECK-LABEL: strided_slice_range_clamping 2788func @strided_slice_range_clamping(%input: tensor<4x8xf32>) -> tensor<1x3xf32> { 2789 %begin = "tf.Const"() {value = dense<[-4, -10]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2790 %end = "tf.Const"() {value = dense<[1, 10]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2791 %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) 2792 2793 // CHECK: mhlo.slice 2794 // CHECK-DAG-SAME: start_indices = dense<[0, 0]> 2795 // CHECK-DAG-SAME: limit_indices = dense<[1, 8]> 2796 // CHECK-DAG-SAME: strides = dense<[1, 3]> 2797 // CHECK-SAME: -> tensor<1x3xf32> 2798 %output = "tf.StridedSlice"(%input, %begin, %end, %strides) 2799 : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x3xf32> 2800 return %output : tensor<1x3xf32> 2801} 2802 2803// CHECK-LABEL: strided_slice_empty 2804func @strided_slice_empty(%input: tensor<4xf32>) -> tensor<0xf32> { 2805 %begin = "tf.Const"() {value = dense<[-4]> : tensor<1xi32>} : () -> (tensor<1xi32>) 2806 %end = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>) 2807 %strides = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>) 2808 2809 // CHECK: mhlo.constant dense<> : tensor<0xf32> 2810 %output = "tf.StridedSlice"(%input, %begin, %end, %strides) 2811 : (tensor<4xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xf32> 2812 return %output : tensor<0xf32> 2813} 2814 2815// CHECK-LABEL: strided_slice_begin_end_mask 2816// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<4x128x1024xf32> 2817func @strided_slice_begin_end_mask(%input: tensor<4x128x1024xf32>) { 2818 2819 // For StridedSlice 2820 // Dim #: 0, 1, 2 2821 // Input shape: [4, 128, 1024] 2822 // Begin: 1, 4, -3 2823 // End: 8, 65, 42 2824 // Stride: 1, 4, -1 2825 // Begin mask: 0, 0, 1 (= 1) 2826 // End mask: 1, 0, 0 (= 4) 2827 2828 // So result shape: 2829 // Dim #0: begin mask (1) -> begin = 0; end 8 canonicalized to 4: so 4 2830 // Dim #1: 4 to 65 stride 4: so 16 2831 // Dim #2: begin -3 + 1024 = 1021; end mask (1) -> end = -1: so 1022 2832 // result shape: [4, 16, 1022] 2833 2834 %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>) 2835 %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>) 2836 %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>) 2837 2838 // CHECK: %[[REVERSE:.*]] = "mhlo.reverse"(%[[INPUT]]) 2839 2840 // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[REVERSE]]) 2841 // CHECK-DAG-SAME: limit_indices = dense<[4, 65, 1024]> 2842 // CHECK-DAG-SAME: start_indices = dense<[0, 4, 2]> 2843 // CHECK-DAG-SAME: strides = dense<[1, 4, 1]> 2844 // CHECK-SAME: -> tensor<4x16x1022xf32> 2845 2846 %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 1, end_mask = 4} : (tensor<4x128x1024xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<4x16x1022xf32> 2847 2848 // CHECK: "mhlo.reshape"(%[[SLICE]]) 2849 // CHECK-SAME: -> tensor<4x16x1022xf32> 2850 2851 return 2852} 2853 2854// CHECK-LABEL: strided_slice_shrink_axis_mask 2855// CHECK-SAME: %[[INPUT:.+]]: tensor<4x128x1024xf32> 2856func @strided_slice_shrink_axis_mask(%input: tensor<4x128x1024xf32>) { 2857 2858 // For StridedSlice 2859 // Dim #: 0, 1, 2 2860 // Input shape: [4, 128, 1024] 2861 // Begin: 1, 4, -3 2862 // End: 8, 65, 42 2863 // Stride: 1, 4, -1 2864 // Begin mask: 1, 0, 0 (= 1) 2865 // End mask: 0, 0, 1 (= 4) 2866 // Shrink axis mask: 1, 0, 1 (= 5) 2867 2868 // So result shape: 2869 // Dim #0: shrink axis, take value at [1] 2870 // Dim #1: 4 to 65 stride 4: so 16 2871 // Dim #2: shrink axis, take value at [-3] 2872 // result shape: [16] 2873 2874 // As output shape of StridedSlice differs, a reshape will follow. 2875 2876 %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>) 2877 %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>) 2878 %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>) 2879 2880 // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]]) 2881 // CHECK-DAG-SAME: limit_indices = dense<[1, 65, 1022]> 2882 // CHECK-DAG-SAME: start_indices = dense<[0, 4, 1021]> 2883 // CHECK-DAG-SAME: strides = dense<[1, 4, 1]> 2884 // CHECK-SAME: -> tensor<1x16x1xf32> 2885 2886 %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 1, end_mask = 4, shrink_axis_mask = 5} : (tensor<4x128x1024xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<16xf32> 2887 2888 // CHECK: "mhlo.reshape"(%[[SLICE]]) 2889 // CHECK-SAME: -> tensor<16xf32> 2890 2891 return 2892} 2893 2894// CHECK-LABEL: strided_slice_ellipsis_mask 2895// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<2x4x8x16x32x64xf32> 2896func @strided_slice_ellipsis_mask(%input: tensor<2x4x8x16x32x64xf32>) { 2897 // For StridedSlice input[1, ..., 8:, :10, 2:6:2] 2898 // The ellipsis mask is applied to dim #1, #2, i.e, we get canonicalized 2899 // slice input[1, :, :, 8:, :10, 2:6:2] 2900 2901 // The start, limit indices and strides attributes of mhlo.slice would 2902 // reflect the canonicalized slice. 2903 // As output shape of StridedSlice differs, a reshape will follow. 2904 2905 %begin = "tf.Const"() {value = dense<[1, 0, 8, 1, 2]> : tensor<5xi32>} : () -> (tensor<5xi32>) 2906 %end = "tf.Const"() {value = dense<[2, 0, 10, 10, 6]> : tensor<5xi32>} : () -> (tensor<5xi32>) 2907 %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 2]> : tensor<5xi32>} : () -> (tensor<5xi32>) 2908 2909 // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]]) 2910 // CHECK-DAG-SAME: limit_indices = dense<[2, 4, 8, 16, 10, 6]> : tensor<6xi64> 2911 // CHECK-DAG-SAME: start_indices = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64> 2912 // CHECK-DAG-SAME: strides = dense<[1, 1, 1, 1, 1, 2]> : tensoe<6xi64> 2913 // CHECK-SAME: -> tensor<1x4x8x8x10x2xf32> 2914 %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 8, end_mask = 4, shrink_axis_mask = 1, ellipsis_mask = 2} : (tensor<2x4x8x16x32x64xf32>, tensor<5xi32>, tensor<5xi32>, tensor<5xi32>) -> tensor<4x8x8x10x2xf32> 2915 2916 // CHECK: "mhlo.reshape"(%[[SLICE]]) 2917 // CHECK-SAME: -> tensor<4x8x8x10x2xf32> 2918 2919 return 2920} 2921 2922// CHECK-LABEL: strided_slice_new_axis_mask 2923// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<2x4x8x16x32x64xf32> 2924func @strided_slice_new_axis_mask(%input: tensor<2x4x8x16x32x64xf32>) { 2925 // For StridedSlice input[1, tf.new_axis, ..., 8:, :10, 2:6:2, tf.new_axis] 2926 // New axis mask is at index 1 and 6 of sparse spec, so 2927 // new_axis_mask = 2^1 + 2^6 = 66 2928 // The ellipsis mask is applied to dim #1, #2 of input i.e, we get 2929 // canonicalized slice input[1, :, :, 8:, :10, 2:6:2] 2930 // This is then reshaped to add the new axes. 2931 2932 // The start, limit indices and strides attributes of mhlo.slice would 2933 // reflect the canonicalized slice. 2934 // As output shape of StridedSlice differs, a reshape will follow to reflect 2935 // new axes added. 2936 2937 %begin = "tf.Const"() {value = dense<[1, 0, 0, 8, 1, 2, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) 2938 %end = "tf.Const"() {value = dense<[2, 0, 0, 10, 10, 6, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) 2939 %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 1, 2, 1]> : tensor<7xi32>} : () -> (tensor<7xi32>) 2940 2941 // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]]) 2942 // CHECK-DAG-SAME: limit_indices = dense<[2, 4, 8, 16, 10, 6]> : tensor<6xi64> 2943 // CHECK-DAG-SAME: start_indices = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64> 2944 // CHECK-DAG-SAME: strides = dense<[1, 1, 1, 1, 1, 2]> : tensoe<6xi64> 2945 // CHECK-SAME: -> tensor<1x4x8x8x10x2xf32> 2946 %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 16, end_mask = 8, shrink_axis_mask = 1, ellipsis_mask = 4, new_axis_mask = 66} : (tensor<2x4x8x16x32x64xf32>, tensor<7xi32>, tensor<7xi32>, tensor<7xi32>) -> tensor<1x4x8x8x10x2x1xf32> 2947 2948 // CHECK: "mhlo.reshape"(%[[SLICE]]) 2949 // CHECK-SAME: -> tensor<1x4x8x8x10x2x1xf32> 2950 2951 return 2952} 2953 2954// CHECK-LABEL: strided_slice_implicit_ellipsis_mask( 2955// CHECK-SAME: [[INPUT:%.*]]: tensor<10x16x2xf32> 2956func @strided_slice_implicit_ellipsis_mask(%input: tensor<10x16x2xf32>) -> tensor<2x16x2xf32> { 2957 // StridedSlice gets input[8:10], which is same as input[8:10, ...] 2958 // The start_indices, limit_indices, and strides attribute of mhlo.slice 2959 // reflect the canonicalized slice. 2960 %begin = "tf.Const"() {value = dense<8> : tensor<1xi32>} : () -> tensor<1xi32> 2961 %end = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32> 2962 %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 2963 // CHECK: [[SLICE:%.*]] = "mhlo.slice"([[INPUT]]) 2964 // CHECK-DAG-SAME: limit_indices = dense<[10, 16, 2]> : tensor<3xi64> 2965 // CHECK-DAG-SAME: start_indices = dense<[8, 0, 0]> : tensor<3xi64> 2966 // CHECK-DAG-SAME: strides = dense<1> : tensor<3xi64> 2967 // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[SLICE]]) : (tensor<2x16x2xf32>) -> tensor<2x16x2xf32> 2968 %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = f32} : (tensor<10x16x2xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2x16x2xf32> 2969 // CHECK: return [[RESHAPE]] : tensor<2x16x2xf32> 2970 return %0 : tensor<2x16x2xf32> 2971} 2972 2973// CHECK-LABEL: strided_slice_nonconstant_begin_end 2974func @strided_slice_nonconstant_begin_end(%arg0: tensor<i32>, %arg1: tensor<32x1x97xi32>) -> (tensor<1x97xi32>) { 2975 // In this case, the `begin` and `end` inputs are unknown at compile time -- 2976 // so the StridedSlice needs to slice these vectors and use that as input to 2977 // an HLO dynamic slice. 2978 %begin = "tf.Pack"(%arg0) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32> 2979 %0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> 2980 %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 2981 %2 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32> 2982 %end = "tf.Pack"(%2) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32> 2983 // CHECK: %[[A:.*]] = "mhlo.reshape"(%arg0) : (tensor<i32>) -> tensor<1xi32> 2984 // CHECK-NEXT: %[[BEGIN:.*]] = "mhlo.concatenate"(%[[A]]) 2985 // CHECK-DAG-SAME: {dimension = 0 : i64} : (tensor<1xi32>) -> tensor<1xi32> 2986 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32> 2987 // CHECK-NEXT: %[[INDEX:.*]] = "mhlo.slice"(%[[BEGIN]]) 2988 // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, 2989 // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, 2990 // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32> 2991 // CHECK-NEXT: %[[INDEX2:.*]] = "mhlo.reshape"(%[[INDEX]]) : (tensor<1xi32>) -> tensor<i32> 2992 // CHECK-NEXT: %[[CMP:.*]] = chlo.broadcast_compare %[[INDEX2]], %[[ZERO]] 2993 // CHECK-DAG-SAME: {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1> 2994 // CHECK-NEXT: %[[DIM:.*]] = mhlo.constant dense<32> : tensor<i32> 2995 // CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[DIM]], %[[INDEX2]] : (tensor<i32>, tensor<i32>) -> tensor<i32> 2996 // CHECK-NEXT: %[[INDEX3:.*]] = "mhlo.select"(%[[CMP]], %[[WRAP]], %[[INDEX2]]) : 2997 // CHECK-DAG-SAME: (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32> 2998 // CHECK-NEXT: %[[SLICED:.*]] = "mhlo.dynamic-slice" 2999 // CHECK-DAG-SAME: (%arg1, %[[INDEX3]], %[[ZERO]], %[[ZERO]]) 3000 // CHECK-DAG-SAME: {slice_sizes = dense<[1, 1, 97]> : tensor<3xi64>} : 3001 // CHECK-DAG-SAME: (tensor<32x1x97xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<1x97xi32> 3002 // CHECK-NEXT: %[[FINAL:.*]] = "mhlo.reshape"(%[[SLICED]]) : (tensor<1x97xi32>) -> tensor<1x97xi32> 3003 %result = "tf.StridedSlice"(%arg1, %begin, %end, %1) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3004 // CHECK-NEXT: return %[[FINAL]] : tensor<1x97xi32> 3005 return %result : tensor<1x97xi32> 3006} 3007 3008// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_1 3009func @strided_slice_nonconstant_begin_end_stride_1(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>, %strides: tensor<1xi32>) -> (tensor<1x97xi32>) { 3010 // Dynamic stride: when `begin` and `end` inputs are unknown at compile time, 3011 // `strides` must be known. 3012 // CHECK: tf.StridedSlice 3013 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3014 return %result : tensor<1x97xi32> 3015} 3016 3017// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_2 3018func @strided_slice_nonconstant_begin_end_stride_2(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { 3019 // Invalid stride (not equal to 1): when `begin` and `end` inputs are unknown 3020 // at compile time, `strides` must be known to have all 1 values. 3021 %strides = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> 3022 // CHECK: tf.StridedSlice 3023 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3024 return %result : tensor<1x97xi32> 3025} 3026 3027// CHECK-LABEL: strided_slice_nonconstant_begin_end_invalid_elem_count 3028func @strided_slice_nonconstant_begin_end_invalid_elem_count(%input: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>) -> tensor<6x10xf32> { 3029 %strides = "tf.Const"() { value = dense<[1, 1]> : tensor<2xi64> } : () -> tensor<2xi64> 3030 // When begin/end are dynamic, the number of output elements must be equal to 3031 // the number of input elements sliced. 3032 // CHECK: tf.StridedSlice 3033 %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<6x10xf32> 3034 return %0 : tensor<6x10xf32> 3035} 3036 3037// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_begin_mask 3038func @strided_slice_nonconstant_begin_end_and_begin_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { 3039 // Begin mask: When `begin` and `end` inputs are unknown at compile time, we 3040 // can't support a begin mask. 3041 %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3042 // CHECK: tf.StridedSlice 3043 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3044 return %result : tensor<1x97xi32> 3045} 3046 3047// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_end_mask 3048func @strided_slice_nonconstant_begin_end_and_end_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { 3049 // End mask: When `begin` and `end` inputs are unknown at compile time, we 3050 // can't support an end mask. 3051 %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3052 // CHECK: tf.StridedSlice 3053 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3054 return %result : tensor<1x97xi32> 3055} 3056 3057// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_new_axis_mask 3058func @strided_slice_nonconstant_begin_end_and_new_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { 3059 // New axis mask: When `begin` and `end` inputs are unknown at compile time, 3060 // we can't support a new_axis mask. 3061 %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3062 // CHECK: tf.StridedSlice 3063 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 15 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3064 return %result : tensor<1x97xi32> 3065} 3066 3067// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_ellipsis_mask 3068func @strided_slice_nonconstant_begin_end_and_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { 3069 // This ellipsis mask is not supported because it does not refer to the last 3070 // dimension. 3071 // [0, 1, 0] = 2 3072 %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3073 // CHECK: tf.StridedSlice 3074 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 2 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3075 return %result : tensor<1x97xi32> 3076} 3077 3078// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask 3079func @strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { 3080 // This ellipsis mask is supported because it refers to the last dimension. 3081 // [1, 0, 0] = 4 3082 %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3083 // CHECK: mhlo.dynamic-slice 3084 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 4 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3085 return %result : tensor<1x97xi32> 3086} 3087 3088// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask 3089func @strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { 3090 // This shrink_axis mask is supported because it refers to a major dimension. 3091 // [1, 1, 1] = 7 3092 %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3093 // CHECK: mhlo.dynamic-slice 3094 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 7 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3095 return %result : tensor<1x97xi32> 3096} 3097 3098// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_invalid_shrink_axis_mask 3099func @strided_slice_nonconstant_begin_end_and_invalid_shrink_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { 3100 // This shrink_axis mask is unsupported because it does not refer to a major 3101 // dimension. 3102 // [0, 1, 0] = 2 3103 %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3104 // CHECK: tf.StridedSlice 3105 %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> 3106 return %result : tensor<1x97xi32> 3107} 3108 3109 3110//===----------------------------------------------------------------------===// 3111// Reduction op legalizations. 3112//===----------------------------------------------------------------------===// 3113 3114// CHECK-LABEL: func @mean 3115func @mean(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { 3116 // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf32> 3117 // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 3118 // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { 3119 // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f32>, %[[ARGB:.*]]: tensor<f32>): 3120 // CHECK: %[[REDUCE_BODY_RESULT:.*]] = mhlo.add %[[ARGA]], %[[ARGB]] : tensor<f32> 3121 // CHECK: "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f32>) -> () 3122 // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32> 3123 // CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<8.000000e+00> : tensor<f32> 3124 // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %[[DIVISOR]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> 3125 // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[MEAN]]) : (tensor<4xf32>) -> tensor<4xf16> 3126 // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> 3127 // CHECK: return %[[RESULT]] : tensor<4x1xf16> 3128 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3129 %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> 3130 return %0 : tensor<4x1xf16> 3131} 3132 3133// CHECK-LABEL: func @mean_scalar_dim 3134func @mean_scalar_dim(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { 3135 // Verify that tf.Mean op with scalar attributes are lowered successfully. 3136 3137 // CHECK-NOT: tf.Mean 3138 %dimension = "tf.Const"() { value = dense<1> : tensor<i64> } : () -> tensor<i64> 3139 %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<i64>) -> tensor<4x1xf16> 3140 return %0 : tensor<4x1xf16> 3141} 3142 3143// CHECK-LABEL: func @mean_dynamic 3144func @mean_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> { 3145 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3146 // CHECK: "tf.Mean" 3147 %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<4x?xf16>, tensor<1xi64>) -> tensor<4x1xf16> 3148 return %0 : tensor<4x1xf16> 3149} 3150 3151// CHECK-LABEL: func @sum 3152func @sum(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { 3153 // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf32> 3154 // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 3155 // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { 3156 // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f32>, %[[ARGB:.*]]: tensor<f32>): 3157 // CHECK: %[[REDUCE_BODY_RESULT:.*]] = mhlo.add %[[ARGA]], %[[ARGB]] : tensor<f32> 3158 // CHECK: "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f32>) -> () 3159 // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32> 3160 // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf32>) -> tensor<4xf16> 3161 // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> 3162 // CHECK: return %[[RESULT]] : tensor<4x1xf16> 3163 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3164 %0 = "tf.Sum"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> 3165 return %0 : tensor<4x1xf16> 3166} 3167 3168// CHECK-LABEL: func @sum_dynamic 3169func @sum_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> { 3170 // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x?xf16>) -> tensor<4x?xf32> 3171 // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 3172 // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { 3173 // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f32>, %[[ARGB:.*]]: tensor<f32>): 3174 // CHECK: %[[REDUCE_BODY_RESULT:.*]] = mhlo.add %[[ARGA]], %[[ARGB]] : tensor<f32> 3175 // CHECK: "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f32>) -> () 3176 // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x?xf32>, tensor<f32>) -> tensor<4xf32> 3177 // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf32>) -> tensor<4xf16> 3178 // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> 3179 // CHECK: return %[[RESULT]] : tensor<4x1xf16> 3180 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3181 %0 = "tf.Sum"(%arg0, %dimension) { keep_dims = true }: (tensor<4x?xf16>, tensor<1xi64>) -> tensor<4x1xf16> 3182 return %0 : tensor<4x1xf16> 3183} 3184 3185// CHECK-LABEL: func @max 3186func @max(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { 3187 // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf16> 3188 // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0xFC00> : tensor<f16> 3189 // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { 3190 // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f16>, %[[ARGB:.*]]: tensor<f16>): 3191 // CHECK: %[[REDUCE_BODY_RESULT:.*]] = mhlo.maximum %[[ARGA]], %[[ARGB]] : tensor<f16> 3192 // CHECK: "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f16>) -> () 3193 // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf16>, tensor<f16>) -> tensor<4xf16> 3194 // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf16>) -> tensor<4xf16> 3195 // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> 3196 // CHECK: return %[[RESULT]] : tensor<4x1xf16> 3197 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3198 %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> 3199 return %0 : tensor<4x1xf16> 3200} 3201 3202// CHECK-LABEL: func @max_qint 3203// Regression test to ensure we don't crash getting the initial value for 3204// tf.Max when using quantized integer types. 3205func @max_qint(%arg0: tensor<4x8x!tf.qint8>) -> tensor<4x1x!tf.qint8> { 3206 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3207 %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf.qint8>, tensor<1xi64>) -> tensor<4x1x!tf.qint8> 3208 return %0 : tensor<4x1x!tf.qint8> 3209} 3210 3211// CHECK-LABEL: func @max_dynamic 3212func @max_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> { 3213 // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x?xf16>) -> tensor<4x?xf16> 3214 // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0xFC00> : tensor<f16> 3215 // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { 3216 // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f16>, %[[ARGB:.*]]: tensor<f16>): 3217 // CHECK: %[[REDUCE_BODY_RESULT:.*]] = mhlo.maximum %[[ARGA]], %[[ARGB]] : tensor<f16> 3218 // CHECK: "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f16>) -> () 3219 // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x?xf16>, tensor<f16>) -> tensor<4xf16> 3220 // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf16>) -> tensor<4xf16> 3221 // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> 3222 // CHECK: return %[[RESULT]] : tensor<4x1xf16> 3223 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3224 %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x?xf16>, tensor<1xi64>) -> tensor<4x1xf16> 3225 return %0 : tensor<4x1xf16> 3226} 3227 3228// CHECK-LABEL: func @min 3229func @min(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { 3230 // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf16> 3231 // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0x7C00> : tensor<f16> 3232 // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { 3233 // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f16>, %[[ARGB:.*]]: tensor<f16>): 3234 // CHECK: %[[REDUCE_BODY_RESULT:.*]] = mhlo.minimum %[[ARGA]], %[[ARGB]] : tensor<f16> 3235 // CHECK: "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f16>) -> () 3236 // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf16>, tensor<f16>) -> tensor<4xf16> 3237 // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf16>) -> tensor<4xf16> 3238 // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> 3239 // CHECK: return %[[RESULT]] : tensor<4x1xf16> 3240 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3241 %0 = "tf.Min"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> 3242 return %0 : tensor<4x1xf16> 3243} 3244 3245// CHECK-LABEL: func @min_qint 3246// Regression test to ensure we don't crash getting the initial value for 3247// tf.Min when using quantized integer types. 3248func @min_qint(%arg0: tensor<4x8x!tf.qint8>) -> tensor<4x1x!tf.qint8> { 3249 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3250 %0 = "tf.Min"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf.qint8>, tensor<1xi64>) -> tensor<4x1x!tf.qint8> 3251 return %0 : tensor<4x1x!tf.qint8> 3252} 3253 3254// CHECK-LABEL: func @prod 3255func @prod(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { 3256 // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf32> 3257 // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> 3258 // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( { 3259 // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f32>, %[[ARGB:.*]]: tensor<f32>): 3260 // CHECK: %[[REDUCE_BODY_RESULT:.*]] = mhlo.multiply %[[ARGA]], %[[ARGB]] : tensor<f32> 3261 // CHECK: "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f32>) -> () 3262 // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32> 3263 // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf32>) -> tensor<4xf16> 3264 // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> 3265 // CHECK: return %[[RESULT]] : tensor<4x1xf16> 3266 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3267 %0 = "tf.Prod"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> 3268 return %0 : tensor<4x1xf16> 3269} 3270 3271// CHECK-LABEL: func @prod_qint 3272// Regression test to ensure we don't crash getting the initial value for 3273// tf.Prod when using quantized integer types. 3274func @prod_qint(%arg0: tensor<4x8x!tf.qint8>) -> tensor<4x1x!tf.qint8> { 3275 %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> 3276 %0 = "tf.Prod"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf.qint8>, tensor<1xi64>) -> tensor<4x1x!tf.qint8> 3277 return %0 : tensor<4x1x!tf.qint8> 3278} 3279 3280// CHECK-LABEL: @all 3281func @all(%input: tensor<4x8xi1>) -> tensor<4xi1> { 3282 %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3283 // CHECK: %[[INIT:.*]] = mhlo.constant dense<true> : tensor<i1> 3284 // CHECK: "mhlo.reduce"(%{{.*}}, %[[INIT]]) ( { 3285 // CHECK: ^{{.*}}(%[[ARGA:.*]]: tensor<i1>, %[[ARGB:.*]]: tensor<i1>): 3286 // CHECK: %[[AND:.*]] = mhlo.and %[[ARGA]], %[[ARGB]] : tensor<i1> 3287 // CHECK: "mhlo.return"(%[[AND]]) : (tensor<i1>) -> () 3288 // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xi1>, tensor<i1>) -> tensor<4xi1> 3289 %0 = "tf.All"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1> 3290 return %0 : tensor<4xi1> 3291} 3292 3293// CHECK-LABEL: @all_keep_dim 3294func @all_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> { 3295 // CHECK: "mhlo.reshape"(%{{.*}}) : (tensor<4xi1>) -> tensor<4x1xi1> 3296 %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3297 %0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1> 3298 return %0 : tensor<4x1xi1> 3299} 3300 3301// CHECk-LABEL: @all_dynamic 3302func @all_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> { 3303 %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3304 // CHECK: %[[ARG:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<4x?xi1>) -> tensor<4x?xi1> 3305 // CHECK: "mhlo.reduce"(%[[ARG]] 3306 %0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1> 3307 return %0 : tensor<4x1xi1> 3308} 3309 3310// CHECK-LABEL: @any 3311func @any(%input: tensor<4x8xi1>) -> tensor<4xi1> { 3312 %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3313 // CHECK: %[[INIT:.*]] = mhlo.constant dense<false> : tensor<i1> 3314 // CHECK: "mhlo.reduce"(%{{.*}}, %[[INIT]]) ( { 3315 // CHECK: ^{{.*}}(%[[ARGA:.*]]: tensor<i1>, %[[ARGB:.*]]: tensor<i1>): 3316 // CHECK: %[[AND:.*]] = mhlo.or %[[ARGA]], %[[ARGB]] : tensor<i1> 3317 // CHECK: "mhlo.return"(%[[AND]]) : (tensor<i1>) -> () 3318 // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xi1>, tensor<i1>) -> tensor<4xi1> 3319 %0 = "tf.Any"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1> 3320 return %0 : tensor<4xi1> 3321} 3322 3323// CHECK-LABEL: @any_keep_dim 3324func @any_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> { 3325 // CHECK: "mhlo.reshape"(%{{.*}}) : (tensor<4xi1>) -> tensor<4x1xi1> 3326 %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3327 %0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1> 3328 return %0 : tensor<4x1xi1> 3329} 3330 3331// CHECk-LABEL: @any_dynamic 3332func @any_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> { 3333 %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 3334 // CHECK: %[[ARG:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<4x?xi1>) -> tensor<4x?xi1> 3335 // CHECK: "mhlo.reduce"(%[[ARG]] 3336 %0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1> 3337 return %0 : tensor<4x1xi1> 3338} 3339 3340//===----------------------------------------------------------------------===// 3341// Tile op legalizations. 3342//===----------------------------------------------------------------------===// 3343 3344// CHECK-LABEL: func @tile_by_reshape 3345func @tile_by_reshape(%arg0: tensor<4x8xf32>) -> tensor<28x24xf32> { 3346 // CHECK: %[[BROADCASTED:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>} : (tensor<4x8xf32>) -> tensor<7x4x3x8xf32> 3347 // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[BROADCASTED]]) : (tensor<7x4x3x8xf32>) -> tensor<28x24xf32> 3348 // CHECK: return %[[RESULT]] : tensor<28x24xf32> 3349 %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi64> } : () -> tensor<2xi64> 3350 %0 = "tf.Tile"(%arg0, %multiples) : (tensor<4x8xf32>, tensor<2xi64>) -> tensor<28x24xf32> 3351 return %0 : tensor<28x24xf32> 3352} 3353 3354// CHECK-LABEL: func @tile_just_broadcast 3355func @tile_just_broadcast(%arg0: tensor<1x1xf32>) -> tensor<7x3xf32> { 3356 // CHECK: %[[RESULT:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<7x3xf32> 3357 // CHECK: return %[[RESULT]] : tensor<7x3xf32> 3358 %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi64> } : () -> tensor<2xi64> 3359 %0 = "tf.Tile"(%arg0, %multiples) : (tensor<1x1xf32>, tensor<2xi64>) -> tensor<7x3xf32> 3360 return %0 : tensor<7x3xf32> 3361} 3362 3363//===----------------------------------------------------------------------===// 3364// ArgMax op legalizations. 3365//===----------------------------------------------------------------------===// 3366 3367// CHECK-LABEL: func @argmax_i64_input_i32_output_axis_0 3368func @argmax_i64_input_i32_output_axis_0(%arg0: tensor<3x7xi64>) -> tensor<7xi32> { 3369 // CHECK: %[[INIT:.*]] = mhlo.constant dense<-9223372036854775808> : tensor<i64> 3370 // CHECK: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i32> 3371 // CHECK: %[[INDEX:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<3x7xi32> 3372 // CHECK: %[[REDUCE:.*]]:2 = "mhlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]]) 3373 // CHECK: ^bb0(%[[ARG1:.*]]: tensor<i64>, %[[ARG2:.*]]: tensor<i32>, %[[ARG3:.*]]: tensor<i64>, %[[ARG4:.*]]: tensor<i32>): 3374 // CHECK: %[[COMPARE:.*]] = "mhlo.compare"(%[[ARG1]], %[[ARG3]]) {comparison_direction = "GT"} : (tensor<i64>, tensor<i64>) -> tensor<i1> 3375 // CHECK: %[[RESULT1:.*]] = "mhlo.select"(%[[COMPARE]], %[[ARG1]], %[[ARG3]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64> 3376 // CHECK: %[[RESULT2:.*]] = "mhlo.select"(%[[COMPARE]], %[[ARG2]], %[[ARG4]]) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32> 3377 // CHECK: "mhlo.return"(%[[RESULT1]], %[[RESULT2]]) : (tensor<i64>, tensor<i32>) -> () 3378 // CHECK: return %[[REDUCE]]#1 : tensor<7xi32> 3379 %axis = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32> 3380 %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xi64>, tensor<i32>) -> tensor<7xi32> 3381 return %0 : tensor<7xi32> 3382} 3383 3384// CHECK-LABEL: func @argmax_f32_input_i64_output_axis_1 3385func @argmax_f32_input_i64_output_axis_1(%arg0: tensor<3x7xf32>) -> tensor<3xi64> { 3386 // CHECK: %[[INIT:.*]] = mhlo.constant dense<0xFF800000> : tensor<f32> 3387 // CHECK: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i64> 3388 // CHECK: %[[INDEX:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x7xi64> 3389 // CHECK: %[[REDUCE:.*]]:2 = "mhlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]]) 3390 // CHECK: return %[[REDUCE]]#1 : tensor<3xi64> 3391 %axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32> 3392 %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xf32>, tensor<i32>) -> tensor<3xi64> 3393 return %0 : tensor<3xi64> 3394} 3395 3396// CHECK-LABEL: func @argmax_dynamic_shape_input_output 3397func @argmax_dynamic_shape_input_output(%arg0: tensor<3x?xi32>) -> tensor<?xi32> { 3398 // CHECK: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor<i32> 3399 // CHECK: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i32> 3400 // CHECK: %[[INDEX:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<3x?xi32> 3401 // CHECK: %[[REDUCE:.*]]:2 = "mhlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]]) 3402 // CHECK: return %[[REDUCE]]#1 : tensor<?xi32> 3403 %axis = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32> 3404 %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x?xi32>, tensor<i32>) -> tensor<?xi32> 3405 return %0 : tensor<?xi32> 3406} 3407 3408// CHECK-LABEL: func @argmax_dynamic_shape_input 3409func @argmax_dynamic_shape_input(%arg0: tensor<3x?xi32>) -> tensor<3xi32> { 3410 // CHECK: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor<i32> 3411 // CHECK: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i32> 3412 // CHECK: %[[INDEX:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x?xi32> 3413 // CHECK: %[[REDUCE:.*]]:2 = "mhlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]]) 3414 // CHECK: return %[[REDUCE]]#1 : tensor<3xi32> 3415 %axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32> 3416 %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x?xi32>, tensor<i32>) -> tensor<3xi32> 3417 return %0 : tensor<3xi32> 3418} 3419 3420//===----------------------------------------------------------------------===// 3421// Random op legalizations. 3422//===----------------------------------------------------------------------===// 3423 3424// CHECK-LABEL: func @rng_uniform 3425func @rng_uniform(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { 3426 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 3427 // CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> 3428 // CHECK: %[[CONV:.*]] = "mhlo.convert"(%arg0) : (tensor<3xi32>) -> tensor<3xi64> 3429 // CHECK: %[[F32:.*]] = "mhlo.rng_uniform"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*}} -> tensor<12x?x64xf32> 3430 %0 = "tf.RandomUniform"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32> 3431 // CHECK: return %[[F32]] 3432 return %0 : tensor<12x?x64xf32> 3433} 3434 3435// CHECK-LABEL: func @rng_std_normal 3436func @rng_std_normal(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { 3437 // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 3438 // CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> 3439 // CHECK: %[[CONV:.*]] = "mhlo.convert"(%arg0) : (tensor<3xi32>) -> tensor<3xi64> 3440 // CHECK: %[[F32:.*]] = "mhlo.rng_normal"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*}} -> tensor<12x?x64xf32> 3441 %0 = "tf.RandomStandardNormal"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32> 3442 // CHECK: return %[[F32]] 3443 return %0 : tensor<12x?x64xf32> 3444} 3445 3446//===----------------------------------------------------------------------===// 3447// Range op legalizations. 3448//===----------------------------------------------------------------------===// 3449 3450// CHECK-LABEL: func @range 3451// CHECK-SAME: [[START:%.*]]: tensor<f32>, [[DELTA:%.*]]: tensor<f32> 3452func @range(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<5xf32> { 3453 %1 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "range/limit", value = dense<5.000000e+00> : tensor<f32>} : () -> tensor<f32> 3454 // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota" 3455 // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[DELTA]] {broadcast_dimensions = dense<> : tensor<0xi64>} 3456 // CHECK: chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<> : tensor<0xi64>} 3457 %3 = "tf.Range"(%arg0, %1, %arg1) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<5xf32> 3458 return %3 : tensor<5xf32> 3459} 3460 3461// CHECK-LABEL: func @range_dynamic 3462// CHECK-SAME: [[START:%.*]]: tensor<f32>, [[DELTA:%.*]]: tensor<f32> 3463func @range_dynamic(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<?xf32> { 3464 // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract %arg1, %arg0 3465 // CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"([[SUB]]) 3466 // CHECK-DAG: [[CONVERT1:%.+]] = "mhlo.convert"([[ABS1]]) 3467 // CHECK-DAG: [[CONVERT2:%.+]] = "mhlo.convert"(%arg2) 3468 // CHECK-DAG: [[DIV:%.+]] = mhlo.divide [[CONVERT1]], [[CONVERT2]] 3469 // CHECK-DAG: [[CEIL:%.+]] = "mhlo.ceil"([[DIV]]) 3470 // CHECK-DAG: [[CONVERT3:%.+]] = "mhlo.convert"([[CEIL]]) 3471 // CHECK-DAG: [[RESHAPE:%.+]] = "mhlo.reshape"([[CONVERT3]]) 3472 // CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64} 3473 // CHECK-DAG: [[CONVERT3:%.+]] = "mhlo.convert"(%arg0) 3474 // CHECK-DAG: [[CONVERT4:%.+]] = "mhlo.convert"(%arg2) 3475 // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<> : tensor<0xi64>} 3476 // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<> : tensor<0xi64>} 3477 %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32> 3478 3479 // CHECK: return [[ADD]] 3480 return %2 : tensor<?xf32> 3481} 3482 3483// CHECK-LABEL: func @range_int_dynamic 3484// CHECK-SAME: [[START:%.*]]: tensor<i32>, [[DELTA:%.*]]: tensor<i32> 3485func @range_int_dynamic(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?xi32> { 3486 // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract %arg1, %arg0 3487 // CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"([[SUB]]) 3488 // CHECK-DAG: [[CONVERT1:%.+]] = "mhlo.convert"([[ABS1]]) 3489 // CHECK-DAG: [[CONVERT2:%.+]] = "mhlo.convert"(%arg2) 3490 // CHECK-DAG: [[DIV:%.+]] = mhlo.divide [[CONVERT1]], [[CONVERT2]] 3491 // CHECK-DAG: [[CEIL:%.+]] = "mhlo.ceil"([[DIV]]) 3492 // CHECK-DAG: [[CONVERT3:%.+]] = "mhlo.convert"([[CEIL]]) 3493 // CHECK-DAG: [[RESHAPE:%.+]] = "mhlo.reshape"([[CONVERT3]]) 3494 // CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64} 3495 // CHECK-DAG: [[CONVERT3:%.+]] = "mhlo.convert"(%arg0) 3496 // CHECK-DAG: [[CONVERT4:%.+]] = "mhlo.convert"(%arg2) 3497 // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<> : tensor<0xi64>} 3498 // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<> : tensor<0xi64>} 3499 %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32> 3500 3501 // CHECK: return [[ADD]] 3502 return %2 : tensor<?xi32> 3503} 3504 3505// CHECK-LABEL: func @linspace_static 3506// CHECK-SAME: [[START:%.*]]: tensor<f32>, [[STOP:%.*]]: tensor<f32> 3507func @linspace_static(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<4xf32> { 3508 // CHECK-DAG: [[NUM:%.*]] = mhlo.constant dense<4> 3509 // CHECK-DAG: [[NUM_CAST:%.*]] = tensor.cast [[NUM]] 3510 // CHECK-DAG: [[NUM_F32:%.*]] = "mhlo.convert"([[NUM_CAST]]) 3511 // CHECK-DAG: [[ONE:%.*]] = mhlo.constant dense<1.000000e+00> 3512 // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = chlo.broadcast_subtract [[NUM_F32]], [[ONE]] 3513 // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = chlo.broadcast_subtract [[STOP]], [[START]] 3514 // CHECK-DAG: [[STEP:%.*]] = chlo.broadcast_divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]] 3515 // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} 3516 // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = dense<> : tensor<0xi64>} 3517 // CHECK-DAG: [[LINSPACE:%.*]] = chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<> : tensor<0xi64>} 3518 // CHECK: return [[LINSPACE]] 3519 %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<4> : tensor<i32>} : () -> tensor<i32> 3520 %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor<f32>, tensor<f32>, tensor<i32>) -> tensor<4xf32> 3521 return %1 : tensor<4xf32> 3522} 3523 3524// CHECK-LABEL: func @linspace_dynamic 3525func @linspace_dynamic(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>) -> tensor<?xf32> { 3526 // CHECK: "tf.LinSpace" 3527 %0 = "tf.LinSpace"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<f32>, tensor<i32>) -> tensor<?xf32> 3528 return %0 : tensor<?xf32> 3529} 3530 3531// CHECK-LABEL: func @linspace_invalid_num 3532func @linspace_invalid_num(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<?xf32> { 3533 // CHECK: mhlo.constant dense<> : tensor<0xi32> 3534 // CHECK: "tf.LinSpace" 3535 %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> 3536 %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor<f32>, tensor<f32>, tensor<0xi32>) -> tensor<?xf32> 3537 return %1 : tensor<?xf32> 3538} 3539 3540//===----------------------------------------------------------------------===// 3541// LegacyCall op legalizations. 3542//===----------------------------------------------------------------------===// 3543 3544func @identity_func(%arg0: tensor<10x2xf32>) -> tensor<10x2xf32> { 3545 return %arg0: tensor<10x2xf32> 3546} 3547 3548// CHECK-LABEL: testSimpleLegacyCallOp 3549func @testSimpleLegacyCallOp(%arg0: tensor<10x2xf32>) -> tensor<10x2xf32> { 3550 // CHECK: %[[RESULT:.*]] = call @identity_func(%arg0) : (tensor<10x2xf32>) -> tensor<10x2xf32> 3551 %0 = "tf.LegacyCall"(%arg0) {f = @identity_func} : (tensor<10x2xf32>) -> tensor<10x2xf32> 3552 // CHECK: return %[[RESULT]] 3553 return %0: tensor<10x2xf32> 3554} 3555 3556func @select_first(%arg0: tensor<10x2xf32>, %arg1: tensor<10x2xf32>) -> tensor<10x2xf32> { 3557 return %arg0: tensor<10x2xf32> 3558} 3559 3560// CHECK-LABEL: testMultiInputLegacyCallOp 3561func @testMultiInputLegacyCallOp(%arg0: tensor<10x2xf32>, %arg1: tensor<10x2xf32>) -> tensor<10x2xf32> { 3562 // CHECK: %[[RESULT:.*]] = call @select_first(%arg0, %arg1) : (tensor<10x2xf32>, tensor<10x2xf32>) -> tensor<10x2xf32> 3563 %0 = "tf.LegacyCall"(%arg0, %arg1) {_disable_call_shape_inference = true, _tpu_replicate = "cluster", device = "", f = @select_first} : (tensor<10x2xf32>, tensor<10x2xf32>) -> tensor<10x2xf32> 3564 // CHECK: return %[[RESULT]] 3565 return %0: tensor<10x2xf32> 3566} 3567 3568//===----------------------------------------------------------------------===// 3569// Conv op legalizations. 3570//===----------------------------------------------------------------------===// 3571 3572// CHECK-LABEL: conv_simple 3573func @conv_simple(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32> { 3574 3575 // CHECK: "mhlo.convolution"(%arg0, %arg1) 3576 3577 // Default attributes 3578 // CHECK-NOT: lhs_dilation 3579 // CHECK-NOT: precision_config 3580 3581 // CHECK-DAG-SAME: window_strides = dense<[4, 5]> 3582 // CHECK-DAG-SAME: padding = dense<{{\[\[}}44, 45], [60, 60]]> 3583 // CHECK-DAG-SAME: rhs_dilation = dense<[2, 3]> 3584 3585 // CHECK-DAG-SAME: dimension_numbers 3586 // CHECK-DAG-SAME: input_batch_dimension = 0 3587 // CHECK-DAG-SAME: input_feature_dimension = 3 3588 // CHECK-DAG-SAME: input_spatial_dimensions = dense<[1, 2]> 3589 // CHECK-DAG-SAME: kernel_input_feature_dimension = 2 3590 // CHECK-DAG-SAME: kernel_output_feature_dimension = 3 3591 // CHECK-DAG-SAME: kernel_spatial_dimensions = dense<[0, 1]> 3592 // CHECK-DAG-SAME: output_batch_dimension = 0 3593 // CHECK-DAG-SAME: output_feature_dimension = 3 3594 // CHECK-DAG-SAME: output_spatial_dimensions = dense<[1, 2]> 3595 3596 // CHECK-DAG-SAME: feature_group_count = 2 3597 // CHECK-DAG-SAME: batch_group_count = 1 3598 3599 %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32> 3600 return %0 : tensor<256x8x7x16xf32> 3601} 3602 3603// CHECK-LABEL: conv3d_simple 3604func @conv3d_simple(%arg0: tensor<256x32x32x32x6xf32>, %arg1: tensor<3x3x3x3x16xf32>) -> tensor<256x7x6x5x16xf32> { 3605 3606 // CHECK: "mhlo.convolution"(%arg0, %arg1) 3607 3608 // Default attributes 3609 // CHECK-NOT: lhs_dilation 3610 // CHECK-NOT: precision_config 3611 3612 // CHECK-DAG-SAME: window_strides = dense<[5, 6, 7]> 3613 // CHECK-DAG-SAME: padding = dense<[[1, 2], [2, 3], [2, 3]]> 3614 // CHECK-DAG-SAME: rhs_dilation = dense<[2, 3, 4]> 3615 3616 // CHECK-DAG-SAME: dimension_numbers 3617 // CHECK-DAG-SAME: input_batch_dimension = 0 3618 // CHECK-DAG-SAME: input_feature_dimension = 4 3619 // CHECK-DAG-SAME: input_spatial_dimensions = dense<[1, 2, 3]> 3620 // CHECK-DAG-SAME: kernel_input_feature_dimension = 3 3621 // CHECK-DAG-SAME: kernel_output_feature_dimension = 4 3622 // CHECK-DAG-SAME: kernel_spatial_dimensions = dense<[0, 1, 2]> 3623 // CHECK-DAG-SAME: output_batch_dimension = 0 3624 // CHECK-DAG-SAME: output_feature_dimension = 4 3625 // CHECK-DAG-SAME: output_spatial_dimensions = dense<[1, 2, 3]> 3626 3627 // CHECK-DAG-SAME: feature_group_count = 2 3628 // CHECK-DAG-SAME: batch_group_count = 1 3629 3630 %0 = "tf.Conv3D"(%arg0, %arg1) {data_format = "NDHWC", dilations = [1, 2, 3, 4, 1], padding = "SAME", strides = [1, 5, 6, 7, 1]} : (tensor<256x32x32x32x6xf32>, tensor<3x3x3x3x16xf32>) -> tensor<256x7x6x5x16xf32> 3631 return %0 : tensor<256x7x6x5x16xf32> 3632} 3633 3634// CHECK-LABEL: depthwiseconv_simple 3635func @depthwiseconv_simple(%arg0: tensor<2x4x5x3xf32>, %arg1: tensor<2x2x3x3xf32>) -> tensor<2x3x4x9xf32> { 3636 // CHECK: %[[RESHAPED_FILTER:.*]] = "mhlo.reshape"(%arg1) : (tensor<2x2x3x3xf32>) -> tensor<2x2x1x9xf32> 3637 // CHECK: "mhlo.convolution"(%arg0, %[[RESHAPED_FILTER]]) 3638 // CHECK: feature_group_count = 3 3639 %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) { 3640 data_format = "NHWC", 3641 device = "", 3642 dilations = [1, 1, 1, 1], 3643 explicit_paddings = [], 3644 padding = "VALID", 3645 strides = [1, 1, 1, 1] 3646 } : (tensor<2x4x5x3xf32>, tensor<2x2x3x3xf32>) -> tensor<2x3x4x9xf32> 3647 return %0 : tensor<2x3x4x9xf32> 3648} 3649 3650// CHECK-LABEL: conv_valid_padding 3651func @conv_valid_padding(%arg0: tensor<1x4x5x1xf32>, %arg1: tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32> { 3652 // CHECK: "mhlo.convolution"(%arg0, %arg1) 3653 3654 %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x4x5x1xf32>, tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32> 3655 return %0 : tensor<1x2x3x1xf32> 3656} 3657 3658// CHECK-LABEL: conv_explicit_paddings 3659func @conv_explicit_paddings(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x9x7x16xf32> { 3660 3661 // CHECK: "mhlo.convolution"(%arg0, %arg1) 3662 // CHECK-SAME: padding = dense<{{\[\[}}6, 0], [3, 3]]> 3663 3664 %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "EXPLICIT", explicit_paddings = [0, 0, 6, 0, 3, 3, 0, 0], strides = [1, 4, 5, 1]} : (tensor<256x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<256x9x7x16xf32> 3665 return %0 : tensor<256x9x7x16xf32> 3666} 3667 3668// CHECK-LABEL: @conv2d_backprop_input 3669func @conv2d_backprop_input( 3670 %filter: tensor<3x3x1x32xf32>, 3671 %out_backprop: tensor<100x26x26x32xf32> 3672 ) -> tensor<100x28x28x1xf32> { 3673 // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} 3674 // CHECK: %[[RESULT:.*]] = "mhlo.convolution"(%arg1, %[[REV_FILTER]]) { 3675 // CHECK-SAME: batch_group_count = 1 : i64, 3676 // CHECK-SAME: dimension_numbers = { 3677 // CHECK-SAME: input_batch_dimension = 0 : i64, 3678 // CHECK-SAME: input_feature_dimension = 3 : i64, 3679 // CHECK-SAME: input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, 3680 // CHECK-SAME: kernel_input_feature_dimension = 3 : i64, 3681 // CHECK-SAME: kernel_output_feature_dimension = 2 : i64, 3682 // CHECK-SAME: kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, 3683 // CHECK-SAME: output_batch_dimension = 0 : i64, 3684 // CHECK-SAME: output_feature_dimension = 3 : i64, 3685 // CHECK-SAME: output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64> 3686 // CHECK-SAME: }, 3687 // CHECK-SAME: feature_group_count = 1 : i64, 3688 // CHECK-SAME: lhs_dilation = dense<1> : tensor<2xi64>, 3689 // CHECK-SAME: padding = dense<2> : tensor<2x2xi64>, 3690 // CHECK-SAME: rhs_dilation = dense<1> : tensor<2xi64>, 3691 // CHECK-SAME: window_strides = dense<1> : tensor<2xi64> 3692 // CHECK: return %[[RESULT]] 3693 %input_sizes = "tf.Const" () { value = dense<[100,28,28,1]> : tensor<4xi32> } : () -> tensor<4xi32> 3694 %result = "tf.Conv2DBackpropInput"(%input_sizes, %filter, %out_backprop) { 3695 data_format = "NHWC", 3696 dilations = [1, 1, 1, 1], 3697 explicit_paddings = [], 3698 padding = "VALID", 3699 strides = [1, 1, 1, 1], 3700 use_cudnn_on_gpu = true 3701 } : (tensor<4xi32>, tensor<3x3x1x32xf32>, tensor<100x26x26x32xf32>) -> tensor<100x28x28x1xf32> 3702 return %result : tensor<100x28x28x1xf32> 3703} 3704 3705// CHECK-LABEL: @conv2d_backprop_input_grouped 3706func @conv2d_backprop_input_grouped( 3707 %filter: tensor<2x2x5x21xf32>, 3708 %out_backprop: tensor<5x2x2x21xf32> 3709 ) -> tensor<5x3x3x15xf32> { 3710 %input_sizes = "tf.Const" () { value = dense<[5, 3, 3, 15]> : tensor<4xi32> } : () -> tensor<4xi32> 3711 3712 // Verify filter transformation for grouped convolution. 3713 3714 // CHECK: %[[RESHAPE:.*]] = "mhlo.reshape"(%arg0) : (tensor<2x2x5x21xf32>) -> tensor<2x2x5x3x7xf32> 3715 // CHECK: %[[TRANSPOSE:.*]] = "mhlo.transpose"(%[[RESHAPE]]) 3716 // CHECK-SAME: permutation = dense<[0, 1, 3, 2, 4]> 3717 // CHECK-SAME: (tensor<2x2x5x3x7xf32>) -> tensor<2x2x3x5x7xf32> 3718 // CHECK: "mhlo.reshape"(%[[TRANSPOSE]]) : (tensor<2x2x3x5x7xf32>) -> tensor<2x2x15x7xf32> 3719 3720 %result = "tf.Conv2DBackpropInput"(%input_sizes, %filter, %out_backprop) { 3721 data_format = "NHWC", 3722 dilations = [1, 1, 1, 1], 3723 explicit_paddings = [], 3724 padding = "VALID", 3725 strides = [1, 1, 1, 1], 3726 use_cudnn_on_gpu = true 3727 } : (tensor<4xi32>, tensor<2x2x5x21xf32>, tensor<5x2x2x21xf32>) -> tensor<5x3x3x15xf32> 3728 return %result : tensor<5x3x3x15xf32> 3729} 3730 3731 3732// CHECK-LABEL: @conv3d_backprop_input 3733func @conv3d_backprop_input(%filter: tensor<3x3x3x1x6xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> { 3734 // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} 3735 // CHECK: %[[RESULT:.*]] = "mhlo.convolution"(%arg1, %[[REV_FILTER]]) 3736 3737 // CHECK-DAG-SAME: batch_group_count = 1 : i64, 3738 3739 // CHECK-DAG-SAME: dimension_numbers = 3740 // CHECK-DAG-SAME: input_batch_dimension = 0 : i64 3741 // CHECK-DAG-SAME: input_feature_dimension = 4 : i64 3742 // CHECK-DAG-SAME: input_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64> 3743 // CHECK-DAG-SAME: kernel_input_feature_dimension = 4 : i64 3744 // CHECK-DAG-SAME: kernel_output_feature_dimension = 3 : i64 3745 // CHECK-DAG-SAME: kernel_spatial_dimensions = dense<[0, 1, 2]> : tensor<3xi64> 3746 // CHECK-DAG-SAME: output_batch_dimension = 0 : i64 3747 // CHECK-DAG-SAME: output_feature_dimension = 4 : i64 3748 // CHECK-DAG-SAME: output_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64> 3749 3750 // CHECK-DAG-SAME: feature_group_count = 1 : i64 3751 // CHECK-DAG-SAME: lhs_dilation = dense<1> : tensor<3xi64> 3752 // CHECK-DAG-SAME: padding = dense<1> : tensor<3x2xi64> 3753 // CHECK-DAG-SAME: rhs_dilation = dense<1> : tensor<3xi64> 3754 // CHECK-DAG-SAME: window_strides = dense<1> : tensor<3xi64> 3755 3756 // CHECK: return %[[RESULT]] 3757 %input_sizes = "tf.Const" () {value = dense<[2, 8, 8, 8, 1]> : tensor<5xi32>} : () -> tensor<5xi32> 3758 %result = "tf.Conv3DBackpropInputV2"(%input_sizes, %filter, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<5xi32>, tensor<3x3x3x1x6xf32>, tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> 3759 return %result : tensor<2x8x8x8x1xf32> 3760} 3761 3762// CHECK-LABEL: @conv2d_backprop_filter 3763func @conv2d_backprop_filter( 3764 %input: tensor<100x28x28x1xf32>, 3765 %out_backprop: tensor<100x26x26x32xf32> 3766 ) -> tensor<100x28x28x1xf32> { 3767 // CHECK: %[[RESULT:.*]] = "mhlo.convolution"(%arg0, %arg1) { 3768 // CHECK-SAME: batch_group_count = 1 : i64, 3769 // CHECK-SAME: dimension_numbers = { 3770 // CHECK-SAME: input_batch_dimension = 3 : i64, 3771 // CHECK-SAME: input_feature_dimension = 0 : i64, 3772 // CHECK-SAME: input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, 3773 // CHECK-SAME: kernel_input_feature_dimension = 0 : i64, 3774 // CHECK-SAME: kernel_output_feature_dimension = 3 : i64, 3775 // CHECK-SAME: kernel_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, 3776 // CHECK-SAME: output_batch_dimension = 2 : i64, 3777 // CHECK-SAME: output_feature_dimension = 3 : i64, 3778 // CHECK-SAME: output_spatial_dimensions = dense<[0, 1]> : tensor<2xi64> 3779 // CHECK-SAME: }, 3780 // CHECK-SAME: feature_group_count = 1 : i64, 3781 // CHECK-SAME: lhs_dilation = dense<1> : tensor<2xi64>, 3782 // CHECK-SAME: padding = dense<0> : tensor<2x2xi64>, 3783 // CHECK-SAME: rhs_dilation = dense<1> : tensor<2xi64>, 3784 // CHECK-SAME: window_strides = dense<1> : tensor<2xi64> 3785 // CHECK: return %[[RESULT]] 3786 %filter_sizes = "tf.Const" () { value = dense<[3,3,1,32]> : tensor<4xi32> } : () -> tensor<4xi32> 3787 %result = "tf.Conv2DBackpropFilter"(%input, %filter_sizes, %out_backprop) { 3788 data_format = "NHWC", 3789 dilations = [1, 1, 1, 1], 3790 explicit_paddings = [], 3791 padding = "VALID", 3792 strides = [1, 1, 1, 1], 3793 use_cudnn_on_gpu = true 3794 } : (tensor<100x28x28x1xf32>, tensor<4xi32>, tensor<100x26x26x32xf32>) -> tensor<100x28x28x1xf32> 3795 return %result : tensor<100x28x28x1xf32> 3796} 3797 3798// CHECK-LABEL: @conv2d_backprop_filter_grouped 3799func @conv2d_backprop_filter_grouped( 3800 %input: tensor<1x2x2x2xf32>, 3801 %out_backprop: tensor<1x1x1x2xf32> 3802 ) -> tensor<2x2x1x2xf32> { 3803 3804 // CHECK: "mhlo.convolution"(%arg0, %arg1) { 3805 // CHECK-SAME: batch_group_count = 2 : i64, 3806 // CHECK-SAME: feature_group_count = 1 : i64, 3807 3808 %filter_sizes = "tf.Const" () { value = dense<[2, 2, 1, 2]> : tensor<4xi32> } : () -> tensor<4xi32> 3809 %result = "tf.Conv2DBackpropFilter"(%input, %filter_sizes, %out_backprop) { 3810 data_format = "NHWC", 3811 dilations = [1, 1, 1, 1], 3812 explicit_paddings = [], 3813 padding = "VALID", 3814 strides = [1, 1, 1, 1], 3815 use_cudnn_on_gpu = true 3816 } : (tensor<1x2x2x2xf32>, tensor<4xi32>, tensor<1x1x1x2xf32>) -> tensor<2x2x1x2xf32> 3817 return %result : tensor<2x2x1x2xf32> 3818} 3819 3820 3821// CHECK-LABEL: @conv3d_backprop_filter 3822func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> { 3823 // CHECK: %[[RESULT:.*]] = "mhlo.convolution"(%arg0, %arg1) 3824 3825 // CHECK-DAG-SAME: batch_group_count = 1 : i64 3826 3827 // CHECK-DAG-SAME: dimension_numbers = 3828 // CHECK-DAG-SAME: input_batch_dimension = 4 : i64 3829 // CHECK-DAG-SAME: input_feature_dimension = 0 : i64 3830 // CHECK-DAG-SAME: input_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64> 3831 // CHECK-DAG-SAME: kernel_input_feature_dimension = 0 : i64 3832 // CHECK-DAG-SAME: kernel_output_feature_dimension = 4 : i64 3833 // CHECK-DAG-SAME: kernel_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64> 3834 // CHECK-DAG-SAME: output_batch_dimension = 3 : i64 3835 // CHECK-DAG-SAME: output_feature_dimension = 4 : i64 3836 // CHECK-DAG-SAME: output_spatial_dimensions = dense<[0, 1, 2]> : tensor<3xi64> 3837 3838 // CHECK-DAG-SAME: feature_group_count = 1 : i64 3839 // CHECK-DAG-SAME: lhs_dilation = dense<1> : tensor<3xi64> 3840 // CHECK-DAG-SAME: padding = dense<1> : tensor<3x2xi64> 3841 // CHECK-DAG-SAME: rhs_dilation = dense<1> : tensor<3xi64> 3842 // CHECK-DAG-SAME: window_strides = dense<1> : tensor<3xi64> 3843 3844 // CHECK: return %[[RESULT]] 3845 %filter_sizes = "tf.Const"() {value = dense<[3, 3, 3, 1, 6]> : tensor<5xi32>} : () -> tensor<5xi32> 3846 %result = "tf.Conv3DBackpropFilterV2"(%input, %filter_sizes, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<2x8x8x8x1xf32>, tensor<5xi32>, tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> 3847 return %result : tensor<2x8x8x8x1xf32> 3848} 3849 3850// CHECK-LABEL: @collective_permute 3851func @collective_permute(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { 3852 %source_target_pairs = "tf.Const" () { 3853 value = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi32> 3854 } : () -> tensor<3x2xi32> 3855 3856 // CHECK: "mhlo.collective_permute" 3857 // CHECK-SAME: source_target_pairs = dense<{{\[}}[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> 3858 %0 = "tf.CollectivePermute"(%arg0, %source_target_pairs) { 3859 } : (tensor<128x32xf32>, tensor<3x2xi32>) -> tensor<128x32xf32> 3860 3861 return %0 : tensor<128x32xf32> 3862} 3863 3864// CHECK-LABEL: @cross_replica_sum 3865func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> { 3866 %replica_groups = "tf.Const" () { 3867 value = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32> 3868 } : () -> tensor<2x4xi32> 3869 3870 // CHECK: mhlo.cross-replica-sum 3871 // CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> 3872 %result = "tf.CrossReplicaSum" (%input, %replica_groups) : (tensor<10xf32>, tensor<2x4xi32>) -> tensor<10xf32> 3873 return %result : tensor<10xf32> 3874} 3875 3876//===----------------------------------------------------------------------===// 3877// tf.Split legalization 3878//===----------------------------------------------------------------------===// 3879 3880// CHECK-LABEL: @split_not_match_non_const_split_dim 3881func @split_not_match_non_const_split_dim(%input: tensor<4x4xf32>, %split_dim: tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>) { 3882 // CHECK: tf.Split 3883 %0:2 = "tf.Split"(%split_dim, %input) : (tensor<i32>, tensor<4x4xf32>) -> (tensor<*xf32>, tensor<*xf32>) 3884 return %0#0, %0#1 : tensor<*xf32>, tensor<*xf32> 3885} 3886 3887// CHECK-LABEL: @split_not_match_unknown_input_dim 3888func @split_not_match_unknown_input_dim(%input: tensor<4x?x4xf32>) -> (tensor<4x?x4xf32>, tensor<4x?x4xf32>) { 3889 %cst = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> 3890 // CHECK: tf.Split 3891 %0:2 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x?x4xf32>) -> (tensor<4x?x4xf32>, tensor<4x?x4xf32>) 3892 return %0#0, %0#1 : tensor<4x?x4xf32>, tensor<4x?x4xf32> 3893} 3894 3895// CHECK-LABEL: @split_match_and_split_into_two 3896func @split_match_and_split_into_two(%input: tensor<4x6xf32>) -> (tensor<2x6xf32>, tensor<2x6xf32>) { 3897 %cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> 3898 // CHECK: %[[ONE:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[2, 6]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<2x6xf32> 3899 // CHECK: %[[TWO:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<2x6xf32> 3900 %0:2 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x6xf32>) -> (tensor<2x6xf32>, tensor<2x6xf32>) 3901 // CHECK: return %[[ONE]], %[[TWO]] 3902 return %0#0, %0#1 : tensor<2x6xf32>, tensor<2x6xf32> 3903} 3904 3905// CHECK-LABEL: @split_match_and_split_into_two_dynamic 3906func @split_match_and_split_into_two_dynamic(%input: tensor<4x?xf32>) -> (tensor<2x?xf32>, tensor<2x?xf32>) { 3907 %cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> 3908 // CHECK: %[[ONE:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[2, -1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32> 3909 // CHECK: %[[TWO:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, -1]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32> 3910 %0:2 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x?xf32>) -> (tensor<2x?xf32>, tensor<2x?xf32>) 3911 // CHECK: return %[[ONE]], %[[TWO]] 3912 return %0#0, %0#1 : tensor<2x?xf32>, tensor<2x?xf32> 3913} 3914 3915// CHECK-LABEL: @split_match_and_split_into_three 3916// CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>) 3917func @split_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) { 3918 %cst = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> 3919 // CHECK: %[[ONE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 2]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> 3920 // CHECK: %[[TWO:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<4> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> 3921 // CHECK: %[[THREE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 4]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> 3922 %0:3 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) 3923 // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]] 3924 return %0#0, %0#1, %0#2 : tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32> 3925} 3926 3927//===----------------------------------------------------------------------===// 3928// tf.TopKV2 legalization 3929//===----------------------------------------------------------------------===// 3930 3931// CHECK-LABEL: topk_v2_non_const_k 3932func @topk_v2_non_const_k(%input: tensor<16xf32>, %k: tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>) { 3933 // CHECK: tf.TopKV2 3934 %0:2 = "tf.TopKV2"(%input, %k): (tensor<16xf32>, tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>) 3935 return %0#0, %0#1: tensor<?xf32>, tensor<?xi32> 3936} 3937 3938// CHECK-LABEL: topk_v2_unknown_input_last_dim 3939func @topk_v2_unknown_input_last_dim(%input: tensor<16x?xf32>) -> (tensor<16x?xf32>, tensor<16x?xi32>) { 3940 %k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32> 3941 // CHECK: tf.TopKV2 3942 %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x?xf32>, tensor<i32>) -> (tensor<16x?xf32>, tensor<16x?xi32>) 3943 return %0#0, %0#1: tensor<16x?xf32>, tensor<16x?xi32> 3944} 3945 3946// CHECK-LABEL: topk_v2 3947// CHECK-SAME: %[[INPUT:.*]]: tensor<16x16xf32> 3948func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) { 3949 %k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32> 3950 3951 // CHECK: %[[IOTA:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} 3952 // CHECK-NEXT: %[[SORT:.*]]:2 = "mhlo.sort"(%[[INPUT]], %[[IOTA]]) ( { 3953 // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor<f32>, %[[RHS:.*]]: tensor<f32>, %{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>): 3954 // CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[LHS]], %[[RHS]]) {compare_type = "TOTALORDER", comparison_direction = "GT"} 3955 // CHECK-NEXT: "mhlo.return"(%[[CMP]]) 3956 // CHECK-NEXT: }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) 3957 // CHECK-NEXT: %[[VAL:.*]] = "mhlo.slice"(%[[SORT]]#0) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} 3958 // CHECK-NEXT: %[[IDX:.*]] = "mhlo.slice"(%[[SORT]]#1) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} 3959 // CHECK-NEXT: return %[[VAL]], %[[IDX]] 3960 %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor<i32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) 3961 return %0#0, %0#1: tensor<16x8xf32>, tensor<16x8xi32> 3962} 3963 3964//===----------------------------------------------------------------------===// 3965// tf.SplitV legalization 3966//===----------------------------------------------------------------------===// 3967 3968// CHECK-LABEL: @splitv_match_and_split_into_three 3969// CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>) 3970func @splitv_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) { 3971 %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> 3972 %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> 3973 // CHECK: %[[ONE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x1xf32> 3974 // CHECK: %[[TWO:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32> 3975 // CHECK: %[[THREE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x3xf32> 3976 %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) 3977 // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]] 3978 return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32> 3979} 3980 3981// CHECK-LABEL: @splitv_match_and_split_into_three_dynamic 3982func @splitv_match_and_split_into_three_dynamic(%input: tensor<?x6xf32>) -> (tensor<?x1xf32>, tensor<?x2xf32>, tensor<?x3xf32>) { 3983 %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> 3984 %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> 3985 // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<?x6xf32>) -> tensor<?x1xf32> 3986 // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<?x6xf32>) -> tensor<?x2xf32> 3987 // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<?x6xf32>) -> tensor<?x3xf32> 3988 %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<?x6xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<?x1xf32>, tensor<?x2xf32>, tensor<?x3xf32>) 3989 return %0#0, %0#1, %0#2 : tensor<?x1xf32>, tensor<?x2xf32>, tensor<?x3xf32> 3990} 3991 3992// CHECK-LABEL: @splitv_dynamic_dim_in_split_sizes 3993func @splitv_dynamic_dim_in_split_sizes(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) { 3994 %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>} : () -> tensor<3xi32> 3995 %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> 3996 // CHECK: limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64> 3997 // CHECK: limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64> 3998 // CHECK: limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64> 3999 %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) 4000 return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32> 4001} 4002 4003//===----------------------------------------------------------------------===// 4004// tf.Assert legalization 4005//===----------------------------------------------------------------------===// 4006 4007// CHECK-LABEL: @assert 4008func @assert(%arg0: tensor<i1>, %arg1: tensor<*xf32>) { 4009 // CHECK-NOT: tf.Assert 4010 "tf.Assert"(%arg0, %arg1) {summarize = 1} : (tensor<i1>, tensor<*xf32>) -> () 4011 return 4012} 4013 4014//===----------------------------------------------------------------------===// 4015// tf.Unpack legalization 4016//===----------------------------------------------------------------------===// 4017 4018// CHECK-LABEL: @unpack 4019func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) { 4020 // CHECK: %[[SLICE1:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> 4021 // CHECK: %[[RES1:.*]] = "mhlo.reshape"(%[[SLICE1]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> 4022 // CHECK: %[[SLICE2:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> 4023 // CHECK: %[[RES2:.*]] = "mhlo.reshape"(%[[SLICE2]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> 4024 // CHECK: %[[SLICE3:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> 4025 // CHECK: %[[RES3:.*]] = "mhlo.reshape"(%[[SLICE3]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> 4026 4027 %0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) 4028 // return %[[RES1]], %[[RES2]], %[[RES3]] 4029 return %0#0, %0#1, %0#2 : tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32> 4030} 4031 4032// CHECK-LABEL: @unpack_dynamic 4033func @unpack_dynamic(%input: tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) { 4034 4035 // CHECK: tf.Unpack 4036 %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) 4037 return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32> 4038} 4039 4040// CHECK-LABEL: @unpack_unranked 4041func @unpack_unranked(%input: tensor<*xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) { 4042 4043 // CHECK: tf.Unpack 4044 %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<*xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) 4045 return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32> 4046} 4047 4048//===----------------------------------------------------------------------===// 4049// tf.UnsortedSegment{Max|Min|Prod|Sum} legalization 4050//===----------------------------------------------------------------------===// 4051 4052// CHECK-LABEL: @unsorted_segment_sum 4053// CHECK-SAME: [[DATA:%.*]]: tensor<8x16x64xf32> 4054// CHECK-SAME: [[SI:%.*]]: tensor<8x16xi32> 4055func @unsorted_segment_sum(%data: tensor<8x16x64xf32>, %segment_ids : tensor<8x16xi32>) -> (tensor<4x64xf32>) { 4056 %num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32> 4057 // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4058 // CHECK: [[INIT:%.*]] = "mhlo.broadcast"([[ZERO]]) {broadcast_sizes = dense<[4, 64]> : tensor<2xi64>} : (tensor<f32>) -> tensor<4x64xf32> 4059 // CHECK: [[SCATTER:%.*]] = "mhlo.scatter"([[INIT]], [[SI]], [[DATA]]) ( { 4060 // CHECK: ^{{.*}}([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>): 4061 // CHECK: [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]] : tensor<f32> 4062 // CHECK: "mhlo.return"([[ADD]]) 4063 // CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = {index_vector_dim = 2 : i64, inserted_window_dims = dense<0> : tensor<1xi64>, scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, update_window_dims = dense<2> : tensor<1xi64>}, unique_indices = false} : (tensor<4x64xf32>, tensor<8x16xi32>, tensor<8x16x64xf32>) -> tensor<4x64xf32> 4064 // CHECK: return [[SCATTER]] 4065 %0 = "tf.UnsortedSegmentSum"(%data, %segment_ids, %num_segments) : (tensor<8x16x64xf32>, tensor<8x16xi32>, tensor<i32>) -> (tensor<4x64xf32>) 4066 return %0: tensor<4x64xf32> 4067} 4068 4069// CHECK-LABEL: @unsorted_segment_prod 4070// CHECK-SAME: [[DATA:%.*]]: tensor<8x?x64xf32> 4071// CHECK-SAME: [[SI:%.*]]: tensor<?x16xi32> 4072func @unsorted_segment_prod(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x16xi32>) -> (tensor<4x?xf32>) { 4073 %num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32> 4074 // CHECK: [[ONE:%.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> 4075 // CHECK: [[INIT:%.*]] = "mhlo.broadcast"([[ONE]]) {broadcast_sizes = dense<[4, 64]> : tensor<2xi64>} : (tensor<f32>) -> tensor<4x64xf32> 4076 // CHECK: [[SCATTER:%.*]] = "mhlo.scatter"([[INIT]], [[SI]], [[DATA]]) ( { 4077 // CHECK: ^{{.*}}([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>): 4078 // CHECK: [[MUL:%.*]] = mhlo.multiply [[LHS]], [[RHS]] : tensor<f32> 4079 // CHECK: "mhlo.return"([[MUL]]) 4080 // CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = {index_vector_dim = 2 : i64, inserted_window_dims = dense<0> : tensor<1xi64>, scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, update_window_dims = dense<2> : tensor<1xi64>}, unique_indices = false} : (tensor<4x64xf32>, tensor<?x16xi32>, tensor<8x?x64xf32>) -> tensor<4x?xf32> 4081 // CHECK: return [[SCATTER]] 4082 %0 = "tf.UnsortedSegmentProd"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor<?x16xi32>, tensor<i32>) -> (tensor<4x?xf32>) 4083 return %0: tensor<4x?xf32> 4084} 4085 4086// CHECK-LABEL: @unsorted_segment_min 4087func @unsorted_segment_min(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x16xi32>) -> (tensor<4x?xf32>) { 4088 %num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32> 4089 // CHECK: mhlo.constant dense<3.40282347E+38> : tensor<f32> 4090 // CHECK: mhlo.scatter 4091 // CHECK: mhlo.minimum 4092 %0 = "tf.UnsortedSegmentMin"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor<?x16xi32>, tensor<i32>) -> (tensor<4x?xf32>) 4093 return %0: tensor<4x?xf32> 4094} 4095 4096// CHECK-LABEL: @unsorted_segment_max 4097func @unsorted_segment_max(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x16xi32>) -> (tensor<4x?xf32>) { 4098 %num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32> 4099 // CHECK: mhlo.constant dense<-3.40282347E+38> : tensor<f32> 4100 // CHECK: mhlo.scatter 4101 // CHECK: mhlo.maximum 4102 %0 = "tf.UnsortedSegmentMax"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor<?x16xi32>, tensor<i32>) -> (tensor<4x?xf32>) 4103 return %0: tensor<4x?xf32> 4104} 4105 4106//===----------------------------------------------------------------------===// 4107// tf.GatherV2 legalization 4108//===----------------------------------------------------------------------===// 4109 4110// CHECK-LABEL: @gather_v2 4111func @gather_v2(%arg0: tensor<16x2x3xf32>, %arg1: tensor<16x5xi32>) -> tensor<16x2x5xf32> { 4112 // CHECK: "mhlo.torch_index_select"(%arg0, %arg1) {batch_dims = 1 : i64, dim = 2 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>) -> tensor<16x2x5xf32> 4113 %0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> 4114 %1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>, tensor<1xi32>) -> tensor<16x2x5xf32> 4115 return %1 : tensor<16x2x5xf32> 4116} 4117 4118// CHECK-LABEL: @gather_v2_dynamic 4119func @gather_v2_dynamic(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xi32>) -> tensor<*xf32> { 4120 // CHECK: "mhlo.torch_index_select"(%arg0, %arg1) {batch_dims = 1 : i64, dim = 2 : i64} : (tensor<?x?x?xf32>, tensor<?x?xi32>) -> tensor<*xf32> 4121 %0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> 4122 %1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<?x?x?xf32>, tensor<?x?xi32>, tensor<1xi32>) -> tensor<*xf32> 4123 return %1 : tensor<*xf32> 4124} 4125 4126// CHECK-LABEL: @gather_v2_unranked 4127func @gather_v2_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xi32>) -> tensor<*xf32> { 4128 // CHECK: tf.GatherV2 4129 %0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> 4130 %1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<*xf32>, tensor<*xi32>, tensor<1xi32>) -> tensor<*xf32> 4131 return %1 : tensor<*xf32> 4132} 4133 4134//===----------------------------------------------------------------------===// 4135// tf.StridedSliceGrad legalization 4136//===----------------------------------------------------------------------===// 4137 4138// CHECK-LABEL: strided_slice_grad 4139// CHECK-SAME: [[GRAD:%.*]]: tensor<4x16x1022xf32> 4140func @strided_slice_grad(%grad: tensor<4x16x1022xf32>) -> tensor<4x128x1024xf32> { 4141 4142 // For StridedSlice 4143 // Dim #: 0, 1, 2 4144 // Input shape: [4, 128, 1024] 4145 // Begin: 1, 4, -3 4146 // End: 8, 65, 42 4147 // Stride: 1, 4, -1 4148 // Begin mask: 1, 0, 0 (= 1) 4149 // End mask: 0, 0, 1 (= 4) 4150 4151 // So result shape: 4152 // Dim #0: begin mask (1) -> begin = 0; end 8 canonicalized to 4: so 4 4153 // Dim #1: 4 to 65 stride 4: so 16 4154 // Dim #2: begin -3 + 1024 = 1021; end mask (1) -> end = -1: so 1022 4155 // result shape: [4, 16, 1022] 4156 4157 // To pad back: 4158 // Dim #: 0, 1, 2 4159 // Pad low: 0, 4, 0 4160 // Pad interm: 0, 3, 0 4161 // Pad high: 0, 63, 2 4162 4163 %shape = "tf.Const"() {value = dense<[4, 128, 1024]> : tensor<3xi32>} : () -> (tensor<3xi32>) 4164 %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>) 4165 %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>) 4166 %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>) 4167 4168 // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"(%arg0) : (tensor<4x16x1022xf32>) -> tensor<4x16x1022xf32> 4169 // CHECK: [[REVERSE:%.*]] = "mhlo.reverse"([[RESHAPE]]) {dimensions = dense<2> : tensor<1xi64>} : (tensor<4x16x1022xf32>) -> tensor<4x16x1022xf32> 4170 // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4171 // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REVERSE]], [[ZERO]]) {edge_padding_high = dense<[0, 63, 2]> : tensor<3xi64>, edge_padding_low = dense<[0, 4, 0]> : tensor<3xi64>, interior_padding = dense<[0, 3, 0]> : tensor<3xi64>} : (tensor<4x16x1022xf32>, tensor<f32>) -> tensor<4x128x1024xf32> 4172 4173 %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 1, end_mask = 4} : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<4x16x1022xf32>) -> tensor<4x128x1024xf32> 4174 // CHECK: return [[PAD]] 4175 return %0: tensor<4x128x1024xf32> 4176} 4177 4178// CHECK-LABEL: strided_slice_grad_shrink_axis_mask 4179// CHECK-SAME: [[GRAD:%.*]]: tensor<8xf32> 4180func @strided_slice_grad_shrink_axis_mask(%grad: tensor<8xf32>) -> tensor<4x8xf32> { 4181 // Input to StridedSlice was of shape 4x8xf32 4182 // Strided slice gets input[2:3, 0:8] 4183 // shrink_axis_mask is 1 denoting that dim#0 is shrunk. So the output is 8xf32 4184 // which is the shape of gradient. 4185 // StridedSliceGrad would reshape the gradient to 1x8xf32 and 4186 // then pad to match the shape of input 4x8xf32. 4187 4188 %shape = "tf.Const"() {value = dense<[4, 8]> : tensor<2xi32>} : () -> (tensor<2xi32>) 4189 %begin = "tf.Const"() {value = dense<[2, 0]> : tensor<2xi32>} : () -> (tensor<2xi32>) 4190 %end = "tf.Const"() {value = dense<[3, 8]> : tensor<2xi32>} : () -> (tensor<2xi32>) 4191 %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>) 4192 4193 // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[GRAD]]) : (tensor<8xf32>) -> tensor<1x8xf32> 4194 // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4195 // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]]) 4196 // CHECK-DAG-SAME: edge_padding_low = dense<[2, 0]> : tensor<2xi64> 4197 // CHECK-DAG-SAME: edge_padding_high = dense<[1, 0]> : tensor<2xi64> 4198 // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<2xi64> 4199 %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, shrink_axis_mask = 1} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<8xf32>) -> tensor<4x8xf32> 4200 4201 // CHECK: return [[PAD]] : tensor<4x8xf32> 4202 return %0 : tensor<4x8xf32> 4203} 4204 4205// CHECK-LABEL: strided_slice_grad_new_axis_mask 4206// CHECK-SAME: [[GRAD:%.*]]: tensor<1x2xf32> 4207func @strided_slice_grad_new_axis_mask(%grad: tensor<1x2xf32>) -> tensor<8xf32> { 4208 // Input to StridedSlice was of shape 8xf32 4209 // Strided slice gets input[tf.new_axis, 2:4] 4210 // new_axis_mask is 1 denoting new axis is inserted at dim#0. So the output is 4211 // 1x2xf32 which is the shape of gradient. 4212 // StridedSliceGrad would reshape the gradient to 2xf32 and 4213 // then pad to match the shape of input 4x8xf32. 4214 4215 %shape = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) 4216 %begin = "tf.Const"() {value = dense<[0, 2]> : tensor<2xi32>} : () -> (tensor<2xi32>) 4217 %end = "tf.Const"() {value = dense<[0, 4]> : tensor<2xi32>} : () -> (tensor<2xi32>) 4218 %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>) 4219 4220 // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[GRAD]]) : (tensor<1x2xf32>) -> tensor<2xf32> 4221 // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4222 // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]]) 4223 // CHECK-DAG-SAME: edge_padding_low = dense<2> : tensor<1xi64> 4224 // CHECK-DAG-SAME: edge_padding_high = dense<4> : tensor<1xi64> 4225 // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<1xi64> 4226 %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, new_axis_mask = 1} : (tensor<1xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<1x2xf32>) -> tensor<8xf32> 4227 4228 // CHECK: return [[PAD]] : tensor<8xf32> 4229 return %0 : tensor<8xf32> 4230} 4231 4232// CHECK-LABEL: strided_slice_grad_ellipsis_mask 4233// CHECK-SAME: [[GRAD:%.*]]: tensor<2x4x8xf32> 4234func @strided_slice_grad_ellipsis_mask(%grad: tensor<2x4x8xf32>) -> tensor<4x4x8xf32> { 4235 // Input to StridedSlice was of shape 4x4x8xf32 4236 // Strided slice gets input[2:4, ...] 4237 // ellipsis_mask is 2 denoting that slice contains all elements in dim#1 and 4238 // dim#2, ignoring begin and end indices for these dimensions. So the output 4239 // is 2x4x8xf32 which is the shape of gradient. 4240 // StridedSliceGrad would pad the gradient to match the shape of 4241 // input 4x4x8xf32. 4242 4243 %shape = "tf.Const"() {value = dense<[4, 4, 8]> : tensor<3xi32>} : () -> (tensor<3xi32>) 4244 %begin = "tf.Const"() {value = dense<[2, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) 4245 %end = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi32>} : () -> (tensor<2xi32>) 4246 %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>) 4247 4248 // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[GRAD]]) : (tensor<2x4x8xf32>) -> tensor<2x4x8xf32> 4249 // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4250 // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]]) 4251 // CHECK-DAG-SAME: edge_padding_low = dense<[2, 0, 0]> : tensor<3xi64> 4252 // CHECK-DAG-SAME: edge_padding_high = dense<0> : tensor<3xi64> 4253 // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<3xi64> 4254 %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, ellipsis_mask = 2} : (tensor<3xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2x4x8xf32>) -> tensor<4x4x8xf32> 4255 4256 // CHECK: return [[PAD]] : tensor<4x4x8xf32> 4257 return %0 : tensor<4x4x8xf32> 4258} 4259 4260 4261// CHECK-LABEL: strided_slice_grad_all_masks 4262// CHECK-SAME: [[GRAD:%.*]]: tensor<1x4x8x8x10x2x1xf32> 4263func @strided_slice_grad_all_masks(%grad: tensor<1x4x8x8x10x2x1xf32>) -> tensor<2x4x8x16x32x64xf32> { 4264 // For StridedSlice input[1, tf.new_axis, ..., 8:, :10, 2:6:2, tf.new_axis] 4265 // New axis mask is at index 1 and 6 of sparse spec, so 4266 // new_axis_mask = 2^1 + 2^6 = 66 4267 // The ellipsis mask is applied to dim #1, #2 of input i.e, we get 4268 // canonicalized slice input[1, :, :, 8:, :10, 2:6:2] 4269 // The StridedSliceGrad op would propogate the gradient for the sliced tensor 4270 // to the original input tensor by padding with zeroes. 4271 4272 %shape = "tf.Const"() {value = dense<[2, 4, 8, 16, 32, 64]> : tensor<6xi32>} : () -> (tensor<6xi32>) 4273 %begin = "tf.Const"() {value = dense<[1, 0, 0, 8, 1, 2, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) 4274 %end = "tf.Const"() {value = dense<[2, 0, 0, 10, 10, 6, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) 4275 %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 1, 2, 1]> : tensor<7xi32>} : () -> (tensor<7xi32>) 4276 4277 // Remove 2 new axes (at index 1 and 6) and 1 shrink axis (at index 0) 4278 // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[GRAD]]) : (tensor<1x4x8x8x10x2x1xf32>) -> tensor<1x4x8x8x10x2xf32> 4279 // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4280 // The edge_padding_low, edge_padding_high and interior_padding attributes of 4281 // mhlo.pad would reflect the padding required to get the shape of the 4282 // input of StridedSlice op. 4283 // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZERO]]) 4284 // CHECK-DAG-SAME: edge_padding_low = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64> 4285 // CHECK-DAG-SAME: edge_padding_high = dense<[0, 0, 0, 0, 22, 59]> : tensor<6xi64> 4286 // CHECK-DAG-SAME: interior_padding = dense<[0, 0, 0, 0, 0, 1]> : tensor<6xi64> 4287 %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 16, end_mask = 8, shrink_axis_mask = 1, ellipsis_mask = 4, new_axis_mask = 66} : (tensor<6xi32>, tensor<7xi32>, tensor<7xi32>, tensor<7xi32>, tensor<1x4x8x8x10x2x1xf32>) -> tensor<2x4x8x16x32x64xf32> 4288 4289 // CHECK: return [[PAD]] : tensor<2x4x8x16x32x64xf32> 4290 return %0 : tensor<2x4x8x16x32x64xf32> 4291} 4292 4293// CHECK-LABEL: @tensor_scatter_update 4294func @tensor_scatter_update(%tensor: tensor<?x?x?xf32>, %indices: tensor<?x2xi32>, %updates: tensor<?x?xf32>) -> tensor<?x?x?xf32> { 4295 // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ( { 4296 // CHECK: ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): 4297 // CHECK: "mhlo.return"(%arg4) : (tensor<f32>) -> () 4298 // CHECK: }) 4299 // CHECK-SAME: indices_are_sorted = false 4300 // CHECK-SAME: scatter_dimension_numbers 4301 // CHECK-SAME: index_vector_dim = 1 : i64 4302 // CHECK-SAME: inserted_window_dims = dense<[0, 1]> : tensor<2xi64> 4303 // CHECK-SAME: scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64> 4304 // CHECK-SAME: update_window_dims = dense<1> : tensor<1xi64> 4305 // CHECK-SAME: unique_indices = false 4306 %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor<?x?x?xf32>, tensor<?x2xi32>, tensor<?x?xf32>) -> tensor<?x?x?xf32> 4307 return %0 : tensor<?x?x?xf32> 4308} 4309 4310//===----------------------------------------------------------------------===// 4311// tf.RandomShuffle legalization 4312//===----------------------------------------------------------------------===// 4313 4314// CHECK-LABEL: @random_shuffle_first_dim_1 4315// CHECK-SAME: [[INPUT:%.*]]: tensor<1x?xf32> 4316func @random_shuffle_first_dim_1(%input: tensor<1x?xf32>) -> tensor<1x?xf32> { 4317 %0 = "tf.RandomShuffle"(%input) : (tensor<1x?xf32>) -> (tensor<1x?xf32>) 4318 // CHECK-NEXT: return [[INPUT]] 4319 return %0: tensor<1x?xf32> 4320} 4321 4322// CHECK-LABEL: @random_shuffle_1D_16 4323// CHECK-SAME: [[INPUT:%.*]]: tensor<16xf32> 4324func @random_shuffle_1D_16(%input: tensor<16xf32>) -> tensor<16xf32> { 4325 // CHECK: [[SHAPE:%.*]] = mhlo.constant dense<16> : tensor<1xi64> 4326 // CHECK: [[LOWER:%.*]] = mhlo.constant dense<0> : tensor<i32> 4327 // CHECK: [[UPPER:%.*]] = mhlo.constant dense<-1> : tensor<i32> 4328 // CHECK: [[RNG:%.*]] = "mhlo.rng_uniform"([[LOWER]], [[UPPER]], [[SHAPE]]) 4329 // CHECK: [[SORT:%.*]]:2 = "mhlo.sort"([[RNG]], [[INPUT]]) ( { 4330 // CHECK: ^{{.*}}([[ARG1:%.*]]: tensor<i32>, [[ARG2:%.*]]: tensor<i32>, {{.*}}: tensor<f32>, {{.*}}: tensor<f32>): 4331 // CHECK: "mhlo.compare"([[ARG1]], [[ARG2]]) {comparison_direction = "LT"} 4332 // CHECK: }) {dimension = -1 : i64, is_stable = true} : (tensor<16xi32>, tensor<16xf32>) -> (tensor<16xi32>, tensor<16xf32>) 4333 // CHECK: return [[SORT]]#1 4334 %0 = "tf.RandomShuffle"(%input) : (tensor<16xf32>) -> (tensor<16xf32>) 4335 return %0: tensor<16xf32> 4336} 4337 4338// CHECK-LABEL: @random_shuffle_1D_10240 4339func @random_shuffle_1D_10240(%input: tensor<10240xf32>) -> tensor<10240xf32> { 4340 // CHECK: mhlo.rng_uniform 4341 // CHECK: mhlo.sort 4342 // CHECK: mhlo.rng_uniform 4343 // CHECK: mhlo.sort 4344 %0 = "tf.RandomShuffle"(%input) : (tensor<10240xf32>) -> (tensor<10240xf32>) 4345 return %0: tensor<10240xf32> 4346} 4347 4348// CHECK-LABEL: @random_shuffle_3D 4349// CHECK-SAME: [[INPUT:%.*]]: tensor<4x?x16xf32> 4350func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { 4351 // CHECK: [[INDICES:%.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> 4352 4353 // CHECK: [[RNG_SHAPE:%.*]] = mhlo.constant dense<4> : tensor<1xi64> 4354 // CHECK: [[RNG_LOWER:%.*]] = mhlo.constant dense<0> : tensor<i32> 4355 // CHECK: [[RNG_UPPER:%.*]] = mhlo.constant dense<4> : tensor<i32> 4356 // CHECK: [[SWAPS:%.*]] = "mhlo.rng_uniform"([[RNG_LOWER]], [[RNG_UPPER]], [[RNG_SHAPE]]) 4357 4358 // CHECK: [[IV_INIT:%.*]] = mhlo.constant dense<0> : tensor<i32> 4359 // CHECK: [[WHILE_INIT:%.*]] = "mhlo.tuple"([[IV_INIT]], [[SWAPS]], [[INDICES]]) 4360 4361 // CHECK: [[WHILE_OUT:%.*]] = "mhlo.while"([[WHILE_INIT]]) ( { 4362 // CHECK: ^{{.*}}([[COND_ARG:%.*]]: tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>>): 4363 // CHECK: [[IV:%.*]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32} 4364 // CHECK: [[LIMIT:%.*]] = mhlo.constant dense<4> : tensor<i32> 4365 // CHECK: [[CMP:%.*]] = "mhlo.compare"([[IV]], [[LIMIT]]) {comparison_direction = "LT"} 4366 // CHECK: "mhlo.return"([[CMP]]) 4367 // CHECK: }, { 4368 // CHECK: ^{{.*}}([[BODY_ARG:%.*]]: tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>>): 4369 // CHECK: [[IV:%.*]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 0 : i32} 4370 // CHECK: [[SWAPS:%.*]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 1 : i32} 4371 // CHECK: [[INDICES:%.*]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 2 : i32} 4372 // CHECK: [[SRC_IDX:%.*]] = "mhlo.dynamic-slice"([[INDICES]], [[IV]]) {slice_sizes = dense<1> : tensor<i64>} : (tensor<4xi32>, tensor<i32>) -> tensor<1xi32> 4373 // CHECK: [[SWP_IDX:%.*]] = "mhlo.dynamic-slice"([[SWAPS]], [[IV]]) {slice_sizes = dense<1> : tensor<i64>} : (tensor<4xi32>, tensor<i32>) -> tensor<1xi32> 4374 // CHECK: [[SWP:%.*]] = "mhlo.reshape"([[SWP_IDX]]) : (tensor<1xi32>) -> tensor<i32> 4375 // CHECK: [[TGT_IDX:%.*]] = "mhlo.dynamic-slice"([[INDICES]], [[SWP]]) {slice_sizes = dense<1> : tensor<i64>} 4376 // CHECK: [[INDICES1:%.*]] = "mhlo.dynamic-update-slice"([[INDICES]], [[TGT_IDX]], [[IV]]) : (tensor<4xi32>, tensor<1xi32>, tensor<i32>) -> tensor<4xi32> 4377 // CHECK: [[INDICES2:%.*]] = "mhlo.dynamic-update-slice"([[INDICES1]], [[SRC_IDX]], [[SWP]]) : (tensor<4xi32>, tensor<1xi32>, tensor<i32>) -> tensor<4xi32> 4378 // CHECK: [[ONE:%.*]] = mhlo.constant dense<1> : tensor<i32> 4379 // CHECK: [[NEW_IV:%.*]] = chlo.broadcast_add [[IV]], [[ONE]] 4380 // CHECK: [[NEW_TUPLE:%.*]] = "mhlo.tuple"([[NEW_IV]], [[SWAPS]], [[INDICES2]]) 4381 // CHECK: "mhlo.return"([[NEW_TUPLE]]) 4382 // CHECK: }) : (tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>>) -> tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>> 4383 4384 // CHECK: [[SWAPED_INDICES:%.*]] = "mhlo.get_tuple_element"([[WHILE_OUT]]) {index = 2 : i32} : (tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>>) -> tensor<4xi32> 4385 // CHECK: [[GATHER:%.*]] = "mhlo.gather"([[INPUT]], [[SWAPED_INDICES]]) 4386 // CHECK-SAME: dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 1 : i64, offset_dims = dense<[1, 2]> : tensor<2xi64>, start_index_map = dense<0> : tensor<1xi64>} 4387 // CHECK-SAME: indices_are_sorted = false 4388 // CHECK-SAME: slice_sizes = dense<[1, -1, 16]> : tensor<3xi64> 4389 // CHECK: (tensor<4x?x16xf32>, tensor<4xi32>) -> tensor<4x?x16xf32> 4390 4391 // CHECK: return [[GATHER]] 4392 4393 %0 = "tf.RandomShuffle"(%input) : (tensor<4x?x16xf32>) -> (tensor<4x?x16xf32>) 4394 return %0: tensor<4x?x16xf32> 4395} 4396 4397//===----------------------------------------------------------------------===// 4398// tf.AvgPool legalization 4399//===----------------------------------------------------------------------===// 4400 4401// CHECK-LABEL: @avgpool_valid_padding 4402// CHECK-SAME: [[ARG:%.+]]: tensor<2x12x21x7xf16> 4403// CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x12x21x7xf16>) -> tensor<2x12x21x7xf32> 4404// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4405// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( { 4406// CHECK: ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>): 4407// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] 4408// CHECK: "mhlo.return"([[ADD]]) 4409// CHECK: }) 4410// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> 4411// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> 4412// CHECK-SAME: -> tensor<2x3x5x7xf32> 4413// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32> 4414// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] 4415// CHECK-SAME: broadcast_dimensions = dense<> 4416// CHECK-SAME: -> tensor<2x3x5x7xf32> 4417// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]]) 4418// CHECK-SAME: -> tensor<2x3x5x7xf16> 4419// CHECK: return [[CONV16]] 4420func @avgpool_valid_padding(%arg0: tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16> { 4421 %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16> 4422 return %0 : tensor<2x3x5x7xf16> 4423} 4424 4425// CHECK-LABEL: @avgpool_3d_valid_padding 4426// CHECK-SAME: [[ARG:%.+]]: tensor<2x4x12x21x7xf16> 4427// CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x12x21x7xf32> 4428// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4429// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( { 4430// CHECK: ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>): 4431// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] 4432// CHECK: "mhlo.return"([[ADD]]) 4433// CHECK: }) 4434// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2, 1]> 4435// CHECK-SAME: window_strides = dense<[1, 1, 4, 4, 1]> 4436// CHECK-SAME: -> tensor<2x4x3x5x7xf32> 4437// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32> 4438// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] 4439// CHECK-SAME: broadcast_dimensions = dense<> 4440// CHECK-SAME: -> tensor<2x4x3x5x7xf32> 4441// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]]) 4442// CHECK-SAME: -> tensor<2x4x3x5x7xf16> 4443// CHECK: return [[CONV16]] 4444func @avgpool_3d_valid_padding(%arg0: tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16> { 4445 %0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16> 4446 return %0 : tensor<2x4x3x5x7xf16> 4447} 4448 4449// CHECK-LABEL: @avgpool_nchw_format 4450// CHECK-SAME: [[ARG:%.+]]: tensor<2x7x12x21xf16> 4451// CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x7x12x21xf16>) -> tensor<2x7x12x21xf32> 4452// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4453// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( { 4454// CHECK: ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>): 4455// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] 4456// CHECK: "mhlo.return"([[ADD]]) 4457// CHECK: }) 4458// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2]> 4459// CHECK-SAME: window_strides = dense<[1, 1, 4, 4]> 4460// CHECK-SAME: -> tensor<2x7x3x5xf32> 4461// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32> 4462// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] 4463// CHECK-SAME: broadcast_dimensions = dense<> 4464// CHECK-SAME: -> tensor<2x7x3x5xf32> 4465// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]]) 4466// CHECK-SAME: -> tensor<2x7x3x5xf16> 4467// CHECK: return [[CONV16]] 4468func @avgpool_nchw_format(%arg0: tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16> { 4469 %0 = "tf.AvgPool"(%arg0) {data_format = "NCHW", ksize = [1, 1, 2, 2], padding = "VALID", strides = [1, 1, 4, 4]} : (tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16> 4470 return %0 : tensor<2x7x3x5xf16> 4471} 4472 4473// CHECK-LABEL: @avgpool_3d_ncdhw_format 4474// CHECK-SAME: [[ARG:%.+]]: tensor<2x7x4x12x21xf16> 4475// CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x12x21xf32> 4476// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4477// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( { 4478// CHECK: ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>): 4479// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] 4480// CHECK: "mhlo.return"([[ADD]]) 4481// CHECK: }) 4482// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 2]> 4483// CHECK-SAME: window_strides = dense<[1, 1, 1, 4, 4]> 4484// CHECK-SAME: -> tensor<2x7x4x3x5xf32> 4485// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32> 4486// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] 4487// CHECK-SAME: broadcast_dimensions = dense<> 4488// CHECK-SAME: -> tensor<2x7x4x3x5xf32> 4489// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]]) 4490// CHECK-SAME: -> tensor<2x7x4x3x5xf16> 4491// CHECK: return [[CONV16]] 4492func @avgpool_3d_ncdhw_format(%arg0: tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16> { 4493 %0 = "tf.AvgPool3D"(%arg0) {data_format = "NCDHW", ksize = [1, 1, 1, 2, 2], padding = "VALID", strides = [1, 1, 1, 4, 4]} : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16> 4494 return %0 : tensor<2x7x4x3x5xf16> 4495} 4496 4497// CHECK-LABEL: @avgpool_same_padding( 4498// CHECK-SAME: %[[ARG0:.*]]: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> 4499// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4500// CHECK: %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) ( { 4501// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 4502// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 4503// CHECK: "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> () 4504// CHECK: }) 4505// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]> 4506// CHECK-SAME: window_dimensions = dense<[1, 5, 2, 1]> 4507// CHECK-SAME: window_strides = dense<[1, 3, 4, 1]> 4508// CHECK-SAME: -> tensor<2x4x6x7xf32> 4509// CHECK: %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x12x21x7xf32> 4510// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) ( { 4511// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>): 4512// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32> 4513// CHECK: "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> () 4514// CHECK: }) 4515// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]> 4516// CHECK-SAME: window_dimensions = dense<[1, 5, 2, 1]> 4517// CHECK-SAME: window_strides = dense<[1, 3, 4, 1]> 4518// CHECK-SAME: -> tensor<2x4x6x7xf32> 4519// CHECK: %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]] : tensor<2x4x6x7xf32> 4520// CHECK: return %[[RESULT]] : tensor<2x4x6x7xf32> 4521// CHECK: } 4522func @avgpool_same_padding(%arg0: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> { 4523 %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 5, 2, 1], padding = "SAME", strides = [1, 3, 4, 1]} : (tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> 4524 return %0 : tensor<2x4x6x7xf32> 4525} 4526 4527// CHECK-LABEL: @avgpool_3d_same_padding( 4528// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> 4529// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4530// CHECK: %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) ( { 4531// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 4532// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 4533// CHECK: "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> () 4534// CHECK: }) 4535// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]> 4536// CHECK-SAME: window_dimensions = dense<[1, 1, 5, 2, 1]> 4537// CHECK-SAME: window_strides = dense<[1, 1, 3, 4, 1]> 4538// CHECK-SAME: -> tensor<2x4x4x6x7xf32> 4539// CHECK: %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x4x12x21x7xf32> 4540// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) ( { 4541// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>): 4542// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32> 4543// CHECK: "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> () 4544// CHECK: }) 4545// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]> 4546// CHECK-SAME: window_dimensions = dense<[1, 1, 5, 2, 1]> 4547// CHECK-SAME: window_strides = dense<[1, 1, 3, 4, 1]> 4548// CHECK-SAME: -> tensor<2x4x4x6x7xf32> 4549// CHECK: %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]] 4550// CHECK: return %[[RESULT]] : tensor<2x4x4x6x7xf32> 4551// CHECK: } 4552func @avgpool_3d_same_padding(%arg0: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> { 4553 %0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 5, 2, 1], padding = "SAME", strides = [1, 1, 3, 4, 1]} : (tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> 4554 return %0 : tensor<2x4x4x6x7xf32> 4555} 4556 4557//===----------------------------------------------------------------------===// 4558// AvgPoolGrad op legalizations. 4559//===----------------------------------------------------------------------===// 4560 4561// CHECK-LABEL: @avgpool_grad_valid_padding( 4562// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> { 4563// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4564// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32> 4565// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] 4566// CHECK_SAME: broadcast_dimensions = dense<> 4567// CHECK_SAME: -> tensor<10x12x16x64xf32> 4568// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) 4569// CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> 4570// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> 4571// CHECK-SAME: interior_padding = dense<[0, 1, 1, 0]> 4572// CHECK-SAME: -> tensor<10x25x33x64xf32> 4573// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { 4574// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 4575// CHECK: %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 4576// CHECK: "mhlo.return"(%[[SUM]]) : (tensor<f32>) -> () 4577// CHECK: }) 4578// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> 4579// CHECK-SAME: window_strides = dense<1> 4580// CHECK-SAME: -> tensor<10x24x32x64xf32> 4581// CHECK: return %[[RESULT]] : tensor<10x24x32x64xf32> 4582func @avgpool_grad_valid_padding(%grad: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> { 4583 %orig_input_shape = "tf.Const"() {value = dense<[10, 24, 32, 64]> : tensor<4xi32>} : () -> (tensor<4xi32>) 4584 %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { 4585 data_format = "NHWC", 4586 ksize = [1, 2, 2, 1], 4587 padding = "VALID", 4588 strides = [1, 2, 2, 1] 4589 } : (tensor<4xi32>, tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> 4590 return %result : tensor<10x24x32x64xf32> 4591} 4592 4593// CHECK-LABEL: @avgpool_3d_grad_valid_padding( 4594// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> { 4595// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4596// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32> 4597// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<10x8x12x16x64xf32>, tensor<f32>) -> tensor<10x8x12x16x64xf32> 4598// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) 4599// CHECK-SAME: edge_padding_high = dense<[0, 0, 1, 1, 0]> 4600// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1, 0]> 4601// CHECK-SAME: interior_padding = dense<[0, 0, 1, 1, 0]> 4602// CHECK-SAME: -> tensor<10x8x25x33x64xf32> 4603// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { 4604// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 4605// CHECK: %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 4606// CHECK: "mhlo.return"(%[[SUM]]) : (tensor<f32>) -> () 4607// CHECK: }) 4608// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2, 1]> 4609// CHECK-SAME: window_strides = dense<1> 4610// CHECK-SAME: -> tensor<10x8x24x32x64xf32> 4611// CHECK: return %[[RESULT]] : tensor<10x8x24x32x64xf32> 4612func @avgpool_3d_grad_valid_padding(%grad: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> { 4613 %orig_input_shape = "tf.Const"() {value = dense<[10, 8, 24, 32, 64]> : tensor<5xi32>} : () -> (tensor<5xi32>) 4614 %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) { 4615 data_format = "NDHWC", 4616 ksize = [1, 1, 2, 2, 1], 4617 padding = "VALID", 4618 strides = [1, 1, 2, 2, 1]} : (tensor<5xi32>, tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> 4619 return %result : tensor<10x8x24x32x64xf32> 4620} 4621 4622// CHECK-LABEL: @avgpool_grad_same_padding( 4623// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> { 4624// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4625// CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x13x25x9xf32> 4626// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( { 4627// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 4628// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 4629// CHECK: "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> () 4630// CHECK: }) 4631// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> 4632// CHECK-SAME: window_dimensions = dense<[1, 2, 3, 1]> 4633// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> 4634// CHECK-SAME: -> tensor<2x4x7x9xf32> 4635// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x4x7x9xf32> 4636// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) 4637// CHECK-SAME: edge_padding_high = dense<[0, 0, 1, 0]> 4638// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> 4639// CHECK-SAME: interior_padding = dense<[0, 3, 3, 0]> 4640// CHECK-SAME: -> tensor<2x14x27x9xf32> 4641// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { 4642// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>): 4643// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32> 4644// CHECK: "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> () 4645// CHECK: }) 4646// CHECK-SAME: window_dimensions = dense<[1, 2, 3, 1]> 4647// CHECK-SAME: window_strides = dense<1> 4648// CHECK-SAME: -> tensor<2x13x25x9xf32> 4649// CHECK: return %[[RESULT]] : tensor<2x13x25x9xf32> 4650func @avgpool_grad_same_padding(%grad: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> { 4651 %orig_input_shape = "tf.Const"() {value = dense<[2, 13, 25, 9]> : tensor<4xi32>} : () -> (tensor<4xi32>) 4652 %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { 4653 data_format = "NHWC", 4654 ksize = [1, 2, 3, 1], 4655 padding = "SAME", 4656 strides = [1, 4, 4, 1] 4657 } : (tensor<4xi32>, tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> 4658 return %result : tensor<2x13x25x9xf32> 4659} 4660 4661// CHECK-LABEL: @avgpool_3d_grad_same_padding( 4662// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> { 4663// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4664// CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x8x13x25x9xf32> 4665// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( { 4666// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 4667// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 4668// CHECK: "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> () 4669// CHECK: }) 4670// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> 4671// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3, 1]> 4672// CHECK-SAME: window_strides = dense<[1, 1, 4, 4, 1]> 4673// CHECK-SAME: -> tensor<2x8x4x7x9xf32> 4674// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x8x4x7x9xf32> 4675// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) 4676// CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 1, 0]> 4677// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1, 0]> 4678// CHECK-SAME: interior_padding = dense<[0, 0, 3, 3, 0]> 4679// CHECK-SAME: -> tensor<2x8x14x27x9xf32> 4680// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { 4681// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>): 4682// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32> 4683// CHECK: "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> () 4684// CHECK: }) 4685// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3, 1]> 4686// CHECK-SAME: window_strides = dense<1> 4687// CHECK-SAME: -> tensor<2x8x13x25x9xf32> 4688// CHECK: return %[[RESULT]] : tensor<2x8x13x25x9xf32> 4689func @avgpool_3d_grad_same_padding(%grad: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> { 4690 %orig_input_shape = "tf.Const"() {value = dense<[2, 8, 13, 25, 9]> : tensor<5xi32>} : () -> (tensor<5xi32>) 4691 %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) { 4692 data_format = "NDHWC", 4693 ksize = [1, 1, 2, 3, 1], 4694 padding = "SAME", 4695 strides = [1, 1, 4, 4, 1]} : (tensor<5xi32>, tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> 4696 return %result : tensor<2x8x13x25x9xf32> 4697} 4698 4699// CHECK-LABEL: @avgpool_grad_nchw_format( 4700// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> { 4701// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4702// CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x13x25xf32> 4703// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( { 4704// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 4705// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 4706// CHECK: "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> () 4707// CHECK: }) 4708// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1]]> 4709// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3]> 4710// CHECK-SAME: window_strides = dense<[1, 1, 4, 4]> 4711// CHECK-SAME: -> tensor<2x9x4x7xf32> 4712// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x4x7xf32> 4713// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) 4714// CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 1]> 4715// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1]> 4716// CHECK-SAME: interior_padding = dense<[0, 0, 3, 3]> 4717// CHECK-SAME: -> tensor<2x9x14x27xf32> 4718// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { 4719// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>): 4720// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32> 4721// CHECK: "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> () 4722// CHECK: }) 4723// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3]> 4724// CHECK-SAME: window_strides = dense<1> 4725// CHECK-SAME: -> tensor<2x9x13x25xf32> 4726// CHECK: return %[[RESULT]] : tensor<2x9x13x25xf32> 4727func @avgpool_grad_nchw_format(%grad: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> { 4728 %orig_input_shape = "tf.Const"() {value = dense<[2, 9, 13, 25]> : tensor<4xi32>} : () -> (tensor<4xi32>) 4729 %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { 4730 data_format = "NCHW", 4731 ksize = [1, 1, 2, 3], 4732 padding = "SAME", 4733 strides = [1, 1, 4, 4] 4734 } : (tensor<4xi32>, tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> 4735 return %result : tensor<2x9x13x25xf32> 4736} 4737 4738// CHECK-LABEL: @avgpool_3d_grad_ncdwh_format( 4739// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> { 4740// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4741// CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x8x13x25xf32> 4742// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( { 4743// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 4744// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 4745// CHECK: "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> () 4746// CHECK: }) 4747// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 0], [0, 1], [1, 1]]> 4748// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 3]> 4749// CHECK-SAME: window_strides = dense<[1, 1, 1, 4, 4]> 4750// CHECK-SAME: -> tensor<2x9x8x4x7xf32> 4751// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x8x4x7xf32> 4752// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) 4753// CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 0, 1]> 4754// CHECK-SAME: edge_padding_low = dense<[0, 0, 0, 1, 1]> 4755// CHECK-SAME: interior_padding = dense<[0, 0, 0, 3, 3]> 4756// CHECK-SAME: -> tensor<2x9x8x14x27xf32> 4757// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( { 4758// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>): 4759// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32> 4760// CHECK: "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> () 4761// CHECK: }) 4762// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 3]> 4763// CHECK-SAME: window_strides = dense<1> : tensor<5xi64> 4764// CHECK-SAME: -> tensor<2x9x8x13x25xf32> 4765// CHECK: return %[[RESULT]] : tensor<2x9x8x13x25xf32> 4766func @avgpool_3d_grad_ncdwh_format(%grad: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> { 4767 %orig_input_shape = "tf.Const"() {value = dense<[2, 9, 8, 13, 25]> : tensor<5xi32>} : () -> (tensor<5xi32>) 4768 %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) { 4769 data_format = "NCDHW", 4770 ksize = [1, 1, 1, 2, 3], 4771 padding = "SAME", 4772 strides = [1, 1, 1, 4, 4]} : (tensor<5xi32>, tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> 4773 return %result : tensor<2x9x8x13x25xf32> 4774} 4775 4776// CHECK-LABEL: @avgpool_grad_bf16( 4777// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> { 4778// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<bf16> 4779// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<bf16> 4780// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] 4781// CHECK-SAME: broadcast_dimensions = dense<> 4782// CHECK-SAME: -> tensor<10x12x16x64xbf16> 4783// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) 4784// CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> 4785// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> 4786// CHECK-SAME: interior_padding = dense<[0, 1, 1, 0]> 4787// CHECK-SAME: -> tensor<10x25x33x64xbf16> 4788// CHECK: %[[REDUCE_WINDOW_INPUT_CONVERTED:.*]] = "mhlo.convert"(%[[REDUCE_WINDOW_INPUT]]) : (tensor<10x25x33x64xbf16>) -> tensor<10x25x33x64xf32> 4789// CHECK: %[[ZERO_F32:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4790// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT_CONVERTED]], %[[ZERO_F32]]) ( { 4791// CHECK: ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>): 4792// CHECK: %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32> 4793// CHECK: "mhlo.return"(%[[SUM]]) : (tensor<f32>) -> () 4794// CHECK: }) 4795// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> 4796// CHECK-SAME: window_strides = dense<1> 4797// CHECK-SAME: -> tensor<10x24x32x64xf32> 4798// CHECK: %[[RESULT_CONVERTED:.*]] = "mhlo.convert"(%[[RESULT]]) : (tensor<10x24x32x64xf32>) -> tensor<10x24x32x64xbf16> 4799// CHECK: return %[[RESULT_CONVERTED]] : tensor<10x24x32x64xbf16> 4800func @avgpool_grad_bf16(%grad: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> { 4801 %orig_input_shape = "tf.Const"() {value = dense<[10, 24, 32, 64]> : tensor<4xi32>} : () -> (tensor<4xi32>) 4802 %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { 4803 data_format = "NHWC", 4804 ksize = [1, 2, 2, 1], 4805 padding = "VALID", 4806 strides = [1, 2, 2, 1] 4807 } : (tensor<4xi32>, tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> 4808 return %result : tensor<10x24x32x64xbf16> 4809} 4810 4811// CHECK-LABEL: xla_sharding 4812func @xla_sharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> { 4813 // CHECK-NEXT: "mhlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", has_side_effect = false, mhlo.sharding = ""} 4814 %0 = "tf.XlaSharding"(%arg0) {_XlaSharding = "", sharding = ""} : (tensor<4x16xf32>) -> tensor<4x16xf32> 4815 return %0 : tensor<4x16xf32> 4816} 4817 4818// CHECK-LABEL: inplace_update_one 4819func @inplace_update_one(%arg0: tensor<8x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<1xi32>) -> tensor<8x4xf32> { 4820 // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0> 4821 // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} 4822 // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} 4823 // CHECK-DAG: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]]) 4824 // CHECK-DAG: [[UPDATE:%.+]] = "mhlo.dynamic-update-slice"(%arg0, [[SLICE2]], [[RESHAPE1]], [[CST]]) 4825 %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x4xf32>, tensor<1xi32>, tensor<1x4xf32>) -> tensor<8x4xf32> 4826 4827 // CHECK: return [[UPDATE]] 4828 return %0 : tensor<8x4xf32> 4829} 4830 4831// CHECK-LABEL: inplace_update_three 4832func @inplace_update_three(%arg0: tensor<8x8x4xf32>, %arg1: tensor<3x8x4xf32>, %arg2: tensor<3xi32>) -> tensor<8x8x4xf32> { 4833 // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0> 4834 // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} 4835 // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} 4836 // CHECK-DAG: [[SLICE3:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<3> : tensor<1xi64>, start_indices = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} 4837 // CHECK-DAG: [[SLICE4:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 8, 4]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} 4838 // CHECK-DAG: [[SLICE5:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[2, 8, 4]> : tensor<3xi64>, start_indices = dense<[1, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} 4839 // CHECK-DAG: [[SLICE6:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[3, 8, 4]> : tensor<3xi64>, start_indices = dense<[2, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} 4840 // CHECK-DAG: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]]) 4841 // CHECK-DAG: [[RESHAPE2:%.+]] = "mhlo.reshape"([[SLICE2]]) 4842 // CHECK-DAG: [[RESHAPE3:%.+]] = "mhlo.reshape"([[SLICE3]]) 4843 // CHECK-DAG: [[UPDATE1:%.+]] = "mhlo.dynamic-update-slice"(%arg0, [[SLICE4]], [[RESHAPE1]], [[CST]], [[CST]]) 4844 // CHECK-DAG: [[UPDATE2:%.+]] = "mhlo.dynamic-update-slice"([[UPDATE1]], [[SLICE5]], [[RESHAPE2]], [[CST]], [[CST]]) 4845 // CHECK-DAG: [[UPDATE3:%.+]] = "mhlo.dynamic-update-slice"([[UPDATE2]], [[SLICE6]], [[RESHAPE3]], [[CST]], [[CST]]) 4846 %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x8x4xf32>, tensor<3xi32>, tensor<3x8x4xf32>) -> tensor<8x8x4xf32> 4847 4848 // CHECK: return [[UPDATE3]] : tensor<8x8x4xf32> 4849 return %0 : tensor<8x8x4xf32> 4850} 4851 4852 4853// CHECK-LABEL: xla_dynamic_update_slice 4854func @xla_dynamic_update_slice(%arg0: tensor<4x16xf32>, %arg1: tensor<2x4xf32>, %arg2: tensor<2xi32>) -> tensor<4x16xf32> { 4855 // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> 4856 // CHECK: [[RESHAPE0:%.+]] = "mhlo.reshape"([[SLICE0]]) : (tensor<1xi32>) -> tensor<i32> 4857 // CHECK: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> 4858 // CHECK: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]]) : (tensor<1xi32>) -> tensor<i32> 4859 // CHECK: [[DUS:%.+]] = "mhlo.dynamic-update-slice"(%arg0, %arg1, [[RESHAPE0]], [[RESHAPE1]]) : (tensor<4x16xf32>, tensor<2x4xf32>, tensor<i32>, tensor<i32>) -> tensor<4x16xf32> 4860 // CHECK: return [[DUS]] 4861 %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4x16xf32>, tensor<2x4xf32>, tensor<2xi32>) -> tensor<4x16xf32> 4862 return %0 : tensor<4x16xf32> 4863} 4864 4865// CHECK-LABEL: xla_dynamic_update_slice2 4866func @xla_dynamic_update_slice2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg2: tensor<1xi32>) -> tensor<4xf32> { 4867 // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32> 4868 // CHECK: [[RESHAPE0:%.+]] = "mhlo.reshape"([[SLICE0]]) : (tensor<1xi32>) -> tensor<i32> 4869 // CHECK: [[DUS:%.+]] = "mhlo.dynamic-update-slice"(%arg0, %arg1, [[RESHAPE0]]) : (tensor<4xf32>, tensor<2xf32>, tensor<i32>) -> tensor<4xf32> 4870 // CHECK: return [[DUS]] 4871 %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor<1xi32>) -> tensor<4xf32> 4872 return %0 : tensor<4xf32> 4873} 4874 4875//===----------------------------------------------------------------------===// 4876// AllToAll op legalizations. 4877//===----------------------------------------------------------------------===// 4878 4879// CHECK-LABEL: func @alltoall_basic 4880func @alltoall_basic(%input: tensor<10xf32>) -> tensor<10xf32> { 4881 %group_assignment = "tf.Const" () { 4882 value = dense<[[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi32> 4883 } : () -> tensor<3x4xi32> 4884 %result = "tf.AllToAll"(%input, %group_assignment) {T = f32, concat_dimension = 1 : i64, split_count = 2 : i64, split_dimension = 0 : i64} : (tensor<10xf32>, tensor<3x4xi32>) -> tensor<10xf32> 4885 // CHECK: mhlo.all_to_all 4886 // CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi64> 4887 return %result : tensor<10xf32> 4888} 4889 4890//===----------------------------------------------------------------------===// 4891// Cumsum op legalizations. 4892//===----------------------------------------------------------------------===// 4893 4894// CHECK-LABEL: func @cumsum_static 4895// CHECK-SAME: [[X:%.*]]: tensor<4xf32> 4896func @cumsum_static(%arg0: tensor<4xf32>) -> tensor<4xf32> { 4897 // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor<i32> 4898 // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[X]]) : (tensor<4xf32>) -> tensor<4xf32> 4899 // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4900 // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ( { 4901 // CHECK: ^bb0([[A:%.*]]: tensor<f32>, [[B:%.*]]: tensor<f32>): 4902 // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor<f32> 4903 // CHECK: "mhlo.return"([[SUM]]) : (tensor<f32>) -> () 4904 // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> 4905 // CHECK: [[CONVERT_REDUCE:%.*]] = "mhlo.convert"([[REDUCE]]) : (tensor<4xf32>) -> tensor<4xf32> 4906 // CHECK: return [[CONVERT_REDUCE]] 4907 %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32> 4908 %1 = "tf.Cumsum"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32> 4909 return %1 : tensor<4xf32> 4910} 4911 4912// CHECK-LABEL: func @cumsum_exclusive 4913// CHECK-SAME: [[X:%.*]]: tensor<4xf32> 4914func @cumsum_exclusive(%arg0: tensor<4xf32>) -> tensor<4xf32> { 4915 // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor<i32> 4916 // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[X]]) : (tensor<4xf32>) -> tensor<4xf32> 4917 // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4918 // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ( { 4919 // CHECK: ^bb0([[A:%.*]]: tensor<f32>, [[B:%.*]]: tensor<f32>): 4920 // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor<f32> 4921 // CHECK: "mhlo.return"([[SUM]]) : (tensor<f32>) -> () 4922 // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> 4923 // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) {edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> 4924 // CHECK: [[CONVERT_REDUCE:%.*]] = "mhlo.convert"([[PAD]]) : (tensor<4xf32>) -> tensor<4xf32> 4925 // CHECK: return [[CONVERT_REDUCE]] 4926 %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32> 4927 %1 = "tf.Cumsum"(%arg0, %0) {exclusive = true, reverse = false} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32> 4928 return %1 : tensor<4xf32> 4929} 4930 4931// CHECK-LABEL: func @cumsum_reverse 4932// CHECK-SAME: [[X:%.*]]: tensor<4xf32> 4933func @cumsum_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> { 4934 // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor<i32> 4935 // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> 4936 // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[REVERSE1]]) : (tensor<4xf32>) -> tensor<4xf32> 4937 // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4938 // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ( { 4939 // CHECK: ^bb0([[A:%.*]]: tensor<f32>, [[B:%.*]]: tensor<f32>): 4940 // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor<f32> 4941 // CHECK: "mhlo.return"([[SUM]]) : (tensor<f32>) -> () 4942 // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> 4943 // CHECK: [[CONVERT_REDUCE:%.*]] = "mhlo.convert"([[REDUCE]]) : (tensor<4xf32>) -> tensor<4xf32> 4944 // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> 4945 // CHECK: return [[REVERSE_BACK]] 4946 %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32> 4947 %1 = "tf.Cumsum"(%arg0, %0) {exclusive = false, reverse = true} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32> 4948 return %1 : tensor<4xf32> 4949} 4950 4951// CHECK-LABEL: func @cumsum_exclusive_reverse 4952// CHECK-SAME: [[X:%.*]]: tensor<4xf32> 4953func @cumsum_exclusive_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> { 4954 // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor<i32> 4955 // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> 4956 // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[REVERSE1]]) : (tensor<4xf32>) -> tensor<4xf32> 4957 // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32> 4958 // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ( { 4959 // CHECK: ^bb0([[A:%.*]]: tensor<f32>, [[B:%.*]]: tensor<f32>): 4960 // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor<f32> 4961 // CHECK: "mhlo.return"([[SUM]]) : (tensor<f32>) -> () 4962 // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> 4963 // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) {edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> 4964 // CHECK: [[CONVERT_REDUCE:%.*]] = "mhlo.convert"([[PAD]]) : (tensor<4xf32>) -> tensor<4xf32> 4965 // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> 4966 // CHECK: return [[REVERSE_BACK]] 4967 %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32> 4968 %1 = "tf.Cumsum"(%arg0, %0) {exclusive = true, reverse = true} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32> 4969 return %1 : tensor<4xf32> 4970} 4971 4972// CHECK-LABEL: func @cumsum_empty 4973func @cumsum_empty(%arg0: tensor<0xf32>) -> tensor<0xf32> { 4974 %0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> 4975 4976 // CHECK: mhlo.constant dense<> : tensor<0xf32> 4977 %1 = "tf.Cumsum"(%arg0, %0) : (tensor<0xf32>, tensor<i32>) -> tensor<0xf32> 4978 return %1 : tensor<0xf32> 4979} 4980 4981// CHECK-LABEL: func @cumsum_dynamic 4982func @cumsum_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<i32>) -> tensor<?xf32> { 4983 // CHECK: "tf.Cumsum" 4984 %0 = "tf.Cumsum"(%arg0, %arg1) : (tensor<?xf32>, tensor<i32>) -> tensor<?xf32> 4985 return %0 : tensor<?xf32> 4986} 4987 4988//===----------------------------------------------------------------------===// 4989// Cumprod op legalizations. 4990//===----------------------------------------------------------------------===// 4991 4992// CHECK-LABEL: func @cumprod 4993func @cumprod(%arg0: tensor<4xf32>) -> tensor<4xf32> { 4994 // CHECK: [[INIT:%.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> 4995 // CHECK: "mhlo.reduce_window"({{.*}}, [[INIT]]) ( { 4996 // CHECK: mhlo.mul 4997 %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32> 4998 %1 = "tf.Cumprod"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32> 4999 return %1 : tensor<4xf32> 5000} 5001 5002//===----------------------------------------------------------------------===// 5003// Qr op legalization 5004//===----------------------------------------------------------------------===// 5005 5006// CHECK: func @qr([[VAL_0:%.*]]: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) 5007func @qr(%arg0: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) { 5008 // The tf.Qr lowering is a full algorithm that is not effective to verify with 5009 // FileCheck. Just verify that it converted. 5010 // TODO(laurenzo): Move this out of the mainline tf2xla conversion as it is 5011 // really only applicable to certain legacy uses. 5012 // CHECK-NOT: "tf.Qr" 5013 %0:2 = "tf.Qr"(%arg0) {full_matrices = false} : (tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) 5014 return %0#0, %0#1 : tensor<500x100x75xf32>, tensor<500x75x75xf32> 5015} 5016 5017//===----------------------------------------------------------------------===// 5018// tf.Softplus legalization 5019//===----------------------------------------------------------------------===// 5020 5021// CHECK-LABEL: func @softplus_f16 5022// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf16>) 5023func @softplus_f16(%arg0: tensor<8x16xf16>) -> tensor<8x16xf16> { 5024 // CHECK-DAG: [[FEATURES_EXP:%.*]] = "mhlo.exponential"([[FEATURES]]) 5025 // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.220700e-04> : tensor<f16> 5026 // CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]]) 5027 // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<f16> 5028 // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] 5029 // CHECK: [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]]) 5030 // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"} 5031 // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"} 5032 // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]]) 5033 // CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]]) 5034 // CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]]) 5035 %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf16>) -> tensor<8x16xf16> 5036 5037 // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf16> 5038 return %0 : tensor<8x16xf16> 5039} 5040 5041// CHECK-LABEL: func @softplus_bf16 5042// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xbf16>) 5043func @softplus_bf16(%arg0: tensor<8x16xbf16>) -> tensor<8x16xbf16> { 5044 // CHECK-DAG: [[FEATURES_EXP:%.*]] = "mhlo.exponential"([[FEATURES]]) 5045 // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<7.812500e-03> : tensor<bf16> 5046 // CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]]) 5047 // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<bf16> 5048 // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] 5049 // CHECK: [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]]) 5050 // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"} 5051 // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"} 5052 // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]]) 5053 // CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]]) 5054 // CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]]) 5055 %0 = "tf.Softplus"(%arg0) : (tensor<8x16xbf16>) -> tensor<8x16xbf16> 5056 5057 // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xbf16> 5058 return %0 : tensor<8x16xbf16> 5059} 5060 5061// CHECK-LABEL: func @softplus_f32 5062// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf32>) 5063func @softplus_f32(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { 5064 // CHECK-DAG: [[FEATURES_EXP:%.*]] = "mhlo.exponential"([[FEATURES]]) 5065 // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.1920929E-7> : tensor<f32> 5066 // CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]]) 5067 // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32> 5068 // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] 5069 // CHECK: [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]]) 5070 // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"} 5071 // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"} 5072 // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]]) 5073 // CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]]) 5074 // CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]]) 5075 %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> 5076 5077 // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf32> 5078 return %0 : tensor<8x16xf32> 5079} 5080 5081// CHECK-LABEL: func @softplus_f64 5082// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf64>) 5083func @softplus_f64(%arg0: tensor<8x16xf64>) -> tensor<8x16xf64> { 5084 // CHECK-DAG: [[FEATURES_EXP:%.*]] = "mhlo.exponential"([[FEATURES]]) 5085 // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<2.2204460492503131E-16> : tensor<f64> 5086 // CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]]) 5087 // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<f64> 5088 // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] 5089 // CHECK: [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]]) 5090 // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"} 5091 // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"} 5092 // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]]) 5093 // CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]]) 5094 // CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]]) 5095 %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf64>) -> tensor<8x16xf64> 5096 5097 // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf64> 5098 return %0 : tensor<8x16xf64> 5099} 5100 5101// CHECK-LABEL: @xla_gather 5102func @xla_gather(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10x1x300xf32> { 5103 %cst = "tf.Const"() { value = dense<[1, 1, 300]> : tensor<3xi64> } : () -> tensor<3xi64> 5104 5105 // CHECK: "mhlo.gather" 5106 // CHECK-SAME: dimension_numbers = 5107 // CHECK-SAME: collapsed_slice_dims = dense<0> : tensor<1xi64> 5108 // CHECK-SAME: index_vector_dim = 1 : i64 5109 // CHECK-SAME: offset_dims = dense<1> : tensor<1xi64> 5110 // CHECK-SAME: start_index_map = dense<0> : tensor<1xi64> 5111 // CHECK-SAME: indices_are_sorted = true 5112 // CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64> 5113 5114 %0 = "tf.XlaGather"(%arg0, %arg1, %cst) {dimension_numbers = "\0A\01\01\12\01\00\1A\01\00 \01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi64>) -> tensor<10x1x300xf32> 5115 return %0 : tensor<10x1x300xf32> 5116} 5117 5118// CHECK-LABEL: @xla_gather_i32 5119func @xla_gather_i32(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10x1x300xf32> { 5120 %cst = "tf.Const"() { value = dense<[1, 1, 300]> : tensor<3xi32> } : () -> tensor<3xi32> 5121 5122 // CHECK: "mhlo.gather" 5123 // CHECK-SAME: dimension_numbers = 5124 // CHECK-SAME: collapsed_slice_dims = dense<0> : tensor<1xi64> 5125 // CHECK-SAME: index_vector_dim = 1 : i64 5126 // CHECK-SAME: offset_dims = dense<1> : tensor<1xi64> 5127 // CHECK-SAME: start_index_map = dense<0> : tensor<1xi64> 5128 // CHECK-SAME: indices_are_sorted = true 5129 // CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64> 5130 5131 %0 = "tf.XlaGather"(%arg0, %arg1, %cst) {dimension_numbers = "\0A\01\01\12\01\00\1A\01\00 \01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi32>) -> tensor<10x1x300xf32> 5132 return %0 : tensor<10x1x300xf32> 5133} 5134 5135 5136// CHECK: func @stridedslice_with_i32 5137func @stridedslice_with_i32(%arg0: tensor<i32>) -> tensor<4xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "const_0_arg", outputs = "identity_0_retval_RetVal"}} { 5138// CHECK-NOT: tf.StridedSlice 5139// CHECK: [[DYNSLICE:%.*]] = "mhlo.dynamic-slice 5140// CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[DYNSLICE]]) 5141// CHECK: return [[RESHAPE]] 5142 %0 = "tf.Const"() {value = dense<[[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32> 5143 %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> 5144 %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> 5145 %3 = "tf.AddV2"(%arg0, %1) {_xla_inferred_shapes = [#tf.shape<>], device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32> 5146 %4 = "tf.Pack"(%3) {_xla_inferred_shapes = [#tf.shape<1>], axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32> 5147 %5 = "tf.Pack"(%arg0) {_xla_inferred_shapes = [#tf.shape<1>], axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32> 5148 %6 = "tf.StridedSlice"(%0, %5, %4, %2) {_xla_inferred_shapes = [#tf.shape<4>], begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2x4xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xf32> 5149 return %6 : tensor<4xf32> 5150} 5151 5152func @replica_id() -> tensor<i32> { 5153 // CHECK: %[[ID:.*]] = "mhlo.replica_id"() : () -> tensor<ui32> 5154 // CHECK: %[[RESULT:.*]] = "mhlo.convert"(%0) : (tensor<ui32>) -> tensor<i32> 5155 %0 = "tf.XlaReplicaId"() : () -> tensor<i32> 5156 return %0 : tensor<i32> 5157} 5158 5159// CHECK: func @angle_c64 5160// CHECK-SAME: ([[ARG0:%.*]]: tensor<complex<f32>>) 5161func @angle_c64(%arg0: tensor<complex<f32>>) -> tensor<f32> { 5162// CHECK: [[IMAG:%.*]] = "mhlo.imag"([[ARG0]]) 5163// CHECK: [[REAL:%.*]] = "mhlo.real"([[ARG0]]) 5164// CHECK: [[ATAN2:%.*]] = mhlo.atan2 [[IMAG]], [[REAL]] 5165 %0 = "tf.Angle"(%arg0): (tensor<complex<f32>>) -> tensor<f32> 5166 return %0 : tensor<f32> 5167} 5168