1// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=false" %s | FILECHECK_OPTS="" FileCheck %s
2// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" -verify-diagnostics %s | FileCheck %s --check-prefix CHLO --dump-input-filter=all
3// This test runs twice:
4//   1. Through FILECHECK_OPTS="" FileCheck with chlo legalization disabled since verifying
5//      that the chlo ops emit produces more useful tests.
6//   2. With chlo legalization enabled, verifying diagnostics to pick up any
7//      issues with the full lowering (can catch some broadcasting corner
8//      cases which emit with a warning).
9
10//===----------------------------------------------------------------------===//
11// BatchNorm op legalizations.
12//===----------------------------------------------------------------------===//
13
14// CHECK-LABEL: fusedBatchNorm_notraining
15func @fusedBatchNorm_notraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
16  // CHECK: "mhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
17  %0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
18  return %0#0 : tensor<8x8x8x8xf32>
19}
20
21// CHECK-LABEL: fusedBatchNorm_training
22func @fusedBatchNorm_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
23  // TODO(riverriddle) Support training.
24  // CHECK: "tf.FusedBatchNorm"
25  %0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true}  : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
26  return %0#0 : tensor<8x8x8x8xf32>
27}
28
29// fusedBatchNormV2 is almost identical to fusedBatchNormV3 (and uses the same
30// code), so only do a couple of basic checks.
31
32// CHECK-LABEL: fusedBatchNormV2_noTraining
33func @fusedBatchNormV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
34  // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
35  %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
36  return %0#0 : tensor<8x8x8x8xf32>
37}
38
39// CHECK-LABEL: fusedBatchNormV2_training
40func @fusedBatchNormV2_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
41  // CHECK: %[[RESULT0:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
42  %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
43  // CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32>
44  // CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
45  // CHECK: %[[VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
46  // CHECK: mhlo.constant
47  // CHECK: chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
48  return %0#0 : tensor<8x8x8x8xf32>
49}
50
51// CHECK-LABEL: fusedBatchNormV3_noTraining
52func @fusedBatchNormV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
53  // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
54  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
55  return %0#0 : tensor<8x8x8x8xf32>
56}
57
58// CHECK-LABEL: fusedBatchNormV3_noTraining_mixedPrecision
59// CHECK-SAME:  ([[X:%.*]]: tensor<8x8x8x8xbf16>, [[SCALE:%.*]]: tensor<8xf32>, [[OFFSET:%.*]]: tensor<8xf32>, [[MEAN:%.*]]: tensor<8xf32>, [[VARIANCE:%.*]]: tensor<8xf32>)
60func @fusedBatchNormV3_noTraining_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) {
61  // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[X]]) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
62  // CHECK: [[Y:%.*]] = "mhlo.batch_norm_inference"([[CONVERT_X]], [[SCALE]], [[OFFSET]], [[MEAN]], [[VARIANCE]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
63  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>)
64  // CHECK: [[Y_CONVERT:%.*]] = "mhlo.convert"([[Y]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
65  // CHECK: [[DUMMY:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<0xf32>
66  // CHECK: [[DUMMY_CAST:%.*]] = tensor.cast [[DUMMY]] : tensor<0xf32> to tensor<*xf32>
67  // CHECK: return [[Y_CONVERT]], [[MEAN]], [[VARIANCE]], [[MEAN]], [[VARIANCE]], [[DUMMY_CAST]]
68  return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>
69}
70
71// CHECK-LABEL: fusedBatchNormV3_training
72func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
73  // CHECK: %[[RESULT0:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
74  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
75  // CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32>
76  // CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
77  // CHECK: %[[VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
78  // CHECK: mhlo.constant
79  // CHECK: chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
80  return %0#0 : tensor<8x8x8x8xf32>
81}
82
83// CHECK-LABEL: func @fusedBatchNormV3_training_batchVariance
84func @fusedBatchNormV3_training_batchVariance(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> tensor<8xf32> {
85  // CHECK: %[[RESULT0:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
86  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
87  // CHECK: %[[VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
88  // CHECK: return %[[VAR]]
89  return %0#4 : tensor<8xf32>
90}
91
92// CHECK-LABEL: fusedBatchNormV3_training_exponentialAvgFactor
93func @fusedBatchNormV3_training_exponentialAvgFactor(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) {
94  // CHECK: %[[RESULT0:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
95  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 0.8 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
96  // CHECK-DAG: %[[BATCH_MEAN:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32}
97  // CHECK-DAG: %[[BATCH_VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32}
98
99  // CHECK: %[[FACTOR:.*]] = mhlo.constant dense<1.00195694>
100  // CHECK: %[[CORRECTED_VAR:.*]] = chlo.broadcast_multiply %[[BATCH_VAR]], %[[FACTOR]]
101
102  // CHECK-DAG: %[[ALPHA:.*]] = mhlo.constant dense<0.199999988>
103  // CHECK-DAG: %[[BETA:.*]] = mhlo.constant dense<8.000000e-01>
104
105  // CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = chlo.broadcast_multiply %[[ALPHA]], %arg3
106  // CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = chlo.broadcast_multiply %[[BETA]], %[[BATCH_MEAN]]
107  // CHECK: %[[NEW_BATCH_MEAN:.*]] = chlo.broadcast_add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]]
108
109  // CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = chlo.broadcast_multiply %[[ALPHA]], %arg4
110  // CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = chlo.broadcast_multiply %[[BETA]], %[[CORRECTED_VAR]]
111  // CHECK: %[[NEW_BATCH_VAR:.*]] = chlo.broadcast_add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]]
112
113  // CHECK: return %[[NEW_BATCH_MEAN]], %[[NEW_BATCH_VAR]], %[[BATCH_MEAN]], %[[BATCH_VAR]]
114  return %0#1, %0#2, %0#3, %0#4 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
115}
116
117// CHECK-LABEL: fusedBatchNormV3_training_mixedPrecision
118func @fusedBatchNormV3_training_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) {
119  // CHECK: "mhlo.convert"(%arg0) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
120  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
121  // CHECK: "mhlo.convert"({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
122  return %0#0 : tensor<8x8x8x8xbf16>
123}
124
125// CHECK-LABEL: fusedBatchNormV3_NCHW
126func @fusedBatchNormV3_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
127  // CHECK: "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
128  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
129  return %0#0 : tensor<8x8x8x8xf32>
130}
131
132// CHECK-LABEL: fusedBatchNormV3_noTraining_dynamic_supported
133func @fusedBatchNormV3_noTraining_dynamic_supported(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: tensor<?xf32>, %arg4: tensor<?xf32>) -> (tensor<?x?x?x?xf32>) {
134  // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?x?x?x?xf32>
135  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = false} : (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>)
136  return %0#0 : tensor<?x?x?x?xf32>
137}
138
139// CHECK-LABEL: fusedBatchNormV3_training_dynamic_unsupported1
140func @fusedBatchNormV3_training_dynamic_unsupported1(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: tensor<?xf32>, %arg4: tensor<?xf32>) -> (tensor<?x?x?x?xf32>) {
141  // CHECK: tf.FusedBatchNormV3
142  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>)
143  return %0#0 : tensor<?x?x?x?xf32>
144}
145
146// CHECK-LABEL: fusedBatchNormV3_training_dynamic_unsupported2
147func @fusedBatchNormV3_training_dynamic_unsupported2(%arg0: tensor<?x6x?x?xf32>, %arg1: tensor<6xf32>, %arg2: tensor<6xf32>, %arg3: tensor<6xf32>, %arg4: tensor<6xf32>) -> (tensor<?x6x?x?xf32>) {
148  // CHECK: tf.FusedBatchNormV3
149  %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<?x6x?x?xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>) -> (tensor<?x6x?x?xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>)
150  return %0#0 : tensor<?x6x?x?xf32>
151}
152
153// CHECK-LABEL: fusedBatchNormGrad_noTraining
154func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
155  // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
156  // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
157  // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
158
159  // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
160  // CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32>
161
162  // CHECK:      %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
163  // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32>
164  // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32>
165  // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64>
166  // CHECK-NEXT: %[[cmul:.*]] = "mhlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
167  // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
168  // CHECK-NEXT: %[[red1:.*]] = "mhlo.reduce"(%[[cmul]], %[[init]]) ( {
169  // CHECK-NEXT: ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>):	// no predecessors
170  // CHECK-NEXT:   %[[reduced:.*]] = mhlo.add %arg5, %arg6 : tensor<f32>
171  // CHECK-NEXT:   "mhlo.return"(%[[reduced]]) : (tensor<f32>) -> ()
172  // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
173  // CHECK-NEXT: %[[scr2:.*]] = "mhlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32>
174
175  // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32>
176  // CHECK:      %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
177  // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32>
178  // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32>
179
180  // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64>
181  // CHECK-NEXT: %[[cgrad:.*]] = "mhlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
182  // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
183  // CHECK-NEXT: %[[red2:.*]] = "mhlo.reduce"(%[[cgrad]], %[[init2]]) ( {
184  // CHECK-NEXT: ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>):	// no predecessors
185  // CHECK-NEXT:   %[[reduced1:.*]] = mhlo.add %arg5, %arg6 : tensor<f32>
186  // CHECK-NEXT:   "mhlo.return"(%[[reduced1]]) : (tensor<f32>) -> ()
187  // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
188  // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32>
189
190  // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
191  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32>
192
193  %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
194  return %0#0 : tensor<8x8x8x8xf32>
195}
196
197// CHECK-LABEL: fusedBatchNormGrad_Training
198func @fusedBatchNormGrad_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
199  // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
200  // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
201  // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
202  // CHECK-NEXT: %[[tact:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32>
203  // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
204  // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
205  // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
206  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32>
207
208  %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
209  return %0#0 : tensor<8x8x8x8xf32>
210}
211
212// CHECK-LABEL: fusedBatchNormGradV2_noTraining
213func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
214  // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
215  // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
216  // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
217
218  // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
219  // CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32>
220
221  // CHECK:      %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
222  // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32>
223  // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32>
224  // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64>
225  // CHECK-NEXT: %[[cmul:.*]] = "mhlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
226  // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
227  // CHECK-NEXT: %[[red1:.*]] = "mhlo.reduce"(%[[cmul]], %[[init]]) ( {
228  // CHECK-NEXT: ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>):	// no predecessors
229  // CHECK-NEXT:   %[[reduced:.*]] = mhlo.add %arg5, %arg6 : tensor<f32>
230  // CHECK-NEXT:   "mhlo.return"(%[[reduced]]) : (tensor<f32>) -> ()
231  // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
232  // CHECK-NEXT: %[[scr2:.*]] = "mhlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32>
233
234  // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32>
235  // CHECK:      %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
236  // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32>
237
238  // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32>
239
240  // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64>
241  // CHECK-NEXT: %[[cgrad:.*]] = "mhlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
242  // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
243  // CHECK-NEXT: %[[red2:.*]] = "mhlo.reduce"(%[[cgrad]], %[[init2]]) ( {
244  // CHECK-NEXT: ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>):	// no predecessors
245  // CHECK-NEXT:   %[[reduced1:.*]] = mhlo.add %arg5, %arg6 : tensor<f32>
246  // CHECK-NEXT:   "mhlo.return"(%[[reduced1]]) : (tensor<f32>) -> ()
247  // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
248  // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32>
249
250  // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
251  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32>
252
253  %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
254  return %0#0 : tensor<8x8x8x8xf32>
255}
256
257// CHECK-LABEL: fusedBatchNormGradV2_Training
258func @fusedBatchNormGradV2_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
259  // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
260  // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
261  // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
262  // CHECK-NEXT: %[[tact:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32>
263  // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
264  // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
265  // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
266  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32>
267
268  %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
269  return %0#0 : tensor<8x8x8x8xf32>
270}
271
272// CHECK-LABEL: fusedBatchNormGradV2_noTraining_mixed_precision
273func @fusedBatchNormGradV2_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) {
274  // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
275  // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
276
277  // CHECK: %[[x_backprop:.*]] = "mhlo.convert"({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
278  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16>
279
280  %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
281  return %0#0 : tensor<8x8x8x8xbf16>
282}
283
284// CHECK-LABEL: fusedBatchNormGradV2_Training_mixed_precision
285func @fusedBatchNormGradV2_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) {
286  // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
287  // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
288  // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
289  // CHECK-NEXT: %[[tact:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32>
290  // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
291  // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
292  // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
293  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16>
294
295  %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
296  return %0#0 : tensor<8x8x8x8xbf16>
297}
298
299// CHECK-LABEL: fusedBatchNormGradV3_noTraining
300func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
301  // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
302  // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
303  // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
304
305  // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
306  // CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32>
307
308  // CHECK:      %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
309  // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32>
310  // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32>
311  // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64>
312  // CHECK-NEXT: %[[cmul:.*]] = "mhlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
313  // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
314  // CHECK-NEXT: %[[red1:.*]] = "mhlo.reduce"(%[[cmul]], %[[init]]) ( {
315  // CHECK-NEXT: ^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>):	// no predecessors
316  // CHECK-NEXT:   %[[reduced:.*]] = mhlo.add %arg6, %arg7 : tensor<f32>
317  // CHECK-NEXT:   "mhlo.return"(%[[reduced]]) : (tensor<f32>) -> ()
318  // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
319  // CHECK-NEXT: %[[scr2:.*]] = "mhlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32>
320
321  // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32>
322  // CHECK:      %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
323  // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32>
324
325  // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32>
326
327  // CHECK-NEXT: mhlo.constant dense<[0, 1, 2]> : tensor<3xi64>
328  // CHECK-NEXT: %[[cgrad:.*]] = "mhlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
329  // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
330  // CHECK-NEXT: %[[red2:.*]] = "mhlo.reduce"(%[[cgrad]], %[[init2]]) ( {
331  // CHECK-NEXT: ^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>):	// no predecessors
332  // CHECK-NEXT:   %[[reduced1:.*]] = mhlo.add %arg6, %arg7 : tensor<f32>
333  // CHECK-NEXT:   "mhlo.return"(%[[reduced1]]) : (tensor<f32>) -> ()
334  // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
335  // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32>
336
337  // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
338  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32>
339
340  %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
341  return %0#0 : tensor<8x8x8x8xf32>
342}
343
344// CHECK-LABEL: fusedBatchNormGradV3_Training
345func @fusedBatchNormGradV3_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32>) {
346  // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
347  // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
348  // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
349  // CHECK-NEXT: %[[tact:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32>
350  // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
351  // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
352  // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
353  // CHECK: return %[[x_backprop]]
354  // CHECK-SAME: tensor<8x8x8x8xf32>
355
356  %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<0xf32>, tensor<*xf32>)
357  return %0#0, %0#3, %0#4 : tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32>
358}
359
360// CHECK-LABEL: fusedBatchNormGradV3_noTraining_mixed_precision
361func @fusedBatchNormGradV3_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) {
362  // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
363  // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
364
365  // CHECK: %[[x_backprop:.*]] = "mhlo.convert"({{.*}}) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
366  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16>
367
368  %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
369  return %0#0 : tensor<8x8x8x8xbf16>
370}
371
372// CHECK-LABEL: fusedBatchNormGradV3_Training_mixed_precision
373func @fusedBatchNormGradV3_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) {
374  // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
375  // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
376  // CHECK-NEXT: %[[training:.*]] = "mhlo.batch_norm_grad"(%[[act]], %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
377  // CHECK-NEXT: %[[tact:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 0 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32>
378  // CHECK-NEXT: %[[scale_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
379  // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.get_tuple_element"(%[[training]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
380  // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[tact]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
381  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xbf16>
382
383  %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
384  return %0#0 : tensor<8x8x8x8xbf16>
385}
386
387// CHECK-LABEL: fusedBatchNormGradV3_noTraining_NCHW
388func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
389  // CHECK-NEXT: %[[grad:.*]] = "mhlo.convert"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
390  // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
391  // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
392
393  // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
394  // CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32>
395
396  // CHECK:      %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
397  // CHECK-NEXT: %[[sub:.*]] = mhlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32>
398  // CHECK-NEXT: %[[mul:.*]] = mhlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32>
399  // CHECK-NEXT: mhlo.constant dense<[0, 2, 3]> : tensor<3xi64>
400  // CHECK-NEXT: %[[cmul:.*]] = "mhlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
401  // CHECK-NEXT: %[[init:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
402  // CHECK-NEXT: %[[red1:.*]] = "mhlo.reduce"(%[[cmul]], %[[init]]) ( {
403  // CHECK-NEXT: ^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>):	// no predecessors
404  // CHECK-NEXT:   %[[reduced:.*]] = mhlo.add %arg6, %arg7 : tensor<f32>
405  // CHECK-NEXT:   "mhlo.return"(%[[reduced]]) : (tensor<f32>) -> ()
406  // CHECK-NEXT: }) {dimensions = dense<[0, 2, 3]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
407  // CHECK-NEXT: %[[scr2:.*]] = "mhlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32>
408
409  // CHECK-NEXT: %[[mul2:.*]] = mhlo.multiply %arg2, %[[scr1]] : tensor<8xf32>
410  // CHECK:      %[[bcast_mul2:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
411  // CHECK-NEXT: %[[mul3:.*]] = mhlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32>
412
413  // CHECK-NEXT: %[[scale_backprop:.*]] = mhlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32>
414
415  // CHECK-NEXT: mhlo.constant dense<[0, 2, 3]> : tensor<3xi64>
416  // CHECK-NEXT: %[[cgrad:.*]] = "mhlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
417  // CHECK-NEXT: %[[init2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
418  // CHECK-NEXT: %[[red2:.*]] = "mhlo.reduce"(%[[cgrad]], %[[init2]]) ( {
419  // CHECK-NEXT: ^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>):	// no predecessors
420  // CHECK-NEXT:   %[[reduced1:.*]] = mhlo.add %arg6, %arg7 : tensor<f32>
421  // CHECK-NEXT:   "mhlo.return"(%[[reduced1]]) : (tensor<f32>) -> ()
422  // CHECK-NEXT: }) {dimensions = dense<[0, 2, 3]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32>
423  // CHECK-NEXT: %[[offset_backprop:.*]] = "mhlo.convert"(%[[red2]]) : (tensor<8xf32>) -> tensor<8xf32>
424
425  // CHECK-NEXT: %[[x_backprop:.*]] = "mhlo.convert"(%[[mul3]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
426  // CHECK-NEXT: return %[[x_backprop]] : tensor<8x8x8x8xf32>
427
428  %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
429  return %0#0 : tensor<8x8x8x8xf32>
430}
431
432// CHECK-LABEL: fusedBatchNormGradV3_Training_NCHW
433func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
434  // CHECK: %{{.*}} = "mhlo.batch_norm_grad"(%{{.*}}, %arg2, %arg3, %arg4, %[[grad]]) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
435  %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
436  return %0#0 : tensor<8x8x8x8xf32>
437}
438
439//===----------------------------------------------------------------------===//
440// Bias op legalizations.
441//===----------------------------------------------------------------------===//
442
443// CHECK-LABEL: func @biasAdd_default
444func @biasAdd_default(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
445  // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
446  // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]]
447  // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]])
448  // CHECK-SAME:   {broadcast_dimensions = dense<3> : tensor<1xi64>}
449  // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]]
450  %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
451  return %0 : tensor<1x32x10x32xi32>
452}
453
454// CHECK-LABEL: func @biasAdd_NHWC
455func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
456  // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
457  // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]]
458  // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]])
459  // CHECK-SAME:   {broadcast_dimensions = dense<3> : tensor<1xi64>}
460  // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]]
461  %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
462  return %0 : tensor<1x32x10x32xi32>
463}
464
465// CHECK-LABEL: func @biasAdd_NCHW
466func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
467  // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
468  // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]]
469  // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]])
470  // CHECK-SAME:   {broadcast_dimensions = dense<1> : tensor<1xi64>}
471  // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]]
472  %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
473  return %0 : tensor<1x32x10x32xi32>
474}
475
476// CHECK-LABEL: func @biasAdd_dynamic
477func @biasAdd_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?x?x?xi32> {
478  // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
479  // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]]
480  // CHECK: %[[ARG1_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]])
481  // CHECK-SAME:   {broadcast_dimensions = dense<1> : tensor<1xi64>}
482  // CHECK: %[[RESULT:.+]] = mhlo.add %arg0, %[[ARG1_BCAST]]
483  %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} : (tensor<?x?x?x?xi32>, tensor<?xi32>) -> tensor<?x?x?x?xi32>
484  return %0 : tensor<?x?x?x?xi32>
485}
486
487
488//===----------------------------------------------------------------------===//
489// ClipByValue
490//===----------------------------------------------------------------------===//
491
492// CHECK-LABEL: @clip
493func @clip(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<f32> {
494  // CHECK: [[VAL:%.+]] = "mhlo.clamp"(%arg1, %arg0, %arg2)
495
496  %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
497  // CHECK: return [[VAL]]
498  return %0 : tensor<f32>
499}
500
501// CHECK-LABEL: @clip_dynamic
502func @clip_dynamic(%arg0 : tensor<?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>) -> tensor<?xf32> {
503  // CHECK-DAG: [[CLAMP:%.+]] = "mhlo.clamp"(%arg1, %arg0, %arg2)
504  %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
505
506  // CHECK: return [[CLAMP]]
507  return %0 : tensor<?xf32>
508}
509
510// CHECK-LABEL: @clip_static_broadcast
511func @clip_static_broadcast(%arg0 : tensor<5xf32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<5xf32> {
512  // CHECK-DAG: [[SHP:%.+]] = mhlo.constant dense<5>
513  // CHECK-DAG: [[SHPIDX:%.+]] = tensor.cast [[SHP]]
514  // CHECK-DAG: [[BROADCAST_MIN:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>}
515  // CHECK-DAG: [[BROADCAST_MAX:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>}
516  // CHECK-DAG: [[CLAMP:%.+]] = "mhlo.clamp"([[BROADCAST_MIN]], %arg0, [[BROADCAST_MAX]])
517  %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<5xf32>, tensor<f32>, tensor<f32>) -> tensor<5xf32>
518
519  // CHECK: return [[CLAMP]]
520  return %0 : tensor<5xf32>
521}
522
523
524// CHECK-LABEL: @clip_dynamic_broadcast
525func @clip_dynamic_broadcast(%arg0 : tensor<?xf32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<?xf32> {
526  // CHECK-DAG: [[SHP:%.+]] = shape.shape_of %arg0
527  // CHECK-DAG: [[EXT:%.+]] = shape.to_extent_tensor [[SHP]]
528  // CHECK-DAG: [[SHPIDX:%.+]] = index_cast %1
529  // CHECK-DAG: [[BROADCAST_MIN:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>}
530  // CHECK-DAG: [[BROADCAST_MAX:%.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg2, [[SHPIDX]]) {broadcast_dimensions = dense<> : tensor<0xi64>}
531  // CHECK-DAG: [[CLAMP:%.+]] = "mhlo.clamp"([[BROADCAST_MIN]], %arg0, [[BROADCAST_MAX]])
532  %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<?xf32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
533
534  // CHECK: return [[CLAMP]]
535  return %0 : tensor<?xf32>
536}
537
538//===----------------------------------------------------------------------===//
539// DiagPart
540//===----------------------------------------------------------------------===//
541
542// CHECK-LABEL: func @diag_part
543// CHECK-SAME: %[[ARG:.*]]: tensor<4x3x4x3xf32>
544func @diag_part(%arg0: tensor<4x3x4x3xf32>) -> tensor<4x3xf32> {
545  // CHECK: %[[RS:.*]] = "mhlo.reshape"(%[[ARG]]) : (tensor<4x3x4x3xf32>) -> tensor<12x12xf32>
546  // CHECK-DAG: %[[IOTA0:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<12x12xi32>
547  // CHECK-DAG: %[[IOTA1:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<12x12xi32>
548  // CHECK-DAG: %[[COMP:.*]] = "mhlo.compare"(%[[IOTA0]], %[[IOTA1]]) {comparison_direction = "EQ"} : (tensor<12x12xi32>, tensor<12x12xi32>) -> tensor<12x12xi1>
549  // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
550  // CHECK-DAG: %[[ZERO_MAT:.*]] = "mhlo.broadcast"(%[[ZERO]]) {broadcast_sizes = dense<12> : tensor<2xi64>} : (tensor<f32>) -> tensor<12x12xf32>
551  // CHECK-DAG: %[[SEL:.*]] = "mhlo.select"(%[[COMP]], %[[RS]], %[[ZERO_MAT]]) : (tensor<12x12xi1>, tensor<12x12xf32>, tensor<12x12xf32>) -> tensor<12x12xf32>
552  // CHECK-DAG: %[[RED:.*]] = "mhlo.reduce"(%[[SEL]], %[[ZERO]])
553  // CHECK-DAG:  mhlo.add
554  // CHECK-DAG: {dimensions = dense<0> : tensor<1xi64>} : (tensor<12x12xf32>, tensor<f32>) -> tensor<12xf32>
555  // CHECK-DAG:  %[[RES:.*]] = "mhlo.reshape"(%[[RED]]) : (tensor<12xf32>) -> tensor<4x3xf32>
556  // CHECK-DAG:  return %[[RES]] : tensor<4x3xf32>
557  %0 = "tf.DiagPart"(%arg0) : (tensor<4x3x4x3xf32>) -> tensor<4x3xf32>
558  return %0: tensor<4x3xf32>
559}
560
561//===----------------------------------------------------------------------===//
562// MatrixDiagPart
563//===----------------------------------------------------------------------===//
564
565// CHECK-LABEL: func @matrix_diag_part
566// CHECK-SAME: %[[ARG:.*]]: tensor<7x140x128xi32>
567func @matrix_diag_part(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> {
568  // CHECK-DAG: %[[V0:.*]] = mhlo.constant dense<42> : tensor<i32>
569  // CHECK-DAG: %[[V1:.*]] = mhlo.constant dense<[-10, 11]> : tensor<2xi32>
570  // CHECK-DAG: %[[V2:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<1x22x128xi32>
571  // CHECK-DAG: %[[V3:.*]] = "mhlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<1x22x128xi32>
572  // CHECK-DAG: %[[V4:.*]] = mhlo.constant dense<0> : tensor<i32>
573  // CHECK-DAG: %[[V5:.*]] = "mhlo.broadcast"(%[[V4]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32>
574  // CHECK-DAG: %[[V6:.*]] = mhlo.constant dense<false> : tensor<i1>
575  // CHECK-DAG: %[[V7:.*]] = "mhlo.broadcast"(%[[V6]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i1>) -> tensor<1x22x128xi1>
576  // CHECK-DAG: %[[V8:.*]] = mhlo.constant dense<true> : tensor<i1>
577  // CHECK-DAG: %[[V9:.*]] = "mhlo.broadcast"(%[[V8]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i1>) -> tensor<1x22x128xi1>
578  // CHECK-DAG: %[[V10:.*]] = mhlo.constant dense<11> : tensor<i32>
579  // CHECK-DAG: %[[V11:.*]] = "mhlo.broadcast"(%[[V10]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32>
580  // CHECK-DAG: %[[V12:.*]] = mhlo.constant dense<140> : tensor<i32>
581  // CHECK-DAG: %[[V13:.*]] = "mhlo.broadcast"(%[[V12]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32>
582  // CHECK-DAG: %[[V14:.*]] = mhlo.constant dense<128> : tensor<i32>
583  // CHECK-DAG: %[[V15:.*]] = "mhlo.broadcast"(%[[V14]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32>
584  // CHECK-DAG: %[[V16:.*]] = mhlo.constant dense<128> : tensor<i32>
585  // CHECK-DAG: %[[V17:.*]] = "mhlo.broadcast"(%[[V16]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<1x22x128xi32>
586  // CHECK-DAG: %[[V18:.*]] = mhlo.subtract %[[V11]], %[[V2]] : tensor<1x22x128xi32>
587  // CHECK-DAG: %[[V19:.*]] = "mhlo.negate"(%[[V18]]) : (tensor<1x22x128xi32>) -> tensor<1x22x128xi32>
588  // CHECK-DAG: %[[V20:.*]] = mhlo.minimum %[[V18]], %[[V5]] : tensor<1x22x128xi32>
589  // CHECK-DAG: %[[V21:.*]] = mhlo.add %[[V13]], %[[V20]] : tensor<1x22x128xi32>
590  // CHECK-DAG: %[[V22:.*]] = mhlo.maximum %[[V18]], %[[V5]] : tensor<1x22x128xi32>
591  // CHECK-DAG: %[[V23:.*]] = mhlo.subtract %[[V15]], %[[V22]] : tensor<1x22x128xi32>
592  // CHECK-DAG: %[[V24:.*]] = mhlo.minimum %[[V21]], %[[V23]] : tensor<1x22x128xi32>
593  // CHECK-DAG: %[[V25:.*]] = chlo.broadcast_compare %[[V18]], %[[V5]] {comparison_direction = "GE"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1>
594  // CHECK-DAG: %[[V26:.*]] = mhlo.subtract %[[V17]], %[[V24]] : tensor<1x22x128xi32>
595  // CHECK-DAG: %[[V27:.*]] = "mhlo.select"(%[[V25]], %[[V26]], %[[V5]]) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32>
596  // CHECK-DAG: %[[V28:.*]] = mhlo.maximum %[[V18]], %[[V5]] : tensor<1x22x128xi32>
597  // CHECK-DAG: %[[V29:.*]] = mhlo.subtract %[[V28]], %[[V27]] : tensor<1x22x128xi32>
598  // CHECK-DAG: %[[V30:.*]] = mhlo.maximum %[[V19]], %[[V5]] : tensor<1x22x128xi32>
599  // CHECK-DAG: %[[V31:.*]] = mhlo.subtract %[[V30]], %[[V27]] : tensor<1x22x128xi32>
600  // CHECK-DAG: %[[V32:.*]] = mhlo.add %[[V3]], %[[V29]] : tensor<1x22x128xi32>
601  // CHECK-DAG: %[[V33:.*]] = mhlo.add %[[V3]], %[[V31]] : tensor<1x22x128xi32>
602  // CHECK-DAG: %[[V34:.*]] = chlo.broadcast_compare %[[V32]], %[[V5]] {comparison_direction = "GE"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1>
603  // CHECK-DAG: %[[V35:.*]] = chlo.broadcast_compare %[[V32]], %[[V15]] {comparison_direction = "LT"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1>
604  // CHECK-DAG: %[[V36:.*]] = mhlo.and %[[V34]], %[[V35]] : tensor<1x22x128xi1>
605  // CHECK-DAG: %[[V37:.*]] = chlo.broadcast_compare %[[V33]], %[[V5]] {comparison_direction = "GE"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1>
606  // CHECK-DAG: %[[V38:.*]] = chlo.broadcast_compare %[[V33]], %[[V13]] {comparison_direction = "LT"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1>
607  // CHECK-DAG: %[[V39:.*]] = mhlo.and %[[V37]], %[[V38]] : tensor<1x22x128xi1>
608  // CHECK-DAG: %[[V40:.*]] = mhlo.and %[[V36]], %[[V39]] : tensor<1x22x128xi1>
609  // CHECK-DAG: %[[V41:.*]] = "mhlo.reshape"(%[[V40]]) : (tensor<1x22x128xi1>) -> tensor<22x128xi1>
610  // CHECK-DAG: %[[V42:.*]] = "mhlo.concatenate"(%[[V33]], %[[V32]]) {dimension = 0 : i64} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<2x22x128xi32>
611  // CHECK-DAG: %[[V43:.*]] = "mhlo.gather"(%[[ARG]], %[[V42]]) {dimension_numbers = {collapsed_slice_dims = dense<[1, 2]> : tensor<2xi64>, index_vector_dim = 0 : i64, offset_dims = dense<0> : tensor<1xi64>, start_index_map = dense<[1, 2]> : tensor<2xi64>}, indices_are_sorted = false, slice_sizes = dense<[7, 1, 1]> : tensor<3xi64>} : (tensor<7x140x128xi32>, tensor<2x22x128xi32>) -> tensor<7x22x128xi32>
612  // CHECK-DAG: %[[V44:.*]] = "mhlo.broadcast"(%[[V41]]) {broadcast_sizes = dense<7> : tensor<1xi64>} : (tensor<22x128xi1>) -> tensor<7x22x128xi1>
613  // CHECK-DAG: %[[V45:.*]] = "mhlo.broadcast"(%[[V0]]) {broadcast_sizes = dense<[7, 22, 128]> : tensor<3xi64>} : (tensor<i32>) -> tensor<7x22x128xi32>
614  // CHECK: %[[V46:.*]] = "mhlo.select"(%[[V44]], %[[V43]], %[[V45]]) : (tensor<7x22x128xi1>, tensor<7x22x128xi32>, tensor<7x22x128xi32>) -> tensor<7x22x128xi32>
615  // CHECK: return %[[V46]] : tensor<7x22x128xi32>
616  %0 = mhlo.constant dense<42> : tensor<i32>  // padding value
617  %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32>  // k
618  %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) {
619      T = i32, align = "RIGHT_LEFT"
620  } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32>
621  return %2: tensor<7x22x128xi32>
622}
623
624// CHECK-LABEL: func @matrix_diag_part_single_diagonal
625func @matrix_diag_part_single_diagonal(%arg0: tensor<7x140x128xi32>) -> tensor<7x128xi32> {
626  %0 = mhlo.constant dense<42> : tensor<i32>  // padding value
627  %1 = mhlo.constant dense<0> : tensor<2xi32>  // k
628  %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) {
629      T = i32, align = "RIGHT_LEFT"
630  } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x128xi32>
631  // CHECK: %[[result:.*]] = "mhlo.reshape"({{.*}}) : (tensor<7x1x128xi32>) -> tensor<7x128xi32>
632  // CHECK: return %[[result]] : tensor<7x128xi32>
633  return %2: tensor<7x128xi32>
634}
635
636// CHECK-LABEL: func @matrix_diag_part_align_ll
637func @matrix_diag_part_align_ll(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> {
638  %0 = mhlo.constant dense<42> : tensor<i32>  // padding value
639  %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32>  // k
640  %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) {
641      T = i32, align = "LEFT_LEFT"
642  } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32>
643  // CHECK: %[[false:.*]] = mhlo.constant dense<false> : tensor<i1>
644  // CHECK: %[[b_false:.*]] = "mhlo.broadcast"(%[[false]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i1>) -> tensor<1x22x128xi1>
645  // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[b_false]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32>
646  return %2: tensor<7x22x128xi32>
647}
648
649// CHECK-LABEL: func @matrix_diag_part_align_lr
650func @matrix_diag_part_align_lr(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> {
651  %0 = mhlo.constant dense<42> : tensor<i32>  // padding value
652  %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32>  // k
653  %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) {
654      T = i32, align = "LEFT_RIGHT"
655  } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32>
656  // CHECK: %[[le:.*]] = chlo.broadcast_compare %{{[0-9]*}}, %{{[0-9]*}} {comparison_direction = "LE"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1>
657  // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[le]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32>
658  return %2: tensor<7x22x128xi32>
659}
660
661// CHECK-LABEL: func @matrix_diag_part_align_rl
662func @matrix_diag_part_align_rl(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> {
663  %0 = mhlo.constant dense<42> : tensor<i32>  // padding value
664  %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32>  // k
665  %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) {
666      T = i32, align = "RIGHT_LEFT"
667  } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32>
668  // CHECK: %[[ge:.*]] = chlo.broadcast_compare %{{[0-9]*}}, %{{[0-9]*}} {comparison_direction = "GE"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1>
669  // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[ge]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32>
670  return %2: tensor<7x22x128xi32>
671}
672
673// CHECK-LABEL: func @matrix_diag_part_align_rr
674func @matrix_diag_part_align_rr(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> {
675  %0 = mhlo.constant dense<42> : tensor<i32>  // padding value
676  %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32>  // k
677  %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) {
678      T = i32, align = "RIGHT_RIGHT"
679  } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor<i32>) -> tensor<7x22x128xi32>
680  // CHECK: %[[true:.*]] = mhlo.constant dense<true> : tensor<i1>
681  // CHECK: %[[b_true:.*]] = "mhlo.broadcast"(%[[true]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor<i1>) -> tensor<1x22x128xi1>
682  // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[b_true]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32>
683  return %2: tensor<7x22x128xi32>
684}
685
686// CHECK-LABEL: func @matrix_diag_part_align_7d
687// CHECK: (%arg0: tensor<3x5x7x9x11x13x17xf32>) -> tensor<3x5x7x9x11x4x10xf32>
688func @matrix_diag_part_align_7d(%arg0: tensor<3x5x7x9x11x13x17xf32>) -> tensor<3x5x7x9x11x4x10xf32> {
689  %0 = mhlo.constant dense<-1.> : tensor<f32>  // padding value
690  %1 = mhlo.constant dense<[-6, -3]> : tensor<2xi32>  // k
691  %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) {
692      T = f32, align = "LEFT_RIGHT"
693  } : (tensor<3x5x7x9x11x13x17xf32>, tensor<2xi32>, tensor<f32>) -> tensor<3x5x7x9x11x4x10xf32>
694  return %2: tensor<3x5x7x9x11x4x10xf32>
695}
696
697//===----------------------------------------------------------------------===//
698// Erf
699//===----------------------------------------------------------------------===//
700
701// CHECK-LABEL: func @erf
702func @erf(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
703  // CHECK: chlo.erf %arg0 : tensor<2x3xf32>
704  %0 = "tf.Erf"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
705  return %0 : tensor<2x3xf32>
706}
707
708//===----------------------------------------------------------------------===//
709// Erfc
710//===----------------------------------------------------------------------===//
711
712// CHECK-LABEL: func @erfc
713func @erfc(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
714  // CHECK: chlo.erfc %arg0 : tensor<2x3xf32>
715  %0 = "tf.Erfc"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
716  return %0 : tensor<2x3xf32>
717}
718
719//===----------------------------------------------------------------------===//
720// Einsum.
721//===----------------------------------------------------------------------===//
722
723// CHECK-LABEL: func @einsum
724func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> {
725  // CHECK:  mhlo.einsum
726  %0 = "tf.Einsum"(%arg0, %arg1) {equation = "ab,bc->ac"} : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32>
727  return %0: tensor<2x4xf32>
728}
729
730// CHECK-LABEL: func @unary_einsum
731func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
732  // CHECK:  mhlo.unary_einsum
733  %0 = "tf.Einsum"(%arg0) {equation = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
734  return %0: tensor<2x2xf32>
735}
736
737//===----------------------------------------------------------------------===//
738// FloorDiv and FloorMod.
739//===----------------------------------------------------------------------===//
740
741// CHECK-LABEL: func @floordiv_broadcast_i32
742func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
743  // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0>
744  // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
745  // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0>
746  // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
747  // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"}
748  // CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
749  // CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0)
750  // CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1)
751  // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1>
752  // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]]
753  // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
754  // CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]])
755  // CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1)
756  // CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
757  // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]])
758  // CHECK: return [[SELECT]]
759  %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
760  return %0: tensor<2x3xi32>
761}
762
763// CHECK-LABEL: func @floordiv_reverse_broadcast_i32
764func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
765  // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0>
766  // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
767  // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0>
768  // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
769  // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"}
770  // CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
771  // CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0)
772  // CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1)
773  // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1>
774  // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]]
775  // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
776  // CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]])
777  // CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1)
778  // CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]]
779  // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]])
780  // CHECK: return [[SELECT]]
781  %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
782  return %0: tensor<2x3xi32>
783}
784
785// CHECK-LABEL: func @floordiv_f32
786func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> {
787  // CHECK-NEXT:  %[[DIV:.*]] = chlo.broadcast_divide %arg0, %arg0
788  // CHECK-NEXT:  %[[FLOOR:.*]] = "mhlo.floor"(%[[DIV]])
789  // CHECK-NEXT:  return %[[FLOOR]] : tensor<2xf32>
790  %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
791  return %0: tensor<2xf32>
792}
793
794// CHECK-LABEL: func @floordiv_bf16
795func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> {
796  // CHECK-NEXT:  mhlo.convert
797  // CHECK-NEXT:  mhlo.convert
798  // CHECK-NEXT:  chlo.broadcast_divide
799  // CHECK-NEXT:  mhlo.floor
800  // CHECK-NEXT:  mhlo.convert
801  // CHECK-NEXT:  return
802  %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xbf16>, tensor<2xbf16>) -> tensor<2xbf16>
803  return %0: tensor<2xbf16>
804}
805
806// CHECK-LABEL: func @floordiv_f16_broadcast
807func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> {
808  // CHECK-NEXT:  chlo.broadcast_divide
809  // CHECK-NEXT:  mhlo.floor
810  // CHECK-NEXT:  return
811  %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
812  return %0: tensor<2x3xf16>
813}
814
815// CHECK-LABEL: func @floordiv_dynamic
816func @floordiv_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?xi32> {
817  // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0>
818  // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
819  // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0>
820  // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
821  // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"}
822  // CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
823  // CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0)
824  // CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1)
825  // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1>
826  // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]]
827  // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
828  // CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]])
829  // CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1)
830  // CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
831  // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]])
832  // CHECK: return [[SELECT]]
833  %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<?xi32>) -> tensor<?x?xi32>
834  return %0: tensor<?x?xi32>
835}
836
837// CHECK-LABEL: func @floordiv_unranked
838func @floordiv_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
839  // CHECK-NOT: tf.FloorDiv
840  %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
841  return %0: tensor<*xf32>
842}
843
844// CHECK-LABEL: func @floordiv_int
845func @floordiv_int(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
846  // CHECK: tf.FloorDiv
847  %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
848  return %0: tensor<*xi32>
849}
850
851// CHECK-LABEL: func @floormod_broadcast_numerator
852func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
853  // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
854  // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0>
855  // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
856  // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0>
857  // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"}
858  // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "LT"}
859  // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = "NE"}
860  // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
861  // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]]
862  // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]])
863  // CHECK-NEXT: return [[SELECT]]
864  %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
865  return %0: tensor<2x3xi32>
866}
867
868// CHECK-LABEL: func @floormod_broadcast_denominator
869func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
870  // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
871  // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0>
872  // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"}
873  // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0>
874  // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"}
875  // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "LT"}
876  // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
877  // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
878  // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
879  // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]])
880  // CHECK-NEXT: return [[SELECT]]
881  %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
882  return %0: tensor<2x3xi32>
883}
884
885// CHECK-LABEL: func @floormod_dynamic
886func @floormod_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?xi32> {
887  // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
888  // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0>
889  // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"}
890  // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0>
891  // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"}
892  // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "LT"}
893  // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
894  // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
895  // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
896  // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]])
897  // CHECK-NEXT: return [[SELECT]]
898  %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<?xi32>) -> tensor<?x?xi32>
899  return %0: tensor<?x?xi32>
900}
901
902// CHECK-LABEL: func @floormod_unranked
903func @floormod_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
904  // CHECK-NOT: tf.FloorMod
905  %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
906  return %0: tensor<*xi32>
907}
908
909//===----------------------------------------------------------------------===//
910// BroadcastTo.
911//===----------------------------------------------------------------------===//
912
913// CHECK-LABEL: func @broadcast_to
914func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> {
915  %cst = "tf.Const"() { value = dense<16> : tensor<4xi32> } : () -> tensor<4xi32>
916
917  // CHECK: [[CST:%.+]] = mhlo.constant
918  // CHECK: [[CAST:%.+]] = tensor.cast [[CST]] : tensor<4xi32> to tensor<4xi32>
919  // CHECK: "mhlo.dynamic_broadcast_in_dim"(%arg0, [[CAST]])
920  // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>}
921  %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<16xf32>, tensor<4xi32>) -> tensor<16x16x16x16xf32>
922  return %0 : tensor<16x16x16x16xf32>
923}
924
925//===----------------------------------------------------------------------===//
926// Complex op legalizations.
927//===----------------------------------------------------------------------===//
928
929// CHECK-LABEL: func @complex
930func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex<f32>> {
931  // CHECK: chlo.broadcast_complex
932  %1 = "tf.Complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex<f32>>
933  return %1 : tensor<3xcomplex<f32>>
934}
935
936// CHECK-LABEL: func @imag
937func @imag(%arg0: tensor<3xcomplex<f32>>) -> tensor<3xf32> {
938  // CHECK: "mhlo.imag"
939  %1 = "tf.Imag"(%arg0) : (tensor<3xcomplex<f32>>) -> tensor<3xf32>
940  return %1 : tensor<3xf32>
941}
942
943// CHECK-LABEL: func @real
944func @real(%arg0: tensor<3xcomplex<f32>>) -> tensor<3xf32> {
945  // CHECK: "mhlo.real"
946  %1 = "tf.Real"(%arg0) : (tensor<3xcomplex<f32>>) -> tensor<3xf32>
947  return %1 : tensor<3xf32>
948}
949
950//===----------------------------------------------------------------------===//
951// Concat op legalizations.
952//===----------------------------------------------------------------------===//
953
954// CHECK-LABEL: func @concat_v2
955func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> {
956  // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32>
957  %axis = "tf.Const"() { value = dense<0> : tensor<i64> } : () -> tensor<i64>
958  %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<i64>) -> tensor<6x3xf32>
959  return %1 : tensor<6x3xf32>
960}
961
962// CHECK-LABEL: func @concat_v2_neg_axis
963func @concat_v2_neg_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> {
964  // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32>
965
966  %axis = "tf.Const"() { value = dense<-2> : tensor<i64> } : () -> tensor<i64>
967  %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<i64>) -> tensor<6x3xf32>
968  return %1 : tensor<6x3xf32>
969}
970
971// CHECK-LABEL: func @concat_v2_1d_axis
972func @concat_v2_1d_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x6xf32> {
973  // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 1 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32>
974
975  %axis = "tf.Const"() { value = dense<[1]> : tensor<1xi64> } : () -> tensor<1xi64>
976  %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<1xi64>) -> tensor<3x6xf32>
977  return %1 : tensor<3x6xf32>
978}
979
980// CHECK-LABEL: func @concat_v2_non_const_axis
981func @concat_v2_non_const_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>, %axis: tensor<i64>) -> tensor<3x6xf32> {
982  // CHECK: "tf.ConcatV2"
983  %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<i64>) -> tensor<3x6xf32>
984  return %1 : tensor<3x6xf32>
985}
986
987// CHECK-LABEL: func @concat_v2_unranked
988func @concat_v2_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
989  %axis = "tf.Const"() { value = dense<0> : tensor<i64> } : () -> tensor<i64>
990  // CHECK: "tf.ConcatV2"
991  %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<*xf32>, tensor<*xf32>, tensor<i64>) -> tensor<*xf32>
992  return %1 : tensor<*xf32>
993}
994
995//===----------------------------------------------------------------------===//
996// Pad op legalizations.
997//===----------------------------------------------------------------------===//
998
999// CHECK-LABEL: func @padv2_1D
1000func @padv2_1D(%arg0: tensor<3xf32>, %arg1: tensor<f32>) -> tensor<6xf32> {
1001  %padding = "tf.Const"() { value = dense<[[1, 2]]> : tensor<1x2xi64> } : () -> tensor<1x2xi64>
1002  // CHECK: "mhlo.pad"(%arg0, %arg1) {
1003  // CHECK-SAME: edge_padding_high = dense<2> : tensor<1xi64>,
1004  // CHECK-SAME: edge_padding_low = dense<1> : tensor<1xi64>,
1005  // CHECK-SAME: interior_padding = dense<0> : tensor<1xi64>
1006  %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3xf32>, tensor<1x2xi64>, tensor<f32>) -> tensor<6xf32>
1007  return %1 : tensor<6xf32>
1008}
1009
1010// CHECK-LABEL: func @padv2_2D
1011func @padv2_2D(%arg0: tensor<3x2xf32>, %arg1: tensor<f32>) -> tensor<6x9xf32> {
1012  %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi64> } : () -> tensor<2x2xi64>
1013  // CHECK: "mhlo.pad"(%arg0, %arg1) {
1014  // CHECK-SAME:    edge_padding_high = dense<[2, 4]> : tensor<2xi64>,
1015  // CHECK-SAME:    edge_padding_low = dense<[1, 3]> : tensor<2xi64>,
1016  // CHECK-SAME:    interior_padding = dense<0> : tensor<2xi64>
1017  %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi64>, tensor<f32>) -> tensor<6x9xf32>
1018  return %1 : tensor<6x9xf32>
1019}
1020
1021// CHECK-LABEL: func @padv2_i32_paddings
1022func @padv2_i32_paddings(%arg0: tensor<3x2xf32>, %arg1: tensor<f32>) -> tensor<6x9xf32> {
1023  %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi32> } : () -> tensor<2x2xi32>
1024  // CHECK: "mhlo.pad"(%arg0, %arg1) {
1025  // CHECK-SAME:    edge_padding_high = dense<[2, 4]> : tensor<2xi64>,
1026  // CHECK-SAME:    edge_padding_low = dense<[1, 3]> : tensor<2xi64>,
1027  // CHECK-SAME:    interior_padding = dense<0> : tensor<2xi64>
1028  %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi32>, tensor<f32>) -> tensor<6x9xf32>
1029  return %1 : tensor<6x9xf32>
1030}
1031
1032//===----------------------------------------------------------------------===//
1033// Identity op legalizations.
1034//===----------------------------------------------------------------------===//
1035
1036// CHECK-LABEL: func @identity
1037func @identity(%arg0: tensor<1xi32>) -> tensor<1xi32> {
1038  // CHECK-NEXT:  return %arg0 : tensor<1xi32>
1039  %0 = "tf.Identity"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
1040  return %0: tensor<1xi32>
1041}
1042
1043// CHECK-LABEL: func @identityN
1044func @identityN(%arg0: tensor<1xi32>, %arg1: tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>) {
1045  // CHECK-NEXT:  return %arg0, %arg1 : tensor<1xi32>, tensor<1xf32>
1046  %0:2 = "tf.IdentityN"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>)
1047  return %0#0, %0#1: tensor<1xi32>, tensor<1xf32>
1048}
1049
1050// CHECK-LABEL: func @stopgradient
1051func @stopgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> {
1052  // CHECK-NEXT:  return %arg0 : tensor<1xi32>
1053  %0 = "tf.StopGradient"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
1054  return %0: tensor<1xi32>
1055}
1056
1057// CHECK-LABEL: func @preventgradient
1058func @preventgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> {
1059  // CHECK-NEXT:  return %arg0 : tensor<1xi32>
1060  %0 = "tf.PreventGradient"(%arg0) {message = "fin gradients"} : (tensor<1xi32>) -> tensor<1xi32>
1061  return %0: tensor<1xi32>
1062}
1063
1064// CHECK-LABEL: func @checkNumerics
1065func @checkNumerics(%arg0: tensor<1xf32>) -> tensor<1xf32> {
1066  // CHECK-NEXT:  return %arg0 : tensor<1xf32>
1067  %0 = "tf.CheckNumerics"(%arg0) {message = "check numerics"} : (tensor<1xf32>) -> tensor<1xf32>
1068  return %0: tensor<1xf32>
1069}
1070
1071//===----------------------------------------------------------------------===//
1072// InfeedDequeueTuple legalization
1073//===----------------------------------------------------------------------===//
1074
1075// CHECK-LABEL: func @infeed_dequeue_tuple
1076func @infeed_dequeue_tuple() -> (tensor<3xi32>, tensor<4xf32>) {
1077// CHECK: [[TOKEN:%.*]] = "mhlo.create_token"() : () -> !mhlo.token
1078// CHECK: [[INFEED:%.*]] = "mhlo.infeed"([[TOKEN]]) {infeed_config = ""} : (!mhlo.token) -> tuple<tuple<tensor<3xi32>, tensor<4xf32>>, !mhlo.token>
1079// CHECK: [[INFEED_VAL:%.*]] = "mhlo.get_tuple_element"([[INFEED]]) {index = 0 : i32} : (tuple<tuple<tensor<3xi32>, tensor<4xf32>>, !mhlo.token>) -> tuple<tensor<3xi32>, tensor<4xf32>>
1080// CHECK: [[RES_1:%.*]] = "mhlo.get_tuple_element"([[INFEED_VAL]]) {index = 0 : i32} : (tuple<tensor<3xi32>, tensor<4xf32>>) -> tensor<3xi32>
1081// CHECK: [[RES_2:%.*]] = "mhlo.get_tuple_element"([[INFEED_VAL]]) {index = 1 : i32} : (tuple<tensor<3xi32>, tensor<4xf32>>) -> tensor<4xf32>
1082// CHECK: return [[RES_1]], [[RES_2]]
1083  %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3xi32>, tensor<4xf32>)
1084  return %0#0, %0#1 : tensor<3xi32>, tensor<4xf32>
1085}
1086
1087// The following op sharding is used:
1088// Proto debug string:
1089//   type: TUPLE
1090//   tuple_shardings {
1091//     type: MAXIMAL
1092//     tile_assignment_dimensions: 1
1093//     tile_assignment_devices: 0
1094//   }
1095// Serialized string:
1096//   "\08\02*\08\08\01\1A\01\01\22\01\00"
1097
1098// CHECK-LABEL: infeed_dequeue_tuple_sharding
1099func @infeed_dequeue_tuple_sharding() -> tensor<8xi32> {
1100  // CHECK: "mhlo.infeed"
1101  // An additional sharding is added at the end to account for token result.
1102  // Proto debug string:
1103  //   type: TUPLE
1104  //   tuple_shardings {
1105  //     type: MAXIMAL
1106  //     tile_assignment_dimensions: 1
1107  //     tile_assignment_devices: 0
1108  //   }
1109  //   tuple_shardings {
1110  //     type: MAXIMAL
1111  //     tile_assignment_dimensions: 1
1112  //     tile_assignment_devices: 0
1113  //   }
1114  // CHECK-SAME: mhlo.sharding = "\08\02*\08\08\01\1A\01\01\22\01\00*\08\08\01\1A\01\01\22\01\00"
1115  %0 = "tf.InfeedDequeueTuple"() {_XlaSharding = "\08\02*\08\08\01\1A\01\01\22\01\00"} : () -> tensor<8xi32>
1116  return %0 : tensor<8xi32>
1117}
1118
1119//===----------------------------------------------------------------------===//
1120// Nullary op legalizations.
1121//===----------------------------------------------------------------------===//
1122
1123// CHECK-LABEL: @const
1124func @const() -> tensor<2xi32> {
1125  // CHECK: mhlo.constant dense<0> : tensor<2xi32>
1126  %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = dense<0> : tensor<2xi32>} : () -> (tensor<2xi32>)
1127  return %0: tensor<2xi32>
1128}
1129
1130// CHECK-LABEL: @const_dynamic_output
1131func @const_dynamic_output() -> tensor<*xi32> {
1132  // CHECK: [[CONST:%.*]] = mhlo.constant dense<0> : tensor<2xi32>
1133  // CHECK: [[CAST:%.*]] = tensor.cast [[CONST]] : tensor<2xi32> to tensor<*xi32>
1134  %0 = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> (tensor<*xi32>)
1135  // CHECK: return [[CAST]]
1136  return %0: tensor<*xi32>
1137}
1138
1139// CHECK-LABEL: @opaque_const
1140func @opaque_const() -> tensor<!tf.variant<tensor<2xi32>>> {
1141  // CHECK-NOT: mhlo.constant
1142  %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<!tf.variant>} : () -> tensor<!tf.variant<tensor<2xi32>>>
1143  return %0 : tensor<!tf.variant<tensor<2xi32>>>
1144}
1145
1146//===----------------------------------------------------------------------===//
1147// Matmul op legalizations.
1148//===----------------------------------------------------------------------===//
1149
1150// CHECK-LABEL: matmul_notranspose
1151// CHECK-SAME: (%[[A:.*]]: tensor<5x7xf32>, %[[B:.*]]: tensor<7x11xf32>)
1152func @matmul_notranspose(%a: tensor<5x7xf32>, %b: tensor<7x11xf32>) -> tensor<5x11xf32> {
1153  // CHECK: "mhlo.dot"(%[[A]], %[[B]])
1154  %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor<5x7xf32>, tensor<7x11xf32>) -> tensor<5x11xf32>
1155
1156  return %0 : tensor<5x11xf32>
1157}
1158
1159// CHECK-LABEL: matmul_transpose_b
1160// CHECK-SAME: (%[[A:.*]]: tensor<5x7xf32>, %[[B:.*]]: tensor<11x7xf32>)
1161func @matmul_transpose_b(%a: tensor<5x7xf32>, %b: tensor<11x7xf32>) -> tensor<5x11xf32> {
1162  // CHECK: %[[UPDATED_B:.*]] = "mhlo.transpose"(%[[B]]) {permutation = dense<[1, 0]> : tensor<2xi64>}
1163  // CHECK: "mhlo.dot"(%[[A]], %[[UPDATED_B]])
1164  %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = true} : (tensor<5x7xf32>, tensor<11x7xf32>) -> tensor<5x11xf32>
1165
1166  return %0 : tensor<5x11xf32>
1167}
1168
1169// CHECK-LABEL: matmul_transpose_both
1170// CHECK-SAME: (%[[A:.*]]: tensor<7x5xf32>, %[[B:.*]]: tensor<11x7xf32>)
1171func @matmul_transpose_both(%a: tensor<7x5xf32>, %b: tensor<11x7xf32>) -> tensor<5x11xf32> {
1172  // CHECK: %[[UPDATED_A:.*]] = "mhlo.transpose"(%[[A]]) {permutation = dense<[1, 0]> : tensor<2xi64>}
1173  // CHECK: %[[UPDATED_B:.*]] = "mhlo.transpose"(%[[B]]) {permutation = dense<[1, 0]> : tensor<2xi64>}
1174  // CHECK: "mhlo.dot"(%[[UPDATED_A]], %[[UPDATED_B]])
1175  %0 = "tf.MatMul"(%a, %b) {transpose_a = true, transpose_b = true} : (tensor<7x5xf32>, tensor<11x7xf32>) -> tensor<5x11xf32>
1176
1177  return %0 : tensor<5x11xf32>
1178}
1179
1180// Verify that MatMul with ranked inputs are lowered to HLO.
1181// CHECK-LABEL: matmul_ranked
1182func @matmul_ranked(%a: tensor<?x7xf32>, %b: tensor<7x?xf32>) -> tensor<?x?xf32> {
1183  // CHECK: "mhlo.dot"
1184  %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor<?x7xf32>, tensor<7x?xf32>) -> tensor<?x?xf32>
1185
1186  return %0 : tensor<?x?xf32>
1187}
1188
1189// Verify that MatMul with unranked inputs are lowered to HLO.
1190// CHECK-LABEL: matmul_unranked
1191func @matmul_unranked(%a: tensor<*xf32>, %b: tensor<*xf32>) -> tensor<*xf32> {
1192  // CHECK: "mhlo.dot"
1193  %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
1194
1195  return %0 : tensor<*xf32>
1196}
1197
1198// Verify SparseMatMul is legalized to dot.
1199// CHECK-LABEL: test_sparse_mat_mul
1200func @test_sparse_mat_mul(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> tensor<3x5xf32> {
1201  // CHECK: "mhlo.dot"
1202  %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<3x5xf32>
1203  return %0: tensor<3x5xf32>
1204}
1205
1206// SparseMatMul where one operand needs to be transposed and the other one not.
1207//
1208// CHECK-LABEL:   @test_sparse_mat_mul_with_transpose
1209// CHECK-SAME:      %[[ARG0:.*]]: tensor<3x4xf32>
1210// CHECK-SAME:      %[[ARG1:.*]]: tensor<5x4xf32>
1211// CHECK-SAME:      -> tensor<3x5xf32>
1212// CHECK:           %[[TRANSPOSE:.*]] = "mhlo.transpose"(%[[ARG1]])
1213// CHECK-SAME:        permutation = dense<[1, 0]>
1214// CHECK-SAME:        -> tensor<4x5xf32>
1215// CHECK:           %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[TRANSPOSE]])
1216// CHECK-SAME:        -> tensor<3x5xf32>
1217// CHECK:           return %[[RESULT]]
1218func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5x4xf32>) -> tensor<3x5xf32> {
1219  %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = true} : (tensor<3x4xf32>, tensor<5x4xf32>) -> tensor<3x5xf32>
1220  return %0: tensor<3x5xf32>
1221}
1222
1223// SparseMatMul where one operand needs to be casted and the other one not.
1224//
1225// CHECK-LABEL:   @test_sparse_mat_mul_with_cast
1226// CHECK-SAME:      %[[ARG0:.*]]: tensor<3x4xf32>
1227// CHECK-SAME:      %[[ARG1:.*]]: tensor<4x5xbf16>
1228// CHECK-SAME:      -> tensor<3x5xf32>
1229// CHECK:           %[[CAST:.*]] = "mhlo.convert"(%[[ARG1]])
1230// CHECK-SAME:        -> tensor<4x5xf32>
1231// CHECK:           %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[CAST]])
1232// CHECK-SAME:        -> tensor<3x5xf32>
1233// CHECK:           return %[[RESULT]]
1234func @test_sparse_mat_mul_with_cast(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xbf16>) -> tensor<3x5xf32> {
1235  %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xbf16>) -> tensor<3x5xf32>
1236  return %0: tensor<3x5xf32>
1237}
1238
1239//===----------------------------------------------------------------------===//
1240// MatrixBandPart op legalizations.
1241//===----------------------------------------------------------------------===//
1242
1243// CHECK-LABEL: matrix_band_part
1244// CHECK-SAME: (%[[INPUT:.*]]: tensor<64x64xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>)
1245func @matrix_band_part(%arg0: tensor<64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> {
1246  // CHECK-DAG: %[[M:.*]] = mhlo.constant dense<64> : tensor<i64>
1247  // CHECK-DAG: %[[N:.*]] = mhlo.constant dense<64> : tensor<i64>
1248
1249  // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i64>
1250  // CHECK-DAG: %[[A:.*]] = "mhlo.compare"(%[[LOWER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
1251  // CHECK-DAG: %[[B:.*]] = "mhlo.select"(%[[A]], %[[M]], %[[LOWER]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
1252
1253  // CHECK-DAG: %[[C:.*]] = "mhlo.compare"(%[[UPPER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
1254  // CHECK-DAG: %[[D:.*]] = "mhlo.select"(%[[C]], %[[N]], %[[UPPER]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
1255  // CHECK-DAG: %[[F:.*]] = "mhlo.negate"(%[[B]]) : (tensor<i64>) -> tensor<i64>
1256
1257  // CHECK-DAG: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<64x64xi64>
1258  // CHECK-DAG: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64x64xi64>
1259  // CHECK-DAG: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<64x64xi64>
1260  // CHECK-DAG: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<i64>, tensor<64x64xi64>) -> tensor<64x64xi1>
1261
1262  // CHECK-DAG: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[D]] {comparison_direction = "LE"} : (tensor<64x64xi64>, tensor<i64>) -> tensor<64x64xi1>
1263
1264  // CHECK-DAG: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<64x64xi1>
1265
1266  // CHECK-DAG: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<64x64xbf16>
1267
1268  // CHECK-DAG: %[[R:.*]] = "mhlo.select"(%[[J]], %[[INPUT]], %[[ZERO2]])
1269  // CHECK-DAG: return %[[R]]
1270  %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64xbf16>
1271  return %0 : tensor<64x64xbf16>
1272}
1273
1274// CHECK-LABEL: matrix_band_part_2
1275// CHECK-SAME: (%[[INPUT:.*]]: tensor<12x24x48xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>)
1276func @matrix_band_part_2(%arg0: tensor<12x24x48xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<12x24x48xbf16> {
1277  // CHECK-DAG: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<24x48xi64>
1278  // CHECK-DAG: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<24x48xi64>
1279  // CHECK-DAG: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<24x48xi64>
1280
1281  // CHECK-DAG: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<i64>, tensor<24x48xi64>) -> tensor<24x48xi1>
1282
1283  // CHECK-DAG: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[D]] {comparison_direction = "LE"} : (tensor<24x48xi64>, tensor<i64>) -> tensor<24x48xi1>
1284  // CHECK-DAG: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<24x48xi1>
1285
1286  // CHECK-DAG: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<12x24x48xbf16>
1287
1288  // CHECK-DAG: %[[K:.*]] = "mhlo.broadcast_in_dim"(%[[J]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<24x48xi1>) -> tensor<12x24x48xi1>
1289  // CHECK-DAG: %[[R:.*]] = "mhlo.select"(%[[K]], %[[INPUT]], %[[ZERO2]])
1290  // CHECK-DAG: return %[[R]]
1291  %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<12x24x48xbf16>, tensor<i64>, tensor<i64>) -> tensor<12x24x48xbf16>
1292  return %0 : tensor<12x24x48xbf16>
1293}
1294
1295// CHECK-LABEL: matrix_band_part_3
1296// CHECK-SAME: (%[[INPUT:.*]]: tensor<*xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>)
1297func @matrix_band_part_3(%arg0: tensor<*xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<*xbf16> {
1298  // CHECK: "tf.MatrixBandPart"
1299  %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<*xbf16>, tensor<i64>, tensor<i64>) -> tensor<*xbf16>
1300  return %0 : tensor<*xbf16>
1301}
1302
1303// CHECK-LABEL: matrix_band_part_4
1304// CHECK-SAME: (%[[INPUT:.*]]: tensor<24x48xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>)
1305func @matrix_band_part_4(%arg0: tensor<24x48xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<24x48xbf16> {
1306  // This one should lower.
1307  // CHECK-NOT: "tf.MatrixBandPart"
1308  %0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<24x48xbf16>, tensor<i64>, tensor<i64>) -> tensor<24x48xbf16>
1309  return %0 : tensor<24x48xbf16>
1310}
1311
1312//===----------------------------------------------------------------------===//
1313// MaxPool op legalizations.
1314//===----------------------------------------------------------------------===//
1315
1316// CHECK-LABEL: maxpool_valid_padding
1317// CHECK-SAME: %[[ARG:.*]]: tensor
1318func @maxpool_valid_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> {
1319  // CHECK: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor<i32>
1320  // CHECK: "mhlo.reduce_window"(%[[ARG]], %[[INIT]])
1321  // CHECK: mhlo.maximum
1322  // CHECK: mhlo.return
1323  // CHECK: {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>}
1324
1325  %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32>
1326  return %0 : tensor<2x3x5x7xi32>
1327}
1328
1329// CHECK-LABEL: maxpool_same_padding
1330// CHECK-SAME: %[[ARG:.*]]: tensor
1331func @maxpool_same_padding(%arg0: tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32> {
1332  // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>
1333
1334  %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 4, 1]} : (tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32>
1335  return %0 : tensor<2x4x7x7xi32>
1336}
1337
1338// CHECK-LABEL: maxpool_3d_valid_padding
1339// CHECK-SAME: %[[ARG:.*]]: tensor
1340func @maxpool_3d_valid_padding(%arg0: tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32> {
1341  // CHECK: %[[INIT:.*]] = mhlo.constant dense<0xFF800000> : tensor<f32>
1342  // CHECK: "mhlo.reduce_window"(%[[ARG]], %[[INIT]])
1343  // CHECK: mhlo.maximum
1344  // CHECK: mhlo.return
1345  // CHECK: {window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 4, 4, 1]> : tensor<5xi64>}
1346
1347  %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32>
1348  return %0 : tensor<2x8x3x5x7xf32>
1349}
1350
1351// CHECK-LABEL: maxpool_3d_same_padding
1352// CHECK-SAME: %[[ARG:.*]]: tensor
1353func @maxpool_3d_same_padding(%arg0: tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x7xf32> {
1354  // CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64>
1355
1356  %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 3, 1], padding = "SAME", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x7xf32>
1357  return %0 : tensor<2x8x4x7x7xf32>
1358}
1359
1360// CHECK-LABEL: maxpool_explicit_padding
1361func @maxpool_explicit_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> {
1362  // CHECK: tf.MaxPool
1363  // TODO(b/165938852): need to support explicit padding in max_pool.
1364
1365  %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "EXPLICIT", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32>
1366  return %0 : tensor<2x3x5x7xi32>
1367}
1368
1369//===----------------------------------------------------------------------===//
1370// MaxPoolGrad op legalizations.
1371//===----------------------------------------------------------------------===//
1372
1373// CHECK-LABEL: @max_pool_grad_valid
1374// CHECK-SAME: %[[INPUT:.*]]: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x12x12x64xf32>
1375func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_output: tensor<10x12x12x64xf32>, %grad: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> {
1376  // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
1377  // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) ( {
1378  // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>):
1379  // CHECK: %[[SELECT_RESULT:.*]] = "mhlo.compare"(%[[VALUE_A]], %[[VALUE_B]]) {comparison_direction = "GE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
1380  // CHECK: "mhlo.return"(%[[SELECT_RESULT]]) : (tensor<i1>) -> ()
1381  // CHECK: },  {
1382  // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>):
1383  // CHECK: %[[SELECT_RESULT:.*]] = mhlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor<f32>
1384  // CHECK: "mhlo.return"(%[[SELECT_RESULT]]) : (tensor<f32>) -> ()
1385  // CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) -> tensor<10x24x24x64xf32>
1386  // CHECK: return %[[RESULT]] : tensor<10x24x24x64xf32>
1387  %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) {
1388     data_format = "NHWC",
1389     ksize = [1, 2, 2, 1],
1390     padding = "VALID",
1391     strides = [1, 2, 2, 1]
1392  } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32>
1393  return %result : tensor<10x24x24x64xf32>
1394}
1395
1396// CHECK-LABEL: @max_pool_3d_grad_valid
1397// CHECK-SAME: %[[INPUT:.*]]: tensor<10x8x24x24x64xf32>, %arg1: tensor<10x8x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x8x12x12x64xf32>
1398func @max_pool_3d_grad_valid(%orig_input: tensor<10x8x24x24x64xf32>, %orig_output: tensor<10x8x12x12x64xf32>, %grad: tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> {
1399  // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
1400  // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) ( {
1401  // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>):
1402  // CHECK: %[[SELECT_RESULT:.*]] = "mhlo.compare"(%[[VALUE_A]], %[[VALUE_B]]) {comparison_direction = "GE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
1403  // CHECK: "mhlo.return"(%[[SELECT_RESULT]]) : (tensor<i1>) -> ()
1404  // CHECK: },  {
1405  // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor<f32>, %[[VALUE_B:.*]]: tensor<f32>):
1406  // CHECK: %[[SELECT_RESULT:.*]] = mhlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor<f32>
1407  // CHECK: "mhlo.return"(%[[SELECT_RESULT]]) : (tensor<f32>) -> ()
1408  // CHECK: }) {window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<f32>) -> tensor<10x8x24x24x64xf32>
1409  // CHECK: return %[[RESULT]] : tensor<10x8x24x24x64xf32>
1410  %result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 2, 2, 1]} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32>
1411  return %result : tensor<10x8x24x24x64xf32>
1412}
1413
1414// CHECK-LABEL: @max_pool_grad_same
1415func @max_pool_grad_same(%orig_input: tensor<2x13x25x7xf32>, %orig_output: tensor<2x4x7x7xf32>, %grad: tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> {
1416  // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>
1417  %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) {
1418     data_format = "NHWC",
1419     ksize = [1, 2, 3, 1],
1420     padding = "SAME",
1421     strides = [1, 4, 4, 1]
1422  } : (tensor<2x13x25x7xf32>, tensor<2x4x7x7xf32>, tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32>
1423  return %result : tensor<2x13x25x7xf32>
1424}
1425
1426// CHECK-LABEL: @max_pool_3d_grad_same
1427func @max_pool_3d_grad_same(%orig_input: tensor<2x8x13x25x7xf32>, %orig_output: tensor<2x8x4x7x7xf32>, %grad: tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32> {
1428  // CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64>
1429  %result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 3, 1], padding = "SAME", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x13x25x7xf32>, tensor<2x8x4x7x7xf32>, tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32>
1430  return %result : tensor<2x8x13x25x7xf32>
1431}
1432
1433//===----------------------------------------------------------------------===//
1434// OneHot op legalizations.
1435//===----------------------------------------------------------------------===//
1436
1437// CHECK-LABEL:one_hot
1438func @one_hot(%indices: tensor<3xi32>, %on_value: tensor<f32>, %off_value: tensor<f32>) -> tensor<3x5xf32> {
1439  // CHECK: %[[IOTA:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x5xi32>
1440  // CHECK: %[[BCAST_ARG0:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<3x5xi32>
1441  // CHECK: %[[COMPARE:.*]] = "mhlo.compare"(%[[BCAST_ARG0]], %[[IOTA]]) {comparison_direction = "EQ"} : (tensor<3x5xi32>, tensor<3x5xi32>) -> tensor<3x5xi1>
1442  // CHECK: %[[ON_VALUE:.*]] = "mhlo.broadcast"(%arg1) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor<f32>) -> tensor<3x5xf32>
1443  // CHECK: %[[OFF_VALUE:.*]] = "mhlo.broadcast"(%arg2) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor<f32>) -> tensor<3x5xf32>
1444  // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[COMPARE]], %[[ON_VALUE]], %[[OFF_VALUE]]) : (tensor<3x5xi1>, tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<3x5xf32>
1445  // CHECK: return %[[RESULT]] : tensor<3x5xf32>
1446  %depth = "tf.Const"() { value = dense<5> : tensor<i32> } : () -> tensor<i32>
1447  %result = "tf.OneHot"(%indices, %depth, %on_value, %off_value) {axis = -1 : i64} : (tensor<3xi32>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<3x5xf32>
1448  return %result : tensor<3x5xf32>
1449}
1450
1451//===----------------------------------------------------------------------===//
1452// tf.OutfeedEnqueueTuple legalization
1453//===----------------------------------------------------------------------===//
1454
1455// CHECK-LABEL: func @outfeed_enqueue_tuple
1456// CHECK-SAME: [[VAL_0:%.*]]: tensor<3xi32>, [[VAL_1:%.*]]: tensor<4xf32>)
1457func @outfeed_enqueue_tuple(%data_1: tensor<3xi32>, %data_2: tensor<4xf32>) -> () {
1458// CHECK: [[TUPLE:%.*]] = "mhlo.tuple"([[VAL_0]], [[VAL_1]]) : (tensor<3xi32>, tensor<4xf32>) -> tuple<tensor<3xi32>, tensor<4xf32>>
1459// CHECK: [[TOKEN:%.*]] = "mhlo.create_token"() : () -> !mhlo.token
1460// CHECK: "mhlo.outfeed"([[TUPLE]], [[TOKEN]]) {outfeed_config = ""} : (tuple<tensor<3xi32>, tensor<4xf32>>, !mhlo.token) -> !mhlo.token
1461  "tf.OutfeedEnqueueTuple"(%data_1, %data_2) : (tensor<3xi32>, tensor<4xf32>) -> ()
1462  return
1463}
1464
1465//===----------------------------------------------------------------------===//
1466// Pack op legalizations.
1467//===----------------------------------------------------------------------===//
1468
1469// CHECK-LABEL: func @pack
1470func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
1471  // CHECK: "mhlo.reshape"({{.*}}) : (tensor<2xi32>) -> tensor<1x2xi32>
1472  // CHECK: "mhlo.reshape"({{.*}}) : (tensor<2xi32>) -> tensor<1x2xi32>
1473  // CHECK: "mhlo.concatenate"({{.*}}) {dimension = 0 : i64} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32>
1474
1475  %0 = "tf.Pack"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
1476  return %0 : tensor<2x2xi32>
1477}
1478
1479//===----------------------------------------------------------------------===//
1480// PartitionedCall op legalization.
1481//===----------------------------------------------------------------------===//
1482
1483// CHECK-LABEL: func @partitioned_call
1484func @partitioned_call(%arg0: tensor<i32>) -> tensor<i32> {
1485  // CHECK: call @pcall_func(%arg0) : (tensor<i32>) -> tensor<i32>
1486  %0 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @pcall_func} : (tensor<i32>) -> (tensor<i32>)
1487  return %0 : tensor<i32>
1488}
1489
1490func @pcall_func(%arg0: tensor<i32>) -> tensor<i32> {
1491  return %arg0 : tensor<i32>
1492}
1493
1494// CHECK-LABEL: func @partitioned_call_multi_input
1495func @partitioned_call_multi_input(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
1496  // CHECK: call @pcall_multi_input(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
1497  %0 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_input} : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
1498  return %0 : tensor<i32>
1499}
1500
1501func @pcall_multi_input(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
1502  return %arg0 : tensor<i32>
1503}
1504
1505// CHECK-LABEL: func @partitioned_call_multi_in_out
1506func @partitioned_call_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
1507  // CHECK: call @pcall_multi_in_out(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
1508  %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
1509  return %0, %1 : tensor<i32>, tensor<i32>
1510}
1511
1512func @pcall_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
1513  return %arg1, %arg0 : tensor<i32>, tensor<i32>
1514}
1515
1516// CHECK-LABEL: func @unhandled_partitioned_call
1517func @unhandled_partitioned_call(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> (tensor<i32>, tensor<i32>) {
1518  // The argument types don't match the parameter types for the
1519  // pcall_multi_in_out function. That's fine for a PartitionedCallOp but not
1520  // for a standard CallOp, so this op can't be lowered.
1521  // CHECK: "tf.PartitionedCall"
1522  %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<i32>, tensor<i32>)
1523  return %0, %1 : tensor<i32>, tensor<i32>
1524}
1525
1526// CHECK-LABEL: func @unhandled_partitioned_call_2
1527func @unhandled_partitioned_call_2(%arg0: tensor<i32>, %arg1: tensor<*xi32>) -> (tensor<i32>, tensor<i32>) {
1528  // CHECK: "tf.PartitionedCall"
1529  %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor<i32>, tensor<*xi32>) -> (tensor<i32>, tensor<i32>)
1530  return %0, %1 : tensor<i32>, tensor<i32>
1531}
1532
1533
1534//===----------------------------------------------------------------------===//
1535// ReverseV2 op legalization.
1536//===----------------------------------------------------------------------===//
1537
1538// CHECK-LABEL: @reverse_func_32
1539func @reverse_func_32(%arg0: tensor<5xi32>) -> tensor<5xi32> {
1540  %axis = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> (tensor<1xi32>)
1541
1542  // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>}
1543  %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi32>) -> tensor<5xi32>
1544
1545  // CHECK: return [[VAL]] : tensor<5xi32>
1546  return %reversed : tensor<5xi32>
1547}
1548
1549// CHECK-LABEL: @reverse_func_64
1550func @reverse_func_64(%arg0: tensor<5xi32>) -> tensor<5xi32> {
1551  %axis = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> (tensor<1xi64>)
1552
1553  // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>}
1554  %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi64>) -> tensor<5xi32>
1555
1556  // CHECK: return [[VAL]] : tensor<5xi32>
1557  return %reversed : tensor<5xi32>
1558}
1559
1560// CHECK-LABEL: @reverse_func_neg
1561func @reverse_func_neg(%arg0: tensor<5x5xi32>) -> tensor<5x5xi32> {
1562  %axis = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>)
1563
1564  // CHECK: [[VAL:%.+]] = "mhlo.reverse"(%arg0) {dimensions = dense<1> : tensor<1xi64>}
1565  %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5x5xi32>, tensor<1xi32>) -> tensor<5x5xi32>
1566
1567  // CHECK: return [[VAL]] : tensor<5x5xi32>
1568  return %reversed : tensor<5x5xi32>
1569}
1570
1571//===----------------------------------------------------------------------===//
1572// StatefulPartitionedCall op legalization.
1573//===----------------------------------------------------------------------===//
1574
1575// CHECK-LABEL: func @stateful_partitioned_call
1576// CHECK-SAME: [[ARG:%.+]]: tensor<i32>
1577func @stateful_partitioned_call(%arg0: tensor<i32>) -> tensor<i32> {
1578  // CHECK: call @stateful_pcall_func([[ARG]]) : (tensor<i32>) -> tensor<i32>
1579  %0 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @stateful_pcall_func} : (tensor<i32>) -> (tensor<i32>)
1580  return %0 : tensor<i32>
1581}
1582
1583func @stateful_pcall_func(%arg0: tensor<i32>) -> tensor<i32> {
1584  return %arg0 : tensor<i32>
1585}
1586
1587// CHECK-LABEL: func @stateful_partitioned_call_multi_in_out
1588// CHECK-SAME: ([[ARG0:%.+]]: tensor<i32>, [[ARG1:%.+]]: tensor<i32>)
1589func @stateful_partitioned_call_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
1590  // CHECK: call @stateful_pcall_multi_in_out([[ARG0]], [[ARG1]]) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
1591  %0, %1 = "tf.StatefulPartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @stateful_pcall_multi_in_out} : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
1592  return %0, %1 : tensor<i32>, tensor<i32>
1593}
1594
1595func @stateful_pcall_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
1596  return %arg1, %arg0 : tensor<i32>, tensor<i32>
1597}
1598
1599//===----------------------------------------------------------------------===//
1600// Elu op legalizations.
1601//===----------------------------------------------------------------------===//
1602
1603// CHECK-LABEL: func @elu
1604func @elu(%arg0: tensor<1xf32>) -> tensor<1xf32> {
1605  // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
1606  // CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %arg0, %[[ZERO]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "GT"}
1607  // CHECK-DAG: %[[EXP:.*]] = "mhlo.exponential_minus_one"(%arg0)
1608  // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %arg0, %[[EXP]])
1609  // CHECK: return %[[RESULT]]
1610  %0 = "tf.Elu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32>
1611  return %0: tensor<1xf32>
1612}
1613
1614// CHECK-LABEL: func @elu_grad
1615// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<4x8xf32>, %[[FEATURES:.*]]: tensor<?x?xf32>)
1616func @elu_grad(%gradients: tensor<4x8xf32>, %features: tensor<?x?xf32>) -> tensor<4x8xf32> {
1617  // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
1618  // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
1619  // CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FEATURES]], %[[ZERO]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "GT"}
1620  // CHECK-DAG: %[[ADD1:.*]] = chlo.broadcast_add %[[FEATURES]], %[[ONE]] {broadcast_dimensions = dense<> : tensor<0xi64>}
1621  // CHECK-DAG: %[[MULGRAD:.*]] = "mhlo.multiply"(%[[GRADIENTS]], %[[ADD1]])
1622  // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %[[GRADIENTS]], %[[MULGRAD]])
1623  // CHECK: return %[[RESULT]]
1624  %2 = "tf.EluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor<?x?xf32>) -> tensor<4x8xf32>
1625  return %2 : tensor<4x8xf32>
1626}
1627
1628//===----------------------------------------------------------------------===//
1629// Relu op legalizations.
1630//===----------------------------------------------------------------------===//
1631
1632// CHECK-LABEL: func @relu
1633func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> {
1634  // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32>
1635  // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
1636  %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
1637  return %0: tensor<1xi32>
1638}
1639
1640// CHECK-LABEL: func @relu_unranked
1641func @relu_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
1642  // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32>
1643  // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
1644  %0 = "tf.Relu"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
1645  return %0: tensor<?xi32>
1646}
1647
1648// CHECK-LABEL: func @relu6
1649func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> {
1650  // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32>
1651  // CHECK: %[[SIX:.*]] = mhlo.constant dense<6> : tensor<i32>
1652  // CHECK: "mhlo.clamp"(%[[ZERO]], %arg0, %[[SIX]]) : (tensor<i32>, tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
1653  %0 = "tf.Relu6"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
1654  return %0: tensor<1xi32>
1655}
1656
1657// CHECK-LABEL: func @relu6_unranked
1658func @relu6_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
1659  // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32>
1660  // CHECK: %[[SIX:.*]] = mhlo.constant dense<6> : tensor<i32>
1661  // CHECK: "mhlo.clamp"(%[[ZERO]], %arg0, %[[SIX]]) : (tensor<i32>, tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
1662  %0 = "tf.Relu6"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
1663  return %0: tensor<?xi32>
1664}
1665
1666// CHECK-LABEL: func @relu_grad
1667// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<4x8xf32>, %[[FEATURES:.*]]: tensor<?x?xf32>)
1668func @relu_grad(%gradients: tensor<4x8xf32>, %features: tensor<?x?xf32>) -> tensor<4x8xf32> {
1669  // CHECK-DAG: %[[ZERO_SCALAR:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
1670  // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<4x8xf32>
1671  // CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FEATURES]], %[[ZERO_SCALAR]] {comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
1672  // CHECK-DAG: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor<?x?xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
1673  // CHECK-DAG: return %[[RESULT]] : tensor<4x8xf32>
1674  %2 = "tf.ReluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor<?x?xf32>) -> tensor<4x8xf32>
1675  return %2 : tensor<4x8xf32>
1676}
1677
1678//===----------------------------------------------------------------------===//
1679// Select op legalizations.
1680//===----------------------------------------------------------------------===//
1681
1682// CHECK-LABEL: func @selectv2
1683func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
1684  // CHECK-NEXT: "mhlo.select"(%arg0, %arg1, %arg2)
1685  %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
1686  return %0: tensor<2xi32>
1687}
1688
1689// CHECK-LABEL: func @selectv2_pred_scalar
1690func @selectv2_pred_scalar(%arg0: tensor<i1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
1691  // CHECK-NEXT: "mhlo.select"(%arg0, %arg1, %arg2)
1692  %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
1693  return %0: tensor<2xi32>
1694}
1695
1696// CHECK-LABEL: func @selectv2_broadcast_then
1697func @selectv2_broadcast_then(%arg0: tensor<i1>, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> {
1698  // CHECK: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<8x1xi32>) -> tensor<2x8x8xi32>
1699  // CHECK: "mhlo.select"(%arg0, %[[BROADCAST]], %arg2)
1700  %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
1701  return %0: tensor<2x8x8xi32>
1702}
1703
1704// CHECK-LABEL: func @selectv2_broadcast_else
1705func @selectv2_broadcast_else(%arg0: tensor<i1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<8x1xi32>) -> tensor<2x8x8xi32> {
1706  // CHECK: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<8x1xi32>) -> tensor<2x8x8xi32>
1707  // CHECK: "mhlo.select"(%arg0, %arg1, %[[BROADCAST]])
1708  %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x8x8xi32>, tensor<8x1xi32>) -> tensor<2x8x8xi32>
1709  return %0: tensor<2x8x8xi32>
1710}
1711
1712// CHECK-LABEL: func @selectv2_broadcast_pred
1713func @selectv2_broadcast_pred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> {
1714  // CHECK: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1xi1>) -> tensor<2x8x8xi1>
1715  // CHECK: "mhlo.select"(%[[BROADCAST]], %arg1, %arg2)
1716  %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
1717  return %0: tensor<2x8x8xi32>
1718}
1719
1720// CHECK-LABEL: func @selectv2_broadcast_tensor_pred
1721func @selectv2_broadcast_tensor_pred(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> {
1722  // CHECK: %[[BROADCAST:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi1>) -> tensor<2x3xi1>
1723  // CHECK: "mhlo.select"(%[[BROADCAST]], %arg1, %arg2)
1724  %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
1725  return %0: tensor<2x3xf16>
1726}
1727
1728// CHECK-LABEL: func @selectv2_broadcast_all
1729func @selectv2_broadcast_all(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> {
1730  // CHECK-DAG: %[[BROADCAST_0:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x1x1xi1>) -> tensor<8x8x8xi1>
1731  // CHECK-DAG: %[[BROADCAST_1:.*]] = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x8x1xi32>) -> tensor<8x8x8xi32>
1732  // CHECK-DAG: %[[BROADCAST_2:.*]] = "mhlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x1x8xi32>) -> tensor<8x8x8xi32>
1733  // CHECK: "mhlo.select"(%[[BROADCAST_0]], %[[BROADCAST_1]], %[[BROADCAST_2]])
1734  %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8x1x1xi1>, tensor<1x8x1xi32>, tensor<1x1x8xi32>) -> tensor<8x8x8xi32>
1735  return %0: tensor<8x8x8xi32>
1736}
1737
1738// CHECK-LABEL: func @selectv2_dynamic_ranked
1739func @selectv2_dynamic_ranked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> {
1740  // CHECK: tf.SelectV2
1741  %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32>
1742  return %0: tensor<2x?x8xi32>
1743}
1744
1745// CHECK-LABEL: func @selectv2_unranked
1746func @selectv2_unranked(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<*xi32>) -> tensor<*xi32> {
1747  // CHECK: tf.SelectV2
1748  %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<*xi32>) -> tensor<*xi32>
1749  return %0: tensor<*xi32>
1750}
1751
1752//===----------------------------------------------------------------------===//
1753// Softmax op legalizations.
1754//===----------------------------------------------------------------------===//
1755
1756// CHECK-LABEL: func @simple_softmax
1757// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>)
1758func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
1759
1760  // Verify reduce op for max computation and its body.
1761  // CHECK-DAG: %[[NEG_INF:.*]] = mhlo.constant dense<0xFF800000> : tensor<f32>
1762  // CHECK-DAG: %[[CASTED_INP:.*]] = "mhlo.convert"(%[[ARG0]]) : (tensor<2x3xf32>) -> tensor<2x3xf32>
1763  // CHECK: %[[MAX:.*]] = "mhlo.reduce"(%[[CASTED_INP]], %[[NEG_INF]])
1764  // CHECK:  mhlo.maximum
1765  // CHECK: "mhlo.return"
1766  // CHECK: {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf32>, tensor<f32>) -> tensor<2xf32>
1767  // CHECK: %[[CASTED_MAX:.*]] = "mhlo.convert"(%[[MAX]]) : (tensor<2xf32>) -> tensor<2xf32>
1768
1769  // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]]
1770  // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
1771  // CHECK: %[[BCAST_MAX:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[CASTED_MAX]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
1772  // CHECK: %[[SHIFTED_INP:.*]] = mhlo.subtract %[[ARG0]], %[[BCAST_MAX]]
1773  // CHECK: %[[EXP:.*]] = "mhlo.exponential"(%[[SHIFTED_INP]])
1774
1775  // Verify reduce op for summation and its body.
1776  // CHECK-DAG: %[[CASTED_EXP:.*]] = "mhlo.convert"(%[[EXP]]) : (tensor<2x3xf32>) -> tensor<2x3xf32>
1777  // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
1778  // CHECK: %[[SUM:.*]] = "mhlo.reduce"(%[[CASTED_EXP]], %[[ZERO]])
1779  // CHECK:  mhlo.add
1780  // CHECK: "mhlo.return"
1781  // CHECK: {dimensions = dense<1> : tensor<1xi64>}
1782  // CHECK: %[[CASTED_SUM:.*]] = "mhlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32>
1783
1784  // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]]
1785  // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
1786  // CHECK: %[[BCAST_SUM:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[CASTED_SUM]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
1787  // CHECK: %[[RESULT:.*]] = mhlo.divide %[[EXP]], %[[BCAST_SUM]]
1788  // CHECK: return %[[RESULT]]
1789
1790  %0 = "tf.Softmax"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
1791  return %0: tensor<2x3xf32>
1792}
1793
1794// Verify intermediate and final shape are correct with dynamic shapes.
1795// CHECK-LABEL: func @dynamic_softmax
1796func @dynamic_softmax(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
1797  // CHECK: mhlo.divide {{.*}}  : tensor<?x?xf32>
1798  %0 = "tf.Softmax"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
1799  return %0: tensor<?x?xf32>
1800}
1801
1802// CHECK-LABEL: bf16_softmax
1803func @bf16_softmax(%arg0: tensor<2x3xbf16>) -> tensor<2x3xbf16> {
1804  // Verify that conversion to f32 and then back to bf16 are introduced.
1805
1806  // CHECK: "mhlo.convert"({{.*}}) : (tensor<2x3xbf16>) -> tensor<2x3xf32>
1807  // CHECK: "mhlo.convert"({{.*}}) : (tensor<2xf32>) -> tensor<2xbf16>
1808
1809  %0 = "tf.Softmax"(%arg0) : (tensor<2x3xbf16>) -> tensor<2x3xbf16>
1810  return %0: tensor<2x3xbf16>
1811}
1812
1813// CHECK-LABEL: rank4_softmax
1814func @rank4_softmax(%arg0: tensor<2x3x4x5xf16>) -> tensor<2x3x4x5xf16> {
1815  // Verify that reduce op dimensions and broadcast dimensions are correct.
1816
1817  // CHECK: "mhlo.reduce"
1818  // CHECK: dimensions = dense<3>
1819
1820  // CHECK: "mhlo.reduce"
1821  // CHECK: dimensions = dense<3>
1822
1823  // CHECK: {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}
1824  // CHECK: mhlo.divide {{.*}}
1825  %0 = "tf.Softmax"(%arg0) : (tensor<2x3x4x5xf16>) -> tensor<2x3x4x5xf16>
1826  return %0: tensor<2x3x4x5xf16>
1827}
1828
1829//===----------------------------------------------------------------------===//
1830// LogSoftmax op legalizations.
1831// This just changes the tail of the regular Softmax legalization
1832//===----------------------------------------------------------------------===//
1833
1834// CHECK-LABEL: func @simple_logsoftmax
1835// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>)
1836func @simple_logsoftmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
1837  // CHECK: %{{.*}} = "mhlo.reduce"({{.*}})
1838  // CHECK: %[[SUM:.*]] = "mhlo.reduce"({{.*}})
1839  // CHECK: %[[CASTED_SUM:.*]] = "mhlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32>
1840  // CHECK: %[[LOG:.*]] = "mhlo.log"(%[[CASTED_SUM]]) : (tensor<2xf32>) -> tensor<2xf32>
1841  // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]]
1842  // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
1843  // CHECK: %[[BCAST_SUM:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[LOG]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
1844  // CHECK: %[[RESULT:.*]] = mhlo.subtract {{.*}}, %[[BCAST_SUM]]
1845  // CHECK: return %[[RESULT]]
1846
1847  %0 = "tf.LogSoftmax"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
1848  return %0: tensor<2x3xf32>
1849}
1850
1851//===----------------------------------------------------------------------===//
1852// Fast Fourier Transform op legalization.
1853//===----------------------------------------------------------------------===//
1854
1855// CHECK-LABEL: func @fft_1D
1856func @fft_1D(%arg0: tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>> {
1857  // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = "FFT"} : (tensor<8xcomplex<f32>>
1858  %0 = "tf.FFT"(%arg0) : (tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>>
1859  return %0 : tensor<8xcomplex<f32>>
1860}
1861
1862// CHECK-LABEL: func @ifft_1D
1863func @ifft_1D(%arg0: tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>> {
1864  // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = "IFFT"} : (tensor<8xcomplex<f32>>
1865  %0 = "tf.IFFT"(%arg0) : (tensor<8xcomplex<f32>>) -> tensor<8xcomplex<f32>>
1866  return %0 : tensor<8xcomplex<f32>>
1867}
1868
1869// CHECK-LABEL: func @rfft_1D
1870func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<8xcomplex<f32>> {
1871  %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>)
1872  // CHECK: "mhlo.fft"(%arg0) {fft_length = dense<8> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<8xf32>
1873  %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<8xf32>, tensor<1xi32>) -> tensor<8xcomplex<f32>>
1874  return %0 : tensor<8xcomplex<f32>>
1875}
1876
1877// CHECK-LABEL: func @rfft_1D_padded
1878func @rfft_1D_padded(%arg0: tensor<7xf32>) -> tensor<8xcomplex<f32>> {
1879  %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>)
1880  // CHECK: %[[PADDED:.*]] = "mhlo.pad"(%arg0, %2) {edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<7xf32>, tensor<f32>) -> tensor<8xf32>
1881  // CHECK: "mhlo.fft"(%[[PADDED]]) {fft_length = dense<8> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<8xf32>
1882  %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<7xf32>, tensor<1xi32>) -> tensor<8xcomplex<f32>>
1883  return %0 : tensor<8xcomplex<f32>>
1884}
1885
1886// CHECK-LABEL: func @rfft_1D_sliced
1887func @rfft_1D_sliced(%arg0: tensor<2x9xf32>) -> tensor<2x8xcomplex<f32>> {
1888  %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>)
1889  // CHECK: %[[SLICED:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[2, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x9xf32>) -> tensor<2x8xf32>
1890  // CHECK: "mhlo.fft"(%[[SLICED]]) {fft_length = dense<8> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<2x8xf32>
1891  %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<2x9xf32>, tensor<1xi32>) -> tensor<2x8xcomplex<f32>>
1892  return %0 : tensor<2x8xcomplex<f32>>
1893}
1894
1895// CHECK-LABEL: func @irfft_1D
1896func @irfft_1D(%arg0: tensor<8xcomplex<f32>>) -> tensor<5xf32> {
1897  %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>)
1898  // CHECK: %[[SLICED:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<5> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<8xcomplex<f32>>) -> tensor<5xcomplex<f32>>
1899  // CHECK: "mhlo.fft"(%[[SLICED]]) {fft_length = dense<5> : tensor<1xi64>, fft_type = "IRFFT"} : (tensor<5xcomplex<f32>>
1900  %0 = "tf.IRFFT"(%arg0, %fftlength) : (tensor<8xcomplex<f32>>, tensor<1xi32>) -> tensor<5xf32>
1901  return %0 : tensor<5xf32>
1902}
1903
1904// CHECK-LABEL: fft_1D_dynamic
1905func @fft_1D_dynamic(%arg0: tensor<?xcomplex<f32>>) -> tensor<8xcomplex<f32>> {
1906  // CHECK: "tf.FFT"
1907  %0 = "tf.FFT"(%arg0) : (tensor<?xcomplex<f32>>) -> tensor<8xcomplex<f32>>
1908  return %0 : tensor<8xcomplex<f32>>
1909}
1910
1911// CHECK-LABEL: rfft_1D_dynamic
1912func @rfft_1D_dynamic(%arg0: tensor<?xf32>) -> tensor<8xcomplex<f32>> {
1913  %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>)
1914  // CHECK: "tf.RFFT"
1915  %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<?xf32>, tensor<1xi32>) -> tensor<8xcomplex<f32>>
1916  return %0 : tensor<8xcomplex<f32>>
1917}
1918
1919//===----------------------------------------------------------------------===//
1920// Shape op legalization.
1921//===----------------------------------------------------------------------===//
1922
1923// CHECK-LABEL: func @shape_1D
1924func @shape_1D(%arg0: tensor<?xf32>) -> tensor<1xi32> {
1925  // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
1926  // CHECK: [[TENSOR:%.+]] = shape.to_extent_tensor [[SHAPE]]
1927  // CHECK: [[CAST:%.+]] = index_cast [[TENSOR]]
1928  %0 = "tf.Shape"(%arg0) : (tensor<?xf32>) -> tensor<1xi32>
1929
1930  // CHECK: return [[CAST]]
1931  return %0 : tensor<1xi32>
1932}
1933
1934// CHECK-LABEL: func @shape_2D
1935func @shape_2D(%arg0: tensor<?x?xf32>) -> tensor<2xi32> {
1936  // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
1937  // CHECK: [[TENSOR:%.+]] = shape.to_extent_tensor [[SHAPE]]
1938  // CHECK: [[CAST:%.+]] = index_cast [[TENSOR]]
1939  %0 = "tf.Shape"(%arg0) : (tensor<?x?xf32>) -> tensor<2xi32>
1940
1941  // CHECK: return [[CAST]]
1942  return %0 : tensor<2xi32>
1943}
1944
1945// CHECK-LABEL: func @shape_rankless
1946func @shape_rankless(%arg0: tensor<*xf32>) -> tensor<?xi32> {
1947  // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
1948  // CHECK: [[TENSOR:%.+]] = shape.to_extent_tensor [[SHAPE]]
1949  // CHECK: [[CAST:%.+]] = index_cast [[TENSOR]]
1950  %0 = "tf.Shape"(%arg0) : (tensor<*xf32>) -> tensor<?xi32>
1951
1952  // CHECK: return [[CAST]]
1953  return %0 : tensor<?xi32>
1954}
1955
1956//===----------------------------------------------------------------------===//
1957// Transpose op legalization.
1958//===----------------------------------------------------------------------===//
1959
1960// CHECK-LABEL: @transpose_noop
1961func @transpose_noop(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
1962  %permutation = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> (tensor<2xi64>)
1963  // CHECK: return %arg0
1964  %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<2x3xf32>
1965  return %0 : tensor<2x3xf32>
1966}
1967
1968// CHECK-LABEL: @transpose_2d
1969func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> {
1970  %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>)
1971  // CHECK: "mhlo.transpose"
1972  %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<3x2xf32>
1973  return %0 : tensor<3x2xf32>
1974}
1975
1976// CHECK-LABEL: @transpose_3d_int32
1977func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
1978  %permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> (tensor<3xi32>)
1979  // CHECK: "mhlo.transpose"
1980  %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<1x2x3xf32>, tensor<3xi32>) -> tensor<3x2x1xf32>
1981  return %0 : tensor<3x2x1xf32>
1982}
1983
1984// CHECK-LABEL: @transpose_3d
1985func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
1986  %permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> (tensor<3xi64>)
1987  // CHECK: "mhlo.transpose"
1988  %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<1x2x3xf32>, tensor<3xi64>) -> tensor<3x2x1xf32>
1989  return %0 : tensor<3x2x1xf32>
1990}
1991
1992// CHECK-LABEL: @transpose_dynamic_2d
1993func @transpose_dynamic_2d(%arg0: tensor<?x4xf32>) -> tensor<4x?xf32> {
1994  %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>)
1995  // CHECK: "mhlo.transpose"
1996  %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<?x4xf32>, tensor<2xi64>) -> tensor<4x?xf32>
1997  return %0 : tensor<4x?xf32>
1998}
1999
2000// CHECK-LABEL: @transpose_unranked_2d
2001func @transpose_unranked_2d(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2002  %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>)
2003  // CHECK: "mhlo.transpose"
2004  %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<*xf32>, tensor<2xi64>) -> tensor<*xf32>
2005  return %0 : tensor<*xf32>
2006}
2007
2008
2009//===----------------------------------------------------------------------===//
2010// Unary op legalizations.
2011//===----------------------------------------------------------------------===//
2012
2013// CHECK-LABEL: @abs
2014func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2015  // CHECK:  "mhlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2016  %0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2017  return %0 : tensor<2xf32>
2018}
2019
2020// CHECK-LABEL: func @abs_dynamic
2021func @abs_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2022  // CHECK:  "mhlo.abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2023  %0 = "tf.Abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2024  return %0 : tensor<?xf32>
2025}
2026
2027// CHECK-LABEL: func @abs_unranked
2028func @abs_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2029  // CHECK:  "mhlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2030  %0 = "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2031  return %0 : tensor<*xf32>
2032}
2033
2034// CHECK-LABEL: @acos
2035// CHLO-LABEL: @acos
2036func @acos(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2037  // CHECK:  chlo.acos %arg0 : tensor<2xf32>
2038// CHLO:   %[[VAL_1:.*]] = "mhlo.compare"({{.*}}) {comparison_direction = "NE"}
2039// CHLO:   %[[VAL_5:.*]] = mhlo.multiply %arg0, %arg0
2040// CHLO:   %[[VAL_4:.*]] = mhlo.constant dense<1.000000e+00>
2041// CHLO:   %[[VAL_6:.*]] = mhlo.subtract %[[VAL_4]], %[[VAL_5]]
2042// CHLO:   %[[VAL_7:.*]] = "mhlo.sqrt"(%[[VAL_6]])
2043// CHLO:   %[[VAL_8:.*]] = mhlo.constant dense<1.000000e+00>
2044// CHLO:   %[[VAL_9:.*]] = mhlo.add %[[VAL_8]], %arg0
2045// CHLO:   %[[VAL_10:.*]] = mhlo.atan2 %[[VAL_7]], %[[VAL_9]]
2046// CHLO:   %[[VAL_3:.*]] = mhlo.constant dense<2.000000e+00>
2047// CHLO:   %[[VAL_11:.*]] = mhlo.multiply %[[VAL_3]], %[[VAL_10]]
2048// CHLO:   %[[VAL_12:.*]] = mhlo.constant dense<3.14159274>
2049// CHLO:   %[[VAL_13:.*]] = "mhlo.select"(%[[VAL_1]], %[[VAL_11]], %[[VAL_12]])
2050// CHLO:       return %[[VAL_13]] : tensor<2xf32>
2051  %0 = "tf.Acos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2052  return %0 : tensor<2xf32>
2053}
2054
2055// CHECK-LABEL: @acos_complex
2056// CHLO-LABEL: @acos_complex
2057func @acos_complex(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> {
2058  // CHLO: tf.Acos
2059  %0 = "tf.Acos"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
2060  return %0 : tensor<2xcomplex<f32>>
2061}
2062
2063// CHECK-LABEL: @acos_dynamic
2064// CHLO-LABEL: @acos_dynamic
2065func @acos_dynamic(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2066  // CHECK:  chlo.acos %arg0 : tensor<*xf32>
2067  // `tf.Acos` is lowered to `chlo.constant_like` operations which can only be
2068  // lowered further on ranked tensors.  Unranked CHLO must be transformed to
2069  // ranked code before further lowering.
2070  // CHLO: "tf.Acos"
2071  %0 = "tf.Acos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2072  return %0 : tensor<*xf32>
2073}
2074
2075// CHECK-LABEL: @tan
2076// CHECK-SAME: (%[[ARG:.*]]: tensor<2xf32>) -> tensor<2xf32>
2077// CHLO-LABEL: @tan
2078// CHLO-SAME: (%[[ARG:.*]]: tensor<2xf32>) -> tensor<2xf32>
2079func @tan(%arg : tensor<2xf32>) -> tensor<2xf32> {
2080  // CHECK: chlo.tan %[[ARG]] : tensor<2xf32>
2081  // CHLO: %[[SINE:.*]] = "mhlo.sine"(%[[ARG]])
2082  // CHLO  %[[COSINE:.*]] = "mhlo.cosine"(%[[ARG]])
2083  // CHLO  %[[RESULT:.*]] = "mhlo.divide"(%[[SINE]], %[[COSINE]])
2084  %result = "tf.Tan"(%arg) : (tensor<2xf32>) -> tensor<2xf32>
2085  return %result : tensor<2xf32>
2086}
2087
2088// CHECK-LABEL: @tan_unranked
2089// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32>
2090// CHLO-LABEL: @tan_unranked
2091// CHLO-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32>
2092func @tan_unranked(%arg : tensor<*xf32>) -> tensor<*xf32> {
2093  // CHECK: chlo.tan %[[ARG]] : tensor<*xf32>
2094  // CHLO: %[[SINE:.*]] = "mhlo.sine"(%[[ARG]])
2095  // CHLO  %[[COSINE:.*]] = "mhlo.cosine"(%[[ARG]])
2096  // CHLO  %[[RESULT:.*]] = "mhlo.divide"(%[[SINE]], %[[COSINE]])
2097  %result = "tf.Tan"(%arg) : (tensor<*xf32>) -> tensor<*xf32>
2098  return %result : tensor<*xf32>
2099}
2100
2101// CHECK-LABEL: @sinh_complex
2102// CHLO-LABEL: @sinh_complex
2103func @sinh_complex(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> {
2104  // CHLO: tf.Sinh
2105  %0 = "tf.Sinh"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
2106  return %0 : tensor<2xcomplex<f32>>
2107}
2108
2109// CHECK-LABEL: func @cast_dynamic_i2f
2110func @cast_dynamic_i2f(%arg0: tensor<?xi32>) -> tensor<?xf32> {
2111  // CHECK: "mhlo.convert"(%arg0) : (tensor<?xi32>) -> tensor<?xf32>
2112  %0 = "tf.Cast"(%arg0) : (tensor<?xi32>) -> tensor<?xf32>
2113  return %0 : tensor<?xf32>
2114}
2115
2116// CHECK-LABEL: func @cast_i2f
2117func @cast_i2f(%arg0: tensor<2xi32>) -> tensor<2xf32> {
2118  // CHECK: "mhlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
2119  %0 = "tf.Cast"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
2120  return %0 : tensor<2xf32>
2121}
2122
2123// CHECK-LABEL: func @cast_c2f
2124func @cast_c2f(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> {
2125  // CHECK: tf.Cast
2126  %0 = "tf.Cast"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
2127  return %0 : tensor<2xf32>
2128}
2129
2130// CHECK-LABEL: @ceil
2131func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2132  // CHECK:  "mhlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2133  %0 = "tf.Ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2134  return %0 : tensor<2xf32>
2135}
2136
2137// CHECK-LABEL: func @ceil_dynamic
2138func @ceil_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2139  // CHECK:  "mhlo.ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2140  %0 = "tf.Ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2141  return %0 : tensor<?xf32>
2142}
2143
2144// CHECK-LABEL: func @ceil_unranked
2145func @ceil_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2146  // CHECK:  "mhlo.ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2147  %0 = "tf.Ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2148  return %0 : tensor<*xf32>
2149}
2150
2151// CHECK-LABEL: @complex_abs
2152func @complex_abs(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> {
2153  // CHECK:  "mhlo.abs"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
2154  %0 = "tf.ComplexAbs"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
2155  return %0 : tensor<2xf32>
2156}
2157
2158// CHECK-LABEL: @cos
2159func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2160  // CHECK:  "mhlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2161  %0 = "tf.Cos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2162  return %0 : tensor<2xf32>
2163}
2164
2165// CHECK-LABEL: func @cos_dynamic
2166func @cos_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2167  // CHECK:  "mhlo.cosine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2168  %0 = "tf.Cos"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2169  return %0 : tensor<?xf32>
2170}
2171
2172// CHECK-LABEL: func @cos_unranked
2173func @cos_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2174  // CHECK:  "mhlo.cosine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2175  %0 = "tf.Cos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2176  return %0 : tensor<*xf32>
2177}
2178
2179// CHECK-LABEL: @exp
2180func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2181  // CHECK:  "mhlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2182  %0 = "tf.Exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2183  return %0 : tensor<2xf32>
2184}
2185
2186// CHECK-LABEL: @expm1
2187func @expm1(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2188  // CHECK:  "mhlo.exponential_minus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2189  %0 = "tf.Expm1"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2190  return %0 : tensor<2xf32>
2191}
2192
2193// CHECK-LABEL: func @exp_dynamic
2194func @exp_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2195  // CHECK:  "mhlo.exponential"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2196  %0 = "tf.Exp"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2197  return %0 : tensor<?xf32>
2198}
2199
2200// CHECK-LABEL: func @exp_unranked
2201func @exp_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2202  // CHECK:  "mhlo.exponential"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2203  %0 = "tf.Exp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2204  return %0 : tensor<*xf32>
2205}
2206
2207// CHECK-LABEL: @floor
2208func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2209  // CHECK:  "mhlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2210  %0 = "tf.Floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2211  return %0 : tensor<2xf32>
2212}
2213
2214// CHECK-LABEL: func @floor_dynamic
2215func @floor_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2216  // CHECK:  "mhlo.floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2217  %0 = "tf.Floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2218  return %0 : tensor<?xf32>
2219}
2220
2221// CHECK-LABEL: func @floor_unranked
2222func @floor_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2223  // CHECK:  "mhlo.floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2224  %0 = "tf.Floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2225  return %0 : tensor<*xf32>
2226}
2227
2228// CHECK-LABEL: func @invert_op_unranked
2229func @invert_op_unranked(%arg0: tensor<*xi32>) -> tensor<*xi32> {
2230  // CHECK:  "mhlo.not"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
2231  %0 = "tf.Invert"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
2232  return %0 : tensor<*xi32>
2233}
2234
2235// CHECK-LABEL: @is_finite
2236func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> {
2237  // CHECK:  "mhlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1>
2238  %0 = "tf.IsFinite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1>
2239  return %0 : tensor<2xi1>
2240}
2241
2242// CHECK-LABEL: func @is_finite_dynamic
2243func @is_finite_dynamic(%arg0: tensor<?xf32>) -> tensor<?xi1> {
2244  // CHECK:  "mhlo.is_finite"(%arg0) : (tensor<?xf32>) -> tensor<?xi1>
2245  %0 = "tf.IsFinite"(%arg0) : (tensor<?xf32>) -> tensor<?xi1>
2246  return %0 : tensor<?xi1>
2247}
2248
2249// CHECK-LABEL: func @is_finite_unranked
2250func @is_finite_unranked(%arg0: tensor<*xf32>) -> tensor<*xi1> {
2251  // CHECK:  "mhlo.is_finite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1>
2252  %0 = "tf.IsFinite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1>
2253  return %0 : tensor<*xi1>
2254}
2255
2256// CHECK-LABEL: @log
2257func @log(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2258  // CHECK:  "mhlo.log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2259  %0 = "tf.Log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2260  return %0 : tensor<2xf32>
2261}
2262
2263// CHECK-LABEL: func @log_dynamic
2264func @log_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2265  // CHECK:  "mhlo.log"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2266  %0 = "tf.Log"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2267  return %0 : tensor<?xf32>
2268}
2269
2270// CHECK-LABEL: func @log_unranked
2271func @log_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2272  // CHECK:  "mhlo.log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2273  %0 = "tf.Log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2274  return %0 : tensor<*xf32>
2275}
2276
2277// CHECK-LABEL: @log1p
2278func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2279  // CHECK:  "mhlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2280  %0 = "tf.Log1p"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2281  return %0 : tensor<2xf32>
2282}
2283
2284// CHECK-LABEL: func @log1p_dynamic
2285func @log1p_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2286  // CHECK:  "mhlo.log_plus_one"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2287  %0 = "tf.Log1p"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2288  return %0 : tensor<?xf32>
2289}
2290
2291// CHECK-LABEL: func @log1p_unranked
2292func @log1p_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2293  // CHECK:  "mhlo.log_plus_one"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2294  %0 = "tf.Log1p"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2295  return %0 : tensor<*xf32>
2296}
2297
2298// CHECK-LABEL: func @not_op_unranked
2299func @not_op_unranked(%arg0: tensor<*xi1>) -> tensor<*xi1> {
2300  // CHECK:  "mhlo.not"(%arg0) : (tensor<*xi1>) -> tensor<*xi1>
2301  %0 = "tf.LogicalNot"(%arg0) : (tensor<*xi1>) -> tensor<*xi1>
2302  return %0 : tensor<*xi1>
2303}
2304
2305// CHECK-LABEL: @neg
2306func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2307  // CHECK:  "mhlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2308  %0 = "tf.Neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2309  return %0 : tensor<2xf32>
2310}
2311
2312// CHECK-LABEL: func @neg_dynamic
2313func @neg_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2314  // CHECK:  "mhlo.negate"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2315  %0 = "tf.Neg"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2316  return %0 : tensor<?xf32>
2317}
2318
2319// CHECK-LABEL: func @neg_unranked
2320func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2321  // CHECK:  "mhlo.negate"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2322  %0 = "tf.Neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2323  return %0 : tensor<*xf32>
2324}
2325
2326// CHECK-LABEL: @sigmoid
2327func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2328  // CHECK-DAG: [[SCALAR:%.+]] = mhlo.constant dense<5.000000e-01> : tensor<f32>
2329  // CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0 : tensor<2xf32>
2330  // CHECK-DAG: [[SHAPE_VAL:%.+]] = shape.to_extent_tensor [[SHAPE]]
2331  // CHECK-DAG: [[HALF:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[SCALAR]], [[SHAPE_VAL]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<2xf32>
2332  // CHECK-DAG: [[R1:%.+]] =  mhlo.multiply %arg0, [[HALF]] : tensor<2xf32>
2333  // CHECK-DAG: [[R2:%.+]] =  "mhlo.tanh"([[R1]]) : (tensor<2xf32>) -> tensor<2xf32>
2334  // CHECK-DAG: [[R3:%.+]] =  mhlo.multiply [[R2]], [[HALF]] : tensor<2xf32>
2335  // CHECK-DAG: [[R4:%.+]] =  mhlo.add [[R3]], [[HALF]] : tensor<2xf32>
2336  %0 = "tf.Sigmoid"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2337  return %0 : tensor<2xf32>
2338}
2339
2340// CHECK-LABEL: @sigmoid_complex
2341func @sigmoid_complex(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> {
2342  // CHECK: [[R0:%.+]] = mhlo.constant dense<(5.000000e-01,0.000000e+00)> : tensor<complex<f32>>
2343  // CHECK-NOT: tf.Sigmoid
2344  %0 = "tf.Sigmoid"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
2345  return %0 : tensor<2xcomplex<f32>>
2346}
2347
2348// CHECK-LABEL: @sigmoid_unranked
2349func @sigmoid_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2350  // CHECK-DAG: [[SCALAR:%.+]] = mhlo.constant dense<5.000000e-01> : tensor<f32>
2351  // CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0 : tensor<*xf32>
2352  // CHECK-DAG: [[SHAPE_VAL:%.+]] = shape.to_extent_tensor [[SHAPE]]
2353  // CHECK-DAG: [[HALF:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[SCALAR]], [[SHAPE_VAL]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<?xindex>) -> tensor<*xf32>
2354  // CHECK-DAG: [[R1:%.+]] =  mhlo.multiply %arg0, [[HALF]] : tensor<*xf32>
2355  // CHECK-DAG: [[R2:%.+]] =  "mhlo.tanh"([[R1]]) : (tensor<*xf32>) -> tensor<*xf32>
2356  // CHECK-DAG: [[R3:%.+]] =  mhlo.multiply [[R2]], [[HALF]] : tensor<*xf32>
2357  // CHECK-DAG: [[R4:%.+]] =  mhlo.add [[R3]], [[HALF]] : tensor<*xf32>
2358  %0 = "tf.Sigmoid"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2359  return %0 : tensor<*xf32>
2360}
2361
2362
2363// CHECK-LABEL: @sigmoid_grad
2364func @sigmoid_grad(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
2365  // CHECK-DAG: [[MUL0:%.+]] =  mhlo.multiply %arg1, %arg0 : tensor<2xf32>
2366  // CHECK-DAG: [[ONE:%.+]] = mhlo.constant dense<1.000000e+00> : tensor<2xf32>
2367  // CHECK-DAG: [[SUB:%.+]] =  mhlo.subtract [[ONE]], %arg0 : tensor<2xf32>
2368  // CHECK-DAG: [[MUL1:%.+]] =  mhlo.multiply [[MUL0]], [[SUB]] : tensor<2xf32>
2369  // CHECK: return [[MUL1]]
2370  %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
2371  return %0 : tensor<2xf32>
2372}
2373
2374// CHECK-LABEL: @sigmoid_grad_complex
2375func @sigmoid_grad_complex(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> {
2376  // CHECK-DAG: [[MUL0:%.+]] =  mhlo.multiply %arg1, %arg0 : tensor<2xcomplex<f32>>
2377  // CHECK-DAG: [[ONE:%.+]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<2xcomplex<f32>>
2378  // CHECK-DAG: [[SUB:%.+]] =  mhlo.subtract [[ONE]], %arg0 : tensor<2xcomplex<f32>>
2379  // CHECK-DAG: [[MUL1:%.+]] =  mhlo.multiply [[MUL0]], [[SUB]] : tensor<2xcomplex<f32>>
2380  // CHECK: return [[MUL1]]
2381  %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
2382  return %0 : tensor<2xcomplex<f32>>
2383}
2384
2385// CHECK-LABEL: @sin
2386func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2387  // CHECK:  "mhlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2388  %0 = "tf.Sin"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2389  return %0 : tensor<2xf32>
2390}
2391
2392// CHECK-LABEL: func @sin_dynamic
2393func @sin_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2394  // CHECK:  "mhlo.sine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2395  %0 = "tf.Sin"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2396  return %0 : tensor<?xf32>
2397}
2398
2399// CHECK-LABEL: func @sin_unranked
2400func @sin_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2401  // CHECK:  "mhlo.sine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2402  %0 = "tf.Sin"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2403  return %0 : tensor<*xf32>
2404}
2405
2406// CHECK-LABEL: func @rsqrt
2407func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2408  // CHECK:  "mhlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2409  %0 = "tf.Rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2410  return %0 : tensor<2xf32>
2411}
2412
2413// CHECK-LABEL: func @rsqrt_dynamic
2414func @rsqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2415  // CHECK:  "mhlo.rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2416  %0 = "tf.Rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2417  return %0 : tensor<?xf32>
2418}
2419
2420// CHECK-LABEL: func @rsqrt_unranked
2421func @rsqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2422  // CHECK:  "mhlo.rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2423  %0 = "tf.Rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2424  return %0 : tensor<*xf32>
2425}
2426
2427// CHECK-LABEL: func @sqrt
2428func @sqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2429  // CHECK:  "mhlo.sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2430  %0 = "tf.Sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2431  return %0 : tensor<2xf32>
2432}
2433
2434// CHECK-LABEL: func @sqrt_dynamic
2435func @sqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2436  // CHECK:  "mhlo.sqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2437  %0 = "tf.Sqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2438  return %0 : tensor<?xf32>
2439}
2440
2441// CHECK-LABEL: func @sqrt_unranked
2442func @sqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2443  // CHECK:  "mhlo.sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2444  %0 = "tf.Sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2445  return %0 : tensor<*xf32>
2446}
2447
2448// CHECK-LABEL: func @tanh
2449func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2450  // CHECK:  "mhlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2451  %0 = "tf.Tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2452  return %0 : tensor<2xf32>
2453}
2454
2455// CHECK-LABEL: func @tanh_dynamic
2456func @tanh_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2457  // CHECK:  "mhlo.tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2458  %0 = "tf.Tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2459  return %0 : tensor<?xf32>
2460}
2461
2462// CHECK-LABEL: func @tanh_unranked
2463func @tanh_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2464  // CHECK:  "mhlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2465  %0 = "tf.Tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2466  return %0 : tensor<*xf32>
2467}
2468
2469// CHECK-LABEL: func @bitcast
2470func @bitcast(%arg0: tensor<2xf32>) -> tensor<2xf32> {
2471  // CHECK:  "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2472  %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
2473  return %0 : tensor<2xf32>
2474}
2475
2476// CHECK-LABEL: func @bitcast_dynamic
2477func @bitcast_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
2478  // CHECK:  "mhlo.bitcast_convert"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2479  %0 = "tf.Bitcast"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
2480  return %0 : tensor<?xf32>
2481}
2482
2483// CHECK-LABEL: func @bitcast_unranked
2484func @bitcast_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
2485  // CHECK:  "mhlo.bitcast_convert"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2486  %0 = "tf.Bitcast"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
2487  return %0 : tensor<*xf32>
2488}
2489
2490// CHECK-LABEL: func @bitcast_same_widths
2491func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> {
2492  // CHECK:  "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32>
2493  %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xi32>
2494  return %0 : tensor<2xi32>
2495}
2496
2497// CHECK-LABEL: func @bitcast_smaller_input_width
2498func @bitcast_smaller_input_width(%arg0: tensor<2xi8>) -> tensor<2xi64> {
2499  // CHECK:  "tf.Bitcast"(%arg0) : (tensor<2xi8>) -> tensor<2xi64>
2500  %0 = "tf.Bitcast"(%arg0) : (tensor<2xi8>) -> tensor<2xi64>
2501  return %0 : tensor<2xi64>
2502}
2503
2504// CHECK-LABEL: func @bitcast_smaller_output_width
2505func @bitcast_smaller_output_width(%arg0: tensor<2xf32>) -> tensor<2xf16> {
2506  // CHECK:  "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xf16>
2507  %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xf16>
2508  return %0 : tensor<2xf16>
2509}
2510
2511// CHECK-LABEL: reshape
2512func @reshape(%arg0: tensor<2xf32>, %arg1: tensor<2xi32>) -> tensor<2x1xf32> {
2513  // CHECK:  "mhlo.reshape"
2514  %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xi32>) -> tensor<2x1xf32>
2515  return %0 : tensor<2x1xf32>
2516}
2517
2518// CHECK-LABEL: reshape_dynamic
2519func @reshape_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<2xi32>) -> tensor<?x?xf32> {
2520  // CHECK:  "mhlo.dynamic_reshape"
2521  %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<?xf32>, tensor<2xi32>) -> tensor<?x?xf32>
2522  return %0 : tensor<?x?xf32>
2523}
2524
2525// CHECK-LABEL: reshape_unranked
2526func @reshape_unranked(%arg0: tensor<*xf32>, %arg1: tensor<2xi32>) -> tensor<?x?xf32> {
2527  // CHECK:  "tf.Reshape"
2528  %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<*xf32>, tensor<2xi32>) -> tensor<?x?xf32>
2529  return %0 : tensor<?x?xf32>
2530}
2531
2532// CHECK-LABEL: squeeze
2533func @squeeze(%arg0: tensor<1x1x10xf32>) -> tensor<1x10xf32> {
2534  // CHECK: "mhlo.reshape"
2535  %0 = "tf.Squeeze"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32>
2536  return %0 : tensor<1x10xf32>
2537}
2538
2539// CHECK-LABEL: squeeze_dynamic
2540func @squeeze_dynamic(%arg0: tensor<?x10xf32>) -> tensor<*xf32> {
2541  // CHECK: "tf.Squeeze"
2542  %0 = "tf.Squeeze"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
2543  return %0 : tensor<*xf32>
2544}
2545
2546// CHECK-LABEL: expand_dims
2547func @expand_dims(%arg0: tensor<2xf32>, %axis: tensor<i32>) -> tensor<1x2xf32> {
2548  // CHECK: "mhlo.reshape"
2549  %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<2xf32>, tensor<i32>) -> tensor<1x2xf32>
2550  return %0 : tensor<1x2xf32>
2551}
2552
2553// CHECK-LABEL: expand_dims_dynamic
2554func @expand_dims_dynamic(%arg0: tensor<?x?xf32>) -> tensor<?x1x?xf32> {
2555  %axis = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> (tensor<i32>)
2556
2557  // CHECK-DAG: [[SHAPEOF:%.+]] = shape.shape_of %arg0
2558  // CHECK-DAG: [[CST0:%.+]] = constant 0
2559  // CHECK-DAG: [[CST1:%.+]] = constant 1
2560  // CHECK-DAG: [[GETEXTENT0:%.+]] = shape.get_extent [[SHAPEOF]], [[CST0]]
2561  // CHECK-DAG: [[CST1_0:%.+]] = constant 1
2562  // CHECK-DAG: [[GETEXTENT1:%.+]] = shape.get_extent [[SHAPEOF]], [[CST1_0]]
2563  // CHECK-DAG: [[FROMEXTENTS:%.+]] = shape.from_extents [[GETEXTENT0]], [[CST1]], [[GETEXTENT1]]
2564  // CHECK-DAG: [[TOEXTENTS:%.+]] = shape.to_extent_tensor [[FROMEXTENTS]]
2565  // CHECK-DAG: [[RESHAPE:%.+]] = "mhlo.dynamic_reshape"(%arg0, [[TOEXTENTS]])
2566  %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<?x?xf32>, tensor<i32>) -> tensor<?x1x?xf32>
2567
2568  // CHECK: return [[RESHAPE]]
2569  return %0 : tensor<?x1x?xf32>
2570}
2571
2572// CHECK-LABEL: expand_dynamic_dims_rank1_axis
2573func @expand_dynamic_dims_rank1_axis(%arg0: tensor<?x?x4xf32>) -> tensor<?x1x?x4xf32> {
2574  %axis = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
2575
2576  // CHECK-DAG: [[SHAPEOF:%.+]] = shape.shape_of %arg0
2577  // CHECK-DAG: [[CST0:%.+]] = constant 0
2578  // CHECK-DAG: [[CST1:%.+]] = constant 1
2579  // CHECK-DAG: [[GETEXTENT0:%.+]] = shape.get_extent [[SHAPEOF]], [[CST0]]
2580  // CHECK-DAG: [[CST1_0:%.+]] = constant 1
2581  // CHECK-DAG: [[GETEXTENT1:%.+]] = shape.get_extent [[SHAPEOF]], [[CST1_0]]
2582  // CHECK-DAG: [[CST2:%.+]] = constant 2
2583  // CHECK-DAG: [[GETEXTENT2:%.+]] = shape.get_extent [[SHAPEOF]], [[CST2]]
2584  // CHECK-DAG: [[FROMEXTENTS:%.+]] = shape.from_extents [[GETEXTENT0]], [[CST1]], [[GETEXTENT1]], [[GETEXTENT2]]
2585  // CHECK-DAG: [[TOEXTENTS:%.+]] = shape.to_extent_tensor [[FROMEXTENTS]]
2586  // CHECK-DAG: [[RESHAPE:%.+]] = "mhlo.dynamic_reshape"(%arg0, [[TOEXTENTS]])
2587  %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<?x?x4xf32>, tensor<1xi32>) -> tensor<?x1x?x4xf32>
2588
2589  // CHECK: return [[RESHAPE]]
2590  return %0 : tensor<?x1x?x4xf32>
2591}
2592
2593// CHECK-LABEL: func @sign
2594// CHECK-SAME: [[ARG:%arg.*]]: tensor<1x2x3x4xf32>
2595func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
2596  // CHECK: [[SIGN:%.*]] = "mhlo.sign"([[ARG]])
2597  // CHECK: return [[SIGN]] : tensor<1x2x3x4xf32>
2598  %0 = "tf.Sign"(%arg0) : (tensor<1x2x3x4xf32>) -> (tensor<1x2x3x4xf32>)
2599  return %0 : tensor<1x2x3x4xf32>
2600}
2601
2602// CHECK-LABEL: slice_constant_start
2603func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
2604  // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor<1xi64>
2605  // CHECK: %[[CAST:.*]] = tensor.cast %[[START]] : tensor<1xi64> to tensor<1xi64>
2606  // CHECK: %[[START_I64:.*]] = "mhlo.convert"(%[[CAST]]) : (tensor<1xi64>) -> tensor<1xi64>
2607  // CHECK: %[[SLICED_START:.*]] = "mhlo.slice"(%[[START_I64]])
2608  // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
2609  // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
2610  // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} :
2611  // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<1xi64>
2612  // CHECK: %[[RESHAPED_START:.*]] = "mhlo.reshape"(%[[SLICED_START:.*]]) :
2613  // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<i64>
2614  // CHECK: %[[RESULT:.*]] = "mhlo.dynamic-slice"(%arg0, %[[RESHAPED_START]])
2615  // CHECK-DAG-SAME: {slice_sizes = dense<2> : tensor<1xi64>} :
2616  // CHECK-DAG-SAME: (tensor<4xi32>, tensor<i64>) -> tensor<2xi32>
2617  // CHECK: return %[[RESULT]] : tensor<2xi32>
2618  %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
2619  %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>)
2620  %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi32>
2621  return %0 : tensor<2xi32>
2622}
2623
2624// CHECK-LABEL: slice_i32_consts
2625func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> {
2626  // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor<1xi32>
2627  // CHECK: %[[START_CAST:.*]] = tensor.cast %[[START]] : tensor<1xi32> to tensor<1xi32>
2628  // CHECK: %[[START_I64:.*]] = "mhlo.convert"(%[[START_CAST]]) : (tensor<1xi32>) -> tensor<1xi64>
2629  // CHECK: %[[SLICED_START:.*]] = "mhlo.slice"(%[[START_I64]])
2630  // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
2631  // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
2632  // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi64>) -> tensor<1xi64>
2633  // CHECK: %[[RESHAPED_START:.*]] = "mhlo.reshape"(%[[SLICED_START]]) : (tensor<1xi64>) -> tensor<i64>
2634  // CHECK: "mhlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<2xi32>
2635  %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> (tensor<1xi32>)
2636  %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> (tensor<1xi32>)
2637  %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
2638  return %0 : tensor<2xi32>
2639}
2640
2641// CHECK-LABEL: slice_constant_start_negative_one_size
2642func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi32> {
2643  // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor<1xi64>
2644  // CHECK: %[[START_CAST:.*]] = tensor.cast %[[START]] : tensor<1xi64> to tensor<1xi64>
2645  // CHECK: %[[START_I64:.*]] = "mhlo.convert"(%[[START_CAST]]) : (tensor<1xi64>) -> tensor<1xi64>
2646  // CHECK: %[[SLICED_START:.*]] = "mhlo.slice"(%[[START_I64]])
2647  // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
2648  // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
2649  // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi64>) -> tensor<1xi64>
2650  // CHECK: %[[RESHAPED_START:.*]] = "mhlo.reshape"(%[[SLICED_START]]) : (tensor<1xi64>) -> tensor<i64>
2651  // CHECK: %[[RESULT:.*]] =  "mhlo.dynamic-slice"(%arg0, %[[RESHAPED_START]]) {slice_sizes = dense<3> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<3xi32>
2652  // CHECK: return %[[RESULT]] : tensor<3xi32>
2653  %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
2654  %sizes = "tf.Const"() {value = dense<[-1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
2655  %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi32>
2656  return %0 : tensor<3xi32>
2657}
2658
2659// CHECK-LABEL: slice_constant_start_dynamic_shape
2660func @slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
2661  // CHECK: %[[START:.*]] = mhlo.constant dense<[1, 0]> : tensor<2xi64>
2662  // CHECK: %[[START_CAST:.*]] = tensor.cast %[[START]] : tensor<2xi64> to tensor<2xi64>
2663  // CHECK: %[[START_I64:.*]] = "mhlo.convert"(%[[START_CAST]]) : (tensor<2xi64>) -> tensor<2xi64>
2664  // CHECK: %[[SLICED_START1:.*]] = "mhlo.slice"(%[[START_I64]])
2665  // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
2666  // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
2667  // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} :
2668  // CHECK-DAG-SAME: (tensor<2xi64>) -> tensor<1xi64>
2669  // CHECK: %[[RESHAPED_START1:.*]] = "mhlo.reshape"(%[[SLICED_START1]]) :
2670  // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<i64>
2671  // CHECK: %[[SLICED_START2:.*]] = "mhlo.slice"(%[[START_I64]])
2672  // CHECK-DAG-SAME: {limit_indices = dense<2> : tensor<1xi64>,
2673  // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>,
2674  // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} :
2675  // CHECK-DAG-SAME: (tensor<2xi64>) -> tensor<1xi64>
2676  // CHECK: %[[RESHAPED_START2:.*]] = "mhlo.reshape"(%[[SLICED_START2]]) :
2677  // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<i64>
2678  // CHECK: %[[RESULT:.*]] = "mhlo.dynamic-slice"
2679  // CHECK-DAG-SAME: (%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]])
2680  // CHECK-DAG-SAME: {slice_sizes = dense<[1, 4]> : tensor<2xi64>} :
2681  // CHECK-DAG-SAME: (tensor<?x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
2682  // CHECK: return %[[RESULT]] : tensor<1x4xi32>
2683  %starts = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>)
2684  %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>)
2685  %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<?x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32>
2686  return %0 : tensor<1x4xi32>
2687}
2688
2689// CHECK-LABEL: slice_variable_start
2690func @slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
2691  // CHECK: %[[START_I64:.*]] = "mhlo.convert"(%arg1) : (tensor<2xi64>) -> tensor<2xi64>
2692  // CHECK: %[[SLICED_START1:.*]] = "mhlo.slice"(%[[START_I64]])
2693  // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
2694  // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
2695  // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64>
2696  // CHECK: %[[RESHAPED_START1:.*]] = "mhlo.reshape"(%[[SLICED_START1]]) : (tensor<1xi64>) -> tensor<i64>
2697  // CHECK: %[[SLICED_START2:.*]] = "mhlo.slice"(%[[START_I64]]) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64>
2698  // CHECK: %[[RESHAPED_START2:.*]] = "mhlo.reshape"(%[[SLICED_START2]]) : (tensor<1xi64>) -> tensor<i64>
2699  // CHECK: %[[RESULT:.*]] = "mhlo.dynamic-slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
2700  // CHECK: return %[[RESULT]] : tensor<1x4xi32>
2701  %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>)
2702  %0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32>
2703  return %0 : tensor<1x4xi32>
2704}
2705
2706// CHECK-LABEL: slice_mhlo_sizes
2707func @slice_mhlo_sizes(%arg0: tensor<1x1024x4xf32>, %arg1: tensor<3xi32>) -> tensor<1x512x4xf32> {
2708  // CHECK-NOT: "tf.Slice"
2709  %0 = "mhlo.constant"() {value = dense<[1, 512, 4]> : tensor<3xi32>} : () -> tensor<3xi32>
2710  %1 = "tf.Slice"(%arg0, %arg1, %0) : (tensor<1x1024x4xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x512x4xf32>
2711  return %1 : tensor<1x512x4xf32>
2712}
2713
2714// CHECK-LABEL: slice_variable_start_negative_one_size
2715func @slice_variable_start_negative_one_size(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
2716  // CHECK: %[[RESULT:.*]] = "tf.Slice"
2717  // CHECK: return %[[RESULT]] : tensor<1x4xi32>
2718  %sizes = "tf.Const"() {value = dense<[1, -1]> : tensor<2xi64>} : () -> (tensor<2xi64>)
2719  %0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32>
2720  return %0 : tensor<1x4xi32>
2721}
2722
2723//===----------------------------------------------------------------------===//
2724// StridedSlice op legalizations.
2725//===----------------------------------------------------------------------===//
2726
2727// CHECK-LABEL: simple_strided_slice
2728func @simple_strided_slice(%input: tensor<4x8xf32>) -> tensor<3x2xf32> {
2729  %begin = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2730  %end = "tf.Const"() {value = dense<[3, 7]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2731  %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2732
2733  // CHECK: mhlo.slice
2734  // CHECK-DAG-SAME: start_indices = dense<[0, 1]>
2735  // CHECK-DAG-SAME: limit_indices = dense<[3, 7]>
2736  // CHECK-DAG-SAME: strides = dense<[1, 3]>
2737  // CHECK-SAME: -> tensor<3x2xf32>
2738
2739  %output = "tf.StridedSlice"(%input, %begin, %end, %strides)
2740      : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<3x2xf32>
2741  return %output : tensor<3x2xf32>
2742}
2743
2744// CHECK-LABEL: dynamic_strided_slice
2745func @dynamic_strided_slice(%input: tensor<?x8xf32>) -> tensor<?x2xf32> {
2746  %begin = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2747  %end = "tf.Const"() {value = dense<[3, 7]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2748  %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2749
2750  // CHECK: "tf.StridedSlice"
2751  %output = "tf.StridedSlice"(%input, %begin, %end, %strides)
2752      : (tensor<?x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x2xf32>
2753  return %output : tensor<?x2xf32>
2754}
2755
2756// CHECK-LABEL: strided_slice_negative_indices
2757func @strided_slice_negative_indices(%input: tensor<4x8xf32>) -> tensor<3x2xf32> {
2758  %begin = "tf.Const"() {value = dense<[-1, -2]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2759  %end = "tf.Const"() {value = dense<[-4, -8]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2760  %strides = "tf.Const"() {value = dense<[-1, -3]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2761
2762  // CHECK: "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>}
2763
2764  // CHECK: mhlo.slice
2765  // CHECK-DAG-SAME: start_indices = dense<[0, 1]>
2766  // CHECK-DAG-SAME: limit_indices = dense<[3, 7]>
2767  // CHECK-DAG-SAME: strides = dense<[1, 3]>
2768  // CHECK-SAME: -> tensor<3x2xf32>
2769
2770  %output = "tf.StridedSlice"(%input, %begin, %end, %strides)
2771      : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<3x2xf32>
2772  return %output : tensor<3x2xf32>
2773}
2774
2775// CHECK-LABEL: dynamic_strided_slice_negative_indices
2776func @dynamic_strided_slice_negative_indices(%input: tensor<?x8xf32>) -> tensor<?x2xf32> {
2777  %begin = "tf.Const"() {value = dense<[-1, -2]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2778  %end = "tf.Const"() {value = dense<[-4, -8]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2779  %strides = "tf.Const"() {value = dense<[-1, -3]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2780
2781  // CHECK: tf.StridedSlice
2782  %output = "tf.StridedSlice"(%input, %begin, %end, %strides)
2783      : (tensor<?x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x2xf32>
2784  return %output : tensor<?x2xf32>
2785}
2786
2787// CHECK-LABEL: strided_slice_range_clamping
2788func @strided_slice_range_clamping(%input: tensor<4x8xf32>) -> tensor<1x3xf32> {
2789  %begin = "tf.Const"() {value = dense<[-4, -10]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2790  %end = "tf.Const"() {value = dense<[1, 10]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2791  %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>)
2792
2793  // CHECK: mhlo.slice
2794  // CHECK-DAG-SAME: start_indices = dense<[0, 0]>
2795  // CHECK-DAG-SAME: limit_indices = dense<[1, 8]>
2796  // CHECK-DAG-SAME: strides = dense<[1, 3]>
2797  // CHECK-SAME: -> tensor<1x3xf32>
2798  %output = "tf.StridedSlice"(%input, %begin, %end, %strides)
2799      : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x3xf32>
2800  return %output : tensor<1x3xf32>
2801}
2802
2803// CHECK-LABEL: strided_slice_empty
2804func @strided_slice_empty(%input: tensor<4xf32>) -> tensor<0xf32> {
2805  %begin = "tf.Const"() {value = dense<[-4]> : tensor<1xi32>} : () -> (tensor<1xi32>)
2806  %end = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>)
2807  %strides = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>)
2808
2809  // CHECK: mhlo.constant dense<> : tensor<0xf32>
2810  %output = "tf.StridedSlice"(%input, %begin, %end, %strides)
2811      : (tensor<4xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xf32>
2812  return %output : tensor<0xf32>
2813}
2814
2815// CHECK-LABEL: strided_slice_begin_end_mask
2816// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<4x128x1024xf32>
2817func @strided_slice_begin_end_mask(%input: tensor<4x128x1024xf32>) {
2818
2819  // For StridedSlice
2820  // Dim #:        0,   1,    2
2821  // Input shape: [4, 128, 1024]
2822  // Begin:        1,   4,   -3
2823  // End:          8,  65,   42
2824  // Stride:       1,   4,   -1
2825  // Begin mask:   0,   0,    1  (= 1)
2826  // End mask:     1,   0,    0  (= 4)
2827
2828  // So result shape:
2829  // Dim #0: begin mask (1) -> begin = 0; end 8 canonicalized to 4: so 4
2830  // Dim #1: 4 to 65 stride 4: so 16
2831  // Dim #2: begin -3 + 1024 = 1021; end mask (1) -> end = -1: so 1022
2832  // result shape: [4, 16, 1022]
2833
2834  %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>)
2835  %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>)
2836  %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>)
2837
2838  // CHECK: %[[REVERSE:.*]] = "mhlo.reverse"(%[[INPUT]])
2839
2840  // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[REVERSE]])
2841  // CHECK-DAG-SAME: limit_indices = dense<[4, 65, 1024]>
2842  // CHECK-DAG-SAME: start_indices = dense<[0, 4, 2]>
2843  // CHECK-DAG-SAME: strides = dense<[1, 4, 1]>
2844  // CHECK-SAME: -> tensor<4x16x1022xf32>
2845
2846  %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 1, end_mask = 4} : (tensor<4x128x1024xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<4x16x1022xf32>
2847
2848  // CHECK: "mhlo.reshape"(%[[SLICE]])
2849  // CHECK-SAME: -> tensor<4x16x1022xf32>
2850
2851  return
2852}
2853
2854// CHECK-LABEL: strided_slice_shrink_axis_mask
2855// CHECK-SAME: %[[INPUT:.+]]: tensor<4x128x1024xf32>
2856func @strided_slice_shrink_axis_mask(%input: tensor<4x128x1024xf32>) {
2857
2858  // For StridedSlice
2859  // Dim #:            0,   1,    2
2860  // Input shape:     [4, 128, 1024]
2861  // Begin:            1,   4,   -3
2862  // End:              8,  65,   42
2863  // Stride:           1,   4,   -1
2864  // Begin mask:       1,   0,    0  (= 1)
2865  // End mask:         0,   0,    1  (= 4)
2866  // Shrink axis mask: 1,   0,    1  (= 5)
2867
2868  // So result shape:
2869  // Dim #0: shrink axis, take value at [1]
2870  // Dim #1: 4 to 65 stride 4: so 16
2871  // Dim #2: shrink axis, take value at [-3]
2872  // result shape: [16]
2873
2874  // As output shape of StridedSlice differs, a reshape will follow.
2875
2876  %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>)
2877  %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>)
2878  %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>)
2879
2880  // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]])
2881  // CHECK-DAG-SAME: limit_indices = dense<[1, 65, 1022]>
2882  // CHECK-DAG-SAME: start_indices = dense<[0, 4, 1021]>
2883  // CHECK-DAG-SAME: strides = dense<[1, 4, 1]>
2884  // CHECK-SAME: -> tensor<1x16x1xf32>
2885
2886  %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 1, end_mask = 4, shrink_axis_mask = 5} : (tensor<4x128x1024xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<16xf32>
2887
2888  // CHECK: "mhlo.reshape"(%[[SLICE]])
2889  // CHECK-SAME: -> tensor<16xf32>
2890
2891  return
2892}
2893
2894// CHECK-LABEL: strided_slice_ellipsis_mask
2895// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<2x4x8x16x32x64xf32>
2896func @strided_slice_ellipsis_mask(%input: tensor<2x4x8x16x32x64xf32>) {
2897  // For StridedSlice input[1, ..., 8:, :10, 2:6:2]
2898  // The ellipsis mask is applied to dim #1, #2, i.e, we get canonicalized
2899  // slice input[1, :, :, 8:, :10, 2:6:2]
2900
2901  // The start, limit indices and strides attributes of mhlo.slice would
2902  // reflect the canonicalized slice.
2903  // As output shape of StridedSlice differs, a reshape will follow.
2904
2905  %begin = "tf.Const"() {value = dense<[1, 0, 8, 1, 2]> : tensor<5xi32>} : () -> (tensor<5xi32>)
2906  %end = "tf.Const"() {value = dense<[2, 0, 10, 10, 6]> : tensor<5xi32>} : () -> (tensor<5xi32>)
2907  %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 2]> : tensor<5xi32>} : () -> (tensor<5xi32>)
2908
2909  // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]])
2910  // CHECK-DAG-SAME: limit_indices = dense<[2, 4, 8, 16, 10, 6]> : tensor<6xi64>
2911  // CHECK-DAG-SAME: start_indices = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64>
2912  // CHECK-DAG-SAME: strides = dense<[1, 1, 1, 1, 1, 2]> : tensoe<6xi64>
2913  // CHECK-SAME: -> tensor<1x4x8x8x10x2xf32>
2914  %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 8, end_mask = 4, shrink_axis_mask = 1, ellipsis_mask = 2} : (tensor<2x4x8x16x32x64xf32>, tensor<5xi32>, tensor<5xi32>, tensor<5xi32>) -> tensor<4x8x8x10x2xf32>
2915
2916  // CHECK: "mhlo.reshape"(%[[SLICE]])
2917  // CHECK-SAME: -> tensor<4x8x8x10x2xf32>
2918
2919  return
2920}
2921
2922// CHECK-LABEL: strided_slice_new_axis_mask
2923// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<2x4x8x16x32x64xf32>
2924func @strided_slice_new_axis_mask(%input: tensor<2x4x8x16x32x64xf32>) {
2925  // For StridedSlice input[1, tf.new_axis, ..., 8:, :10, 2:6:2, tf.new_axis]
2926  // New axis mask is at index 1 and 6 of sparse spec, so
2927  // new_axis_mask = 2^1 + 2^6 = 66
2928  // The ellipsis mask is applied to dim #1, #2 of input i.e, we get
2929  // canonicalized slice input[1, :, :, 8:, :10, 2:6:2]
2930  // This is then reshaped to add the new axes.
2931
2932  // The start, limit indices and strides attributes of mhlo.slice would
2933  // reflect the canonicalized slice.
2934  // As output shape of StridedSlice differs, a reshape will follow to reflect
2935  // new axes added.
2936
2937  %begin = "tf.Const"() {value = dense<[1, 0, 0, 8, 1, 2, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>)
2938  %end = "tf.Const"() {value = dense<[2, 0, 0, 10, 10, 6, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>)
2939  %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 1, 2, 1]> : tensor<7xi32>} : () -> (tensor<7xi32>)
2940
2941  // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]])
2942  // CHECK-DAG-SAME: limit_indices = dense<[2, 4, 8, 16, 10, 6]> : tensor<6xi64>
2943  // CHECK-DAG-SAME: start_indices = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64>
2944  // CHECK-DAG-SAME: strides = dense<[1, 1, 1, 1, 1, 2]> : tensoe<6xi64>
2945  // CHECK-SAME: -> tensor<1x4x8x8x10x2xf32>
2946  %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 16, end_mask = 8, shrink_axis_mask = 1, ellipsis_mask = 4, new_axis_mask = 66} : (tensor<2x4x8x16x32x64xf32>, tensor<7xi32>, tensor<7xi32>, tensor<7xi32>) -> tensor<1x4x8x8x10x2x1xf32>
2947
2948  // CHECK: "mhlo.reshape"(%[[SLICE]])
2949  // CHECK-SAME: -> tensor<1x4x8x8x10x2x1xf32>
2950
2951  return
2952}
2953
2954// CHECK-LABEL: strided_slice_implicit_ellipsis_mask(
2955// CHECK-SAME: [[INPUT:%.*]]: tensor<10x16x2xf32>
2956func @strided_slice_implicit_ellipsis_mask(%input: tensor<10x16x2xf32>) -> tensor<2x16x2xf32> {
2957  // StridedSlice gets input[8:10], which is same as input[8:10, ...]
2958  // The start_indices, limit_indices, and strides attribute of mhlo.slice
2959  // reflect the canonicalized slice.
2960  %begin = "tf.Const"() {value = dense<8> : tensor<1xi32>} : () -> tensor<1xi32>
2961  %end = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32>
2962  %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
2963  // CHECK: [[SLICE:%.*]] = "mhlo.slice"([[INPUT]])
2964  // CHECK-DAG-SAME: limit_indices = dense<[10, 16, 2]> : tensor<3xi64>
2965  // CHECK-DAG-SAME: start_indices = dense<[8, 0, 0]> : tensor<3xi64>
2966  // CHECK-DAG-SAME: strides = dense<1> : tensor<3xi64>
2967  // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[SLICE]]) : (tensor<2x16x2xf32>) -> tensor<2x16x2xf32>
2968  %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = f32} : (tensor<10x16x2xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2x16x2xf32>
2969  // CHECK: return [[RESHAPE]] : tensor<2x16x2xf32>
2970  return %0 : tensor<2x16x2xf32>
2971}
2972
2973// CHECK-LABEL: strided_slice_nonconstant_begin_end
2974func @strided_slice_nonconstant_begin_end(%arg0: tensor<i32>, %arg1: tensor<32x1x97xi32>) -> (tensor<1x97xi32>) {
2975  // In this case, the `begin` and `end` inputs are unknown at compile time --
2976  // so the StridedSlice needs to slice these vectors and use that as input to
2977  // an HLO dynamic slice.
2978  %begin = "tf.Pack"(%arg0) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32>
2979  %0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
2980  %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
2981  %2 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
2982  %end = "tf.Pack"(%2) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32>
2983  // CHECK: %[[A:.*]] = "mhlo.reshape"(%arg0) : (tensor<i32>) -> tensor<1xi32>
2984  // CHECK-NEXT: %[[BEGIN:.*]] = "mhlo.concatenate"(%[[A]])
2985  // CHECK-DAG-SAME: {dimension = 0 : i64} : (tensor<1xi32>) -> tensor<1xi32>
2986  // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32>
2987  // CHECK-NEXT: %[[INDEX:.*]] = "mhlo.slice"(%[[BEGIN]])
2988  // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>,
2989  // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
2990  // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32>
2991  // CHECK-NEXT: %[[INDEX2:.*]] = "mhlo.reshape"(%[[INDEX]]) : (tensor<1xi32>) -> tensor<i32>
2992  // CHECK-NEXT: %[[CMP:.*]] = chlo.broadcast_compare %[[INDEX2]], %[[ZERO]]
2993  // CHECK-DAG-SAME: {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
2994  // CHECK-NEXT: %[[DIM:.*]] = mhlo.constant dense<32> : tensor<i32>
2995  // CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[DIM]], %[[INDEX2]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
2996  // CHECK-NEXT: %[[INDEX3:.*]] = "mhlo.select"(%[[CMP]], %[[WRAP]], %[[INDEX2]]) :
2997  // CHECK-DAG-SAME: (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
2998  // CHECK-NEXT: %[[SLICED:.*]] = "mhlo.dynamic-slice"
2999  // CHECK-DAG-SAME: (%arg1, %[[INDEX3]], %[[ZERO]], %[[ZERO]])
3000  // CHECK-DAG-SAME: {slice_sizes = dense<[1, 1, 97]> : tensor<3xi64>} :
3001  // CHECK-DAG-SAME: (tensor<32x1x97xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<1x97xi32>
3002  // CHECK-NEXT: %[[FINAL:.*]] = "mhlo.reshape"(%[[SLICED]]) : (tensor<1x97xi32>) -> tensor<1x97xi32>
3003  %result = "tf.StridedSlice"(%arg1, %begin, %end, %1) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3004  // CHECK-NEXT: return %[[FINAL]] : tensor<1x97xi32>
3005  return %result : tensor<1x97xi32>
3006}
3007
3008// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_1
3009func @strided_slice_nonconstant_begin_end_stride_1(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>, %strides: tensor<1xi32>) -> (tensor<1x97xi32>) {
3010  // Dynamic stride: when `begin` and `end` inputs are unknown at compile time,
3011  // `strides` must be known.
3012  // CHECK: tf.StridedSlice
3013  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3014  return %result : tensor<1x97xi32>
3015}
3016
3017// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_2
3018func @strided_slice_nonconstant_begin_end_stride_2(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
3019  // Invalid stride (not equal to 1): when `begin` and `end` inputs are unknown
3020  // at compile time, `strides` must be known to have all 1 values.
3021  %strides = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
3022  // CHECK: tf.StridedSlice
3023  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3024  return %result : tensor<1x97xi32>
3025}
3026
3027// CHECK-LABEL: strided_slice_nonconstant_begin_end_invalid_elem_count
3028func @strided_slice_nonconstant_begin_end_invalid_elem_count(%input: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>) -> tensor<6x10xf32> {
3029  %strides = "tf.Const"() { value = dense<[1, 1]> : tensor<2xi64> } : () -> tensor<2xi64>
3030  // When begin/end are dynamic, the number of output elements must be equal to
3031  // the number of input elements sliced.
3032  // CHECK: tf.StridedSlice
3033  %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<6x10xf32>
3034  return %0 : tensor<6x10xf32>
3035}
3036
3037// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_begin_mask
3038func @strided_slice_nonconstant_begin_end_and_begin_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
3039  // Begin mask: When `begin` and `end` inputs are unknown at compile time, we
3040  // can't support a begin mask.
3041  %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3042  // CHECK: tf.StridedSlice
3043  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3044  return %result : tensor<1x97xi32>
3045}
3046
3047// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_end_mask
3048func @strided_slice_nonconstant_begin_end_and_end_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
3049  // End mask: When `begin` and `end` inputs are unknown at compile time, we
3050  // can't support an end mask.
3051  %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3052  // CHECK: tf.StridedSlice
3053  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3054  return %result : tensor<1x97xi32>
3055}
3056
3057// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_new_axis_mask
3058func @strided_slice_nonconstant_begin_end_and_new_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
3059  // New axis mask: When `begin` and `end` inputs are unknown at compile time,
3060  // we can't support a new_axis mask.
3061  %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3062  // CHECK: tf.StridedSlice
3063  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 15 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3064  return %result : tensor<1x97xi32>
3065}
3066
3067// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_ellipsis_mask
3068func @strided_slice_nonconstant_begin_end_and_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
3069  // This ellipsis mask is not supported because it does not refer to the last
3070  // dimension.
3071  // [0, 1, 0] = 2
3072  %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3073  // CHECK: tf.StridedSlice
3074  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 2 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3075  return %result : tensor<1x97xi32>
3076}
3077
3078// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask
3079func @strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
3080  // This ellipsis mask is supported because it refers to the last dimension.
3081  // [1, 0, 0] = 4
3082  %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3083  // CHECK: mhlo.dynamic-slice
3084  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 4 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3085  return %result : tensor<1x97xi32>
3086}
3087
3088// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask
3089func @strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
3090  // This shrink_axis mask is supported because it refers to a major dimension.
3091  // [1, 1, 1] = 7
3092  %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3093  // CHECK: mhlo.dynamic-slice
3094  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 7 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3095  return %result : tensor<1x97xi32>
3096}
3097
3098// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_invalid_shrink_axis_mask
3099func @strided_slice_nonconstant_begin_end_and_invalid_shrink_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) {
3100  // This shrink_axis mask is unsupported because it does not refer to a major
3101  // dimension.
3102  // [0, 1, 0] = 2
3103  %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3104  // CHECK: tf.StridedSlice
3105  %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32>
3106  return %result : tensor<1x97xi32>
3107}
3108
3109
3110//===----------------------------------------------------------------------===//
3111// Reduction op legalizations.
3112//===----------------------------------------------------------------------===//
3113
3114// CHECK-LABEL: func @mean
3115func @mean(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> {
3116  // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf32>
3117  // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
3118  // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( {
3119  // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f32>, %[[ARGB:.*]]: tensor<f32>):
3120  // CHECK:  %[[REDUCE_BODY_RESULT:.*]] = mhlo.add %[[ARGA]], %[[ARGB]] : tensor<f32>
3121  // CHECK:  "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f32>) -> ()
3122  // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32>
3123  // CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<8.000000e+00> : tensor<f32>
3124  // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %[[DIVISOR]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
3125  // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[MEAN]]) : (tensor<4xf32>) -> tensor<4xf16>
3126  // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16>
3127  // CHECK: return %[[RESULT]] : tensor<4x1xf16>
3128  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3129  %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16>
3130  return %0 : tensor<4x1xf16>
3131}
3132
3133// CHECK-LABEL: func @mean_scalar_dim
3134func @mean_scalar_dim(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> {
3135  // Verify that tf.Mean op with scalar attributes are lowered successfully.
3136
3137  // CHECK-NOT: tf.Mean
3138  %dimension = "tf.Const"() { value = dense<1> : tensor<i64> } : () -> tensor<i64>
3139  %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<i64>) -> tensor<4x1xf16>
3140  return %0 : tensor<4x1xf16>
3141}
3142
3143// CHECK-LABEL: func @mean_dynamic
3144func @mean_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> {
3145  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3146  // CHECK: "tf.Mean"
3147  %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<4x?xf16>, tensor<1xi64>) -> tensor<4x1xf16>
3148  return %0 : tensor<4x1xf16>
3149}
3150
3151// CHECK-LABEL: func @sum
3152func @sum(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> {
3153  // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf32>
3154  // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
3155  // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( {
3156  // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f32>, %[[ARGB:.*]]: tensor<f32>):
3157  // CHECK:  %[[REDUCE_BODY_RESULT:.*]] = mhlo.add %[[ARGA]], %[[ARGB]] : tensor<f32>
3158  // CHECK:  "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f32>) -> ()
3159  // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32>
3160  // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf32>) -> tensor<4xf16>
3161  // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16>
3162  // CHECK: return %[[RESULT]] : tensor<4x1xf16>
3163  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3164  %0 = "tf.Sum"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16>
3165  return %0 : tensor<4x1xf16>
3166}
3167
3168// CHECK-LABEL: func @sum_dynamic
3169func @sum_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> {
3170    // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x?xf16>) -> tensor<4x?xf32>
3171    // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
3172    // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( {
3173    // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f32>, %[[ARGB:.*]]: tensor<f32>):
3174    // CHECK:  %[[REDUCE_BODY_RESULT:.*]] = mhlo.add %[[ARGA]], %[[ARGB]] : tensor<f32>
3175    // CHECK:  "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f32>) -> ()
3176    // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x?xf32>, tensor<f32>) -> tensor<4xf32>
3177    // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf32>) -> tensor<4xf16>
3178    // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16>
3179    // CHECK: return %[[RESULT]] : tensor<4x1xf16>
3180  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3181  %0 = "tf.Sum"(%arg0, %dimension) { keep_dims = true }: (tensor<4x?xf16>, tensor<1xi64>) -> tensor<4x1xf16>
3182  return %0 : tensor<4x1xf16>
3183}
3184
3185// CHECK-LABEL: func @max
3186func @max(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> {
3187  // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf16>
3188  // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0xFC00> : tensor<f16>
3189  // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( {
3190  // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f16>, %[[ARGB:.*]]: tensor<f16>):
3191  // CHECK:  %[[REDUCE_BODY_RESULT:.*]] = mhlo.maximum %[[ARGA]], %[[ARGB]] : tensor<f16>
3192  // CHECK:  "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f16>) -> ()
3193  // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf16>, tensor<f16>) -> tensor<4xf16>
3194  // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf16>) -> tensor<4xf16>
3195  // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16>
3196  // CHECK: return %[[RESULT]] : tensor<4x1xf16>
3197  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3198  %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16>
3199  return %0 : tensor<4x1xf16>
3200}
3201
3202// CHECK-LABEL: func @max_qint
3203// Regression test to ensure we don't crash getting the initial value for
3204// tf.Max when using quantized integer types.
3205func @max_qint(%arg0: tensor<4x8x!tf.qint8>) -> tensor<4x1x!tf.qint8> {
3206  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3207  %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf.qint8>, tensor<1xi64>) -> tensor<4x1x!tf.qint8>
3208  return %0 : tensor<4x1x!tf.qint8>
3209}
3210
3211// CHECK-LABEL: func @max_dynamic
3212func @max_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> {
3213    // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x?xf16>) -> tensor<4x?xf16>
3214    // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0xFC00> : tensor<f16>
3215    // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( {
3216    // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f16>, %[[ARGB:.*]]: tensor<f16>):
3217    // CHECK:  %[[REDUCE_BODY_RESULT:.*]] = mhlo.maximum %[[ARGA]], %[[ARGB]] : tensor<f16>
3218    // CHECK:  "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f16>) -> ()
3219    // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x?xf16>, tensor<f16>) -> tensor<4xf16>
3220    // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf16>) -> tensor<4xf16>
3221    // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16>
3222    // CHECK: return %[[RESULT]] : tensor<4x1xf16>
3223  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3224  %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x?xf16>, tensor<1xi64>) -> tensor<4x1xf16>
3225  return %0 : tensor<4x1xf16>
3226}
3227
3228// CHECK-LABEL: func @min
3229func @min(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> {
3230  // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf16>
3231  // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0x7C00> : tensor<f16>
3232  // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( {
3233  // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f16>, %[[ARGB:.*]]: tensor<f16>):
3234  // CHECK:  %[[REDUCE_BODY_RESULT:.*]] = mhlo.minimum %[[ARGA]], %[[ARGB]] : tensor<f16>
3235  // CHECK:  "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f16>) -> ()
3236  // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf16>, tensor<f16>) -> tensor<4xf16>
3237  // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf16>) -> tensor<4xf16>
3238  // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16>
3239  // CHECK: return %[[RESULT]] : tensor<4x1xf16>
3240  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3241  %0 = "tf.Min"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16>
3242  return %0 : tensor<4x1xf16>
3243}
3244
3245// CHECK-LABEL: func @min_qint
3246// Regression test to ensure we don't crash getting the initial value for
3247// tf.Min when using quantized integer types.
3248func @min_qint(%arg0: tensor<4x8x!tf.qint8>) -> tensor<4x1x!tf.qint8> {
3249  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3250  %0 = "tf.Min"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf.qint8>, tensor<1xi64>) -> tensor<4x1x!tf.qint8>
3251  return %0 : tensor<4x1x!tf.qint8>
3252}
3253
3254// CHECK-LABEL: func @prod
3255func @prod(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> {
3256  // CHECK: %[[CAST:.*]] = "mhlo.convert"(%arg0) : (tensor<4x8xf16>) -> tensor<4x8xf32>
3257  // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
3258  // CHECK: %[[REDUCED:.*]] = "mhlo.reduce"(%[[CAST]], %[[INITIAL]]) ( {
3259  // CHECK: ^bb0(%[[ARGA:.*]]: tensor<f32>, %[[ARGB:.*]]: tensor<f32>):
3260  // CHECK:  %[[REDUCE_BODY_RESULT:.*]] = mhlo.multiply %[[ARGA]], %[[ARGB]] : tensor<f32>
3261  // CHECK:  "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f32>) -> ()
3262  // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32>
3263  // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[REDUCED]]) : (tensor<4xf32>) -> tensor<4xf16>
3264  // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16>
3265  // CHECK: return %[[RESULT]] : tensor<4x1xf16>
3266  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3267  %0 = "tf.Prod"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16>
3268  return %0 : tensor<4x1xf16>
3269}
3270
3271// CHECK-LABEL: func @prod_qint
3272// Regression test to ensure we don't crash getting the initial value for
3273// tf.Prod when using quantized integer types.
3274func @prod_qint(%arg0: tensor<4x8x!tf.qint8>) -> tensor<4x1x!tf.qint8> {
3275  %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64>
3276  %0 = "tf.Prod"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf.qint8>, tensor<1xi64>) -> tensor<4x1x!tf.qint8>
3277  return %0 : tensor<4x1x!tf.qint8>
3278}
3279
3280// CHECK-LABEL: @all
3281func @all(%input: tensor<4x8xi1>) -> tensor<4xi1> {
3282  %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3283  // CHECK: %[[INIT:.*]] = mhlo.constant dense<true> : tensor<i1>
3284  // CHECK: "mhlo.reduce"(%{{.*}}, %[[INIT]]) ( {
3285  // CHECK: ^{{.*}}(%[[ARGA:.*]]: tensor<i1>, %[[ARGB:.*]]: tensor<i1>):
3286  // CHECK:  %[[AND:.*]] = mhlo.and %[[ARGA]], %[[ARGB]] : tensor<i1>
3287  // CHECK:  "mhlo.return"(%[[AND]]) : (tensor<i1>) -> ()
3288  // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xi1>, tensor<i1>) -> tensor<4xi1>
3289  %0 = "tf.All"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1>
3290  return %0 : tensor<4xi1>
3291}
3292
3293// CHECK-LABEL: @all_keep_dim
3294func @all_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> {
3295  // CHECK: "mhlo.reshape"(%{{.*}}) : (tensor<4xi1>) -> tensor<4x1xi1>
3296  %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3297  %0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1>
3298  return %0 : tensor<4x1xi1>
3299}
3300
3301// CHECk-LABEL: @all_dynamic
3302func @all_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> {
3303  %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3304  // CHECK: %[[ARG:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<4x?xi1>) -> tensor<4x?xi1>
3305  // CHECK: "mhlo.reduce"(%[[ARG]]
3306  %0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1>
3307  return %0 : tensor<4x1xi1>
3308}
3309
3310// CHECK-LABEL: @any
3311func @any(%input: tensor<4x8xi1>) -> tensor<4xi1> {
3312  %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3313  // CHECK: %[[INIT:.*]] = mhlo.constant dense<false> : tensor<i1>
3314  // CHECK: "mhlo.reduce"(%{{.*}}, %[[INIT]]) ( {
3315  // CHECK: ^{{.*}}(%[[ARGA:.*]]: tensor<i1>, %[[ARGB:.*]]: tensor<i1>):
3316  // CHECK:  %[[AND:.*]] = mhlo.or %[[ARGA]], %[[ARGB]] : tensor<i1>
3317  // CHECK:  "mhlo.return"(%[[AND]]) : (tensor<i1>) -> ()
3318  // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xi1>, tensor<i1>) -> tensor<4xi1>
3319  %0 = "tf.Any"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1>
3320  return %0 : tensor<4xi1>
3321}
3322
3323// CHECK-LABEL: @any_keep_dim
3324func @any_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> {
3325  // CHECK: "mhlo.reshape"(%{{.*}}) : (tensor<4xi1>) -> tensor<4x1xi1>
3326  %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3327  %0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1>
3328  return %0 : tensor<4x1xi1>
3329}
3330
3331// CHECk-LABEL: @any_dynamic
3332func @any_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> {
3333  %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
3334  // CHECK: %[[ARG:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<4x?xi1>) -> tensor<4x?xi1>
3335  // CHECK: "mhlo.reduce"(%[[ARG]]
3336  %0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1>
3337  return %0 : tensor<4x1xi1>
3338}
3339
3340//===----------------------------------------------------------------------===//
3341// Tile op legalizations.
3342//===----------------------------------------------------------------------===//
3343
3344// CHECK-LABEL: func @tile_by_reshape
3345func @tile_by_reshape(%arg0: tensor<4x8xf32>) -> tensor<28x24xf32> {
3346  // CHECK: %[[BROADCASTED:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>} : (tensor<4x8xf32>) -> tensor<7x4x3x8xf32>
3347  // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[BROADCASTED]]) : (tensor<7x4x3x8xf32>) -> tensor<28x24xf32>
3348  // CHECK: return %[[RESULT]] : tensor<28x24xf32>
3349  %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi64> } : () -> tensor<2xi64>
3350  %0 = "tf.Tile"(%arg0, %multiples) : (tensor<4x8xf32>, tensor<2xi64>) -> tensor<28x24xf32>
3351  return %0 : tensor<28x24xf32>
3352}
3353
3354// CHECK-LABEL: func @tile_just_broadcast
3355func @tile_just_broadcast(%arg0: tensor<1x1xf32>) -> tensor<7x3xf32> {
3356  // CHECK: %[[RESULT:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<7x3xf32>
3357  // CHECK: return %[[RESULT]] : tensor<7x3xf32>
3358  %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi64> } : () -> tensor<2xi64>
3359  %0 = "tf.Tile"(%arg0, %multiples) : (tensor<1x1xf32>, tensor<2xi64>) -> tensor<7x3xf32>
3360  return %0 : tensor<7x3xf32>
3361}
3362
3363//===----------------------------------------------------------------------===//
3364// ArgMax op legalizations.
3365//===----------------------------------------------------------------------===//
3366
3367// CHECK-LABEL: func @argmax_i64_input_i32_output_axis_0
3368func @argmax_i64_input_i32_output_axis_0(%arg0: tensor<3x7xi64>) -> tensor<7xi32> {
3369  // CHECK: %[[INIT:.*]] = mhlo.constant dense<-9223372036854775808> : tensor<i64>
3370  // CHECK: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i32>
3371  // CHECK: %[[INDEX:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<3x7xi32>
3372  // CHECK: %[[REDUCE:.*]]:2 = "mhlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]])
3373  // CHECK: ^bb0(%[[ARG1:.*]]: tensor<i64>, %[[ARG2:.*]]: tensor<i32>, %[[ARG3:.*]]: tensor<i64>, %[[ARG4:.*]]: tensor<i32>):
3374  // CHECK: %[[COMPARE:.*]] = "mhlo.compare"(%[[ARG1]], %[[ARG3]]) {comparison_direction = "GT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
3375  // CHECK:  %[[RESULT1:.*]] = "mhlo.select"(%[[COMPARE]], %[[ARG1]], %[[ARG3]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
3376  // CHECK:  %[[RESULT2:.*]] = "mhlo.select"(%[[COMPARE]], %[[ARG2]], %[[ARG4]]) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
3377  // CHECK: "mhlo.return"(%[[RESULT1]], %[[RESULT2]]) : (tensor<i64>, tensor<i32>) -> ()
3378  // CHECK: return %[[REDUCE]]#1 : tensor<7xi32>
3379  %axis = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
3380  %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xi64>, tensor<i32>) -> tensor<7xi32>
3381  return %0 : tensor<7xi32>
3382}
3383
3384// CHECK-LABEL: func @argmax_f32_input_i64_output_axis_1
3385func @argmax_f32_input_i64_output_axis_1(%arg0: tensor<3x7xf32>) -> tensor<3xi64> {
3386  // CHECK: %[[INIT:.*]] = mhlo.constant dense<0xFF800000> : tensor<f32>
3387  // CHECK: %[[INDEX_INIT:.*]] = mhlo.constant  dense<0> : tensor<i64>
3388  // CHECK: %[[INDEX:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x7xi64>
3389  // CHECK: %[[REDUCE:.*]]:2 = "mhlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]])
3390  // CHECK: return %[[REDUCE]]#1 : tensor<3xi64>
3391  %axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
3392  %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xf32>, tensor<i32>) -> tensor<3xi64>
3393  return %0 : tensor<3xi64>
3394}
3395
3396// CHECK-LABEL: func @argmax_dynamic_shape_input_output
3397func @argmax_dynamic_shape_input_output(%arg0: tensor<3x?xi32>) -> tensor<?xi32> {
3398  // CHECK: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor<i32>
3399  // CHECK: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i32>
3400  // CHECK: %[[INDEX:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<3x?xi32>
3401  // CHECK: %[[REDUCE:.*]]:2 = "mhlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]])
3402  // CHECK: return %[[REDUCE]]#1 : tensor<?xi32>
3403  %axis = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
3404  %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x?xi32>, tensor<i32>) -> tensor<?xi32>
3405  return %0 : tensor<?xi32>
3406}
3407
3408// CHECK-LABEL: func @argmax_dynamic_shape_input
3409func @argmax_dynamic_shape_input(%arg0: tensor<3x?xi32>) -> tensor<3xi32> {
3410  // CHECK: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor<i32>
3411  // CHECK: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor<i32>
3412  // CHECK: %[[INDEX:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x?xi32>
3413  // CHECK: %[[REDUCE:.*]]:2 = "mhlo.reduce"(%arg0, %[[INDEX]], %[[INIT]], %[[INDEX_INIT]])
3414  // CHECK: return %[[REDUCE]]#1 : tensor<3xi32>
3415  %axis = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
3416  %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x?xi32>, tensor<i32>) -> tensor<3xi32>
3417  return %0 : tensor<3xi32>
3418}
3419
3420//===----------------------------------------------------------------------===//
3421// Random op legalizations.
3422//===----------------------------------------------------------------------===//
3423
3424// CHECK-LABEL: func @rng_uniform
3425func @rng_uniform(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> {
3426  // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
3427  // CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
3428  // CHECK: %[[CONV:.*]] = "mhlo.convert"(%arg0) : (tensor<3xi32>) -> tensor<3xi64>
3429  // CHECK: %[[F32:.*]] = "mhlo.rng_uniform"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*}} -> tensor<12x?x64xf32>
3430  %0 = "tf.RandomUniform"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32>
3431  // CHECK: return %[[F32]]
3432  return %0 : tensor<12x?x64xf32>
3433}
3434
3435// CHECK-LABEL: func @rng_std_normal
3436func @rng_std_normal(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> {
3437  // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
3438  // CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
3439  // CHECK: %[[CONV:.*]] = "mhlo.convert"(%arg0) : (tensor<3xi32>) -> tensor<3xi64>
3440  // CHECK: %[[F32:.*]] = "mhlo.rng_normal"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*}} -> tensor<12x?x64xf32>
3441  %0 = "tf.RandomStandardNormal"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32>
3442  // CHECK: return %[[F32]]
3443  return %0 : tensor<12x?x64xf32>
3444}
3445
3446//===----------------------------------------------------------------------===//
3447// Range op legalizations.
3448//===----------------------------------------------------------------------===//
3449
3450// CHECK-LABEL: func @range
3451// CHECK-SAME: [[START:%.*]]: tensor<f32>, [[DELTA:%.*]]: tensor<f32>
3452func @range(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<5xf32> {
3453  %1 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "range/limit", value = dense<5.000000e+00> : tensor<f32>} : () -> tensor<f32>
3454  // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota"
3455  // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[DELTA]] {broadcast_dimensions = dense<> : tensor<0xi64>}
3456  // CHECK: chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<> : tensor<0xi64>}
3457  %3 = "tf.Range"(%arg0, %1, %arg1) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<5xf32>
3458  return %3 : tensor<5xf32>
3459}
3460
3461// CHECK-LABEL: func @range_dynamic
3462// CHECK-SAME: [[START:%.*]]: tensor<f32>, [[DELTA:%.*]]: tensor<f32>
3463func @range_dynamic(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<?xf32> {
3464  // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract %arg1, %arg0
3465  // CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"([[SUB]])
3466  // CHECK-DAG: [[CONVERT1:%.+]] = "mhlo.convert"([[ABS1]])
3467  // CHECK-DAG: [[CONVERT2:%.+]] = "mhlo.convert"(%arg2)
3468  // CHECK-DAG: [[DIV:%.+]] = mhlo.divide [[CONVERT1]], [[CONVERT2]]
3469  // CHECK-DAG: [[CEIL:%.+]] = "mhlo.ceil"([[DIV]])
3470  // CHECK-DAG: [[CONVERT3:%.+]] = "mhlo.convert"([[CEIL]])
3471  // CHECK-DAG: [[RESHAPE:%.+]] = "mhlo.reshape"([[CONVERT3]])
3472  // CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64}
3473  // CHECK-DAG: [[CONVERT3:%.+]] = "mhlo.convert"(%arg0)
3474  // CHECK-DAG: [[CONVERT4:%.+]] = "mhlo.convert"(%arg2)
3475  // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<> : tensor<0xi64>}
3476  // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<> : tensor<0xi64>}
3477  %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
3478
3479  // CHECK: return [[ADD]]
3480  return %2 : tensor<?xf32>
3481}
3482
3483// CHECK-LABEL: func @range_int_dynamic
3484// CHECK-SAME: [[START:%.*]]: tensor<i32>, [[DELTA:%.*]]: tensor<i32>
3485func @range_int_dynamic(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?xi32> {
3486  // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract %arg1, %arg0
3487  // CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"([[SUB]])
3488  // CHECK-DAG: [[CONVERT1:%.+]] = "mhlo.convert"([[ABS1]])
3489  // CHECK-DAG: [[CONVERT2:%.+]] = "mhlo.convert"(%arg2)
3490  // CHECK-DAG: [[DIV:%.+]] = mhlo.divide [[CONVERT1]], [[CONVERT2]]
3491  // CHECK-DAG: [[CEIL:%.+]] = "mhlo.ceil"([[DIV]])
3492  // CHECK-DAG: [[CONVERT3:%.+]] = "mhlo.convert"([[CEIL]])
3493  // CHECK-DAG: [[RESHAPE:%.+]] = "mhlo.reshape"([[CONVERT3]])
3494  // CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64}
3495  // CHECK-DAG: [[CONVERT3:%.+]] = "mhlo.convert"(%arg0)
3496  // CHECK-DAG: [[CONVERT4:%.+]] = "mhlo.convert"(%arg2)
3497  // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<> : tensor<0xi64>}
3498  // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<> : tensor<0xi64>}
3499  %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
3500
3501  // CHECK: return [[ADD]]
3502  return %2 : tensor<?xi32>
3503}
3504
3505// CHECK-LABEL: func @linspace_static
3506// CHECK-SAME: [[START:%.*]]: tensor<f32>, [[STOP:%.*]]: tensor<f32>
3507func @linspace_static(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<4xf32> {
3508  // CHECK-DAG: [[NUM:%.*]] = mhlo.constant dense<4>
3509  // CHECK-DAG: [[NUM_CAST:%.*]] = tensor.cast [[NUM]]
3510  // CHECK-DAG: [[NUM_F32:%.*]] = "mhlo.convert"([[NUM_CAST]])
3511  // CHECK-DAG: [[ONE:%.*]] = mhlo.constant dense<1.000000e+00>
3512  // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = chlo.broadcast_subtract [[NUM_F32]], [[ONE]]
3513  // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = chlo.broadcast_subtract [[STOP]], [[START]]
3514  // CHECK-DAG: [[STEP:%.*]] = chlo.broadcast_divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]]
3515  // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota"() {iota_dimension = 0 : i64}
3516  // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = dense<> : tensor<0xi64>}
3517  // CHECK-DAG: [[LINSPACE:%.*]] = chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<> : tensor<0xi64>}
3518  // CHECK: return [[LINSPACE]]
3519  %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<4> : tensor<i32>} : () -> tensor<i32>
3520  %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor<f32>, tensor<f32>, tensor<i32>) -> tensor<4xf32>
3521  return %1 : tensor<4xf32>
3522}
3523
3524// CHECK-LABEL: func @linspace_dynamic
3525func @linspace_dynamic(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>) -> tensor<?xf32> {
3526  // CHECK: "tf.LinSpace"
3527  %0 = "tf.LinSpace"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<f32>, tensor<i32>) -> tensor<?xf32>
3528  return %0 : tensor<?xf32>
3529}
3530
3531// CHECK-LABEL: func @linspace_invalid_num
3532func @linspace_invalid_num(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<?xf32> {
3533  // CHECK: mhlo.constant dense<> : tensor<0xi32>
3534  // CHECK: "tf.LinSpace"
3535  %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
3536  %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor<f32>, tensor<f32>, tensor<0xi32>) -> tensor<?xf32>
3537  return %1 : tensor<?xf32>
3538}
3539
3540//===----------------------------------------------------------------------===//
3541// LegacyCall op legalizations.
3542//===----------------------------------------------------------------------===//
3543
3544func @identity_func(%arg0: tensor<10x2xf32>) -> tensor<10x2xf32> {
3545  return %arg0: tensor<10x2xf32>
3546}
3547
3548// CHECK-LABEL: testSimpleLegacyCallOp
3549func @testSimpleLegacyCallOp(%arg0: tensor<10x2xf32>) -> tensor<10x2xf32> {
3550  // CHECK: %[[RESULT:.*]] = call @identity_func(%arg0) : (tensor<10x2xf32>) -> tensor<10x2xf32>
3551  %0 = "tf.LegacyCall"(%arg0) {f = @identity_func} : (tensor<10x2xf32>) -> tensor<10x2xf32>
3552  // CHECK: return %[[RESULT]]
3553  return %0: tensor<10x2xf32>
3554}
3555
3556func @select_first(%arg0: tensor<10x2xf32>, %arg1: tensor<10x2xf32>) -> tensor<10x2xf32> {
3557  return %arg0: tensor<10x2xf32>
3558}
3559
3560// CHECK-LABEL: testMultiInputLegacyCallOp
3561func @testMultiInputLegacyCallOp(%arg0: tensor<10x2xf32>, %arg1: tensor<10x2xf32>) -> tensor<10x2xf32> {
3562  // CHECK: %[[RESULT:.*]] = call @select_first(%arg0, %arg1) : (tensor<10x2xf32>, tensor<10x2xf32>) -> tensor<10x2xf32>
3563  %0 = "tf.LegacyCall"(%arg0, %arg1) {_disable_call_shape_inference = true, _tpu_replicate = "cluster", device = "", f = @select_first} : (tensor<10x2xf32>, tensor<10x2xf32>) -> tensor<10x2xf32>
3564  // CHECK: return %[[RESULT]]
3565  return %0: tensor<10x2xf32>
3566}
3567
3568//===----------------------------------------------------------------------===//
3569// Conv op legalizations.
3570//===----------------------------------------------------------------------===//
3571
3572// CHECK-LABEL: conv_simple
3573func @conv_simple(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32> {
3574
3575  // CHECK: "mhlo.convolution"(%arg0, %arg1)
3576
3577  // Default attributes
3578  // CHECK-NOT: lhs_dilation
3579  // CHECK-NOT: precision_config
3580
3581  // CHECK-DAG-SAME: window_strides = dense<[4, 5]>
3582  // CHECK-DAG-SAME: padding = dense<{{\[\[}}44, 45], [60, 60]]>
3583  // CHECK-DAG-SAME: rhs_dilation = dense<[2, 3]>
3584
3585  // CHECK-DAG-SAME: dimension_numbers
3586  // CHECK-DAG-SAME:   input_batch_dimension = 0
3587  // CHECK-DAG-SAME:   input_feature_dimension = 3
3588  // CHECK-DAG-SAME:   input_spatial_dimensions = dense<[1, 2]>
3589  // CHECK-DAG-SAME:   kernel_input_feature_dimension = 2
3590  // CHECK-DAG-SAME:   kernel_output_feature_dimension = 3
3591  // CHECK-DAG-SAME:   kernel_spatial_dimensions = dense<[0, 1]>
3592  // CHECK-DAG-SAME:   output_batch_dimension = 0
3593  // CHECK-DAG-SAME:   output_feature_dimension = 3
3594  // CHECK-DAG-SAME:   output_spatial_dimensions = dense<[1, 2]>
3595
3596  // CHECK-DAG-SAME: feature_group_count = 2
3597  // CHECK-DAG-SAME: batch_group_count = 1
3598
3599  %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32>
3600  return %0 : tensor<256x8x7x16xf32>
3601}
3602
3603// CHECK-LABEL: conv3d_simple
3604func @conv3d_simple(%arg0: tensor<256x32x32x32x6xf32>, %arg1: tensor<3x3x3x3x16xf32>) -> tensor<256x7x6x5x16xf32> {
3605
3606  // CHECK: "mhlo.convolution"(%arg0, %arg1)
3607
3608  // Default attributes
3609  // CHECK-NOT: lhs_dilation
3610  // CHECK-NOT: precision_config
3611
3612  // CHECK-DAG-SAME: window_strides = dense<[5, 6, 7]>
3613  // CHECK-DAG-SAME: padding = dense<[[1, 2], [2, 3], [2, 3]]>
3614  // CHECK-DAG-SAME: rhs_dilation = dense<[2, 3, 4]>
3615
3616  // CHECK-DAG-SAME: dimension_numbers
3617  // CHECK-DAG-SAME:   input_batch_dimension = 0
3618  // CHECK-DAG-SAME:   input_feature_dimension = 4
3619  // CHECK-DAG-SAME:   input_spatial_dimensions = dense<[1, 2, 3]>
3620  // CHECK-DAG-SAME:   kernel_input_feature_dimension = 3
3621  // CHECK-DAG-SAME:   kernel_output_feature_dimension = 4
3622  // CHECK-DAG-SAME:   kernel_spatial_dimensions = dense<[0, 1, 2]>
3623  // CHECK-DAG-SAME:   output_batch_dimension = 0
3624  // CHECK-DAG-SAME:   output_feature_dimension = 4
3625  // CHECK-DAG-SAME:   output_spatial_dimensions = dense<[1, 2, 3]>
3626
3627  // CHECK-DAG-SAME: feature_group_count = 2
3628  // CHECK-DAG-SAME: batch_group_count = 1
3629
3630  %0 = "tf.Conv3D"(%arg0, %arg1) {data_format = "NDHWC", dilations = [1, 2, 3, 4, 1], padding = "SAME", strides = [1, 5, 6, 7, 1]} : (tensor<256x32x32x32x6xf32>, tensor<3x3x3x3x16xf32>) -> tensor<256x7x6x5x16xf32>
3631  return %0 : tensor<256x7x6x5x16xf32>
3632}
3633
3634// CHECK-LABEL: depthwiseconv_simple
3635func @depthwiseconv_simple(%arg0: tensor<2x4x5x3xf32>, %arg1: tensor<2x2x3x3xf32>) -> tensor<2x3x4x9xf32> {
3636  // CHECK: %[[RESHAPED_FILTER:.*]] = "mhlo.reshape"(%arg1) : (tensor<2x2x3x3xf32>) -> tensor<2x2x1x9xf32>
3637  // CHECK: "mhlo.convolution"(%arg0, %[[RESHAPED_FILTER]])
3638  // CHECK: feature_group_count = 3
3639  %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {
3640    data_format = "NHWC",
3641    device = "",
3642    dilations = [1, 1, 1, 1],
3643    explicit_paddings = [],
3644    padding = "VALID",
3645    strides = [1, 1, 1, 1]
3646  } : (tensor<2x4x5x3xf32>, tensor<2x2x3x3xf32>) -> tensor<2x3x4x9xf32>
3647  return %0 : tensor<2x3x4x9xf32>
3648}
3649
3650// CHECK-LABEL: conv_valid_padding
3651func @conv_valid_padding(%arg0: tensor<1x4x5x1xf32>, %arg1: tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32> {
3652  // CHECK: "mhlo.convolution"(%arg0, %arg1)
3653
3654  %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x4x5x1xf32>, tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32>
3655  return %0 : tensor<1x2x3x1xf32>
3656}
3657
3658// CHECK-LABEL: conv_explicit_paddings
3659func @conv_explicit_paddings(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x9x7x16xf32> {
3660
3661  // CHECK: "mhlo.convolution"(%arg0, %arg1)
3662  // CHECK-SAME: padding = dense<{{\[\[}}6, 0], [3, 3]]>
3663
3664  %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "EXPLICIT", explicit_paddings = [0, 0, 6, 0, 3, 3, 0, 0], strides = [1, 4, 5, 1]} : (tensor<256x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<256x9x7x16xf32>
3665  return %0 : tensor<256x9x7x16xf32>
3666}
3667
3668// CHECK-LABEL: @conv2d_backprop_input
3669func @conv2d_backprop_input(
3670    %filter: tensor<3x3x1x32xf32>,
3671    %out_backprop: tensor<100x26x26x32xf32>
3672  ) -> tensor<100x28x28x1xf32> {
3673    // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>}
3674    // CHECK: %[[RESULT:.*]] = "mhlo.convolution"(%arg1, %[[REV_FILTER]]) {
3675    // CHECK-SAME: batch_group_count = 1 : i64,
3676    // CHECK-SAME: dimension_numbers = {
3677    // CHECK-SAME:   input_batch_dimension = 0 : i64,
3678    // CHECK-SAME:   input_feature_dimension = 3 : i64,
3679    // CHECK-SAME:   input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
3680    // CHECK-SAME:   kernel_input_feature_dimension = 3 : i64,
3681    // CHECK-SAME:   kernel_output_feature_dimension = 2 : i64,
3682    // CHECK-SAME:   kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
3683    // CHECK-SAME:   output_batch_dimension = 0 : i64,
3684    // CHECK-SAME:   output_feature_dimension = 3 : i64,
3685    // CHECK-SAME:   output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
3686    // CHECK-SAME: },
3687    // CHECK-SAME: feature_group_count = 1 : i64,
3688    // CHECK-SAME: lhs_dilation = dense<1> : tensor<2xi64>,
3689    // CHECK-SAME: padding = dense<2> : tensor<2x2xi64>,
3690    // CHECK-SAME: rhs_dilation = dense<1> : tensor<2xi64>,
3691    // CHECK-SAME: window_strides = dense<1> : tensor<2xi64>
3692    // CHECK: return %[[RESULT]]
3693  %input_sizes = "tf.Const" () { value = dense<[100,28,28,1]> : tensor<4xi32> } : () -> tensor<4xi32>
3694  %result = "tf.Conv2DBackpropInput"(%input_sizes, %filter, %out_backprop) {
3695    data_format = "NHWC",
3696    dilations = [1, 1, 1, 1],
3697    explicit_paddings = [],
3698    padding = "VALID",
3699    strides = [1, 1, 1, 1],
3700    use_cudnn_on_gpu = true
3701  } : (tensor<4xi32>, tensor<3x3x1x32xf32>, tensor<100x26x26x32xf32>) -> tensor<100x28x28x1xf32>
3702  return %result : tensor<100x28x28x1xf32>
3703}
3704
3705// CHECK-LABEL: @conv2d_backprop_input_grouped
3706func @conv2d_backprop_input_grouped(
3707    %filter: tensor<2x2x5x21xf32>,
3708    %out_backprop: tensor<5x2x2x21xf32>
3709  ) -> tensor<5x3x3x15xf32> {
3710  %input_sizes = "tf.Const" () { value = dense<[5, 3, 3, 15]> : tensor<4xi32> } : () -> tensor<4xi32>
3711
3712  // Verify filter transformation for grouped convolution.
3713
3714  // CHECK: %[[RESHAPE:.*]] = "mhlo.reshape"(%arg0) : (tensor<2x2x5x21xf32>) -> tensor<2x2x5x3x7xf32>
3715  // CHECK: %[[TRANSPOSE:.*]] = "mhlo.transpose"(%[[RESHAPE]])
3716  // CHECK-SAME: permutation = dense<[0, 1, 3, 2, 4]>
3717  // CHECK-SAME: (tensor<2x2x5x3x7xf32>) -> tensor<2x2x3x5x7xf32>
3718  // CHECK: "mhlo.reshape"(%[[TRANSPOSE]]) : (tensor<2x2x3x5x7xf32>) -> tensor<2x2x15x7xf32>
3719
3720  %result = "tf.Conv2DBackpropInput"(%input_sizes, %filter, %out_backprop) {
3721    data_format = "NHWC",
3722    dilations = [1, 1, 1, 1],
3723    explicit_paddings = [],
3724    padding = "VALID",
3725    strides = [1, 1, 1, 1],
3726    use_cudnn_on_gpu = true
3727  } : (tensor<4xi32>, tensor<2x2x5x21xf32>, tensor<5x2x2x21xf32>) -> tensor<5x3x3x15xf32>
3728  return %result : tensor<5x3x3x15xf32>
3729}
3730
3731
3732// CHECK-LABEL: @conv3d_backprop_input
3733func @conv3d_backprop_input(%filter: tensor<3x3x3x1x6xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> {
3734  // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>}
3735  // CHECK: %[[RESULT:.*]] = "mhlo.convolution"(%arg1, %[[REV_FILTER]])
3736
3737  // CHECK-DAG-SAME: batch_group_count = 1 : i64,
3738
3739  // CHECK-DAG-SAME: dimension_numbers =
3740  // CHECK-DAG-SAME:   input_batch_dimension = 0 : i64
3741  // CHECK-DAG-SAME:   input_feature_dimension = 4 : i64
3742  // CHECK-DAG-SAME:   input_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>
3743  // CHECK-DAG-SAME:   kernel_input_feature_dimension = 4 : i64
3744  // CHECK-DAG-SAME:   kernel_output_feature_dimension = 3 : i64
3745  // CHECK-DAG-SAME:   kernel_spatial_dimensions = dense<[0, 1, 2]> : tensor<3xi64>
3746  // CHECK-DAG-SAME:   output_batch_dimension = 0 : i64
3747  // CHECK-DAG-SAME:   output_feature_dimension = 4 : i64
3748  // CHECK-DAG-SAME:   output_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>
3749
3750  // CHECK-DAG-SAME: feature_group_count = 1 : i64
3751  // CHECK-DAG-SAME: lhs_dilation = dense<1> : tensor<3xi64>
3752  // CHECK-DAG-SAME: padding = dense<1> : tensor<3x2xi64>
3753  // CHECK-DAG-SAME: rhs_dilation = dense<1> : tensor<3xi64>
3754  // CHECK-DAG-SAME: window_strides = dense<1> : tensor<3xi64>
3755
3756  // CHECK: return %[[RESULT]]
3757  %input_sizes = "tf.Const" () {value = dense<[2, 8, 8, 8, 1]> : tensor<5xi32>} : () -> tensor<5xi32>
3758  %result = "tf.Conv3DBackpropInputV2"(%input_sizes, %filter, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1],  padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<5xi32>, tensor<3x3x3x1x6xf32>, tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32>
3759  return %result : tensor<2x8x8x8x1xf32>
3760}
3761
3762// CHECK-LABEL: @conv2d_backprop_filter
3763func @conv2d_backprop_filter(
3764    %input: tensor<100x28x28x1xf32>,
3765    %out_backprop: tensor<100x26x26x32xf32>
3766  ) -> tensor<100x28x28x1xf32> {
3767  // CHECK: %[[RESULT:.*]] = "mhlo.convolution"(%arg0, %arg1) {
3768  // CHECK-SAME:  batch_group_count = 1 : i64,
3769  // CHECK-SAME:  dimension_numbers = {
3770  // CHECK-SAME:    input_batch_dimension = 3 : i64,
3771  // CHECK-SAME:    input_feature_dimension = 0 : i64,
3772  // CHECK-SAME:    input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
3773  // CHECK-SAME:    kernel_input_feature_dimension = 0 : i64,
3774  // CHECK-SAME:    kernel_output_feature_dimension = 3 : i64,
3775  // CHECK-SAME:    kernel_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
3776  // CHECK-SAME:    output_batch_dimension = 2 : i64,
3777  // CHECK-SAME:    output_feature_dimension = 3 : i64,
3778  // CHECK-SAME:    output_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>
3779  // CHECK-SAME:  },
3780  // CHECK-SAME:  feature_group_count = 1 : i64,
3781  // CHECK-SAME:  lhs_dilation = dense<1> : tensor<2xi64>,
3782  // CHECK-SAME:  padding = dense<0> : tensor<2x2xi64>,
3783  // CHECK-SAME:  rhs_dilation = dense<1> : tensor<2xi64>,
3784  // CHECK-SAME:  window_strides = dense<1> : tensor<2xi64>
3785  // CHECK: return %[[RESULT]]
3786  %filter_sizes = "tf.Const" () { value = dense<[3,3,1,32]> : tensor<4xi32> } : () -> tensor<4xi32>
3787  %result = "tf.Conv2DBackpropFilter"(%input, %filter_sizes, %out_backprop) {
3788    data_format = "NHWC",
3789    dilations = [1, 1, 1, 1],
3790    explicit_paddings = [],
3791    padding = "VALID",
3792    strides = [1, 1, 1, 1],
3793    use_cudnn_on_gpu = true
3794  } : (tensor<100x28x28x1xf32>, tensor<4xi32>, tensor<100x26x26x32xf32>) -> tensor<100x28x28x1xf32>
3795  return %result : tensor<100x28x28x1xf32>
3796}
3797
3798// CHECK-LABEL: @conv2d_backprop_filter_grouped
3799func @conv2d_backprop_filter_grouped(
3800    %input: tensor<1x2x2x2xf32>,
3801    %out_backprop: tensor<1x1x1x2xf32>
3802  ) -> tensor<2x2x1x2xf32> {
3803
3804  // CHECK: "mhlo.convolution"(%arg0, %arg1) {
3805  // CHECK-SAME:  batch_group_count = 2 : i64,
3806  // CHECK-SAME:  feature_group_count = 1 : i64,
3807
3808  %filter_sizes = "tf.Const" () { value = dense<[2, 2, 1, 2]> : tensor<4xi32> } : () -> tensor<4xi32>
3809  %result = "tf.Conv2DBackpropFilter"(%input, %filter_sizes, %out_backprop) {
3810    data_format = "NHWC",
3811    dilations = [1, 1, 1, 1],
3812    explicit_paddings = [],
3813    padding = "VALID",
3814    strides = [1, 1, 1, 1],
3815    use_cudnn_on_gpu = true
3816  } : (tensor<1x2x2x2xf32>, tensor<4xi32>, tensor<1x1x1x2xf32>) -> tensor<2x2x1x2xf32>
3817  return %result : tensor<2x2x1x2xf32>
3818}
3819
3820
3821// CHECK-LABEL: @conv3d_backprop_filter
3822func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> {
3823  // CHECK: %[[RESULT:.*]] = "mhlo.convolution"(%arg0, %arg1)
3824
3825  // CHECK-DAG-SAME: batch_group_count = 1 : i64
3826
3827  // CHECK-DAG-SAME: dimension_numbers =
3828  // CHECK-DAG-SAME:   input_batch_dimension = 4 : i64
3829  // CHECK-DAG-SAME:   input_feature_dimension = 0 : i64
3830  // CHECK-DAG-SAME:   input_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>
3831  // CHECK-DAG-SAME:   kernel_input_feature_dimension = 0 : i64
3832  // CHECK-DAG-SAME:   kernel_output_feature_dimension = 4 : i64
3833  // CHECK-DAG-SAME:   kernel_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>
3834  // CHECK-DAG-SAME:   output_batch_dimension = 3 : i64
3835  // CHECK-DAG-SAME:   output_feature_dimension = 4 : i64
3836  // CHECK-DAG-SAME:   output_spatial_dimensions = dense<[0, 1, 2]> : tensor<3xi64>
3837
3838  // CHECK-DAG-SAME: feature_group_count = 1 : i64
3839  // CHECK-DAG-SAME: lhs_dilation = dense<1> : tensor<3xi64>
3840  // CHECK-DAG-SAME: padding = dense<1> : tensor<3x2xi64>
3841  // CHECK-DAG-SAME: rhs_dilation = dense<1> : tensor<3xi64>
3842  // CHECK-DAG-SAME: window_strides = dense<1> : tensor<3xi64>
3843
3844  // CHECK: return %[[RESULT]]
3845  %filter_sizes = "tf.Const"() {value = dense<[3, 3, 3, 1, 6]> : tensor<5xi32>} : () -> tensor<5xi32>
3846  %result = "tf.Conv3DBackpropFilterV2"(%input, %filter_sizes, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1],  padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<2x8x8x8x1xf32>, tensor<5xi32>, tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32>
3847  return %result : tensor<2x8x8x8x1xf32>
3848}
3849
3850// CHECK-LABEL: @collective_permute
3851func @collective_permute(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> {
3852  %source_target_pairs = "tf.Const" () {
3853    value = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi32>
3854  } : () -> tensor<3x2xi32>
3855
3856  // CHECK: "mhlo.collective_permute"
3857  // CHECK-SAME: source_target_pairs = dense<{{\[}}[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>
3858  %0 = "tf.CollectivePermute"(%arg0, %source_target_pairs) {
3859  } : (tensor<128x32xf32>, tensor<3x2xi32>) -> tensor<128x32xf32>
3860
3861  return %0 : tensor<128x32xf32>
3862}
3863
3864// CHECK-LABEL: @cross_replica_sum
3865func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> {
3866  %replica_groups = "tf.Const" () {
3867    value = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32>
3868  } : () -> tensor<2x4xi32>
3869
3870  // CHECK: mhlo.cross-replica-sum
3871  // CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
3872  %result = "tf.CrossReplicaSum" (%input, %replica_groups) : (tensor<10xf32>, tensor<2x4xi32>) -> tensor<10xf32>
3873  return %result : tensor<10xf32>
3874}
3875
3876//===----------------------------------------------------------------------===//
3877// tf.Split legalization
3878//===----------------------------------------------------------------------===//
3879
3880// CHECK-LABEL: @split_not_match_non_const_split_dim
3881func @split_not_match_non_const_split_dim(%input: tensor<4x4xf32>, %split_dim: tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>) {
3882  // CHECK: tf.Split
3883  %0:2 = "tf.Split"(%split_dim, %input) : (tensor<i32>, tensor<4x4xf32>) -> (tensor<*xf32>, tensor<*xf32>)
3884  return %0#0, %0#1 : tensor<*xf32>, tensor<*xf32>
3885}
3886
3887// CHECK-LABEL: @split_not_match_unknown_input_dim
3888func @split_not_match_unknown_input_dim(%input: tensor<4x?x4xf32>) -> (tensor<4x?x4xf32>, tensor<4x?x4xf32>) {
3889  %cst = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
3890  // CHECK: tf.Split
3891  %0:2 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x?x4xf32>) -> (tensor<4x?x4xf32>, tensor<4x?x4xf32>)
3892  return %0#0, %0#1 : tensor<4x?x4xf32>, tensor<4x?x4xf32>
3893}
3894
3895// CHECK-LABEL: @split_match_and_split_into_two
3896func @split_match_and_split_into_two(%input: tensor<4x6xf32>) -> (tensor<2x6xf32>, tensor<2x6xf32>) {
3897  %cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
3898  // CHECK: %[[ONE:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[2, 6]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<2x6xf32>
3899  // CHECK: %[[TWO:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<2x6xf32>
3900  %0:2 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x6xf32>) -> (tensor<2x6xf32>, tensor<2x6xf32>)
3901  // CHECK: return %[[ONE]], %[[TWO]]
3902  return %0#0, %0#1 : tensor<2x6xf32>, tensor<2x6xf32>
3903}
3904
3905// CHECK-LABEL: @split_match_and_split_into_two_dynamic
3906func @split_match_and_split_into_two_dynamic(%input: tensor<4x?xf32>) -> (tensor<2x?xf32>, tensor<2x?xf32>) {
3907  %cst = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
3908  // CHECK: %[[ONE:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[2, -1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32>
3909  // CHECK: %[[TWO:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, -1]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32>
3910  %0:2 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x?xf32>) -> (tensor<2x?xf32>, tensor<2x?xf32>)
3911  // CHECK: return %[[ONE]], %[[TWO]]
3912  return %0#0, %0#1 : tensor<2x?xf32>, tensor<2x?xf32>
3913}
3914
3915// CHECK-LABEL: @split_match_and_split_into_three
3916// CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>)
3917func @split_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) {
3918  %cst = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
3919  // CHECK: %[[ONE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 2]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
3920  // CHECK: %[[TWO:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<4> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
3921  // CHECK: %[[THREE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 4]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
3922  %0:3 = "tf.Split"(%cst, %input) : (tensor<i32>, tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>)
3923  // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]]
3924  return %0#0, %0#1, %0#2 : tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>
3925}
3926
3927//===----------------------------------------------------------------------===//
3928// tf.TopKV2 legalization
3929//===----------------------------------------------------------------------===//
3930
3931// CHECK-LABEL: topk_v2_non_const_k
3932func @topk_v2_non_const_k(%input: tensor<16xf32>, %k: tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>) {
3933  // CHECK: tf.TopKV2
3934  %0:2 = "tf.TopKV2"(%input, %k): (tensor<16xf32>, tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>)
3935  return %0#0, %0#1: tensor<?xf32>, tensor<?xi32>
3936}
3937
3938// CHECK-LABEL: topk_v2_unknown_input_last_dim
3939func @topk_v2_unknown_input_last_dim(%input: tensor<16x?xf32>) -> (tensor<16x?xf32>, tensor<16x?xi32>) {
3940  %k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
3941  // CHECK: tf.TopKV2
3942  %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x?xf32>, tensor<i32>) -> (tensor<16x?xf32>, tensor<16x?xi32>)
3943  return %0#0, %0#1: tensor<16x?xf32>, tensor<16x?xi32>
3944}
3945
3946// CHECK-LABEL: topk_v2
3947// CHECK-SAME: %[[INPUT:.*]]: tensor<16x16xf32>
3948func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) {
3949  %k = "tf.Const"() {value = dense<8> : tensor<i32>} : () -> tensor<i32>
3950
3951  // CHECK:      %[[IOTA:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64}
3952  // CHECK-NEXT: %[[SORT:.*]]:2 = "mhlo.sort"(%[[INPUT]], %[[IOTA]]) ( {
3953  // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor<f32>, %[[RHS:.*]]: tensor<f32>, %{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>):
3954  // CHECK-NEXT:   %[[CMP:.*]] = "mhlo.compare"(%[[LHS]], %[[RHS]]) {compare_type = "TOTALORDER", comparison_direction = "GT"}
3955  // CHECK-NEXT:   "mhlo.return"(%[[CMP]])
3956  // CHECK-NEXT: }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
3957  // CHECK-NEXT: %[[VAL:.*]] = "mhlo.slice"(%[[SORT]]#0) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
3958  // CHECK-NEXT: %[[IDX:.*]] = "mhlo.slice"(%[[SORT]]#1) {limit_indices = dense<[16, 8]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
3959  // CHECK-NEXT: return %[[VAL]], %[[IDX]]
3960  %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor<i32>) -> (tensor<16x8xf32>, tensor<16x8xi32>)
3961  return %0#0, %0#1: tensor<16x8xf32>, tensor<16x8xi32>
3962}
3963
3964//===----------------------------------------------------------------------===//
3965// tf.SplitV legalization
3966//===----------------------------------------------------------------------===//
3967
3968// CHECK-LABEL: @splitv_match_and_split_into_three
3969// CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>)
3970func @splitv_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) {
3971  %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
3972  %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
3973  // CHECK: %[[ONE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x1xf32>
3974  // CHECK: %[[TWO:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x2xf32>
3975  // CHECK: %[[THREE:.*]] = "mhlo.slice"(%[[ARG]]) {limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>) -> tensor<4x3xf32>
3976  %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>)
3977  // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]]
3978  return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>
3979}
3980
3981// CHECK-LABEL: @splitv_match_and_split_into_three_dynamic
3982func @splitv_match_and_split_into_three_dynamic(%input: tensor<?x6xf32>) -> (tensor<?x1xf32>, tensor<?x2xf32>, tensor<?x3xf32>) {
3983  %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
3984  %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
3985  // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<?x6xf32>) -> tensor<?x1xf32>
3986  // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<?x6xf32>) -> tensor<?x2xf32>
3987  // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<?x6xf32>) -> tensor<?x3xf32>
3988  %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<?x6xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<?x1xf32>, tensor<?x2xf32>, tensor<?x3xf32>)
3989  return %0#0, %0#1, %0#2 : tensor<?x1xf32>, tensor<?x2xf32>, tensor<?x3xf32>
3990}
3991
3992// CHECK-LABEL: @splitv_dynamic_dim_in_split_sizes
3993func @splitv_dynamic_dim_in_split_sizes(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) {
3994  %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
3995  %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
3996  // CHECK: limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>
3997  // CHECK: limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>
3998  // CHECK: limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>
3999  %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>)
4000  return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>
4001}
4002
4003//===----------------------------------------------------------------------===//
4004// tf.Assert legalization
4005//===----------------------------------------------------------------------===//
4006
4007// CHECK-LABEL: @assert
4008func @assert(%arg0: tensor<i1>, %arg1: tensor<*xf32>) {
4009  // CHECK-NOT: tf.Assert
4010  "tf.Assert"(%arg0, %arg1) {summarize = 1} : (tensor<i1>, tensor<*xf32>) -> ()
4011  return
4012}
4013
4014//===----------------------------------------------------------------------===//
4015// tf.Unpack legalization
4016//===----------------------------------------------------------------------===//
4017
4018// CHECK-LABEL: @unpack
4019func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) {
4020  // CHECK: %[[SLICE1:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
4021  // CHECK: %[[RES1:.*]] = "mhlo.reshape"(%[[SLICE1]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
4022  // CHECK: %[[SLICE2:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
4023  // CHECK: %[[RES2:.*]] = "mhlo.reshape"(%[[SLICE2]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
4024  // CHECK: %[[SLICE3:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32>
4025  // CHECK: %[[RES3:.*]] = "mhlo.reshape"(%[[SLICE3]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32>
4026
4027  %0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>)
4028  // return %[[RES1]], %[[RES2]], %[[RES3]]
4029  return %0#0, %0#1, %0#2 : tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>
4030}
4031
4032// CHECK-LABEL: @unpack_dynamic
4033func @unpack_dynamic(%input: tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
4034
4035  // CHECK: tf.Unpack
4036  %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>)
4037  return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32>
4038}
4039
4040// CHECK-LABEL: @unpack_unranked
4041func @unpack_unranked(%input: tensor<*xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
4042
4043  // CHECK: tf.Unpack
4044  %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<*xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>)
4045  return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32>
4046}
4047
4048//===----------------------------------------------------------------------===//
4049// tf.UnsortedSegment{Max|Min|Prod|Sum} legalization
4050//===----------------------------------------------------------------------===//
4051
4052// CHECK-LABEL: @unsorted_segment_sum
4053// CHECK-SAME: [[DATA:%.*]]: tensor<8x16x64xf32>
4054// CHECK-SAME: [[SI:%.*]]: tensor<8x16xi32>
4055func @unsorted_segment_sum(%data: tensor<8x16x64xf32>, %segment_ids : tensor<8x16xi32>) -> (tensor<4x64xf32>) {
4056  %num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
4057  // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4058  // CHECK: [[INIT:%.*]] = "mhlo.broadcast"([[ZERO]]) {broadcast_sizes = dense<[4, 64]> : tensor<2xi64>} : (tensor<f32>) -> tensor<4x64xf32>
4059  // CHECK: [[SCATTER:%.*]] = "mhlo.scatter"([[INIT]], [[SI]], [[DATA]]) ( {
4060  // CHECK: ^{{.*}}([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
4061  // CHECK:   [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]] : tensor<f32>
4062  // CHECK:   "mhlo.return"([[ADD]])
4063  // CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = {index_vector_dim = 2 : i64, inserted_window_dims = dense<0> : tensor<1xi64>, scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, update_window_dims = dense<2> : tensor<1xi64>}, unique_indices = false} : (tensor<4x64xf32>, tensor<8x16xi32>, tensor<8x16x64xf32>) -> tensor<4x64xf32>
4064  // CHECK: return [[SCATTER]]
4065  %0 = "tf.UnsortedSegmentSum"(%data, %segment_ids, %num_segments) : (tensor<8x16x64xf32>, tensor<8x16xi32>, tensor<i32>) -> (tensor<4x64xf32>)
4066  return %0: tensor<4x64xf32>
4067}
4068
4069// CHECK-LABEL: @unsorted_segment_prod
4070// CHECK-SAME: [[DATA:%.*]]: tensor<8x?x64xf32>
4071// CHECK-SAME: [[SI:%.*]]: tensor<?x16xi32>
4072func @unsorted_segment_prod(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x16xi32>) -> (tensor<4x?xf32>) {
4073  %num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
4074  // CHECK: [[ONE:%.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
4075  // CHECK: [[INIT:%.*]] = "mhlo.broadcast"([[ONE]]) {broadcast_sizes = dense<[4, 64]> : tensor<2xi64>} : (tensor<f32>) -> tensor<4x64xf32>
4076  // CHECK: [[SCATTER:%.*]] = "mhlo.scatter"([[INIT]], [[SI]], [[DATA]]) ( {
4077  // CHECK: ^{{.*}}([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
4078  // CHECK:   [[MUL:%.*]] = mhlo.multiply [[LHS]], [[RHS]] : tensor<f32>
4079  // CHECK:   "mhlo.return"([[MUL]])
4080  // CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = {index_vector_dim = 2 : i64, inserted_window_dims = dense<0> : tensor<1xi64>, scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, update_window_dims = dense<2> : tensor<1xi64>}, unique_indices = false} : (tensor<4x64xf32>, tensor<?x16xi32>, tensor<8x?x64xf32>) -> tensor<4x?xf32>
4081  // CHECK: return [[SCATTER]]
4082  %0 = "tf.UnsortedSegmentProd"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor<?x16xi32>, tensor<i32>) -> (tensor<4x?xf32>)
4083  return %0: tensor<4x?xf32>
4084}
4085
4086// CHECK-LABEL: @unsorted_segment_min
4087func @unsorted_segment_min(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x16xi32>) -> (tensor<4x?xf32>) {
4088  %num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
4089  // CHECK: mhlo.constant dense<3.40282347E+38> : tensor<f32>
4090  // CHECK: mhlo.scatter
4091  // CHECK: mhlo.minimum
4092  %0 = "tf.UnsortedSegmentMin"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor<?x16xi32>, tensor<i32>) -> (tensor<4x?xf32>)
4093  return %0: tensor<4x?xf32>
4094}
4095
4096// CHECK-LABEL: @unsorted_segment_max
4097func @unsorted_segment_max(%data: tensor<8x?x64xf32>, %segment_ids : tensor<?x16xi32>) -> (tensor<4x?xf32>) {
4098  %num_segments = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
4099  // CHECK: mhlo.constant dense<-3.40282347E+38> : tensor<f32>
4100  // CHECK: mhlo.scatter
4101  // CHECK: mhlo.maximum
4102  %0 = "tf.UnsortedSegmentMax"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor<?x16xi32>, tensor<i32>) -> (tensor<4x?xf32>)
4103  return %0: tensor<4x?xf32>
4104}
4105
4106//===----------------------------------------------------------------------===//
4107// tf.GatherV2 legalization
4108//===----------------------------------------------------------------------===//
4109
4110// CHECK-LABEL: @gather_v2
4111func @gather_v2(%arg0: tensor<16x2x3xf32>, %arg1: tensor<16x5xi32>) -> tensor<16x2x5xf32> {
4112  // CHECK: "mhlo.torch_index_select"(%arg0, %arg1) {batch_dims = 1 : i64, dim = 2 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>) -> tensor<16x2x5xf32>
4113  %0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32>
4114  %1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>, tensor<1xi32>) -> tensor<16x2x5xf32>
4115  return %1 : tensor<16x2x5xf32>
4116}
4117
4118// CHECK-LABEL: @gather_v2_dynamic
4119func @gather_v2_dynamic(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xi32>) -> tensor<*xf32> {
4120  // CHECK: "mhlo.torch_index_select"(%arg0, %arg1) {batch_dims = 1 : i64, dim = 2 : i64} : (tensor<?x?x?xf32>, tensor<?x?xi32>) -> tensor<*xf32>
4121  %0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32>
4122  %1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<?x?x?xf32>, tensor<?x?xi32>, tensor<1xi32>) -> tensor<*xf32>
4123  return %1 : tensor<*xf32>
4124}
4125
4126// CHECK-LABEL: @gather_v2_unranked
4127func @gather_v2_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xi32>) -> tensor<*xf32> {
4128  // CHECK: tf.GatherV2
4129  %0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32>
4130  %1 = "tf.GatherV2"(%arg0, %arg1, %0) {batch_dims = -1 : i64} : (tensor<*xf32>, tensor<*xi32>, tensor<1xi32>) -> tensor<*xf32>
4131  return %1 : tensor<*xf32>
4132}
4133
4134//===----------------------------------------------------------------------===//
4135// tf.StridedSliceGrad legalization
4136//===----------------------------------------------------------------------===//
4137
4138// CHECK-LABEL: strided_slice_grad
4139// CHECK-SAME: [[GRAD:%.*]]: tensor<4x16x1022xf32>
4140func @strided_slice_grad(%grad: tensor<4x16x1022xf32>) -> tensor<4x128x1024xf32> {
4141
4142  // For StridedSlice
4143  // Dim #:        0,   1,    2
4144  // Input shape: [4, 128, 1024]
4145  // Begin:        1,   4,   -3
4146  // End:          8,  65,   42
4147  // Stride:       1,   4,   -1
4148  // Begin mask:   1,   0,    0  (= 1)
4149  // End mask:     0,   0,    1  (= 4)
4150
4151  // So result shape:
4152  // Dim #0: begin mask (1) -> begin = 0; end 8 canonicalized to 4: so 4
4153  // Dim #1: 4 to 65 stride 4: so 16
4154  // Dim #2: begin -3 + 1024 = 1021; end mask (1) -> end = -1: so 1022
4155  // result shape: [4, 16, 1022]
4156
4157  // To pad back:
4158  // Dim #:        0,   1,   2
4159  // Pad low:      0,   4,   0
4160  // Pad interm:   0,   3,   0
4161  // Pad high:     0,  63,   2
4162
4163  %shape = "tf.Const"() {value = dense<[4, 128, 1024]> : tensor<3xi32>} : () -> (tensor<3xi32>)
4164  %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>)
4165  %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>)
4166  %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>)
4167
4168  // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"(%arg0) : (tensor<4x16x1022xf32>) -> tensor<4x16x1022xf32>
4169  // CHECK: [[REVERSE:%.*]] = "mhlo.reverse"([[RESHAPE]]) {dimensions = dense<2> : tensor<1xi64>} : (tensor<4x16x1022xf32>) -> tensor<4x16x1022xf32>
4170  // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4171  // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REVERSE]], [[ZERO]]) {edge_padding_high = dense<[0, 63, 2]> : tensor<3xi64>, edge_padding_low = dense<[0, 4, 0]> : tensor<3xi64>, interior_padding = dense<[0, 3, 0]> : tensor<3xi64>} : (tensor<4x16x1022xf32>, tensor<f32>) -> tensor<4x128x1024xf32>
4172
4173  %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 1, end_mask = 4} : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<4x16x1022xf32>) -> tensor<4x128x1024xf32>
4174  // CHECK: return [[PAD]]
4175  return %0: tensor<4x128x1024xf32>
4176}
4177
4178// CHECK-LABEL: strided_slice_grad_shrink_axis_mask
4179// CHECK-SAME: [[GRAD:%.*]]: tensor<8xf32>
4180func @strided_slice_grad_shrink_axis_mask(%grad: tensor<8xf32>) -> tensor<4x8xf32> {
4181  // Input to StridedSlice was of shape 4x8xf32
4182  // Strided slice gets input[2:3, 0:8]
4183  // shrink_axis_mask is 1 denoting that dim#0 is shrunk. So the output is 8xf32
4184  // which is the shape of gradient.
4185  // StridedSliceGrad would reshape the gradient to 1x8xf32 and
4186  // then pad to match the shape of input 4x8xf32.
4187
4188  %shape = "tf.Const"() {value = dense<[4, 8]> : tensor<2xi32>} : () -> (tensor<2xi32>)
4189  %begin = "tf.Const"() {value = dense<[2, 0]> : tensor<2xi32>} : () -> (tensor<2xi32>)
4190  %end = "tf.Const"() {value = dense<[3, 8]> : tensor<2xi32>} : () -> (tensor<2xi32>)
4191  %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>)
4192
4193  // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[GRAD]]) : (tensor<8xf32>) -> tensor<1x8xf32>
4194  // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4195  // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]])
4196  // CHECK-DAG-SAME: edge_padding_low = dense<[2, 0]> : tensor<2xi64>
4197  // CHECK-DAG-SAME: edge_padding_high = dense<[1, 0]> : tensor<2xi64>
4198  // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<2xi64>
4199  %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, shrink_axis_mask = 1} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<8xf32>) -> tensor<4x8xf32>
4200
4201  // CHECK: return [[PAD]] : tensor<4x8xf32>
4202  return %0 : tensor<4x8xf32>
4203}
4204
4205// CHECK-LABEL: strided_slice_grad_new_axis_mask
4206// CHECK-SAME: [[GRAD:%.*]]: tensor<1x2xf32>
4207func @strided_slice_grad_new_axis_mask(%grad: tensor<1x2xf32>) -> tensor<8xf32> {
4208  // Input to StridedSlice was of shape 8xf32
4209  // Strided slice gets input[tf.new_axis, 2:4]
4210  // new_axis_mask is 1 denoting new axis is inserted at dim#0. So the output is
4211  // 1x2xf32 which is the shape of gradient.
4212  // StridedSliceGrad would reshape the gradient to 2xf32 and
4213  // then pad to match the shape of input 4x8xf32.
4214
4215  %shape = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>)
4216  %begin = "tf.Const"() {value = dense<[0, 2]> : tensor<2xi32>} : () -> (tensor<2xi32>)
4217  %end = "tf.Const"() {value = dense<[0, 4]> : tensor<2xi32>} : () -> (tensor<2xi32>)
4218  %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>)
4219
4220  // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[GRAD]]) : (tensor<1x2xf32>) -> tensor<2xf32>
4221  // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4222  // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]])
4223  // CHECK-DAG-SAME: edge_padding_low = dense<2> : tensor<1xi64>
4224  // CHECK-DAG-SAME: edge_padding_high = dense<4> : tensor<1xi64>
4225  // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<1xi64>
4226  %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, new_axis_mask = 1} : (tensor<1xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<1x2xf32>) -> tensor<8xf32>
4227
4228  // CHECK: return [[PAD]] : tensor<8xf32>
4229  return %0 : tensor<8xf32>
4230}
4231
4232// CHECK-LABEL: strided_slice_grad_ellipsis_mask
4233// CHECK-SAME: [[GRAD:%.*]]: tensor<2x4x8xf32>
4234func @strided_slice_grad_ellipsis_mask(%grad: tensor<2x4x8xf32>) -> tensor<4x4x8xf32> {
4235  // Input to StridedSlice was of shape 4x4x8xf32
4236  // Strided slice gets input[2:4, ...]
4237  // ellipsis_mask is 2 denoting that slice contains all elements in dim#1 and
4238  // dim#2, ignoring begin and end indices for these dimensions. So the output
4239  // is 2x4x8xf32 which is the shape of gradient.
4240  // StridedSliceGrad would pad the gradient to match the shape of
4241  // input 4x4x8xf32.
4242
4243  %shape = "tf.Const"() {value = dense<[4, 4, 8]> : tensor<3xi32>} : () -> (tensor<3xi32>)
4244  %begin = "tf.Const"() {value = dense<[2, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>)
4245  %end = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi32>} : () -> (tensor<2xi32>)
4246  %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>)
4247
4248  // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[GRAD]]) : (tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
4249  // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4250  // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]])
4251  // CHECK-DAG-SAME: edge_padding_low = dense<[2, 0, 0]> : tensor<3xi64>
4252  // CHECK-DAG-SAME: edge_padding_high = dense<0> : tensor<3xi64>
4253  // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<3xi64>
4254  %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, ellipsis_mask = 2} : (tensor<3xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2x4x8xf32>) -> tensor<4x4x8xf32>
4255
4256  // CHECK: return [[PAD]] : tensor<4x4x8xf32>
4257  return %0 : tensor<4x4x8xf32>
4258}
4259
4260
4261// CHECK-LABEL: strided_slice_grad_all_masks
4262// CHECK-SAME: [[GRAD:%.*]]: tensor<1x4x8x8x10x2x1xf32>
4263func @strided_slice_grad_all_masks(%grad: tensor<1x4x8x8x10x2x1xf32>) -> tensor<2x4x8x16x32x64xf32> {
4264  // For StridedSlice input[1, tf.new_axis, ..., 8:, :10, 2:6:2, tf.new_axis]
4265  // New axis mask is at index 1 and 6 of sparse spec, so
4266  // new_axis_mask = 2^1 + 2^6 = 66
4267  // The ellipsis mask is applied to dim #1, #2 of input i.e, we get
4268  // canonicalized slice input[1, :, :, 8:, :10, 2:6:2]
4269  // The StridedSliceGrad op would propogate the gradient for the sliced tensor
4270  // to the original input tensor by padding with zeroes.
4271
4272  %shape = "tf.Const"() {value = dense<[2, 4, 8, 16, 32, 64]> : tensor<6xi32>} : () -> (tensor<6xi32>)
4273  %begin = "tf.Const"() {value = dense<[1, 0, 0, 8, 1, 2, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>)
4274  %end = "tf.Const"() {value = dense<[2, 0, 0, 10, 10, 6, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>)
4275  %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 1, 2, 1]> : tensor<7xi32>} : () -> (tensor<7xi32>)
4276
4277  // Remove 2 new axes (at index 1 and 6) and 1 shrink axis (at index 0)
4278  // CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[GRAD]]) : (tensor<1x4x8x8x10x2x1xf32>) -> tensor<1x4x8x8x10x2xf32>
4279  // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4280  // The edge_padding_low, edge_padding_high and interior_padding attributes of
4281  // mhlo.pad would reflect the padding required to get the shape of the
4282  // input of StridedSlice op.
4283  // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZERO]])
4284  // CHECK-DAG-SAME: edge_padding_low = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64>
4285  // CHECK-DAG-SAME: edge_padding_high = dense<[0, 0, 0, 0, 22, 59]> : tensor<6xi64>
4286  // CHECK-DAG-SAME: interior_padding = dense<[0, 0, 0, 0, 0, 1]> : tensor<6xi64>
4287  %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 16, end_mask = 8, shrink_axis_mask = 1, ellipsis_mask = 4, new_axis_mask = 66} : (tensor<6xi32>, tensor<7xi32>, tensor<7xi32>, tensor<7xi32>, tensor<1x4x8x8x10x2x1xf32>) -> tensor<2x4x8x16x32x64xf32>
4288
4289  // CHECK: return [[PAD]] : tensor<2x4x8x16x32x64xf32>
4290  return %0 : tensor<2x4x8x16x32x64xf32>
4291}
4292
4293// CHECK-LABEL: @tensor_scatter_update
4294func @tensor_scatter_update(%tensor: tensor<?x?x?xf32>, %indices: tensor<?x2xi32>, %updates: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
4295  // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
4296  // CHECK:  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
4297  // CHECK:    "mhlo.return"(%arg4) : (tensor<f32>) -> ()
4298  // CHECK:  })
4299  // CHECK-SAME: indices_are_sorted = false
4300  // CHECK-SAME: scatter_dimension_numbers
4301  // CHECK-SAME:   index_vector_dim = 1 : i64
4302  // CHECK-SAME:   inserted_window_dims = dense<[0, 1]> : tensor<2xi64>
4303  // CHECK-SAME:   scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>
4304  // CHECK-SAME:   update_window_dims = dense<1> : tensor<1xi64>
4305  // CHECK-SAME: unique_indices = false
4306  %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor<?x?x?xf32>, tensor<?x2xi32>, tensor<?x?xf32>) -> tensor<?x?x?xf32>
4307  return %0 : tensor<?x?x?xf32>
4308}
4309
4310//===----------------------------------------------------------------------===//
4311// tf.RandomShuffle legalization
4312//===----------------------------------------------------------------------===//
4313
4314// CHECK-LABEL: @random_shuffle_first_dim_1
4315// CHECK-SAME: [[INPUT:%.*]]: tensor<1x?xf32>
4316func @random_shuffle_first_dim_1(%input: tensor<1x?xf32>) -> tensor<1x?xf32> {
4317  %0 = "tf.RandomShuffle"(%input) : (tensor<1x?xf32>) -> (tensor<1x?xf32>)
4318  // CHECK-NEXT: return [[INPUT]]
4319  return %0: tensor<1x?xf32>
4320}
4321
4322// CHECK-LABEL: @random_shuffle_1D_16
4323// CHECK-SAME: [[INPUT:%.*]]: tensor<16xf32>
4324func @random_shuffle_1D_16(%input: tensor<16xf32>) -> tensor<16xf32> {
4325  // CHECK: [[SHAPE:%.*]] = mhlo.constant dense<16> : tensor<1xi64>
4326  // CHECK: [[LOWER:%.*]] = mhlo.constant dense<0> : tensor<i32>
4327  // CHECK: [[UPPER:%.*]] = mhlo.constant dense<-1> : tensor<i32>
4328  // CHECK: [[RNG:%.*]] = "mhlo.rng_uniform"([[LOWER]], [[UPPER]], [[SHAPE]])
4329  // CHECK: [[SORT:%.*]]:2 = "mhlo.sort"([[RNG]], [[INPUT]]) ( {
4330  // CHECK: ^{{.*}}([[ARG1:%.*]]: tensor<i32>, [[ARG2:%.*]]: tensor<i32>, {{.*}}: tensor<f32>, {{.*}}: tensor<f32>):
4331  // CHECK:   "mhlo.compare"([[ARG1]], [[ARG2]]) {comparison_direction = "LT"}
4332  // CHECK: }) {dimension = -1 : i64, is_stable = true} : (tensor<16xi32>, tensor<16xf32>) -> (tensor<16xi32>, tensor<16xf32>)
4333  // CHECK: return [[SORT]]#1
4334  %0 = "tf.RandomShuffle"(%input) : (tensor<16xf32>) -> (tensor<16xf32>)
4335  return %0: tensor<16xf32>
4336}
4337
4338// CHECK-LABEL: @random_shuffle_1D_10240
4339func @random_shuffle_1D_10240(%input: tensor<10240xf32>) -> tensor<10240xf32> {
4340  // CHECK: mhlo.rng_uniform
4341  // CHECK: mhlo.sort
4342  // CHECK: mhlo.rng_uniform
4343  // CHECK: mhlo.sort
4344  %0 = "tf.RandomShuffle"(%input) : (tensor<10240xf32>) -> (tensor<10240xf32>)
4345  return %0: tensor<10240xf32>
4346}
4347
4348// CHECK-LABEL: @random_shuffle_3D
4349// CHECK-SAME: [[INPUT:%.*]]: tensor<4x?x16xf32>
4350func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> {
4351  // CHECK: [[INDICES:%.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
4352
4353  // CHECK: [[RNG_SHAPE:%.*]] = mhlo.constant dense<4> : tensor<1xi64>
4354  // CHECK: [[RNG_LOWER:%.*]] = mhlo.constant dense<0> : tensor<i32>
4355  // CHECK: [[RNG_UPPER:%.*]] = mhlo.constant dense<4> : tensor<i32>
4356  // CHECK: [[SWAPS:%.*]] = "mhlo.rng_uniform"([[RNG_LOWER]], [[RNG_UPPER]], [[RNG_SHAPE]])
4357
4358  // CHECK: [[IV_INIT:%.*]] = mhlo.constant dense<0> : tensor<i32>
4359  // CHECK: [[WHILE_INIT:%.*]] = "mhlo.tuple"([[IV_INIT]], [[SWAPS]], [[INDICES]])
4360
4361  // CHECK: [[WHILE_OUT:%.*]] = "mhlo.while"([[WHILE_INIT]]) ( {
4362  // CHECK: ^{{.*}}([[COND_ARG:%.*]]: tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>>):
4363  // CHECK:   [[IV:%.*]] = "mhlo.get_tuple_element"([[COND_ARG]]) {index = 0 : i32}
4364  // CHECK:   [[LIMIT:%.*]] = mhlo.constant dense<4> : tensor<i32>
4365  // CHECK:   [[CMP:%.*]] = "mhlo.compare"([[IV]], [[LIMIT]]) {comparison_direction = "LT"}
4366  // CHECK:   "mhlo.return"([[CMP]])
4367  // CHECK: },  {
4368  // CHECK: ^{{.*}}([[BODY_ARG:%.*]]: tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>>):
4369  // CHECK:   [[IV:%.*]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 0 : i32}
4370  // CHECK:   [[SWAPS:%.*]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 1 : i32}
4371  // CHECK:   [[INDICES:%.*]] = "mhlo.get_tuple_element"([[BODY_ARG]]) {index = 2 : i32}
4372  // CHECK:   [[SRC_IDX:%.*]] = "mhlo.dynamic-slice"([[INDICES]], [[IV]]) {slice_sizes = dense<1> : tensor<i64>} : (tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
4373  // CHECK:   [[SWP_IDX:%.*]] = "mhlo.dynamic-slice"([[SWAPS]], [[IV]]) {slice_sizes = dense<1> : tensor<i64>} : (tensor<4xi32>, tensor<i32>) -> tensor<1xi32>
4374  // CHECK:   [[SWP:%.*]] = "mhlo.reshape"([[SWP_IDX]]) : (tensor<1xi32>) -> tensor<i32>
4375  // CHECK:   [[TGT_IDX:%.*]] = "mhlo.dynamic-slice"([[INDICES]], [[SWP]]) {slice_sizes = dense<1> : tensor<i64>}
4376  // CHECK:   [[INDICES1:%.*]] = "mhlo.dynamic-update-slice"([[INDICES]], [[TGT_IDX]], [[IV]]) : (tensor<4xi32>, tensor<1xi32>, tensor<i32>) -> tensor<4xi32>
4377  // CHECK:   [[INDICES2:%.*]] = "mhlo.dynamic-update-slice"([[INDICES1]], [[SRC_IDX]], [[SWP]]) : (tensor<4xi32>, tensor<1xi32>, tensor<i32>) -> tensor<4xi32>
4378  // CHECK:   [[ONE:%.*]] = mhlo.constant dense<1> : tensor<i32>
4379  // CHECK:   [[NEW_IV:%.*]] = chlo.broadcast_add [[IV]], [[ONE]]
4380  // CHECK:   [[NEW_TUPLE:%.*]] = "mhlo.tuple"([[NEW_IV]], [[SWAPS]], [[INDICES2]])
4381  // CHECK:   "mhlo.return"([[NEW_TUPLE]])
4382  // CHECK: }) : (tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>>) -> tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>>
4383
4384  // CHECK: [[SWAPED_INDICES:%.*]] = "mhlo.get_tuple_element"([[WHILE_OUT]]) {index = 2 : i32} : (tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>>) -> tensor<4xi32>
4385  // CHECK: [[GATHER:%.*]] = "mhlo.gather"([[INPUT]], [[SWAPED_INDICES]])
4386  // CHECK-SAME: dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 1 : i64, offset_dims = dense<[1, 2]> : tensor<2xi64>, start_index_map = dense<0> : tensor<1xi64>}
4387  // CHECK-SAME: indices_are_sorted = false
4388  // CHECK-SAME: slice_sizes = dense<[1, -1, 16]> : tensor<3xi64>
4389  // CHECK: (tensor<4x?x16xf32>, tensor<4xi32>) -> tensor<4x?x16xf32>
4390
4391  // CHECK: return [[GATHER]]
4392
4393  %0 = "tf.RandomShuffle"(%input) : (tensor<4x?x16xf32>) -> (tensor<4x?x16xf32>)
4394  return %0: tensor<4x?x16xf32>
4395}
4396
4397//===----------------------------------------------------------------------===//
4398// tf.AvgPool legalization
4399//===----------------------------------------------------------------------===//
4400
4401// CHECK-LABEL:   @avgpool_valid_padding
4402// CHECK-SAME:      [[ARG:%.+]]: tensor<2x12x21x7xf16>
4403// CHECK:           [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x12x21x7xf16>) -> tensor<2x12x21x7xf32>
4404// CHECK:           [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4405// CHECK:           [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( {
4406// CHECK:           ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>):
4407// CHECK:             [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]]
4408// CHECK:             "mhlo.return"([[ADD]])
4409// CHECK:           })
4410// CHECK-SAME:        window_dimensions = dense<[1, 2, 2, 1]>
4411// CHECK-SAME:        window_strides = dense<[1, 4, 4, 1]>
4412// CHECK-SAME:        -> tensor<2x3x5x7xf32>
4413// CHECK:           [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
4414// CHECK:           [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]]
4415// CHECK-SAME:        broadcast_dimensions = dense<>
4416// CHECK-SAME:        -> tensor<2x3x5x7xf32>
4417// CHECK:           [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]])
4418// CHECK-SAME:        -> tensor<2x3x5x7xf16>
4419// CHECK:           return [[CONV16]]
4420func @avgpool_valid_padding(%arg0: tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16> {
4421  %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16>
4422  return %0 : tensor<2x3x5x7xf16>
4423}
4424
4425// CHECK-LABEL:   @avgpool_3d_valid_padding
4426// CHECK-SAME:      [[ARG:%.+]]: tensor<2x4x12x21x7xf16>
4427// CHECK:           [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x12x21x7xf32>
4428// CHECK:           [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4429// CHECK:           [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( {
4430// CHECK:           ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>):
4431// CHECK:           [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]]
4432// CHECK:             "mhlo.return"([[ADD]])
4433// CHECK:           })
4434// CHECK-SAME:        window_dimensions = dense<[1, 1, 2, 2, 1]>
4435// CHECK-SAME:        window_strides = dense<[1, 1, 4, 4, 1]>
4436// CHECK-SAME:        -> tensor<2x4x3x5x7xf32>
4437// CHECK:           [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
4438// CHECK:           [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]]
4439// CHECK-SAME:        broadcast_dimensions = dense<>
4440// CHECK-SAME:        -> tensor<2x4x3x5x7xf32>
4441// CHECK:           [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]])
4442// CHECK-SAME:        -> tensor<2x4x3x5x7xf16>
4443// CHECK:           return [[CONV16]]
4444func @avgpool_3d_valid_padding(%arg0: tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16> {
4445  %0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16>
4446  return %0 : tensor<2x4x3x5x7xf16>
4447}
4448
4449// CHECK-LABEL:   @avgpool_nchw_format
4450// CHECK-SAME:      [[ARG:%.+]]: tensor<2x7x12x21xf16>
4451// CHECK:           [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x7x12x21xf16>) -> tensor<2x7x12x21xf32>
4452// CHECK:           [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4453// CHECK:           [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( {
4454// CHECK:           ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>):
4455// CHECK:             [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]]
4456// CHECK:             "mhlo.return"([[ADD]])
4457// CHECK:           })
4458// CHECK-SAME:        window_dimensions = dense<[1, 1, 2, 2]>
4459// CHECK-SAME:        window_strides = dense<[1, 1, 4, 4]>
4460// CHECK-SAME:        -> tensor<2x7x3x5xf32>
4461// CHECK:           [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
4462// CHECK:           [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]]
4463// CHECK-SAME:        broadcast_dimensions = dense<>
4464// CHECK-SAME:        -> tensor<2x7x3x5xf32>
4465// CHECK:           [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]])
4466// CHECK-SAME:        -> tensor<2x7x3x5xf16>
4467// CHECK:           return [[CONV16]]
4468func @avgpool_nchw_format(%arg0: tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16> {
4469  %0 = "tf.AvgPool"(%arg0) {data_format = "NCHW", ksize = [1, 1, 2, 2], padding = "VALID", strides = [1, 1, 4, 4]} : (tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16>
4470  return %0 : tensor<2x7x3x5xf16>
4471}
4472
4473// CHECK-LABEL:   @avgpool_3d_ncdhw_format
4474// CHECK-SAME:      [[ARG:%.+]]: tensor<2x7x4x12x21xf16>
4475// CHECK:           [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x12x21xf32>
4476// CHECK:           [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4477// CHECK:           [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( {
4478// CHECK:           ^bb0([[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<f32>):
4479// CHECK:             [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]]
4480// CHECK:             "mhlo.return"([[ADD]])
4481// CHECK:           })
4482// CHECK-SAME:        window_dimensions = dense<[1, 1, 1, 2, 2]>
4483// CHECK-SAME:        window_strides = dense<[1, 1, 1, 4, 4]>
4484// CHECK-SAME:        -> tensor<2x7x4x3x5xf32>
4485// CHECK:           [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
4486// CHECK:           [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]]
4487// CHECK-SAME:        broadcast_dimensions = dense<>
4488// CHECK-SAME:        -> tensor<2x7x4x3x5xf32>
4489// CHECK:           [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]])
4490// CHECK-SAME:        -> tensor<2x7x4x3x5xf16>
4491// CHECK:           return [[CONV16]]
4492func @avgpool_3d_ncdhw_format(%arg0: tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16> {
4493  %0 = "tf.AvgPool3D"(%arg0) {data_format = "NCDHW", ksize = [1, 1, 1, 2, 2], padding = "VALID", strides = [1, 1, 1, 4, 4]} : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16>
4494  return %0 : tensor<2x7x4x3x5xf16>
4495}
4496
4497// CHECK-LABEL:   @avgpool_same_padding(
4498// CHECK-SAME:      %[[ARG0:.*]]: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32>
4499// CHECK:           %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4500// CHECK:           %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) ( {
4501// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
4502// CHECK:             %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
4503// CHECK:             "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> ()
4504// CHECK:           })
4505// CHECK-SAME:        padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]>
4506// CHECK-SAME:        window_dimensions = dense<[1, 5, 2, 1]>
4507// CHECK-SAME:        window_strides = dense<[1, 3, 4, 1]>
4508// CHECK-SAME:        -> tensor<2x4x6x7xf32>
4509// CHECK:           %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x12x21x7xf32>
4510// CHECK:           %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) ( {
4511// CHECK:           ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
4512// CHECK:             %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
4513// CHECK:             "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> ()
4514// CHECK:           })
4515// CHECK-SAME:        padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]>
4516// CHECK-SAME:        window_dimensions = dense<[1, 5, 2, 1]>
4517// CHECK-SAME:        window_strides = dense<[1, 3, 4, 1]>
4518// CHECK-SAME:        -> tensor<2x4x6x7xf32>
4519// CHECK:           %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]] : tensor<2x4x6x7xf32>
4520// CHECK:           return %[[RESULT]] : tensor<2x4x6x7xf32>
4521// CHECK:         }
4522func @avgpool_same_padding(%arg0: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> {
4523  %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 5, 2, 1], padding = "SAME", strides = [1, 3, 4, 1]} : (tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32>
4524  return %0 : tensor<2x4x6x7xf32>
4525}
4526
4527// CHECK-LABEL:   @avgpool_3d_same_padding(
4528// CHECK-SAME:      %[[ARG0:.*]]: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32>
4529// CHECK:           %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4530// CHECK:           %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) ( {
4531// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
4532// CHECK:             %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
4533// CHECK:             "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> ()
4534// CHECK:           })
4535// CHECK-SAME:        padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]>
4536// CHECK-SAME:        window_dimensions = dense<[1, 1, 5, 2, 1]>
4537// CHECK-SAME:        window_strides = dense<[1, 1, 3, 4, 1]>
4538// CHECK-SAME:        -> tensor<2x4x4x6x7xf32>
4539// CHECK:           %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x4x12x21x7xf32>
4540// CHECK:           %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) ( {
4541// CHECK:           ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
4542// CHECK:             %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
4543// CHECK:             "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> ()
4544// CHECK:           })
4545// CHECK-SAME:        padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]>
4546// CHECK-SAME:        window_dimensions = dense<[1, 1, 5, 2, 1]>
4547// CHECK-SAME:        window_strides = dense<[1, 1, 3, 4, 1]>
4548// CHECK-SAME:        -> tensor<2x4x4x6x7xf32>
4549// CHECK:           %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]]
4550// CHECK:           return %[[RESULT]] : tensor<2x4x4x6x7xf32>
4551// CHECK:         }
4552func @avgpool_3d_same_padding(%arg0: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> {
4553  %0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 5, 2, 1], padding = "SAME", strides = [1, 1, 3, 4, 1]} : (tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32>
4554  return %0 : tensor<2x4x4x6x7xf32>
4555}
4556
4557//===----------------------------------------------------------------------===//
4558// AvgPoolGrad op legalizations.
4559//===----------------------------------------------------------------------===//
4560
4561// CHECK-LABEL:   @avgpool_grad_valid_padding(
4562// CHECK-SAME:      %[[OUT_GRAD:.*]]: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> {
4563// CHECK:           %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4564// CHECK:           %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
4565// CHECK:           %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]]
4566// CHECK_SAME:        broadcast_dimensions = dense<>
4567// CHECK_SAME:        -> tensor<10x12x16x64xf32>
4568// CHECK:           %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
4569// CHECK-SAME:        edge_padding_high = dense<[0, 1, 1, 0]>
4570// CHECK-SAME:        edge_padding_low = dense<[0, 1, 1, 0]>
4571// CHECK-SAME:        interior_padding = dense<[0, 1, 1, 0]>
4572// CHECK-SAME:        -> tensor<10x25x33x64xf32>
4573// CHECK:           %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( {
4574// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
4575// CHECK:             %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
4576// CHECK:             "mhlo.return"(%[[SUM]]) : (tensor<f32>) -> ()
4577// CHECK:           })
4578// CHECK-SAME:        window_dimensions = dense<[1, 2, 2, 1]>
4579// CHECK-SAME:        window_strides = dense<1>
4580// CHECK-SAME:        -> tensor<10x24x32x64xf32>
4581// CHECK:           return %[[RESULT]] : tensor<10x24x32x64xf32>
4582func @avgpool_grad_valid_padding(%grad: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> {
4583  %orig_input_shape = "tf.Const"() {value = dense<[10, 24, 32, 64]> : tensor<4xi32>} : () -> (tensor<4xi32>)
4584  %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) {
4585     data_format = "NHWC",
4586     ksize = [1, 2, 2, 1],
4587     padding = "VALID",
4588     strides = [1, 2, 2, 1]
4589  } : (tensor<4xi32>, tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32>
4590  return %result : tensor<10x24x32x64xf32>
4591}
4592
4593// CHECK-LABEL:   @avgpool_3d_grad_valid_padding(
4594// CHECK-SAME:      %[[OUT_GRAD:.*]]: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> {
4595// CHECK:           %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4596// CHECK:           %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
4597// CHECK:           %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<10x8x12x16x64xf32>, tensor<f32>) -> tensor<10x8x12x16x64xf32>
4598// CHECK:           %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
4599// CHECK-SAME:        edge_padding_high = dense<[0, 0, 1, 1, 0]>
4600// CHECK-SAME:        edge_padding_low = dense<[0, 0, 1, 1, 0]>
4601// CHECK-SAME:        interior_padding = dense<[0, 0, 1, 1, 0]>
4602// CHECK-SAME:        -> tensor<10x8x25x33x64xf32>
4603// CHECK:           %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( {
4604// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
4605// CHECK:             %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
4606// CHECK:             "mhlo.return"(%[[SUM]]) : (tensor<f32>) -> ()
4607// CHECK:           })
4608// CHECK-SAME:        window_dimensions = dense<[1, 1, 2, 2, 1]>
4609// CHECK-SAME:        window_strides = dense<1>
4610// CHECK-SAME:        -> tensor<10x8x24x32x64xf32>
4611// CHECK:           return %[[RESULT]] : tensor<10x8x24x32x64xf32>
4612func @avgpool_3d_grad_valid_padding(%grad: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> {
4613  %orig_input_shape = "tf.Const"() {value = dense<[10, 8, 24, 32, 64]> : tensor<5xi32>} : () -> (tensor<5xi32>)
4614  %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) {
4615    data_format = "NDHWC",
4616    ksize = [1, 1, 2, 2, 1],
4617    padding = "VALID",
4618    strides = [1, 1, 2, 2, 1]} : (tensor<5xi32>, tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32>
4619  return %result : tensor<10x8x24x32x64xf32>
4620}
4621
4622// CHECK-LABEL:   @avgpool_grad_same_padding(
4623// CHECK-SAME:      %[[OUT_GRAD:.*]]: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> {
4624// CHECK:           %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4625// CHECK:           %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x13x25x9xf32>
4626// CHECK:           %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( {
4627// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
4628// CHECK:             %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
4629// CHECK:             "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> ()
4630// CHECK:           })
4631// CHECK-SAME:        padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]>
4632// CHECK-SAME:        window_dimensions = dense<[1, 2, 3, 1]>
4633// CHECK-SAME:        window_strides = dense<[1, 4, 4, 1]>
4634// CHECK-SAME:        -> tensor<2x4x7x9xf32>
4635// CHECK:           %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x4x7x9xf32>
4636// CHECK:           %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
4637// CHECK-SAME:        edge_padding_high = dense<[0, 0, 1, 0]>
4638// CHECK-SAME:        edge_padding_low = dense<[0, 1, 1, 0]>
4639// CHECK-SAME:        interior_padding = dense<[0, 3, 3, 0]>
4640// CHECK-SAME:        -> tensor<2x14x27x9xf32>
4641// CHECK:           %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( {
4642// CHECK:           ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
4643// CHECK:             %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
4644// CHECK:             "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> ()
4645// CHECK:           })
4646// CHECK-SAME:        window_dimensions = dense<[1, 2, 3, 1]>
4647// CHECK-SAME:        window_strides = dense<1>
4648// CHECK-SAME:        -> tensor<2x13x25x9xf32>
4649// CHECK:           return %[[RESULT]] : tensor<2x13x25x9xf32>
4650func @avgpool_grad_same_padding(%grad: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> {
4651  %orig_input_shape = "tf.Const"() {value = dense<[2, 13, 25, 9]> : tensor<4xi32>} : () -> (tensor<4xi32>)
4652  %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) {
4653     data_format = "NHWC",
4654     ksize = [1, 2, 3, 1],
4655     padding = "SAME",
4656     strides = [1, 4, 4, 1]
4657  } : (tensor<4xi32>, tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32>
4658  return %result : tensor<2x13x25x9xf32>
4659}
4660
4661// CHECK-LABEL:   @avgpool_3d_grad_same_padding(
4662// CHECK-SAME:      %[[OUT_GRAD:.*]]: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> {
4663// CHECK:           %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4664// CHECK:           %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x8x13x25x9xf32>
4665// CHECK:           %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( {
4666// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
4667// CHECK:             %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
4668// CHECK:             "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> ()
4669// CHECK:           })
4670// CHECK-SAME:        padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]>
4671// CHECK-SAME:        window_dimensions = dense<[1, 1, 2, 3, 1]>
4672// CHECK-SAME:        window_strides = dense<[1, 1, 4, 4, 1]>
4673// CHECK-SAME:        -> tensor<2x8x4x7x9xf32>
4674// CHECK:           %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x8x4x7x9xf32>
4675// CHECK:           %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
4676// CHECK-SAME:        edge_padding_high = dense<[0, 0, 0, 1, 0]>
4677// CHECK-SAME:        edge_padding_low = dense<[0, 0, 1, 1, 0]>
4678// CHECK-SAME:        interior_padding = dense<[0, 0, 3, 3, 0]>
4679// CHECK-SAME:        -> tensor<2x8x14x27x9xf32>
4680// CHECK:           %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( {
4681// CHECK:           ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
4682// CHECK:             %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
4683// CHECK:             "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> ()
4684// CHECK:           })
4685// CHECK-SAME:        window_dimensions = dense<[1, 1, 2, 3, 1]>
4686// CHECK-SAME:        window_strides = dense<1>
4687// CHECK-SAME:        -> tensor<2x8x13x25x9xf32>
4688// CHECK:           return %[[RESULT]] : tensor<2x8x13x25x9xf32>
4689func @avgpool_3d_grad_same_padding(%grad: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> {
4690  %orig_input_shape = "tf.Const"() {value = dense<[2, 8, 13, 25, 9]> : tensor<5xi32>} : () -> (tensor<5xi32>)
4691  %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) {
4692    data_format = "NDHWC",
4693    ksize = [1, 1, 2, 3, 1],
4694    padding = "SAME",
4695    strides = [1, 1, 4, 4, 1]} : (tensor<5xi32>, tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32>
4696  return %result : tensor<2x8x13x25x9xf32>
4697}
4698
4699// CHECK-LABEL:   @avgpool_grad_nchw_format(
4700// CHECK-SAME:      %[[OUT_GRAD:.*]]: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> {
4701// CHECK:           %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4702// CHECK:           %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x13x25xf32>
4703// CHECK:           %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( {
4704// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
4705// CHECK:             %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
4706// CHECK:             "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> ()
4707// CHECK:           })
4708// CHECK-SAME:        padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1]]>
4709// CHECK-SAME:        window_dimensions = dense<[1, 1, 2, 3]>
4710// CHECK-SAME:        window_strides = dense<[1, 1, 4, 4]>
4711// CHECK-SAME:        -> tensor<2x9x4x7xf32>
4712// CHECK:           %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x4x7xf32>
4713// CHECK:           %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
4714// CHECK-SAME:        edge_padding_high = dense<[0, 0, 0, 1]>
4715// CHECK-SAME:        edge_padding_low = dense<[0, 0, 1, 1]>
4716// CHECK-SAME:        interior_padding = dense<[0, 0, 3, 3]>
4717// CHECK-SAME:        -> tensor<2x9x14x27xf32>
4718// CHECK:           %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( {
4719// CHECK:           ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
4720// CHECK:             %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
4721// CHECK:             "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> ()
4722// CHECK:           })
4723// CHECK-SAME:        window_dimensions = dense<[1, 1, 2, 3]>
4724// CHECK-SAME:        window_strides = dense<1>
4725// CHECK-SAME:        -> tensor<2x9x13x25xf32>
4726// CHECK:           return %[[RESULT]] : tensor<2x9x13x25xf32>
4727func @avgpool_grad_nchw_format(%grad: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> {
4728  %orig_input_shape = "tf.Const"() {value = dense<[2, 9, 13, 25]> : tensor<4xi32>} : () -> (tensor<4xi32>)
4729  %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) {
4730     data_format = "NCHW",
4731     ksize = [1, 1, 2, 3],
4732     padding = "SAME",
4733     strides = [1, 1, 4, 4]
4734  } : (tensor<4xi32>, tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32>
4735  return %result : tensor<2x9x13x25xf32>
4736}
4737
4738// CHECK-LABEL:   @avgpool_3d_grad_ncdwh_format(
4739// CHECK-SAME:      %[[OUT_GRAD:.*]]: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> {
4740// CHECK:           %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4741// CHECK:           %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x8x13x25xf32>
4742// CHECK:           %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) ( {
4743// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
4744// CHECK:             %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
4745// CHECK:             "mhlo.return"(%[[SUM1]]) : (tensor<f32>) -> ()
4746// CHECK:           })
4747// CHECK-SAME:        padding = dense<{{\[\[}}0, 0], [0, 0], [0, 0], [0, 1], [1, 1]]>
4748// CHECK-SAME:        window_dimensions = dense<[1, 1, 1, 2, 3]>
4749// CHECK-SAME:        window_strides = dense<[1, 1, 1, 4, 4]>
4750// CHECK-SAME:        -> tensor<2x9x8x4x7xf32>
4751// CHECK:           %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x8x4x7xf32>
4752// CHECK:           %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
4753// CHECK-SAME:        edge_padding_high = dense<[0, 0, 0, 0, 1]>
4754// CHECK-SAME:        edge_padding_low = dense<[0, 0, 0, 1, 1]>
4755// CHECK-SAME:        interior_padding = dense<[0, 0, 0, 3, 3]>
4756// CHECK-SAME:        -> tensor<2x9x8x14x27xf32>
4757// CHECK:           %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) ( {
4758// CHECK:           ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>):
4759// CHECK:             %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor<f32>
4760// CHECK:             "mhlo.return"(%[[SUM2]]) : (tensor<f32>) -> ()
4761// CHECK:           })
4762// CHECK-SAME:        window_dimensions = dense<[1, 1, 1, 2, 3]>
4763// CHECK-SAME:        window_strides = dense<1> : tensor<5xi64>
4764// CHECK-SAME:        -> tensor<2x9x8x13x25xf32>
4765// CHECK:           return %[[RESULT]] : tensor<2x9x8x13x25xf32>
4766func @avgpool_3d_grad_ncdwh_format(%grad: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> {
4767  %orig_input_shape = "tf.Const"() {value = dense<[2, 9, 8, 13, 25]> : tensor<5xi32>} : () -> (tensor<5xi32>)
4768  %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) {
4769    data_format = "NCDHW",
4770    ksize = [1, 1, 1, 2, 3],
4771    padding = "SAME",
4772    strides = [1, 1, 1, 4, 4]} : (tensor<5xi32>, tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32>
4773  return %result : tensor<2x9x8x13x25xf32>
4774}
4775
4776// CHECK-LABEL:   @avgpool_grad_bf16(
4777// CHECK-SAME:      %[[OUT_GRAD:.*]]: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> {
4778// CHECK:           %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<bf16>
4779// CHECK:           %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<bf16>
4780// CHECK:           %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]]
4781// CHECK-SAME:        broadcast_dimensions = dense<>
4782// CHECK-SAME:        -> tensor<10x12x16x64xbf16>
4783// CHECK:           %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
4784// CHECK-SAME:        edge_padding_high = dense<[0, 1, 1, 0]>
4785// CHECK-SAME:        edge_padding_low = dense<[0, 1, 1, 0]>
4786// CHECK-SAME:        interior_padding = dense<[0, 1, 1, 0]>
4787// CHECK-SAME:        -> tensor<10x25x33x64xbf16>
4788// CHECK:           %[[REDUCE_WINDOW_INPUT_CONVERTED:.*]] = "mhlo.convert"(%[[REDUCE_WINDOW_INPUT]]) : (tensor<10x25x33x64xbf16>) -> tensor<10x25x33x64xf32>
4789// CHECK:           %[[ZERO_F32:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4790// CHECK:           %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT_CONVERTED]], %[[ZERO_F32]]) ( {
4791// CHECK:           ^bb0(%[[ARG1:.*]]: tensor<f32>, %[[ARG2:.*]]: tensor<f32>):
4792// CHECK:             %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor<f32>
4793// CHECK:             "mhlo.return"(%[[SUM]]) : (tensor<f32>) -> ()
4794// CHECK:           })
4795// CHECK-SAME:        window_dimensions = dense<[1, 2, 2, 1]>
4796// CHECK-SAME:        window_strides = dense<1>
4797// CHECK-SAME:        -> tensor<10x24x32x64xf32>
4798// CHECK:           %[[RESULT_CONVERTED:.*]] = "mhlo.convert"(%[[RESULT]]) : (tensor<10x24x32x64xf32>) -> tensor<10x24x32x64xbf16>
4799// CHECK:           return %[[RESULT_CONVERTED]] : tensor<10x24x32x64xbf16>
4800func @avgpool_grad_bf16(%grad: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> {
4801  %orig_input_shape = "tf.Const"() {value = dense<[10, 24, 32, 64]> : tensor<4xi32>} : () -> (tensor<4xi32>)
4802  %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) {
4803     data_format = "NHWC",
4804     ksize = [1, 2, 2, 1],
4805     padding = "VALID",
4806     strides = [1, 2, 2, 1]
4807  } : (tensor<4xi32>, tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16>
4808  return %result : tensor<10x24x32x64xbf16>
4809}
4810
4811// CHECK-LABEL: xla_sharding
4812func @xla_sharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> {
4813  // CHECK-NEXT: "mhlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", has_side_effect = false, mhlo.sharding = ""}
4814  %0 = "tf.XlaSharding"(%arg0) {_XlaSharding = "", sharding = ""} : (tensor<4x16xf32>) -> tensor<4x16xf32>
4815  return %0 : tensor<4x16xf32>
4816}
4817
4818// CHECK-LABEL: inplace_update_one
4819func @inplace_update_one(%arg0: tensor<8x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<1xi32>) -> tensor<8x4xf32> {
4820  // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0>
4821  // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
4822  // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
4823  // CHECK-DAG: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]])
4824  // CHECK-DAG: [[UPDATE:%.+]] = "mhlo.dynamic-update-slice"(%arg0, [[SLICE2]], [[RESHAPE1]], [[CST]])
4825  %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x4xf32>, tensor<1xi32>, tensor<1x4xf32>) -> tensor<8x4xf32>
4826
4827  // CHECK: return [[UPDATE]]
4828  return %0 : tensor<8x4xf32>
4829}
4830
4831// CHECK-LABEL: inplace_update_three
4832func @inplace_update_three(%arg0: tensor<8x8x4xf32>, %arg1: tensor<3x8x4xf32>, %arg2: tensor<3xi32>) -> tensor<8x8x4xf32> {
4833  // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0>
4834  // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
4835  // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
4836  // CHECK-DAG: [[SLICE3:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<3> : tensor<1xi64>, start_indices = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
4837  // CHECK-DAG: [[SLICE4:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 8, 4]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
4838  // CHECK-DAG: [[SLICE5:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[2, 8, 4]> : tensor<3xi64>, start_indices = dense<[1, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
4839  // CHECK-DAG: [[SLICE6:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[3, 8, 4]> : tensor<3xi64>, start_indices = dense<[2, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
4840  // CHECK-DAG: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]])
4841  // CHECK-DAG: [[RESHAPE2:%.+]] = "mhlo.reshape"([[SLICE2]])
4842  // CHECK-DAG: [[RESHAPE3:%.+]] = "mhlo.reshape"([[SLICE3]])
4843  // CHECK-DAG: [[UPDATE1:%.+]] = "mhlo.dynamic-update-slice"(%arg0, [[SLICE4]], [[RESHAPE1]], [[CST]], [[CST]])
4844  // CHECK-DAG: [[UPDATE2:%.+]] = "mhlo.dynamic-update-slice"([[UPDATE1]], [[SLICE5]], [[RESHAPE2]], [[CST]], [[CST]])
4845  // CHECK-DAG: [[UPDATE3:%.+]] = "mhlo.dynamic-update-slice"([[UPDATE2]], [[SLICE6]], [[RESHAPE3]], [[CST]], [[CST]])
4846  %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x8x4xf32>, tensor<3xi32>, tensor<3x8x4xf32>) -> tensor<8x8x4xf32>
4847
4848  // CHECK:  return [[UPDATE3]] : tensor<8x8x4xf32>
4849  return %0 : tensor<8x8x4xf32>
4850}
4851
4852
4853// CHECK-LABEL: xla_dynamic_update_slice
4854func @xla_dynamic_update_slice(%arg0: tensor<4x16xf32>, %arg1: tensor<2x4xf32>, %arg2: tensor<2xi32>) -> tensor<4x16xf32> {
4855  // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
4856  // CHECK: [[RESHAPE0:%.+]] = "mhlo.reshape"([[SLICE0]]) : (tensor<1xi32>) -> tensor<i32>
4857  // CHECK: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
4858  // CHECK: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]]) : (tensor<1xi32>) -> tensor<i32>
4859  // CHECK: [[DUS:%.+]] = "mhlo.dynamic-update-slice"(%arg0, %arg1, [[RESHAPE0]], [[RESHAPE1]]) : (tensor<4x16xf32>, tensor<2x4xf32>, tensor<i32>, tensor<i32>) -> tensor<4x16xf32>
4860  // CHECK: return [[DUS]]
4861  %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4x16xf32>, tensor<2x4xf32>, tensor<2xi32>) -> tensor<4x16xf32>
4862  return %0 : tensor<4x16xf32>
4863}
4864
4865// CHECK-LABEL: xla_dynamic_update_slice2
4866func @xla_dynamic_update_slice2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg2: tensor<1xi32>) -> tensor<4xf32> {
4867  // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32>
4868  // CHECK: [[RESHAPE0:%.+]] = "mhlo.reshape"([[SLICE0]]) : (tensor<1xi32>) -> tensor<i32>
4869  // CHECK: [[DUS:%.+]] = "mhlo.dynamic-update-slice"(%arg0, %arg1, [[RESHAPE0]]) : (tensor<4xf32>, tensor<2xf32>, tensor<i32>) -> tensor<4xf32>
4870  // CHECK: return [[DUS]]
4871  %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor<1xi32>) -> tensor<4xf32>
4872  return %0 : tensor<4xf32>
4873}
4874
4875//===----------------------------------------------------------------------===//
4876// AllToAll op legalizations.
4877//===----------------------------------------------------------------------===//
4878
4879// CHECK-LABEL: func @alltoall_basic
4880func @alltoall_basic(%input: tensor<10xf32>) -> tensor<10xf32> {
4881  %group_assignment = "tf.Const" () {
4882    value = dense<[[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi32>
4883  } : () -> tensor<3x4xi32>
4884  %result = "tf.AllToAll"(%input, %group_assignment) {T = f32, concat_dimension = 1 : i64, split_count = 2 : i64, split_dimension = 0 : i64} :  (tensor<10xf32>, tensor<3x4xi32>)  -> tensor<10xf32>
4885  // CHECK: mhlo.all_to_all
4886  // CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7], [3, 5, 6, 8]]> : tensor<3x4xi64>
4887  return %result : tensor<10xf32>
4888}
4889
4890//===----------------------------------------------------------------------===//
4891// Cumsum op legalizations.
4892//===----------------------------------------------------------------------===//
4893
4894// CHECK-LABEL: func @cumsum_static
4895// CHECK-SAME: [[X:%.*]]: tensor<4xf32>
4896func @cumsum_static(%arg0: tensor<4xf32>) -> tensor<4xf32> {
4897  // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor<i32>
4898  // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[X]]) : (tensor<4xf32>) -> tensor<4xf32>
4899  // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4900  // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ( {
4901  // CHECK: ^bb0([[A:%.*]]: tensor<f32>, [[B:%.*]]: tensor<f32>):
4902  // CHECK:   [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor<f32>
4903  // CHECK:   "mhlo.return"([[SUM]]) : (tensor<f32>) -> ()
4904  // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
4905  // CHECK: [[CONVERT_REDUCE:%.*]] = "mhlo.convert"([[REDUCE]]) : (tensor<4xf32>) -> tensor<4xf32>
4906  // CHECK: return [[CONVERT_REDUCE]]
4907  %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32>
4908  %1 = "tf.Cumsum"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32>
4909  return %1 : tensor<4xf32>
4910}
4911
4912// CHECK-LABEL: func @cumsum_exclusive
4913// CHECK-SAME: [[X:%.*]]: tensor<4xf32>
4914func @cumsum_exclusive(%arg0: tensor<4xf32>) -> tensor<4xf32> {
4915  // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor<i32>
4916  // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[X]]) : (tensor<4xf32>) -> tensor<4xf32>
4917  // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4918  // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ( {
4919  // CHECK: ^bb0([[A:%.*]]: tensor<f32>, [[B:%.*]]: tensor<f32>):
4920  // CHECK:   [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor<f32>
4921  // CHECK:   "mhlo.return"([[SUM]]) : (tensor<f32>) -> ()
4922  // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
4923  // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) {edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
4924  // CHECK: [[CONVERT_REDUCE:%.*]] = "mhlo.convert"([[PAD]]) : (tensor<4xf32>) -> tensor<4xf32>
4925  // CHECK: return [[CONVERT_REDUCE]]
4926  %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32>
4927  %1 = "tf.Cumsum"(%arg0, %0) {exclusive = true, reverse = false} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32>
4928  return %1 : tensor<4xf32>
4929}
4930
4931// CHECK-LABEL: func @cumsum_reverse
4932// CHECK-SAME: [[X:%.*]]: tensor<4xf32>
4933func @cumsum_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> {
4934  // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor<i32>
4935  // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32>
4936  // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[REVERSE1]]) : (tensor<4xf32>) -> tensor<4xf32>
4937  // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4938  // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ( {
4939  // CHECK: ^bb0([[A:%.*]]: tensor<f32>, [[B:%.*]]: tensor<f32>):
4940  // CHECK:   [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor<f32>
4941  // CHECK:   "mhlo.return"([[SUM]]) : (tensor<f32>) -> ()
4942  // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
4943  // CHECK: [[CONVERT_REDUCE:%.*]] = "mhlo.convert"([[REDUCE]]) : (tensor<4xf32>) -> tensor<4xf32>
4944  // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32>
4945  // CHECK: return [[REVERSE_BACK]]
4946  %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32>
4947  %1 = "tf.Cumsum"(%arg0, %0) {exclusive = false, reverse = true} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32>
4948  return %1 : tensor<4xf32>
4949}
4950
4951// CHECK-LABEL: func @cumsum_exclusive_reverse
4952// CHECK-SAME: [[X:%.*]]: tensor<4xf32>
4953func @cumsum_exclusive_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> {
4954  // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor<i32>
4955  // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32>
4956  // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[REVERSE1]]) : (tensor<4xf32>) -> tensor<4xf32>
4957  // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
4958  // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ( {
4959  // CHECK: ^bb0([[A:%.*]]: tensor<f32>, [[B:%.*]]: tensor<f32>):
4960  // CHECK:   [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor<f32>
4961  // CHECK:   "mhlo.return"([[SUM]]) : (tensor<f32>) -> ()
4962  // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
4963  // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) {edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
4964  // CHECK: [[CONVERT_REDUCE:%.*]] = "mhlo.convert"([[PAD]]) : (tensor<4xf32>) -> tensor<4xf32>
4965  // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32>
4966  // CHECK: return [[REVERSE_BACK]]
4967  %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32>
4968  %1 = "tf.Cumsum"(%arg0, %0) {exclusive = true, reverse = true} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32>
4969  return %1 : tensor<4xf32>
4970}
4971
4972// CHECK-LABEL: func @cumsum_empty
4973func @cumsum_empty(%arg0: tensor<0xf32>) -> tensor<0xf32> {
4974  %0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
4975
4976  // CHECK: mhlo.constant dense<> : tensor<0xf32>
4977  %1 = "tf.Cumsum"(%arg0, %0) : (tensor<0xf32>, tensor<i32>) -> tensor<0xf32>
4978  return %1 : tensor<0xf32>
4979}
4980
4981// CHECK-LABEL: func @cumsum_dynamic
4982func @cumsum_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<i32>) -> tensor<?xf32> {
4983  // CHECK: "tf.Cumsum"
4984  %0 = "tf.Cumsum"(%arg0, %arg1) : (tensor<?xf32>, tensor<i32>) -> tensor<?xf32>
4985  return %0 : tensor<?xf32>
4986}
4987
4988//===----------------------------------------------------------------------===//
4989// Cumprod op legalizations.
4990//===----------------------------------------------------------------------===//
4991
4992// CHECK-LABEL: func @cumprod
4993func @cumprod(%arg0: tensor<4xf32>) -> tensor<4xf32> {
4994  // CHECK: [[INIT:%.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
4995  // CHECK: "mhlo.reduce_window"({{.*}}, [[INIT]]) ( {
4996  // CHECK:   mhlo.mul
4997  %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32>
4998  %1 = "tf.Cumprod"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32>
4999  return %1 : tensor<4xf32>
5000}
5001
5002//===----------------------------------------------------------------------===//
5003// Qr op legalization
5004//===----------------------------------------------------------------------===//
5005
5006// CHECK:  func @qr([[VAL_0:%.*]]: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>)
5007func @qr(%arg0: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) {
5008  // The tf.Qr lowering is a full algorithm that is not effective to verify with
5009  // FileCheck. Just verify that it converted.
5010  // TODO(laurenzo): Move this out of the mainline tf2xla conversion as it is
5011  // really only applicable to certain legacy uses.
5012  // CHECK-NOT: "tf.Qr"
5013  %0:2 = "tf.Qr"(%arg0) {full_matrices = false} : (tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>)
5014  return %0#0, %0#1 : tensor<500x100x75xf32>, tensor<500x75x75xf32>
5015}
5016
5017//===----------------------------------------------------------------------===//
5018// tf.Softplus legalization
5019//===----------------------------------------------------------------------===//
5020
5021// CHECK-LABEL: func @softplus_f16
5022// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf16>)
5023func @softplus_f16(%arg0: tensor<8x16xf16>) -> tensor<8x16xf16> {
5024  // CHECK-DAG: [[FEATURES_EXP:%.*]] = "mhlo.exponential"([[FEATURES]])
5025  // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.220700e-04> : tensor<f16>
5026  // CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]])
5027  // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<f16>
5028  // CHECK:     [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
5029  // CHECK:     [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]])
5030  // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
5031  // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
5032  // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]])
5033  // CHECK:     [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
5034  // CHECK:     [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
5035  %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf16>) -> tensor<8x16xf16>
5036
5037  // CHECK:     return [[ENTRY_SELECT]] : tensor<8x16xf16>
5038  return %0 : tensor<8x16xf16>
5039}
5040
5041// CHECK-LABEL: func @softplus_bf16
5042// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xbf16>)
5043func @softplus_bf16(%arg0: tensor<8x16xbf16>) -> tensor<8x16xbf16> {
5044  // CHECK-DAG: [[FEATURES_EXP:%.*]] = "mhlo.exponential"([[FEATURES]])
5045  // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<7.812500e-03> : tensor<bf16>
5046  // CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]])
5047  // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<bf16>
5048  // CHECK:     [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
5049  // CHECK:     [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]])
5050  // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
5051  // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
5052  // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]])
5053  // CHECK:     [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
5054  // CHECK:     [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
5055  %0 = "tf.Softplus"(%arg0) : (tensor<8x16xbf16>) -> tensor<8x16xbf16>
5056
5057  // CHECK:     return [[ENTRY_SELECT]] : tensor<8x16xbf16>
5058  return %0 : tensor<8x16xbf16>
5059}
5060
5061// CHECK-LABEL: func @softplus_f32
5062// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf32>)
5063func @softplus_f32(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
5064  // CHECK-DAG: [[FEATURES_EXP:%.*]] = "mhlo.exponential"([[FEATURES]])
5065  // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.1920929E-7> : tensor<f32>
5066  // CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]])
5067  // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32>
5068  // CHECK:     [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
5069  // CHECK:     [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]])
5070  // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
5071  // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
5072  // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]])
5073  // CHECK:     [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
5074  // CHECK:     [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
5075  %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
5076
5077  // CHECK:     return [[ENTRY_SELECT]] : tensor<8x16xf32>
5078  return %0 : tensor<8x16xf32>
5079}
5080
5081// CHECK-LABEL: func @softplus_f64
5082// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf64>)
5083func @softplus_f64(%arg0: tensor<8x16xf64>) -> tensor<8x16xf64> {
5084  // CHECK-DAG: [[FEATURES_EXP:%.*]] = "mhlo.exponential"([[FEATURES]])
5085  // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<2.2204460492503131E-16> : tensor<f64>
5086  // CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]])
5087  // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<f64>
5088  // CHECK:     [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
5089  // CHECK:     [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]])
5090  // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
5091  // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
5092  // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]])
5093  // CHECK:     [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
5094  // CHECK:     [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
5095  %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf64>) -> tensor<8x16xf64>
5096
5097  // CHECK:     return [[ENTRY_SELECT]] : tensor<8x16xf64>
5098  return %0 : tensor<8x16xf64>
5099}
5100
5101// CHECK-LABEL: @xla_gather
5102func @xla_gather(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10x1x300xf32> {
5103  %cst = "tf.Const"() { value = dense<[1, 1, 300]> : tensor<3xi64> } : () -> tensor<3xi64>
5104
5105  // CHECK: "mhlo.gather"
5106  // CHECK-SAME: dimension_numbers =
5107  // CHECK-SAME:   collapsed_slice_dims = dense<0> : tensor<1xi64>
5108  // CHECK-SAME:   index_vector_dim = 1 : i64
5109  // CHECK-SAME:   offset_dims = dense<1> : tensor<1xi64>
5110  // CHECK-SAME:   start_index_map = dense<0> : tensor<1xi64>
5111  // CHECK-SAME: indices_are_sorted = true
5112  // CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>
5113
5114  %0 = "tf.XlaGather"(%arg0, %arg1, %cst) {dimension_numbers = "\0A\01\01\12\01\00\1A\01\00 \01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi64>) -> tensor<10x1x300xf32>
5115  return %0 : tensor<10x1x300xf32>
5116}
5117
5118// CHECK-LABEL: @xla_gather_i32
5119func @xla_gather_i32(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10x1x300xf32> {
5120  %cst = "tf.Const"() { value = dense<[1, 1, 300]> : tensor<3xi32> } : () -> tensor<3xi32>
5121
5122  // CHECK: "mhlo.gather"
5123  // CHECK-SAME: dimension_numbers =
5124  // CHECK-SAME:   collapsed_slice_dims = dense<0> : tensor<1xi64>
5125  // CHECK-SAME:   index_vector_dim = 1 : i64
5126  // CHECK-SAME:   offset_dims = dense<1> : tensor<1xi64>
5127  // CHECK-SAME:   start_index_map = dense<0> : tensor<1xi64>
5128  // CHECK-SAME: indices_are_sorted = true
5129  // CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>
5130
5131  %0 = "tf.XlaGather"(%arg0, %arg1, %cst) {dimension_numbers = "\0A\01\01\12\01\00\1A\01\00 \01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi32>) -> tensor<10x1x300xf32>
5132  return %0 : tensor<10x1x300xf32>
5133}
5134
5135
5136// CHECK: func @stridedslice_with_i32
5137func @stridedslice_with_i32(%arg0: tensor<i32>) -> tensor<4xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "const_0_arg", outputs = "identity_0_retval_RetVal"}} {
5138// CHECK-NOT: tf.StridedSlice
5139// CHECK: [[DYNSLICE:%.*]] = "mhlo.dynamic-slice
5140// CHECK: [[RESHAPE:%.*]] = "mhlo.reshape"([[DYNSLICE]])
5141// CHECK: return [[RESHAPE]]
5142  %0 = "tf.Const"() {value = dense<[[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32>
5143  %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
5144  %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
5145  %3 = "tf.AddV2"(%arg0, %1) {_xla_inferred_shapes = [#tf.shape<>], device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
5146  %4 = "tf.Pack"(%3) {_xla_inferred_shapes = [#tf.shape<1>], axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32>
5147  %5 = "tf.Pack"(%arg0) {_xla_inferred_shapes = [#tf.shape<1>], axis = 0 : i64, device = ""} : (tensor<i32>) -> tensor<1xi32>
5148  %6 = "tf.StridedSlice"(%0, %5, %4, %2) {_xla_inferred_shapes = [#tf.shape<4>], begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2x4xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xf32>
5149  return %6 : tensor<4xf32>
5150}
5151
5152func @replica_id() -> tensor<i32> {
5153  // CHECK: %[[ID:.*]] = "mhlo.replica_id"() : () -> tensor<ui32>
5154  // CHECK: %[[RESULT:.*]] = "mhlo.convert"(%0) : (tensor<ui32>) -> tensor<i32>
5155  %0 = "tf.XlaReplicaId"() : () -> tensor<i32>
5156  return %0 : tensor<i32>
5157}
5158
5159// CHECK: func @angle_c64
5160// CHECK-SAME: ([[ARG0:%.*]]: tensor<complex<f32>>)
5161func @angle_c64(%arg0: tensor<complex<f32>>) -> tensor<f32> {
5162// CHECK: [[IMAG:%.*]] = "mhlo.imag"([[ARG0]])
5163// CHECK: [[REAL:%.*]] = "mhlo.real"([[ARG0]])
5164// CHECK: [[ATAN2:%.*]] = mhlo.atan2 [[IMAG]], [[REAL]]
5165  %0 = "tf.Angle"(%arg0): (tensor<complex<f32>>) -> tensor<f32>
5166  return %0 : tensor<f32>
5167}
5168