1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <vector>
18 
19 #include "tensorflow/compiler/xla/client/global_data.h"
20 #include "tensorflow/compiler/xla/client/local_client.h"
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
23 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
24 #include "tensorflow/compiler/xla/tests/test_macros.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/platform/types.h"
27 
28 namespace xla {
29 namespace {
30 
31 class SelectTest : public ClientLibraryTestBase {
32  public:
33   ErrorSpec error_spec_{0.0001};
34 };
35 
TEST_F(SelectTest,SelectScalarF32True)36 TEST_F(SelectTest, SelectScalarF32True) {
37   XlaBuilder builder(TestName());
38   auto pred = ConstantR0<bool>(&builder, true);
39   auto on_true = ConstantR0<float>(&builder, 123.0f);
40   auto on_false = ConstantR0<float>(&builder, 42.0f);
41   Select(pred, on_true, on_false);
42 
43   ComputeAndCompareR0<float>(&builder, 123.0f, {}, error_spec_);
44 }
45 
TEST_F(SelectTest,SelectScalarS32True)46 TEST_F(SelectTest, SelectScalarS32True) {
47   XlaBuilder builder(TestName());
48   auto pred = ConstantR0<bool>(&builder, true);
49   auto on_true = ConstantR0<int32>(&builder, -42);
50   auto on_false = ConstantR0<int32>(&builder, 42);
51   Select(pred, on_true, on_false);
52 
53   ComputeAndCompareR0<int32>(&builder, -42, {});
54 }
55 
TEST_F(SelectTest,SelectScalarF32False)56 TEST_F(SelectTest, SelectScalarF32False) {
57   XlaBuilder builder(TestName());
58   auto pred = ConstantR0<bool>(&builder, false);
59   auto on_true = ConstantR0<float>(&builder, 123.0f);
60   auto on_false = ConstantR0<float>(&builder, 42.0f);
61   Select(pred, on_true, on_false);
62 
63   ComputeAndCompareR0<float>(&builder, 42.0f, {}, error_spec_);
64 }
65 
XLA_TEST_F(SelectTest,SelectR1S0F32WithConstantR1S0PRED)66 XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) {
67   XlaBuilder builder(TestName());
68   auto pred = ConstantR1<bool>(&builder, {});
69   auto on_true = ConstantR1<float>(&builder, {});
70   auto on_false = ConstantR1<float>(&builder, {});
71   Select(pred, on_true, on_false);
72 
73   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
74 }
75 
TEST_F(SelectTest,SelectR1F32WithConstantR1PRED)76 TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) {
77   XlaBuilder builder(TestName());
78   auto pred = ConstantR1<bool>(&builder, {false, true, false, true, false});
79   auto on_true =
80       ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
81   auto on_false =
82       ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
83   Select(pred, on_true, on_false);
84 
85   ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {},
86                              error_spec_);
87 }
88 
XLA_TEST_F(SelectTest,SelectR1S0F32WithCmpR1S0S32s)89 XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) {
90   // Similar to SelectR1S0F32WithConstantR1S0PRED, except that the pred vector
91   // is not a constant, but rather the result of comparing two other vectors.
92   XlaBuilder builder(TestName());
93   auto v1 = ConstantR1<int32>(&builder, {});
94   auto v2 = ConstantR1<int32>(&builder, {});
95   auto cmp = Eq(v1, v2);
96   auto on_true = ConstantR1<float>(&builder, {});
97   auto on_false = ConstantR1<float>(&builder, {});
98   Select(cmp, on_true, on_false);
99 
100   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
101 }
102 
TEST_F(SelectTest,SelectR1F32WithCmpR1S32s)103 TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) {
104   // Similar to SelectR1F32WithConstantR1PRED, except that the pred vector is
105   // not a constant, but rather the result of comparing two other vectors.
106   XlaBuilder builder(TestName());
107   auto v1 = ConstantR1<int32>(&builder, {1, 2, 3, 4, 5});
108   auto v2 = ConstantR1<int32>(&builder, {9, 2, 9, 4, 9});
109   auto cmp = Eq(v1, v2);
110   auto on_true =
111       ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
112   auto on_false =
113       ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
114   Select(cmp, on_true, on_false);
115 
116   ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {},
117                              error_spec_);
118 }
119 
TEST_F(SelectTest,SelectR1F32WithCmpR1F32s)120 TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) {
121   // Similar to SelectR1F32WithCmpR1S32s, except "gt"-comparing two R1F32s.
122   XlaBuilder builder(TestName());
123   auto v1 = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
124   auto v2 = ConstantR1<float>(&builder, {-1.0f, -2.0f, 13.0f, 14.0f, 4.4f});
125   auto cmp = Gt(v1, v2);
126   auto on_true =
127       ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
128   auto on_false =
129       ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
130   Select(cmp, on_true, on_false);
131 
132   ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f, 1.0f, 10.0f, 6.0f}, {},
133                              error_spec_);
134 }
135 
TEST_F(SelectTest,SelectR1F32WithCmpR1F32sFromParamsSmall)136 TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) {
137   // Selects among two R1F32s, which come from parameters. v1 and v2 are
138   // compared, and selection between them happens based on a gt-comparison mask.
139   XlaBuilder builder(TestName());
140 
141   XlaOp v1, v2;
142   std::unique_ptr<GlobalData> param0_data = CreateR1Parameter<float>(
143       {41.0f, 2.0f, 3.0f, 84.0f}, /*parameter_number=*/0, /*name=*/"v1",
144       /*builder=*/&builder, /*data_handle=*/&v1);
145   std::unique_ptr<GlobalData> param1_data = CreateR1Parameter<float>(
146       {21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2",
147       /*builder=*/&builder, /*data_handle=*/&v2);
148 
149   auto cmp = Gt(v1, v2);
150   Select(cmp, v1, v2);
151   ComputeAndCompareR1<float>(&builder, {41.0f, 22.0f, 23.0f, 84.0f},
152                              {param0_data.get(), param1_data.get()},
153                              error_spec_);
154 }
155 
TEST_F(SelectTest,SelectR1F32WithCmpR1F32sFromParamsLarge)156 TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) {
157   // Similar to SelectR1F32WithCmpR1F32sFromParamsSmall, except that the
158   // data size passed in and out is large.
159   XlaBuilder builder(TestName());
160 
161   // Number of floats in the data passed into and out of the computation.
162   constexpr int datalen = 15 * 1000;
163 
164   // The inputs are initialized with a special pattern where in the first third
165   // of the data v1[i] > v2[i] and elsewhere it's vice versa.
166   std::vector<float> v1vec;
167   std::vector<float> v2vec;
168   std::vector<float> expected_vec;
169   for (int i = 0; i < datalen; ++i) {
170     float smaller = i;
171     float larger = i * 2;
172     if (i < datalen / 3) {
173       v1vec.push_back(larger);
174       v2vec.push_back(smaller);
175     } else {
176       v1vec.push_back(smaller);
177       v2vec.push_back(larger);
178     }
179     expected_vec.push_back(larger);
180   }
181 
182   XlaOp v1, v2;
183   std::unique_ptr<GlobalData> param0_data =
184       CreateR1Parameter<float>(v1vec, /*parameter_number=*/0, /*name=*/"v1",
185                                /*builder=*/&builder, /*data_handle=*/&v1);
186   std::unique_ptr<GlobalData> param1_data =
187       CreateR1Parameter<float>(v2vec, /*parameter_number=*/1, /*name=*/"v2",
188                                /*builder=*/&builder, /*data_handle=*/&v2);
189 
190   auto cmp = Gt(v1, v2);
191   Select(cmp, v1, v2);
192   ComputeAndCompareR1<float>(&builder, expected_vec,
193                              {param0_data.get(), param1_data.get()},
194                              error_spec_);
195 }
196 
TEST_F(SelectTest,SelectR1F32WithCmpR1S32ToScalar)197 TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) {
198   // "gt"-compares a R1S32 with a S32 scalar, and uses the resulting R1PRED to
199   // select between two R1F32s.
200   XlaBuilder builder(TestName());
201   auto v = ConstantR1<int32>(&builder, {1, -1, 2, -2});
202   auto s = ConstantR0<int32>(&builder, 0);
203   auto cmp = Gt(v, s);
204 
205   auto on_true = ConstantR1<float>(&builder, {11.0f, 22.0f, 33.0f, 44.0f});
206   auto on_false =
207       ConstantR1<float>(&builder, {-111.0f, -222.0f, -333.0f, -444.0f});
208   Select(cmp, on_true, on_false);
209 
210   ComputeAndCompareR1<float>(&builder, {11.0f, -222.0f, 33.0f, -444.0f}, {},
211                              error_spec_);
212 }
213 
TEST_F(SelectTest,SelectR1F32WithCmpR1F32ToScalar)214 TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) {
215   // "gt"-compares a R1F32 with a F32 scalar, and uses the resulting R1PRED to
216   // select between two R1F32s.
217   XlaBuilder builder(TestName());
218   auto v = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f, 4.0f});
219   auto s = ConstantR0<float>(&builder, 2.5f);
220   auto cmp = Gt(v, s);
221 
222   auto on_true = ConstantR1<float>(&builder, {11.0f, 22.0f, 33.0f, 44.0f});
223   auto on_false =
224       ConstantR1<float>(&builder, {-111.0f, -222.0f, -333.0f, -444.0f});
225   Select(cmp, on_true, on_false);
226 
227   ComputeAndCompareR1<float>(&builder, {-111.0f, -222.0f, 33.0f, 44.0f}, {},
228                              error_spec_);
229 }
230 
XLA_TEST_F(SelectTest,SelectR1S0F32WithScalarPredicate)231 XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) {
232   for (bool which : {false, true}) {
233     XlaBuilder builder(TestName());
234     auto pred = ConstantR0<bool>(&builder, which);
235     auto on_true = ConstantR1<float>(&builder, {});
236     auto on_false = ConstantR1<float>(&builder, {});
237     Select(pred, on_true, on_false);
238 
239     ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
240   }
241 }
242 
TEST_F(SelectTest,SelectR1F32WithScalarPredicateTrue)243 TEST_F(SelectTest, SelectR1F32WithScalarPredicateTrue) {
244   XlaBuilder builder(TestName());
245   auto pred = ConstantR0<bool>(&builder, true);
246   auto on_true = ConstantR1<float>(&builder, {-2.5f, 25.5f});
247   auto on_false = ConstantR1<float>(&builder, {10.0f, 5.0f});
248   Select(pred, on_true, on_false);
249 
250   ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f}, {}, error_spec_);
251 }
252 
TEST_F(SelectTest,SelectR1F32WithScalarPredicateFalse)253 TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) {
254   XlaBuilder builder(TestName());
255   auto pred = ConstantR0<bool>(&builder, false);
256   auto on_true = ConstantR1<float>(&builder, {-2.5f, 25.5f});
257   auto on_false = ConstantR1<float>(&builder, {10.0f, 5.0f});
258   Select(pred, on_true, on_false);
259 
260   ComputeAndCompareR1<float>(&builder, {10.0f, 5.0f}, {}, error_spec_);
261 }
262 }  // namespace
263 }  // namespace xla
264