1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <vector>
18
19 #include "tensorflow/compiler/xla/client/global_data.h"
20 #include "tensorflow/compiler/xla/client/local_client.h"
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
23 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
24 #include "tensorflow/compiler/xla/tests/test_macros.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/platform/types.h"
27
28 namespace xla {
29 namespace {
30
31 class SelectTest : public ClientLibraryTestBase {
32 public:
33 ErrorSpec error_spec_{0.0001};
34 };
35
TEST_F(SelectTest,SelectScalarF32True)36 TEST_F(SelectTest, SelectScalarF32True) {
37 XlaBuilder builder(TestName());
38 auto pred = ConstantR0<bool>(&builder, true);
39 auto on_true = ConstantR0<float>(&builder, 123.0f);
40 auto on_false = ConstantR0<float>(&builder, 42.0f);
41 Select(pred, on_true, on_false);
42
43 ComputeAndCompareR0<float>(&builder, 123.0f, {}, error_spec_);
44 }
45
TEST_F(SelectTest,SelectScalarS32True)46 TEST_F(SelectTest, SelectScalarS32True) {
47 XlaBuilder builder(TestName());
48 auto pred = ConstantR0<bool>(&builder, true);
49 auto on_true = ConstantR0<int32>(&builder, -42);
50 auto on_false = ConstantR0<int32>(&builder, 42);
51 Select(pred, on_true, on_false);
52
53 ComputeAndCompareR0<int32>(&builder, -42, {});
54 }
55
TEST_F(SelectTest,SelectScalarF32False)56 TEST_F(SelectTest, SelectScalarF32False) {
57 XlaBuilder builder(TestName());
58 auto pred = ConstantR0<bool>(&builder, false);
59 auto on_true = ConstantR0<float>(&builder, 123.0f);
60 auto on_false = ConstantR0<float>(&builder, 42.0f);
61 Select(pred, on_true, on_false);
62
63 ComputeAndCompareR0<float>(&builder, 42.0f, {}, error_spec_);
64 }
65
XLA_TEST_F(SelectTest,SelectR1S0F32WithConstantR1S0PRED)66 XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) {
67 XlaBuilder builder(TestName());
68 auto pred = ConstantR1<bool>(&builder, {});
69 auto on_true = ConstantR1<float>(&builder, {});
70 auto on_false = ConstantR1<float>(&builder, {});
71 Select(pred, on_true, on_false);
72
73 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
74 }
75
TEST_F(SelectTest,SelectR1F32WithConstantR1PRED)76 TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) {
77 XlaBuilder builder(TestName());
78 auto pred = ConstantR1<bool>(&builder, {false, true, false, true, false});
79 auto on_true =
80 ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
81 auto on_false =
82 ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
83 Select(pred, on_true, on_false);
84
85 ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {},
86 error_spec_);
87 }
88
XLA_TEST_F(SelectTest,SelectR1S0F32WithCmpR1S0S32s)89 XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) {
90 // Similar to SelectR1S0F32WithConstantR1S0PRED, except that the pred vector
91 // is not a constant, but rather the result of comparing two other vectors.
92 XlaBuilder builder(TestName());
93 auto v1 = ConstantR1<int32>(&builder, {});
94 auto v2 = ConstantR1<int32>(&builder, {});
95 auto cmp = Eq(v1, v2);
96 auto on_true = ConstantR1<float>(&builder, {});
97 auto on_false = ConstantR1<float>(&builder, {});
98 Select(cmp, on_true, on_false);
99
100 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
101 }
102
TEST_F(SelectTest,SelectR1F32WithCmpR1S32s)103 TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) {
104 // Similar to SelectR1F32WithConstantR1PRED, except that the pred vector is
105 // not a constant, but rather the result of comparing two other vectors.
106 XlaBuilder builder(TestName());
107 auto v1 = ConstantR1<int32>(&builder, {1, 2, 3, 4, 5});
108 auto v2 = ConstantR1<int32>(&builder, {9, 2, 9, 4, 9});
109 auto cmp = Eq(v1, v2);
110 auto on_true =
111 ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
112 auto on_false =
113 ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
114 Select(cmp, on_true, on_false);
115
116 ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {},
117 error_spec_);
118 }
119
TEST_F(SelectTest,SelectR1F32WithCmpR1F32s)120 TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) {
121 // Similar to SelectR1F32WithCmpR1S32s, except "gt"-comparing two R1F32s.
122 XlaBuilder builder(TestName());
123 auto v1 = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
124 auto v2 = ConstantR1<float>(&builder, {-1.0f, -2.0f, 13.0f, 14.0f, 4.4f});
125 auto cmp = Gt(v1, v2);
126 auto on_true =
127 ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
128 auto on_false =
129 ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
130 Select(cmp, on_true, on_false);
131
132 ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f, 1.0f, 10.0f, 6.0f}, {},
133 error_spec_);
134 }
135
TEST_F(SelectTest,SelectR1F32WithCmpR1F32sFromParamsSmall)136 TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) {
137 // Selects among two R1F32s, which come from parameters. v1 and v2 are
138 // compared, and selection between them happens based on a gt-comparison mask.
139 XlaBuilder builder(TestName());
140
141 XlaOp v1, v2;
142 std::unique_ptr<GlobalData> param0_data = CreateR1Parameter<float>(
143 {41.0f, 2.0f, 3.0f, 84.0f}, /*parameter_number=*/0, /*name=*/"v1",
144 /*builder=*/&builder, /*data_handle=*/&v1);
145 std::unique_ptr<GlobalData> param1_data = CreateR1Parameter<float>(
146 {21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2",
147 /*builder=*/&builder, /*data_handle=*/&v2);
148
149 auto cmp = Gt(v1, v2);
150 Select(cmp, v1, v2);
151 ComputeAndCompareR1<float>(&builder, {41.0f, 22.0f, 23.0f, 84.0f},
152 {param0_data.get(), param1_data.get()},
153 error_spec_);
154 }
155
TEST_F(SelectTest,SelectR1F32WithCmpR1F32sFromParamsLarge)156 TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) {
157 // Similar to SelectR1F32WithCmpR1F32sFromParamsSmall, except that the
158 // data size passed in and out is large.
159 XlaBuilder builder(TestName());
160
161 // Number of floats in the data passed into and out of the computation.
162 constexpr int datalen = 15 * 1000;
163
164 // The inputs are initialized with a special pattern where in the first third
165 // of the data v1[i] > v2[i] and elsewhere it's vice versa.
166 std::vector<float> v1vec;
167 std::vector<float> v2vec;
168 std::vector<float> expected_vec;
169 for (int i = 0; i < datalen; ++i) {
170 float smaller = i;
171 float larger = i * 2;
172 if (i < datalen / 3) {
173 v1vec.push_back(larger);
174 v2vec.push_back(smaller);
175 } else {
176 v1vec.push_back(smaller);
177 v2vec.push_back(larger);
178 }
179 expected_vec.push_back(larger);
180 }
181
182 XlaOp v1, v2;
183 std::unique_ptr<GlobalData> param0_data =
184 CreateR1Parameter<float>(v1vec, /*parameter_number=*/0, /*name=*/"v1",
185 /*builder=*/&builder, /*data_handle=*/&v1);
186 std::unique_ptr<GlobalData> param1_data =
187 CreateR1Parameter<float>(v2vec, /*parameter_number=*/1, /*name=*/"v2",
188 /*builder=*/&builder, /*data_handle=*/&v2);
189
190 auto cmp = Gt(v1, v2);
191 Select(cmp, v1, v2);
192 ComputeAndCompareR1<float>(&builder, expected_vec,
193 {param0_data.get(), param1_data.get()},
194 error_spec_);
195 }
196
TEST_F(SelectTest,SelectR1F32WithCmpR1S32ToScalar)197 TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) {
198 // "gt"-compares a R1S32 with a S32 scalar, and uses the resulting R1PRED to
199 // select between two R1F32s.
200 XlaBuilder builder(TestName());
201 auto v = ConstantR1<int32>(&builder, {1, -1, 2, -2});
202 auto s = ConstantR0<int32>(&builder, 0);
203 auto cmp = Gt(v, s);
204
205 auto on_true = ConstantR1<float>(&builder, {11.0f, 22.0f, 33.0f, 44.0f});
206 auto on_false =
207 ConstantR1<float>(&builder, {-111.0f, -222.0f, -333.0f, -444.0f});
208 Select(cmp, on_true, on_false);
209
210 ComputeAndCompareR1<float>(&builder, {11.0f, -222.0f, 33.0f, -444.0f}, {},
211 error_spec_);
212 }
213
TEST_F(SelectTest,SelectR1F32WithCmpR1F32ToScalar)214 TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) {
215 // "gt"-compares a R1F32 with a F32 scalar, and uses the resulting R1PRED to
216 // select between two R1F32s.
217 XlaBuilder builder(TestName());
218 auto v = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f, 4.0f});
219 auto s = ConstantR0<float>(&builder, 2.5f);
220 auto cmp = Gt(v, s);
221
222 auto on_true = ConstantR1<float>(&builder, {11.0f, 22.0f, 33.0f, 44.0f});
223 auto on_false =
224 ConstantR1<float>(&builder, {-111.0f, -222.0f, -333.0f, -444.0f});
225 Select(cmp, on_true, on_false);
226
227 ComputeAndCompareR1<float>(&builder, {-111.0f, -222.0f, 33.0f, 44.0f}, {},
228 error_spec_);
229 }
230
XLA_TEST_F(SelectTest,SelectR1S0F32WithScalarPredicate)231 XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) {
232 for (bool which : {false, true}) {
233 XlaBuilder builder(TestName());
234 auto pred = ConstantR0<bool>(&builder, which);
235 auto on_true = ConstantR1<float>(&builder, {});
236 auto on_false = ConstantR1<float>(&builder, {});
237 Select(pred, on_true, on_false);
238
239 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
240 }
241 }
242
TEST_F(SelectTest,SelectR1F32WithScalarPredicateTrue)243 TEST_F(SelectTest, SelectR1F32WithScalarPredicateTrue) {
244 XlaBuilder builder(TestName());
245 auto pred = ConstantR0<bool>(&builder, true);
246 auto on_true = ConstantR1<float>(&builder, {-2.5f, 25.5f});
247 auto on_false = ConstantR1<float>(&builder, {10.0f, 5.0f});
248 Select(pred, on_true, on_false);
249
250 ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f}, {}, error_spec_);
251 }
252
TEST_F(SelectTest,SelectR1F32WithScalarPredicateFalse)253 TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) {
254 XlaBuilder builder(TestName());
255 auto pred = ConstantR0<bool>(&builder, false);
256 auto on_true = ConstantR1<float>(&builder, {-2.5f, 25.5f});
257 auto on_false = ConstantR1<float>(&builder, {10.0f, 5.0f});
258 Select(pred, on_true, on_false);
259
260 ComputeAndCompareR1<float>(&builder, {10.0f, 5.0f}, {}, error_spec_);
261 }
262 } // namespace
263 } // namespace xla
264