1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cmath>
17 #include <limits>
18 #include <memory>
19 #include <numeric>
20 #include <vector>
21 
22 #include "absl/base/casts.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/array2d.h"
25 #include "tensorflow/compiler/xla/array3d.h"
26 #include "tensorflow/compiler/xla/array4d.h"
27 #include "tensorflow/compiler/xla/client/global_data.h"
28 #include "tensorflow/compiler/xla/client/local_client.h"
29 #include "tensorflow/compiler/xla/client/xla_builder.h"
30 #include "tensorflow/compiler/xla/layout_util.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/test.h"
34 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
35 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
36 #include "tensorflow/compiler/xla/tests/test_macros.h"
37 #include "tensorflow/compiler/xla/types.h"
38 #include "tensorflow/core/platform/types.h"
39 
40 namespace xla {
41 namespace {
42 
43 class ArrayElementwiseOpTest : public ClientLibraryTestBase {
44  public:
45   ErrorSpec error_spec_{0.0001, 0.0001};
46 };
47 
48 class ArrayElementwiseOpTestParamCount
49     : public ArrayElementwiseOpTest,
50       public ::testing::WithParamInterface<int> {};
51 
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantZeroElementF32)52 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementF32) {
53   XlaBuilder builder(TestName());
54   auto a = ConstantR1<float>(&builder, {});
55   Neg(a);
56 
57   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
58 }
59 
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantF32)60 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF32) {
61   XlaBuilder builder(TestName());
62   auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
63   Neg(a);
64 
65   ComputeAndCompareR1<float>(&builder, {2.5f, -3.14f, -2.25f, 10.0f, -6.0f}, {},
66                              error_spec_);
67 }
68 
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantS32)69 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) {
70   XlaBuilder builder(TestName());
71   auto a = ConstantR1<int32>(&builder,
72                              {-1, 0, 1, 324, std::numeric_limits<int32>::min(),
73                               std::numeric_limits<int32>::max()});
74   Neg(a);
75 
76   // -min == min for int32 due to an overflow. In C++ it is undefined behavior
77   // to do this calculation. For XLA we have not specified that, so it
78   // ought to work.
79   ComputeAndCompareR1<int32>(&builder,
80                              {1, 0, -1, -324, std::numeric_limits<int32>::min(),
81                               -std::numeric_limits<int32>::max()},
82                              {});
83 }
84 
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantZeroElementC64)85 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementC64) {
86   XlaBuilder builder(TestName());
87   auto a = ConstantR1<complex64>(&builder, {});
88   Neg(a);
89 
90   ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
91 }
92 
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantC64)93 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) {
94   XlaBuilder builder(TestName());
95   auto a = ConstantR1<complex64>(
96       &builder, {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}});
97   Neg(a);
98 
99   ComputeAndCompareR1<complex64>(
100       &builder, {{2.5f, -1.0f}, {0.0f, -3.14f}, {-2.25f, 1.0f}, {10.0f, 0.0f}},
101       {}, error_spec_);
102 }
103 
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantS64)104 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS64) {
105   XlaBuilder builder(TestName());
106   auto a =
107       ConstantR1<int64>(&builder, {
108                                       -1,
109                                       1,
110                                       0,
111                                       0x12345678,
112                                       static_cast<int64>(0xffffffff12345678l),
113                                       static_cast<int64>(0x8000000000000000LL),
114                                       static_cast<int64>(0x8000000000000001LL),
115                                   });
116   Neg(a);
117   LOG(INFO) << -static_cast<int64>(0x7FFFFFFFFFFFFFFFLL);
118 
119   ComputeAndCompareR1<int64>(&builder,
120                              {
121                                  1,
122                                  -1,
123                                  0,
124                                  -0x12345678,
125                                  0xedcba988,
126                                  static_cast<int64>(0x8000000000000000LL),
127                                  -static_cast<int64>(0x8000000000000001LL),
128                              },
129                              {});
130 }
131 
XLA_TEST_F(ArrayElementwiseOpTest,IsFiniteZeroElementF32s)132 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) {
133   XlaBuilder builder(TestName());
134   auto a = ConstantR1<float>(&builder, {});
135   IsFinite(a);
136 
137   ComputeAndCompareR1<bool>(&builder, {}, {});
138 }
139 
140 // A non-canonical quiet NaN value.
141 static const float kNonCanonicalNaN = absl::bit_cast<float>(0x7FD01234);
142 
XLA_TEST_F(ArrayElementwiseOpTest,IsFiniteScalarF32)143 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteScalarF32) {
144   XlaBuilder builder(TestName());
145   IsFinite(ConstantR0<float>(&builder, NAN));
146   ComputeAndCompareR0<bool>(&builder, false, {});
147 
148   EXPECT_TRUE(std::isnan(kNonCanonicalNaN));
149   IsFinite(ConstantR0<float>(&builder, kNonCanonicalNaN));
150   ComputeAndCompareR0<bool>(&builder, false, {});
151 
152   const float inf = std::numeric_limits<float>::infinity();
153   IsFinite(ConstantR0<float>(&builder, inf));
154   ComputeAndCompareR0<bool>(&builder, false, {});
155 
156   IsFinite(ConstantR0<float>(&builder, -inf));
157   ComputeAndCompareR0<bool>(&builder, false, {});
158 
159   IsFinite(ConstantR0<float>(&builder, 0.0f));
160   ComputeAndCompareR0<bool>(&builder, true, {});
161 }
162 
XLA_TEST_F(ArrayElementwiseOpTest,IsFiniteR1F32s)163 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) {
164   XlaBuilder builder(TestName());
165   const float inf = std::numeric_limits<float>::infinity();
166   EXPECT_TRUE(std::isnan(kNonCanonicalNaN));
167   auto a = ConstantR1<float>(&builder,
168                              {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}});
169   IsFinite(a);
170 
171   ComputeAndCompareR1<bool>(&builder, {false, true, false, true, false, false},
172                             {});
173 }
174 
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantF32s)175 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) {
176   XlaBuilder builder(TestName());
177   auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
178   auto b = ConstantR1<float>(&builder, {100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
179   Add(a, b);
180 
181   ComputeAndCompareR1<float>(&builder, {97.5f, 6.27f, 5.0f, 0.5f, -993.0f}, {},
182                              error_spec_);
183 }
184 
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantZeroElementF32s)185 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) {
186   XlaBuilder builder(TestName());
187   auto a = ConstantR1<float>(&builder, {});
188   auto b = ConstantR1<float>(&builder, {});
189   Add(a, b);
190 
191   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
192 }
193 
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantC64s)194 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) {
195   XlaBuilder builder(TestName());
196   auto a = ConstantR1<complex64>(
197       &builder, {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}});
198   auto b = ConstantR1<complex64>(
199       &builder, {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}});
200   Add(a, b);
201 
202   ComputeAndCompareR1<complex64>(
203       &builder, {97.5f, {3.13f, 3.14f}, {5.0f, 1.0f}, {-1.0f, 0.5f}}, {},
204       error_spec_);
205 }
206 
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantZeroElementC64s)207 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) {
208   XlaBuilder builder(TestName());
209   auto a = ConstantR1<complex64>(&builder, {});
210   auto b = ConstantR1<complex64>(&builder, {});
211   Add(a, b);
212 
213   ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
214 }
215 
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantU64s)216 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
217   XlaBuilder b(TestName());
218 
219   std::vector<uint64> lhs{0xFFFFFFFF,
220                           static_cast<uint64>(-1),
221                           0,
222                           0,
223                           0x7FFFFFFFFFFFFFFFLL,
224                           0x7FFFFFFFFFFFFFFLL,
225                           0x8000000000000000LL,
226                           0x8000000000000000LL,
227                           1};
228   Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
229   auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
230   std::unique_ptr<GlobalData> lhs_data =
231       client_->TransferToServer(lhs_literal).ConsumeValueOrDie();
232 
233   std::vector<uint64> rhs{1,
234                           0x7FFFFFFFFFFFFFFLL,
235                           0x7FFFFFFFFFFFFFFFLL,
236                           0x8000000000000000LL,
237                           0,
238                           static_cast<uint64>(-1),
239                           0,
240                           1,
241                           0x8000000000000000LL};
242   Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
243   auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
244   std::unique_ptr<GlobalData> rhs_data =
245       client_->TransferToServer(rhs_literal).ConsumeValueOrDie();
246 
247   Add(lhs_param, rhs_param);
248 
249   std::vector<uint64> expected(lhs.size());
250   for (int64 i = 0; i < lhs.size(); ++i) {
251     expected[i] = lhs[i] + rhs[i];
252   }
253 
254   ComputeAndCompareR1<uint64>(&b, expected, {lhs_data.get(), rhs_data.get()});
255 }
256 
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantS64s)257 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
258   XlaBuilder b(TestName());
259 
260   std::vector<int64> lhs{static_cast<int64>(0x8000000000000000LL),
261                          static_cast<int64>(0x8000000000000000LL),
262                          -1,
263                          0x7FFFFFFFFFFFFFFLL,
264                          0x7FFFFFFFFFFFFFFFLL,
265                          1,
266                          0,
267                          -1};
268   Literal lhs_literal = LiteralUtil::CreateR1<int64>({lhs});
269   auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
270   std::unique_ptr<GlobalData> lhs_data =
271       client_->TransferToServer(lhs_literal).ConsumeValueOrDie();
272 
273   std::vector<int64> rhs{-1,
274                          0,
275                          static_cast<int64>(0x8000000000000000LL),
276                          1,
277                          0,
278                          0x7FFFFFFFFFFFFFFLL,
279                          0x7FFFFFFFFFFFFFFFLL,
280                          0x7FFFFFFFFFFFFFFFLL};
281   Literal rhs_literal = LiteralUtil::CreateR1<int64>({rhs});
282   auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
283   std::unique_ptr<GlobalData> rhs_data =
284       client_->TransferToServer(rhs_literal).ConsumeValueOrDie();
285 
286   Sub(lhs_param, rhs_param);
287 
288   std::vector<int64> expected(lhs.size());
289   for (int64 i = 0; i < lhs.size(); ++i) {
290     expected[i] = lhs[i] - rhs[i];
291   }
292 
293   ComputeAndCompareR1<int64>(&b, expected, {lhs_data.get(), rhs_data.get()});
294 }
295 
XLA_TEST_F(ArrayElementwiseOpTest,CmpTwoConstantU64s)296 XLA_TEST_F(ArrayElementwiseOpTest, CmpTwoConstantU64s) {
297   XlaBuilder b(TestName());
298 
299   std::vector<uint64> lhs{static_cast<uint64>(0x8000000000000000ULL)};
300   Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
301   auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
302 
303   std::vector<uint64> rhs{static_cast<uint64>(0x7FFFFFFFFFFFFFFFULL)};
304   Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
305   auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
306 
307   Lt(lhs_param, rhs_param);
308 
309   ComputeAndCompare(&b, {std::move(lhs_literal), std::move(rhs_literal)});
310 }
311 
TEST_P(ArrayElementwiseOpTestParamCount,AddManyValues)312 TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
313   const int count = GetParam();
314   XlaBuilder builder(TestName());
315   std::vector<float> a_values;
316   std::vector<float> b_values;
317   for (int i = 0; i < count; ++i) {
318     a_values.push_back(i / static_cast<float>(count));
319     b_values.push_back(2 * i / static_cast<float>(count + 2));
320   }
321 
322   Literal a_literal = LiteralUtil::CreateR1<float>({a_values});
323   std::unique_ptr<GlobalData> a_data =
324       client_->TransferToServer(a_literal).ConsumeValueOrDie();
325   auto a_constant = ConstantR1<float>(&builder, a_values);
326   auto a_param = Parameter(&builder, 0, a_literal.shape(), "a_param");
327 
328   Literal b_literal = LiteralUtil::CreateR1<float>({b_values});
329   std::unique_ptr<GlobalData> b_data =
330       client_->TransferToServer(b_literal).ConsumeValueOrDie();
331   auto b_param = Parameter(&builder, 1, a_literal.shape(), "b_param");
332   auto b_constant = ConstantR1<float>(&builder, b_values);
333 
334   auto sum1 = Add(a_constant, b_param);
335   auto sum2 = Add(a_constant, b_constant);
336   auto sum3 = Add(a_param, b_param);
337   auto sum4 = Add(a_param, b_constant);
338 
339   auto sum = Add(sum1, sum2);
340   sum = Add(sum, sum3);
341   sum = Add(sum, sum4);
342 
343   std::vector<float> expected;
344   for (int64 i = 0; i < count; ++i) {
345     expected.push_back(4 * (a_values[i] + b_values[i]));
346   }
347 
348   ComputeAndCompareR1<float>(&builder, expected, {a_data.get(), b_data.get()},
349                              error_spec_);
350 }
351 
XLA_TEST_F(ArrayElementwiseOpTest,DeeplyNestedAddWithSlices)352 XLA_TEST_F(ArrayElementwiseOpTest, DeeplyNestedAddWithSlices) {
353   XlaBuilder builder(TestName());
354   std::vector<float> values(30, 0.0);
355   auto a_literal = LiteralUtil::CreateR1<float>(values);
356   auto a = Parameter(&builder, 0, a_literal.shape(), "x");
357   auto b_literal = LiteralUtil::CreateR1<float>(values);
358   auto b = Parameter(&builder, 1, b_literal.shape(), "x");
359 
360   // Construct a sequence of diamond-shaped gadgets like this:
361   //
362   //      add
363   //    /    \
364   //  slice  slice
365   //     \   /
366   //      add
367   //
368   // Each 'left' slice removes the last element, each 'right' slice removes the
369   // first element. In this way, we index into the add with different
370   // multi-dimensional index arrays, which defeats the caching we use to avoid
371   // exponential compile time.
372   std::function<XlaOp(int64)> generate_recursive =
373       [&](int64 slice_size) -> XlaOp {
374     if (slice_size == values.size()) {
375       return Add(a, b);
376     }
377     XlaOp param = generate_recursive(slice_size + 1);
378     auto slice1 = Slice(param, {0}, {slice_size}, {1});
379     auto slice2 = Slice(param, {1}, {slice_size + 1}, {1});
380     return Add(slice1, slice2);
381   };
382   generate_recursive(1);
383   auto a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie();
384   auto b_data = client_->TransferToServer(b_literal).ConsumeValueOrDie();
385   ComputeAndCompareR1<float>(&builder, {0.0}, {a_data.get(), b_data.get()});
386 }
387 
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantF32s)388 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) {
389   XlaBuilder builder(TestName());
390   auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
391   auto b = ConstantR1<float>(&builder, {100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
392   Sub(a, b);
393 
394   ComputeAndCompareR1<float>(&builder, {-102.5f, 0.01f, -0.5f, -20.5f, 1005.0f},
395                              {}, error_spec_);
396 }
397 
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantZeroElementF32s)398 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementF32s) {
399   XlaBuilder builder(TestName());
400   auto a = ConstantR1<float>(&builder, {});
401   auto b = ConstantR1<float>(&builder, {});
402   Sub(a, b);
403 
404   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
405 }
406 
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantS32s)407 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) {
408   XlaBuilder builder(TestName());
409   auto a = ConstantR1<int32>(&builder, {-1, 0, 2, 1000000000});
410   auto b = ConstantR1<int32>(&builder, {-1, 2, 1, -1});
411   Sub(a, b);
412 
413   ComputeAndCompareR1<int32>(&builder, {0, -2, 1, 1000000001}, {});
414 }
415 
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantZeroElementS32s)416 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) {
417   XlaBuilder builder(TestName());
418   auto a = ConstantR1<int32>(&builder, {});
419   auto b = ConstantR1<int32>(&builder, {});
420   Sub(a, b);
421 
422   ComputeAndCompareR1<int32>(&builder, {}, {});
423 }
424 
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantC64s)425 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) {
426   XlaBuilder builder(TestName());
427   auto a = ConstantR1<complex64>(&builder,
428                                  {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}});
429   auto b = ConstantR1<complex64>(
430       &builder, {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}});
431   Sub(a, b);
432 
433   ComputeAndCompareR1<complex64>(
434       &builder, {{-2.5f, -10.0f}, {-3.13f, 3.14f}, {0.25f, 2.5f}}, {},
435       error_spec_);
436 }
437 
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantZeroElementC64s)438 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementC64s) {
439   XlaBuilder builder(TestName());
440   auto a = ConstantR1<complex64>(&builder, {});
441   auto b = ConstantR1<complex64>(&builder, {});
442   Sub(a, b);
443 
444   ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
445 }
446 
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantF32s)447 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) {
448   XlaBuilder builder(TestName());
449   auto a = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
450   auto b = ConstantR1<float>(&builder, {10.0f, 5.1f, 1.0f, 10.0f, -6.0f});
451   Div(a, b);
452 
453   ComputeAndCompareR1<float>(&builder, {-0.25f, 5.0f, 2.25f, -1.0f, -1.0f}, {},
454                              error_spec_);
455 }
456 
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantZeroElementF32s)457 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) {
458   XlaBuilder builder(TestName());
459   auto a = ConstantR1<float>(&builder, {});
460   auto b = ConstantR1<float>(&builder, {});
461   Div(a, b);
462 
463   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
464 }
465 
466 class IntegerDivideOpTest : public ArrayElementwiseOpTest {
467  protected:
468   template <typename T>
TestDivRem(absl::Span<const T> dividends,absl::Span<const T> divisors,absl::Span<const T> quotients,absl::Span<const T> remainders)469   void TestDivRem(absl::Span<const T> dividends, absl::Span<const T> divisors,
470                   absl::Span<const T> quotients,
471                   absl::Span<const T> remainders) {
472     {
473       XlaBuilder builder(TestName());
474       XlaOp dividend;
475       XlaOp divisor;
476       auto dividend_data =
477           CreateR1Parameter<T>(dividends, 0, "dividend", &builder, &dividend);
478       auto divisor_data =
479           CreateR1Parameter<T>(divisors, 1, "divisor", &builder, &divisor);
480       Div(dividend, divisor);
481 
482       ComputeAndCompareR1<T>(&builder, quotients,
483                              {dividend_data.get(), divisor_data.get()});
484     }
485 
486     // Test with a compile-time constant divisor.
487     {
488       XlaBuilder builder(TestName());
489       XlaOp dividend;
490       auto dividend_data =
491           CreateR1Parameter<T>(dividends, 0, "dividend", &builder, &dividend);
492       Div(dividend, ConstantR1<T>(&builder, divisors));
493 
494       ComputeAndCompareR1<T>(&builder, quotients, {dividend_data.get()});
495     }
496 
497     {
498       XlaBuilder builder(TestName());
499       XlaOp dividend;
500       XlaOp divisor;
501       auto dividend_data =
502           CreateR1Parameter<T>(dividends, 0, "dividend", &builder, &dividend);
503       auto divisor_data =
504           CreateR1Parameter<T>(divisors, 1, "divisor", &builder, &divisor);
505       Rem(dividend, divisor);
506 
507       ComputeAndCompareR1<T>(&builder, remainders,
508                              {dividend_data.get(), divisor_data.get()});
509     }
510 
511     // Test with a compile-time constant divisor.
512     {
513       XlaBuilder builder(TestName());
514       XlaOp dividend;
515       auto dividend_data =
516           CreateR1Parameter<T>(dividends, 0, "dividend", &builder, &dividend);
517       Rem(dividend, ConstantR1<T>(&builder, divisors));
518 
519       ComputeAndCompareR1<T>(&builder, remainders, {dividend_data.get()});
520     }
521   }
522 };
523 
XLA_TEST_F(IntegerDivideOpTest,DivS32s)524 XLA_TEST_F(IntegerDivideOpTest, DivS32s) {
525   // clang-format off
526   // Some interesting values to test.
527   std::vector<int32> vals = {
528     INT32_MIN, INT32_MIN + 1, INT32_MIN + 2, -0x40000000, -0x3fffffff,
529     -271181, -1309, -17, -10, -5, -3, -2, -1, 0, 1, 2, 3, 5, 10, 17, 26, 101,
530     7919, 0x40000000, INT32_MAX - 2, INT32_MAX - 1, INT32_MAX};
531   // clang-format on
532 
533   std::vector<int32> dividends, divisors, quotients, remainders;
534   for (int32 divisor : vals) {
535     if (divisor != 0) {
536       for (int32 dividend : vals) {
537         // Avoid integer overflow.
538         if (dividend != INT32_MIN || divisor != -1) {
539           dividends.push_back(dividend);
540           divisors.push_back(divisor);
541           quotients.push_back(dividend / divisor);
542           remainders.push_back(dividend % divisor);
543         }
544       }
545     }
546   }
547 
548   TestDivRem<int32>(dividends, divisors, quotients, remainders);
549 }
550 
XLA_TEST_F(IntegerDivideOpTest,SignedOverflow)551 XLA_TEST_F(IntegerDivideOpTest, SignedOverflow) {
552   std::vector<int32> dividends = {5, INT32_MIN}, divisors = {0, -1},
553                      quotients = {-1, INT32_MIN}, remainders = {5, 0};
554 
555   TestDivRem<int32>(dividends, divisors, quotients, remainders);
556 }
557 
XLA_TEST_F(IntegerDivideOpTest,DivU32s)558 XLA_TEST_F(IntegerDivideOpTest, DivU32s) {
559   // clang-format off
560   // Some interesting values to test.
561   std::vector<uint32> vals = {
562     0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0xABCDEF12, 0xCAFEBEEF, 0x80000000,
563     0x80000001, UINT32_MAX - 2, UINT32_MAX - 1, UINT32_MAX};
564   // clang-format on
565 
566   std::vector<uint32> dividends, divisors, quotients, remainders;
567   for (uint32 divisor : vals) {
568     if (divisor != 0) {
569       for (uint32 dividend : vals) {
570         dividends.push_back(dividend);
571         divisors.push_back(divisor);
572         quotients.push_back(dividend / divisor);
573         remainders.push_back(dividend % divisor);
574       }
575     }
576   }
577 
578   TestDivRem<uint32>(dividends, divisors, quotients, remainders);
579 }
580 
XLA_TEST_F(IntegerDivideOpTest,UnsignedOverflow)581 XLA_TEST_F(IntegerDivideOpTest, UnsignedOverflow) {
582   std::vector<int32> dividends = {5}, divisors = {0}, quotients = {-1},
583                      remainders = {5};
584 
585   TestDivRem<int32>(dividends, divisors, quotients, remainders);
586 }
587 
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantC64s)588 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) {
589   XlaBuilder builder(TestName());
590   auto a = ConstantR1<complex64>(
591       &builder, {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}});
592   auto b = ConstantR1<complex64>(&builder,
593                                  {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}});
594   Div(a, b);
595 
596   ComputeAndCompareR1<complex64>(
597       &builder, {{-0.25f, 0.1f}, {0.0f, 25.5f}, {1.0f, 0.0f}}, {}, error_spec_);
598 }
599 
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantZeroElementC64s)600 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementC64s) {
601   XlaBuilder builder(TestName());
602   auto a = ConstantR1<complex64>(&builder, {});
603   auto b = ConstantR1<complex64>(&builder, {});
604   Div(a, b);
605 
606   ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
607 }
608 
XLA_TEST_F(ArrayElementwiseOpTest,RemF32s)609 XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) {
610   XlaBuilder builder(TestName());
611   auto a = ConstantR1<float>(
612       &builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f});
613   auto b = ConstantR1<float>(
614       &builder, {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f});
615   Rem(a, b);
616 
617   ComputeAndCompareR1<float>(
618       &builder, {-2.5f, 0.0f, 0.25f, 0.0f, -0.0f, 1.0f, 1.0f, -1.0f, -0.0f}, {},
619       error_spec_);
620 }
621 
XLA_TEST_F(ArrayElementwiseOpTest,RemZeroElementF32s)622 XLA_TEST_F(ArrayElementwiseOpTest, RemZeroElementF32s) {
623   XlaBuilder builder(TestName());
624   auto a = ConstantR1<float>(&builder, {});
625   auto b = ConstantR1<float>(&builder, {});
626   Rem(a, b);
627 
628   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
629 }
630 
XLA_TEST_F(ArrayElementwiseOpTest,RemF64s)631 XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) {
632   XlaBuilder builder(TestName());
633   auto a = ConstantR1<double>(
634       &builder, {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0});
635   auto b = ConstantR1<double>(
636       &builder, {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0});
637   Rem(a, b);
638 
639   ComputeAndCompareR1<double>(
640       &builder, {-2.5, 0.0, 0.25, 0.0, -0.0, 1.0, 1.0, -1.0, -0.0}, {},
641       error_spec_);
642 }
643 
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantF32s)644 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) {
645   XlaBuilder builder(TestName());
646   auto a = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
647   auto b = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
648   Mul(a, b);
649 
650   ComputeAndCompareR1<float>(&builder, {-25.0f, 127.5f, 2.25f, -100.0f, -36.0f},
651                              {}, error_spec_);
652 }
653 
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantZeroElementF32s)654 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementF32s) {
655   XlaBuilder builder(TestName());
656   auto a = ConstantR1<float>(&builder, {});
657   auto b = ConstantR1<float>(&builder, {});
658   Mul(a, b);
659 
660   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
661 }
662 
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantS32s)663 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) {
664   std::vector<int32> data = {0,
665                              1,
666                              -1,
667                              1234,
668                              0x1a243514,
669                              std::numeric_limits<int32>::max(),
670                              std::numeric_limits<int32>::min()};
671   // Form the test data set using all products of 'data' with itself.
672   std::vector<int32> a_data, b_data, expected;
673   for (int32 a : data) {
674     for (int32 b : data) {
675       a_data.push_back(a);
676       b_data.push_back(b);
677       expected.push_back(static_cast<uint32>(a) * static_cast<uint32>(b));
678     }
679   }
680 
681   XlaBuilder builder(TestName());
682   auto a = ConstantR1<int32>(&builder, a_data);
683   auto b = ConstantR1<int32>(&builder, b_data);
684   Mul(a, b);
685 
686   ComputeAndCompareR1<int32>(&builder, expected, {});
687 }
688 
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantZeroElementS32s)689 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementS32s) {
690   XlaBuilder builder(TestName());
691   auto a = ConstantR1<int32>(&builder, {});
692   auto b = ConstantR1<int32>(&builder, {});
693   Mul(a, b);
694 
695   ComputeAndCompareR1<int32>(&builder, {}, {});
696 }
697 
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantU32s)698 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) {
699   std::vector<uint32> data = {0,          1,          0xDEADBEEF, 1234,
700                               0x1a243514, 0xFFFFFFFF, 0x80808080};
701 
702   // Form the test data set using all products of 'data' with itself.
703   std::vector<uint32> a_data, b_data, expected;
704   for (uint32 a : data) {
705     for (uint32 b : data) {
706       a_data.push_back(a);
707       b_data.push_back(b);
708       expected.push_back(a * b);
709     }
710   }
711 
712   XlaBuilder builder(TestName());
713   auto a = ConstantR1<uint32>(&builder, a_data);
714   auto b = ConstantR1<uint32>(&builder, b_data);
715   Mul(a, b);
716 
717   ComputeAndCompareR1<uint32>(&builder, expected, {});
718 }
719 
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantC64s)720 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) {
721   XlaBuilder builder(TestName());
722   auto a = ConstantR1<complex64>(
723       &builder, {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}});
724   auto b = ConstantR1<complex64>(&builder,
725                                  {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}});
726   Mul(a, b);
727 
728   ComputeAndCompareR1<complex64>(
729       &builder, {{0.0f, -25.0f}, {-25.5f, 127.5f}, {-40.0f, -112.0}}, {},
730       error_spec_);
731 }
732 
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantZeroElementC64s)733 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementC64s) {
734   XlaBuilder builder(TestName());
735   auto a = ConstantR1<complex64>(&builder, {});
736   auto b = ConstantR1<complex64>(&builder, {});
737   Mul(a, b);
738 
739   ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
740 }
741 
XLA_TEST_F(ArrayElementwiseOpTest,AndPredR1)742 XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) {
743   XlaBuilder builder(TestName());
744   auto a = ConstantR1<bool>(&builder, {false, false, true, true});
745   auto b = ConstantR1<bool>(&builder, {false, true, false, true});
746   And(a, b);
747 
748   ComputeAndCompareR1<bool>(&builder, {false, false, false, true}, {});
749 }
750 
XLA_TEST_F(ArrayElementwiseOpTest,AndPredR2)751 XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) {
752   XlaBuilder builder(TestName());
753   auto a = ConstantR2<bool>(&builder, {{false, false}, {true, true}});
754   auto b = ConstantR2<bool>(&builder, {{false, true}, {false, true}});
755   And(a, b);
756 
757   Array2D<bool> expected_array({{false, false}, {false, true}});
758   ComputeAndCompareR2<bool>(&builder, expected_array, {});
759 }
760 
XLA_TEST_F(ArrayElementwiseOpTest,AndZeroElementPredR1)761 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementPredR1) {
762   XlaBuilder builder(TestName());
763   auto a = ConstantR1<bool>(&builder, {});
764   auto b = ConstantR1<bool>(&builder, {});
765   And(a, b);
766 
767   ComputeAndCompareR1<bool>(&builder, {}, {});
768 }
769 
XLA_TEST_F(ArrayElementwiseOpTest,AndS32R1)770 XLA_TEST_F(ArrayElementwiseOpTest, AndS32R1) {
771   XlaBuilder builder(TestName());
772   auto a = ConstantR1<int32>(&builder, {0, -1, -8});
773   auto b = ConstantR1<int32>(&builder, {5, -7, 12});
774   And(a, b);
775 
776   ComputeAndCompareR1<int32>(&builder, {0, -7, 8}, {});
777 }
778 
XLA_TEST_F(ArrayElementwiseOpTest,AndS32R2)779 XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) {
780   XlaBuilder builder(TestName());
781   auto a = ConstantR2<int32>(&builder, {{0, -5}, {-1, 5}});
782   auto b = ConstantR2<int32>(&builder, {{1, -6}, {4, 5}});
783   And(a, b);
784 
785   Array2D<int32> expected_array({{0, -6}, {4, 5}});
786   ComputeAndCompareR2<int32>(&builder, expected_array, {});
787 }
788 
XLA_TEST_F(ArrayElementwiseOpTest,AndZeroElementS32R1)789 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementS32R1) {
790   XlaBuilder builder(TestName());
791   auto a = ConstantR1<int32>(&builder, {});
792   auto b = ConstantR1<int32>(&builder, {});
793   And(a, b);
794 
795   ComputeAndCompareR1<int32>(&builder, {}, {});
796 }
797 
XLA_TEST_F(ArrayElementwiseOpTest,AndU32R1)798 XLA_TEST_F(ArrayElementwiseOpTest, AndU32R1) {
799   XlaBuilder builder(TestName());
800   auto a = ConstantR1<int32>(&builder, {0, 1, 8});
801   auto b = ConstantR1<int32>(&builder, {5, 7, 12});
802   And(a, b);
803 
804   ComputeAndCompareR1<int32>(&builder, {0, 1, 8}, {});
805 }
806 
XLA_TEST_F(ArrayElementwiseOpTest,AndU32R2)807 XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) {
808   XlaBuilder builder(TestName());
809   auto a = ConstantR2<uint32>(&builder, {{0, 1}, {3, 8}});
810   auto b = ConstantR2<uint32>(&builder, {{1, 0}, {7, 6}});
811   And(a, b);
812 
813   Array2D<uint32> expected_array({{0, 0}, {3, 0}});
814   ComputeAndCompareR2<uint32>(&builder, expected_array, {});
815 }
816 
XLA_TEST_F(ArrayElementwiseOpTest,AndZeroElementU32R1)817 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementU32R1) {
818   XlaBuilder builder(TestName());
819   auto a = ConstantR1<uint32>(&builder, {});
820   auto b = ConstantR1<uint32>(&builder, {});
821   And(a, b);
822 
823   ComputeAndCompareR1<uint32>(&builder, {}, {});
824 }
825 
XLA_TEST_F(ArrayElementwiseOpTest,OrPredR1)826 XLA_TEST_F(ArrayElementwiseOpTest, OrPredR1) {
827   XlaBuilder builder(TestName());
828   auto a = ConstantR1<bool>(&builder, {false, false, true, true});
829   auto b = ConstantR1<bool>(&builder, {false, true, false, true});
830   Or(a, b);
831 
832   ComputeAndCompareR1<bool>(&builder, {false, true, true, true}, {});
833 }
834 
XLA_TEST_F(ArrayElementwiseOpTest,OrPredR2)835 XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) {
836   XlaBuilder builder(TestName());
837   auto a = ConstantR2<bool>(&builder, {{false, false}, {true, true}});
838   auto b = ConstantR2<bool>(&builder, {{false, true}, {false, true}});
839   Or(a, b);
840 
841   Array2D<bool> expected_array({{false, true}, {true, true}});
842   ComputeAndCompareR2<bool>(&builder, expected_array, {});
843 }
844 
XLA_TEST_F(ArrayElementwiseOpTest,OrZeroElementPredR1)845 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementPredR1) {
846   XlaBuilder builder(TestName());
847   auto a = ConstantR1<bool>(&builder, {});
848   auto b = ConstantR1<bool>(&builder, {});
849   Or(a, b);
850 
851   ComputeAndCompareR1<bool>(&builder, {}, {});
852 }
853 
XLA_TEST_F(ArrayElementwiseOpTest,OrS32R1)854 XLA_TEST_F(ArrayElementwiseOpTest, OrS32R1) {
855   XlaBuilder builder(TestName());
856   auto a = ConstantR1<int32>(&builder, {0, -1, 8});
857   auto b = ConstantR1<int32>(&builder, {5, -7, 4});
858   Or(a, b);
859 
860   ComputeAndCompareR1<int32>(&builder, {5, -1, 12}, {});
861 }
862 
XLA_TEST_F(ArrayElementwiseOpTest,OrS32R2)863 XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) {
864   XlaBuilder builder(TestName());
865   auto a = ConstantR2<int32>(&builder, {{0, -1}, {8, 8}});
866   auto b = ConstantR2<int32>(&builder, {{5, -7}, {4, 1}});
867   Or(a, b);
868 
869   Array2D<int32> expected_array({{5, -1}, {12, 9}});
870   ComputeAndCompareR2<int32>(&builder, expected_array, {});
871 }
872 
XLA_TEST_F(ArrayElementwiseOpTest,OrZeroElementS32R1)873 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementS32R1) {
874   XlaBuilder builder(TestName());
875   auto a = ConstantR1<int32>(&builder, {});
876   auto b = ConstantR1<int32>(&builder, {});
877   Or(a, b);
878 
879   ComputeAndCompareR1<int32>(&builder, {}, {});
880 }
881 
XLA_TEST_F(ArrayElementwiseOpTest,OrU32R1)882 XLA_TEST_F(ArrayElementwiseOpTest, OrU32R1) {
883   XlaBuilder builder(TestName());
884   auto a = ConstantR1<uint32>(&builder, {0, 1, 8});
885   auto b = ConstantR1<uint32>(&builder, {5, 7, 4});
886   Or(a, b);
887 
888   ComputeAndCompareR1<uint32>(&builder, {5, 7, 12}, {});
889 }
890 
XLA_TEST_F(ArrayElementwiseOpTest,OrU32R2)891 XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) {
892   XlaBuilder builder(TestName());
893   auto a = ConstantR2<uint32>(&builder, {{0, 1}, {8, 8}});
894   auto b = ConstantR2<uint32>(&builder, {{5, 7}, {4, 1}});
895   Or(a, b);
896 
897   Array2D<uint32> expected_array({{5, 7}, {12, 9}});
898   ComputeAndCompareR2<uint32>(&builder, expected_array, {});
899 }
900 
XLA_TEST_F(ArrayElementwiseOpTest,OrZeroElementU32R1)901 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementU32R1) {
902   XlaBuilder builder(TestName());
903   auto a = ConstantR1<uint32>(&builder, {});
904   auto b = ConstantR1<uint32>(&builder, {});
905   Or(a, b);
906 
907   ComputeAndCompareR1<uint32>(&builder, {}, {});
908 }
909 
XLA_TEST_F(ArrayElementwiseOpTest,XorPredR1)910 XLA_TEST_F(ArrayElementwiseOpTest, XorPredR1) {
911   XlaBuilder builder(TestName());
912   auto a = ConstantR1<bool>(&builder, {false, false, true, true});
913   auto b = ConstantR1<bool>(&builder, {false, true, false, true});
914   Xor(a, b);
915 
916   ComputeAndCompareR1<bool>(&builder, {false, true, true, false}, {});
917 }
918 
XLA_TEST_F(ArrayElementwiseOpTest,XorPredR2)919 XLA_TEST_F(ArrayElementwiseOpTest, XorPredR2) {
920   XlaBuilder builder(TestName());
921   auto a = ConstantR2<bool>(&builder, {{false, false}, {true, true}});
922   auto b = ConstantR2<bool>(&builder, {{false, true}, {false, true}});
923   Xor(a, b);
924 
925   Array2D<bool> expected_array({{false, true}, {true, false}});
926   ComputeAndCompareR2<bool>(&builder, expected_array, {});
927 }
928 
XLA_TEST_F(ArrayElementwiseOpTest,XorZeroElementPredR1)929 XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementPredR1) {
930   XlaBuilder builder(TestName());
931   auto a = ConstantR1<bool>(&builder, {});
932   auto b = ConstantR1<bool>(&builder, {});
933   Xor(a, b);
934 
935   ComputeAndCompareR1<bool>(&builder, {}, {});
936 }
937 
XLA_TEST_F(ArrayElementwiseOpTest,XorS32R1)938 XLA_TEST_F(ArrayElementwiseOpTest, XorS32R1) {
939   XlaBuilder builder(TestName());
940   auto a = ConstantR1<int32>(&builder, {0, -1, 8});
941   auto b = ConstantR1<int32>(&builder, {5, -7, 4});
942   Xor(a, b);
943 
944   ComputeAndCompareR1<int32>(&builder, {5, 6, 12}, {});
945 }
946 
XLA_TEST_F(ArrayElementwiseOpTest,XorS32R2)947 XLA_TEST_F(ArrayElementwiseOpTest, XorS32R2) {
948   XlaBuilder builder(TestName());
949   auto a = ConstantR2<int32>(&builder, {{0, -1}, {8, 8}});
950   auto b = ConstantR2<int32>(&builder, {{5, -7}, {4, 1}});
951   Xor(a, b);
952 
953   Array2D<int32> expected_array({{5, 6}, {12, 9}});
954   ComputeAndCompareR2<int32>(&builder, expected_array, {});
955 }
956 
XLA_TEST_F(ArrayElementwiseOpTest,XorZeroElementS32R1)957 XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementS32R1) {
958   XlaBuilder builder(TestName());
959   auto a = ConstantR1<int32>(&builder, {});
960   auto b = ConstantR1<int32>(&builder, {});
961   Xor(a, b);
962 
963   ComputeAndCompareR1<int32>(&builder, {}, {});
964 }
965 
XLA_TEST_F(ArrayElementwiseOpTest,XorU32R1)966 XLA_TEST_F(ArrayElementwiseOpTest, XorU32R1) {
967   XlaBuilder builder(TestName());
968   auto a = ConstantR1<uint32>(&builder, {0, 1, 8});
969   auto b = ConstantR1<uint32>(&builder, {5, 7, 4});
970   Xor(a, b);
971 
972   ComputeAndCompareR1<uint32>(&builder, {5, 6, 12}, {});
973 }
974 
XLA_TEST_F(ArrayElementwiseOpTest,XorU32R2)975 XLA_TEST_F(ArrayElementwiseOpTest, XorU32R2) {
976   XlaBuilder builder(TestName());
977   auto a = ConstantR2<uint32>(&builder, {{0, 1}, {8, 8}});
978   auto b = ConstantR2<uint32>(&builder, {{5, 7}, {4, 1}});
979   Xor(a, b);
980 
981   Array2D<uint32> expected_array({{5, 6}, {12, 9}});
982   ComputeAndCompareR2<uint32>(&builder, expected_array, {});
983 }
984 
XLA_TEST_F(ArrayElementwiseOpTest,XorZeroElementU32R1)985 XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementU32R1) {
986   XlaBuilder builder(TestName());
987   auto a = ConstantR1<uint32>(&builder, {});
988   auto b = ConstantR1<uint32>(&builder, {});
989   Xor(a, b);
990 
991   ComputeAndCompareR1<uint32>(&builder, {}, {});
992 }
XLA_TEST_F(ArrayElementwiseOpTest,NotPredR1)993 XLA_TEST_F(ArrayElementwiseOpTest, NotPredR1) {
994   XlaBuilder builder(TestName());
995   auto a = ConstantR1<bool>(&builder, {false, true, true, false});
996   Not(a);
997 
998   ComputeAndCompareR1<bool>(&builder, {true, false, false, true}, {});
999 }
1000 
XLA_TEST_F(ArrayElementwiseOpTest,NotPredR2)1001 XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) {
1002   XlaBuilder builder(TestName());
1003   auto a = ConstantR2<bool>(&builder, {{false, true}, {true, false}});
1004   Not(a);
1005 
1006   Array2D<bool> expected_array({{true, false}, {false, true}});
1007   ComputeAndCompareR2<bool>(&builder, expected_array, {});
1008 }
1009 
XLA_TEST_F(ArrayElementwiseOpTest,NotZeroElementPredR1)1010 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementPredR1) {
1011   XlaBuilder builder(TestName());
1012   auto a = ConstantR1<bool>(&builder, {});
1013   Not(a);
1014 
1015   ComputeAndCompareR1<bool>(&builder, {}, {});
1016 }
1017 
XLA_TEST_F(ArrayElementwiseOpTest,NotS32R1)1018 XLA_TEST_F(ArrayElementwiseOpTest, NotS32R1) {
1019   XlaBuilder builder(TestName());
1020   auto a = ConstantR1<int32>(&builder, {-1, 0, 1});
1021   Not(a);
1022 
1023   ComputeAndCompareR1<int32>(&builder, {0, -1, -2}, {});
1024 }
1025 
XLA_TEST_F(ArrayElementwiseOpTest,NotS32R2)1026 XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) {
1027   XlaBuilder builder(TestName());
1028   auto a = ConstantR2<int32>(&builder, {{-1, 0}, {1, 8}});
1029   Not(a);
1030 
1031   Array2D<int32> expected_array({{0, -1}, {-2, -9}});
1032   ComputeAndCompareR2<int32>(&builder, expected_array, {});
1033 }
1034 
XLA_TEST_F(ArrayElementwiseOpTest,NotZeroElementS32R1)1035 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementS32R1) {
1036   XlaBuilder builder(TestName());
1037   auto a = ConstantR1<int32>(&builder, {});
1038   Not(a);
1039 
1040   ComputeAndCompareR1<int32>(&builder, {}, {});
1041 }
1042 
XLA_TEST_F(ArrayElementwiseOpTest,NotU32R1)1043 XLA_TEST_F(ArrayElementwiseOpTest, NotU32R1) {
1044   XlaBuilder builder(TestName());
1045   auto a = ConstantR1<uint32>(&builder, {0, 4294967295});
1046   Not(a);
1047 
1048   ComputeAndCompareR1<uint32>(&builder, {4294967295, 0}, {});
1049 }
1050 
XLA_TEST_F(ArrayElementwiseOpTest,NotU32R2)1051 XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) {
1052   XlaBuilder builder(TestName());
1053   auto a = ConstantR2<uint32>(&builder, {{0, 4294967295}, {1, 4294967294}});
1054   Not(a);
1055 
1056   Array2D<uint32> expected_array({{4294967295, 0}, {4294967294, 1}});
1057   ComputeAndCompareR2<uint32>(&builder, expected_array, {});
1058 }
1059 
XLA_TEST_F(ArrayElementwiseOpTest,NotZeroElementU32R1)1060 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) {
1061   XlaBuilder builder(TestName());
1062   auto a = ConstantR1<uint32>(&builder, {});
1063   Not(a);
1064 
1065   ComputeAndCompareR1<uint32>(&builder, {}, {});
1066 }
1067 
XLA_TEST_F(ArrayElementwiseOpTest,ShiftLeftS32)1068 XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) {
1069   XlaBuilder builder(TestName());
1070   auto a = ConstantR1<int32>(
1071       &builder, {static_cast<int32>(0x12345678), static_cast<int32>(0xF0001000),
1072                  1, 3, 77, 1, -3, 77});
1073   auto b = ConstantR1<int32>(&builder, {4, 8, 2, 7, 15, 32, 100, -1});
1074   ShiftLeft(a, b);
1075 
1076   ComputeAndCompareR1<int32>(&builder,
1077                              {static_cast<int32>(0x23456780), 0x00100000, 0x4,
1078                               0x180, 2523136, 0, 0, 0},
1079                              {});
1080 }
1081 
XLA_TEST_F(ArrayElementwiseOpTest,ShiftRightArithmeticS32)1082 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) {
1083   XlaBuilder builder(TestName());
1084   auto a = ConstantR1<int32>(
1085       &builder, {static_cast<int32>(0x92345678), static_cast<int32>(0x10001000),
1086                  1, 3, 77, 1, -3, 77});
1087   auto b = ConstantR1<int32>(&builder, {4, 8, 2, 7, 2, 32, 100, -1});
1088   ShiftRightArithmetic(a, b);
1089 
1090   ComputeAndCompareR1<int32>(
1091       &builder,
1092       {static_cast<int32>(0xF9234567), static_cast<int32>(0x00100010), 0, 0, 19,
1093        0, -1, 0},
1094       {});
1095 }
1096 
XLA_TEST_F(ArrayElementwiseOpTest,ShiftRightLogicalS32)1097 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) {
1098   XlaBuilder builder(TestName());
1099   auto a = ConstantR1<int32>(
1100       &builder, {static_cast<int32>(0x92345678), static_cast<int32>(0x10001000),
1101                  1, 3, 77, 1, -3, 77});
1102   auto b = ConstantR1<int32>(&builder, {4, 8, 2, 7, 5, 32, 100, -1});
1103   ShiftRightLogical(a, b);
1104 
1105   ComputeAndCompareR1<int32>(&builder,
1106                              {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {});
1107 }
1108 
XLA_TEST_F(ArrayElementwiseOpTest,ShiftLeftU32)1109 XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) {
1110   XlaBuilder builder(TestName());
1111   auto a = ConstantR1<uint32>(&builder,
1112                               {0x12345678, 0xF0001000, 1, 3, 77, 1, ~3u, 77});
1113   auto b = ConstantR1<uint32>(&builder, {4, 8, 2, 7, 15, 32, 100, ~0u});
1114   ShiftLeft(a, b);
1115 
1116   ComputeAndCompareR1<uint32>(
1117       &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136, 0, 0, 0}, {});
1118 }
1119 
XLA_TEST_F(ArrayElementwiseOpTest,ShiftRightArithmeticU32)1120 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) {
1121   XlaBuilder builder(TestName());
1122   auto a = ConstantR1<uint32>(&builder,
1123                               {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77});
1124   auto b = ConstantR1<uint32>(&builder, {4, 8, 2, 7, 2, 32, 100, ~0u});
1125   ShiftRightArithmetic(a, b);
1126 
1127   ComputeAndCompareR1<uint32>(
1128       &builder, {0xF9234567, 0x00100010, 0, 0, 19, 0, ~0u, 0}, {});
1129 }
1130 
XLA_TEST_F(ArrayElementwiseOpTest,ShiftRightLogicalU32)1131 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) {
1132   XlaBuilder builder(TestName());
1133   auto a = ConstantR1<uint32>(&builder,
1134                               {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77});
1135   auto b = ConstantR1<uint32>(&builder, {4, 8, 2, 7, 5, 32, 100, ~0u});
1136   ShiftRightLogical(a, b);
1137 
1138   ComputeAndCompareR1<uint32>(&builder,
1139                               {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {});
1140 }
1141 
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqF32s)1142 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) {
1143   SetFastMathDisabled(true);
1144   XlaBuilder builder(TestName());
1145   auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1146   auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 2.25f, 10.0f, NAN});
1147   Eq(lhs, rhs);
1148 
1149   ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {});
1150 }
1151 
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqZeroElementF32s)1152 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) {
1153   XlaBuilder builder(TestName());
1154   auto lhs = ConstantR1<float>(&builder, {});
1155   auto rhs = ConstantR1<float>(&builder, {});
1156   Eq(lhs, rhs);
1157 
1158   ComputeAndCompareR1<bool>(&builder, {}, {});
1159 }
1160 
XLA_TEST_F(ArrayElementwiseOpTest,CompareGeF32s)1161 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) {
1162   SetFastMathDisabled(true);
1163   XlaBuilder builder(TestName());
1164   auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1165   auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
1166   Ge(lhs, rhs);
1167 
1168   ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
1169 }
1170 
XLA_TEST_F(ArrayElementwiseOpTest,CompareGtF32s)1171 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) {
1172   SetFastMathDisabled(true);
1173   XlaBuilder builder(TestName());
1174   auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1175   auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
1176   Gt(lhs, rhs);
1177 
1178   ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
1179 }
1180 
XLA_TEST_F(ArrayElementwiseOpTest,CompareLeF32s)1181 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeF32s) {
1182   SetFastMathDisabled(true);
1183   XlaBuilder builder(TestName());
1184   auto lhs = ConstantR1<float>(&builder, {-2.5f, 5.0f, 2.25f, NAN, 6.0f});
1185   auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
1186   Le(lhs, rhs);
1187 
1188   ComputeAndCompareR1<bool>(&builder, {true, true, false, false, false}, {});
1189 }
1190 
XLA_TEST_F(ArrayElementwiseOpTest,CompareLtF32s)1191 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtF32s) {
1192   SetFastMathDisabled(true);
1193   XlaBuilder builder(TestName());
1194   auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1195   auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
1196   Lt(lhs, rhs);
1197 
1198   ComputeAndCompareR1<bool>(&builder, {true, false, false, false, false}, {});
1199 }
1200 
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqS32s)1201 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) {
1202   const int32 min = std::numeric_limits<int32>::min();
1203   const int32 max = std::numeric_limits<int32>::max();
1204   XlaBuilder builder(TestName());
1205   auto lhs =
1206       ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1207   auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1208   Eq(lhs, rhs);
1209 
1210   ComputeAndCompareR1<bool>(
1211       &builder, {true, false, false, false, true, false, false, false, true},
1212       {});
1213 }
1214 
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqZeroElementS32s)1215 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) {
1216   XlaBuilder builder(TestName());
1217   auto lhs = ConstantR1<int32>(&builder, {});
1218   auto rhs = ConstantR1<int32>(&builder, {});
1219   Eq(lhs, rhs);
1220 
1221   ComputeAndCompareR1<bool>(&builder, {}, {});
1222 }
1223 
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqC64s)1224 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) {
1225   SetFastMathDisabled(true);
1226   XlaBuilder builder(TestName());
1227   auto lhs = ConstantR1<complex64>(&builder, {{-2.5f, 10.0f},
1228                                               {1.0f, 25.5f},
1229                                               {2.25f, -3.0f},
1230                                               {NAN, 0.0f},
1231                                               {1.0f, 6.0f}});
1232   auto rhs = ConstantR1<complex64>(&builder, {{0.0f, 10.0f},
1233                                               {1.0f, 5.0f},
1234                                               {2.25f, -3.0f},
1235                                               {10.0f, 0.0f},
1236                                               {1.0f, NAN}});
1237   Eq(lhs, rhs);
1238 
1239   ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {});
1240 }
1241 
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqZeroElementC64s)1242 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementC64s) {
1243   XlaBuilder builder(TestName());
1244   auto lhs = ConstantR1<complex64>(&builder, {});
1245   auto rhs = ConstantR1<complex64>(&builder, {});
1246   Eq(lhs, rhs);
1247 
1248   ComputeAndCompareR1<bool>(&builder, {}, {});
1249 }
1250 
XLA_TEST_F(ArrayElementwiseOpTest,CompareNeC64s)1251 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) {
1252   // Disable fast-math because we're operating on NaNs.
1253   SetFastMathDisabled(true);
1254 
1255   XlaBuilder builder(TestName());
1256   auto lhs = ConstantR1<complex64>(&builder, {{-2.5f, 10.0f},
1257                                               {1.0f, 25.5f},
1258                                               {2.25f, -3.0f},
1259                                               {NAN, 0.0f},
1260                                               {1.0f, 6.0f}});
1261   auto rhs = ConstantR1<complex64>(&builder, {{0.0f, 10.0f},
1262                                               {1.0f, 5.0f},
1263                                               {2.25f, -3.0f},
1264                                               {10.0f, 0.0f},
1265                                               {1.0f, NAN}});
1266   Ne(lhs, rhs);
1267 
1268   ComputeAndCompareR1<bool>(&builder, {true, true, false, true, true}, {});
1269 }
1270 
XLA_TEST_F(ArrayElementwiseOpTest,CompareNeF32s)1271 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) {
1272   // Disable fast-math because we're operating on NaNs.
1273   SetFastMathDisabled(true);
1274 
1275   XlaBuilder builder(TestName());
1276   auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1277   auto rhs = ConstantR1<float>(&builder, {10.0f, 25.5f, 1.0f, 10.0f, NAN});
1278   Ne(lhs, rhs);
1279 
1280   ComputeAndCompareR1<bool>(&builder, {true, false, true, true, true}, {});
1281 }
1282 
XLA_TEST_F(ArrayElementwiseOpTest,CompareNeS32s)1283 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeS32s) {
1284   const int32 min = std::numeric_limits<int32>::min();
1285   const int32 max = std::numeric_limits<int32>::max();
1286   XlaBuilder builder(TestName());
1287   auto lhs =
1288       ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1289   auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1290   Ne(lhs, rhs);
1291 
1292   ComputeAndCompareR1<bool>(
1293       &builder, {false, true, true, true, false, true, true, true, false}, {});
1294 }
1295 
XLA_TEST_F(ArrayElementwiseOpTest,CompareGeS32s)1296 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeS32s) {
1297   const int32 min = std::numeric_limits<int32>::min();
1298   const int32 max = std::numeric_limits<int32>::max();
1299   XlaBuilder builder(TestName());
1300   auto lhs =
1301       ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1302   auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1303   Ge(lhs, rhs);
1304 
1305   ComputeAndCompareR1<bool>(
1306       &builder, {true, false, false, true, true, false, true, true, true}, {});
1307 }
1308 
XLA_TEST_F(ArrayElementwiseOpTest,CompareGtS32s)1309 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtS32s) {
1310   const int32 min = std::numeric_limits<int32>::min();
1311   const int32 max = std::numeric_limits<int32>::max();
1312   XlaBuilder builder(TestName());
1313   auto lhs =
1314       ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1315   auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1316   Gt(lhs, rhs);
1317 
1318   ComputeAndCompareR1<bool>(
1319       &builder, {false, false, false, true, false, false, true, true, false},
1320       {});
1321 }
1322 
XLA_TEST_F(ArrayElementwiseOpTest,CompareLeS32s)1323 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeS32s) {
1324   const int32 min = std::numeric_limits<int32>::min();
1325   const int32 max = std::numeric_limits<int32>::max();
1326   XlaBuilder builder(TestName());
1327   auto lhs =
1328       ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1329   auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1330   Le(lhs, rhs);
1331 
1332   ComputeAndCompareR1<bool>(
1333       &builder, {true, true, true, false, true, true, false, false, true}, {});
1334 }
1335 
XLA_TEST_F(ArrayElementwiseOpTest,CompareLtS32s)1336 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) {
1337   const int32 min = std::numeric_limits<int32>::min();
1338   const int32 max = std::numeric_limits<int32>::max();
1339   XlaBuilder builder(TestName());
1340   auto lhs =
1341       ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1342   auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1343   Lt(lhs, rhs);
1344 
1345   ComputeAndCompareR1<bool>(
1346       &builder, {false, true, true, false, false, true, false, false, false},
1347       {});
1348 }
1349 
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqU32s)1350 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) {
1351   const uint32 max = std::numeric_limits<uint32>::max();
1352   XlaBuilder builder(TestName());
1353   auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1354   auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1355   Eq(lhs, rhs);
1356 
1357   ComputeAndCompareR1<bool>(
1358       &builder, {true, false, false, false, true, false, false, false, true},
1359       {});
1360 }
1361 
XLA_TEST_F(ArrayElementwiseOpTest,CompareNeU32s)1362 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) {
1363   const uint32 max = std::numeric_limits<uint32>::max();
1364   XlaBuilder builder(TestName());
1365   auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1366   auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1367   Ne(lhs, rhs);
1368 
1369   ComputeAndCompareR1<bool>(
1370       &builder, {false, true, true, true, false, true, true, true, false}, {});
1371 }
1372 
XLA_TEST_F(ArrayElementwiseOpTest,CompareGeU32s)1373 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) {
1374   const uint32 max = std::numeric_limits<uint32>::max();
1375   XlaBuilder builder(TestName());
1376   auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1377   auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1378   Ge(lhs, rhs);
1379 
1380   ComputeAndCompareR1<bool>(
1381       &builder, {true, false, false, true, true, false, true, true, true}, {});
1382 }
1383 
XLA_TEST_F(ArrayElementwiseOpTest,CompareGtU32s)1384 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) {
1385   const uint32 max = std::numeric_limits<uint32>::max();
1386   XlaBuilder builder(TestName());
1387   auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1388   auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1389   Gt(lhs, rhs);
1390 
1391   ComputeAndCompareR1<bool>(
1392       &builder, {false, false, false, true, false, false, true, true, false},
1393       {});
1394 }
1395 
XLA_TEST_F(ArrayElementwiseOpTest,CompareLeU32s)1396 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) {
1397   const uint32 max = std::numeric_limits<uint32>::max();
1398   XlaBuilder builder(TestName());
1399   auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1400   auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1401   Le(lhs, rhs);
1402 
1403   ComputeAndCompareR1<bool>(
1404       &builder, {true, true, true, false, true, true, false, false, true}, {});
1405 }
1406 
XLA_TEST_F(ArrayElementwiseOpTest,CompareLtU32s)1407 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) {
1408   const uint32 max = std::numeric_limits<uint32>::max();
1409   XlaBuilder builder(TestName());
1410   auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1411   auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1412   Lt(lhs, rhs);
1413 
1414   ComputeAndCompareR1<bool>(
1415       &builder, {false, true, true, false, false, true, false, false, false},
1416       {});
1417 }
1418 
XLA_TEST_F(ArrayElementwiseOpTest,PowF32s)1419 XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) {
1420   SetFastMathDisabled(true);
1421   XlaBuilder builder(TestName());
1422   auto lhs =
1423       ConstantR1<float>(&builder, {4.0f, 2.0f, 2.0f, NAN, 6.0f, -2.0f, -2.0f});
1424   auto rhs =
1425       ConstantR1<float>(&builder, {2.0f, -2.0f, 3.0f, 10.0f, NAN, 3.0f, 4.0f});
1426   Pow(lhs, rhs);
1427 
1428   ComputeAndCompareR1<float>(
1429       &builder, {16.0f, 0.25f, 8.0f, NAN, NAN, -8.0f, 16.0f}, {}, error_spec_);
1430 }
1431 
XLA_TEST_F(ArrayElementwiseOpTest,PowNonIntegerF32s)1432 XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) {
1433   SetFastMathDisabled(true);
1434   XlaBuilder builder(TestName());
1435   auto lhs = ConstantR1<float>(&builder, {-2.0f, -0.6f, -0.6f, 0.0f});
1436   auto rhs = ConstantR1<float>(&builder, {0.5f, 0.6f, -0.6f, -0.6f});
1437   Pow(lhs, rhs);
1438 
1439   ComputeAndCompareR1<float>(&builder, {NAN, NAN, NAN, INFINITY}, {},
1440                              error_spec_);
1441 }
1442 
XLA_TEST_F(ArrayElementwiseOpTest,PowC64s)1443 XLA_TEST_F(ArrayElementwiseOpTest, PowC64s) {
1444   SetFastMathDisabled(true);
1445   XlaBuilder builder(TestName());
1446   auto lhs =
1447       ConstantR1<complex64>(&builder, {-2.0f, -0.6f, -0.6f, 0.0f, 0.0f, 0.0f});
1448   auto rhs =
1449       ConstantR1<complex64>(&builder, {0.5f, 0.6f, -0.6f, 0.5f, 0.6f, 0.0f});
1450   Pow(lhs, rhs);
1451 
1452   ComputeAndCompareR1<complex64>(&builder,
1453                                  {
1454                                      {0, 1.41421356},
1455                                      {-2.27443288e-01, 0.69999846},
1456                                      {-4.19847531e-01, -1.29215783},
1457                                      {0, 0},
1458                                      {0, 0},
1459                                      {1, 0},
1460                                  },
1461                                  {}, error_spec_);
1462 }
1463 
XLA_TEST_F(ArrayElementwiseOpTest,PowZeroElementF32s)1464 XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) {
1465   XlaBuilder builder(TestName());
1466   auto lhs = ConstantR1<float>(&builder, {});
1467   auto rhs = ConstantR1<float>(&builder, {});
1468   Pow(lhs, rhs);
1469 
1470   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
1471 }
1472 
1473 // Some Pow cases that can be implemented more efficiently.
XLA_TEST_F(ArrayElementwiseOpTest,PowSpecialF32)1474 XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
1475   XlaBuilder b(TestName());
1476 
1477   std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f};
1478   std::vector<float> exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1479 
1480   Literal param_literal = LiteralUtil::CreateR1<float>(values);
1481   std::unique_ptr<GlobalData> param_data =
1482       client_->TransferToServer(param_literal).ConsumeValueOrDie();
1483 
1484   auto sum = ConstantR0<float>(&b, 0.0f);
1485   auto param = Parameter(&b, 0, param_literal.shape(), "param");
1486   for (float exponent : exponents) {
1487     sum = Add(sum, Pow(param, ConstantR0<float>(&b, exponent)));
1488   }
1489 
1490   std::vector<float> expected;
1491   for (auto value : values) {
1492     float sum = 0.0f;
1493     for (float exponent : exponents) {
1494       sum += std::pow(value, exponent);
1495     }
1496     expected.push_back(sum);
1497   }
1498 
1499   ComputeAndCompareR1<float>(&b, expected, {param_data.get()}, error_spec_);
1500 }
1501 
XLA_TEST_F(ArrayElementwiseOpTest,PowOfExpF32)1502 XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) {
1503   XlaBuilder b(TestName());
1504 
1505   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
1506   std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1507 
1508   Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1509   std::unique_ptr<GlobalData> data0 =
1510       client_->TransferToServer(literal0).ConsumeValueOrDie();
1511   Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1512   std::unique_ptr<GlobalData> data1 =
1513       client_->TransferToServer(literal1).ConsumeValueOrDie();
1514   auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1515   auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1516   Pow(Exp(param0), param1);
1517 
1518   std::vector<float> expected(values0.size());
1519   for (int64 i = 0; i < values0.size(); ++i) {
1520     expected[i] = std::pow(std::exp(values0[i]), values1[i]);
1521   }
1522 
1523   ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1524                              error_spec_);
1525 }
1526 
XLA_TEST_F(ArrayElementwiseOpTest,LogOfPowerF32)1527 XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) {
1528   XlaBuilder b(TestName());
1529 
1530   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f};
1531   std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1532 
1533   Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1534   std::unique_ptr<GlobalData> data0 =
1535       client_->TransferToServer(literal0).ConsumeValueOrDie();
1536   Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1537   std::unique_ptr<GlobalData> data1 =
1538       client_->TransferToServer(literal1).ConsumeValueOrDie();
1539   auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1540   auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1541   Log(Pow(param0, param1));
1542 
1543   std::vector<float> expected(values0.size());
1544   for (int64 i = 0; i < values0.size(); ++i) {
1545     expected[i] = std::log(std::pow(values0[i], values1[i]));
1546   }
1547 
1548   ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1549                              error_spec_);
1550 }
1551 
XLA_TEST_F(ArrayElementwiseOpTest,MulOfExpF32)1552 XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) {
1553   XlaBuilder b(TestName());
1554 
1555   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
1556   std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1557 
1558   Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1559   std::unique_ptr<GlobalData> data0 =
1560       client_->TransferToServer(literal0).ConsumeValueOrDie();
1561   Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1562   std::unique_ptr<GlobalData> data1 =
1563       client_->TransferToServer(literal1).ConsumeValueOrDie();
1564   auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1565   auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1566   Mul(Exp(param0), Exp(param1));
1567 
1568   std::vector<float> expected(values0.size());
1569   for (int64 i = 0; i < values0.size(); ++i) {
1570     expected[i] = std::exp(values0[i]) * std::exp(values1[i]);
1571   }
1572 
1573   ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1574                              error_spec_);
1575 }
1576 
XLA_TEST_F(ArrayElementwiseOpTest,DivOfExpF32)1577 XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) {
1578   XlaBuilder b(TestName());
1579 
1580   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
1581   std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1582 
1583   Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1584   std::unique_ptr<GlobalData> data0 =
1585       client_->TransferToServer(literal0).ConsumeValueOrDie();
1586   Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1587   std::unique_ptr<GlobalData> data1 =
1588       client_->TransferToServer(literal1).ConsumeValueOrDie();
1589   auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1590   auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1591   Div(param0, Exp(param1));
1592 
1593   std::vector<float> expected(values0.size());
1594   for (int64 i = 0; i < values0.size(); ++i) {
1595     expected[i] = values0[i] / std::exp(values1[i]);
1596   }
1597 
1598   ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1599                              error_spec_);
1600 }
1601 
XLA_TEST_F(ArrayElementwiseOpTest,Div3_lhs_F32)1602 XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) {
1603   XlaBuilder b(TestName());
1604 
1605   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1606   std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1607   std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
1608 
1609   Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1610   std::unique_ptr<GlobalData> data0 =
1611       client_->TransferToServer(literal0).ConsumeValueOrDie();
1612 
1613   Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1614   std::unique_ptr<GlobalData> data1 =
1615       client_->TransferToServer(literal1).ConsumeValueOrDie();
1616 
1617   Literal literal2 = LiteralUtil::CreateR1<float>(values2);
1618   std::unique_ptr<GlobalData> data2 =
1619       client_->TransferToServer(literal2).ConsumeValueOrDie();
1620   auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1621   auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1622   auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
1623   Div(Div(param0, param1), param2);
1624 
1625   std::vector<float> expected(values0.size());
1626   for (int64 i = 0; i < values0.size(); ++i) {
1627     expected[i] = (values0[i] / values1[i]) / values2[i];
1628   }
1629 
1630   ComputeAndCompareR1<float>(
1631       &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
1632 }
1633 
XLA_TEST_F(ArrayElementwiseOpTest,Div3_rhs_F32)1634 XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) {
1635   XlaBuilder b(TestName());
1636 
1637   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1638   std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1639   std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
1640 
1641   Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1642   std::unique_ptr<GlobalData> data0 =
1643       client_->TransferToServer(literal0).ConsumeValueOrDie();
1644 
1645   Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1646   std::unique_ptr<GlobalData> data1 =
1647       client_->TransferToServer(literal1).ConsumeValueOrDie();
1648 
1649   Literal literal2 = LiteralUtil::CreateR1<float>(values2);
1650   std::unique_ptr<GlobalData> data2 =
1651       client_->TransferToServer(literal2).ConsumeValueOrDie();
1652 
1653   auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1654   auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1655   auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
1656   Div(param0, Div(param1, param2));
1657 
1658   std::vector<float> expected(values0.size());
1659   for (int64 i = 0; i < values0.size(); ++i) {
1660     expected[i] = values0[i] / (values1[i] / values2[i]);
1661   }
1662 
1663   ComputeAndCompareR1<float>(
1664       &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
1665 }
1666 
XLA_TEST_F(ArrayElementwiseOpTest,DivOfPowerF32)1667 XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) {
1668   XlaBuilder b(TestName());
1669 
1670   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1671   std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f};
1672   std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f};
1673 
1674   Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1675   std::unique_ptr<GlobalData> data0 =
1676       client_->TransferToServer(literal0).ConsumeValueOrDie();
1677 
1678   Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1679   std::unique_ptr<GlobalData> data1 =
1680       client_->TransferToServer(literal1).ConsumeValueOrDie();
1681 
1682   Literal literal2 = LiteralUtil::CreateR1<float>(values2);
1683   std::unique_ptr<GlobalData> data2 =
1684       client_->TransferToServer(literal2).ConsumeValueOrDie();
1685 
1686   auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1687   auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1688   auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
1689   Div(param0, Pow(param1, param2));
1690 
1691   std::vector<float> expected(values0.size());
1692   for (int64 i = 0; i < values0.size(); ++i) {
1693     expected[i] = values0[i] / std::pow(values1[i], values2[i]);
1694   }
1695 
1696   ComputeAndCompareR1<float>(
1697       &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
1698 }
1699 
XLA_TEST_F(ArrayElementwiseOpTest,Div4F32)1700 XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) {
1701   XlaBuilder b(TestName());
1702 
1703   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1704   std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1705   std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
1706   std::vector<float> values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f};
1707 
1708   Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1709   std::unique_ptr<GlobalData> data0 =
1710       client_->TransferToServer(literal0).ConsumeValueOrDie();
1711 
1712   Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1713   std::unique_ptr<GlobalData> data1 =
1714       client_->TransferToServer(literal1).ConsumeValueOrDie();
1715 
1716   Literal literal2 = LiteralUtil::CreateR1<float>(values2);
1717   std::unique_ptr<GlobalData> data2 =
1718       client_->TransferToServer(literal2).ConsumeValueOrDie();
1719 
1720   Literal literal3 = LiteralUtil::CreateR1<float>(values3);
1721   std::unique_ptr<GlobalData> data3 =
1722       client_->TransferToServer(literal3).ConsumeValueOrDie();
1723 
1724   auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1725   auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1726   auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
1727   auto param3 = Parameter(&b, 3, literal3.shape(), "param2");
1728   Div(Div(param0, param1), Div(param2, param3));
1729 
1730   std::vector<float> expected(values0.size());
1731   for (int64 i = 0; i < values0.size(); ++i) {
1732     expected[i] = (values0[i] / values1[i]) / (values2[i] / values3[i]);
1733   }
1734 
1735   ComputeAndCompareR1<float>(
1736       &b, expected, {data0.get(), data1.get(), data2.get(), data3.get()},
1737       error_spec_);
1738 }
1739 
TEST_P(ArrayElementwiseOpTestParamCount,SquareManyValues)1740 TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) {
1741   const int count = GetParam();
1742   XlaBuilder builder(TestName());
1743   std::vector<float> values;
1744   values.reserve(count);
1745   for (int i = 0; i < count; ++i) {
1746     values.push_back(i / static_cast<float>(count));
1747   }
1748   auto x = ConstantR1<float>(&builder, values);
1749   Pow(x, ConstantR0<float>(&builder, 2.0f));
1750 
1751   std::vector<float> expected;
1752   expected.reserve(values.size());
1753   for (float value : values) {
1754     expected.push_back(value * value);
1755   }
1756 
1757   ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
1758 }
1759 
XLA_TEST_F(ArrayElementwiseOpTest,SquareIn4D)1760 XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4D) {
1761   XlaBuilder builder(TestName());
1762   Array4D<float> values(2, 2, 2, 2);
1763 
1764   std::vector<float> values_vector;
1765   std::vector<float> expected_vector;
1766   for (int i = 0; i < values.num_elements(); ++i) {
1767     values_vector.push_back(static_cast<float>(i) / values.num_elements());
1768     expected_vector.push_back(values_vector.back() * values_vector.back());
1769   }
1770   values.SetValues(values_vector);
1771 
1772   Array4D<float> expected(2, 2, 2, 2, expected_vector);
1773 
1774   auto x = ConstantR4FromArray4D<float>(&builder, values);
1775   Pow(x, ConstantR0<float>(&builder, 2.0f));
1776 
1777   ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
1778 }
1779 
XLA_TEST_F(ArrayElementwiseOpTest,SquareIn4DZeroElements)1780 XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) {
1781   XlaBuilder builder(TestName());
1782   Array4D<float> values(2, 2, 0, 2);
1783   Array4D<float> expected(2, 2, 0, 2);
1784 
1785   auto x = ConstantR4FromArray4D<float>(&builder, values);
1786   Pow(x, ConstantR0<float>(&builder, 2.0f));
1787 
1788   ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
1789 }
1790 
XLA_TEST_F(ArrayElementwiseOpTest,MinF32s)1791 XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) {
1792   XlaBuilder builder(TestName());
1793   SetFastMathDisabled(true);
1794   auto lhs = ConstantR1<float>(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f});
1795   auto rhs = ConstantR1<float>(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN});
1796   Min(lhs, rhs);
1797 
1798   ComputeAndCompareR1<float>(&builder, {1.0f, -5.0f, 1.0f, NAN, NAN}, {},
1799                              error_spec_);
1800 }
1801 
XLA_TEST_F(ArrayElementwiseOpTest,MinZeroElementF32s)1802 XLA_TEST_F(ArrayElementwiseOpTest, MinZeroElementF32s) {
1803   XlaBuilder builder(TestName());
1804   auto lhs = ConstantR1<float>(&builder, {});
1805   auto rhs = ConstantR1<float>(&builder, {});
1806   Min(lhs, rhs);
1807   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
1808 }
1809 
XLA_TEST_F(ArrayElementwiseOpTest,MinF64s)1810 XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) {
1811   XlaBuilder builder(TestName());
1812   SetFastMathDisabled(true);
1813   auto lhs = ConstantR1<double>(&builder, {1.0, 1.0, 2.25, NAN, 6.0});
1814   auto rhs = ConstantR1<double>(&builder, {2.0, -5.0, 1.0, 10.0, NAN});
1815   Min(lhs, rhs);
1816 
1817   ComputeAndCompareR1<double>(&builder, {1.0, -5.0, 1.0, NAN, NAN}, {},
1818                               error_spec_);
1819 }
1820 
XLA_TEST_F(ArrayElementwiseOpTest,MaxF32s)1821 XLA_TEST_F(ArrayElementwiseOpTest, MaxF32s) {
1822   XlaBuilder builder(TestName());
1823   SetFastMathDisabled(true);
1824   auto lhs = ConstantR1<float>(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f});
1825   auto rhs = ConstantR1<float>(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN});
1826   Max(lhs, rhs);
1827 
1828   ComputeAndCompareR1<float>(&builder, {2.0f, 1.0f, 2.25f, NAN, NAN}, {},
1829                              error_spec_);
1830 }
1831 
XLA_TEST_F(ArrayElementwiseOpTest,MaxZeroElementF32s)1832 XLA_TEST_F(ArrayElementwiseOpTest, MaxZeroElementF32s) {
1833   XlaBuilder builder(TestName());
1834   auto lhs = ConstantR1<float>(&builder, {});
1835   auto rhs = ConstantR1<float>(&builder, {});
1836   Max(lhs, rhs);
1837   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
1838 }
1839 
XLA_TEST_F(ArrayElementwiseOpTest,MaxF64s)1840 XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) {
1841   XlaBuilder builder(TestName());
1842   SetFastMathDisabled(true);
1843   auto lhs = ConstantR1<double>(&builder, {1.0, 1.0, 2.25, NAN, 6.0});
1844   auto rhs = ConstantR1<double>(&builder, {2.0, -5.0, 1.0, 10.0, NAN});
1845   Max(lhs, rhs);
1846 
1847   ComputeAndCompareR1<double>(&builder, {2.0, 1.0, 2.25, NAN, NAN}, {},
1848                               error_spec_);
1849 }
1850 
XLA_TEST_F(ArrayElementwiseOpTest,MaxS32s)1851 XLA_TEST_F(ArrayElementwiseOpTest, MaxS32s) {
1852   const int32 min = std::numeric_limits<int32>::min();
1853   const int32 max = std::numeric_limits<int32>::max();
1854   XlaBuilder builder(TestName());
1855   auto x = ConstantR1<int32>(
1856       &builder, {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max});
1857   auto y = ConstantR1<int32>(
1858       &builder, {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min});
1859   Max(x, y);
1860 
1861   std::vector<int32> expected = {min, max, 0,  -1,  0,   0,  0,
1862                                  1,   1,   10, max, max, max};
1863   ComputeAndCompareR1<int32>(&builder, expected, {});
1864 }
1865 
XLA_TEST_F(ArrayElementwiseOpTest,MinS32s)1866 XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) {
1867   const int32 min = std::numeric_limits<int32>::min();
1868   const int32 max = std::numeric_limits<int32>::max();
1869   XlaBuilder builder(TestName());
1870   auto x = ConstantR1<int32>(
1871       &builder, {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max});
1872   auto y = ConstantR1<int32>(
1873       &builder, {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min});
1874   Min(x, y);
1875 
1876   std::vector<int32> expected = {min, min, min, -10, -1,  -1, 0,
1877                                  0,   0,   1,   0,   max, min};
1878   ComputeAndCompareR1<int32>(&builder, expected, {});
1879 }
1880 
XLA_TEST_F(ArrayElementwiseOpTest,MaxU32s)1881 XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) {
1882   const uint32 max = std::numeric_limits<uint32>::max();
1883   XlaBuilder builder(TestName());
1884   auto x = ConstantR1<uint32>(&builder, {0, 0, 1, 1, 1, max, max, max});
1885   auto y = ConstantR1<uint32>(&builder, {0, 1, 0, 1, 10, 0, 234234, max});
1886   Max(x, y);
1887 
1888   std::vector<uint32> expected = {0, 1, 1, 1, 10, max, max, max};
1889   ComputeAndCompareR1<uint32>(&builder, expected, {});
1890 }
1891 
XLA_TEST_F(ArrayElementwiseOpTest,MinU32s)1892 XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) {
1893   const uint32 max = std::numeric_limits<uint32>::max();
1894   XlaBuilder builder(TestName());
1895   auto x = ConstantR1<uint32>(&builder, {0, 0, 1, 1, 1, max, max, max});
1896   auto y = ConstantR1<uint32>(&builder, {0, 1, 0, 1, 10, 0, 234234, max});
1897   Min(x, y);
1898 
1899   std::vector<uint32> expected = {0, 0, 0, 1, 1, 0, 234234, max};
1900   ComputeAndCompareR1<uint32>(&builder, expected, {});
1901 }
1902 
XLA_TEST_F(ArrayElementwiseOpTest,MaxTenF32s)1903 XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) {
1904   XlaBuilder builder(TestName());
1905   auto x = ConstantR1<float>(
1906       &builder, {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0});
1907   auto y = ConstantR1<float>(
1908       &builder, {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0});
1909   Max(x, y);
1910 
1911   std::vector<float> expected = {-0.0, 1.0, 2.0, 3.0, 4.0,
1912                                  5.0,  6.0, 7.0, 8.0, 9.0};
1913   ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
1914 }
1915 
XLA_TEST_F(ArrayElementwiseOpTest,MaxR1S1AndR1S0F32s)1916 XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) {
1917   XlaBuilder builder(TestName());
1918   auto u = ConstantR1<float>(&builder, {3.5});
1919   auto v = ConstantR1<float>(&builder, {});
1920   Max(u, v);
1921 
1922   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
1923 }
1924 
XLA_TEST_F(ArrayElementwiseOpTest,MaxR1S0AndR2S0x2F32s)1925 XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) {
1926   for (int broadcast_dim : {0, 1}) {
1927     XlaBuilder builder(TestName());
1928     auto u = ConstantR1<float>(&builder, {3.5});
1929     auto v = ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 2));
1930     Max(u, v, /*broadcast_dimensions=*/{broadcast_dim});
1931 
1932     ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 2), {}, error_spec_);
1933   }
1934 }
1935 
XLA_TEST_F(ArrayElementwiseOpTest,Max1DAnd2DF32s)1936 XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) {
1937   XlaBuilder builder(TestName());
1938   auto v = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
1939   auto m = ConstantR2<float>(&builder,
1940                              {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
1941   Max(v, m, /*broadcast_dimensions=*/{1});
1942 
1943   Array2D<float> expected({{2.0f, 3.14f, 4.0f}, {2.25f, 3.0f, 4.0f}});
1944   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1945 }
1946 
XLA_TEST_F(ArrayElementwiseOpTest,Max1DAnd2DZeroElementF32s)1947 XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) {
1948   XlaBuilder builder(TestName());
1949   auto v = ConstantR1<float>(&builder, {});
1950   auto m = ConstantR2<float>(&builder, {{}, {}});
1951   Max(v, m, /*broadcast_dimensions=*/{1});
1952 
1953   Array2D<float> expected({{}, {}});
1954   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1955 }
1956 
XLA_TEST_F(ArrayElementwiseOpTest,Max3DAndScalarS32s)1957 XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) {
1958   XlaBuilder builder(TestName());
1959   auto scalar = ConstantR0<int32>(&builder, 2);
1960   Array3D<int32> a_3d({{{3, 9, -1}, {2, -10, 3}}, {{-2, 2, 8}, {12, 10, 4}}});
1961   auto array = ConstantR3FromArray3D<int32>(&builder, a_3d);
1962   Max(array, scalar, /*broadcast_dimensions=*/{});
1963 
1964   Array3D<int32> expected({{{3, 9, 2}, {2, 2, 3}}, {{2, 2, 8}, {12, 10, 4}}});
1965   ComputeAndCompareR3<int32>(&builder, expected, {});
1966 }
1967 
XLA_TEST_F(ArrayElementwiseOpTest,Max3DAndScalarZeroElementS32s)1968 XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) {
1969   XlaBuilder builder(TestName());
1970   auto scalar = ConstantR0<int32>(&builder, 2);
1971   Array3D<int32> a_3d(2, 0, 3);
1972   auto array = ConstantR3FromArray3D<int32>(&builder, a_3d);
1973   Max(array, scalar, /*broadcast_dimensions=*/{});
1974 
1975   Array3D<int32> expected(2, 0, 3);
1976   ComputeAndCompareR3<int32>(&builder, expected, {});
1977 }
1978 
XLA_TEST_F(ArrayElementwiseOpTest,Min2DTo1DF32s)1979 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) {
1980   XlaBuilder builder(TestName());
1981   auto m = ConstantR2<float>(&builder,
1982                              {{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}});
1983   auto v = ConstantR1<float>(&builder, {-10.2f, 16.4f});
1984   Min(m, v, /*broadcast_dimensions=*/{0});
1985 
1986   Array2D<float> expected({{-10.4f, -10.2f, -10.2f}, {0.1f, 16.4f, 16.1f}});
1987   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1988 }
1989 
XLA_TEST_F(ArrayElementwiseOpTest,Min2DTo1DZeroElementF32s)1990 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) {
1991   XlaBuilder builder(TestName());
1992   auto m = ConstantR2<float>(&builder, {{}, {}});
1993   auto v = ConstantR1<float>(&builder, {-10.2f, 16.4f});
1994   Min(m, v, /*broadcast_dimensions=*/{0});
1995 
1996   Array2D<float> expected({{}, {}});
1997   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
1998 }
1999 
XLA_TEST_F(ArrayElementwiseOpTest,Min2DTo4DF32s)2000 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) {
2001   XlaBuilder builder(TestName());
2002   auto array2d =
2003       ConstantR2<float>(&builder, {{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}});
2004   auto array4d = ConstantR4FromArray4D<float>(
2005       &builder, {{{{-12.1f, 32.3f, 6.2f}}, {{0.0f, 32.5f, 3.0f}}},
2006                  {{{-2.5f, 64.29f, 6.5f}}, {{-0.01f, 32.25f, 2.6f}}}});
2007   Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3});
2008 
2009   Array4D<float> expected(
2010       {{{{-12.2f, 32.3f, 6.1f}}, {{0.0f, 32.2f, 2.5f}}},
2011        {{{-12.2f, 64.29f, 6.1f}}, {{-0.01f, 32.2f, 2.5f}}}});
2012   ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
2013 }
2014 
XLA_TEST_F(ArrayElementwiseOpTest,Min2DTo4DZeroElementF32s)2015 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) {
2016   XlaBuilder builder(TestName());
2017   auto array2d =
2018       ConstantR2<float>(&builder, {{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}});
2019   Array4D<float> arg(2, 2, 0, 3);
2020   auto array4d = ConstantR4FromArray4D<float>(&builder, arg);
2021   Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3});
2022 
2023   Array4D<float> expected(2, 2, 0, 3);
2024   ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
2025 }
2026 
XLA_TEST_F(ArrayElementwiseOpTest,MinTenS32s)2027 XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) {
2028   XlaBuilder builder(TestName());
2029   auto x = ConstantR1<int32>(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
2030   auto y = ConstantR1<int32>(&builder, {9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
2031   Min(x, y);
2032 
2033   std::vector<int32> expected = {0, 1, 2, 3, 4, 4, 3, 2, 1, 0};
2034   ComputeAndCompareR1<int32>(&builder, expected, {});
2035 }
2036 
XLA_TEST_F(ArrayElementwiseOpTest,MaxTenS32s)2037 XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) {
2038   XlaBuilder builder(TestName());
2039   auto x = ConstantR1<int32>(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
2040   auto y = ConstantR1<int32>(&builder, {9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
2041   Max(x, y);
2042 
2043   std::vector<int32> expected = {9, 8, 7, 6, 5, 5, 6, 7, 8, 9};
2044   ComputeAndCompareR1<int32>(&builder, expected, {});
2045 }
2046 
XLA_TEST_F(ArrayElementwiseOpTest,RemTwoConstantS32s)2047 XLA_TEST_F(ArrayElementwiseOpTest, RemTwoConstantS32s) {
2048   XlaBuilder builder(TestName());
2049   auto a = ConstantR1<int32>(&builder, {-3, 26, 2, -1, 1});
2050   auto b = ConstantR1<int32>(&builder, {10, 5, 1, 10, -10});
2051   Rem(a, b);
2052 
2053   ComputeAndCompareR1<int32>(&builder, {-3, 1, 0, -1, 1}, {});
2054 }
2055 
XLA_TEST_F(ArrayElementwiseOpTest,NonNanClampF32)2056 XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) {
2057   XlaBuilder builder(TestName());
2058   auto minimum = ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
2059   auto argument =
2060       ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f});
2061   auto maximum = ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, 5.0f, 123.0});
2062   Clamp(minimum, argument, maximum);
2063 
2064   ComputeAndCompareR1<float>(&builder, {2.0f, 0.5f, 1.0f, 2.25f, 10.0f}, {},
2065                              error_spec_);
2066 }
2067 
XLA_TEST_F(ArrayElementwiseOpTest,ClampF32)2068 XLA_TEST_F(ArrayElementwiseOpTest, ClampF32) {
2069   SetFastMathDisabled(true);
2070   XlaBuilder builder(TestName());
2071   auto minimum = ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, NAN});
2072   auto argument =
2073       ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f});
2074   auto maximum = ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, NAN, 123.0f});
2075   Clamp(minimum, argument, maximum);
2076 
2077   ComputeAndCompareR1<float>(&builder, {2.0f, 0.5f, 1.0f, NAN, NAN}, {},
2078                              error_spec_);
2079 }
2080 
XLA_TEST_F(ArrayElementwiseOpTest,ClampF32Scalar)2081 XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) {
2082   XlaBuilder builder(TestName());
2083   auto minimum = ConstantR0<float>(&builder, 0.0f);
2084   auto argument = ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
2085   auto maximum = ConstantR0<float>(&builder, 5.0f);
2086   Clamp(minimum, argument, maximum);
2087 
2088   ComputeAndCompareR1<float>(&builder, {2.0f, 5.0f, 0.0f, 1.0f, 4.0f}, {},
2089                              error_spec_);
2090 }
2091 
XLA_TEST_F(ArrayElementwiseOpTest,ClampF32ScalarVector)2092 XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) {
2093   XlaBuilder builder(TestName());
2094   auto min_scalar = ConstantR0<float>(&builder, 0.0f);
2095   auto min_vector =
2096       ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
2097   auto arg_vector =
2098       ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
2099   auto max_scalar = ConstantR0<float>(&builder, 3.0f);
2100   auto max_vector =
2101       ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, 5.0f, 123.0});
2102   // Perform clamp with broadcasted scalar and vector.
2103   Add(Add(Clamp(min_vector, arg_vector, max_scalar),
2104           Clamp(min_scalar, arg_vector, max_vector)),
2105       Add(Clamp(min_vector, arg_vector, max_vector),
2106           Clamp(min_scalar, arg_vector, max_scalar)));
2107 
2108   ComputeAndCompareR1<float>(&builder, {8.0f, 7.0f, 2.0f, 6.5f, 14.0f}, {},
2109                              error_spec_);
2110 }
2111 
XLA_TEST_F(ArrayElementwiseOpTest,ClampS32Vector)2112 XLA_TEST_F(ArrayElementwiseOpTest, ClampS32Vector) {
2113   XlaBuilder builder(TestName());
2114   auto min_vector = ConstantR1<int32>(&builder, {1, -6, 1, 2, 0, -5});
2115   auto arg_vector = ConstantR1<int32>(&builder, {2, 10, -5, 1, 4, 10});
2116   auto max_vector = ConstantR1<int32>(&builder, {3, 0, 25, 5, 123, -1});
2117   Clamp(min_vector, arg_vector, max_vector);
2118 
2119   ComputeAndCompareR1<int32>(&builder, {2, 0, 1, 2, 4, -1}, {});
2120 }
2121 
XLA_TEST_F(ArrayElementwiseOpTest,ClampS32ScalarVector)2122 XLA_TEST_F(ArrayElementwiseOpTest, ClampS32ScalarVector) {
2123   XlaBuilder builder(TestName());
2124   auto min_scalar = ConstantR0<int32>(&builder, 0);
2125   auto min_vector = ConstantR1<int32>(&builder, {1, -6, 1, 2, 0});
2126   auto arg_vector = ConstantR1<int32>(&builder, {2, 10, -5, 1, 4});
2127   auto max_scalar = ConstantR0<int32>(&builder, 3);
2128   auto max_vector = ConstantR1<int32>(&builder, {3, 1, 25, 5, 123});
2129   // Perform clamp with broadcasted scalar and vector.
2130   Add(Add(Clamp(min_vector, arg_vector, max_scalar),
2131           Clamp(min_scalar, arg_vector, max_vector)),
2132       Add(Clamp(min_vector, arg_vector, max_vector),
2133           Clamp(min_scalar, arg_vector, max_scalar)));
2134 
2135   ComputeAndCompareR1<int32>(&builder, {8, 8, 2, 6, 14}, {});
2136 }
2137 
XLA_TEST_F(ArrayElementwiseOpTest,ClampU32Vector)2138 XLA_TEST_F(ArrayElementwiseOpTest, ClampU32Vector) {
2139   XlaBuilder builder(TestName());
2140   auto min_vector = ConstantR1<uint32>(&builder, {1, 2, 1, 2, 0, ~0u - 4});
2141   auto arg_vector = ConstantR1<uint32>(&builder, {2, 10, 5, 1, 4, 10});
2142   auto max_vector = ConstantR1<uint32>(&builder, {3, 5, 25, 5, 123, ~0u});
2143   Clamp(min_vector, arg_vector, max_vector);
2144 
2145   ComputeAndCompareR1<uint32>(&builder, {2, 5, 5, 2, 4, ~0u - 4}, {});
2146 }
2147 
XLA_TEST_F(ArrayElementwiseOpTest,ClampU32ScalarVector)2148 XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) {
2149   XlaBuilder builder(TestName());
2150   auto min_scalar = ConstantR0<uint32>(&builder, 0);
2151   auto min_vector = ConstantR1<uint32>(&builder, {1, 0, 1, 2, 0});
2152   auto arg_vector = ConstantR1<uint32>(&builder, {2, 10, 0, 1, 4});
2153   auto max_scalar = ConstantR0<uint32>(&builder, 3);
2154   auto max_vector = ConstantR1<uint32>(&builder, {3, 1, 25, 5, 123});
2155   // Perform clamp with broadcasted scalar and vector.
2156   Add(Add(Clamp(min_vector, arg_vector, max_scalar),
2157           Clamp(min_scalar, arg_vector, max_vector)),
2158       Add(Clamp(min_vector, arg_vector, max_vector),
2159           Clamp(min_scalar, arg_vector, max_scalar)));
2160 
2161   ComputeAndCompareR1<uint32>(&builder, {8, 8, 2, 6, 14}, {});
2162 }
2163 
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoParametersF32s)2164 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
2165   XlaBuilder builder(TestName());
2166 
2167   Literal param0_literal =
2168       LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
2169   std::unique_ptr<GlobalData> param0_data =
2170       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
2171 
2172   Literal param1_literal =
2173       LiteralUtil::CreateR1<float>({7.2f, 2.3f, 3.4f, 5.6f});
2174   std::unique_ptr<GlobalData> param1_data =
2175       client_->TransferToServer(param1_literal).ConsumeValueOrDie();
2176 
2177   auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
2178   auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
2179   Add(p0, p1);
2180 
2181   ComputeAndCompareR1<float>(&builder, {8.3f, 4.5f, 6.7f, 11.1f},
2182                              {param0_data.get(), param1_data.get()},
2183                              error_spec_);
2184 }
2185 
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoParametersZeroElementF32s)2186 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
2187   XlaBuilder builder(TestName());
2188 
2189   Literal param0_literal =
2190       LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
2191   std::unique_ptr<GlobalData> param0_data =
2192       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
2193 
2194   Literal param1_literal =
2195       LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
2196   std::unique_ptr<GlobalData> param1_data =
2197       client_->TransferToServer(param1_literal).ConsumeValueOrDie();
2198 
2199   auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
2200   auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
2201   Add(p0, p1);
2202 
2203   Array3D<float> expected(0, 7, 0);
2204   ComputeAndCompareR3<float>(
2205       &builder, expected, {param0_data.get(), param1_data.get()}, error_spec_);
2206 }
2207 
XLA_TEST_F(ArrayElementwiseOpTest,AddParameterToConstantF32s)2208 XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) {
2209   XlaBuilder builder(TestName());
2210 
2211   Literal param0_literal =
2212       LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
2213   std::unique_ptr<GlobalData> param0_data =
2214       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
2215 
2216   auto a = ConstantR1<float>(&builder, {1.1f, 2.2f, 3.3f, 4.4f});
2217   auto p = Parameter(&builder, 0, param0_literal.shape(), "param0");
2218   Add(a, p);
2219 
2220   ComputeAndCompareR1<float>(&builder, {2.2f, 4.4f, 6.6f, 9.9f},
2221                              {param0_data.get()}, error_spec_);
2222 }
2223 
XLA_TEST_F(ArrayElementwiseOpTest,CosF32s)2224 XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) {
2225   XlaBuilder builder(TestName());
2226   auto a = ConstantR1<float>(&builder, {3.14159f, 0.0f, 1.570796f, -0.78539f});
2227   Cos(a);
2228 
2229   ComputeAndCompareR1<float>(&builder, {-1.0f, 1.0f, 0.0f, 0.707107f}, {},
2230                              error_spec_);
2231 }
2232 
XLA_TEST_F(ArrayElementwiseOpTest,SinF32s)2233 XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) {
2234   XlaBuilder builder(TestName());
2235   auto a = ConstantR1<float>(&builder, {3.14159f, 0.0f, 1.570796f, -0.78539f});
2236   Sin(a);
2237 
2238   ComputeAndCompareR1<float>(&builder, {0.0f, 0.0f, 1.0f, -0.707107f}, {},
2239                              error_spec_);
2240 }
2241 
XLA_TEST_F(ArrayElementwiseOpTest,Atan2F32s)2242 XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) {
2243   XlaBuilder builder(TestName());
2244   auto a = ConstantR1<float>(&builder, {0.0f, 5.0f, 0.0f, -3.0f, 2.0f, -8.0f});
2245   auto b = ConstantR1<float>(&builder, {6.0f, 0.0f, -4.0f, 0.0f, 2.0f, 8.0f});
2246   Atan2(a, b);
2247 
2248   ComputeAndCompareR1<float>(
2249       &builder,
2250       {0.0f, 1.57079633f, 3.14159265f, -1.57079633f, 0.78539816f, -0.78539816f},
2251       {}, error_spec_);
2252 }
2253 
XLA_TEST_F(ArrayElementwiseOpTest,TanhF32s)2254 XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) {
2255   XlaBuilder builder(TestName());
2256   auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f});
2257   Tanh(a);
2258 
2259   ComputeAndCompareR1<float>(&builder, {-0.986614f, 0.996260f, 0.978026}, {},
2260                              error_spec_);
2261 }
2262 
XLA_TEST_F(ArrayElementwiseOpTest,TanhF32sVector)2263 XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) {
2264   // This is like the test ArrayElementwiseOpTest.TanhF32s above, except that
2265   // the input tensor is large enough to exercise the vectorized tanh
2266   // implementation on XLA CPU.
2267   XlaBuilder builder(TestName());
2268   auto input_literal = LiteralUtil::CreateR1<float>(
2269       {1.02,  -0.32, 0.85,  0.90,  1.23,  -0.91, -0.49, 0.80,  -0.67, 0.16,
2270        -0.07, 0.39,  -0.41, 0.04,  1.36,  1.25,  0.41,  0.65,  -1.08, 0.32,
2271        -1.45, -0.77, -1.09, 0.91,  -1.03, -0.30, -1.11, -1.17, 1.50,  -0.85,
2272        0.04,  1.02,  0.34,  -0.61, 0.41,  0.07,  -0.02, 1.42,  -0.62, 0.81,
2273        0.08,  0.81,  -0.30, 1.17,  -0.65, -0.44, 0.92,  1.26,  -1.29, 1.35,
2274        0.08,  -1.24, -0.92, 0.49,  1.17,  -0.45, -1.31, -1.44, -0.13, -1.31,
2275        -0.79, 1.41,  1.21,  1.05});
2276   TF_ASSERT_OK_AND_ASSIGN(auto input_data,
2277                           client_->TransferToServer(input_literal));
2278 
2279   auto input = Parameter(&builder, 0, input_literal.shape(), "input");
2280   Tanh(input);
2281 
2282   ComputeAndCompareR1<float>(
2283       &builder,
2284       {0.77009583,  -0.30665702, 0.69070244,  0.71401149,  0.84400684,
2285        -0.71985596, -0.45764771, 0.66664988,  -0.58278900, 0.16050975,
2286        -0.06770509, 0.36843640,  -0.38476998, 0.04018109,  0.87562293,
2287        0.84788644,  0.38603750,  0.57294142,  -0.79140943, 0.31032649,
2288        -0.89590985, -0.64770776, -0.79625875, 0.72234446,  -0.77389336,
2289        -0.28871772, -0.80428445, -0.82541436, 0.90456349,  -0.68856895,
2290        0.03877772,  0.76877952,  0.32561871,  -0.54546672, 0.39072621,
2291        0.07273290,  -0.01924866, 0.88924897,  -0.55283129, 0.67183107,
2292        0.08006320,  0.66944766,  -0.29068485, 0.82573754,  -0.57170743,
2293        -0.41581789, 0.72739530,  0.85025692,  -0.85931867, 0.87357593,
2294        0.07782833,  -0.84597743, -0.72748238, 0.45396307,  0.82449573,
2295        -0.42462519, -0.86363792, -0.89368379, -0.12621804, -0.86445558,
2296        -0.65565848, 0.88789743,  0.83566397,  0.78287679},
2297       {input_data.get()},
2298       // The error spec is unusually high here to account for the fact that we
2299       // use a rational interpolant to approximate tanh.
2300       ErrorSpec(0.004, 0.004));
2301 }
2302 
XLA_TEST_F(ArrayElementwiseOpTest,ExpF32sVector)2303 XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
2304   // The input tensor is large enough to exercise the vectorized exp
2305   // implementation on XLA CPU.
2306   XlaBuilder builder(TestName());
2307 
2308   // Just to help make sense of the scales here -- exp(89) saturates float32 and
2309   // exp(-10) is smaller than our error spec.
2310   Literal input_literal = LiteralUtil::CreateR1<float>(
2311       {1.02,   -0.32,  0.85,   0.9,    1.23,   -0.91,  -0.49, 0.8,    -1.31,
2312        -1.44,  -0.13,  -1.31,  -0.79,  1.41,   1.21,   1.05,  -195.6, -194.5,
2313        -193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5,  -17.4,
2314        -16.3,  -15.2,  -14.1,  -13.0,  -11.9,  -10.8,  -9.7,  -8.6,   -7.5,
2315        -6.4,   -5.3,   -4.2,   -3.1,   -2.0,   -0.9,   0.2,   1.3,    2.4,
2316        3.5,    4.6,    5.7,    6.8,    7.9,    9.0,    10.1,  11.2,   12.3,
2317        13.4,   14.5,   15.6,   16.7,   17.8,   18.9,   20.0,  21.1,   22.2,
2318        23.3,   24.4,   25.5,   26.6,   27.7,   28.8,   29.9,  31.0,   32.1,
2319        68.4,   69.5,   70.6,   71.7,   72.8,   73.9,   75.0,  76.1,   77.2,
2320        78.3,   79.4,   80.5,   81.6,   82.7,   83.8,   84.9,  85.2,   86.3,
2321        86.4,   86.5,   87.6,   87.7,   87.8,   87.9});
2322   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
2323                           client_->TransferToServer(input_literal));
2324 
2325   auto input = Parameter(&builder, 0, input_literal.shape(), "input");
2326   Exp(input);
2327 
2328   std::vector<float> expected_result;
2329   int64 input_size = input_literal.shape().dimensions(0);
2330   expected_result.reserve(input_size);
2331   for (int64 i = 0; i < input_size; i++) {
2332     expected_result.push_back(std::exp(input_literal.Get<float>({i})));
2333   }
2334 
2335   ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
2336                              error_spec_);
2337 }
2338 
XLA_TEST_F(ArrayElementwiseOpTest,LogF32sVector)2339 XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
2340   // The input tensor is large enough to exercise the vectorized exp
2341   // implementation on XLA CPU.
2342   XlaBuilder builder(TestName());
2343 
2344   Literal input_literal = LiteralUtil::CreateR1<float>(
2345       {-1.29,    -1.41,    -1.25,    -13.5,    -11.7,    -17.9,    -198,
2346        -167,     1.29,     1.41,     1.25,     13.5,     11.7,     17.9,
2347        198,      167,      1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04,  1.84e+04,
2348        1.74e+04, 1.89e+05, 1.9e+05,  1.93e+06, 1.98e+06, 1.65e+06, 1.97e+07,
2349        1.66e+07, 1e+07,    1.98e+08, 1.96e+08, 1.64e+09, 1.58e+09, 1.64e+09,
2350        1.44e+10, 1.5e+10,  1.99e+10, 1.17e+11, 1.08e+11, 1.08e+12, 1.38e+12,
2351        1.4e+12,  1.03e+13, 1.6e+13,  1.99e+13, 1.26e+14, 1.51e+14, 1.33e+15,
2352        1.41e+15, 1.63e+15, 1.39e+16, 1.21e+16, 1.27e+16, 1.28e+17, 1.62e+17,
2353        2e+18,    1.96e+18, 1.81e+18, 1.99e+19, 1.86e+19, 1.61e+19, 1.71e+20,
2354        1.47e+20, 1.83e+21, 1.33e+21, 1.3e+21,  1.35e+22, 1.84e+22, 1.02e+22,
2355        1.81e+23, 1.02e+23, 1.89e+24, 1.49e+24, 1.08e+24, 1.95e+25, 1.1e+25,
2356        1.62e+25, 1.2e+26,  1.41e+26, 1.93e+27, 1.66e+27, 1.62e+27, 1.05e+28,
2357        1.5e+28,  1.79e+28, 1.36e+29, 1.95e+29, 1.5e+30,  1.81e+30, 1.34e+30,
2358        1.7e+31,  1.44e+31, 1.1e+31,  1.4e+32,  1.67e+32, 1.96e+33, 1.11e+33,
2359        1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35});
2360   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
2361                           client_->TransferToServer(input_literal));
2362 
2363   auto input = Parameter(&builder, 0, input_literal.shape(), "input");
2364   Log(input);
2365 
2366   std::vector<float> expected_result;
2367   int64 input_size = input_literal.shape().dimensions(0);
2368   expected_result.reserve(input_size);
2369   for (int64 i = 0; i < input_size; i++) {
2370     expected_result.push_back(std::log(input_literal.Get<float>({i})));
2371   }
2372 
2373   ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
2374                              error_spec_);
2375 }
2376 
XLA_TEST_F(ArrayElementwiseOpTest,ClzU32s)2377 XLA_TEST_F(ArrayElementwiseOpTest, ClzU32s) {
2378   XlaBuilder builder(TestName());
2379   auto a = ConstantR1<uint32>(
2380       &builder, {0, 1, 0x10, 0x10000, 0x700000, 0x12345678, 0xF2345678});
2381   Clz(a);
2382 
2383   ComputeAndCompareR1<uint32>(&builder, {32, 31, 27, 15, 9, 3, 0}, {});
2384 }
2385 
XLA_TEST_F(ArrayElementwiseOpTest,ClzS64s)2386 XLA_TEST_F(ArrayElementwiseOpTest, ClzS64s) {
2387   XlaBuilder builder(TestName());
2388   auto a =
2389       ConstantR1<int64>(&builder, {0, 1, 0x80000000, 0x7FFFFFFFF2345678ul, -1});
2390   Clz(a);
2391 
2392   ComputeAndCompareR1<int64>(&builder, {64, 63, 32, 1, 0}, {});
2393 }
2394 
XLA_TEST_F(ArrayElementwiseOpTest,AddChainFoldLeft)2395 XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) {
2396   // a ------ (add) --------- (add)
2397   //         /               /
2398   // b -----/               /
2399   // c---------------------/
2400   XlaBuilder builder(TestName());
2401 
2402   auto a = ConstantR1<float>(&builder, {1.1f, 2.2f, 3.3f, 4.4f});
2403   auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
2404   auto c = ConstantR1<float>(&builder, {-3.3f, -15.5f, -7.7f, -29.9f});
2405 
2406   auto add = Add(a, b);
2407   Add(add, c);
2408 
2409   ComputeAndCompareR1<float>(&builder, {-0.1f, -10.1f, -0.1f, -20.1f}, {},
2410                              error_spec_);
2411 }
2412 
XLA_TEST_F(ArrayElementwiseOpTest,AddChainFoldRight)2413 XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) {
2414   // b ------ (add) --------- (add)
2415   //         /               /
2416   // c -----/               /
2417   // a---------------------/
2418   XlaBuilder builder(TestName());
2419 
2420   auto a = ConstantR1<float>(&builder, {91.1f, 2.2f, 3.3f, 4.4f});
2421   auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
2422   auto c = ConstantR1<float>(&builder, {-3.3f, -15.5f, -7.7f, -29.9f});
2423 
2424   auto add = Add(b, c);
2425   Add(a, add);
2426 
2427   ComputeAndCompareR1<float>(&builder, {89.9f, -10.1f, -0.1f, -20.1f}, {},
2428                              error_spec_);
2429 }
2430 
XLA_TEST_F(ArrayElementwiseOpTest,AddWithNeg)2431 XLA_TEST_F(ArrayElementwiseOpTest, AddWithNeg) {
2432   // a ----- (neg) ----- (add)
2433   //                    /
2434   // b ----- (neg) ----/
2435   XlaBuilder builder(TestName());
2436 
2437   auto a = ConstantR1<float>(&builder, {91.1f, 2.2f, 3.3f, 4.4f});
2438   auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
2439 
2440   auto neg_a = Neg(a);
2441   auto neg_b = Neg(b);
2442   Add(neg_a, neg_b);
2443 
2444   ComputeAndCompareR1<float>(&builder, {-93.2f, -5.4f, -7.6f, -9.8f}, {},
2445                              error_spec_);
2446 }
2447 
XLA_TEST_F(ArrayElementwiseOpTest,AddChainTwoSide)2448 XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) {
2449   // a ------ (add) ------------\
2450   //         /                   \
2451   // b -----/                    (add)
2452   //                             /
2453   // c ------ (add) ------------/
2454   //         /
2455   // d -----/
2456   XlaBuilder builder(TestName());
2457 
2458   auto a = ConstantR1<float>(&builder, {91.1f, 2.2f, 3.3f, 4.4f});
2459   auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
2460   auto c = ConstantR1<float>(&builder, {-3.3f, -15.5f, -7.7f, -29.9f});
2461   auto d = ConstantR1<float>(&builder, {-19.0f, 10.0f, -40.0f, 20.2f});
2462 
2463   auto add_ab = Add(a, b);
2464   auto add_cd = Add(c, d);
2465   Add(add_ab, add_cd);
2466 
2467   ComputeAndCompareR1<float>(&builder, {70.9f, -0.1f, -40.1f, 0.1f}, {},
2468                              error_spec_);
2469 }
2470 
2471 XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) {
2472   XlaBuilder builder(TestName());
2473   auto a = ConstantR2<float>(&builder,
2474                              {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2475   auto b = ConstantR2<float>(&builder,
2476                              {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
2477   Add(a, b);
2478 
2479   Array2D<float> expected_array(
2480       {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}});
2481   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2482 }
2483 
XLA_TEST_F(ArrayElementwiseOpTest,ScalarPlus2DF32)2484 XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) {
2485   // Add a scalar + matrix.
2486   XlaBuilder builder(TestName());
2487   auto a = ConstantR2<float>(&builder,
2488                              {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2489   auto scalar = ConstantR0<float>(&builder, 3.0f);
2490   Add(scalar, a);
2491 
2492   Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}});
2493   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2494 }
2495 
2496 XLA_TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) {
2497   // Add a matrix + scalar.
2498   XlaBuilder builder(TestName());
2499   auto a = ConstantR2<float>(&builder,
2500                              {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2501   auto scalar = ConstantR0<float>(&builder, 3.0f);
2502   Add(a, scalar);
2503 
2504   Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}});
2505   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2506 }
2507 
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo2DF32)2508 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) {
2509   // Test simple broadcasting of a R1F32 over R2F32. The vector's size matches
2510   // only dim 0 of the matrix.
2511   XlaBuilder builder(TestName());
2512   auto v = ConstantR1<float>(&builder, {20.0f, 40.0f, 60.0f});
2513   // clang-format off
2514   auto m = ConstantR2<float>(&builder, {
2515     {-2.5f, 3.14f, 1.0f},
2516     {2.25f, -10.0f, 3.33f}});
2517   // clang-format on
2518   Add(v, m, /*broadcast_dimensions=*/{1});
2519   Array2D<float> expected_array(
2520       {{17.5f, 43.14f, 61.0f}, {22.25f, 30.0f, 63.33f}});
2521   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2522 }
2523 
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Eq)2524 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) {
2525   // Test broadcasting in Eq comparison.
2526   XlaBuilder builder(TestName());
2527   auto v = ConstantR1<int32>(&builder, {42, 73});
2528   auto m = ConstantR2<int32>(&builder, {{42, 73}, {42, 52}});
2529 
2530   // This test exercises both possible broadcast dimensions for a vector/matrix
2531   // comparison.
2532   auto cmp_dim_0 = Eq(v, m, /*broadcast_dimensions=*/{1});
2533   auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0});
2534   Tuple(&builder, {cmp_dim_0, cmp_dim_1});
2535 
2536   auto expected = LiteralUtil::MakeTupleFromSlices(
2537       {LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}),
2538        LiteralUtil::CreateR2<bool>({{true, false}, {false, false}})});
2539   ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
2540 }
2541 
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Ne)2542 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) {
2543   // Test broadcasting in Ne comparison.
2544   XlaBuilder builder(TestName());
2545   auto v = ConstantR1<int32>(&builder, {42, 73});
2546   auto m = ConstantR2<int32>(&builder, {{42, 73}, {42, 52}});
2547   Ne(v, m, /*broadcast_dimensions=*/{1});
2548 
2549   const string expected = R"(pred[2,2] {
2550   { 0, 0 },
2551   { 0, 1 }
2552 })";
2553   EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2554 }
2555 
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Ge)2556 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) {
2557   // Test broadcasting in Ge comparison.
2558   XlaBuilder builder(TestName());
2559   auto v = ConstantR1<int32>(&builder, {1, 2, 3, 4});
2560   auto m = ConstantR2<int32>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
2561   Ge(v, m, /*broadcast_dimensions=*/{1});
2562 
2563   const string expected = R"(pred[2,4] {
2564   { 1, 1, 0, 0 },
2565   { 0, 0, 0, 1 }
2566 })";
2567   EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2568 }
2569 
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Gt)2570 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) {
2571   // Test broadcasting in Gt comparison.
2572   XlaBuilder builder(TestName());
2573   auto v = ConstantR1<int32>(&builder, {1, 2, 3, 4});
2574   auto m = ConstantR2<int32>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
2575   Gt(v, m, /*broadcast_dimensions=*/{1});
2576 
2577   const string expected = R"(pred[2,4] {
2578   { 0, 1, 0, 0 },
2579   { 0, 0, 0, 0 }
2580 })";
2581   EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2582 }
2583 
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Le)2584 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) {
2585   // Test broadcasting in Le comparison.
2586   XlaBuilder builder(TestName());
2587   auto v = ConstantR1<int32>(&builder, {1, 2, 3, 4});
2588   auto m = ConstantR2<int32>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
2589   Le(v, m, /*broadcast_dimensions=*/{1});
2590 
2591   const string expected = R"(pred[2,4] {
2592   { 1, 0, 1, 1 },
2593   { 1, 1, 1, 1 }
2594 })";
2595   EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2596 }
2597 
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Lt)2598 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) {
2599   // Test broadcasting in Lt comparison.
2600   XlaBuilder builder(TestName());
2601   auto v = ConstantR1<int32>(&builder, {1, 2, 3, 4});
2602   auto m = ConstantR2<int32>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
2603   Lt(v, m, /*broadcast_dimensions=*/{1});
2604 
2605   const string expected = R"(pred[2,4] {
2606   { 0, 0, 1, 1 },
2607   { 1, 1, 1, 0 }
2608 })";
2609   EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2610 }
2611 
XLA_TEST_F(ArrayElementwiseOpTest,Mul2Dby1DF32)2612 XLA_TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) {
2613   // Test simple broadcasting of a R1F32 over R2F32 when the order of binary op
2614   // arguments is reversed.
2615   XlaBuilder builder(TestName());
2616   auto m =
2617       ConstantR2<float>(&builder, {{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}});
2618   auto v = ConstantR1<float>(&builder, {2.0f, 4.0f, 6.0f});
2619   Mul(m, v, /*broadcast_dimensions=*/{1});
2620   Array2D<float> expected_array({{3.0f, 10.0f, 21.0f}, {9.0f, 22.0f, 39.0f}});
2621   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2622 }
2623 
XLA_TEST_F(ArrayElementwiseOpTest,Add2DTo2DWithDegenerateDim1)2624 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) {
2625   // Tests broadcasting for arrays with degenerate (size == 1) dimensions.
2626   XlaBuilder builder(TestName());
2627   // m's shape in XLA notation is {3, 2}
2628   // md's shape in XLA notation is {3, 1}
2629   // The result has shape {3, 2}, where md is broadcast over m
2630   auto m = ConstantR2<float>(&builder,
2631                              {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2632   auto md = ConstantR2<float>(&builder, {{10.0f, 20.0f, 30.0f}});
2633   Add(m, md);
2634   Array2D<float> expected_array(
2635       {{7.5f, 23.14f, 31.0f}, {12.25f, 10.0f, 33.33f}});
2636   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2637 }
2638 
XLA_TEST_F(ArrayElementwiseOpTest,Add2DTo2DWithDegenerateDim0)2639 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim0) {
2640   // Tests broadcasting for arrays with degenerate (size == 1) dimensions.
2641   XlaBuilder builder(TestName());
2642   // m's shape in XLA notation is {3, 2}
2643   // md's shape in XLA notation is {1, 2}
2644   // The result has shape {3, 2}, where md is broadcast over m
2645   auto m = ConstantR2<float>(&builder,
2646                              {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2647   auto md = ConstantR2<float>(&builder, {{10.0f}, {20.0f}});
2648   Add(m, md);
2649   Array2D<float> expected_array(
2650       {{7.5f, 13.14f, 11.0f}, {22.25f, 10.0f, 23.33f}});
2651   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2652 }
2653 
XLA_TEST_F(ArrayElementwiseOpTest,Add2DsWithDegenerateDimsOuterProduct)2654 XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) {
2655   // Tests broadcasting for two degenerate arrays. This kind of broadcasting
2656   // effectively creates an "outer product" operation.
2657   // This is taken from the Numpy docs example at:
2658   // http://docs.scipy.org/doc/numpy-1.10.1/user/basics.broadcasting.html
2659   XlaBuilder builder(TestName());
2660   // a's shape in XLA notation is {1, 4}
2661   // b's shape in XLA notation is {3, 1}
2662   // The result has shape {3, 4}.
2663   auto a = ConstantR2<float>(&builder, {{0.0f}, {10.0f}, {20.0f}, {30.0f}});
2664   auto b = ConstantR2<float>(&builder, {{1.0f, 2.0f, 3.0f}});
2665   Add(a, b);
2666   Array2D<float> expected_array({{1.0f, 2.0f, 3.0f},
2667                                  {11.0f, 12.0f, 13.0f},
2668                                  {21.0f, 22.0f, 23.0f},
2669                                  {31.0f, 32.0f, 33.0f}});
2670   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2671 }
2672 
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo2DF32TwoWaysOver1)2673 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) {
2674   // Add together a (2,2) array and a (2) array, using dimension 0 for
2675   // broadcasting (though there are two ways to broadcast these shapes).
2676   XlaBuilder builder(TestName());
2677   auto v = ConstantR1<float>(&builder, {20.0f, 40.0f});
2678   auto m = ConstantR2<float>(&builder, {{10.0f, 50.0f}, {77.0f, 88.0f}});
2679   Add(v, m, /*broadcast_dimensions=*/{1});
2680   Array2D<float> expected_array({{30.0f, 90.0f}, {97.0f, 128.0f}});
2681   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2682 }
2683 
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo2DF32TwoWaysOver0)2684 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) {
2685   // Add together a (2,2) array and a (2) array, using dimension 1 for
2686   // broadcasting (though there are two ways to broadcast these shapes).
2687   XlaBuilder builder(TestName());
2688   auto v = ConstantR1<float>(&builder, {20.0f, 40.0f});
2689   auto m = ConstantR2<float>(&builder, {{10.0f, 50.0f}, {77.0f, 88.0f}});
2690   Add(v, m, /*broadcast_dimensions=*/{0});
2691   Array2D<float> expected_array({{30.0f, 70.0f}, {117.0f, 128.0f}});
2692   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2693 }
2694 
2695 XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) {
2696   // Binary add of two R3s together
2697   XlaBuilder builder(TestName());
2698   Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
2699                        {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}});
2700   auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2701 
2702   Array3D<float> b_3d({{{2.0f, 4.0f}, {6.0f, 8.0f}, {10.0f, 12.0f}},
2703                        {{14.0f, 16.0f}, {18.0f, 20.0f}, {22.0f, 24.0f}}});
2704   auto b = ConstantR3FromArray3D<float>(&builder, b_3d);
2705   Add(a, b);
2706 
2707   Array3D<float> expected_3d(
2708       {{{3.0f, 6.0f}, {9.0f, 12.0f}, {15.0f, 18.0f}},
2709        {{21.0f, 24.0f}, {27.0f, 30.0f}, {33.0f, 36.0f}}});
2710   ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2711 }
2712 
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo3DTwoWaysOver2)2713 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) {
2714   // Add together a (2, 3, 2) array with a (2) array, using dimension 0 for
2715   // broadcasting (though there are two ways to broadcast these shapes).
2716   XlaBuilder builder(TestName());
2717   // clang-format off
2718   Array3D<float> a_3d({
2719     {{1.0f, 2.0f},
2720      {3.0f, 4.0f},
2721      {5.0f, 6.0f}},
2722     {{7.0f, 8.0f},
2723      {9.0f, 10.0f},
2724      {11.0f, 12.0f}},
2725   });
2726   // clang-format on
2727   auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2728   auto v = ConstantR1<float>(&builder, {10.0f, 20.0f});
2729   Add(a, v, /*broadcast_dimensions=*/{2});
2730 
2731   Array3D<float> expected_3d(
2732       {{{11.0f, 22.0f}, {13.0f, 24.0f}, {15.0f, 26.0f}},
2733        {{17.0f, 28.0f}, {19.0f, 30.0f}, {21.0f, 32.0f}}});
2734   ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2735 }
2736 
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo3DTwoWaysOver0)2737 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) {
2738   // Add together a (2, 3, 2) array with a (2) array, using dimension 2 for
2739   // broadcasting (though there are two ways to broadcast these shapes).
2740   XlaBuilder builder(TestName());
2741   // clang-format off
2742   Array3D<float> a_3d({
2743     {{1.0f, 2.0f},
2744      {3.0f, 4.0f},
2745      {5.0f, 6.0f}},
2746     {{7.0f, 8.0f},
2747      {9.0f, 10.0f},
2748      {11.0f, 12.0f}},
2749   });
2750   // clang-format on
2751   auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2752   auto v = ConstantR1<float>(&builder, {10.0f, 20.0f});
2753   Add(a, v, /*broadcast_dimensions=*/{0});
2754 
2755   // clang-format off
2756   Array3D<float> expected_3d({
2757     {{11.0f, 12.0f},
2758      {13.0f, 14.0f},
2759      {15.0f, 16.0f}},
2760     {{27.0f, 28.0f},
2761      {29.0f, 30.0f},
2762      {31.0f, 32.0f}},
2763   });
2764   // clang-format on
2765   ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2766 }
2767 
XLA_TEST_F(ArrayElementwiseOpTest,Add2DTo3D)2768 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) {
2769   // Add together a (2, 3, 2) array with a (3, 2) array, using dimensions {1,2}
2770   // for broadcasting.
2771   XlaBuilder builder(TestName());
2772   // clang-format off
2773   Array3D<float> a_3d({
2774     {{1.0f, 2.0f},
2775      {3.0f, 4.0f},
2776      {5.0f, 6.0f}},
2777     {{7.0f, 8.0f},
2778      {9.0f, 10.0f},
2779      {11.0f, 12.0f}},
2780   });
2781   auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2782   auto m = ConstantR2<float>(&builder, {
2783     {10.0f, 20.0f, 30.0f},
2784     {40.0f, 50.0f, 60.0f},
2785   });
2786   Add(a, m, /*broadcast_dimensions=*/{0, 1});
2787 
2788   Array3D<float> expected_3d({
2789     {{11.0f, 12.0f},
2790      {23.0f, 24.0f},
2791      {35.0f, 36.0f}},
2792     {{47.0f, 48.0f},
2793      {59.0f, 60.0f},
2794      {71.0f, 72.0f}},
2795   });
2796   // clang-format on
2797   ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2798 }
2799 
XLA_TEST_F(ArrayElementwiseOpTest,CompareGtR3F32sWithDegenerateDim2)2800 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) {
2801   // Comparison between two 3D arrays of compatible shapes:
2802   // (2, 3, 2) and (2, 3, 1): expected to produce a (2, 3, 2) shape of PREDs.
2803   XlaBuilder builder(TestName());
2804   Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
2805                        {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}});
2806   auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2807 
2808   Array3D<float> b_3d({{{7.0f, 1.0f}, {3.0f, 10.0f}, {15.0f, 6.0f}}});
2809   auto b = ConstantR3FromArray3D<float>(&builder, b_3d);
2810 
2811   Gt(a, b);
2812 
2813   Array3D<int> expected_3d(
2814       {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}});
2815   const string expected = R"(pred[2,3,2] {
2816 {
2817   { 0, 1 },
2818   { 0, 0 },
2819   { 0, 0 }
2820 },
2821 {
2822   { 0, 1 },
2823   { 1, 0 },
2824   { 0, 1 }
2825 }
2826 })";
2827   EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2828 }
2829 
2830 XLA_TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) {
2831   XlaBuilder builder(TestName());
2832 
2833   std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5));
2834   std::unique_ptr<Array4D<float>> operand_b_4d(new Array4D<float>(2, 3, 4, 5));
2835   std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5));
2836   float value = 0.0;
2837   for (int64 p = 0; p < 2; ++p) {
2838     for (int64 z = 0; z < 3; ++z) {
2839       for (int64 y = 0; y < 4; ++y) {
2840         for (int64 x = 0; x < 5; ++x) {
2841           (*operand_a_4d)(p, z, y, x) = value;
2842           (*operand_b_4d)(p, z, y, x) = 2.0 * value;
2843           (*expected_4d)(p, z, y, x) = 3.0 * value;
2844           value += 0.1;
2845         }
2846       }
2847     }
2848   }
2849 
2850   auto a = ConstantR4FromArray4D<float>(&builder, *operand_a_4d);
2851   auto b = ConstantR4FromArray4D<float>(&builder, *operand_b_4d);
2852   Add(a, b);
2853 
2854   ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
2855 }
2856 
XLA_TEST_F(ArrayElementwiseOpTest,R4PlusR1InDim1)2857 XLA_TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) {
2858   XlaBuilder builder(TestName());
2859 
2860   std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5));
2861   std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5));
2862   std::vector<float> operand_b_1d(3);
2863   std::iota(operand_b_1d.begin(), operand_b_1d.end(), 1.0);
2864 
2865   float value = 0.0;
2866   for (int64 p = 0; p < 2; ++p) {
2867     for (int64 z = 0; z < 3; ++z) {
2868       for (int64 y = 0; y < 4; ++y) {
2869         for (int64 x = 0; x < 5; ++x) {
2870           (*operand_a_4d)(p, z, y, x) = value;
2871           (*expected_4d)(p, z, y, x) = value + operand_b_1d[z];
2872           value += 0.1;
2873         }
2874       }
2875     }
2876   }
2877 
2878   auto a = ConstantR4FromArray4D<float>(&builder, *operand_a_4d);
2879   auto b = ConstantR1<float>(&builder, operand_b_1d);
2880   Add(a, b, {1});
2881 
2882   ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
2883 }
2884 
XLA_TEST_F(ArrayElementwiseOpTest,R4_16x16x2x2_Plus_R1_16)2885 XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) {
2886   constexpr int d0 = 16;
2887   constexpr int d1 = 16;
2888   constexpr int d2 = 2;
2889   constexpr int d3 = 2;
2890   Array4D<float> r4(d0, d1, d2, d3);
2891   r4.Fill(1.0);
2892   std::vector<float> r1(d1);
2893   std::iota(r1.begin(), r1.end(), 1.0);
2894 
2895   XlaBuilder builder(TestName());
2896   Literal a_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
2897       r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
2898   auto a = ConstantLiteral(&builder, a_literal);
2899   auto b = ConstantR1<float>(&builder, r1);
2900   Add(a, b, {1});
2901 
2902   for (int i0 = 0; i0 < d0; ++i0) {
2903     for (int i1 = 0; i1 < d1; ++i1) {
2904       for (int i2 = 0; i2 < d2; ++i2) {
2905         for (int i3 = 0; i3 < d3; ++i3) {
2906           r4(i0, i1, i2, i3) += r1[i1];
2907         }
2908       }
2909     }
2910   }
2911   ComputeAndCompareR4<float>(&builder, r4, {}, error_spec_);
2912 }
2913 
2914 // Show that we can't add two opaques.
XLA_TEST_F(ArrayElementwiseOpTest,CannotAddOpaques)2915 XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) {
2916   XlaBuilder builder(TestName());
2917   auto shape = ShapeUtil::MakeOpaqueShape();
2918   auto x = Parameter(&builder, 0, shape, "x");
2919   Add(x, x);
2920   auto computation_status = builder.Build();
2921   ASSERT_FALSE(computation_status.ok());
2922   EXPECT_THAT(computation_status.status().ToString(),
2923               ::testing::ContainsRegex(
2924                   "Expected array argument for lhs of binary operation"));
2925 }
2926 
XLA_TEST_F(ArrayElementwiseOpTest,IdentityBroadcastOfSameRankIsAllowed)2927 XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) {
2928   XlaBuilder builder(TestName());
2929   auto a = ConstantR2<float>(&builder,
2930                              {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2931   auto b = ConstantR2<float>(&builder,
2932                              {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
2933   Add(a, b, /*broadcast_dimensions=*/{0, 1});
2934 
2935   Array2D<float> expected_array(
2936       {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}});
2937   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2938 }
2939 
XLA_TEST_F(ArrayElementwiseOpTest,NonIdentityBroadcastOfSameRankIsDisallowed)2940 XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) {
2941   XlaBuilder builder(TestName());
2942   auto a = ConstantR2<float>(&builder,
2943                              {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2944   auto b = ConstantR2<float>(&builder,
2945                              {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
2946   Add(a, b, /*broadcast_dimensions=*/{1, 0});
2947 
2948   auto computation_status = builder.Build();
2949   ASSERT_FALSE(computation_status.ok());
2950   EXPECT_THAT(computation_status.status().error_message(),
2951               ::testing::ContainsRegex("must.*be the identity"));
2952 }
2953 
2954 // Regression test for b/31927799. "slice - y" is fused and requires implicit
2955 // broadcast.
XLA_TEST_F(ArrayElementwiseOpTest,ImplictBroadcastInFusedExpressions)2956 XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) {
2957   XlaBuilder builder(TestName());
2958   auto x_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
2959   auto y_literal = LiteralUtil::CreateR1<float>({4, 5});
2960   auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
2961   auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
2962 
2963   auto x = Parameter(&builder, 0, x_literal.shape(), "x");
2964   auto y = Parameter(&builder, 1, y_literal.shape(), "y");
2965   auto slice = Slice(x, {1}, {2}, {1});
2966   Sub(slice, y);
2967 
2968   ComputeAndCompareR1<float>(&builder, {-2, -3}, {x_data.get(), y_data.get()},
2969                              error_spec_);
2970 }
2971 
2972 INSTANTIATE_TEST_CASE_P(ArrayElementwiseOpTestParamCount,
2973                         ArrayElementwiseOpTestParamCount,
2974                         ::testing::Values(127, 128, 129, 17 * 4096));
2975 
2976 }  // namespace
2977 }  // namespace xla
2978