1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cmath>
17 #include <limits>
18 #include <memory>
19 #include <numeric>
20 #include <vector>
21
22 #include "absl/base/casts.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/array2d.h"
25 #include "tensorflow/compiler/xla/array3d.h"
26 #include "tensorflow/compiler/xla/array4d.h"
27 #include "tensorflow/compiler/xla/client/global_data.h"
28 #include "tensorflow/compiler/xla/client/local_client.h"
29 #include "tensorflow/compiler/xla/client/xla_builder.h"
30 #include "tensorflow/compiler/xla/layout_util.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/test.h"
34 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
35 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
36 #include "tensorflow/compiler/xla/tests/test_macros.h"
37 #include "tensorflow/compiler/xla/types.h"
38 #include "tensorflow/core/platform/types.h"
39
40 namespace xla {
41 namespace {
42
43 class ArrayElementwiseOpTest : public ClientLibraryTestBase {
44 public:
45 ErrorSpec error_spec_{0.0001, 0.0001};
46 };
47
48 class ArrayElementwiseOpTestParamCount
49 : public ArrayElementwiseOpTest,
50 public ::testing::WithParamInterface<int> {};
51
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantZeroElementF32)52 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementF32) {
53 XlaBuilder builder(TestName());
54 auto a = ConstantR1<float>(&builder, {});
55 Neg(a);
56
57 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
58 }
59
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantF32)60 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF32) {
61 XlaBuilder builder(TestName());
62 auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
63 Neg(a);
64
65 ComputeAndCompareR1<float>(&builder, {2.5f, -3.14f, -2.25f, 10.0f, -6.0f}, {},
66 error_spec_);
67 }
68
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantS32)69 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) {
70 XlaBuilder builder(TestName());
71 auto a = ConstantR1<int32>(&builder,
72 {-1, 0, 1, 324, std::numeric_limits<int32>::min(),
73 std::numeric_limits<int32>::max()});
74 Neg(a);
75
76 // -min == min for int32 due to an overflow. In C++ it is undefined behavior
77 // to do this calculation. For XLA we have not specified that, so it
78 // ought to work.
79 ComputeAndCompareR1<int32>(&builder,
80 {1, 0, -1, -324, std::numeric_limits<int32>::min(),
81 -std::numeric_limits<int32>::max()},
82 {});
83 }
84
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantZeroElementC64)85 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementC64) {
86 XlaBuilder builder(TestName());
87 auto a = ConstantR1<complex64>(&builder, {});
88 Neg(a);
89
90 ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
91 }
92
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantC64)93 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) {
94 XlaBuilder builder(TestName());
95 auto a = ConstantR1<complex64>(
96 &builder, {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}});
97 Neg(a);
98
99 ComputeAndCompareR1<complex64>(
100 &builder, {{2.5f, -1.0f}, {0.0f, -3.14f}, {-2.25f, 1.0f}, {10.0f, 0.0f}},
101 {}, error_spec_);
102 }
103
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantS64)104 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS64) {
105 XlaBuilder builder(TestName());
106 auto a =
107 ConstantR1<int64>(&builder, {
108 -1,
109 1,
110 0,
111 0x12345678,
112 static_cast<int64>(0xffffffff12345678l),
113 static_cast<int64>(0x8000000000000000LL),
114 static_cast<int64>(0x8000000000000001LL),
115 });
116 Neg(a);
117 LOG(INFO) << -static_cast<int64>(0x7FFFFFFFFFFFFFFFLL);
118
119 ComputeAndCompareR1<int64>(&builder,
120 {
121 1,
122 -1,
123 0,
124 -0x12345678,
125 0xedcba988,
126 static_cast<int64>(0x8000000000000000LL),
127 -static_cast<int64>(0x8000000000000001LL),
128 },
129 {});
130 }
131
XLA_TEST_F(ArrayElementwiseOpTest,IsFiniteZeroElementF32s)132 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) {
133 XlaBuilder builder(TestName());
134 auto a = ConstantR1<float>(&builder, {});
135 IsFinite(a);
136
137 ComputeAndCompareR1<bool>(&builder, {}, {});
138 }
139
140 // A non-canonical quiet NaN value.
141 static const float kNonCanonicalNaN = absl::bit_cast<float>(0x7FD01234);
142
XLA_TEST_F(ArrayElementwiseOpTest,IsFiniteScalarF32)143 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteScalarF32) {
144 XlaBuilder builder(TestName());
145 IsFinite(ConstantR0<float>(&builder, NAN));
146 ComputeAndCompareR0<bool>(&builder, false, {});
147
148 EXPECT_TRUE(std::isnan(kNonCanonicalNaN));
149 IsFinite(ConstantR0<float>(&builder, kNonCanonicalNaN));
150 ComputeAndCompareR0<bool>(&builder, false, {});
151
152 const float inf = std::numeric_limits<float>::infinity();
153 IsFinite(ConstantR0<float>(&builder, inf));
154 ComputeAndCompareR0<bool>(&builder, false, {});
155
156 IsFinite(ConstantR0<float>(&builder, -inf));
157 ComputeAndCompareR0<bool>(&builder, false, {});
158
159 IsFinite(ConstantR0<float>(&builder, 0.0f));
160 ComputeAndCompareR0<bool>(&builder, true, {});
161 }
162
XLA_TEST_F(ArrayElementwiseOpTest,IsFiniteR1F32s)163 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) {
164 XlaBuilder builder(TestName());
165 const float inf = std::numeric_limits<float>::infinity();
166 EXPECT_TRUE(std::isnan(kNonCanonicalNaN));
167 auto a = ConstantR1<float>(&builder,
168 {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}});
169 IsFinite(a);
170
171 ComputeAndCompareR1<bool>(&builder, {false, true, false, true, false, false},
172 {});
173 }
174
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantF32s)175 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) {
176 XlaBuilder builder(TestName());
177 auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
178 auto b = ConstantR1<float>(&builder, {100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
179 Add(a, b);
180
181 ComputeAndCompareR1<float>(&builder, {97.5f, 6.27f, 5.0f, 0.5f, -993.0f}, {},
182 error_spec_);
183 }
184
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantZeroElementF32s)185 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) {
186 XlaBuilder builder(TestName());
187 auto a = ConstantR1<float>(&builder, {});
188 auto b = ConstantR1<float>(&builder, {});
189 Add(a, b);
190
191 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
192 }
193
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantC64s)194 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) {
195 XlaBuilder builder(TestName());
196 auto a = ConstantR1<complex64>(
197 &builder, {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}});
198 auto b = ConstantR1<complex64>(
199 &builder, {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}});
200 Add(a, b);
201
202 ComputeAndCompareR1<complex64>(
203 &builder, {97.5f, {3.13f, 3.14f}, {5.0f, 1.0f}, {-1.0f, 0.5f}}, {},
204 error_spec_);
205 }
206
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantZeroElementC64s)207 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) {
208 XlaBuilder builder(TestName());
209 auto a = ConstantR1<complex64>(&builder, {});
210 auto b = ConstantR1<complex64>(&builder, {});
211 Add(a, b);
212
213 ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
214 }
215
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantU64s)216 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
217 XlaBuilder b(TestName());
218
219 std::vector<uint64> lhs{0xFFFFFFFF,
220 static_cast<uint64>(-1),
221 0,
222 0,
223 0x7FFFFFFFFFFFFFFFLL,
224 0x7FFFFFFFFFFFFFFLL,
225 0x8000000000000000LL,
226 0x8000000000000000LL,
227 1};
228 Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
229 auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
230 std::unique_ptr<GlobalData> lhs_data =
231 client_->TransferToServer(lhs_literal).ConsumeValueOrDie();
232
233 std::vector<uint64> rhs{1,
234 0x7FFFFFFFFFFFFFFLL,
235 0x7FFFFFFFFFFFFFFFLL,
236 0x8000000000000000LL,
237 0,
238 static_cast<uint64>(-1),
239 0,
240 1,
241 0x8000000000000000LL};
242 Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
243 auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
244 std::unique_ptr<GlobalData> rhs_data =
245 client_->TransferToServer(rhs_literal).ConsumeValueOrDie();
246
247 Add(lhs_param, rhs_param);
248
249 std::vector<uint64> expected(lhs.size());
250 for (int64 i = 0; i < lhs.size(); ++i) {
251 expected[i] = lhs[i] + rhs[i];
252 }
253
254 ComputeAndCompareR1<uint64>(&b, expected, {lhs_data.get(), rhs_data.get()});
255 }
256
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantS64s)257 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
258 XlaBuilder b(TestName());
259
260 std::vector<int64> lhs{static_cast<int64>(0x8000000000000000LL),
261 static_cast<int64>(0x8000000000000000LL),
262 -1,
263 0x7FFFFFFFFFFFFFFLL,
264 0x7FFFFFFFFFFFFFFFLL,
265 1,
266 0,
267 -1};
268 Literal lhs_literal = LiteralUtil::CreateR1<int64>({lhs});
269 auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
270 std::unique_ptr<GlobalData> lhs_data =
271 client_->TransferToServer(lhs_literal).ConsumeValueOrDie();
272
273 std::vector<int64> rhs{-1,
274 0,
275 static_cast<int64>(0x8000000000000000LL),
276 1,
277 0,
278 0x7FFFFFFFFFFFFFFLL,
279 0x7FFFFFFFFFFFFFFFLL,
280 0x7FFFFFFFFFFFFFFFLL};
281 Literal rhs_literal = LiteralUtil::CreateR1<int64>({rhs});
282 auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
283 std::unique_ptr<GlobalData> rhs_data =
284 client_->TransferToServer(rhs_literal).ConsumeValueOrDie();
285
286 Sub(lhs_param, rhs_param);
287
288 std::vector<int64> expected(lhs.size());
289 for (int64 i = 0; i < lhs.size(); ++i) {
290 expected[i] = lhs[i] - rhs[i];
291 }
292
293 ComputeAndCompareR1<int64>(&b, expected, {lhs_data.get(), rhs_data.get()});
294 }
295
XLA_TEST_F(ArrayElementwiseOpTest,CmpTwoConstantU64s)296 XLA_TEST_F(ArrayElementwiseOpTest, CmpTwoConstantU64s) {
297 XlaBuilder b(TestName());
298
299 std::vector<uint64> lhs{static_cast<uint64>(0x8000000000000000ULL)};
300 Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
301 auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
302
303 std::vector<uint64> rhs{static_cast<uint64>(0x7FFFFFFFFFFFFFFFULL)};
304 Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
305 auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
306
307 Lt(lhs_param, rhs_param);
308
309 ComputeAndCompare(&b, {std::move(lhs_literal), std::move(rhs_literal)});
310 }
311
TEST_P(ArrayElementwiseOpTestParamCount,AddManyValues)312 TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
313 const int count = GetParam();
314 XlaBuilder builder(TestName());
315 std::vector<float> a_values;
316 std::vector<float> b_values;
317 for (int i = 0; i < count; ++i) {
318 a_values.push_back(i / static_cast<float>(count));
319 b_values.push_back(2 * i / static_cast<float>(count + 2));
320 }
321
322 Literal a_literal = LiteralUtil::CreateR1<float>({a_values});
323 std::unique_ptr<GlobalData> a_data =
324 client_->TransferToServer(a_literal).ConsumeValueOrDie();
325 auto a_constant = ConstantR1<float>(&builder, a_values);
326 auto a_param = Parameter(&builder, 0, a_literal.shape(), "a_param");
327
328 Literal b_literal = LiteralUtil::CreateR1<float>({b_values});
329 std::unique_ptr<GlobalData> b_data =
330 client_->TransferToServer(b_literal).ConsumeValueOrDie();
331 auto b_param = Parameter(&builder, 1, a_literal.shape(), "b_param");
332 auto b_constant = ConstantR1<float>(&builder, b_values);
333
334 auto sum1 = Add(a_constant, b_param);
335 auto sum2 = Add(a_constant, b_constant);
336 auto sum3 = Add(a_param, b_param);
337 auto sum4 = Add(a_param, b_constant);
338
339 auto sum = Add(sum1, sum2);
340 sum = Add(sum, sum3);
341 sum = Add(sum, sum4);
342
343 std::vector<float> expected;
344 for (int64 i = 0; i < count; ++i) {
345 expected.push_back(4 * (a_values[i] + b_values[i]));
346 }
347
348 ComputeAndCompareR1<float>(&builder, expected, {a_data.get(), b_data.get()},
349 error_spec_);
350 }
351
XLA_TEST_F(ArrayElementwiseOpTest,DeeplyNestedAddWithSlices)352 XLA_TEST_F(ArrayElementwiseOpTest, DeeplyNestedAddWithSlices) {
353 XlaBuilder builder(TestName());
354 std::vector<float> values(30, 0.0);
355 auto a_literal = LiteralUtil::CreateR1<float>(values);
356 auto a = Parameter(&builder, 0, a_literal.shape(), "x");
357 auto b_literal = LiteralUtil::CreateR1<float>(values);
358 auto b = Parameter(&builder, 1, b_literal.shape(), "x");
359
360 // Construct a sequence of diamond-shaped gadgets like this:
361 //
362 // add
363 // / \
364 // slice slice
365 // \ /
366 // add
367 //
368 // Each 'left' slice removes the last element, each 'right' slice removes the
369 // first element. In this way, we index into the add with different
370 // multi-dimensional index arrays, which defeats the caching we use to avoid
371 // exponential compile time.
372 std::function<XlaOp(int64)> generate_recursive =
373 [&](int64 slice_size) -> XlaOp {
374 if (slice_size == values.size()) {
375 return Add(a, b);
376 }
377 XlaOp param = generate_recursive(slice_size + 1);
378 auto slice1 = Slice(param, {0}, {slice_size}, {1});
379 auto slice2 = Slice(param, {1}, {slice_size + 1}, {1});
380 return Add(slice1, slice2);
381 };
382 generate_recursive(1);
383 auto a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie();
384 auto b_data = client_->TransferToServer(b_literal).ConsumeValueOrDie();
385 ComputeAndCompareR1<float>(&builder, {0.0}, {a_data.get(), b_data.get()});
386 }
387
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantF32s)388 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) {
389 XlaBuilder builder(TestName());
390 auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
391 auto b = ConstantR1<float>(&builder, {100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
392 Sub(a, b);
393
394 ComputeAndCompareR1<float>(&builder, {-102.5f, 0.01f, -0.5f, -20.5f, 1005.0f},
395 {}, error_spec_);
396 }
397
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantZeroElementF32s)398 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementF32s) {
399 XlaBuilder builder(TestName());
400 auto a = ConstantR1<float>(&builder, {});
401 auto b = ConstantR1<float>(&builder, {});
402 Sub(a, b);
403
404 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
405 }
406
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantS32s)407 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) {
408 XlaBuilder builder(TestName());
409 auto a = ConstantR1<int32>(&builder, {-1, 0, 2, 1000000000});
410 auto b = ConstantR1<int32>(&builder, {-1, 2, 1, -1});
411 Sub(a, b);
412
413 ComputeAndCompareR1<int32>(&builder, {0, -2, 1, 1000000001}, {});
414 }
415
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantZeroElementS32s)416 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) {
417 XlaBuilder builder(TestName());
418 auto a = ConstantR1<int32>(&builder, {});
419 auto b = ConstantR1<int32>(&builder, {});
420 Sub(a, b);
421
422 ComputeAndCompareR1<int32>(&builder, {}, {});
423 }
424
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantC64s)425 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) {
426 XlaBuilder builder(TestName());
427 auto a = ConstantR1<complex64>(&builder,
428 {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}});
429 auto b = ConstantR1<complex64>(
430 &builder, {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}});
431 Sub(a, b);
432
433 ComputeAndCompareR1<complex64>(
434 &builder, {{-2.5f, -10.0f}, {-3.13f, 3.14f}, {0.25f, 2.5f}}, {},
435 error_spec_);
436 }
437
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantZeroElementC64s)438 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementC64s) {
439 XlaBuilder builder(TestName());
440 auto a = ConstantR1<complex64>(&builder, {});
441 auto b = ConstantR1<complex64>(&builder, {});
442 Sub(a, b);
443
444 ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
445 }
446
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantF32s)447 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) {
448 XlaBuilder builder(TestName());
449 auto a = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
450 auto b = ConstantR1<float>(&builder, {10.0f, 5.1f, 1.0f, 10.0f, -6.0f});
451 Div(a, b);
452
453 ComputeAndCompareR1<float>(&builder, {-0.25f, 5.0f, 2.25f, -1.0f, -1.0f}, {},
454 error_spec_);
455 }
456
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantZeroElementF32s)457 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) {
458 XlaBuilder builder(TestName());
459 auto a = ConstantR1<float>(&builder, {});
460 auto b = ConstantR1<float>(&builder, {});
461 Div(a, b);
462
463 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
464 }
465
466 class IntegerDivideOpTest : public ArrayElementwiseOpTest {
467 protected:
468 template <typename T>
TestDivRem(absl::Span<const T> dividends,absl::Span<const T> divisors,absl::Span<const T> quotients,absl::Span<const T> remainders)469 void TestDivRem(absl::Span<const T> dividends, absl::Span<const T> divisors,
470 absl::Span<const T> quotients,
471 absl::Span<const T> remainders) {
472 {
473 XlaBuilder builder(TestName());
474 XlaOp dividend;
475 XlaOp divisor;
476 auto dividend_data =
477 CreateR1Parameter<T>(dividends, 0, "dividend", &builder, ÷nd);
478 auto divisor_data =
479 CreateR1Parameter<T>(divisors, 1, "divisor", &builder, &divisor);
480 Div(dividend, divisor);
481
482 ComputeAndCompareR1<T>(&builder, quotients,
483 {dividend_data.get(), divisor_data.get()});
484 }
485
486 // Test with a compile-time constant divisor.
487 {
488 XlaBuilder builder(TestName());
489 XlaOp dividend;
490 auto dividend_data =
491 CreateR1Parameter<T>(dividends, 0, "dividend", &builder, ÷nd);
492 Div(dividend, ConstantR1<T>(&builder, divisors));
493
494 ComputeAndCompareR1<T>(&builder, quotients, {dividend_data.get()});
495 }
496
497 {
498 XlaBuilder builder(TestName());
499 XlaOp dividend;
500 XlaOp divisor;
501 auto dividend_data =
502 CreateR1Parameter<T>(dividends, 0, "dividend", &builder, ÷nd);
503 auto divisor_data =
504 CreateR1Parameter<T>(divisors, 1, "divisor", &builder, &divisor);
505 Rem(dividend, divisor);
506
507 ComputeAndCompareR1<T>(&builder, remainders,
508 {dividend_data.get(), divisor_data.get()});
509 }
510
511 // Test with a compile-time constant divisor.
512 {
513 XlaBuilder builder(TestName());
514 XlaOp dividend;
515 auto dividend_data =
516 CreateR1Parameter<T>(dividends, 0, "dividend", &builder, ÷nd);
517 Rem(dividend, ConstantR1<T>(&builder, divisors));
518
519 ComputeAndCompareR1<T>(&builder, remainders, {dividend_data.get()});
520 }
521 }
522 };
523
XLA_TEST_F(IntegerDivideOpTest,DivS32s)524 XLA_TEST_F(IntegerDivideOpTest, DivS32s) {
525 // clang-format off
526 // Some interesting values to test.
527 std::vector<int32> vals = {
528 INT32_MIN, INT32_MIN + 1, INT32_MIN + 2, -0x40000000, -0x3fffffff,
529 -271181, -1309, -17, -10, -5, -3, -2, -1, 0, 1, 2, 3, 5, 10, 17, 26, 101,
530 7919, 0x40000000, INT32_MAX - 2, INT32_MAX - 1, INT32_MAX};
531 // clang-format on
532
533 std::vector<int32> dividends, divisors, quotients, remainders;
534 for (int32 divisor : vals) {
535 if (divisor != 0) {
536 for (int32 dividend : vals) {
537 // Avoid integer overflow.
538 if (dividend != INT32_MIN || divisor != -1) {
539 dividends.push_back(dividend);
540 divisors.push_back(divisor);
541 quotients.push_back(dividend / divisor);
542 remainders.push_back(dividend % divisor);
543 }
544 }
545 }
546 }
547
548 TestDivRem<int32>(dividends, divisors, quotients, remainders);
549 }
550
XLA_TEST_F(IntegerDivideOpTest,SignedOverflow)551 XLA_TEST_F(IntegerDivideOpTest, SignedOverflow) {
552 std::vector<int32> dividends = {5, INT32_MIN}, divisors = {0, -1},
553 quotients = {-1, INT32_MIN}, remainders = {5, 0};
554
555 TestDivRem<int32>(dividends, divisors, quotients, remainders);
556 }
557
XLA_TEST_F(IntegerDivideOpTest,DivU32s)558 XLA_TEST_F(IntegerDivideOpTest, DivU32s) {
559 // clang-format off
560 // Some interesting values to test.
561 std::vector<uint32> vals = {
562 0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0xABCDEF12, 0xCAFEBEEF, 0x80000000,
563 0x80000001, UINT32_MAX - 2, UINT32_MAX - 1, UINT32_MAX};
564 // clang-format on
565
566 std::vector<uint32> dividends, divisors, quotients, remainders;
567 for (uint32 divisor : vals) {
568 if (divisor != 0) {
569 for (uint32 dividend : vals) {
570 dividends.push_back(dividend);
571 divisors.push_back(divisor);
572 quotients.push_back(dividend / divisor);
573 remainders.push_back(dividend % divisor);
574 }
575 }
576 }
577
578 TestDivRem<uint32>(dividends, divisors, quotients, remainders);
579 }
580
XLA_TEST_F(IntegerDivideOpTest,UnsignedOverflow)581 XLA_TEST_F(IntegerDivideOpTest, UnsignedOverflow) {
582 std::vector<int32> dividends = {5}, divisors = {0}, quotients = {-1},
583 remainders = {5};
584
585 TestDivRem<int32>(dividends, divisors, quotients, remainders);
586 }
587
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantC64s)588 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) {
589 XlaBuilder builder(TestName());
590 auto a = ConstantR1<complex64>(
591 &builder, {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}});
592 auto b = ConstantR1<complex64>(&builder,
593 {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}});
594 Div(a, b);
595
596 ComputeAndCompareR1<complex64>(
597 &builder, {{-0.25f, 0.1f}, {0.0f, 25.5f}, {1.0f, 0.0f}}, {}, error_spec_);
598 }
599
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantZeroElementC64s)600 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementC64s) {
601 XlaBuilder builder(TestName());
602 auto a = ConstantR1<complex64>(&builder, {});
603 auto b = ConstantR1<complex64>(&builder, {});
604 Div(a, b);
605
606 ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
607 }
608
XLA_TEST_F(ArrayElementwiseOpTest,RemF32s)609 XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) {
610 XlaBuilder builder(TestName());
611 auto a = ConstantR1<float>(
612 &builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f});
613 auto b = ConstantR1<float>(
614 &builder, {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f});
615 Rem(a, b);
616
617 ComputeAndCompareR1<float>(
618 &builder, {-2.5f, 0.0f, 0.25f, 0.0f, -0.0f, 1.0f, 1.0f, -1.0f, -0.0f}, {},
619 error_spec_);
620 }
621
XLA_TEST_F(ArrayElementwiseOpTest,RemZeroElementF32s)622 XLA_TEST_F(ArrayElementwiseOpTest, RemZeroElementF32s) {
623 XlaBuilder builder(TestName());
624 auto a = ConstantR1<float>(&builder, {});
625 auto b = ConstantR1<float>(&builder, {});
626 Rem(a, b);
627
628 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
629 }
630
XLA_TEST_F(ArrayElementwiseOpTest,RemF64s)631 XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) {
632 XlaBuilder builder(TestName());
633 auto a = ConstantR1<double>(
634 &builder, {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0});
635 auto b = ConstantR1<double>(
636 &builder, {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0});
637 Rem(a, b);
638
639 ComputeAndCompareR1<double>(
640 &builder, {-2.5, 0.0, 0.25, 0.0, -0.0, 1.0, 1.0, -1.0, -0.0}, {},
641 error_spec_);
642 }
643
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantF32s)644 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) {
645 XlaBuilder builder(TestName());
646 auto a = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
647 auto b = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
648 Mul(a, b);
649
650 ComputeAndCompareR1<float>(&builder, {-25.0f, 127.5f, 2.25f, -100.0f, -36.0f},
651 {}, error_spec_);
652 }
653
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantZeroElementF32s)654 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementF32s) {
655 XlaBuilder builder(TestName());
656 auto a = ConstantR1<float>(&builder, {});
657 auto b = ConstantR1<float>(&builder, {});
658 Mul(a, b);
659
660 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
661 }
662
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantS32s)663 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) {
664 std::vector<int32> data = {0,
665 1,
666 -1,
667 1234,
668 0x1a243514,
669 std::numeric_limits<int32>::max(),
670 std::numeric_limits<int32>::min()};
671 // Form the test data set using all products of 'data' with itself.
672 std::vector<int32> a_data, b_data, expected;
673 for (int32 a : data) {
674 for (int32 b : data) {
675 a_data.push_back(a);
676 b_data.push_back(b);
677 expected.push_back(static_cast<uint32>(a) * static_cast<uint32>(b));
678 }
679 }
680
681 XlaBuilder builder(TestName());
682 auto a = ConstantR1<int32>(&builder, a_data);
683 auto b = ConstantR1<int32>(&builder, b_data);
684 Mul(a, b);
685
686 ComputeAndCompareR1<int32>(&builder, expected, {});
687 }
688
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantZeroElementS32s)689 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementS32s) {
690 XlaBuilder builder(TestName());
691 auto a = ConstantR1<int32>(&builder, {});
692 auto b = ConstantR1<int32>(&builder, {});
693 Mul(a, b);
694
695 ComputeAndCompareR1<int32>(&builder, {}, {});
696 }
697
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantU32s)698 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) {
699 std::vector<uint32> data = {0, 1, 0xDEADBEEF, 1234,
700 0x1a243514, 0xFFFFFFFF, 0x80808080};
701
702 // Form the test data set using all products of 'data' with itself.
703 std::vector<uint32> a_data, b_data, expected;
704 for (uint32 a : data) {
705 for (uint32 b : data) {
706 a_data.push_back(a);
707 b_data.push_back(b);
708 expected.push_back(a * b);
709 }
710 }
711
712 XlaBuilder builder(TestName());
713 auto a = ConstantR1<uint32>(&builder, a_data);
714 auto b = ConstantR1<uint32>(&builder, b_data);
715 Mul(a, b);
716
717 ComputeAndCompareR1<uint32>(&builder, expected, {});
718 }
719
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantC64s)720 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) {
721 XlaBuilder builder(TestName());
722 auto a = ConstantR1<complex64>(
723 &builder, {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}});
724 auto b = ConstantR1<complex64>(&builder,
725 {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}});
726 Mul(a, b);
727
728 ComputeAndCompareR1<complex64>(
729 &builder, {{0.0f, -25.0f}, {-25.5f, 127.5f}, {-40.0f, -112.0}}, {},
730 error_spec_);
731 }
732
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantZeroElementC64s)733 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementC64s) {
734 XlaBuilder builder(TestName());
735 auto a = ConstantR1<complex64>(&builder, {});
736 auto b = ConstantR1<complex64>(&builder, {});
737 Mul(a, b);
738
739 ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
740 }
741
XLA_TEST_F(ArrayElementwiseOpTest,AndPredR1)742 XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) {
743 XlaBuilder builder(TestName());
744 auto a = ConstantR1<bool>(&builder, {false, false, true, true});
745 auto b = ConstantR1<bool>(&builder, {false, true, false, true});
746 And(a, b);
747
748 ComputeAndCompareR1<bool>(&builder, {false, false, false, true}, {});
749 }
750
XLA_TEST_F(ArrayElementwiseOpTest,AndPredR2)751 XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) {
752 XlaBuilder builder(TestName());
753 auto a = ConstantR2<bool>(&builder, {{false, false}, {true, true}});
754 auto b = ConstantR2<bool>(&builder, {{false, true}, {false, true}});
755 And(a, b);
756
757 Array2D<bool> expected_array({{false, false}, {false, true}});
758 ComputeAndCompareR2<bool>(&builder, expected_array, {});
759 }
760
XLA_TEST_F(ArrayElementwiseOpTest,AndZeroElementPredR1)761 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementPredR1) {
762 XlaBuilder builder(TestName());
763 auto a = ConstantR1<bool>(&builder, {});
764 auto b = ConstantR1<bool>(&builder, {});
765 And(a, b);
766
767 ComputeAndCompareR1<bool>(&builder, {}, {});
768 }
769
XLA_TEST_F(ArrayElementwiseOpTest,AndS32R1)770 XLA_TEST_F(ArrayElementwiseOpTest, AndS32R1) {
771 XlaBuilder builder(TestName());
772 auto a = ConstantR1<int32>(&builder, {0, -1, -8});
773 auto b = ConstantR1<int32>(&builder, {5, -7, 12});
774 And(a, b);
775
776 ComputeAndCompareR1<int32>(&builder, {0, -7, 8}, {});
777 }
778
XLA_TEST_F(ArrayElementwiseOpTest,AndS32R2)779 XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) {
780 XlaBuilder builder(TestName());
781 auto a = ConstantR2<int32>(&builder, {{0, -5}, {-1, 5}});
782 auto b = ConstantR2<int32>(&builder, {{1, -6}, {4, 5}});
783 And(a, b);
784
785 Array2D<int32> expected_array({{0, -6}, {4, 5}});
786 ComputeAndCompareR2<int32>(&builder, expected_array, {});
787 }
788
XLA_TEST_F(ArrayElementwiseOpTest,AndZeroElementS32R1)789 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementS32R1) {
790 XlaBuilder builder(TestName());
791 auto a = ConstantR1<int32>(&builder, {});
792 auto b = ConstantR1<int32>(&builder, {});
793 And(a, b);
794
795 ComputeAndCompareR1<int32>(&builder, {}, {});
796 }
797
XLA_TEST_F(ArrayElementwiseOpTest,AndU32R1)798 XLA_TEST_F(ArrayElementwiseOpTest, AndU32R1) {
799 XlaBuilder builder(TestName());
800 auto a = ConstantR1<int32>(&builder, {0, 1, 8});
801 auto b = ConstantR1<int32>(&builder, {5, 7, 12});
802 And(a, b);
803
804 ComputeAndCompareR1<int32>(&builder, {0, 1, 8}, {});
805 }
806
XLA_TEST_F(ArrayElementwiseOpTest,AndU32R2)807 XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) {
808 XlaBuilder builder(TestName());
809 auto a = ConstantR2<uint32>(&builder, {{0, 1}, {3, 8}});
810 auto b = ConstantR2<uint32>(&builder, {{1, 0}, {7, 6}});
811 And(a, b);
812
813 Array2D<uint32> expected_array({{0, 0}, {3, 0}});
814 ComputeAndCompareR2<uint32>(&builder, expected_array, {});
815 }
816
XLA_TEST_F(ArrayElementwiseOpTest,AndZeroElementU32R1)817 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementU32R1) {
818 XlaBuilder builder(TestName());
819 auto a = ConstantR1<uint32>(&builder, {});
820 auto b = ConstantR1<uint32>(&builder, {});
821 And(a, b);
822
823 ComputeAndCompareR1<uint32>(&builder, {}, {});
824 }
825
XLA_TEST_F(ArrayElementwiseOpTest,OrPredR1)826 XLA_TEST_F(ArrayElementwiseOpTest, OrPredR1) {
827 XlaBuilder builder(TestName());
828 auto a = ConstantR1<bool>(&builder, {false, false, true, true});
829 auto b = ConstantR1<bool>(&builder, {false, true, false, true});
830 Or(a, b);
831
832 ComputeAndCompareR1<bool>(&builder, {false, true, true, true}, {});
833 }
834
XLA_TEST_F(ArrayElementwiseOpTest,OrPredR2)835 XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) {
836 XlaBuilder builder(TestName());
837 auto a = ConstantR2<bool>(&builder, {{false, false}, {true, true}});
838 auto b = ConstantR2<bool>(&builder, {{false, true}, {false, true}});
839 Or(a, b);
840
841 Array2D<bool> expected_array({{false, true}, {true, true}});
842 ComputeAndCompareR2<bool>(&builder, expected_array, {});
843 }
844
XLA_TEST_F(ArrayElementwiseOpTest,OrZeroElementPredR1)845 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementPredR1) {
846 XlaBuilder builder(TestName());
847 auto a = ConstantR1<bool>(&builder, {});
848 auto b = ConstantR1<bool>(&builder, {});
849 Or(a, b);
850
851 ComputeAndCompareR1<bool>(&builder, {}, {});
852 }
853
XLA_TEST_F(ArrayElementwiseOpTest,OrS32R1)854 XLA_TEST_F(ArrayElementwiseOpTest, OrS32R1) {
855 XlaBuilder builder(TestName());
856 auto a = ConstantR1<int32>(&builder, {0, -1, 8});
857 auto b = ConstantR1<int32>(&builder, {5, -7, 4});
858 Or(a, b);
859
860 ComputeAndCompareR1<int32>(&builder, {5, -1, 12}, {});
861 }
862
XLA_TEST_F(ArrayElementwiseOpTest,OrS32R2)863 XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) {
864 XlaBuilder builder(TestName());
865 auto a = ConstantR2<int32>(&builder, {{0, -1}, {8, 8}});
866 auto b = ConstantR2<int32>(&builder, {{5, -7}, {4, 1}});
867 Or(a, b);
868
869 Array2D<int32> expected_array({{5, -1}, {12, 9}});
870 ComputeAndCompareR2<int32>(&builder, expected_array, {});
871 }
872
XLA_TEST_F(ArrayElementwiseOpTest,OrZeroElementS32R1)873 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementS32R1) {
874 XlaBuilder builder(TestName());
875 auto a = ConstantR1<int32>(&builder, {});
876 auto b = ConstantR1<int32>(&builder, {});
877 Or(a, b);
878
879 ComputeAndCompareR1<int32>(&builder, {}, {});
880 }
881
XLA_TEST_F(ArrayElementwiseOpTest,OrU32R1)882 XLA_TEST_F(ArrayElementwiseOpTest, OrU32R1) {
883 XlaBuilder builder(TestName());
884 auto a = ConstantR1<uint32>(&builder, {0, 1, 8});
885 auto b = ConstantR1<uint32>(&builder, {5, 7, 4});
886 Or(a, b);
887
888 ComputeAndCompareR1<uint32>(&builder, {5, 7, 12}, {});
889 }
890
XLA_TEST_F(ArrayElementwiseOpTest,OrU32R2)891 XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) {
892 XlaBuilder builder(TestName());
893 auto a = ConstantR2<uint32>(&builder, {{0, 1}, {8, 8}});
894 auto b = ConstantR2<uint32>(&builder, {{5, 7}, {4, 1}});
895 Or(a, b);
896
897 Array2D<uint32> expected_array({{5, 7}, {12, 9}});
898 ComputeAndCompareR2<uint32>(&builder, expected_array, {});
899 }
900
XLA_TEST_F(ArrayElementwiseOpTest,OrZeroElementU32R1)901 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementU32R1) {
902 XlaBuilder builder(TestName());
903 auto a = ConstantR1<uint32>(&builder, {});
904 auto b = ConstantR1<uint32>(&builder, {});
905 Or(a, b);
906
907 ComputeAndCompareR1<uint32>(&builder, {}, {});
908 }
909
XLA_TEST_F(ArrayElementwiseOpTest,XorPredR1)910 XLA_TEST_F(ArrayElementwiseOpTest, XorPredR1) {
911 XlaBuilder builder(TestName());
912 auto a = ConstantR1<bool>(&builder, {false, false, true, true});
913 auto b = ConstantR1<bool>(&builder, {false, true, false, true});
914 Xor(a, b);
915
916 ComputeAndCompareR1<bool>(&builder, {false, true, true, false}, {});
917 }
918
XLA_TEST_F(ArrayElementwiseOpTest,XorPredR2)919 XLA_TEST_F(ArrayElementwiseOpTest, XorPredR2) {
920 XlaBuilder builder(TestName());
921 auto a = ConstantR2<bool>(&builder, {{false, false}, {true, true}});
922 auto b = ConstantR2<bool>(&builder, {{false, true}, {false, true}});
923 Xor(a, b);
924
925 Array2D<bool> expected_array({{false, true}, {true, false}});
926 ComputeAndCompareR2<bool>(&builder, expected_array, {});
927 }
928
XLA_TEST_F(ArrayElementwiseOpTest,XorZeroElementPredR1)929 XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementPredR1) {
930 XlaBuilder builder(TestName());
931 auto a = ConstantR1<bool>(&builder, {});
932 auto b = ConstantR1<bool>(&builder, {});
933 Xor(a, b);
934
935 ComputeAndCompareR1<bool>(&builder, {}, {});
936 }
937
XLA_TEST_F(ArrayElementwiseOpTest,XorS32R1)938 XLA_TEST_F(ArrayElementwiseOpTest, XorS32R1) {
939 XlaBuilder builder(TestName());
940 auto a = ConstantR1<int32>(&builder, {0, -1, 8});
941 auto b = ConstantR1<int32>(&builder, {5, -7, 4});
942 Xor(a, b);
943
944 ComputeAndCompareR1<int32>(&builder, {5, 6, 12}, {});
945 }
946
XLA_TEST_F(ArrayElementwiseOpTest,XorS32R2)947 XLA_TEST_F(ArrayElementwiseOpTest, XorS32R2) {
948 XlaBuilder builder(TestName());
949 auto a = ConstantR2<int32>(&builder, {{0, -1}, {8, 8}});
950 auto b = ConstantR2<int32>(&builder, {{5, -7}, {4, 1}});
951 Xor(a, b);
952
953 Array2D<int32> expected_array({{5, 6}, {12, 9}});
954 ComputeAndCompareR2<int32>(&builder, expected_array, {});
955 }
956
XLA_TEST_F(ArrayElementwiseOpTest,XorZeroElementS32R1)957 XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementS32R1) {
958 XlaBuilder builder(TestName());
959 auto a = ConstantR1<int32>(&builder, {});
960 auto b = ConstantR1<int32>(&builder, {});
961 Xor(a, b);
962
963 ComputeAndCompareR1<int32>(&builder, {}, {});
964 }
965
XLA_TEST_F(ArrayElementwiseOpTest,XorU32R1)966 XLA_TEST_F(ArrayElementwiseOpTest, XorU32R1) {
967 XlaBuilder builder(TestName());
968 auto a = ConstantR1<uint32>(&builder, {0, 1, 8});
969 auto b = ConstantR1<uint32>(&builder, {5, 7, 4});
970 Xor(a, b);
971
972 ComputeAndCompareR1<uint32>(&builder, {5, 6, 12}, {});
973 }
974
XLA_TEST_F(ArrayElementwiseOpTest,XorU32R2)975 XLA_TEST_F(ArrayElementwiseOpTest, XorU32R2) {
976 XlaBuilder builder(TestName());
977 auto a = ConstantR2<uint32>(&builder, {{0, 1}, {8, 8}});
978 auto b = ConstantR2<uint32>(&builder, {{5, 7}, {4, 1}});
979 Xor(a, b);
980
981 Array2D<uint32> expected_array({{5, 6}, {12, 9}});
982 ComputeAndCompareR2<uint32>(&builder, expected_array, {});
983 }
984
XLA_TEST_F(ArrayElementwiseOpTest,XorZeroElementU32R1)985 XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementU32R1) {
986 XlaBuilder builder(TestName());
987 auto a = ConstantR1<uint32>(&builder, {});
988 auto b = ConstantR1<uint32>(&builder, {});
989 Xor(a, b);
990
991 ComputeAndCompareR1<uint32>(&builder, {}, {});
992 }
XLA_TEST_F(ArrayElementwiseOpTest,NotPredR1)993 XLA_TEST_F(ArrayElementwiseOpTest, NotPredR1) {
994 XlaBuilder builder(TestName());
995 auto a = ConstantR1<bool>(&builder, {false, true, true, false});
996 Not(a);
997
998 ComputeAndCompareR1<bool>(&builder, {true, false, false, true}, {});
999 }
1000
XLA_TEST_F(ArrayElementwiseOpTest,NotPredR2)1001 XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) {
1002 XlaBuilder builder(TestName());
1003 auto a = ConstantR2<bool>(&builder, {{false, true}, {true, false}});
1004 Not(a);
1005
1006 Array2D<bool> expected_array({{true, false}, {false, true}});
1007 ComputeAndCompareR2<bool>(&builder, expected_array, {});
1008 }
1009
XLA_TEST_F(ArrayElementwiseOpTest,NotZeroElementPredR1)1010 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementPredR1) {
1011 XlaBuilder builder(TestName());
1012 auto a = ConstantR1<bool>(&builder, {});
1013 Not(a);
1014
1015 ComputeAndCompareR1<bool>(&builder, {}, {});
1016 }
1017
XLA_TEST_F(ArrayElementwiseOpTest,NotS32R1)1018 XLA_TEST_F(ArrayElementwiseOpTest, NotS32R1) {
1019 XlaBuilder builder(TestName());
1020 auto a = ConstantR1<int32>(&builder, {-1, 0, 1});
1021 Not(a);
1022
1023 ComputeAndCompareR1<int32>(&builder, {0, -1, -2}, {});
1024 }
1025
XLA_TEST_F(ArrayElementwiseOpTest,NotS32R2)1026 XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) {
1027 XlaBuilder builder(TestName());
1028 auto a = ConstantR2<int32>(&builder, {{-1, 0}, {1, 8}});
1029 Not(a);
1030
1031 Array2D<int32> expected_array({{0, -1}, {-2, -9}});
1032 ComputeAndCompareR2<int32>(&builder, expected_array, {});
1033 }
1034
XLA_TEST_F(ArrayElementwiseOpTest,NotZeroElementS32R1)1035 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementS32R1) {
1036 XlaBuilder builder(TestName());
1037 auto a = ConstantR1<int32>(&builder, {});
1038 Not(a);
1039
1040 ComputeAndCompareR1<int32>(&builder, {}, {});
1041 }
1042
XLA_TEST_F(ArrayElementwiseOpTest,NotU32R1)1043 XLA_TEST_F(ArrayElementwiseOpTest, NotU32R1) {
1044 XlaBuilder builder(TestName());
1045 auto a = ConstantR1<uint32>(&builder, {0, 4294967295});
1046 Not(a);
1047
1048 ComputeAndCompareR1<uint32>(&builder, {4294967295, 0}, {});
1049 }
1050
XLA_TEST_F(ArrayElementwiseOpTest,NotU32R2)1051 XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) {
1052 XlaBuilder builder(TestName());
1053 auto a = ConstantR2<uint32>(&builder, {{0, 4294967295}, {1, 4294967294}});
1054 Not(a);
1055
1056 Array2D<uint32> expected_array({{4294967295, 0}, {4294967294, 1}});
1057 ComputeAndCompareR2<uint32>(&builder, expected_array, {});
1058 }
1059
XLA_TEST_F(ArrayElementwiseOpTest,NotZeroElementU32R1)1060 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) {
1061 XlaBuilder builder(TestName());
1062 auto a = ConstantR1<uint32>(&builder, {});
1063 Not(a);
1064
1065 ComputeAndCompareR1<uint32>(&builder, {}, {});
1066 }
1067
XLA_TEST_F(ArrayElementwiseOpTest,ShiftLeftS32)1068 XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) {
1069 XlaBuilder builder(TestName());
1070 auto a = ConstantR1<int32>(
1071 &builder, {static_cast<int32>(0x12345678), static_cast<int32>(0xF0001000),
1072 1, 3, 77, 1, -3, 77});
1073 auto b = ConstantR1<int32>(&builder, {4, 8, 2, 7, 15, 32, 100, -1});
1074 ShiftLeft(a, b);
1075
1076 ComputeAndCompareR1<int32>(&builder,
1077 {static_cast<int32>(0x23456780), 0x00100000, 0x4,
1078 0x180, 2523136, 0, 0, 0},
1079 {});
1080 }
1081
XLA_TEST_F(ArrayElementwiseOpTest,ShiftRightArithmeticS32)1082 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) {
1083 XlaBuilder builder(TestName());
1084 auto a = ConstantR1<int32>(
1085 &builder, {static_cast<int32>(0x92345678), static_cast<int32>(0x10001000),
1086 1, 3, 77, 1, -3, 77});
1087 auto b = ConstantR1<int32>(&builder, {4, 8, 2, 7, 2, 32, 100, -1});
1088 ShiftRightArithmetic(a, b);
1089
1090 ComputeAndCompareR1<int32>(
1091 &builder,
1092 {static_cast<int32>(0xF9234567), static_cast<int32>(0x00100010), 0, 0, 19,
1093 0, -1, 0},
1094 {});
1095 }
1096
XLA_TEST_F(ArrayElementwiseOpTest,ShiftRightLogicalS32)1097 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) {
1098 XlaBuilder builder(TestName());
1099 auto a = ConstantR1<int32>(
1100 &builder, {static_cast<int32>(0x92345678), static_cast<int32>(0x10001000),
1101 1, 3, 77, 1, -3, 77});
1102 auto b = ConstantR1<int32>(&builder, {4, 8, 2, 7, 5, 32, 100, -1});
1103 ShiftRightLogical(a, b);
1104
1105 ComputeAndCompareR1<int32>(&builder,
1106 {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {});
1107 }
1108
XLA_TEST_F(ArrayElementwiseOpTest,ShiftLeftU32)1109 XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) {
1110 XlaBuilder builder(TestName());
1111 auto a = ConstantR1<uint32>(&builder,
1112 {0x12345678, 0xF0001000, 1, 3, 77, 1, ~3u, 77});
1113 auto b = ConstantR1<uint32>(&builder, {4, 8, 2, 7, 15, 32, 100, ~0u});
1114 ShiftLeft(a, b);
1115
1116 ComputeAndCompareR1<uint32>(
1117 &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136, 0, 0, 0}, {});
1118 }
1119
XLA_TEST_F(ArrayElementwiseOpTest,ShiftRightArithmeticU32)1120 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) {
1121 XlaBuilder builder(TestName());
1122 auto a = ConstantR1<uint32>(&builder,
1123 {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77});
1124 auto b = ConstantR1<uint32>(&builder, {4, 8, 2, 7, 2, 32, 100, ~0u});
1125 ShiftRightArithmetic(a, b);
1126
1127 ComputeAndCompareR1<uint32>(
1128 &builder, {0xF9234567, 0x00100010, 0, 0, 19, 0, ~0u, 0}, {});
1129 }
1130
XLA_TEST_F(ArrayElementwiseOpTest,ShiftRightLogicalU32)1131 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) {
1132 XlaBuilder builder(TestName());
1133 auto a = ConstantR1<uint32>(&builder,
1134 {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77});
1135 auto b = ConstantR1<uint32>(&builder, {4, 8, 2, 7, 5, 32, 100, ~0u});
1136 ShiftRightLogical(a, b);
1137
1138 ComputeAndCompareR1<uint32>(&builder,
1139 {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {});
1140 }
1141
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqF32s)1142 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) {
1143 SetFastMathDisabled(true);
1144 XlaBuilder builder(TestName());
1145 auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1146 auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 2.25f, 10.0f, NAN});
1147 Eq(lhs, rhs);
1148
1149 ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {});
1150 }
1151
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqZeroElementF32s)1152 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) {
1153 XlaBuilder builder(TestName());
1154 auto lhs = ConstantR1<float>(&builder, {});
1155 auto rhs = ConstantR1<float>(&builder, {});
1156 Eq(lhs, rhs);
1157
1158 ComputeAndCompareR1<bool>(&builder, {}, {});
1159 }
1160
XLA_TEST_F(ArrayElementwiseOpTest,CompareGeF32s)1161 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) {
1162 SetFastMathDisabled(true);
1163 XlaBuilder builder(TestName());
1164 auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1165 auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
1166 Ge(lhs, rhs);
1167
1168 ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
1169 }
1170
XLA_TEST_F(ArrayElementwiseOpTest,CompareGtF32s)1171 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) {
1172 SetFastMathDisabled(true);
1173 XlaBuilder builder(TestName());
1174 auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1175 auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
1176 Gt(lhs, rhs);
1177
1178 ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
1179 }
1180
XLA_TEST_F(ArrayElementwiseOpTest,CompareLeF32s)1181 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeF32s) {
1182 SetFastMathDisabled(true);
1183 XlaBuilder builder(TestName());
1184 auto lhs = ConstantR1<float>(&builder, {-2.5f, 5.0f, 2.25f, NAN, 6.0f});
1185 auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
1186 Le(lhs, rhs);
1187
1188 ComputeAndCompareR1<bool>(&builder, {true, true, false, false, false}, {});
1189 }
1190
XLA_TEST_F(ArrayElementwiseOpTest,CompareLtF32s)1191 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtF32s) {
1192 SetFastMathDisabled(true);
1193 XlaBuilder builder(TestName());
1194 auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1195 auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
1196 Lt(lhs, rhs);
1197
1198 ComputeAndCompareR1<bool>(&builder, {true, false, false, false, false}, {});
1199 }
1200
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqS32s)1201 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) {
1202 const int32 min = std::numeric_limits<int32>::min();
1203 const int32 max = std::numeric_limits<int32>::max();
1204 XlaBuilder builder(TestName());
1205 auto lhs =
1206 ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1207 auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1208 Eq(lhs, rhs);
1209
1210 ComputeAndCompareR1<bool>(
1211 &builder, {true, false, false, false, true, false, false, false, true},
1212 {});
1213 }
1214
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqZeroElementS32s)1215 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) {
1216 XlaBuilder builder(TestName());
1217 auto lhs = ConstantR1<int32>(&builder, {});
1218 auto rhs = ConstantR1<int32>(&builder, {});
1219 Eq(lhs, rhs);
1220
1221 ComputeAndCompareR1<bool>(&builder, {}, {});
1222 }
1223
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqC64s)1224 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) {
1225 SetFastMathDisabled(true);
1226 XlaBuilder builder(TestName());
1227 auto lhs = ConstantR1<complex64>(&builder, {{-2.5f, 10.0f},
1228 {1.0f, 25.5f},
1229 {2.25f, -3.0f},
1230 {NAN, 0.0f},
1231 {1.0f, 6.0f}});
1232 auto rhs = ConstantR1<complex64>(&builder, {{0.0f, 10.0f},
1233 {1.0f, 5.0f},
1234 {2.25f, -3.0f},
1235 {10.0f, 0.0f},
1236 {1.0f, NAN}});
1237 Eq(lhs, rhs);
1238
1239 ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {});
1240 }
1241
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqZeroElementC64s)1242 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementC64s) {
1243 XlaBuilder builder(TestName());
1244 auto lhs = ConstantR1<complex64>(&builder, {});
1245 auto rhs = ConstantR1<complex64>(&builder, {});
1246 Eq(lhs, rhs);
1247
1248 ComputeAndCompareR1<bool>(&builder, {}, {});
1249 }
1250
XLA_TEST_F(ArrayElementwiseOpTest,CompareNeC64s)1251 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) {
1252 // Disable fast-math because we're operating on NaNs.
1253 SetFastMathDisabled(true);
1254
1255 XlaBuilder builder(TestName());
1256 auto lhs = ConstantR1<complex64>(&builder, {{-2.5f, 10.0f},
1257 {1.0f, 25.5f},
1258 {2.25f, -3.0f},
1259 {NAN, 0.0f},
1260 {1.0f, 6.0f}});
1261 auto rhs = ConstantR1<complex64>(&builder, {{0.0f, 10.0f},
1262 {1.0f, 5.0f},
1263 {2.25f, -3.0f},
1264 {10.0f, 0.0f},
1265 {1.0f, NAN}});
1266 Ne(lhs, rhs);
1267
1268 ComputeAndCompareR1<bool>(&builder, {true, true, false, true, true}, {});
1269 }
1270
XLA_TEST_F(ArrayElementwiseOpTest,CompareNeF32s)1271 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) {
1272 // Disable fast-math because we're operating on NaNs.
1273 SetFastMathDisabled(true);
1274
1275 XlaBuilder builder(TestName());
1276 auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1277 auto rhs = ConstantR1<float>(&builder, {10.0f, 25.5f, 1.0f, 10.0f, NAN});
1278 Ne(lhs, rhs);
1279
1280 ComputeAndCompareR1<bool>(&builder, {true, false, true, true, true}, {});
1281 }
1282
XLA_TEST_F(ArrayElementwiseOpTest,CompareNeS32s)1283 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeS32s) {
1284 const int32 min = std::numeric_limits<int32>::min();
1285 const int32 max = std::numeric_limits<int32>::max();
1286 XlaBuilder builder(TestName());
1287 auto lhs =
1288 ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1289 auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1290 Ne(lhs, rhs);
1291
1292 ComputeAndCompareR1<bool>(
1293 &builder, {false, true, true, true, false, true, true, true, false}, {});
1294 }
1295
XLA_TEST_F(ArrayElementwiseOpTest,CompareGeS32s)1296 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeS32s) {
1297 const int32 min = std::numeric_limits<int32>::min();
1298 const int32 max = std::numeric_limits<int32>::max();
1299 XlaBuilder builder(TestName());
1300 auto lhs =
1301 ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1302 auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1303 Ge(lhs, rhs);
1304
1305 ComputeAndCompareR1<bool>(
1306 &builder, {true, false, false, true, true, false, true, true, true}, {});
1307 }
1308
XLA_TEST_F(ArrayElementwiseOpTest,CompareGtS32s)1309 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtS32s) {
1310 const int32 min = std::numeric_limits<int32>::min();
1311 const int32 max = std::numeric_limits<int32>::max();
1312 XlaBuilder builder(TestName());
1313 auto lhs =
1314 ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1315 auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1316 Gt(lhs, rhs);
1317
1318 ComputeAndCompareR1<bool>(
1319 &builder, {false, false, false, true, false, false, true, true, false},
1320 {});
1321 }
1322
XLA_TEST_F(ArrayElementwiseOpTest,CompareLeS32s)1323 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeS32s) {
1324 const int32 min = std::numeric_limits<int32>::min();
1325 const int32 max = std::numeric_limits<int32>::max();
1326 XlaBuilder builder(TestName());
1327 auto lhs =
1328 ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1329 auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1330 Le(lhs, rhs);
1331
1332 ComputeAndCompareR1<bool>(
1333 &builder, {true, true, true, false, true, true, false, false, true}, {});
1334 }
1335
XLA_TEST_F(ArrayElementwiseOpTest,CompareLtS32s)1336 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) {
1337 const int32 min = std::numeric_limits<int32>::min();
1338 const int32 max = std::numeric_limits<int32>::max();
1339 XlaBuilder builder(TestName());
1340 auto lhs =
1341 ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1342 auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1343 Lt(lhs, rhs);
1344
1345 ComputeAndCompareR1<bool>(
1346 &builder, {false, true, true, false, false, true, false, false, false},
1347 {});
1348 }
1349
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqU32s)1350 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) {
1351 const uint32 max = std::numeric_limits<uint32>::max();
1352 XlaBuilder builder(TestName());
1353 auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1354 auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1355 Eq(lhs, rhs);
1356
1357 ComputeAndCompareR1<bool>(
1358 &builder, {true, false, false, false, true, false, false, false, true},
1359 {});
1360 }
1361
XLA_TEST_F(ArrayElementwiseOpTest,CompareNeU32s)1362 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) {
1363 const uint32 max = std::numeric_limits<uint32>::max();
1364 XlaBuilder builder(TestName());
1365 auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1366 auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1367 Ne(lhs, rhs);
1368
1369 ComputeAndCompareR1<bool>(
1370 &builder, {false, true, true, true, false, true, true, true, false}, {});
1371 }
1372
XLA_TEST_F(ArrayElementwiseOpTest,CompareGeU32s)1373 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) {
1374 const uint32 max = std::numeric_limits<uint32>::max();
1375 XlaBuilder builder(TestName());
1376 auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1377 auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1378 Ge(lhs, rhs);
1379
1380 ComputeAndCompareR1<bool>(
1381 &builder, {true, false, false, true, true, false, true, true, true}, {});
1382 }
1383
XLA_TEST_F(ArrayElementwiseOpTest,CompareGtU32s)1384 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) {
1385 const uint32 max = std::numeric_limits<uint32>::max();
1386 XlaBuilder builder(TestName());
1387 auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1388 auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1389 Gt(lhs, rhs);
1390
1391 ComputeAndCompareR1<bool>(
1392 &builder, {false, false, false, true, false, false, true, true, false},
1393 {});
1394 }
1395
XLA_TEST_F(ArrayElementwiseOpTest,CompareLeU32s)1396 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) {
1397 const uint32 max = std::numeric_limits<uint32>::max();
1398 XlaBuilder builder(TestName());
1399 auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1400 auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1401 Le(lhs, rhs);
1402
1403 ComputeAndCompareR1<bool>(
1404 &builder, {true, true, true, false, true, true, false, false, true}, {});
1405 }
1406
XLA_TEST_F(ArrayElementwiseOpTest,CompareLtU32s)1407 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) {
1408 const uint32 max = std::numeric_limits<uint32>::max();
1409 XlaBuilder builder(TestName());
1410 auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1411 auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1412 Lt(lhs, rhs);
1413
1414 ComputeAndCompareR1<bool>(
1415 &builder, {false, true, true, false, false, true, false, false, false},
1416 {});
1417 }
1418
XLA_TEST_F(ArrayElementwiseOpTest,PowF32s)1419 XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) {
1420 SetFastMathDisabled(true);
1421 XlaBuilder builder(TestName());
1422 auto lhs =
1423 ConstantR1<float>(&builder, {4.0f, 2.0f, 2.0f, NAN, 6.0f, -2.0f, -2.0f});
1424 auto rhs =
1425 ConstantR1<float>(&builder, {2.0f, -2.0f, 3.0f, 10.0f, NAN, 3.0f, 4.0f});
1426 Pow(lhs, rhs);
1427
1428 ComputeAndCompareR1<float>(
1429 &builder, {16.0f, 0.25f, 8.0f, NAN, NAN, -8.0f, 16.0f}, {}, error_spec_);
1430 }
1431
XLA_TEST_F(ArrayElementwiseOpTest,PowNonIntegerF32s)1432 XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) {
1433 SetFastMathDisabled(true);
1434 XlaBuilder builder(TestName());
1435 auto lhs = ConstantR1<float>(&builder, {-2.0f, -0.6f, -0.6f, 0.0f});
1436 auto rhs = ConstantR1<float>(&builder, {0.5f, 0.6f, -0.6f, -0.6f});
1437 Pow(lhs, rhs);
1438
1439 ComputeAndCompareR1<float>(&builder, {NAN, NAN, NAN, INFINITY}, {},
1440 error_spec_);
1441 }
1442
XLA_TEST_F(ArrayElementwiseOpTest,PowC64s)1443 XLA_TEST_F(ArrayElementwiseOpTest, PowC64s) {
1444 SetFastMathDisabled(true);
1445 XlaBuilder builder(TestName());
1446 auto lhs =
1447 ConstantR1<complex64>(&builder, {-2.0f, -0.6f, -0.6f, 0.0f, 0.0f, 0.0f});
1448 auto rhs =
1449 ConstantR1<complex64>(&builder, {0.5f, 0.6f, -0.6f, 0.5f, 0.6f, 0.0f});
1450 Pow(lhs, rhs);
1451
1452 ComputeAndCompareR1<complex64>(&builder,
1453 {
1454 {0, 1.41421356},
1455 {-2.27443288e-01, 0.69999846},
1456 {-4.19847531e-01, -1.29215783},
1457 {0, 0},
1458 {0, 0},
1459 {1, 0},
1460 },
1461 {}, error_spec_);
1462 }
1463
XLA_TEST_F(ArrayElementwiseOpTest,PowZeroElementF32s)1464 XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) {
1465 XlaBuilder builder(TestName());
1466 auto lhs = ConstantR1<float>(&builder, {});
1467 auto rhs = ConstantR1<float>(&builder, {});
1468 Pow(lhs, rhs);
1469
1470 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
1471 }
1472
1473 // Some Pow cases that can be implemented more efficiently.
XLA_TEST_F(ArrayElementwiseOpTest,PowSpecialF32)1474 XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
1475 XlaBuilder b(TestName());
1476
1477 std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f};
1478 std::vector<float> exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1479
1480 Literal param_literal = LiteralUtil::CreateR1<float>(values);
1481 std::unique_ptr<GlobalData> param_data =
1482 client_->TransferToServer(param_literal).ConsumeValueOrDie();
1483
1484 auto sum = ConstantR0<float>(&b, 0.0f);
1485 auto param = Parameter(&b, 0, param_literal.shape(), "param");
1486 for (float exponent : exponents) {
1487 sum = Add(sum, Pow(param, ConstantR0<float>(&b, exponent)));
1488 }
1489
1490 std::vector<float> expected;
1491 for (auto value : values) {
1492 float sum = 0.0f;
1493 for (float exponent : exponents) {
1494 sum += std::pow(value, exponent);
1495 }
1496 expected.push_back(sum);
1497 }
1498
1499 ComputeAndCompareR1<float>(&b, expected, {param_data.get()}, error_spec_);
1500 }
1501
XLA_TEST_F(ArrayElementwiseOpTest,PowOfExpF32)1502 XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) {
1503 XlaBuilder b(TestName());
1504
1505 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
1506 std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1507
1508 Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1509 std::unique_ptr<GlobalData> data0 =
1510 client_->TransferToServer(literal0).ConsumeValueOrDie();
1511 Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1512 std::unique_ptr<GlobalData> data1 =
1513 client_->TransferToServer(literal1).ConsumeValueOrDie();
1514 auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1515 auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1516 Pow(Exp(param0), param1);
1517
1518 std::vector<float> expected(values0.size());
1519 for (int64 i = 0; i < values0.size(); ++i) {
1520 expected[i] = std::pow(std::exp(values0[i]), values1[i]);
1521 }
1522
1523 ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1524 error_spec_);
1525 }
1526
XLA_TEST_F(ArrayElementwiseOpTest,LogOfPowerF32)1527 XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) {
1528 XlaBuilder b(TestName());
1529
1530 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f};
1531 std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1532
1533 Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1534 std::unique_ptr<GlobalData> data0 =
1535 client_->TransferToServer(literal0).ConsumeValueOrDie();
1536 Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1537 std::unique_ptr<GlobalData> data1 =
1538 client_->TransferToServer(literal1).ConsumeValueOrDie();
1539 auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1540 auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1541 Log(Pow(param0, param1));
1542
1543 std::vector<float> expected(values0.size());
1544 for (int64 i = 0; i < values0.size(); ++i) {
1545 expected[i] = std::log(std::pow(values0[i], values1[i]));
1546 }
1547
1548 ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1549 error_spec_);
1550 }
1551
XLA_TEST_F(ArrayElementwiseOpTest,MulOfExpF32)1552 XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) {
1553 XlaBuilder b(TestName());
1554
1555 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
1556 std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1557
1558 Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1559 std::unique_ptr<GlobalData> data0 =
1560 client_->TransferToServer(literal0).ConsumeValueOrDie();
1561 Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1562 std::unique_ptr<GlobalData> data1 =
1563 client_->TransferToServer(literal1).ConsumeValueOrDie();
1564 auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1565 auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1566 Mul(Exp(param0), Exp(param1));
1567
1568 std::vector<float> expected(values0.size());
1569 for (int64 i = 0; i < values0.size(); ++i) {
1570 expected[i] = std::exp(values0[i]) * std::exp(values1[i]);
1571 }
1572
1573 ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1574 error_spec_);
1575 }
1576
XLA_TEST_F(ArrayElementwiseOpTest,DivOfExpF32)1577 XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) {
1578 XlaBuilder b(TestName());
1579
1580 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
1581 std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1582
1583 Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1584 std::unique_ptr<GlobalData> data0 =
1585 client_->TransferToServer(literal0).ConsumeValueOrDie();
1586 Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1587 std::unique_ptr<GlobalData> data1 =
1588 client_->TransferToServer(literal1).ConsumeValueOrDie();
1589 auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1590 auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1591 Div(param0, Exp(param1));
1592
1593 std::vector<float> expected(values0.size());
1594 for (int64 i = 0; i < values0.size(); ++i) {
1595 expected[i] = values0[i] / std::exp(values1[i]);
1596 }
1597
1598 ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1599 error_spec_);
1600 }
1601
XLA_TEST_F(ArrayElementwiseOpTest,Div3_lhs_F32)1602 XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) {
1603 XlaBuilder b(TestName());
1604
1605 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1606 std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1607 std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
1608
1609 Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1610 std::unique_ptr<GlobalData> data0 =
1611 client_->TransferToServer(literal0).ConsumeValueOrDie();
1612
1613 Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1614 std::unique_ptr<GlobalData> data1 =
1615 client_->TransferToServer(literal1).ConsumeValueOrDie();
1616
1617 Literal literal2 = LiteralUtil::CreateR1<float>(values2);
1618 std::unique_ptr<GlobalData> data2 =
1619 client_->TransferToServer(literal2).ConsumeValueOrDie();
1620 auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1621 auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1622 auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
1623 Div(Div(param0, param1), param2);
1624
1625 std::vector<float> expected(values0.size());
1626 for (int64 i = 0; i < values0.size(); ++i) {
1627 expected[i] = (values0[i] / values1[i]) / values2[i];
1628 }
1629
1630 ComputeAndCompareR1<float>(
1631 &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
1632 }
1633
XLA_TEST_F(ArrayElementwiseOpTest,Div3_rhs_F32)1634 XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) {
1635 XlaBuilder b(TestName());
1636
1637 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1638 std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1639 std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
1640
1641 Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1642 std::unique_ptr<GlobalData> data0 =
1643 client_->TransferToServer(literal0).ConsumeValueOrDie();
1644
1645 Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1646 std::unique_ptr<GlobalData> data1 =
1647 client_->TransferToServer(literal1).ConsumeValueOrDie();
1648
1649 Literal literal2 = LiteralUtil::CreateR1<float>(values2);
1650 std::unique_ptr<GlobalData> data2 =
1651 client_->TransferToServer(literal2).ConsumeValueOrDie();
1652
1653 auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1654 auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1655 auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
1656 Div(param0, Div(param1, param2));
1657
1658 std::vector<float> expected(values0.size());
1659 for (int64 i = 0; i < values0.size(); ++i) {
1660 expected[i] = values0[i] / (values1[i] / values2[i]);
1661 }
1662
1663 ComputeAndCompareR1<float>(
1664 &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
1665 }
1666
XLA_TEST_F(ArrayElementwiseOpTest,DivOfPowerF32)1667 XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) {
1668 XlaBuilder b(TestName());
1669
1670 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1671 std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f};
1672 std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f};
1673
1674 Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1675 std::unique_ptr<GlobalData> data0 =
1676 client_->TransferToServer(literal0).ConsumeValueOrDie();
1677
1678 Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1679 std::unique_ptr<GlobalData> data1 =
1680 client_->TransferToServer(literal1).ConsumeValueOrDie();
1681
1682 Literal literal2 = LiteralUtil::CreateR1<float>(values2);
1683 std::unique_ptr<GlobalData> data2 =
1684 client_->TransferToServer(literal2).ConsumeValueOrDie();
1685
1686 auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1687 auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1688 auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
1689 Div(param0, Pow(param1, param2));
1690
1691 std::vector<float> expected(values0.size());
1692 for (int64 i = 0; i < values0.size(); ++i) {
1693 expected[i] = values0[i] / std::pow(values1[i], values2[i]);
1694 }
1695
1696 ComputeAndCompareR1<float>(
1697 &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
1698 }
1699
XLA_TEST_F(ArrayElementwiseOpTest,Div4F32)1700 XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) {
1701 XlaBuilder b(TestName());
1702
1703 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1704 std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1705 std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
1706 std::vector<float> values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f};
1707
1708 Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1709 std::unique_ptr<GlobalData> data0 =
1710 client_->TransferToServer(literal0).ConsumeValueOrDie();
1711
1712 Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1713 std::unique_ptr<GlobalData> data1 =
1714 client_->TransferToServer(literal1).ConsumeValueOrDie();
1715
1716 Literal literal2 = LiteralUtil::CreateR1<float>(values2);
1717 std::unique_ptr<GlobalData> data2 =
1718 client_->TransferToServer(literal2).ConsumeValueOrDie();
1719
1720 Literal literal3 = LiteralUtil::CreateR1<float>(values3);
1721 std::unique_ptr<GlobalData> data3 =
1722 client_->TransferToServer(literal3).ConsumeValueOrDie();
1723
1724 auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1725 auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1726 auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
1727 auto param3 = Parameter(&b, 3, literal3.shape(), "param2");
1728 Div(Div(param0, param1), Div(param2, param3));
1729
1730 std::vector<float> expected(values0.size());
1731 for (int64 i = 0; i < values0.size(); ++i) {
1732 expected[i] = (values0[i] / values1[i]) / (values2[i] / values3[i]);
1733 }
1734
1735 ComputeAndCompareR1<float>(
1736 &b, expected, {data0.get(), data1.get(), data2.get(), data3.get()},
1737 error_spec_);
1738 }
1739
TEST_P(ArrayElementwiseOpTestParamCount,SquareManyValues)1740 TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) {
1741 const int count = GetParam();
1742 XlaBuilder builder(TestName());
1743 std::vector<float> values;
1744 values.reserve(count);
1745 for (int i = 0; i < count; ++i) {
1746 values.push_back(i / static_cast<float>(count));
1747 }
1748 auto x = ConstantR1<float>(&builder, values);
1749 Pow(x, ConstantR0<float>(&builder, 2.0f));
1750
1751 std::vector<float> expected;
1752 expected.reserve(values.size());
1753 for (float value : values) {
1754 expected.push_back(value * value);
1755 }
1756
1757 ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
1758 }
1759
XLA_TEST_F(ArrayElementwiseOpTest,SquareIn4D)1760 XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4D) {
1761 XlaBuilder builder(TestName());
1762 Array4D<float> values(2, 2, 2, 2);
1763
1764 std::vector<float> values_vector;
1765 std::vector<float> expected_vector;
1766 for (int i = 0; i < values.num_elements(); ++i) {
1767 values_vector.push_back(static_cast<float>(i) / values.num_elements());
1768 expected_vector.push_back(values_vector.back() * values_vector.back());
1769 }
1770 values.SetValues(values_vector);
1771
1772 Array4D<float> expected(2, 2, 2, 2, expected_vector);
1773
1774 auto x = ConstantR4FromArray4D<float>(&builder, values);
1775 Pow(x, ConstantR0<float>(&builder, 2.0f));
1776
1777 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
1778 }
1779
XLA_TEST_F(ArrayElementwiseOpTest,SquareIn4DZeroElements)1780 XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) {
1781 XlaBuilder builder(TestName());
1782 Array4D<float> values(2, 2, 0, 2);
1783 Array4D<float> expected(2, 2, 0, 2);
1784
1785 auto x = ConstantR4FromArray4D<float>(&builder, values);
1786 Pow(x, ConstantR0<float>(&builder, 2.0f));
1787
1788 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
1789 }
1790
XLA_TEST_F(ArrayElementwiseOpTest,MinF32s)1791 XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) {
1792 XlaBuilder builder(TestName());
1793 SetFastMathDisabled(true);
1794 auto lhs = ConstantR1<float>(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f});
1795 auto rhs = ConstantR1<float>(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN});
1796 Min(lhs, rhs);
1797
1798 ComputeAndCompareR1<float>(&builder, {1.0f, -5.0f, 1.0f, NAN, NAN}, {},
1799 error_spec_);
1800 }
1801
XLA_TEST_F(ArrayElementwiseOpTest,MinZeroElementF32s)1802 XLA_TEST_F(ArrayElementwiseOpTest, MinZeroElementF32s) {
1803 XlaBuilder builder(TestName());
1804 auto lhs = ConstantR1<float>(&builder, {});
1805 auto rhs = ConstantR1<float>(&builder, {});
1806 Min(lhs, rhs);
1807 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
1808 }
1809
XLA_TEST_F(ArrayElementwiseOpTest,MinF64s)1810 XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) {
1811 XlaBuilder builder(TestName());
1812 SetFastMathDisabled(true);
1813 auto lhs = ConstantR1<double>(&builder, {1.0, 1.0, 2.25, NAN, 6.0});
1814 auto rhs = ConstantR1<double>(&builder, {2.0, -5.0, 1.0, 10.0, NAN});
1815 Min(lhs, rhs);
1816
1817 ComputeAndCompareR1<double>(&builder, {1.0, -5.0, 1.0, NAN, NAN}, {},
1818 error_spec_);
1819 }
1820
XLA_TEST_F(ArrayElementwiseOpTest,MaxF32s)1821 XLA_TEST_F(ArrayElementwiseOpTest, MaxF32s) {
1822 XlaBuilder builder(TestName());
1823 SetFastMathDisabled(true);
1824 auto lhs = ConstantR1<float>(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f});
1825 auto rhs = ConstantR1<float>(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN});
1826 Max(lhs, rhs);
1827
1828 ComputeAndCompareR1<float>(&builder, {2.0f, 1.0f, 2.25f, NAN, NAN}, {},
1829 error_spec_);
1830 }
1831
XLA_TEST_F(ArrayElementwiseOpTest,MaxZeroElementF32s)1832 XLA_TEST_F(ArrayElementwiseOpTest, MaxZeroElementF32s) {
1833 XlaBuilder builder(TestName());
1834 auto lhs = ConstantR1<float>(&builder, {});
1835 auto rhs = ConstantR1<float>(&builder, {});
1836 Max(lhs, rhs);
1837 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
1838 }
1839
XLA_TEST_F(ArrayElementwiseOpTest,MaxF64s)1840 XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) {
1841 XlaBuilder builder(TestName());
1842 SetFastMathDisabled(true);
1843 auto lhs = ConstantR1<double>(&builder, {1.0, 1.0, 2.25, NAN, 6.0});
1844 auto rhs = ConstantR1<double>(&builder, {2.0, -5.0, 1.0, 10.0, NAN});
1845 Max(lhs, rhs);
1846
1847 ComputeAndCompareR1<double>(&builder, {2.0, 1.0, 2.25, NAN, NAN}, {},
1848 error_spec_);
1849 }
1850
XLA_TEST_F(ArrayElementwiseOpTest,MaxS32s)1851 XLA_TEST_F(ArrayElementwiseOpTest, MaxS32s) {
1852 const int32 min = std::numeric_limits<int32>::min();
1853 const int32 max = std::numeric_limits<int32>::max();
1854 XlaBuilder builder(TestName());
1855 auto x = ConstantR1<int32>(
1856 &builder, {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max});
1857 auto y = ConstantR1<int32>(
1858 &builder, {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min});
1859 Max(x, y);
1860
1861 std::vector<int32> expected = {min, max, 0, -1, 0, 0, 0,
1862 1, 1, 10, max, max, max};
1863 ComputeAndCompareR1<int32>(&builder, expected, {});
1864 }
1865
XLA_TEST_F(ArrayElementwiseOpTest,MinS32s)1866 XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) {
1867 const int32 min = std::numeric_limits<int32>::min();
1868 const int32 max = std::numeric_limits<int32>::max();
1869 XlaBuilder builder(TestName());
1870 auto x = ConstantR1<int32>(
1871 &builder, {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max});
1872 auto y = ConstantR1<int32>(
1873 &builder, {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min});
1874 Min(x, y);
1875
1876 std::vector<int32> expected = {min, min, min, -10, -1, -1, 0,
1877 0, 0, 1, 0, max, min};
1878 ComputeAndCompareR1<int32>(&builder, expected, {});
1879 }
1880
XLA_TEST_F(ArrayElementwiseOpTest,MaxU32s)1881 XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) {
1882 const uint32 max = std::numeric_limits<uint32>::max();
1883 XlaBuilder builder(TestName());
1884 auto x = ConstantR1<uint32>(&builder, {0, 0, 1, 1, 1, max, max, max});
1885 auto y = ConstantR1<uint32>(&builder, {0, 1, 0, 1, 10, 0, 234234, max});
1886 Max(x, y);
1887
1888 std::vector<uint32> expected = {0, 1, 1, 1, 10, max, max, max};
1889 ComputeAndCompareR1<uint32>(&builder, expected, {});
1890 }
1891
XLA_TEST_F(ArrayElementwiseOpTest,MinU32s)1892 XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) {
1893 const uint32 max = std::numeric_limits<uint32>::max();
1894 XlaBuilder builder(TestName());
1895 auto x = ConstantR1<uint32>(&builder, {0, 0, 1, 1, 1, max, max, max});
1896 auto y = ConstantR1<uint32>(&builder, {0, 1, 0, 1, 10, 0, 234234, max});
1897 Min(x, y);
1898
1899 std::vector<uint32> expected = {0, 0, 0, 1, 1, 0, 234234, max};
1900 ComputeAndCompareR1<uint32>(&builder, expected, {});
1901 }
1902
XLA_TEST_F(ArrayElementwiseOpTest,MaxTenF32s)1903 XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) {
1904 XlaBuilder builder(TestName());
1905 auto x = ConstantR1<float>(
1906 &builder, {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0});
1907 auto y = ConstantR1<float>(
1908 &builder, {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0});
1909 Max(x, y);
1910
1911 std::vector<float> expected = {-0.0, 1.0, 2.0, 3.0, 4.0,
1912 5.0, 6.0, 7.0, 8.0, 9.0};
1913 ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
1914 }
1915
XLA_TEST_F(ArrayElementwiseOpTest,MaxR1S1AndR1S0F32s)1916 XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) {
1917 XlaBuilder builder(TestName());
1918 auto u = ConstantR1<float>(&builder, {3.5});
1919 auto v = ConstantR1<float>(&builder, {});
1920 Max(u, v);
1921
1922 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
1923 }
1924
XLA_TEST_F(ArrayElementwiseOpTest,MaxR1S0AndR2S0x2F32s)1925 XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) {
1926 for (int broadcast_dim : {0, 1}) {
1927 XlaBuilder builder(TestName());
1928 auto u = ConstantR1<float>(&builder, {3.5});
1929 auto v = ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 2));
1930 Max(u, v, /*broadcast_dimensions=*/{broadcast_dim});
1931
1932 ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 2), {}, error_spec_);
1933 }
1934 }
1935
XLA_TEST_F(ArrayElementwiseOpTest,Max1DAnd2DF32s)1936 XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) {
1937 XlaBuilder builder(TestName());
1938 auto v = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
1939 auto m = ConstantR2<float>(&builder,
1940 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
1941 Max(v, m, /*broadcast_dimensions=*/{1});
1942
1943 Array2D<float> expected({{2.0f, 3.14f, 4.0f}, {2.25f, 3.0f, 4.0f}});
1944 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1945 }
1946
XLA_TEST_F(ArrayElementwiseOpTest,Max1DAnd2DZeroElementF32s)1947 XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) {
1948 XlaBuilder builder(TestName());
1949 auto v = ConstantR1<float>(&builder, {});
1950 auto m = ConstantR2<float>(&builder, {{}, {}});
1951 Max(v, m, /*broadcast_dimensions=*/{1});
1952
1953 Array2D<float> expected({{}, {}});
1954 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1955 }
1956
XLA_TEST_F(ArrayElementwiseOpTest,Max3DAndScalarS32s)1957 XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) {
1958 XlaBuilder builder(TestName());
1959 auto scalar = ConstantR0<int32>(&builder, 2);
1960 Array3D<int32> a_3d({{{3, 9, -1}, {2, -10, 3}}, {{-2, 2, 8}, {12, 10, 4}}});
1961 auto array = ConstantR3FromArray3D<int32>(&builder, a_3d);
1962 Max(array, scalar, /*broadcast_dimensions=*/{});
1963
1964 Array3D<int32> expected({{{3, 9, 2}, {2, 2, 3}}, {{2, 2, 8}, {12, 10, 4}}});
1965 ComputeAndCompareR3<int32>(&builder, expected, {});
1966 }
1967
XLA_TEST_F(ArrayElementwiseOpTest,Max3DAndScalarZeroElementS32s)1968 XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) {
1969 XlaBuilder builder(TestName());
1970 auto scalar = ConstantR0<int32>(&builder, 2);
1971 Array3D<int32> a_3d(2, 0, 3);
1972 auto array = ConstantR3FromArray3D<int32>(&builder, a_3d);
1973 Max(array, scalar, /*broadcast_dimensions=*/{});
1974
1975 Array3D<int32> expected(2, 0, 3);
1976 ComputeAndCompareR3<int32>(&builder, expected, {});
1977 }
1978
XLA_TEST_F(ArrayElementwiseOpTest,Min2DTo1DF32s)1979 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) {
1980 XlaBuilder builder(TestName());
1981 auto m = ConstantR2<float>(&builder,
1982 {{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}});
1983 auto v = ConstantR1<float>(&builder, {-10.2f, 16.4f});
1984 Min(m, v, /*broadcast_dimensions=*/{0});
1985
1986 Array2D<float> expected({{-10.4f, -10.2f, -10.2f}, {0.1f, 16.4f, 16.1f}});
1987 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1988 }
1989
XLA_TEST_F(ArrayElementwiseOpTest,Min2DTo1DZeroElementF32s)1990 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) {
1991 XlaBuilder builder(TestName());
1992 auto m = ConstantR2<float>(&builder, {{}, {}});
1993 auto v = ConstantR1<float>(&builder, {-10.2f, 16.4f});
1994 Min(m, v, /*broadcast_dimensions=*/{0});
1995
1996 Array2D<float> expected({{}, {}});
1997 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1998 }
1999
XLA_TEST_F(ArrayElementwiseOpTest,Min2DTo4DF32s)2000 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) {
2001 XlaBuilder builder(TestName());
2002 auto array2d =
2003 ConstantR2<float>(&builder, {{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}});
2004 auto array4d = ConstantR4FromArray4D<float>(
2005 &builder, {{{{-12.1f, 32.3f, 6.2f}}, {{0.0f, 32.5f, 3.0f}}},
2006 {{{-2.5f, 64.29f, 6.5f}}, {{-0.01f, 32.25f, 2.6f}}}});
2007 Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3});
2008
2009 Array4D<float> expected(
2010 {{{{-12.2f, 32.3f, 6.1f}}, {{0.0f, 32.2f, 2.5f}}},
2011 {{{-12.2f, 64.29f, 6.1f}}, {{-0.01f, 32.2f, 2.5f}}}});
2012 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
2013 }
2014
XLA_TEST_F(ArrayElementwiseOpTest,Min2DTo4DZeroElementF32s)2015 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) {
2016 XlaBuilder builder(TestName());
2017 auto array2d =
2018 ConstantR2<float>(&builder, {{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}});
2019 Array4D<float> arg(2, 2, 0, 3);
2020 auto array4d = ConstantR4FromArray4D<float>(&builder, arg);
2021 Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3});
2022
2023 Array4D<float> expected(2, 2, 0, 3);
2024 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
2025 }
2026
XLA_TEST_F(ArrayElementwiseOpTest,MinTenS32s)2027 XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) {
2028 XlaBuilder builder(TestName());
2029 auto x = ConstantR1<int32>(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
2030 auto y = ConstantR1<int32>(&builder, {9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
2031 Min(x, y);
2032
2033 std::vector<int32> expected = {0, 1, 2, 3, 4, 4, 3, 2, 1, 0};
2034 ComputeAndCompareR1<int32>(&builder, expected, {});
2035 }
2036
XLA_TEST_F(ArrayElementwiseOpTest,MaxTenS32s)2037 XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) {
2038 XlaBuilder builder(TestName());
2039 auto x = ConstantR1<int32>(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
2040 auto y = ConstantR1<int32>(&builder, {9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
2041 Max(x, y);
2042
2043 std::vector<int32> expected = {9, 8, 7, 6, 5, 5, 6, 7, 8, 9};
2044 ComputeAndCompareR1<int32>(&builder, expected, {});
2045 }
2046
XLA_TEST_F(ArrayElementwiseOpTest,RemTwoConstantS32s)2047 XLA_TEST_F(ArrayElementwiseOpTest, RemTwoConstantS32s) {
2048 XlaBuilder builder(TestName());
2049 auto a = ConstantR1<int32>(&builder, {-3, 26, 2, -1, 1});
2050 auto b = ConstantR1<int32>(&builder, {10, 5, 1, 10, -10});
2051 Rem(a, b);
2052
2053 ComputeAndCompareR1<int32>(&builder, {-3, 1, 0, -1, 1}, {});
2054 }
2055
XLA_TEST_F(ArrayElementwiseOpTest,NonNanClampF32)2056 XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) {
2057 XlaBuilder builder(TestName());
2058 auto minimum = ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
2059 auto argument =
2060 ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f});
2061 auto maximum = ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, 5.0f, 123.0});
2062 Clamp(minimum, argument, maximum);
2063
2064 ComputeAndCompareR1<float>(&builder, {2.0f, 0.5f, 1.0f, 2.25f, 10.0f}, {},
2065 error_spec_);
2066 }
2067
XLA_TEST_F(ArrayElementwiseOpTest,ClampF32)2068 XLA_TEST_F(ArrayElementwiseOpTest, ClampF32) {
2069 SetFastMathDisabled(true);
2070 XlaBuilder builder(TestName());
2071 auto minimum = ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, NAN});
2072 auto argument =
2073 ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f});
2074 auto maximum = ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, NAN, 123.0f});
2075 Clamp(minimum, argument, maximum);
2076
2077 ComputeAndCompareR1<float>(&builder, {2.0f, 0.5f, 1.0f, NAN, NAN}, {},
2078 error_spec_);
2079 }
2080
XLA_TEST_F(ArrayElementwiseOpTest,ClampF32Scalar)2081 XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) {
2082 XlaBuilder builder(TestName());
2083 auto minimum = ConstantR0<float>(&builder, 0.0f);
2084 auto argument = ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
2085 auto maximum = ConstantR0<float>(&builder, 5.0f);
2086 Clamp(minimum, argument, maximum);
2087
2088 ComputeAndCompareR1<float>(&builder, {2.0f, 5.0f, 0.0f, 1.0f, 4.0f}, {},
2089 error_spec_);
2090 }
2091
XLA_TEST_F(ArrayElementwiseOpTest,ClampF32ScalarVector)2092 XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) {
2093 XlaBuilder builder(TestName());
2094 auto min_scalar = ConstantR0<float>(&builder, 0.0f);
2095 auto min_vector =
2096 ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
2097 auto arg_vector =
2098 ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
2099 auto max_scalar = ConstantR0<float>(&builder, 3.0f);
2100 auto max_vector =
2101 ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, 5.0f, 123.0});
2102 // Perform clamp with broadcasted scalar and vector.
2103 Add(Add(Clamp(min_vector, arg_vector, max_scalar),
2104 Clamp(min_scalar, arg_vector, max_vector)),
2105 Add(Clamp(min_vector, arg_vector, max_vector),
2106 Clamp(min_scalar, arg_vector, max_scalar)));
2107
2108 ComputeAndCompareR1<float>(&builder, {8.0f, 7.0f, 2.0f, 6.5f, 14.0f}, {},
2109 error_spec_);
2110 }
2111
XLA_TEST_F(ArrayElementwiseOpTest,ClampS32Vector)2112 XLA_TEST_F(ArrayElementwiseOpTest, ClampS32Vector) {
2113 XlaBuilder builder(TestName());
2114 auto min_vector = ConstantR1<int32>(&builder, {1, -6, 1, 2, 0, -5});
2115 auto arg_vector = ConstantR1<int32>(&builder, {2, 10, -5, 1, 4, 10});
2116 auto max_vector = ConstantR1<int32>(&builder, {3, 0, 25, 5, 123, -1});
2117 Clamp(min_vector, arg_vector, max_vector);
2118
2119 ComputeAndCompareR1<int32>(&builder, {2, 0, 1, 2, 4, -1}, {});
2120 }
2121
XLA_TEST_F(ArrayElementwiseOpTest,ClampS32ScalarVector)2122 XLA_TEST_F(ArrayElementwiseOpTest, ClampS32ScalarVector) {
2123 XlaBuilder builder(TestName());
2124 auto min_scalar = ConstantR0<int32>(&builder, 0);
2125 auto min_vector = ConstantR1<int32>(&builder, {1, -6, 1, 2, 0});
2126 auto arg_vector = ConstantR1<int32>(&builder, {2, 10, -5, 1, 4});
2127 auto max_scalar = ConstantR0<int32>(&builder, 3);
2128 auto max_vector = ConstantR1<int32>(&builder, {3, 1, 25, 5, 123});
2129 // Perform clamp with broadcasted scalar and vector.
2130 Add(Add(Clamp(min_vector, arg_vector, max_scalar),
2131 Clamp(min_scalar, arg_vector, max_vector)),
2132 Add(Clamp(min_vector, arg_vector, max_vector),
2133 Clamp(min_scalar, arg_vector, max_scalar)));
2134
2135 ComputeAndCompareR1<int32>(&builder, {8, 8, 2, 6, 14}, {});
2136 }
2137
XLA_TEST_F(ArrayElementwiseOpTest,ClampU32Vector)2138 XLA_TEST_F(ArrayElementwiseOpTest, ClampU32Vector) {
2139 XlaBuilder builder(TestName());
2140 auto min_vector = ConstantR1<uint32>(&builder, {1, 2, 1, 2, 0, ~0u - 4});
2141 auto arg_vector = ConstantR1<uint32>(&builder, {2, 10, 5, 1, 4, 10});
2142 auto max_vector = ConstantR1<uint32>(&builder, {3, 5, 25, 5, 123, ~0u});
2143 Clamp(min_vector, arg_vector, max_vector);
2144
2145 ComputeAndCompareR1<uint32>(&builder, {2, 5, 5, 2, 4, ~0u - 4}, {});
2146 }
2147
XLA_TEST_F(ArrayElementwiseOpTest,ClampU32ScalarVector)2148 XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) {
2149 XlaBuilder builder(TestName());
2150 auto min_scalar = ConstantR0<uint32>(&builder, 0);
2151 auto min_vector = ConstantR1<uint32>(&builder, {1, 0, 1, 2, 0});
2152 auto arg_vector = ConstantR1<uint32>(&builder, {2, 10, 0, 1, 4});
2153 auto max_scalar = ConstantR0<uint32>(&builder, 3);
2154 auto max_vector = ConstantR1<uint32>(&builder, {3, 1, 25, 5, 123});
2155 // Perform clamp with broadcasted scalar and vector.
2156 Add(Add(Clamp(min_vector, arg_vector, max_scalar),
2157 Clamp(min_scalar, arg_vector, max_vector)),
2158 Add(Clamp(min_vector, arg_vector, max_vector),
2159 Clamp(min_scalar, arg_vector, max_scalar)));
2160
2161 ComputeAndCompareR1<uint32>(&builder, {8, 8, 2, 6, 14}, {});
2162 }
2163
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoParametersF32s)2164 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
2165 XlaBuilder builder(TestName());
2166
2167 Literal param0_literal =
2168 LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
2169 std::unique_ptr<GlobalData> param0_data =
2170 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
2171
2172 Literal param1_literal =
2173 LiteralUtil::CreateR1<float>({7.2f, 2.3f, 3.4f, 5.6f});
2174 std::unique_ptr<GlobalData> param1_data =
2175 client_->TransferToServer(param1_literal).ConsumeValueOrDie();
2176
2177 auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
2178 auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
2179 Add(p0, p1);
2180
2181 ComputeAndCompareR1<float>(&builder, {8.3f, 4.5f, 6.7f, 11.1f},
2182 {param0_data.get(), param1_data.get()},
2183 error_spec_);
2184 }
2185
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoParametersZeroElementF32s)2186 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
2187 XlaBuilder builder(TestName());
2188
2189 Literal param0_literal =
2190 LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
2191 std::unique_ptr<GlobalData> param0_data =
2192 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
2193
2194 Literal param1_literal =
2195 LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
2196 std::unique_ptr<GlobalData> param1_data =
2197 client_->TransferToServer(param1_literal).ConsumeValueOrDie();
2198
2199 auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
2200 auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
2201 Add(p0, p1);
2202
2203 Array3D<float> expected(0, 7, 0);
2204 ComputeAndCompareR3<float>(
2205 &builder, expected, {param0_data.get(), param1_data.get()}, error_spec_);
2206 }
2207
XLA_TEST_F(ArrayElementwiseOpTest,AddParameterToConstantF32s)2208 XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) {
2209 XlaBuilder builder(TestName());
2210
2211 Literal param0_literal =
2212 LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
2213 std::unique_ptr<GlobalData> param0_data =
2214 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
2215
2216 auto a = ConstantR1<float>(&builder, {1.1f, 2.2f, 3.3f, 4.4f});
2217 auto p = Parameter(&builder, 0, param0_literal.shape(), "param0");
2218 Add(a, p);
2219
2220 ComputeAndCompareR1<float>(&builder, {2.2f, 4.4f, 6.6f, 9.9f},
2221 {param0_data.get()}, error_spec_);
2222 }
2223
XLA_TEST_F(ArrayElementwiseOpTest,CosF32s)2224 XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) {
2225 XlaBuilder builder(TestName());
2226 auto a = ConstantR1<float>(&builder, {3.14159f, 0.0f, 1.570796f, -0.78539f});
2227 Cos(a);
2228
2229 ComputeAndCompareR1<float>(&builder, {-1.0f, 1.0f, 0.0f, 0.707107f}, {},
2230 error_spec_);
2231 }
2232
XLA_TEST_F(ArrayElementwiseOpTest,SinF32s)2233 XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) {
2234 XlaBuilder builder(TestName());
2235 auto a = ConstantR1<float>(&builder, {3.14159f, 0.0f, 1.570796f, -0.78539f});
2236 Sin(a);
2237
2238 ComputeAndCompareR1<float>(&builder, {0.0f, 0.0f, 1.0f, -0.707107f}, {},
2239 error_spec_);
2240 }
2241
XLA_TEST_F(ArrayElementwiseOpTest,Atan2F32s)2242 XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) {
2243 XlaBuilder builder(TestName());
2244 auto a = ConstantR1<float>(&builder, {0.0f, 5.0f, 0.0f, -3.0f, 2.0f, -8.0f});
2245 auto b = ConstantR1<float>(&builder, {6.0f, 0.0f, -4.0f, 0.0f, 2.0f, 8.0f});
2246 Atan2(a, b);
2247
2248 ComputeAndCompareR1<float>(
2249 &builder,
2250 {0.0f, 1.57079633f, 3.14159265f, -1.57079633f, 0.78539816f, -0.78539816f},
2251 {}, error_spec_);
2252 }
2253
XLA_TEST_F(ArrayElementwiseOpTest,TanhF32s)2254 XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) {
2255 XlaBuilder builder(TestName());
2256 auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f});
2257 Tanh(a);
2258
2259 ComputeAndCompareR1<float>(&builder, {-0.986614f, 0.996260f, 0.978026}, {},
2260 error_spec_);
2261 }
2262
XLA_TEST_F(ArrayElementwiseOpTest,TanhF32sVector)2263 XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) {
2264 // This is like the test ArrayElementwiseOpTest.TanhF32s above, except that
2265 // the input tensor is large enough to exercise the vectorized tanh
2266 // implementation on XLA CPU.
2267 XlaBuilder builder(TestName());
2268 auto input_literal = LiteralUtil::CreateR1<float>(
2269 {1.02, -0.32, 0.85, 0.90, 1.23, -0.91, -0.49, 0.80, -0.67, 0.16,
2270 -0.07, 0.39, -0.41, 0.04, 1.36, 1.25, 0.41, 0.65, -1.08, 0.32,
2271 -1.45, -0.77, -1.09, 0.91, -1.03, -0.30, -1.11, -1.17, 1.50, -0.85,
2272 0.04, 1.02, 0.34, -0.61, 0.41, 0.07, -0.02, 1.42, -0.62, 0.81,
2273 0.08, 0.81, -0.30, 1.17, -0.65, -0.44, 0.92, 1.26, -1.29, 1.35,
2274 0.08, -1.24, -0.92, 0.49, 1.17, -0.45, -1.31, -1.44, -0.13, -1.31,
2275 -0.79, 1.41, 1.21, 1.05});
2276 TF_ASSERT_OK_AND_ASSIGN(auto input_data,
2277 client_->TransferToServer(input_literal));
2278
2279 auto input = Parameter(&builder, 0, input_literal.shape(), "input");
2280 Tanh(input);
2281
2282 ComputeAndCompareR1<float>(
2283 &builder,
2284 {0.77009583, -0.30665702, 0.69070244, 0.71401149, 0.84400684,
2285 -0.71985596, -0.45764771, 0.66664988, -0.58278900, 0.16050975,
2286 -0.06770509, 0.36843640, -0.38476998, 0.04018109, 0.87562293,
2287 0.84788644, 0.38603750, 0.57294142, -0.79140943, 0.31032649,
2288 -0.89590985, -0.64770776, -0.79625875, 0.72234446, -0.77389336,
2289 -0.28871772, -0.80428445, -0.82541436, 0.90456349, -0.68856895,
2290 0.03877772, 0.76877952, 0.32561871, -0.54546672, 0.39072621,
2291 0.07273290, -0.01924866, 0.88924897, -0.55283129, 0.67183107,
2292 0.08006320, 0.66944766, -0.29068485, 0.82573754, -0.57170743,
2293 -0.41581789, 0.72739530, 0.85025692, -0.85931867, 0.87357593,
2294 0.07782833, -0.84597743, -0.72748238, 0.45396307, 0.82449573,
2295 -0.42462519, -0.86363792, -0.89368379, -0.12621804, -0.86445558,
2296 -0.65565848, 0.88789743, 0.83566397, 0.78287679},
2297 {input_data.get()},
2298 // The error spec is unusually high here to account for the fact that we
2299 // use a rational interpolant to approximate tanh.
2300 ErrorSpec(0.004, 0.004));
2301 }
2302
XLA_TEST_F(ArrayElementwiseOpTest,ExpF32sVector)2303 XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
2304 // The input tensor is large enough to exercise the vectorized exp
2305 // implementation on XLA CPU.
2306 XlaBuilder builder(TestName());
2307
2308 // Just to help make sense of the scales here -- exp(89) saturates float32 and
2309 // exp(-10) is smaller than our error spec.
2310 Literal input_literal = LiteralUtil::CreateR1<float>(
2311 {1.02, -0.32, 0.85, 0.9, 1.23, -0.91, -0.49, 0.8, -1.31,
2312 -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05, -195.6, -194.5,
2313 -193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5, -17.4,
2314 -16.3, -15.2, -14.1, -13.0, -11.9, -10.8, -9.7, -8.6, -7.5,
2315 -6.4, -5.3, -4.2, -3.1, -2.0, -0.9, 0.2, 1.3, 2.4,
2316 3.5, 4.6, 5.7, 6.8, 7.9, 9.0, 10.1, 11.2, 12.3,
2317 13.4, 14.5, 15.6, 16.7, 17.8, 18.9, 20.0, 21.1, 22.2,
2318 23.3, 24.4, 25.5, 26.6, 27.7, 28.8, 29.9, 31.0, 32.1,
2319 68.4, 69.5, 70.6, 71.7, 72.8, 73.9, 75.0, 76.1, 77.2,
2320 78.3, 79.4, 80.5, 81.6, 82.7, 83.8, 84.9, 85.2, 86.3,
2321 86.4, 86.5, 87.6, 87.7, 87.8, 87.9});
2322 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
2323 client_->TransferToServer(input_literal));
2324
2325 auto input = Parameter(&builder, 0, input_literal.shape(), "input");
2326 Exp(input);
2327
2328 std::vector<float> expected_result;
2329 int64 input_size = input_literal.shape().dimensions(0);
2330 expected_result.reserve(input_size);
2331 for (int64 i = 0; i < input_size; i++) {
2332 expected_result.push_back(std::exp(input_literal.Get<float>({i})));
2333 }
2334
2335 ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
2336 error_spec_);
2337 }
2338
XLA_TEST_F(ArrayElementwiseOpTest,LogF32sVector)2339 XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
2340 // The input tensor is large enough to exercise the vectorized exp
2341 // implementation on XLA CPU.
2342 XlaBuilder builder(TestName());
2343
2344 Literal input_literal = LiteralUtil::CreateR1<float>(
2345 {-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198,
2346 -167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9,
2347 198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04,
2348 1.74e+04, 1.89e+05, 1.9e+05, 1.93e+06, 1.98e+06, 1.65e+06, 1.97e+07,
2349 1.66e+07, 1e+07, 1.98e+08, 1.96e+08, 1.64e+09, 1.58e+09, 1.64e+09,
2350 1.44e+10, 1.5e+10, 1.99e+10, 1.17e+11, 1.08e+11, 1.08e+12, 1.38e+12,
2351 1.4e+12, 1.03e+13, 1.6e+13, 1.99e+13, 1.26e+14, 1.51e+14, 1.33e+15,
2352 1.41e+15, 1.63e+15, 1.39e+16, 1.21e+16, 1.27e+16, 1.28e+17, 1.62e+17,
2353 2e+18, 1.96e+18, 1.81e+18, 1.99e+19, 1.86e+19, 1.61e+19, 1.71e+20,
2354 1.47e+20, 1.83e+21, 1.33e+21, 1.3e+21, 1.35e+22, 1.84e+22, 1.02e+22,
2355 1.81e+23, 1.02e+23, 1.89e+24, 1.49e+24, 1.08e+24, 1.95e+25, 1.1e+25,
2356 1.62e+25, 1.2e+26, 1.41e+26, 1.93e+27, 1.66e+27, 1.62e+27, 1.05e+28,
2357 1.5e+28, 1.79e+28, 1.36e+29, 1.95e+29, 1.5e+30, 1.81e+30, 1.34e+30,
2358 1.7e+31, 1.44e+31, 1.1e+31, 1.4e+32, 1.67e+32, 1.96e+33, 1.11e+33,
2359 1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35});
2360 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
2361 client_->TransferToServer(input_literal));
2362
2363 auto input = Parameter(&builder, 0, input_literal.shape(), "input");
2364 Log(input);
2365
2366 std::vector<float> expected_result;
2367 int64 input_size = input_literal.shape().dimensions(0);
2368 expected_result.reserve(input_size);
2369 for (int64 i = 0; i < input_size; i++) {
2370 expected_result.push_back(std::log(input_literal.Get<float>({i})));
2371 }
2372
2373 ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
2374 error_spec_);
2375 }
2376
XLA_TEST_F(ArrayElementwiseOpTest,ClzU32s)2377 XLA_TEST_F(ArrayElementwiseOpTest, ClzU32s) {
2378 XlaBuilder builder(TestName());
2379 auto a = ConstantR1<uint32>(
2380 &builder, {0, 1, 0x10, 0x10000, 0x700000, 0x12345678, 0xF2345678});
2381 Clz(a);
2382
2383 ComputeAndCompareR1<uint32>(&builder, {32, 31, 27, 15, 9, 3, 0}, {});
2384 }
2385
XLA_TEST_F(ArrayElementwiseOpTest,ClzS64s)2386 XLA_TEST_F(ArrayElementwiseOpTest, ClzS64s) {
2387 XlaBuilder builder(TestName());
2388 auto a =
2389 ConstantR1<int64>(&builder, {0, 1, 0x80000000, 0x7FFFFFFFF2345678ul, -1});
2390 Clz(a);
2391
2392 ComputeAndCompareR1<int64>(&builder, {64, 63, 32, 1, 0}, {});
2393 }
2394
XLA_TEST_F(ArrayElementwiseOpTest,AddChainFoldLeft)2395 XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) {
2396 // a ------ (add) --------- (add)
2397 // / /
2398 // b -----/ /
2399 // c---------------------/
2400 XlaBuilder builder(TestName());
2401
2402 auto a = ConstantR1<float>(&builder, {1.1f, 2.2f, 3.3f, 4.4f});
2403 auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
2404 auto c = ConstantR1<float>(&builder, {-3.3f, -15.5f, -7.7f, -29.9f});
2405
2406 auto add = Add(a, b);
2407 Add(add, c);
2408
2409 ComputeAndCompareR1<float>(&builder, {-0.1f, -10.1f, -0.1f, -20.1f}, {},
2410 error_spec_);
2411 }
2412
XLA_TEST_F(ArrayElementwiseOpTest,AddChainFoldRight)2413 XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) {
2414 // b ------ (add) --------- (add)
2415 // / /
2416 // c -----/ /
2417 // a---------------------/
2418 XlaBuilder builder(TestName());
2419
2420 auto a = ConstantR1<float>(&builder, {91.1f, 2.2f, 3.3f, 4.4f});
2421 auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
2422 auto c = ConstantR1<float>(&builder, {-3.3f, -15.5f, -7.7f, -29.9f});
2423
2424 auto add = Add(b, c);
2425 Add(a, add);
2426
2427 ComputeAndCompareR1<float>(&builder, {89.9f, -10.1f, -0.1f, -20.1f}, {},
2428 error_spec_);
2429 }
2430
XLA_TEST_F(ArrayElementwiseOpTest,AddWithNeg)2431 XLA_TEST_F(ArrayElementwiseOpTest, AddWithNeg) {
2432 // a ----- (neg) ----- (add)
2433 // /
2434 // b ----- (neg) ----/
2435 XlaBuilder builder(TestName());
2436
2437 auto a = ConstantR1<float>(&builder, {91.1f, 2.2f, 3.3f, 4.4f});
2438 auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
2439
2440 auto neg_a = Neg(a);
2441 auto neg_b = Neg(b);
2442 Add(neg_a, neg_b);
2443
2444 ComputeAndCompareR1<float>(&builder, {-93.2f, -5.4f, -7.6f, -9.8f}, {},
2445 error_spec_);
2446 }
2447
XLA_TEST_F(ArrayElementwiseOpTest,AddChainTwoSide)2448 XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) {
2449 // a ------ (add) ------------\
2450 // / \
2451 // b -----/ (add)
2452 // /
2453 // c ------ (add) ------------/
2454 // /
2455 // d -----/
2456 XlaBuilder builder(TestName());
2457
2458 auto a = ConstantR1<float>(&builder, {91.1f, 2.2f, 3.3f, 4.4f});
2459 auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
2460 auto c = ConstantR1<float>(&builder, {-3.3f, -15.5f, -7.7f, -29.9f});
2461 auto d = ConstantR1<float>(&builder, {-19.0f, 10.0f, -40.0f, 20.2f});
2462
2463 auto add_ab = Add(a, b);
2464 auto add_cd = Add(c, d);
2465 Add(add_ab, add_cd);
2466
2467 ComputeAndCompareR1<float>(&builder, {70.9f, -0.1f, -40.1f, 0.1f}, {},
2468 error_spec_);
2469 }
2470
2471 XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) {
2472 XlaBuilder builder(TestName());
2473 auto a = ConstantR2<float>(&builder,
2474 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2475 auto b = ConstantR2<float>(&builder,
2476 {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
2477 Add(a, b);
2478
2479 Array2D<float> expected_array(
2480 {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}});
2481 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2482 }
2483
XLA_TEST_F(ArrayElementwiseOpTest,ScalarPlus2DF32)2484 XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) {
2485 // Add a scalar + matrix.
2486 XlaBuilder builder(TestName());
2487 auto a = ConstantR2<float>(&builder,
2488 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2489 auto scalar = ConstantR0<float>(&builder, 3.0f);
2490 Add(scalar, a);
2491
2492 Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}});
2493 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2494 }
2495
2496 XLA_TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) {
2497 // Add a matrix + scalar.
2498 XlaBuilder builder(TestName());
2499 auto a = ConstantR2<float>(&builder,
2500 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2501 auto scalar = ConstantR0<float>(&builder, 3.0f);
2502 Add(a, scalar);
2503
2504 Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}});
2505 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2506 }
2507
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo2DF32)2508 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) {
2509 // Test simple broadcasting of a R1F32 over R2F32. The vector's size matches
2510 // only dim 0 of the matrix.
2511 XlaBuilder builder(TestName());
2512 auto v = ConstantR1<float>(&builder, {20.0f, 40.0f, 60.0f});
2513 // clang-format off
2514 auto m = ConstantR2<float>(&builder, {
2515 {-2.5f, 3.14f, 1.0f},
2516 {2.25f, -10.0f, 3.33f}});
2517 // clang-format on
2518 Add(v, m, /*broadcast_dimensions=*/{1});
2519 Array2D<float> expected_array(
2520 {{17.5f, 43.14f, 61.0f}, {22.25f, 30.0f, 63.33f}});
2521 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2522 }
2523
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Eq)2524 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) {
2525 // Test broadcasting in Eq comparison.
2526 XlaBuilder builder(TestName());
2527 auto v = ConstantR1<int32>(&builder, {42, 73});
2528 auto m = ConstantR2<int32>(&builder, {{42, 73}, {42, 52}});
2529
2530 // This test exercises both possible broadcast dimensions for a vector/matrix
2531 // comparison.
2532 auto cmp_dim_0 = Eq(v, m, /*broadcast_dimensions=*/{1});
2533 auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0});
2534 Tuple(&builder, {cmp_dim_0, cmp_dim_1});
2535
2536 auto expected = LiteralUtil::MakeTupleFromSlices(
2537 {LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}),
2538 LiteralUtil::CreateR2<bool>({{true, false}, {false, false}})});
2539 ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
2540 }
2541
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Ne)2542 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) {
2543 // Test broadcasting in Ne comparison.
2544 XlaBuilder builder(TestName());
2545 auto v = ConstantR1<int32>(&builder, {42, 73});
2546 auto m = ConstantR2<int32>(&builder, {{42, 73}, {42, 52}});
2547 Ne(v, m, /*broadcast_dimensions=*/{1});
2548
2549 const string expected = R"(pred[2,2] {
2550 { 0, 0 },
2551 { 0, 1 }
2552 })";
2553 EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2554 }
2555
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Ge)2556 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) {
2557 // Test broadcasting in Ge comparison.
2558 XlaBuilder builder(TestName());
2559 auto v = ConstantR1<int32>(&builder, {1, 2, 3, 4});
2560 auto m = ConstantR2<int32>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
2561 Ge(v, m, /*broadcast_dimensions=*/{1});
2562
2563 const string expected = R"(pred[2,4] {
2564 { 1, 1, 0, 0 },
2565 { 0, 0, 0, 1 }
2566 })";
2567 EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2568 }
2569
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Gt)2570 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) {
2571 // Test broadcasting in Gt comparison.
2572 XlaBuilder builder(TestName());
2573 auto v = ConstantR1<int32>(&builder, {1, 2, 3, 4});
2574 auto m = ConstantR2<int32>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
2575 Gt(v, m, /*broadcast_dimensions=*/{1});
2576
2577 const string expected = R"(pred[2,4] {
2578 { 0, 1, 0, 0 },
2579 { 0, 0, 0, 0 }
2580 })";
2581 EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2582 }
2583
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Le)2584 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) {
2585 // Test broadcasting in Le comparison.
2586 XlaBuilder builder(TestName());
2587 auto v = ConstantR1<int32>(&builder, {1, 2, 3, 4});
2588 auto m = ConstantR2<int32>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
2589 Le(v, m, /*broadcast_dimensions=*/{1});
2590
2591 const string expected = R"(pred[2,4] {
2592 { 1, 0, 1, 1 },
2593 { 1, 1, 1, 1 }
2594 })";
2595 EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2596 }
2597
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Lt)2598 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) {
2599 // Test broadcasting in Lt comparison.
2600 XlaBuilder builder(TestName());
2601 auto v = ConstantR1<int32>(&builder, {1, 2, 3, 4});
2602 auto m = ConstantR2<int32>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
2603 Lt(v, m, /*broadcast_dimensions=*/{1});
2604
2605 const string expected = R"(pred[2,4] {
2606 { 0, 0, 1, 1 },
2607 { 1, 1, 1, 0 }
2608 })";
2609 EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2610 }
2611
XLA_TEST_F(ArrayElementwiseOpTest,Mul2Dby1DF32)2612 XLA_TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) {
2613 // Test simple broadcasting of a R1F32 over R2F32 when the order of binary op
2614 // arguments is reversed.
2615 XlaBuilder builder(TestName());
2616 auto m =
2617 ConstantR2<float>(&builder, {{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}});
2618 auto v = ConstantR1<float>(&builder, {2.0f, 4.0f, 6.0f});
2619 Mul(m, v, /*broadcast_dimensions=*/{1});
2620 Array2D<float> expected_array({{3.0f, 10.0f, 21.0f}, {9.0f, 22.0f, 39.0f}});
2621 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2622 }
2623
XLA_TEST_F(ArrayElementwiseOpTest,Add2DTo2DWithDegenerateDim1)2624 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) {
2625 // Tests broadcasting for arrays with degenerate (size == 1) dimensions.
2626 XlaBuilder builder(TestName());
2627 // m's shape in XLA notation is {3, 2}
2628 // md's shape in XLA notation is {3, 1}
2629 // The result has shape {3, 2}, where md is broadcast over m
2630 auto m = ConstantR2<float>(&builder,
2631 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2632 auto md = ConstantR2<float>(&builder, {{10.0f, 20.0f, 30.0f}});
2633 Add(m, md);
2634 Array2D<float> expected_array(
2635 {{7.5f, 23.14f, 31.0f}, {12.25f, 10.0f, 33.33f}});
2636 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2637 }
2638
XLA_TEST_F(ArrayElementwiseOpTest,Add2DTo2DWithDegenerateDim0)2639 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim0) {
2640 // Tests broadcasting for arrays with degenerate (size == 1) dimensions.
2641 XlaBuilder builder(TestName());
2642 // m's shape in XLA notation is {3, 2}
2643 // md's shape in XLA notation is {1, 2}
2644 // The result has shape {3, 2}, where md is broadcast over m
2645 auto m = ConstantR2<float>(&builder,
2646 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2647 auto md = ConstantR2<float>(&builder, {{10.0f}, {20.0f}});
2648 Add(m, md);
2649 Array2D<float> expected_array(
2650 {{7.5f, 13.14f, 11.0f}, {22.25f, 10.0f, 23.33f}});
2651 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2652 }
2653
XLA_TEST_F(ArrayElementwiseOpTest,Add2DsWithDegenerateDimsOuterProduct)2654 XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) {
2655 // Tests broadcasting for two degenerate arrays. This kind of broadcasting
2656 // effectively creates an "outer product" operation.
2657 // This is taken from the Numpy docs example at:
2658 // http://docs.scipy.org/doc/numpy-1.10.1/user/basics.broadcasting.html
2659 XlaBuilder builder(TestName());
2660 // a's shape in XLA notation is {1, 4}
2661 // b's shape in XLA notation is {3, 1}
2662 // The result has shape {3, 4}.
2663 auto a = ConstantR2<float>(&builder, {{0.0f}, {10.0f}, {20.0f}, {30.0f}});
2664 auto b = ConstantR2<float>(&builder, {{1.0f, 2.0f, 3.0f}});
2665 Add(a, b);
2666 Array2D<float> expected_array({{1.0f, 2.0f, 3.0f},
2667 {11.0f, 12.0f, 13.0f},
2668 {21.0f, 22.0f, 23.0f},
2669 {31.0f, 32.0f, 33.0f}});
2670 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2671 }
2672
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo2DF32TwoWaysOver1)2673 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) {
2674 // Add together a (2,2) array and a (2) array, using dimension 0 for
2675 // broadcasting (though there are two ways to broadcast these shapes).
2676 XlaBuilder builder(TestName());
2677 auto v = ConstantR1<float>(&builder, {20.0f, 40.0f});
2678 auto m = ConstantR2<float>(&builder, {{10.0f, 50.0f}, {77.0f, 88.0f}});
2679 Add(v, m, /*broadcast_dimensions=*/{1});
2680 Array2D<float> expected_array({{30.0f, 90.0f}, {97.0f, 128.0f}});
2681 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2682 }
2683
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo2DF32TwoWaysOver0)2684 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) {
2685 // Add together a (2,2) array and a (2) array, using dimension 1 for
2686 // broadcasting (though there are two ways to broadcast these shapes).
2687 XlaBuilder builder(TestName());
2688 auto v = ConstantR1<float>(&builder, {20.0f, 40.0f});
2689 auto m = ConstantR2<float>(&builder, {{10.0f, 50.0f}, {77.0f, 88.0f}});
2690 Add(v, m, /*broadcast_dimensions=*/{0});
2691 Array2D<float> expected_array({{30.0f, 70.0f}, {117.0f, 128.0f}});
2692 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2693 }
2694
2695 XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) {
2696 // Binary add of two R3s together
2697 XlaBuilder builder(TestName());
2698 Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
2699 {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}});
2700 auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2701
2702 Array3D<float> b_3d({{{2.0f, 4.0f}, {6.0f, 8.0f}, {10.0f, 12.0f}},
2703 {{14.0f, 16.0f}, {18.0f, 20.0f}, {22.0f, 24.0f}}});
2704 auto b = ConstantR3FromArray3D<float>(&builder, b_3d);
2705 Add(a, b);
2706
2707 Array3D<float> expected_3d(
2708 {{{3.0f, 6.0f}, {9.0f, 12.0f}, {15.0f, 18.0f}},
2709 {{21.0f, 24.0f}, {27.0f, 30.0f}, {33.0f, 36.0f}}});
2710 ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2711 }
2712
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo3DTwoWaysOver2)2713 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) {
2714 // Add together a (2, 3, 2) array with a (2) array, using dimension 0 for
2715 // broadcasting (though there are two ways to broadcast these shapes).
2716 XlaBuilder builder(TestName());
2717 // clang-format off
2718 Array3D<float> a_3d({
2719 {{1.0f, 2.0f},
2720 {3.0f, 4.0f},
2721 {5.0f, 6.0f}},
2722 {{7.0f, 8.0f},
2723 {9.0f, 10.0f},
2724 {11.0f, 12.0f}},
2725 });
2726 // clang-format on
2727 auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2728 auto v = ConstantR1<float>(&builder, {10.0f, 20.0f});
2729 Add(a, v, /*broadcast_dimensions=*/{2});
2730
2731 Array3D<float> expected_3d(
2732 {{{11.0f, 22.0f}, {13.0f, 24.0f}, {15.0f, 26.0f}},
2733 {{17.0f, 28.0f}, {19.0f, 30.0f}, {21.0f, 32.0f}}});
2734 ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2735 }
2736
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo3DTwoWaysOver0)2737 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) {
2738 // Add together a (2, 3, 2) array with a (2) array, using dimension 2 for
2739 // broadcasting (though there are two ways to broadcast these shapes).
2740 XlaBuilder builder(TestName());
2741 // clang-format off
2742 Array3D<float> a_3d({
2743 {{1.0f, 2.0f},
2744 {3.0f, 4.0f},
2745 {5.0f, 6.0f}},
2746 {{7.0f, 8.0f},
2747 {9.0f, 10.0f},
2748 {11.0f, 12.0f}},
2749 });
2750 // clang-format on
2751 auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2752 auto v = ConstantR1<float>(&builder, {10.0f, 20.0f});
2753 Add(a, v, /*broadcast_dimensions=*/{0});
2754
2755 // clang-format off
2756 Array3D<float> expected_3d({
2757 {{11.0f, 12.0f},
2758 {13.0f, 14.0f},
2759 {15.0f, 16.0f}},
2760 {{27.0f, 28.0f},
2761 {29.0f, 30.0f},
2762 {31.0f, 32.0f}},
2763 });
2764 // clang-format on
2765 ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2766 }
2767
XLA_TEST_F(ArrayElementwiseOpTest,Add2DTo3D)2768 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) {
2769 // Add together a (2, 3, 2) array with a (3, 2) array, using dimensions {1,2}
2770 // for broadcasting.
2771 XlaBuilder builder(TestName());
2772 // clang-format off
2773 Array3D<float> a_3d({
2774 {{1.0f, 2.0f},
2775 {3.0f, 4.0f},
2776 {5.0f, 6.0f}},
2777 {{7.0f, 8.0f},
2778 {9.0f, 10.0f},
2779 {11.0f, 12.0f}},
2780 });
2781 auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2782 auto m = ConstantR2<float>(&builder, {
2783 {10.0f, 20.0f, 30.0f},
2784 {40.0f, 50.0f, 60.0f},
2785 });
2786 Add(a, m, /*broadcast_dimensions=*/{0, 1});
2787
2788 Array3D<float> expected_3d({
2789 {{11.0f, 12.0f},
2790 {23.0f, 24.0f},
2791 {35.0f, 36.0f}},
2792 {{47.0f, 48.0f},
2793 {59.0f, 60.0f},
2794 {71.0f, 72.0f}},
2795 });
2796 // clang-format on
2797 ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2798 }
2799
XLA_TEST_F(ArrayElementwiseOpTest,CompareGtR3F32sWithDegenerateDim2)2800 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) {
2801 // Comparison between two 3D arrays of compatible shapes:
2802 // (2, 3, 2) and (2, 3, 1): expected to produce a (2, 3, 2) shape of PREDs.
2803 XlaBuilder builder(TestName());
2804 Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
2805 {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}});
2806 auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2807
2808 Array3D<float> b_3d({{{7.0f, 1.0f}, {3.0f, 10.0f}, {15.0f, 6.0f}}});
2809 auto b = ConstantR3FromArray3D<float>(&builder, b_3d);
2810
2811 Gt(a, b);
2812
2813 Array3D<int> expected_3d(
2814 {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}});
2815 const string expected = R"(pred[2,3,2] {
2816 {
2817 { 0, 1 },
2818 { 0, 0 },
2819 { 0, 0 }
2820 },
2821 {
2822 { 0, 1 },
2823 { 1, 0 },
2824 { 0, 1 }
2825 }
2826 })";
2827 EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2828 }
2829
2830 XLA_TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) {
2831 XlaBuilder builder(TestName());
2832
2833 std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5));
2834 std::unique_ptr<Array4D<float>> operand_b_4d(new Array4D<float>(2, 3, 4, 5));
2835 std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5));
2836 float value = 0.0;
2837 for (int64 p = 0; p < 2; ++p) {
2838 for (int64 z = 0; z < 3; ++z) {
2839 for (int64 y = 0; y < 4; ++y) {
2840 for (int64 x = 0; x < 5; ++x) {
2841 (*operand_a_4d)(p, z, y, x) = value;
2842 (*operand_b_4d)(p, z, y, x) = 2.0 * value;
2843 (*expected_4d)(p, z, y, x) = 3.0 * value;
2844 value += 0.1;
2845 }
2846 }
2847 }
2848 }
2849
2850 auto a = ConstantR4FromArray4D<float>(&builder, *operand_a_4d);
2851 auto b = ConstantR4FromArray4D<float>(&builder, *operand_b_4d);
2852 Add(a, b);
2853
2854 ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
2855 }
2856
XLA_TEST_F(ArrayElementwiseOpTest,R4PlusR1InDim1)2857 XLA_TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) {
2858 XlaBuilder builder(TestName());
2859
2860 std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5));
2861 std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5));
2862 std::vector<float> operand_b_1d(3);
2863 std::iota(operand_b_1d.begin(), operand_b_1d.end(), 1.0);
2864
2865 float value = 0.0;
2866 for (int64 p = 0; p < 2; ++p) {
2867 for (int64 z = 0; z < 3; ++z) {
2868 for (int64 y = 0; y < 4; ++y) {
2869 for (int64 x = 0; x < 5; ++x) {
2870 (*operand_a_4d)(p, z, y, x) = value;
2871 (*expected_4d)(p, z, y, x) = value + operand_b_1d[z];
2872 value += 0.1;
2873 }
2874 }
2875 }
2876 }
2877
2878 auto a = ConstantR4FromArray4D<float>(&builder, *operand_a_4d);
2879 auto b = ConstantR1<float>(&builder, operand_b_1d);
2880 Add(a, b, {1});
2881
2882 ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
2883 }
2884
XLA_TEST_F(ArrayElementwiseOpTest,R4_16x16x2x2_Plus_R1_16)2885 XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) {
2886 constexpr int d0 = 16;
2887 constexpr int d1 = 16;
2888 constexpr int d2 = 2;
2889 constexpr int d3 = 2;
2890 Array4D<float> r4(d0, d1, d2, d3);
2891 r4.Fill(1.0);
2892 std::vector<float> r1(d1);
2893 std::iota(r1.begin(), r1.end(), 1.0);
2894
2895 XlaBuilder builder(TestName());
2896 Literal a_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
2897 r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
2898 auto a = ConstantLiteral(&builder, a_literal);
2899 auto b = ConstantR1<float>(&builder, r1);
2900 Add(a, b, {1});
2901
2902 for (int i0 = 0; i0 < d0; ++i0) {
2903 for (int i1 = 0; i1 < d1; ++i1) {
2904 for (int i2 = 0; i2 < d2; ++i2) {
2905 for (int i3 = 0; i3 < d3; ++i3) {
2906 r4(i0, i1, i2, i3) += r1[i1];
2907 }
2908 }
2909 }
2910 }
2911 ComputeAndCompareR4<float>(&builder, r4, {}, error_spec_);
2912 }
2913
2914 // Show that we can't add two opaques.
XLA_TEST_F(ArrayElementwiseOpTest,CannotAddOpaques)2915 XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) {
2916 XlaBuilder builder(TestName());
2917 auto shape = ShapeUtil::MakeOpaqueShape();
2918 auto x = Parameter(&builder, 0, shape, "x");
2919 Add(x, x);
2920 auto computation_status = builder.Build();
2921 ASSERT_FALSE(computation_status.ok());
2922 EXPECT_THAT(computation_status.status().ToString(),
2923 ::testing::ContainsRegex(
2924 "Expected array argument for lhs of binary operation"));
2925 }
2926
XLA_TEST_F(ArrayElementwiseOpTest,IdentityBroadcastOfSameRankIsAllowed)2927 XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) {
2928 XlaBuilder builder(TestName());
2929 auto a = ConstantR2<float>(&builder,
2930 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2931 auto b = ConstantR2<float>(&builder,
2932 {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
2933 Add(a, b, /*broadcast_dimensions=*/{0, 1});
2934
2935 Array2D<float> expected_array(
2936 {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}});
2937 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2938 }
2939
XLA_TEST_F(ArrayElementwiseOpTest,NonIdentityBroadcastOfSameRankIsDisallowed)2940 XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) {
2941 XlaBuilder builder(TestName());
2942 auto a = ConstantR2<float>(&builder,
2943 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2944 auto b = ConstantR2<float>(&builder,
2945 {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
2946 Add(a, b, /*broadcast_dimensions=*/{1, 0});
2947
2948 auto computation_status = builder.Build();
2949 ASSERT_FALSE(computation_status.ok());
2950 EXPECT_THAT(computation_status.status().error_message(),
2951 ::testing::ContainsRegex("must.*be the identity"));
2952 }
2953
2954 // Regression test for b/31927799. "slice - y" is fused and requires implicit
2955 // broadcast.
XLA_TEST_F(ArrayElementwiseOpTest,ImplictBroadcastInFusedExpressions)2956 XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) {
2957 XlaBuilder builder(TestName());
2958 auto x_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
2959 auto y_literal = LiteralUtil::CreateR1<float>({4, 5});
2960 auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
2961 auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
2962
2963 auto x = Parameter(&builder, 0, x_literal.shape(), "x");
2964 auto y = Parameter(&builder, 1, y_literal.shape(), "y");
2965 auto slice = Slice(x, {1}, {2}, {1});
2966 Sub(slice, y);
2967
2968 ComputeAndCompareR1<float>(&builder, {-2, -3}, {x_data.get(), y_data.get()},
2969 error_spec_);
2970 }
2971
2972 INSTANTIATE_TEST_CASE_P(ArrayElementwiseOpTestParamCount,
2973 ArrayElementwiseOpTestParamCount,
2974 ::testing::Values(127, 128, 129, 17 * 4096));
2975
2976 } // namespace
2977 } // namespace xla
2978