• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Randomized tests for XLA implementations of Tensorflow operations.
17 //
18 // For each operator, the tests in this file choose a set of random inputs and
19 // attributes. The test then compares the outputs of the operator when executed
20 // via Tensorflow using the CPU device and when executed via XLA.
21 //
22 // By default, each test chooses a random seed nondeterministically (using
23 // std::random_device). However, a particular choice of random seed can be
24 // forced using the flag --tf_xla_random_seed; each test logs the
25 // flag value necessary to reproduce its outputs.
26 //
27 // Example usage:
28 // Run tests, comparing the Tensorflow CPU operators with their XLA-compiled
29 // counterparts:
30 // randomized_tests \
31 //   --tf_xla_test_use_jit=true --tf_xla_test_device=CPU:0 \
32 //   --tf_xla_test_repetitions=20
33 
34 // TODO(phawkins): add tests for:
35 // * DepthwiseConv2DNative
36 // * Gather
37 // * InvertPermutation
38 // * MaxPoolGrad (requires implementation of forward operator)
39 // * Select
40 // * Unpack
41 //
42 // TODO(phawkins): improve tests for:
43 // * StridedSliceGrad (need to use shape function to compute sensible inputs)
44 
45 #include <random>
46 #include <unordered_map>
47 
48 #include "absl/algorithm/container.h"
49 #include "absl/container/flat_hash_set.h"
50 #include "absl/strings/str_cat.h"
51 #include "absl/strings/string_view.h"
52 #include "tensorflow/compiler/jit/defs.h"
53 #include "tensorflow/compiler/tf2xla/type_util.h"
54 #include "tensorflow/core/common_runtime/device.h"
55 #include "tensorflow/core/common_runtime/device_factory.h"
56 #include "tensorflow/core/common_runtime/device_mgr.h"
57 #include "tensorflow/core/framework/node_def_builder.h"
58 #include "tensorflow/core/framework/node_def_util.h"
59 #include "tensorflow/core/framework/op_kernel.h"
60 #include "tensorflow/core/framework/tensor.h"
61 #include "tensorflow/core/framework/tensor_testutil.h"
62 #include "tensorflow/core/framework/types.pb.h"
63 #include "tensorflow/core/graph/graph.h"
64 #include "tensorflow/core/graph/graph_constructor.h"
65 #include "tensorflow/core/lib/bfloat16/bfloat16.h"
66 #include "tensorflow/core/lib/core/status.h"
67 #include "tensorflow/core/lib/core/status_test_util.h"
68 #include "tensorflow/core/platform/test.h"
69 #include "tensorflow/core/public/session.h"
70 #include "tensorflow/core/public/session_options.h"
71 #include "tensorflow/core/util/command_line_flags.h"
72 #include "tensorflow/core/util/device_name_utils.h"
73 #include "tensorflow/core/util/tensor_format.h"
74 
75 namespace tensorflow {
76 namespace {
77 
78 // Command line flags: see main() below.
79 int64 tf_xla_random_seed = 0;
80 int32 tf_xla_test_repetitions = 20;
81 int64 tf_xla_max_tensor_size = 10000LL;
82 string* tf_xla_test_device_ptr;  // initial value set in main()
83 string* tf_xla_reference_device_ptr;  // initial value set in main()
84 bool tf_xla_test_use_jit = true;
85 
LocalDeviceToFullDeviceName(const string & device)86 string LocalDeviceToFullDeviceName(const string& device) {
87   return absl::StrCat("/job:localhost/replica:0/task:0/device:", device);
88 }
89 
90 constexpr std::array<DataType, 5> kAllXlaTypes = {
91     {DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64, DT_INT64}};
92 
93 // An OpTestBuilder is a graph builder class that takes as input an operator to
94 // test, its inputs and attributes, and builds a graph that executes the
95 // operator.
96 class OpTestBuilder {
97  public:
98   explicit OpTestBuilder(const string& op_name);
99 
100   // Adds an input 'tensor' as a Placeholder node.
101   OpTestBuilder& Input(const Tensor& tensor);
102 
103   // Adds a random input tensor with 'type' as a Placeholder node.
104   // If 'dims' is not provided, RandomDims() is used.
105   OpTestBuilder& RandomInput(DataType type);
106   OpTestBuilder& RandomInput(DataType type, std::vector<int64> dims);
107 
108   // As RandomInput but the values are unique.
109   OpTestBuilder& RandomUniqueInput(DataType type, std::vector<int64> dims);
110 
111   // Sets an attribute.
112   template <class T>
113   OpTestBuilder& Attr(absl::string_view attr_name, T&& value);
114 
115   // Overload needed to allow {...} expressions for value.
116   template <class T>
117   OpTestBuilder& Attr(absl::string_view attr_name,
118                       std::initializer_list<T> value);
119 
120   // Adds nodes that executes the operator under test on 'device' to 'graphdef'.
121   // If 'use_jit' is true, marks the operator under test to be compiled by XLA.
122   // The graph will consist of one Placeholder node per input, the operator
123   // itself, and one Identity node per output. If 'test_node_def' is not null,
124   // sets it to the NodeDef of the operator under test. Fills 'inputs' and
125   // 'outputs' with the names of the input placeholder nodes and the output
126   // identity nodes, respectively.
127   Status BuildGraph(const string& name_prefix, const string& device,
128                     bool use_jit, GraphDef* graphdef, NodeDef** test_node_def,
129                     std::vector<string>* inputs,
130                     std::vector<string>* outputs) const;
131 
132   struct InputDescription {
133     Tensor tensor;
134 
135     DataType type = DT_INVALID;
136     bool has_dims = false;
137     bool needs_unique_values = false;
138     std::vector<int64> dims;
139   };
140 
inputs() const141   const std::vector<InputDescription>& inputs() const { return inputs_; }
142 
143  private:
144   NodeDef node_def_;
145   std::vector<InputDescription> inputs_;
146 };
147 
OpTestBuilder(const string & op_name)148 OpTestBuilder::OpTestBuilder(const string& op_name) {
149   node_def_.set_op(op_name);
150 }
151 
Input(const Tensor & tensor)152 OpTestBuilder& OpTestBuilder::Input(const Tensor& tensor) {
153   VLOG(1) << "Adding input: " << tensor.DebugString();
154   InputDescription input;
155   input.tensor = tensor;
156   inputs_.push_back(input);
157   return *this;
158 }
159 
RandomInput(DataType type)160 OpTestBuilder& OpTestBuilder::RandomInput(DataType type) {
161   VLOG(1) << "Adding random input: " << type;
162   InputDescription input;
163   input.type = type;
164   inputs_.push_back(input);
165   return *this;
166 }
167 
RandomInput(DataType type,std::vector<int64> dims)168 OpTestBuilder& OpTestBuilder::RandomInput(DataType type,
169                                           std::vector<int64> dims) {
170   VLOG(1) << "Adding input: " << type << " " << TensorShape(dims).DebugString();
171   InputDescription input;
172   input.type = type;
173   input.has_dims = true;
174   input.dims = std::move(dims);
175   inputs_.push_back(input);
176   return *this;
177 }
178 
RandomUniqueInput(DataType type,std::vector<int64> dims)179 OpTestBuilder& OpTestBuilder::RandomUniqueInput(DataType type,
180                                                 std::vector<int64> dims) {
181   VLOG(1) << "Adding input: " << type << " " << TensorShape(dims).DebugString();
182   InputDescription input;
183   input.type = type;
184   input.has_dims = true;
185   input.needs_unique_values = true;
186   input.dims = std::move(dims);
187   inputs_.push_back(input);
188   return *this;
189 }
190 
191 template <class T>
Attr(absl::string_view attr_name,T && value)192 OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name, T&& value) {
193   AddNodeAttr(attr_name, std::forward<T>(value), &node_def_);
194   return *this;
195 }
196 
197 template <class T>
Attr(absl::string_view attr_name,std::initializer_list<T> value)198 OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name,
199                                    std::initializer_list<T> value) {
200   Attr<std::initializer_list<T>>(attr_name, std::move(value));
201   return *this;
202 }
203 
BuildGraph(const string & name_prefix,const string & device,bool use_jit,GraphDef * graphdef,NodeDef ** test_node_def,std::vector<string> * inputs,std::vector<string> * outputs) const204 Status OpTestBuilder::BuildGraph(const string& name_prefix,
205                                  const string& device, bool use_jit,
206                                  GraphDef* graphdef, NodeDef** test_node_def,
207                                  std::vector<string>* inputs,
208                                  std::vector<string>* outputs) const {
209   OpRegistryInterface* op_registry = OpRegistry::Global();
210 
211   const OpDef* op_def;
212   TF_RETURN_IF_ERROR(op_registry->LookUpOpDef(node_def_.op(), &op_def));
213 
214   NodeDef* test_def = graphdef->add_node();
215   *test_def = node_def_;
216   test_def->set_name(absl::StrCat(name_prefix, "_op_under_test"));
217   test_def->set_device(device);
218   AddDefaultsToNodeDef(*op_def, test_def);
219   if (use_jit) {
220     AddNodeAttr(kXlaCompileAttr, true, test_def);
221   }
222   VLOG(1) << "Op under test: " << test_def->DebugString();
223 
224   DataTypeVector input_types, output_types;
225   TF_RETURN_IF_ERROR(
226       InOutTypesForNode(*test_def, *op_def, &input_types, &output_types));
227 
228   // Build feed and fetch nodes.
229   for (int i = 0; i < input_types.size(); ++i) {
230     NodeDef* def = graphdef->add_node();
231     string name = absl::StrCat(name_prefix, "_input_", i);
232     TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Placeholder")
233                            .Device(device)
234                            .Attr("dtype", input_types[i])
235                            .Finalize(def));
236     inputs->push_back(name);
237     test_def->add_input(name);
238   }
239 
240   for (int i = 0; i < output_types.size(); ++i) {
241     NodeDef* def = graphdef->add_node();
242     string name = absl::StrCat(name_prefix, "_output_", i);
243     TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Identity")
244                            .Device(device)
245                            .Attr("T", output_types[i])
246                            .Input(test_def->name(), i, output_types[i])
247                            .Finalize(def));
248     outputs->push_back(name);
249   }
250 
251   if (test_node_def) {
252     *test_node_def = test_def;
253   }
254 
255   return Status::OK();
256 }
257 
258 // Test fixture. The fixture manages the random number generator and its seed,
259 // and has a number of convenience methods for building random Tensors, shapes,
260 // etc.
261 class OpTest : public ::testing::Test {
262  public:
263   OpTest();
264 
265   enum TestResult {
266     // The test saw an unrecoverable error. Don't try any more runs.
267     kFatalError,
268     // The parameters of the test were invalid (e.g., the "golden"
269     // implementation failed, or the parameters are oversize). Reruns are ok.
270     kInvalid,
271     // The test ran successfully, and we have a verdict. Does *not* mean the
272     // test passed.
273     kOk,
274   };
275 
276   // Runs 'fn' up to --tf_xla_test_repetitions times, or until a test failure
277   // occurs; whichever happens first. Reruns if the TestResult is kInvalid.
278   void Repeatedly(const std::function<TestResult(void)>& fn);
279 
280   // Select a random element from 'candidates'.
281   template <typename T>
282   T Choose(absl::Span<const T> candidates);
283 
284   static constexpr int kDefaultMaxRank = 5;
285   static constexpr int64 kDefaultMaxDimensionSize = 256LL;
286 
287   // Returns true if 'dims' have a size less than tf_xla_max_tensor_size.
288   bool TensorSizeIsOk(absl::Span<const int64> dims);
289 
290   // Returns a random dimension size, in the range [min, max).
291   int64 RandomDim(int64 min = 0, int64 max = kDefaultMaxDimensionSize);
292 
293   // Returns a random shape. The tensor has rank in the range [min_rank,
294   // max_rank). Each dimension has size [min_size, max_size).
295   std::vector<int64> RandomDims(int min_rank = 0,
296                                 int max_rank = kDefaultMaxRank,
297                                 int64 min_size = 0,
298                                 int64 max_size = kDefaultMaxDimensionSize);
299 
300   // Given a shape 'dims', build a pair of dimensions such that one broadcasts
301   // to the other.
302   std::pair<std::vector<int64>, std::vector<int64>> BroadcastableDims(
303       std::vector<int64> dims);
304 
305   // Builds a random pair of broadcastable dims.
306   // TODO(phawkins): currently the maximum rank is 3, because broadcasting > 3
307   // dimensions is unimplemented by the Tensorflow Eigen code (b/29268487)
308   std::pair<std::vector<int64>, std::vector<int64>> BroadcastableDims();
309 
310   // Returns a tensor filled with random but "reasonable" values from the middle
311   // of the type's range. If the shape is omitted, a random shape is used.
312   // TODO(phawkins): generalize this code to a caller-supplied distribution.
313   Tensor RandomTensor(DataType dtype, bool needs_unique_values,
314                       absl::Span<const int64> shape);
315   Tensor RandomTensor(DataType dtype);
316 
317   // Like RandomTensor, but uses values >= 0.
318   Tensor RandomNonNegativeTensor(DataType dtype, absl::Span<const int64> shape);
319   Tensor RandomNonNegativeTensor(DataType dtype);
320 
321   // Returns a random subset of the integers in the range [0, rank), suitable
322   // for use as reduction indices.
323   Tensor RandomReductionIndices(int rank);
324 
325   // Returns a random bit.
326   bool RandomBool();
327 
328   struct WindowedSpatialDims {
329     Padding padding;
330     std::vector<int64> kernel_dims;
331     std::vector<int64> stride_dims;
332     std::vector<int64> input_dims;
333     std::vector<int64> output_dims;
334   };
335   // Choose spatial dimensions for a windowed op such as pooling or convolution.
336   WindowedSpatialDims ChooseWindowedSpatialDims(int num_spatial_dims);
337 
338   // Builds dimensions for a windowed op such as pooling or convolution,
339   // including a batch and feature dimension.
340   std::vector<int64> ImageDims(TensorFormat format, int batch, int feature,
341                                const std::vector<int64>& spatial_dims);
342 
343   // Converts an int64 vector to an int32 vector.
344   std::vector<int32> AsInt32s(const std::vector<int64>& int64s);
345 
generator()346   std::mt19937& generator() { return *generator_; }
347 
348   // Run the test case described by 'builder' with and without XLA and check
349   // that the outputs are close. Tensors x and y are close if they have the same
350   // type, same shape, and have close values. For floating-point tensors, the
351   // element-wise difference between x and y must no more than
352   // atol + rtol * abs(x); or both elements may be NaN or infinity. For
353   // non-floating-point tensors the element values must match exactly.
354   TestResult ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder,
355                                            double atol = 1e-2,
356                                            double rtol = 1e-2);
357 
358  protected:
359   // Per-test state:
360   std::unique_ptr<std::mt19937> generator_;
361 
362   std::unique_ptr<Session> session_;
363 
364   // Number of test cases built in 'session_'. Used to uniquify node names.
365   int num_tests_ = 0;
366 };
367 
OpTest()368 OpTest::OpTest() {
369   // Creates a random-number generator for the test case. Use the value of
370   // --tf_xla_random_seed as the seed, if provided.
371   int64 s = tf_xla_random_seed;
372   unsigned int seed;
373   if (s <= 0) {
374     std::random_device random_device;
375     seed = random_device();
376   } else {
377     seed = static_cast<unsigned int>(s);
378   }
379   LOG(ERROR) << "Random seed for test case: " << seed
380              << ". To reproduce the "
381                 "results of this test, pass flag --tf_xla_random_seed="
382              << seed;
383   generator_.reset(new std::mt19937(seed));
384 
385   // Create a session with an empty graph.
386   SessionOptions session_options;
387   session_.reset(NewSession(session_options));
388   GraphDef def;
389   TF_CHECK_OK(session_->Create(def));
390 }
391 
Repeatedly(const std::function<TestResult (void)> & fn)392 void OpTest::Repeatedly(const std::function<TestResult(void)>& fn) {
393   int const max_repetitions = tf_xla_test_repetitions;
394   int valid_test_runs = 0;
395   // We run up to 100 * max_repetitions times; the idea is that if we roll the
396   // dice enough times we will find some valid parameters. We want to put an
397   // upper limit on the number iterations just in case the probability of
398   // finding feasible parameters is very low.
399   for (int i = 0; !HasFailure() && i < max_repetitions * 100 &&
400                   valid_test_runs < max_repetitions;
401        ++i) {
402     TestResult result = fn();
403     switch (result) {
404       case kOk:
405         ++valid_test_runs;
406         break;
407 
408       case kFatalError:
409         ASSERT_TRUE(false) << "Test had fatal failure";
410         return;
411 
412       case kInvalid:
413         break;
414     }
415   }
416   if (!HasFailure()) {
417     EXPECT_GE(valid_test_runs, max_repetitions)
418         << "Not enough test instances passed; this means that either the "
419            "golden implementation is buggy or the operator harness is not "
420            "producing well-formed test cases with a high probability.";
421   }
422 }
423 
424 template <typename T>
Choose(absl::Span<const T> candidates)425 T OpTest::Choose(absl::Span<const T> candidates) {
426   std::uniform_int_distribution<size_t> d(0, candidates.size() - 1);
427   return candidates[d(generator())];
428 }
429 
RandomDim(int64 min,int64 max)430 int64 OpTest::RandomDim(int64 min, int64 max) {
431   std::uniform_int_distribution<int64> size_distribution(min, max - 1);
432   return size_distribution(generator());
433 }
434 
TensorSizeIsOk(absl::Span<const int64> dims)435 bool OpTest::TensorSizeIsOk(absl::Span<const int64> dims) {
436   int64 size = 1LL;
437   for (int64 dim : dims) {
438     size *= dim;
439   }
440   return size < tf_xla_max_tensor_size;
441 }
442 
RandomDims(int min_rank,int max_rank,int64 min_size,int64 max_size)443 std::vector<int64> OpTest::RandomDims(int min_rank, int max_rank,
444                                       int64 min_size, int64 max_size) {
445   CHECK_LE(0, min_rank);
446   CHECK_LE(min_rank, max_rank);
447   std::uniform_int_distribution<int> rank_distribution(min_rank, max_rank);
448   int rank = rank_distribution(generator());
449   std::vector<int64> dims(rank);
450   // TODO(phawkins): too small a maximum tensor size could lead to an infinite
451   // loop here.
452   do {
453     std::generate(dims.begin(), dims.end(), [this, min_size, max_size]() {
454       return RandomDim(min_size, max_size);
455     });
456   } while (!TensorSizeIsOk(dims));
457   return dims;
458 }
459 
RandomBool()460 bool OpTest::RandomBool() {
461   std::bernoulli_distribution d(0.5);
462   return d(generator());
463 }
464 
RandomTensor(DataType dtype,bool needs_unique_values,absl::Span<const int64> shape)465 Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
466                             absl::Span<const int64> shape) {
467   Tensor tensor(dtype, TensorShape(shape));
468   switch (dtype) {
469     case DT_FLOAT: {
470       absl::flat_hash_set<float> already_generated;
471       std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
472       test::FillFn<float>(&tensor, [&](int i) -> float {
473         float generated;
474         do {
475           generated = distribution(generator());
476         } while (needs_unique_values &&
477                  !already_generated.insert(generated).second);
478         return generated;
479       });
480       break;
481     }
482     case DT_DOUBLE: {
483       absl::flat_hash_set<double> already_generated;
484       std::uniform_real_distribution<double> distribution(-1.0, 1.0);
485       test::FillFn<double>(&tensor, [&](int i) -> double {
486         double generated;
487         do {
488           generated = distribution(generator());
489         } while (needs_unique_values &&
490                  !already_generated.insert(generated).second);
491         return generated;
492       });
493       break;
494     }
495     case DT_COMPLEX64: {
496       absl::flat_hash_set<std::pair<float, float>> already_generated;
497       std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
498       test::FillFn<complex64>(&tensor, [&](int i) {
499         complex64 generated;
500         do {
501           generated =
502               complex64(distribution(generator()), distribution(generator()));
503         } while (
504             needs_unique_values &&
505             !already_generated
506                  .insert(std::make_pair(generated.real(), generated.imag()))
507                  .second);
508         return generated;
509       });
510       break;
511     }
512     case DT_INT32: {
513       absl::flat_hash_set<int32> already_generated;
514       std::uniform_int_distribution<int32> distribution(-(1 << 20), 1 << 20);
515       test::FillFn<int32>(&tensor, [&](int i) -> int32 {
516         int32 generated;
517         do {
518           generated = distribution(generator());
519         } while (needs_unique_values &&
520                  !already_generated.insert(generated).second);
521         return generated;
522       });
523       break;
524     }
525     case DT_INT64: {
526       absl::flat_hash_set<int64> already_generated;
527       std::uniform_int_distribution<int64> distribution(-(1LL << 40),
528                                                         1LL << 40);
529       test::FillFn<int64>(&tensor, [&](int i) -> int64 {
530         int64 generated;
531         do {
532           generated = distribution(generator());
533         } while (needs_unique_values &&
534                  !already_generated.insert(generated).second);
535         return generated;
536       });
537       break;
538     }
539     case DT_BOOL: {
540       absl::flat_hash_set<bool> already_generated;
541       std::bernoulli_distribution distribution;
542       test::FillFn<bool>(&tensor, [&](int i) -> bool {
543         bool generated;
544         do {
545           generated = distribution(generator());
546         } while (needs_unique_values &&
547                  !already_generated.insert(generated).second);
548         return generated;
549       });
550       break;
551     }
552     default:
553       LOG(FATAL) << "Unimplemented type " << dtype << " in RandomTensor";
554   }
555   return tensor;
556 }
557 
RandomTensor(DataType dtype)558 Tensor OpTest::RandomTensor(DataType dtype) {
559   return RandomTensor(dtype, /*needs_unique_values=*/false, RandomDims());
560 }
561 
RandomNonNegativeTensor(DataType dtype,absl::Span<const int64> shape)562 Tensor OpTest::RandomNonNegativeTensor(DataType dtype,
563                                        absl::Span<const int64> shape) {
564   Tensor tensor(dtype, TensorShape(shape));
565   switch (dtype) {
566     case DT_FLOAT: {
567       std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
568       test::FillFn<float>(&tensor, [this, &distribution](int i) -> float {
569         return distribution(generator());
570       });
571       break;
572     }
573     case DT_DOUBLE: {
574       std::uniform_real_distribution<double> distribution(0.0, 1.0);
575       test::FillFn<double>(&tensor, [this, &distribution](int i) -> double {
576         return distribution(generator());
577       });
578       break;
579     }
580     case DT_INT32: {
581       std::uniform_int_distribution<int32> distribution(0, 1 << 20);
582       test::FillFn<int32>(&tensor, [this, &distribution](int i) -> int32 {
583         return distribution(generator());
584       });
585       break;
586     }
587     case DT_INT64: {
588       std::uniform_int_distribution<int64> distribution(0, 1LL << 40);
589       test::FillFn<int64>(&tensor, [this, &distribution](int i) -> int64 {
590         return distribution(generator());
591       });
592       break;
593     }
594     default:
595       LOG(FATAL) << "Unimplemented type " << dtype
596                  << " in RandomNonNegativeTensor";
597   }
598   return tensor;
599 }
600 
RandomNonNegativeTensor(DataType dtype)601 Tensor OpTest::RandomNonNegativeTensor(DataType dtype) {
602   return RandomNonNegativeTensor(dtype, RandomDims());
603 }
604 
BroadcastableDims(std::vector<int64> dims)605 std::pair<std::vector<int64>, std::vector<int64>> OpTest::BroadcastableDims(
606     std::vector<int64> dims) {
607   if (dims.empty()) return {dims, dims};
608 
609   // Remove some dimensions from the front of 'dims'.
610   size_t skip =
611       std::uniform_int_distribution<size_t>(0, dims.size() - 1)(generator());
612 
613   std::vector<int64> bdims(dims.begin() + skip, dims.end());
614 
615   // Randomly replace some of the remaining dimensions of 'dims' with 1.
616   std::bernoulli_distribution random_bool;
617 
618   for (int64& dim : bdims) {
619     if (random_bool(generator())) {
620       dim = 1LL;
621     }
622   }
623 
624   // Possibly swap the roles of 'dims' and 'bdims'.
625   if (random_bool(generator())) {
626     dims.swap(bdims);
627   }
628   return {dims, bdims};
629 }
630 
BroadcastableDims()631 std::pair<std::vector<int64>, std::vector<int64>> OpTest::BroadcastableDims() {
632   return BroadcastableDims(RandomDims(0, 3));
633 }
634 
RandomReductionIndices(int rank)635 Tensor OpTest::RandomReductionIndices(int rank) {
636   std::bernoulli_distribution random_bool;
637   std::vector<int32> indices;
638   for (int i = 0; i < rank; ++i) {
639     if (random_bool(generator())) {
640       indices.push_back(i);
641     }
642   }
643   return test::AsTensor<int32>(indices);
644 }
645 
ChooseWindowedSpatialDims(int num_spatial_dims)646 OpTest::WindowedSpatialDims OpTest::ChooseWindowedSpatialDims(
647     int num_spatial_dims) {
648   WindowedSpatialDims d;
649   d.padding = Choose<Padding>({SAME, VALID});
650   std::uniform_int_distribution<int> random_int(1, 5);
651   d.kernel_dims.resize(num_spatial_dims);
652   d.input_dims.resize(num_spatial_dims);
653   d.output_dims.resize(num_spatial_dims);
654   d.stride_dims.resize(num_spatial_dims);
655   for (int i = 0; i < num_spatial_dims; ++i) {
656     Status s;
657     // Repeatedly try different filter/stride sizes until we find a valid
658     // combination.
659     do {
660       // CPU implementations require stride <= kernel size.
661       d.kernel_dims[i] = random_int(generator()),
662       d.input_dims[i] = RandomDim(d.kernel_dims[i]);
663       d.stride_dims[i] =
664           std::uniform_int_distribution<int>(1, d.kernel_dims[i])(generator());
665       int64 pad_dummy;
666       s = GetWindowedOutputSize(d.input_dims[i], d.kernel_dims[i],
667                                 d.stride_dims[i], d.padding, &d.output_dims[i],
668                                 &pad_dummy);
669     } while (!s.ok());
670   }
671   return d;
672 }
673 
ImageDims(TensorFormat format,int batch,int feature,const std::vector<int64> & spatial_dims)674 std::vector<int64> OpTest::ImageDims(TensorFormat format, int batch,
675                                      int feature,
676                                      const std::vector<int64>& spatial_dims) {
677   std::vector<int64> dims;
678   switch (format) {
679     case FORMAT_NHWC:
680       dims.push_back(batch);
681       for (int dim : spatial_dims) {
682         dims.push_back(dim);
683       }
684       dims.push_back(feature);
685       break;
686     case FORMAT_NCHW:
687       dims.push_back(batch);
688       dims.push_back(feature);
689       for (int dim : spatial_dims) {
690         dims.push_back(dim);
691       }
692       break;
693     default:
694       LOG(FATAL) << "Tensor format " << ToString(format) << " not supported.";
695   }
696   return dims;
697 }
698 
AsInt32s(const std::vector<int64> & int64s)699 std::vector<int32> OpTest::AsInt32s(const std::vector<int64>& int64s) {
700   return std::vector<int32>(int64s.begin(), int64s.end());
701 }
702 
703 // Functions for comparing tensors.
704 
705 template <typename T>
Abs(T x)706 double Abs(T x) {
707   return std::fabs(x);
708 }
709 
710 template <>
Abs(complex64 x)711 double Abs<complex64>(complex64 x) {
712   return std::abs(x);
713 }
714 
715 template <typename T>
IsClose(const T & x,const T & y,double atol,double rtol)716 bool IsClose(const T& x, const T& y, double atol, double rtol) {
717   if (std::isnan(x) && std::isnan(y)) return true;
718   if (x == y) return true;  // Allow inf == inf.
719   return Abs(x - y) < atol + rtol * Abs(x);
720 }
721 
722 template <>
IsClose(const complex64 & x,const complex64 & y,double atol,double rtol)723 bool IsClose<complex64>(const complex64& x, const complex64& y, double atol,
724                         double rtol) {
725   if (std::isnan(x.real()) && std::isnan(y.real())) {
726     if (std::isnan(x.imag()) && std::isnan(y.imag())) {
727       return true;
728     }
729     if (x.imag() == y.imag()) return true;  // Allow inf == inf.
730     return Abs(x.imag() - y.imag()) < atol + rtol * Abs(x.imag());
731   } else if (std::isnan(x.imag()) && std::isnan(y.imag())) {
732     if (x.real() == y.real()) return true;  // Allow inf == inf.
733     return Abs(x.real() - y.real()) < atol + rtol * Abs(x.real());
734   }
735   if (x == y) return true;  // Allow inf == inf.
736   return Abs(x - y) < atol + rtol * Abs(x);
737 }
738 
739 template <typename T>
Str(T x)740 string Str(T x) {
741   return absl::StrCat(x);
742 }
743 template <>
Str(complex64 x)744 string Str<complex64>(complex64 x) {
745   return absl::StrCat("(", x.real(), ", ", x.imag(), ")");
746 }
747 
748 template <typename T>
TensorsAreCloseImpl(const Tensor & x,const Tensor & y,double atol,double rtol)749 Status TensorsAreCloseImpl(const Tensor& x, const Tensor& y, double atol,
750                            double rtol) {
751   auto Tx = x.flat<T>();
752   auto Ty = y.flat<T>();
753   for (int i = 0; i < Tx.size(); ++i) {
754     if (!IsClose(Tx(i), Ty(i), atol, rtol)) {
755       return errors::InvalidArgument(
756           absl::StrCat(i, "-th tensor element isn't close: ", Str(Tx(i)),
757                        " vs. ", Str(Ty(i)), ". x = ", x.DebugString(),
758                        "y = ", y.DebugString(), "atol = ", atol,
759                        " rtol = ", rtol, " tol = ", atol + rtol * Abs(Tx(i))));
760     }
761   }
762   return Status::OK();
763 }
764 
765 template <typename T>
TensorsAreEqualImpl(const Tensor & x,const Tensor & y)766 Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) {
767   auto Tx = x.flat<T>();
768   auto Ty = y.flat<T>();
769   for (int i = 0; i < Tx.size(); ++i) {
770     if (Tx(i) != Ty(i)) {
771       return errors::InvalidArgument(absl::StrCat(
772           i, "-th tensor element isn't equal: ", Str(Tx(i)), " vs. ",
773           Str(Ty(i)), ". x = ", x.DebugString(), "y = ", y.DebugString()));
774     }
775   }
776   return Status::OK();
777 }
778 
TensorsAreEqualImplBfloat16(const Tensor & x,const Tensor & y)779 Status TensorsAreEqualImplBfloat16(const Tensor& x, const Tensor& y) {
780   auto Tx = x.flat<bfloat16>();
781   auto Ty = y.flat<bfloat16>();
782   for (int i = 0; i < Tx.size(); ++i) {
783     if (Tx(i) != Ty(i)) {
784       return errors::InvalidArgument(absl::StrCat(
785           i, "-th tensor element isn't equal: ", static_cast<float>(Tx(i)),
786           " vs. ", static_cast<float>(Ty(i)), ". x = ", x.DebugString(),
787           "y = ", y.DebugString()));
788     }
789   }
790   return Status::OK();
791 }
792 
793 // Tests if "x" and "y" are tensors of the same type, same shape, and with
794 // close values. For floating-point tensors, the element-wise difference between
795 // x and y must no more than atol + rtol * abs(x). For non-floating-point
796 // tensors the values must match exactly.
TensorsAreClose(const Tensor & a,const Tensor & b,double atol,double rtol)797 Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol,
798                        double rtol) {
799   if (a.dtype() != b.dtype()) {
800     return errors::InvalidArgument(absl::StrCat(
801         "Tensors have different types: ", DataTypeString(a.dtype()), " and ",
802         DataTypeString(b.dtype())));
803   }
804   if (!a.IsSameSize(b)) {
805     return errors::InvalidArgument(
806         absl::StrCat("Tensors have different shapes: ", a.shape().DebugString(),
807                      " and ", b.shape().DebugString()));
808   }
809 
810   switch (a.dtype()) {
811     case DT_FLOAT:
812       return TensorsAreCloseImpl<float>(a, b, atol, rtol);
813     case DT_DOUBLE:
814       return TensorsAreCloseImpl<double>(a, b, atol, rtol);
815     case DT_COMPLEX64:
816       return TensorsAreCloseImpl<complex64>(a, b, atol, rtol);
817     case DT_INT32:
818       return TensorsAreEqualImpl<int32>(a, b);
819     case DT_INT64:
820       return TensorsAreEqualImpl<int64>(a, b);
821     case DT_BOOL:
822       return TensorsAreEqualImpl<bool>(a, b);
823     case DT_BFLOAT16:
824       return TensorsAreEqualImplBfloat16(a, b);
825     default:
826       LOG(FATAL) << "Unexpected type : " << DataTypeString(a.dtype());
827   }
828 }
829 
ExpectTfAndXlaOutputsAreClose(const OpTestBuilder & builder,double atol,double rtol)830 OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose(
831     const OpTestBuilder& builder, double atol, double rtol) {
832   const std::vector<OpTestBuilder::InputDescription>& inputs = builder.inputs();
833   std::vector<Tensor> input_tensors;
834   input_tensors.reserve(inputs.size());
835   for (const OpTestBuilder::InputDescription& input : inputs) {
836     if (input.type == DT_INVALID) {
837       input_tensors.push_back(input.tensor);
838     } else {
839       std::vector<int64> dims;
840       if (input.has_dims) {
841         dims = input.dims;
842       } else {
843         dims = RandomDims();
844       }
845       if (!TensorSizeIsOk(dims)) {
846         VLOG(1) << "Input: " << input.type << " "
847                 << TensorShape(input.dims).DebugString();
848         VLOG(1) << "Ignoring oversize dims.";
849         return kInvalid;
850       }
851       input_tensors.push_back(
852           RandomTensor(input.type, input.needs_unique_values, dims));
853     }
854     VLOG(1) << "Input: " << input_tensors.back().DebugString();
855   }
856 
857   string reference_device =
858       LocalDeviceToFullDeviceName(*tf_xla_reference_device_ptr);
859   string test_device = LocalDeviceToFullDeviceName(*tf_xla_test_device_ptr);
860 
861   DeviceNameUtils::ParsedName parsed_name;
862   if (!DeviceNameUtils::ParseLocalName(*tf_xla_test_device_ptr, &parsed_name)) {
863     LOG(ERROR) << "Could not parse device name: " << *tf_xla_test_device_ptr;
864     return kFatalError;
865   }
866   DeviceType test_device_type(parsed_name.type);
867   ++num_tests_;
868 
869   GraphDef graph;
870   std::vector<string> expected_inputs, test_inputs;
871   std::vector<string> expected_fetches, test_fetches;
872   Status status = builder.BuildGraph(
873       absl::StrCat("test", num_tests_, "_expected"), reference_device,
874       /*use_jit=*/false, &graph, /*test_node_def=*/nullptr, &expected_inputs,
875       &expected_fetches);
876   if (!status.ok()) {
877     LOG(ERROR) << "Expected graph construction failed: " << status;
878     return kFatalError;
879   }
880 
881   NodeDef* node_def;
882   status = builder.BuildGraph(absl::StrCat("test", num_tests_, "_test"),
883                               test_device, tf_xla_test_use_jit, &graph,
884                               &node_def, &test_inputs, &test_fetches);
885   if (!status.ok()) {
886     LOG(ERROR) << "Test graph construction failed: " << status;
887     return kFatalError;
888   }
889 
890   // Check that there's a kernel corresponding to 'node_def' on the device under
891   // test.
892   status = FindKernelDef(test_device_type, *node_def, nullptr, nullptr);
893   if (!status.ok()) {
894     VLOG(1) << "Skipping test because there is no corresponding registered "
895             << "kernel on the test device: " << status;
896     return kInvalid;
897   }
898 
899   status = session_->Extend(graph);
900   if (!status.ok()) {
901     LOG(ERROR) << "Session::Extend() failed: " << status;
902     return kFatalError;
903   }
904 
905   std::vector<std::pair<string, Tensor>> expected_feeds(expected_inputs.size());
906   std::vector<std::pair<string, Tensor>> test_feeds(test_inputs.size());
907   CHECK_EQ(input_tensors.size(), expected_inputs.size());
908   CHECK_EQ(input_tensors.size(), test_inputs.size());
909 
910   for (int i = 0; i < input_tensors.size(); ++i) {
911     expected_feeds[i] = {expected_inputs[i], input_tensors[i]};
912     test_feeds[i] = {test_inputs[i], input_tensors[i]};
913   }
914 
915   std::vector<Tensor> expected_outputs, test_outputs;
916   VLOG(1) << "Running expected graph";
917   Status s =
918       session_->Run(expected_feeds, expected_fetches, {}, &expected_outputs);
919   if (!s.ok()) {
920     VLOG(1) << "Expected graph failed with status: " << s << ". Ignoring test";
921     return kInvalid;
922   }
923   for (const Tensor& expected : expected_outputs) {
924     VLOG(1) << "Expected: " << expected.DebugString();
925   }
926 
927   VLOG(1) << "Running test graph";
928   status = session_->Run(test_feeds, test_fetches, {}, &test_outputs);
929   if (!status.ok()) {
930     LOG(ERROR) << "Test graph failed: " << status;
931     return kFatalError;
932   }
933 
934   CHECK_EQ(expected_outputs.size(), test_outputs.size());
935   for (int j = 0; s.ok() && j < test_outputs.size(); ++j) {
936     s = TensorsAreClose(expected_outputs[j], test_outputs[j], atol, rtol);
937   }
938   TF_EXPECT_OK(s);
939 
940   return kOk;
941 }
942 
943 // Helper that converts 'values' to an int32 or int64 Tensor.
AsIntTensor(DataType dtype,const std::vector<int64> & values)944 Tensor AsIntTensor(DataType dtype, const std::vector<int64>& values) {
945   switch (dtype) {
946     case DT_INT32: {
947       std::vector<int32> values32(values.begin(), values.end());
948       return test::AsTensor<int32>(values32);
949     }
950     case DT_INT64:
951       return test::AsTensor<int64>(values);
952     default:
953       LOG(FATAL);
954   }
955 }
956 
TEST_F(OpTest,Abs)957 TEST_F(OpTest, Abs) {
958   Repeatedly([this]() {
959     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
960     return ExpectTfAndXlaOutputsAreClose(
961         OpTestBuilder("Abs").RandomInput(type).Attr("T", type));
962   });
963 }
964 
TEST_F(OpTest,Acosh)965 TEST_F(OpTest, Acosh) {
966   Repeatedly([this]() {
967     return ExpectTfAndXlaOutputsAreClose(
968         OpTestBuilder("Acosh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
969   });
970 }
971 
TEST_F(OpTest,Add)972 TEST_F(OpTest, Add) {
973   Repeatedly([this]() {
974     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
975     auto dims = BroadcastableDims();
976     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Add")
977                                              .RandomInput(type, dims.first)
978                                              .RandomInput(type, dims.second)
979                                              .Attr("T", type));
980   });
981 }
982 
TEST_F(OpTest,AddN)983 TEST_F(OpTest, AddN) {
984   Repeatedly([this]() {
985     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
986     int n = std::uniform_int_distribution<int>(1, 5)(generator());
987 
988     auto shape = RandomDims();
989 
990     OpTestBuilder builder("AddN");
991     builder.Attr("T", type);
992     builder.Attr("N", n);
993     for (int i = 0; i < n; ++i) {
994       builder.RandomInput(type, shape);
995     }
996     return ExpectTfAndXlaOutputsAreClose(builder);
997   });
998 }
999 
TEST_F(OpTest,All)1000 TEST_F(OpTest, All) {
1001   Repeatedly([this]() {
1002     std::vector<int64> data_dims = RandomDims();
1003     Tensor indices = RandomReductionIndices(data_dims.size());
1004     bool keep_dims = Choose<bool>({false, true});
1005     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("All")
1006                                              .RandomInput(DT_BOOL, data_dims)
1007                                              .Input(indices)
1008                                              .Attr("keep_dims", keep_dims));
1009   });
1010 }
1011 
TEST_F(OpTest,Angle)1012 TEST_F(OpTest, Angle) {
1013   Repeatedly([this]() {
1014     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Angle")
1015                                              .RandomInput(DT_COMPLEX64)
1016                                              .Attr("T", DT_COMPLEX64));
1017   });
1018 }
1019 
TEST_F(OpTest,Any)1020 TEST_F(OpTest, Any) {
1021   Repeatedly([this]() {
1022     std::vector<int64> data_dims = RandomDims();
1023     Tensor indices = RandomReductionIndices(data_dims.size());
1024     bool keep_dims = Choose<bool>({false, true});
1025     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Any")
1026                                              .RandomInput(DT_BOOL, data_dims)
1027                                              .Input(indices)
1028                                              .Attr("keep_dims", keep_dims));
1029   });
1030 }
1031 
TEST_F(OpTest,ApproximateEqual)1032 TEST_F(OpTest, ApproximateEqual) {
1033   Repeatedly([this]() {
1034     auto dims = BroadcastableDims();
1035     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
1036     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ApproximateEqual")
1037                                              .RandomInput(type, dims.first)
1038                                              .RandomInput(type, dims.second)
1039                                              .Attr("T", DT_FLOAT));
1040   });
1041 }
1042 
TEST_F(OpTest,ArgMax)1043 TEST_F(OpTest, ArgMax) {
1044   Repeatedly([this]() {
1045     std::vector<int64> dims = RandomDims(1, 5, 1);
1046     int num_dims = dims.size();
1047     int reduce_dim =
1048         std::uniform_int_distribution<int32>(-num_dims, num_dims)(generator());
1049     return ExpectTfAndXlaOutputsAreClose(
1050         OpTestBuilder("ArgMax")
1051             .RandomUniqueInput(DT_FLOAT, dims)
1052             .Input(test::AsScalar<int32>(reduce_dim))
1053             .Attr("T", DT_FLOAT)
1054             .Attr("Tidx", DT_INT32)
1055             .Attr("output_type", DT_INT32));
1056   });
1057 }
1058 
TEST_F(OpTest,ArgMin)1059 TEST_F(OpTest, ArgMin) {
1060   Repeatedly([this]() {
1061     std::vector<int64> dims = RandomDims(1, 5, 1);
1062     int num_dims = dims.size();
1063     int reduce_dim =
1064         std::uniform_int_distribution<int32>(-num_dims, num_dims)(generator());
1065     return ExpectTfAndXlaOutputsAreClose(
1066         OpTestBuilder("ArgMin")
1067             .RandomUniqueInput(DT_FLOAT, dims)
1068             .Input(test::AsScalar<int32>(reduce_dim))
1069             .Attr("T", DT_FLOAT)
1070             .Attr("Tidx", DT_INT32)
1071             .Attr("output_type", DT_INT32));
1072   });
1073 }
1074 
TEST_F(OpTest,Asinh)1075 TEST_F(OpTest, Asinh) {
1076   Repeatedly([this]() {
1077     return ExpectTfAndXlaOutputsAreClose(
1078         OpTestBuilder("Asinh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
1079   });
1080 }
1081 
TEST_F(OpTest,Atanh)1082 TEST_F(OpTest, Atanh) {
1083   Repeatedly([this]() {
1084     return ExpectTfAndXlaOutputsAreClose(
1085         OpTestBuilder("Atanh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
1086   });
1087 }
1088 
TEST_F(OpTest,Atan)1089 TEST_F(OpTest, Atan) {
1090   Repeatedly([this]() {
1091     return ExpectTfAndXlaOutputsAreClose(
1092         OpTestBuilder("Atan").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
1093   });
1094 }
1095 
TEST_F(OpTest,Atan2)1096 TEST_F(OpTest, Atan2) {
1097   Repeatedly([this]() {
1098     auto dims = BroadcastableDims();
1099     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Atan2")
1100                                              .RandomInput(DT_FLOAT, dims.first)
1101                                              .RandomInput(DT_FLOAT, dims.second)
1102                                              .Attr("T", DT_FLOAT));
1103   });
1104 }
1105 
TEST_F(OpTest,AvgPool)1106 TEST_F(OpTest, AvgPool) {
1107   Repeatedly([this]() {
1108     std::uniform_int_distribution<int> random_int(1, 5);
1109     std::vector<int64> dims = RandomDims(4, 4, 1);
1110     int kernel_rows =
1111         std::uniform_int_distribution<int>(1, dims[1])(generator());
1112     int kernel_cols =
1113         std::uniform_int_distribution<int>(1, dims[2])(generator());
1114     int stride_rows = random_int(generator()),
1115         stride_cols = random_int(generator());
1116     string padding = Choose<string>({"SAME", "VALID"});
1117     return ExpectTfAndXlaOutputsAreClose(
1118         OpTestBuilder("AvgPool")
1119             .RandomInput(DT_FLOAT, dims)
1120             .Attr("T", DT_FLOAT)
1121             .Attr("ksize", {1, kernel_rows, kernel_cols, 1})
1122             .Attr("strides", {1, stride_rows, stride_cols, 1})
1123             .Attr("padding", padding)
1124             .Attr("data_format", "NHWC"));
1125   });
1126   // TODO(phawkins): the CPU device only implements spatial pooling. Add tests
1127   // for batch pooling when supported.
1128 }
1129 
TEST_F(OpTest,AvgPool3D)1130 TEST_F(OpTest, AvgPool3D) {
1131   Repeatedly([this]() {
1132     std::uniform_int_distribution<int> random_int(1, 5);
1133     std::vector<int64> dims = RandomDims(5, 5, 1);
1134 
1135     std::vector<int64> input_dims, kernel_dims, stride_dims;
1136     for (int i = 0; i < 3; ++i) {
1137       kernel_dims.push_back(
1138           std::uniform_int_distribution<int>(1, dims[i])(generator()));
1139       input_dims.push_back(dims[i]);
1140       stride_dims.push_back(random_int(generator()));
1141     }
1142     int64 batch = dims[3];
1143     int64 feature = dims[4];
1144 
1145     string padding = Choose<string>({"SAME", "VALID"});
1146     return ExpectTfAndXlaOutputsAreClose(
1147         OpTestBuilder("AvgPool3D")
1148             .RandomInput(DT_FLOAT,
1149                          ImageDims(FORMAT_NHWC, batch, feature, input_dims))
1150             .Attr("T", DT_FLOAT)
1151             .Attr("ksize", ImageDims(FORMAT_NHWC, 1, 1, kernel_dims))
1152             .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, stride_dims))
1153             .Attr("padding", padding)
1154             .Attr("data_format", "NDHWC"));
1155   });
1156   // TODO(phawkins): test NCHW format (not supported by CPU)
1157 }
1158 
TEST_F(OpTest,AvgPoolGrad)1159 TEST_F(OpTest, AvgPoolGrad) {
1160   Repeatedly([this]() {
1161     int batch = RandomDim(1), features = RandomDim(1);
1162     WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
1163     std::vector<int32> input_dims =
1164         AsInt32s(ImageDims(FORMAT_NHWC, batch, features, d.input_dims));
1165     std::vector<int64> output_dims =
1166         ImageDims(FORMAT_NHWC, batch, features, d.output_dims);
1167     return ExpectTfAndXlaOutputsAreClose(
1168         OpTestBuilder("AvgPoolGrad")
1169             .Input(test::AsTensor<int32>(input_dims))
1170             .RandomInput(DT_FLOAT, output_dims)
1171             .Attr("T", DT_FLOAT)
1172             .Attr("ksize", ImageDims(FORMAT_NHWC, 1, 1, d.kernel_dims))
1173             .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
1174             .Attr("padding", d.padding == SAME ? "SAME" : "VALID")
1175             .Attr("data_format", "NHWC"));
1176   });
1177 }
1178 
TEST_F(OpTest,AvgPool3DGrad)1179 TEST_F(OpTest, AvgPool3DGrad) {
1180   Repeatedly([this]() {
1181     int batch = RandomDim(1), features = RandomDim(1);
1182     WindowedSpatialDims d = ChooseWindowedSpatialDims(3);
1183     std::vector<int32> input_dims =
1184         AsInt32s(ImageDims(FORMAT_NHWC, batch, features, d.input_dims));
1185     std::vector<int64> output_dims =
1186         ImageDims(FORMAT_NHWC, batch, features, d.output_dims);
1187     return ExpectTfAndXlaOutputsAreClose(
1188         OpTestBuilder("AvgPool3DGrad")
1189             .Input(test::AsTensor<int32>(input_dims))
1190             .RandomInput(DT_FLOAT, output_dims)
1191             .Attr("T", DT_FLOAT)
1192             .Attr("ksize", ImageDims(FORMAT_NHWC, 1, 1, d.kernel_dims))
1193             .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
1194             .Attr("padding", d.padding == SAME ? "SAME" : "VALID")
1195             .Attr("data_format", "NDHWC"));
1196   });
1197 }
1198 
TEST_F(OpTest,BatchMatMul)1199 TEST_F(OpTest, BatchMatMul) {
1200   Repeatedly([this]() {
1201     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
1202     std::vector<int64> output_dims = RandomDims(2, 5, 0, 7);
1203     int64 ndims = output_dims.size();
1204     int64 inner_dim = RandomDim();
1205     std::vector<int64> x_dims(output_dims), y_dims(output_dims);
1206     x_dims[ndims - 1] = inner_dim;
1207     y_dims[ndims - 2] = inner_dim;
1208 
1209     std::bernoulli_distribution random_bool;
1210     bool adj_x = random_bool(generator());
1211     bool adj_y = random_bool(generator());
1212     if (adj_x) {
1213       std::swap(x_dims[ndims - 1], x_dims[ndims - 2]);
1214     }
1215     if (adj_y) {
1216       std::swap(y_dims[ndims - 1], y_dims[ndims - 2]);
1217     }
1218 
1219     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul")
1220                                              .RandomInput(type, x_dims)
1221                                              .RandomInput(type, y_dims)
1222                                              .Attr("T", type)
1223                                              .Attr("adj_x", adj_x)
1224                                              .Attr("adj_y", adj_y));
1225   });
1226 }
1227 
TEST_F(OpTest,BatchToSpace)1228 TEST_F(OpTest, BatchToSpace) {
1229   Repeatedly([this]() {
1230     const int num_block_dims = 2;
1231     std::vector<int64> block_dims =
1232         RandomDims(num_block_dims, num_block_dims, 0, 5);
1233     int64 block_size = RandomDim(2, 5);
1234 
1235     std::vector<int64> input_dims(1 + num_block_dims + 1);
1236     input_dims[0] = RandomDim();
1237     for (int i = 0; i < num_block_dims; ++i) {
1238       input_dims[0] *= block_size;
1239       input_dims[1 + i] = block_dims[i];
1240     }
1241     input_dims[1 + num_block_dims] = RandomDim();
1242 
1243     std::vector<int64> crop_vals;
1244     std::uniform_int_distribution<int> distribution(0, 4);
1245     for (int i = 0; i < num_block_dims; ++i) {
1246       // Chooses crop values; does not always choose legal values.
1247       crop_vals.push_back(distribution(generator()));
1248       crop_vals.push_back(distribution(generator()));
1249     }
1250     Tensor crops;
1251     CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals),
1252                          TensorShape({num_block_dims, 2})));
1253 
1254     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
1255     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchToSpace")
1256                                              .RandomInput(type, input_dims)
1257                                              .Input(crops)
1258                                              .Attr("T", type)
1259                                              .Attr("block_size", block_size));
1260   });
1261 }
1262 
TEST_F(OpTest,BatchToSpaceND)1263 TEST_F(OpTest, BatchToSpaceND) {
1264   Repeatedly([this]() {
1265     std::vector<int64> block_dims = RandomDims(1, 3, 0, 5);
1266     int num_block_dims = block_dims.size();
1267     std::vector<int64> remaining_dims = RandomDims(0, 3);
1268     std::vector<int64> block_multipliers =
1269         RandomDims(block_dims.size(), block_dims.size(), 0, 4);
1270 
1271     std::vector<int64> input_dims(1 + num_block_dims + remaining_dims.size());
1272     input_dims[0] = RandomDim();
1273     for (int i = 0; i < num_block_dims; ++i) {
1274       input_dims[0] *= block_dims[i];
1275     }
1276     std::copy(block_multipliers.begin(), block_multipliers.end(),
1277               input_dims.begin() + 1);
1278     std::copy(remaining_dims.begin(), remaining_dims.end(),
1279               input_dims.begin() + 1 + num_block_dims);
1280 
1281     std::vector<int64> crop_vals;
1282     std::uniform_int_distribution<int> distribution(0, 3);
1283     for (int i = 0; i < num_block_dims; ++i) {
1284       // Chooses crop values; does not always choose legal values.
1285       crop_vals.push_back(distribution(generator()));
1286       crop_vals.push_back(distribution(generator()));
1287     }
1288     Tensor crops;
1289     CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals),
1290                          TensorShape({num_block_dims, 2})));
1291 
1292     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
1293     return ExpectTfAndXlaOutputsAreClose(
1294         OpTestBuilder("BatchToSpaceND")
1295             .RandomInput(type, input_dims)
1296             .Input(test::AsTensor<int32>(
1297                 std::vector<int32>(block_dims.begin(), block_dims.end())))
1298             .Input(crops)
1299             .Attr("T", type));
1300   });
1301 }
1302 
TEST_F(OpTest,BiasAdd)1303 TEST_F(OpTest, BiasAdd) {
1304   Repeatedly([this]() {
1305     auto x_dims = RandomDims(2, kDefaultMaxRank);
1306     auto y_dims = {x_dims[x_dims.size() - 1]};
1307     // TODO(phawkins): test both data formats.
1308     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
1309     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BiasAdd")
1310                                              .RandomInput(type, x_dims)
1311                                              .RandomInput(type, y_dims)
1312                                              .Attr("T", type));
1313   });
1314 }
1315 
TEST_F(OpTest,BiasAddGrad)1316 TEST_F(OpTest, BiasAddGrad) {
1317   Repeatedly([this]() {
1318     // TODO(phawkins): test both data formats.
1319     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
1320     return ExpectTfAndXlaOutputsAreClose(
1321         OpTestBuilder("BiasAddGrad").RandomInput(type).Attr("T", type));
1322   });
1323 }
1324 
TEST_F(OpTest,BiasAddV1)1325 TEST_F(OpTest, BiasAddV1) {
1326   Repeatedly([this]() {
1327     auto x_dims = RandomDims(2, kDefaultMaxRank);
1328     auto y_dims = {x_dims[x_dims.size() - 1]};
1329     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
1330     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BiasAddV1")
1331                                              .RandomInput(type, x_dims)
1332                                              .RandomInput(type, y_dims)
1333                                              .Attr("T", type));
1334   });
1335 }
1336 
TEST_F(OpTest,BitwiseAnd)1337 TEST_F(OpTest, BitwiseAnd) {
1338   Repeatedly([this]() {
1339     DataType type = DT_INT32;
1340     auto dims = BroadcastableDims();
1341     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BitwiseAnd")
1342                                              .RandomInput(type, dims.first)
1343                                              .RandomInput(type, dims.second)
1344                                              .Attr("T", type));
1345   });
1346 }
1347 
TEST_F(OpTest,BitwiseOr)1348 TEST_F(OpTest, BitwiseOr) {
1349   Repeatedly([this]() {
1350     DataType type = DT_INT32;
1351     auto dims = BroadcastableDims();
1352     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BitwiseOr")
1353                                              .RandomInput(type, dims.first)
1354                                              .RandomInput(type, dims.second)
1355                                              .Attr("T", type));
1356   });
1357 }
1358 
TEST_F(OpTest,BroadcastArgs)1359 TEST_F(OpTest, BroadcastArgs) {
1360   Repeatedly([this]() {
1361     // TODO(phawkins): only int32 seems to be implemented in Tensorflow.
1362     // auto type = Choose<DataType>({DT_INT32, DT_INT64});
1363     DataType type = DT_INT32;
1364     auto dims = BroadcastableDims();
1365     return ExpectTfAndXlaOutputsAreClose(
1366         OpTestBuilder("BroadcastArgs")
1367             .Input(AsIntTensor(type, dims.first))
1368             .Input(AsIntTensor(type, dims.second))
1369             .Attr("T", type));
1370   });
1371 }
1372 
TEST_F(OpTest,BroadcastGradientArgs)1373 TEST_F(OpTest, BroadcastGradientArgs) {
1374   Repeatedly([this]() {
1375     // TODO(phawkins): only int32 seems to be implemented in Tensorflow.
1376     // auto type = Choose<DataType>({DT_INT32, DT_INT64});
1377     DataType type = DT_INT32;
1378     auto dims = BroadcastableDims();
1379     return ExpectTfAndXlaOutputsAreClose(
1380         OpTestBuilder("BroadcastGradientArgs")
1381             .Input(AsIntTensor(type, dims.first))
1382             .Input(AsIntTensor(type, dims.second))
1383             .Attr("T", type));
1384   });
1385 }
1386 
TEST_F(OpTest,Cast)1387 TEST_F(OpTest, Cast) {
1388   Repeatedly([this]() {
1389     DataType src_type, dst_type;
1390     src_type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64});
1391     dst_type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64});
1392     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Cast")
1393                                              .RandomInput(src_type)
1394                                              .Attr("SrcT", src_type)
1395                                              .Attr("DstT", dst_type));
1396   });
1397 }
1398 
TEST_F(OpTest,CastBF16)1399 TEST_F(OpTest, CastBF16) {
1400   Repeatedly([this]() {
1401     DataType src_type, dst_type;
1402     src_type = Choose<DataType>({DT_FLOAT});
1403     dst_type = Choose<DataType>({DT_BFLOAT16});
1404     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Cast")
1405                                              .RandomInput(src_type)
1406                                              .Attr("SrcT", src_type)
1407                                              .Attr("DstT", dst_type)
1408                                              .Attr("Truncate", true));
1409   });
1410 }
1411 
TEST_F(OpTest,Ceil)1412 TEST_F(OpTest, Ceil) {
1413   Repeatedly([this]() {
1414     return ExpectTfAndXlaOutputsAreClose(
1415         OpTestBuilder("Ceil").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
1416   });
1417 }
1418 
TEST_F(OpTest,Complex)1419 TEST_F(OpTest, Complex) {
1420   Repeatedly([this]() {
1421     auto dims = BroadcastableDims();
1422     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Complex")
1423                                              .RandomInput(DT_FLOAT, dims.first)
1424                                              .RandomInput(DT_FLOAT, dims.second)
1425                                              .Attr("T", DT_FLOAT));
1426   });
1427 }
1428 
TEST_F(OpTest,Concat)1429 TEST_F(OpTest, Concat) {
1430   Repeatedly([this]() {
1431     auto type = Choose<DataType>(kAllXlaTypes);
1432     int n = std::uniform_int_distribution<int>(2, 5)(generator());
1433 
1434     std::vector<int64> dims = RandomDims(1);
1435     int concat_dim =
1436         std::uniform_int_distribution<int32>(0, dims.size() - 1)(generator());
1437 
1438     OpTestBuilder builder("Concat");
1439     builder.Input(test::AsScalar<int32>(concat_dim));
1440     builder.Attr("T", type);
1441     builder.Attr("N", n);
1442     for (int i = 0; i < n; ++i) {
1443       std::vector<int64> shape = dims;
1444       shape[concat_dim] = RandomDim();
1445       builder.RandomInput(type, shape);
1446     }
1447     return ExpectTfAndXlaOutputsAreClose(builder);
1448   });
1449 }
1450 
TEST_F(OpTest,ConcatOffset)1451 TEST_F(OpTest, ConcatOffset) {
1452   Repeatedly([this]() {
1453     int n = std::uniform_int_distribution<int>(2, 5)(generator());
1454 
1455     std::vector<int64> dims = RandomDims(1);
1456     int concat_dim =
1457         std::uniform_int_distribution<int32>(0, dims.size() - 1)(generator());
1458 
1459     OpTestBuilder builder("ConcatOffset");
1460     builder.Input(test::AsScalar<int32>(concat_dim));
1461     builder.Attr("N", n);
1462     for (int i = 0; i < n; ++i) {
1463       std::vector<int32> shape(dims.begin(), dims.end());
1464       shape[concat_dim] = RandomDim();
1465       builder.Input(test::AsTensor<int32>(shape));
1466     }
1467     return ExpectTfAndXlaOutputsAreClose(builder);
1468   });
1469 }
1470 
TEST_F(OpTest,Conj)1471 TEST_F(OpTest, Conj) {
1472   Repeatedly([this]() {
1473     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Conj")
1474                                              .RandomInput(DT_COMPLEX64)
1475                                              .Attr("T", DT_COMPLEX64));
1476   });
1477 }
1478 
TEST_F(OpTest,FFT)1479 TEST_F(OpTest, FFT) {
1480   Repeatedly([this]() {
1481     std::vector<int64> dims = RandomDims(1, kDefaultMaxRank);
1482     return ExpectTfAndXlaOutputsAreClose(
1483         OpTestBuilder("FFT").RandomInput(DT_COMPLEX64, dims));
1484   });
1485 }
1486 
TEST_F(OpTest,FFT2D)1487 TEST_F(OpTest, FFT2D) {
1488   Repeatedly([this]() {
1489     std::vector<int64> dims = RandomDims(2, kDefaultMaxRank);
1490     return ExpectTfAndXlaOutputsAreClose(
1491         OpTestBuilder("FFT2D").RandomInput(DT_COMPLEX64, dims));
1492   });
1493 }
1494 
TEST_F(OpTest,FFT3D)1495 TEST_F(OpTest, FFT3D) {
1496   Repeatedly([this]() {
1497     std::vector<int64> dims = RandomDims(3, kDefaultMaxRank);
1498     return ExpectTfAndXlaOutputsAreClose(
1499         OpTestBuilder("FFT3D").RandomInput(DT_COMPLEX64, dims));
1500   });
1501 }
1502 
TEST_F(OpTest,IFFT)1503 TEST_F(OpTest, IFFT) {
1504   Repeatedly([this]() {
1505     std::vector<int64> dims = RandomDims(1, kDefaultMaxRank);
1506     return ExpectTfAndXlaOutputsAreClose(
1507         OpTestBuilder("IFFT").RandomInput(DT_COMPLEX64, dims));
1508   });
1509 }
1510 
TEST_F(OpTest,IFFT2D)1511 TEST_F(OpTest, IFFT2D) {
1512   Repeatedly([this]() {
1513     std::vector<int64> dims = RandomDims(2, kDefaultMaxRank);
1514     return ExpectTfAndXlaOutputsAreClose(
1515         OpTestBuilder("IFFT2D").RandomInput(DT_COMPLEX64, dims));
1516   });
1517 }
1518 
TEST_F(OpTest,IFFT3D)1519 TEST_F(OpTest, IFFT3D) {
1520   Repeatedly([this]() {
1521     std::vector<int64> dims = RandomDims(3, kDefaultMaxRank);
1522     return ExpectTfAndXlaOutputsAreClose(
1523         OpTestBuilder("IFFT3D").RandomInput(DT_COMPLEX64, dims));
1524   });
1525 }
1526 
TEST_F(OpTest,RFFT)1527 TEST_F(OpTest, RFFT) {
1528   Repeatedly([this]() {
1529     std::vector<int64> dims = RandomDims(1, kDefaultMaxRank, 3);
1530     Tensor fft_shape = test::AsTensor<int32>(AsInt32s({dims[dims.size() - 1]}));
1531     return ExpectTfAndXlaOutputsAreClose(
1532         OpTestBuilder("RFFT").RandomInput(DT_FLOAT, dims).Input(fft_shape));
1533   });
1534 }
1535 
TEST_F(OpTest,RFFT2D)1536 TEST_F(OpTest, RFFT2D) {
1537   Repeatedly([this]() {
1538     std::vector<int64> dims = RandomDims(2, kDefaultMaxRank, 3);
1539     Tensor fft_shape = test::AsTensor<int32>(
1540         AsInt32s({dims[dims.size() - 2], dims[dims.size() - 1]}));
1541     return ExpectTfAndXlaOutputsAreClose(
1542         OpTestBuilder("RFFT2D").RandomInput(DT_FLOAT, dims).Input(fft_shape));
1543   });
1544 }
1545 
TEST_F(OpTest,RFFT3D)1546 TEST_F(OpTest, RFFT3D) {
1547   Repeatedly([this]() {
1548     std::vector<int64> dims = RandomDims(3, kDefaultMaxRank, 3);
1549     Tensor fft_shape = test::AsTensor<int32>(AsInt32s(
1550         {dims[dims.size() - 3], dims[dims.size() - 2], dims[dims.size() - 1]}));
1551     return ExpectTfAndXlaOutputsAreClose(
1552         OpTestBuilder("RFFT3D").RandomInput(DT_FLOAT, dims).Input(fft_shape));
1553   });
1554 }
1555 
TEST_F(OpTest,IRFFT)1556 TEST_F(OpTest, IRFFT) {
1557   Repeatedly([this]() {
1558     std::vector<int64> dims = RandomDims(1, kDefaultMaxRank, 3);
1559     int64 orig_size = dims[dims.size() - 1];
1560     dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1;
1561     Tensor fft_shape = test::AsTensor<int32>(AsInt32s({orig_size}));
1562     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT")
1563                                              .RandomInput(DT_COMPLEX64, dims)
1564                                              .Input(fft_shape));
1565   });
1566 }
1567 
TEST_F(OpTest,IRFFT2D)1568 TEST_F(OpTest, IRFFT2D) {
1569   Repeatedly([this]() {
1570     std::vector<int64> dims = RandomDims(2, kDefaultMaxRank, 3);
1571     std::vector<int64> orig_size = {dims[dims.size() - 2],
1572                                     dims[dims.size() - 1]};
1573     dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1;
1574     Tensor fft_shape = test::AsTensor<int32>(AsInt32s({orig_size}));
1575     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT2D")
1576                                              .RandomInput(DT_COMPLEX64, dims)
1577                                              .Input(fft_shape));
1578   });
1579 }
1580 
TEST_F(OpTest,IRFFT3D)1581 TEST_F(OpTest, IRFFT3D) {
1582   Repeatedly([this]() {
1583     std::vector<int64> dims = RandomDims(3, kDefaultMaxRank, 3);
1584     std::vector<int64> orig_size = {
1585         dims[dims.size() - 3], dims[dims.size() - 2], dims[dims.size() - 1]};
1586     dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1;
1587     Tensor fft_shape = test::AsTensor<int32>(AsInt32s({orig_size}));
1588     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT3D")
1589                                              .RandomInput(DT_COMPLEX64, dims)
1590                                              .Input(fft_shape));
1591   });
1592 }
1593 
TEST_F(OpTest,Conv2D)1594 TEST_F(OpTest, Conv2D) {
1595   Repeatedly([this]() {
1596     WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
1597     std::uniform_int_distribution<int> random_int(1, 5);
1598     int features_in = random_int(generator());
1599     int features_out = random_int(generator());
1600 
1601     int64 batch = RandomDim();
1602 
1603     std::vector<int64> data_dims =
1604         ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims);
1605 
1606     std::vector<int64> kernel_dims = {d.kernel_dims[0], d.kernel_dims[1],
1607                                       features_in, features_out};
1608     DataType type = DT_FLOAT;
1609     return ExpectTfAndXlaOutputsAreClose(
1610         OpTestBuilder("Conv2D")
1611             .RandomInput(type, data_dims)
1612             .RandomInput(type, kernel_dims)
1613             .Attr("T", type)
1614             .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
1615             .Attr("padding", d.padding == SAME ? "SAME" : "VALID")
1616             .Attr("data_format", "NHWC"));
1617   });
1618 }
1619 
TEST_F(OpTest,Conv2DBackpropFilter)1620 TEST_F(OpTest, Conv2DBackpropFilter) {
1621   Repeatedly([this]() {
1622     WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
1623     std::uniform_int_distribution<int> random_int(1, 5);
1624     int features_in = random_int(generator());
1625     int features_out = random_int(generator());
1626     int32 batch = RandomDim();
1627     std::vector<int64> activations =
1628         ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims);
1629     std::vector<int64> backprop =
1630         ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
1631     Tensor kernel_shape = test::AsTensor<int32>(AsInt32s(
1632         {d.kernel_dims[0], d.kernel_dims[1], features_in, features_out}));
1633     DataType type = DT_FLOAT;
1634     return ExpectTfAndXlaOutputsAreClose(
1635         OpTestBuilder("Conv2DBackpropFilter")
1636             .RandomInput(type, activations)
1637             .Input(kernel_shape)
1638             .RandomInput(type, backprop)
1639             .Attr("T", type)
1640             .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
1641             .Attr("padding", d.padding == SAME ? "SAME" : "VALID")
1642             .Attr("data_format", "NHWC"));
1643   });
1644 }
1645 
TEST_F(OpTest,Conv2DBackpropInput)1646 TEST_F(OpTest, Conv2DBackpropInput) {
1647   Repeatedly([this]() {
1648     WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
1649     std::uniform_int_distribution<int> random_int(1, 5);
1650     int features_in = random_int(generator());
1651     int features_out = random_int(generator());
1652     int32 batch = RandomDim();
1653     Tensor in_shape = test::AsTensor<int32>(
1654         AsInt32s(ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims)));
1655     std::vector<int64> backprop =
1656         ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
1657     std::vector<int64> kernel = {d.kernel_dims[0], d.kernel_dims[1],
1658                                  features_in, features_out};
1659     DataType type = DT_FLOAT;
1660     return ExpectTfAndXlaOutputsAreClose(
1661         OpTestBuilder("Conv2DBackpropInput")
1662             .Input(in_shape)
1663             .RandomInput(type, kernel)
1664             .RandomInput(type, backprop)
1665             .Attr("T", type)
1666             .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
1667             .Attr("padding", d.padding == SAME ? "SAME" : "VALID")
1668             .Attr("data_format", "NHWC"));
1669   });
1670 }
1671 
TEST_F(OpTest,Conv3D)1672 TEST_F(OpTest, Conv3D) {
1673   Repeatedly([this]() {
1674     WindowedSpatialDims d = ChooseWindowedSpatialDims(3);
1675     std::uniform_int_distribution<int> random_int(1, 5);
1676     int features_in = random_int(generator());
1677     int features_out = random_int(generator());
1678     std::vector<int64> data = {RandomDim(), d.input_dims[0], d.input_dims[1],
1679                                d.input_dims[2], features_in};
1680 
1681     std::vector<int64> kernel = {d.kernel_dims[0], d.kernel_dims[1],
1682                                  d.kernel_dims[2], features_in, features_out};
1683     DataType type = DT_FLOAT;
1684     return ExpectTfAndXlaOutputsAreClose(
1685         OpTestBuilder("Conv3D")
1686             .RandomInput(type, data)
1687             .RandomInput(type, kernel)
1688             .Attr("T", type)
1689             .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
1690             .Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
1691   });
1692 }
1693 
TEST_F(OpTest,Conv3DBackpropFilter)1694 TEST_F(OpTest, Conv3DBackpropFilter) {
1695   Repeatedly([this]() {
1696     WindowedSpatialDims d = ChooseWindowedSpatialDims(3);
1697     std::uniform_int_distribution<int> random_int(1, 5);
1698     int features_in = random_int(generator());
1699     int features_out = random_int(generator());
1700     int32 batch = RandomDim(1);
1701     std::vector<int64> activations =
1702         ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims);
1703     std::vector<int64> backprop =
1704         ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
1705     Tensor kernel_shape = test::AsTensor<int32>(
1706         AsInt32s({d.kernel_dims[0], d.kernel_dims[1], d.kernel_dims[2],
1707                   features_in, features_out}));
1708     DataType type = DT_FLOAT;
1709     return ExpectTfAndXlaOutputsAreClose(
1710         OpTestBuilder("Conv3DBackpropFilterV2")
1711             .RandomInput(type, activations)
1712             .Input(kernel_shape)
1713             .RandomInput(type, backprop)
1714             .Attr("T", type)
1715             .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
1716             .Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
1717   });
1718 }
1719 
TEST_F(OpTest,Conv3DBackpropInput)1720 TEST_F(OpTest, Conv3DBackpropInput) {
1721   Repeatedly([this]() {
1722     WindowedSpatialDims d = ChooseWindowedSpatialDims(3);
1723     std::uniform_int_distribution<int> random_int(1, 5);
1724     int features_in = random_int(generator());
1725     int features_out = random_int(generator());
1726     int32 batch = RandomDim(1);
1727     Tensor in_shape = test::AsTensor<int32>(
1728         AsInt32s(ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims)));
1729     std::vector<int64> backprop =
1730         ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims);
1731     std::vector<int64> kernel = {d.kernel_dims[0], d.kernel_dims[1],
1732                                  d.kernel_dims[2], features_in, features_out};
1733     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
1734     return ExpectTfAndXlaOutputsAreClose(
1735         OpTestBuilder("Conv3DBackpropInputV2")
1736             .Input(in_shape)
1737             .RandomInput(type, kernel)
1738             .RandomInput(type, backprop)
1739             .Attr("T", type)
1740             .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
1741             .Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
1742   });
1743 }
1744 
TEST_F(OpTest,Cos)1745 TEST_F(OpTest, Cos) {
1746   Repeatedly([this]() {
1747     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
1748     return ExpectTfAndXlaOutputsAreClose(
1749         OpTestBuilder("Cos").RandomInput(type).Attr("T", type));
1750   });
1751 }
1752 
TEST_F(OpTest,Cosh)1753 TEST_F(OpTest, Cosh) {
1754   Repeatedly([this]() {
1755     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
1756     return ExpectTfAndXlaOutputsAreClose(
1757         OpTestBuilder("Cosh").RandomInput(type).Attr("T", type));
1758   });
1759 }
1760 
TEST_F(OpTest,DepthToSpace)1761 TEST_F(OpTest, DepthToSpace) {
1762   Repeatedly([this]() {
1763     int64 block = RandomDim(2, 5);
1764     std::vector<int64> input_dims = RandomDims(4, 4);
1765     input_dims[1] = (input_dims[1] + (block - 1)) / block;
1766     input_dims[2] = (input_dims[2] + (block - 1)) / block;
1767     input_dims[3] *= block * block;
1768     auto type = Choose<DataType>(kAllXlaTypes);
1769     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("DepthToSpace")
1770                                              .RandomInput(type, input_dims)
1771                                              .Attr("T", type)
1772                                              .Attr("block_size", block));
1773   });
1774 }
1775 
TEST_F(OpTest,DepthwiseConv2DNative)1776 TEST_F(OpTest, DepthwiseConv2DNative) {
1777   Repeatedly([this]() {
1778     WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
1779     std::uniform_int_distribution<int> random_int(1, 5);
1780     int features_in = random_int(generator());
1781     int depth_multiplier = random_int(generator());
1782     std::vector<int64> input_dims = {RandomDim(), d.input_dims[0],
1783                                      d.input_dims[1], features_in};
1784 
1785     std::vector<int64> kernel_dims = {d.kernel_dims[0], d.kernel_dims[1],
1786                                       features_in, depth_multiplier};
1787     std::vector<int64> strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims);
1788     strides[2] = strides[1];  // Current impl only supports equal strides
1789     return ExpectTfAndXlaOutputsAreClose(
1790         OpTestBuilder("DepthwiseConv2dNative")
1791             .RandomInput(DT_FLOAT, input_dims)
1792             .RandomInput(DT_FLOAT, kernel_dims)
1793             .Attr("T", DT_FLOAT)
1794             .Attr("strides", strides)
1795             .Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
1796   });
1797 }
1798 
TEST_F(OpTest,DepthwiseConv2DBackpropFilter)1799 TEST_F(OpTest, DepthwiseConv2DBackpropFilter) {
1800   Repeatedly([this]() {
1801     WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
1802     std::uniform_int_distribution<int> random_int(1, 5);
1803     int features_in = random_int(generator());
1804     int depth_multiplier = random_int(generator());
1805     int32 batch = RandomDim();
1806     std::vector<int64> activations =
1807         ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims);
1808     std::vector<int64> backprop = ImageDims(
1809         FORMAT_NHWC, batch, features_in * depth_multiplier, d.output_dims);
1810     Tensor kernel_shape = test::AsTensor<int32>(AsInt32s(
1811         {d.kernel_dims[0], d.kernel_dims[1], features_in, depth_multiplier}));
1812     std::vector<int64> strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims);
1813     strides[2] = strides[1];  // Current impl only supports equal strides
1814     return ExpectTfAndXlaOutputsAreClose(
1815         OpTestBuilder("DepthwiseConv2dNativeBackpropFilter")
1816             .RandomInput(DT_FLOAT, activations)
1817             .Input(kernel_shape)
1818             .RandomInput(DT_FLOAT, backprop)
1819             .Attr("T", DT_FLOAT)
1820             .Attr("strides", strides)
1821             .Attr("padding", d.padding == SAME ? "SAME" : "VALID")
1822             .Attr("data_format", "NHWC"));
1823   });
1824 }
1825 
TEST_F(OpTest,DepthwiseConv2DBackpropInput)1826 TEST_F(OpTest, DepthwiseConv2DBackpropInput) {
1827   Repeatedly([this]() {
1828     WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
1829     std::uniform_int_distribution<int> random_int(1, 5);
1830     int features_in = random_int(generator());
1831     int depth_multiplier = random_int(generator());
1832     int32 batch = RandomDim();
1833     Tensor in_shape = test::AsTensor<int32>(
1834         AsInt32s(ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims)));
1835     std::vector<int64> backprop = ImageDims(
1836         FORMAT_NHWC, batch, features_in * depth_multiplier, d.output_dims);
1837     std::vector<int64> kernel = {d.kernel_dims[0], d.kernel_dims[1],
1838                                  features_in, depth_multiplier};
1839     std::vector<int64> strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims);
1840     strides[2] = strides[1];  // Current impl only supports equal strides
1841     return ExpectTfAndXlaOutputsAreClose(
1842         OpTestBuilder("DepthwiseConv2dNativeBackpropInput")
1843             .Input(in_shape)
1844             .RandomInput(DT_FLOAT, kernel)
1845             .RandomInput(DT_FLOAT, backprop)
1846             .Attr("T", DT_FLOAT)
1847             .Attr("strides", strides)
1848             .Attr("padding", d.padding == SAME ? "SAME" : "VALID")
1849             .Attr("data_format", "NHWC"));
1850   });
1851 }
1852 
TEST_F(OpTest,Diag)1853 TEST_F(OpTest, Diag) {
1854   Repeatedly([this]() {
1855     auto type = Choose<DataType>(kAllXlaTypes);
1856     std::vector<int64> dims;
1857     // Diag causes a quadratic blowup in output size.
1858     int64 size;
1859     do {
1860       dims = RandomDims(1);
1861       size = TensorShape(dims).num_elements();
1862     } while (size * size > tf_xla_max_tensor_size);
1863     return ExpectTfAndXlaOutputsAreClose(
1864         OpTestBuilder("Diag").RandomInput(type, dims).Attr("T", type));
1865   });
1866 }
1867 
TEST_F(OpTest,DiagPart)1868 TEST_F(OpTest, DiagPart) {
1869   Repeatedly([this]() {
1870     auto type = Choose<DataType>(kAllXlaTypes);
1871     auto dims = RandomDims(1, 3);
1872     // Duplicate the random dims.
1873     std::vector<int64> doubled_dims(dims.size() * 2);
1874     std::copy(dims.begin(), dims.end(), doubled_dims.begin());
1875     std::copy(dims.begin(), dims.end(), doubled_dims.begin() + dims.size());
1876     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("DiagPart")
1877                                              .RandomInput(type, doubled_dims)
1878                                              .Attr("T", type));
1879   });
1880 }
1881 
TEST_F(OpTest,Div)1882 TEST_F(OpTest, Div) {
1883   Repeatedly([this]() {
1884     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
1885     auto dims = BroadcastableDims();
1886     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Div")
1887                                              .RandomInput(type, dims.first)
1888                                              .RandomInput(type, dims.second)
1889                                              .Attr("T", type));
1890   });
1891 }
1892 
TEST_F(OpTest,DynamicStitch)1893 TEST_F(OpTest, DynamicStitch) {
1894   Repeatedly([this]() {
1895     auto type = Choose<DataType>(kAllXlaTypes);
1896     int n = std::uniform_int_distribution<int>(2, 5)(generator());
1897     OpTestBuilder builder("DynamicStitch");
1898     builder.Attr("T", type);
1899     builder.Attr("N", n);
1900     std::vector<std::vector<int64>> index_dims;
1901     int size = 0;
1902     // TODO(phawkins): the XLA implementation of DynamicStitch does not
1903     // accept an empty set of indices.
1904     do {
1905       size = 0;
1906       index_dims.clear();
1907       for (int i = 0; i < n; ++i) {
1908         std::vector<int64> dims = RandomDims(0, 3, 0, 5);
1909         size += TensorShape(dims).num_elements();
1910         index_dims.push_back(dims);
1911       }
1912     } while (size == 0);
1913 
1914     // Shuffle the range of indices that cover the output.
1915     // TODO(phawkins): The documentation for DynamicStitch doesn't require
1916     // that the indices cover all positions of the output. The XLA
1917     // implementation does so require. However, the native TF implementation
1918     // leaves undefined values if we don't cover everything, so we can't
1919     // really test that case anyway.
1920     std::vector<int32> indices(size);
1921     std::iota(indices.begin(), indices.end(), 0);
1922     std::shuffle(indices.begin(), indices.end(), generator());
1923 
1924     int pos = 0;
1925     for (int i = 0; i < n; ++i) {
1926       TensorShape shape(index_dims[i]);
1927       Tensor t = test::AsTensor<int32>(
1928           absl::Span<const int32>(indices).subspan(pos, shape.num_elements()),
1929           shape);
1930       builder.Input(t);
1931       pos += t.NumElements();
1932     }
1933 
1934     std::vector<int64> constant_dims = RandomDims(0, 3, 0, 5);
1935     for (int i = 0; i < n; ++i) {
1936       std::vector<int64> dims(index_dims[i].begin(), index_dims[i].end());
1937       std::copy(constant_dims.begin(), constant_dims.end(),
1938                 std::back_inserter(dims));
1939       builder.RandomInput(type, dims);
1940     }
1941     return ExpectTfAndXlaOutputsAreClose(builder);
1942   });
1943 }
1944 
TEST_F(OpTest,Elu)1945 TEST_F(OpTest, Elu) {
1946   Repeatedly([this]() {
1947     return ExpectTfAndXlaOutputsAreClose(
1948         OpTestBuilder("Elu").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
1949   });
1950 }
1951 
TEST_F(OpTest,EluGrad)1952 TEST_F(OpTest, EluGrad) {
1953   Repeatedly([this]() {
1954     auto dims = RandomDims();
1955     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("EluGrad")
1956                                              .RandomInput(DT_FLOAT, dims)
1957                                              .RandomInput(DT_FLOAT, dims)
1958                                              .Attr("T", DT_FLOAT));
1959   });
1960 }
1961 
TEST_F(OpTest,Selu)1962 TEST_F(OpTest, Selu) {
1963   Repeatedly([this]() {
1964     return ExpectTfAndXlaOutputsAreClose(
1965         OpTestBuilder("Selu").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
1966   });
1967 }
1968 
TEST_F(OpTest,SeluGrad)1969 TEST_F(OpTest, SeluGrad) {
1970   Repeatedly([this]() {
1971     auto dims = RandomDims();
1972     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SeluGrad")
1973                                              .RandomInput(DT_FLOAT, dims)
1974                                              .RandomInput(DT_FLOAT, dims)
1975                                              .Attr("T", DT_FLOAT));
1976   });
1977 }
1978 
TEST_F(OpTest,Equal)1979 TEST_F(OpTest, Equal) {
1980   Repeatedly([this]() {
1981     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
1982     auto dims = BroadcastableDims();
1983     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Equal")
1984                                              .RandomInput(type, dims.first)
1985                                              .RandomInput(type, dims.second)
1986                                              .Attr("T", type));
1987   });
1988 }
1989 
TEST_F(OpTest,Exp)1990 TEST_F(OpTest, Exp) {
1991   Repeatedly([this]() {
1992     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
1993     return ExpectTfAndXlaOutputsAreClose(
1994         OpTestBuilder("Exp").RandomInput(type).Attr("T", type));
1995   });
1996 }
1997 
TEST_F(OpTest,Expm1)1998 TEST_F(OpTest, Expm1) {
1999   Repeatedly([this]() {
2000     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
2001     return ExpectTfAndXlaOutputsAreClose(
2002         OpTestBuilder("Expm1").RandomInput(type).Attr("T", type));
2003   });
2004 }
2005 
TEST_F(OpTest,ExpandDims)2006 TEST_F(OpTest, ExpandDims) {
2007   Repeatedly([this]() {
2008     auto type = Choose<DataType>(kAllXlaTypes);
2009     std::vector<int64> in_dims = RandomDims();
2010     Tensor dim(DT_INT32, TensorShape());
2011     std::uniform_int_distribution<int32> d(-1 - in_dims.size(), in_dims.size());
2012     dim.scalar<int32>()() = d(generator());
2013     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ExpandDims")
2014                                              .RandomInput(type, in_dims)
2015                                              .Input(dim)
2016                                              .Attr("T", type));
2017   });
2018 }
2019 
TEST_F(OpTest,Fill)2020 TEST_F(OpTest, Fill) {
2021   Repeatedly([this]() {
2022     auto type = Choose<DataType>(kAllXlaTypes);
2023     std::vector<int64> dims = RandomDims();
2024     std::vector<int32> shape(dims.begin(), dims.end());
2025     return ExpectTfAndXlaOutputsAreClose(
2026         OpTestBuilder("Fill")
2027             .Input(test::AsTensor<int32>(shape))
2028             .RandomInput(type, {})
2029             .Attr("T", type));
2030   });
2031 }
2032 
TEST_F(OpTest,Floor)2033 TEST_F(OpTest, Floor) {
2034   Repeatedly([this]() {
2035     return ExpectTfAndXlaOutputsAreClose(
2036         OpTestBuilder("Floor").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
2037   });
2038 }
2039 
TEST_F(OpTest,FloorDiv)2040 TEST_F(OpTest, FloorDiv) {
2041   Repeatedly([this]() {
2042     DataType type = DT_INT32;
2043     auto dims = BroadcastableDims();
2044     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("FloorDiv")
2045                                              .RandomInput(type, dims.first)
2046                                              .RandomInput(type, dims.second)
2047                                              .Attr("T", type));
2048   });
2049 }
2050 
TEST_F(OpTest,FloorMod)2051 TEST_F(OpTest, FloorMod) {
2052   Repeatedly([this]() {
2053     auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
2054     auto dims = BroadcastableDims();
2055     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("FloorMod")
2056                                              .RandomInput(type, dims.first)
2057                                              .RandomInput(type, dims.second)
2058                                              .Attr("T", type));
2059   });
2060 }
2061 
TEST_F(OpTest,Greater)2062 TEST_F(OpTest, Greater) {
2063   Repeatedly([this]() {
2064     auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
2065     auto dims = BroadcastableDims();
2066     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Greater")
2067                                              .RandomInput(type, dims.first)
2068                                              .RandomInput(type, dims.second)
2069                                              .Attr("T", type));
2070   });
2071 }
2072 
TEST_F(OpTest,GreaterEqual)2073 TEST_F(OpTest, GreaterEqual) {
2074   Repeatedly([this]() {
2075     auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
2076     auto dims = BroadcastableDims();
2077     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("GreaterEqual")
2078                                              .RandomInput(type, dims.first)
2079                                              .RandomInput(type, dims.second)
2080                                              .Attr("T", type));
2081   });
2082 }
2083 
TEST_F(OpTest,Imag)2084 TEST_F(OpTest, Imag) {
2085   Repeatedly([this]() {
2086     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Imag")
2087                                              .RandomInput(DT_COMPLEX64)
2088                                              .Attr("T", DT_COMPLEX64));
2089   });
2090 }
2091 
TEST_F(OpTest,Invert)2092 TEST_F(OpTest, Invert) {
2093   Repeatedly([this]() {
2094     DataType type = DT_INT32;
2095     return ExpectTfAndXlaOutputsAreClose(
2096         OpTestBuilder("Invert").RandomInput(type).Attr("T", type));
2097   });
2098 }
2099 
TEST_F(OpTest,L2Loss)2100 TEST_F(OpTest, L2Loss) {
2101   Repeatedly([this]() {
2102     DataType type = DT_FLOAT;
2103     return ExpectTfAndXlaOutputsAreClose(
2104         OpTestBuilder("L2Loss").RandomInput(type).Attr("T", type));
2105   });
2106 }
2107 
TEST_F(OpTest,Less)2108 TEST_F(OpTest, Less) {
2109   Repeatedly([this]() {
2110     auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
2111     auto dims = BroadcastableDims();
2112     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Less")
2113                                              .RandomInput(type, dims.first)
2114                                              .RandomInput(type, dims.second)
2115                                              .Attr("T", type));
2116   });
2117 }
2118 
TEST_F(OpTest,LessEqual)2119 TEST_F(OpTest, LessEqual) {
2120   Repeatedly([this]() {
2121     auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
2122     auto dims = BroadcastableDims();
2123     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("LessEqual")
2124                                              .RandomInput(type, dims.first)
2125                                              .RandomInput(type, dims.second)
2126                                              .Attr("T", type));
2127   });
2128 }
2129 
TEST_F(OpTest,LinSpace)2130 TEST_F(OpTest, LinSpace) {
2131   Repeatedly([this]() {
2132     auto ToScalar = [](DataType type, int x) {
2133       if (type == DT_INT32) return test::AsScalar<int32>(x);
2134       return test::AsScalar<int64>(x);
2135     };
2136     std::uniform_int_distribution<int> distribution(-50, 50);
2137     auto type = Choose<DataType>({DT_INT32, DT_INT64});
2138     return ExpectTfAndXlaOutputsAreClose(
2139         OpTestBuilder("LinSpace")
2140             .RandomInput(DT_FLOAT, {})
2141             .RandomInput(DT_FLOAT, {})
2142             .Input(ToScalar(type, distribution(generator())))
2143             .Attr("T", DT_FLOAT)
2144             .Attr("Tidx", type));
2145   });
2146 }
2147 
TEST_F(OpTest,Log)2148 TEST_F(OpTest, Log) {
2149   Repeatedly([this]() {
2150     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
2151     return ExpectTfAndXlaOutputsAreClose(
2152         OpTestBuilder("Log").RandomInput(type).Attr("T", type));
2153   });
2154 }
2155 
TEST_F(OpTest,Log1p)2156 TEST_F(OpTest, Log1p) {
2157   Repeatedly([this]() {
2158     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
2159     return ExpectTfAndXlaOutputsAreClose(
2160         OpTestBuilder("Log1p").RandomInput(type).Attr("T", DT_FLOAT));
2161   });
2162 }
2163 
TEST_F(OpTest,LogicalAnd)2164 TEST_F(OpTest, LogicalAnd) {
2165   Repeatedly([this]() {
2166     auto dims = BroadcastableDims();
2167     return ExpectTfAndXlaOutputsAreClose(
2168         OpTestBuilder("LogicalAnd")
2169             .RandomInput(DT_BOOL, dims.first)
2170             .RandomInput(DT_BOOL, dims.second));
2171   });
2172 }
2173 
TEST_F(OpTest,LogicalNot)2174 TEST_F(OpTest, LogicalNot) {
2175   Repeatedly([this]() {
2176     return ExpectTfAndXlaOutputsAreClose(
2177         OpTestBuilder("LogicalNot").RandomInput(DT_BOOL));
2178   });
2179 }
2180 
TEST_F(OpTest,LogicalOr)2181 TEST_F(OpTest, LogicalOr) {
2182   Repeatedly([this]() {
2183     auto dims = BroadcastableDims();
2184     return ExpectTfAndXlaOutputsAreClose(
2185         OpTestBuilder("LogicalOr")
2186             .RandomInput(DT_BOOL, dims.first)
2187             .RandomInput(DT_BOOL, dims.second));
2188   });
2189 }
2190 
TEST_F(OpTest,LogSoftmax)2191 TEST_F(OpTest, LogSoftmax) {
2192   Repeatedly([this]() {
2193     return ExpectTfAndXlaOutputsAreClose(
2194         OpTestBuilder("LogSoftmax")
2195             .RandomInput(DT_FLOAT, RandomDims(2, 2))
2196             .Attr("T", DT_FLOAT));
2197   });
2198 }
2199 
TEST_F(OpTest,LRN)2200 TEST_F(OpTest, LRN) {
2201   Repeatedly([this]() {
2202     // TODO(b/31362467): Crashes with 0 dims on GPU. Re-enable when fixed.
2203     std::vector<int64> data_dims = RandomDims(4, 4, 1, 8);
2204     // CuDNN requires depth_radius > 0.
2205     std::uniform_int_distribution<int> radius(1, data_dims[3]);
2206     std::uniform_real_distribution<float> coeff(0.01, 2.0);
2207     return ExpectTfAndXlaOutputsAreClose(
2208         OpTestBuilder("LRN")
2209             .RandomInput(DT_FLOAT, data_dims)
2210             .Attr("T", DT_FLOAT)
2211             .Attr("depth_radius", radius(generator()))
2212             .Attr("bias", coeff(generator()))
2213             .Attr("alpha", coeff(generator()))
2214             .Attr("beta", coeff(generator())));
2215   });
2216 }
2217 
TEST_F(OpTest,LRNGrad)2218 TEST_F(OpTest, LRNGrad) {
2219   Repeatedly([this]() {
2220     // TODO(b/31362467): Crashes with 0 dims on GPU. Re-enable when fixed.
2221     std::vector<int64> dims = RandomDims(4, 4, 1, 8);
2222     // CuDNN requires depth_radius > 0.
2223     std::uniform_int_distribution<int> radius(1, dims[3]);
2224     std::uniform_real_distribution<float> coeff(0.0, 2.0);
2225     return ExpectTfAndXlaOutputsAreClose(
2226         OpTestBuilder("LRNGrad")
2227             .RandomInput(DT_FLOAT, dims)
2228             .RandomInput(DT_FLOAT, dims)
2229             .RandomInput(DT_FLOAT, dims)
2230             .Attr("T", DT_FLOAT)
2231             .Attr("depth_radius", radius(generator()))
2232             .Attr("bias", coeff(generator()))
2233             .Attr("alpha", coeff(generator()))
2234             .Attr("beta", coeff(generator())));
2235   });
2236 }
2237 
TEST_F(OpTest,MatMul)2238 TEST_F(OpTest, MatMul) {
2239   Repeatedly([this]() {
2240     int64 x = RandomDim();
2241     int64 y = RandomDim();
2242     int64 z = RandomDim();
2243 
2244     std::vector<int64> a_dims = {x, y};
2245     std::vector<int64> b_dims = {y, z};
2246 
2247     std::bernoulli_distribution random_bool;
2248     bool transpose_a = random_bool(generator());
2249     bool transpose_b = random_bool(generator());
2250     if (transpose_a) {
2251       std::swap(a_dims[0], a_dims[1]);
2252     }
2253     if (transpose_b) {
2254       std::swap(b_dims[0], b_dims[1]);
2255     }
2256 
2257     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
2258     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul")
2259                                              .RandomInput(type, a_dims)
2260                                              .RandomInput(type, b_dims)
2261                                              .Attr("T", type)
2262                                              .Attr("transpose_a", transpose_a)
2263                                              .Attr("transpose_b", transpose_b));
2264   });
2265 }
2266 
TEST_F(OpTest,MatrixDiag)2267 TEST_F(OpTest, MatrixDiag) {
2268   Repeatedly([this]() {
2269     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
2270     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiag")
2271                                              .RandomInput(type, RandomDims(1))
2272                                              .Attr("T", type));
2273   });
2274 }
2275 
TEST_F(OpTest,MatrixDiagPart)2276 TEST_F(OpTest, MatrixDiagPart) {
2277   Repeatedly([this]() {
2278     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
2279     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiagPart")
2280                                              .RandomInput(type, RandomDims(2))
2281                                              .Attr("T", type));
2282   });
2283 }
2284 
TEST_F(OpTest,Max)2285 TEST_F(OpTest, Max) {
2286   Repeatedly([this]() {
2287     auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
2288     std::vector<int64> data_dims = RandomDims();
2289     Tensor indices = RandomReductionIndices(data_dims.size());
2290     bool keep_dims = Choose<bool>({false, true});
2291     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Max")
2292                                              .RandomInput(type, data_dims)
2293                                              .Input(indices)
2294                                              .Attr("T", type)
2295                                              .Attr("keep_dims", keep_dims));
2296   });
2297 }
2298 
TEST_F(OpTest,Maximum)2299 TEST_F(OpTest, Maximum) {
2300   Repeatedly([this]() {
2301     auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
2302     auto dims = BroadcastableDims();
2303     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Maximum")
2304                                              .RandomInput(type, dims.first)
2305                                              .RandomInput(type, dims.second)
2306                                              .Attr("T", type));
2307   });
2308 }
2309 
TEST_F(OpTest,MaxPool)2310 TEST_F(OpTest, MaxPool) {
2311   Repeatedly([this]() {
2312     std::uniform_int_distribution<int> random_int(1, 5);
2313     std::vector<int64> dims = RandomDims(4, 4, 1);
2314     int kernel_rows =
2315         std::uniform_int_distribution<int>(1, dims[1])(generator());
2316     int kernel_cols =
2317         std::uniform_int_distribution<int>(1, dims[2])(generator());
2318     int stride_rows = random_int(generator()),
2319         stride_cols = random_int(generator());
2320 
2321     string padding = Choose<string>({"SAME", "VALID"});
2322     return ExpectTfAndXlaOutputsAreClose(
2323         OpTestBuilder("MaxPool")
2324             .RandomInput(DT_FLOAT, dims)
2325             .Attr("T", DT_FLOAT)
2326             .Attr("ksize", {1, kernel_rows, kernel_cols, 1})
2327             .Attr("strides", {1, stride_rows, stride_cols, 1})
2328             .Attr("padding", padding)
2329             .Attr("data_format", "NHWC"));
2330   });
2331   // TODO(phawkins): test NCHW format (not supported by CPU)
2332 }
2333 
TEST_F(OpTest,MaxPool3D)2334 TEST_F(OpTest, MaxPool3D) {
2335   Repeatedly([this]() {
2336     std::uniform_int_distribution<int> random_int(1, 5);
2337     std::vector<int64> dims = RandomDims(5, 5, 1);
2338 
2339     std::vector<int64> input_dims, kernel_dims, stride_dims;
2340     kernel_dims.push_back(1);
2341     stride_dims.push_back(1);
2342     for (int i = 0; i < 3; ++i) {
2343       kernel_dims.push_back(
2344           std::uniform_int_distribution<int>(1, dims[i])(generator()));
2345       input_dims.push_back(dims[i]);
2346       stride_dims.push_back(random_int(generator()));
2347     }
2348     kernel_dims.push_back(1);
2349     stride_dims.push_back(1);
2350     int64 batch = dims[3];
2351     int64 feature = dims[4];
2352 
2353     string padding = Choose<string>({"SAME", "VALID"});
2354     return ExpectTfAndXlaOutputsAreClose(
2355         OpTestBuilder("MaxPool3D")
2356             .RandomInput(DT_FLOAT,
2357                          ImageDims(FORMAT_NHWC, batch, feature, input_dims))
2358             .Attr("T", DT_FLOAT)
2359             .Attr("ksize", kernel_dims)
2360             .Attr("strides", stride_dims)
2361             .Attr("padding", padding)
2362             .Attr("data_format", "NDHWC"));
2363   });
2364   // TODO(phawkins): test NCHW format (not supported by CPU)
2365 }
2366 
TEST_F(OpTest,Mean)2367 TEST_F(OpTest, Mean) {
2368   Repeatedly([this]() {
2369     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
2370     // TODO(phawkins): CPU and XLA differ output for reducing across a
2371     // size-0 dimension (nan vs 0). For now, require size >= 1.
2372     std::vector<int64> data_dims = RandomDims(0, kDefaultMaxRank, 1);
2373     Tensor indices = RandomReductionIndices(data_dims.size());
2374     bool keep_dims = Choose<bool>({false, true});
2375     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mean")
2376                                              .RandomInput(type, data_dims)
2377                                              .Input(indices)
2378                                              .Attr("T", type)
2379                                              .Attr("keep_dims", keep_dims));
2380   });
2381 }
2382 
TEST_F(OpTest,Min)2383 TEST_F(OpTest, Min) {
2384   Repeatedly([this]() {
2385     auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
2386     std::vector<int64> data_dims = RandomDims();
2387     Tensor indices = RandomReductionIndices(data_dims.size());
2388     bool keep_dims = Choose<bool>({false, true});
2389     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Min")
2390                                              .RandomInput(type, data_dims)
2391                                              .Input(indices)
2392                                              .Attr("T", type)
2393                                              .Attr("keep_dims", keep_dims));
2394   });
2395 }
2396 
TEST_F(OpTest,Minimum)2397 TEST_F(OpTest, Minimum) {
2398   Repeatedly([this]() {
2399     auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
2400     auto dims = BroadcastableDims();
2401     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Minimum")
2402                                              .RandomInput(type, dims.first)
2403                                              .RandomInput(type, dims.second)
2404                                              .Attr("T", type));
2405   });
2406 }
2407 
TEST_F(OpTest,Mod)2408 TEST_F(OpTest, Mod) {
2409   Repeatedly([this]() {
2410     auto dims = BroadcastableDims();
2411     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mod")
2412                                              .RandomInput(DT_INT32, dims.first)
2413                                              .RandomInput(DT_INT32, dims.second)
2414                                              .Attr("T", DT_INT32));
2415   });
2416 }
2417 
TEST_F(OpTest,Mul)2418 TEST_F(OpTest, Mul) {
2419   Repeatedly([this]() {
2420     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
2421     auto dims = BroadcastableDims();
2422     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mul")
2423                                              .RandomInput(type, dims.first)
2424                                              .RandomInput(type, dims.second)
2425                                              .Attr("T", type));
2426   });
2427 }
2428 
TEST_F(OpTest,Neg)2429 TEST_F(OpTest, Neg) {
2430   Repeatedly([this]() {
2431     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
2432     return ExpectTfAndXlaOutputsAreClose(
2433         OpTestBuilder("Neg").RandomInput(type).Attr("T", type));
2434   });
2435 }
2436 
TEST_F(OpTest,NotEqual)2437 TEST_F(OpTest, NotEqual) {
2438   Repeatedly([this]() {
2439     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
2440     auto dims = BroadcastableDims();
2441     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("NotEqual")
2442                                              .RandomInput(type, dims.first)
2443                                              .RandomInput(type, dims.second)
2444                                              .Attr("T", type));
2445   });
2446 }
2447 
TEST_F(OpTest,OneHot)2448 TEST_F(OpTest, OneHot) {
2449   Repeatedly([this]() {
2450     auto type = Choose<DataType>(kAllXlaTypes);
2451 
2452     std::vector<int64> dims = RandomDims();
2453     int num_dims = dims.size();
2454 
2455     int32 depth = RandomDim();
2456 
2457     Tensor indices(DT_INT32, TensorShape(dims));
2458     std::uniform_int_distribution<int32> distribution(-depth * 2, depth * 2);
2459     test::FillFn<int32>(&indices, [this, &distribution](int i) -> int32 {
2460       return distribution(generator());
2461     });
2462 
2463     int axis = std::uniform_int_distribution<int32>(-num_dims - 5,
2464                                                     num_dims + 5)(generator());
2465 
2466     OpTestBuilder builder("OneHot");
2467     builder.Attr("T", type);
2468     builder.Attr("TI", DT_INT32);
2469     builder.Attr("axis", axis);
2470     builder.Input(indices);
2471     builder.Input(test::AsScalar<int32>(depth));
2472     builder.RandomInput(type, {});
2473     builder.RandomInput(type, {});
2474     return ExpectTfAndXlaOutputsAreClose(builder);
2475   });
2476 }
2477 
TEST_F(OpTest,OnesLike)2478 TEST_F(OpTest, OnesLike) {
2479   Repeatedly([this]() {
2480     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
2481     return ExpectTfAndXlaOutputsAreClose(
2482         OpTestBuilder("OnesLike").RandomInput(type).Attr("T", type));
2483   });
2484 }
2485 
TEST_F(OpTest,Pack)2486 TEST_F(OpTest, Pack) {
2487   Repeatedly([this]() {
2488     auto type = Choose<DataType>(kAllXlaTypes);
2489     int n = std::uniform_int_distribution<int>(1, 5)(generator());
2490 
2491     std::vector<int64> dims = RandomDims();
2492     int num_dims = dims.size();
2493     int axis = std::uniform_int_distribution<int32>(-num_dims - 1,
2494                                                     num_dims)(generator());
2495 
2496     OpTestBuilder builder("Pack");
2497     builder.Attr("T", type);
2498     builder.Attr("N", n);
2499     builder.Attr("axis", axis);
2500     for (int i = 0; i < n; ++i) {
2501       builder.RandomInput(type, dims);
2502     }
2503     return ExpectTfAndXlaOutputsAreClose(builder);
2504   });
2505 }
2506 
TEST_F(OpTest,Pad)2507 TEST_F(OpTest, Pad) {
2508   Repeatedly([this]() {
2509     auto type = Choose<DataType>(kAllXlaTypes);
2510     std::vector<int64> t_dims = RandomDims();
2511 
2512     DataType tpaddings = Choose<DataType>({DT_INT32, DT_INT64});
2513     std::vector<int64> paddings_vec;
2514     for (int i = 0; i < t_dims.size(); ++i) {
2515       std::uniform_int_distribution<int> pad_distribution(0, t_dims[i]);
2516       int pad_size = pad_distribution(generator());
2517       std::uniform_int_distribution<int> lower_distribution(0, pad_size);
2518       int low_pad_size = lower_distribution(generator());
2519       paddings_vec.push_back(low_pad_size);
2520       paddings_vec.push_back(pad_size - low_pad_size);
2521       t_dims[i] -= pad_size;
2522     }
2523     Tensor paddings;
2524     CHECK(
2525         paddings.CopyFrom(AsIntTensor(tpaddings, paddings_vec),
2526                           TensorShape({static_cast<int64>(t_dims.size()), 2})));
2527     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Pad")
2528                                              .RandomInput(type, t_dims)
2529                                              .Input(paddings)
2530                                              .Attr("T", type)
2531                                              .Attr("Tpaddings", tpaddings));
2532   });
2533 }
2534 
TEST_F(OpTest,Pow)2535 TEST_F(OpTest, Pow) {
2536   // TODO(phawkins): Feeding large DT_INT32 values to Pow() leads to
2537   // nontermination.
2538   Repeatedly([this]() {
2539     auto dims = BroadcastableDims();
2540     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
2541     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Pow")
2542                                              .RandomInput(type, dims.first)
2543                                              .RandomInput(type, dims.second)
2544                                              .Attr("T", type));
2545   });
2546 }
2547 
TEST_F(OpTest,Prod)2548 TEST_F(OpTest, Prod) {
2549   Repeatedly([this]() {
2550     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
2551     std::vector<int64> data_dims = RandomDims();
2552     Tensor indices = RandomReductionIndices(data_dims.size());
2553     bool keep_dims = Choose<bool>({false, true});
2554     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Prod")
2555                                              .RandomInput(type, data_dims)
2556                                              .Input(indices)
2557                                              .Attr("T", type)
2558                                              .Attr("keep_dims", keep_dims));
2559   });
2560 }
2561 
TEST_F(OpTest,Range)2562 TEST_F(OpTest, Range) {
2563   Repeatedly([this]() {
2564     auto ToScalar = [](DataType type, int x) {
2565       if (type == DT_INT32) return test::AsScalar<int32>(x);
2566       if (type == DT_INT64) return test::AsScalar<int64>(x);
2567       if (type == DT_FLOAT) return test::AsScalar<float>(x);
2568       if (type == DT_DOUBLE) return test::AsScalar<double>(x);
2569       LOG(FATAL) << "Unknown type " << DataTypeString(type);
2570     };
2571     std::uniform_int_distribution<int> distribution(-50, 50);
2572     DataType tidx = Choose<DataType>({DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE});
2573     return ExpectTfAndXlaOutputsAreClose(
2574         OpTestBuilder("Range")
2575             .Input(ToScalar(tidx, distribution(generator())))
2576             .Input(ToScalar(tidx, distribution(generator())))
2577             .Input(ToScalar(tidx, distribution(generator())))
2578             .Attr("Tidx", tidx));
2579   });
2580 }
2581 
TEST_F(OpTest,Rank)2582 TEST_F(OpTest, Rank) {
2583   Repeatedly([this]() {
2584     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
2585     return ExpectTfAndXlaOutputsAreClose(
2586         OpTestBuilder("Rank").RandomInput(type).Attr("T", type));
2587   });
2588 }
2589 
TEST_F(OpTest,Real)2590 TEST_F(OpTest, Real) {
2591   Repeatedly([this]() {
2592     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Real")
2593                                              .RandomInput(DT_COMPLEX64)
2594                                              .Attr("T", DT_COMPLEX64));
2595   });
2596 }
2597 
TEST_F(OpTest,RealDiv)2598 TEST_F(OpTest, RealDiv) {
2599   Repeatedly([this]() {
2600     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
2601     auto dims = BroadcastableDims();
2602     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RealDiv")
2603                                              .RandomInput(type, dims.first)
2604                                              .RandomInput(type, dims.second)
2605                                              .Attr("T", type));
2606   });
2607 }
2608 
TEST_F(OpTest,Reciprocal)2609 TEST_F(OpTest, Reciprocal) {
2610   Repeatedly([this]() {
2611     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
2612     return ExpectTfAndXlaOutputsAreClose(
2613         OpTestBuilder("Reciprocal").RandomInput(type).Attr("T", type));
2614   });
2615 }
2616 
TEST_F(OpTest,ReciprocalGrad)2617 TEST_F(OpTest, ReciprocalGrad) {
2618   Repeatedly([this]() {
2619     std::vector<int64> dims = RandomDims();
2620     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
2621     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReciprocalGrad")
2622                                              .RandomInput(type, dims)
2623                                              .RandomInput(type, dims)
2624                                              .Attr("T", type));
2625   });
2626 }
TEST_F(OpTest,Relu)2627 TEST_F(OpTest, Relu) {
2628   Repeatedly([this]() {
2629     return ExpectTfAndXlaOutputsAreClose(
2630         OpTestBuilder("Relu").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
2631   });
2632 }
2633 
TEST_F(OpTest,Relu6)2634 TEST_F(OpTest, Relu6) {
2635   Repeatedly([this]() {
2636     return ExpectTfAndXlaOutputsAreClose(
2637         OpTestBuilder("Relu6").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
2638   });
2639 }
2640 
TEST_F(OpTest,Relu6Grad)2641 TEST_F(OpTest, Relu6Grad) {
2642   Repeatedly([this]() {
2643     auto dims = RandomDims(1);
2644     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Relu6Grad")
2645                                              .RandomInput(DT_FLOAT, dims)
2646                                              .RandomInput(DT_FLOAT, dims)
2647                                              .Attr("T", DT_FLOAT));
2648   });
2649 }
2650 
TEST_F(OpTest,ReluGrad)2651 TEST_F(OpTest, ReluGrad) {
2652   Repeatedly([this]() {
2653     auto dims = RandomDims(1);
2654     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReluGrad")
2655                                              .RandomInput(DT_FLOAT, dims)
2656                                              .RandomInput(DT_FLOAT, dims)
2657                                              .Attr("T", DT_FLOAT));
2658   });
2659 }
2660 
TEST_F(OpTest,Reshape)2661 TEST_F(OpTest, Reshape) {
2662   Repeatedly([this]() {
2663     auto type = Choose<DataType>(kAllXlaTypes);
2664     std::vector<int64> dims = RandomDims();
2665     std::bernoulli_distribution random_bool;
2666     std::vector<int64> dims_before, dims_after;
2667     for (std::vector<int64>* out : {&dims_before, &dims_after}) {
2668       std::shuffle(dims.begin(), dims.end(), generator());
2669       for (int64 dim : dims) {
2670         // Either add the dimension as a new dimension or merge it with the
2671         // previous dimension.
2672         if (out->empty() || random_bool(generator())) {
2673           out->push_back(dim);
2674         } else {
2675           out->back() *= dim;
2676         }
2677       }
2678     }
2679     return ExpectTfAndXlaOutputsAreClose(
2680         OpTestBuilder("Reshape")
2681             .RandomInput(type, dims_before)
2682             .Input(test::AsTensor<int32>(
2683                 std::vector<int32>(dims_after.begin(), dims_after.end())))
2684             .Attr("T", type));
2685   });
2686 }
2687 
TEST_F(OpTest,ResizeBilinear)2688 TEST_F(OpTest, ResizeBilinear) {
2689   Repeatedly([this]() {
2690     std::vector<int64> in_dims = RandomDims(4, 4);
2691     std::vector<int64> out_dims = RandomDims(2, 2);
2692 
2693     return ExpectTfAndXlaOutputsAreClose(
2694         OpTestBuilder("ResizeBilinear")
2695             .RandomInput(DT_FLOAT, in_dims)
2696             .Input(test::AsTensor<int32>(
2697                 std::vector<int32>(out_dims.begin(), out_dims.end())))
2698             .Attr("T", DT_FLOAT)
2699             .Attr("align_corners", true));
2700   });
2701 }
2702 
TEST_F(OpTest,ResizeBilinearGrad)2703 TEST_F(OpTest, ResizeBilinearGrad) {
2704   Repeatedly([this]() {
2705     std::vector<int64> in_dims = RandomDims(4, 4);
2706     std::vector<int64> out_dims = RandomDims(2, 2);
2707 
2708     return ExpectTfAndXlaOutputsAreClose(
2709         OpTestBuilder("ResizeBilinearGrad")
2710             .RandomInput(DT_FLOAT, in_dims)
2711             .RandomInput(DT_FLOAT,
2712                          {in_dims[0], out_dims[0], out_dims[1], in_dims[3]})
2713             .Attr("T", DT_FLOAT)
2714             .Attr("align_corners", true));
2715   });
2716 }
2717 
TEST_F(OpTest,Reverse)2718 TEST_F(OpTest, Reverse) {
2719   Repeatedly([this]() {
2720     std::vector<int64> dims = RandomDims(1);
2721     auto type = Choose<DataType>(kAllXlaTypes);
2722     int64 rank = dims.size();
2723     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Reverse")
2724                                              .RandomInput(type, dims)
2725                                              .RandomInput(DT_BOOL, {rank})
2726                                              .Attr("T", type));
2727   });
2728 }
2729 
TEST_F(OpTest,ReverseSequence)2730 TEST_F(OpTest, ReverseSequence) {
2731   Repeatedly([this]() {
2732     std::vector<int64> dims = RandomDims(/*min_rank=*/2);
2733     auto type = Choose<DataType>(kAllXlaTypes);
2734     int64 rank = dims.size();
2735 
2736     // Choose random batch and sequence dimensions.
2737     std::vector<int> shuffled_dim_ids(rank);
2738     absl::c_iota(shuffled_dim_ids, 0);
2739     absl::c_shuffle(shuffled_dim_ids, generator());
2740     shuffled_dim_ids.resize(2);
2741     int batch_dim = shuffled_dim_ids[0];
2742     int seq_dim = shuffled_dim_ids[1];
2743 
2744     int batch_size = dims[batch_dim];
2745     int max_seq_len = dims[seq_dim];
2746     std::vector<int32> seq_lens(batch_size);
2747     std::uniform_int_distribution<int32> d(0, max_seq_len);
2748     absl::c_generate(seq_lens, [&]() { return d(generator()); });
2749 
2750     return ExpectTfAndXlaOutputsAreClose(
2751         OpTestBuilder("ReverseSequence")
2752             .RandomInput(type, dims)
2753             .Input(test::AsTensor<int32>(seq_lens))
2754             .Attr("seq_dim", seq_dim)
2755             .Attr("batch_dim", batch_dim)
2756             .Attr("T", type)
2757             .Attr("Tlen", DT_INT32));
2758   });
2759 }
2760 
TEST_F(OpTest,ReverseV2)2761 TEST_F(OpTest, ReverseV2) {
2762   Repeatedly([this]() {
2763     auto type = Choose<DataType>(kAllXlaTypes);
2764     std::vector<int64> data_dims = RandomDims();
2765     Tensor indices = RandomReductionIndices(data_dims.size());
2766     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReverseV2")
2767                                              .RandomInput(type, data_dims)
2768                                              .Input(indices)
2769                                              .Attr("T", type));
2770   });
2771 }
2772 
TEST_F(OpTest,Rint)2773 TEST_F(OpTest, Rint) {
2774   Repeatedly([this]() {
2775     return ExpectTfAndXlaOutputsAreClose(
2776         OpTestBuilder("Rint").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
2777   });
2778 }
2779 
TEST_F(OpTest,Round)2780 TEST_F(OpTest, Round) {
2781   Repeatedly([this]() {
2782     return ExpectTfAndXlaOutputsAreClose(
2783         OpTestBuilder("Round").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
2784   });
2785 }
2786 
TEST_F(OpTest,Rsqrt)2787 TEST_F(OpTest, Rsqrt) {
2788   Repeatedly([this]() {
2789     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
2790     return ExpectTfAndXlaOutputsAreClose(
2791         OpTestBuilder("Rsqrt").RandomInput(type).Attr("T", type));
2792   });
2793 }
2794 
TEST_F(OpTest,RsqrtGrad)2795 TEST_F(OpTest, RsqrtGrad) {
2796   Repeatedly([this]() {
2797     auto dims = RandomDims();
2798     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
2799     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RsqrtGrad")
2800                                              .RandomInput(type, dims)
2801                                              .RandomInput(type, dims)
2802                                              .Attr("T", type));
2803   });
2804 }
2805 
TEST_F(OpTest,Shape)2806 TEST_F(OpTest, Shape) {
2807   Repeatedly([this]() {
2808     auto type = Choose<DataType>(kAllXlaTypes);
2809     return ExpectTfAndXlaOutputsAreClose(
2810         OpTestBuilder("Shape").RandomInput(type).Attr("T", type));
2811   });
2812 }
2813 
TEST_F(OpTest,ShapeN)2814 TEST_F(OpTest, ShapeN) {
2815   Repeatedly([this]() {
2816     auto type = Choose<DataType>(kAllXlaTypes);
2817     int n = std::uniform_int_distribution<int>(1, 5)(generator());
2818     OpTestBuilder builder("ShapeN");
2819     builder.Attr("T", type);
2820     builder.Attr("N", n);
2821     for (int i = 0; i < n; ++i) {
2822       builder.RandomInput(type);
2823     }
2824     return ExpectTfAndXlaOutputsAreClose(builder);
2825   });
2826 }
2827 
TEST_F(OpTest,Sigmoid)2828 TEST_F(OpTest, Sigmoid) {
2829   Repeatedly([this]() {
2830     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
2831     return ExpectTfAndXlaOutputsAreClose(
2832         OpTestBuilder("Sigmoid").RandomInput(type).Attr("T", type));
2833   });
2834 }
2835 
TEST_F(OpTest,SigmoidGrad)2836 TEST_F(OpTest, SigmoidGrad) {
2837   Repeatedly([this]() {
2838     auto dims = RandomDims();
2839     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
2840     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SigmoidGrad")
2841                                              .RandomInput(type, dims)
2842                                              .RandomInput(type, dims)
2843                                              .Attr("T", type));
2844   });
2845 }
2846 
TEST_F(OpTest,Sign)2847 TEST_F(OpTest, Sign) {
2848   Repeatedly([this]() {
2849     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
2850     return ExpectTfAndXlaOutputsAreClose(
2851         OpTestBuilder("Sign").RandomInput(type).Attr("T", type));
2852   });
2853 }
2854 
TEST_F(OpTest,Sin)2855 TEST_F(OpTest, Sin) {
2856   Repeatedly([this]() {
2857     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
2858     return ExpectTfAndXlaOutputsAreClose(
2859         OpTestBuilder("Sin").RandomInput(type).Attr("T", type));
2860   });
2861 }
2862 
TEST_F(OpTest,Sinh)2863 TEST_F(OpTest, Sinh) {
2864   Repeatedly([this]() {
2865     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
2866     return ExpectTfAndXlaOutputsAreClose(
2867         OpTestBuilder("Sinh").RandomInput(type).Attr("T", type));
2868   });
2869 }
2870 
TEST_F(OpTest,Size)2871 TEST_F(OpTest, Size) {
2872   Repeatedly([this]() {
2873     auto type = Choose<DataType>(kAllXlaTypes);
2874     return ExpectTfAndXlaOutputsAreClose(
2875         OpTestBuilder("Size").RandomInput(type).Attr("T", type));
2876   });
2877 }
2878 
TEST_F(OpTest,Slice)2879 TEST_F(OpTest, Slice) {
2880   Repeatedly([this]() {
2881     auto type = Choose<DataType>(kAllXlaTypes);
2882     std::vector<int64> data_dims = RandomDims();
2883 
2884     std::vector<int32> begin(data_dims.size()), size(data_dims.size());
2885     for (int i = 0; i < data_dims.size(); ++i) {
2886       begin[i] =
2887           std::uniform_int_distribution<int32>(0, data_dims[i])(generator());
2888       size[i] = std::uniform_int_distribution<int32>(
2889           -1, data_dims[i] - begin[i])(generator());
2890     }
2891     return ExpectTfAndXlaOutputsAreClose(
2892         OpTestBuilder("Slice")
2893             .RandomInput(type, data_dims)
2894             .Input(test::AsTensor<int32>(begin))
2895             .Input(test::AsTensor<int32>(size))
2896             .Attr("T", type)
2897             .Attr("Index", DT_INT32));
2898   });
2899 }
2900 
TEST_F(OpTest,Softmax)2901 TEST_F(OpTest, Softmax) {
2902   Repeatedly([this]() {
2903     return ExpectTfAndXlaOutputsAreClose(
2904         OpTestBuilder("Softmax")
2905             .RandomInput(DT_FLOAT, RandomDims(2, 2))
2906             .Attr("T", DT_FLOAT));
2907   });
2908 }
2909 
TEST_F(OpTest,SoftmaxCrossEntropyWithLogits)2910 TEST_F(OpTest, SoftmaxCrossEntropyWithLogits) {
2911   Repeatedly([this]() {
2912     std::vector<int64> dims = RandomDims(2, 2, 1);
2913     return ExpectTfAndXlaOutputsAreClose(
2914         OpTestBuilder("SoftmaxCrossEntropyWithLogits")
2915             .RandomInput(DT_FLOAT, dims)
2916             .RandomInput(DT_FLOAT, dims)
2917             .Attr("T", DT_FLOAT));
2918   });
2919 }
2920 
TEST_F(OpTest,Softplus)2921 TEST_F(OpTest, Softplus) {
2922   Repeatedly([this]() {
2923     return ExpectTfAndXlaOutputsAreClose(
2924         OpTestBuilder("Softplus").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
2925   });
2926 }
2927 
TEST_F(OpTest,SoftplusGrad)2928 TEST_F(OpTest, SoftplusGrad) {
2929   Repeatedly([this]() {
2930     std::vector<int64> dims = RandomDims();
2931     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SoftplusGrad")
2932                                              .RandomInput(DT_FLOAT, dims)
2933                                              .RandomInput(DT_FLOAT, dims)
2934                                              .Attr("T", DT_FLOAT));
2935   });
2936 }
2937 
TEST_F(OpTest,Softsign)2938 TEST_F(OpTest, Softsign) {
2939   Repeatedly([this]() {
2940     return ExpectTfAndXlaOutputsAreClose(
2941         OpTestBuilder("Softsign").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT));
2942   });
2943 }
2944 
TEST_F(OpTest,SoftsignGrad)2945 TEST_F(OpTest, SoftsignGrad) {
2946   Repeatedly([this]() {
2947     std::vector<int64> dims = RandomDims();
2948     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SoftsignGrad")
2949                                              .RandomInput(DT_FLOAT, dims)
2950                                              .RandomInput(DT_FLOAT, dims)
2951                                              .Attr("T", DT_FLOAT));
2952   });
2953 }
2954 
TEST_F(OpTest,SpaceToBatch)2955 TEST_F(OpTest, SpaceToBatch) {
2956   Repeatedly([this]() {
2957     std::vector<int64> block_dims = RandomDims(4, 4, 0, 5);
2958     const int num_block_dims = 2;
2959     int64 block_size = RandomDim(2, 5);
2960 
2961     std::vector<int64> input_dims(1 + num_block_dims + 1);
2962     input_dims[0] = RandomDim();
2963     for (int i = 0; i < num_block_dims; ++i) {
2964       input_dims[1 + i] = block_dims[i] * block_size;
2965     }
2966     input_dims[1 + num_block_dims] = RandomDim();
2967 
2968     std::vector<int64> padding_vals;
2969     std::uniform_int_distribution<int> distribution(0, 7);
2970     for (int i = 0; i < num_block_dims; ++i) {
2971       int64 pad_before;
2972       int64 pad_after;
2973       do {
2974         pad_before = distribution(generator());
2975         pad_after = distribution(generator());
2976       } while (pad_before + pad_after > input_dims[1 + i]);
2977       input_dims[1 + i] -= pad_before + pad_after;
2978       padding_vals.push_back(pad_before);
2979       padding_vals.push_back(pad_after);
2980     }
2981     Tensor paddings;
2982     CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals),
2983                             TensorShape({num_block_dims, 2})));
2984 
2985     auto type = Choose<DataType>(kAllXlaTypes);
2986     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SpaceToBatch")
2987                                              .RandomInput(type, input_dims)
2988                                              .Input(paddings)
2989                                              .Attr("T", type)
2990                                              .Attr("block_size", block_size));
2991   });
2992 }
2993 
TEST_F(OpTest,SpaceToBatchND)2994 TEST_F(OpTest, SpaceToBatchND) {
2995   Repeatedly([this]() {
2996     std::vector<int64> block_dims = RandomDims(1, 3, 0, 5);
2997     int num_block_dims = block_dims.size();
2998     std::vector<int64> remaining_dims = RandomDims(0, 3);
2999     std::vector<int64> block_multipliers =
3000         RandomDims(block_dims.size(), block_dims.size(), 0, 4);
3001 
3002     std::vector<int64> input_dims(1 + num_block_dims + remaining_dims.size());
3003     input_dims[0] = RandomDim();
3004     for (int i = 0; i < num_block_dims; ++i) {
3005       input_dims[1 + i] = block_dims[i] * block_multipliers[i];
3006     }
3007     std::copy(remaining_dims.begin(), remaining_dims.end(),
3008               input_dims.begin() + 1 + num_block_dims);
3009 
3010     std::vector<int64> padding_vals;
3011     std::uniform_int_distribution<int> distribution(0, 7);
3012     for (int i = 0; i < num_block_dims; ++i) {
3013       int64 pad_before;
3014       int64 pad_after;
3015       do {
3016         pad_before = distribution(generator());
3017         pad_after = distribution(generator());
3018       } while (pad_before + pad_after > input_dims[1 + i]);
3019       input_dims[1 + i] -= pad_before + pad_after;
3020       padding_vals.push_back(pad_before);
3021       padding_vals.push_back(pad_after);
3022     }
3023     Tensor paddings;
3024     CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals),
3025                             TensorShape({num_block_dims, 2})));
3026 
3027     auto type = Choose<DataType>(kAllXlaTypes);
3028     return ExpectTfAndXlaOutputsAreClose(
3029         OpTestBuilder("SpaceToBatchND")
3030             .RandomInput(type, input_dims)
3031             .Input(test::AsTensor<int32>(
3032                 std::vector<int32>(block_dims.begin(), block_dims.end())))
3033             .Input(paddings)
3034             .Attr("T", type));
3035   });
3036 }
3037 
TEST_F(OpTest,SpaceToDepth)3038 TEST_F(OpTest, SpaceToDepth) {
3039   Repeatedly([this]() {
3040     int64 block = RandomDim(2, 5);
3041     std::vector<int64> input_dims = RandomDims(4, 4);
3042     // Round spatial dimensions up to a multiple of the block size
3043     input_dims[1] = (input_dims[1] + (block - 1)) / block * block;
3044     input_dims[2] = (input_dims[2] + (block - 1)) / block * block;
3045     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SpaceToDepth")
3046                                              .RandomInput(DT_FLOAT, input_dims)
3047                                              .Attr("T", DT_FLOAT)
3048                                              .Attr("block_size", block));
3049   });
3050 }
3051 
TEST_F(OpTest,SparseMatMul)3052 TEST_F(OpTest, SparseMatMul) {
3053   Repeatedly([this]() {
3054     int64 x = RandomDim();
3055     int64 y = RandomDim();
3056     int64 z = RandomDim();
3057 
3058     std::vector<int64> a_dims = {x, y};
3059     std::vector<int64> b_dims = {y, z};
3060 
3061     std::bernoulli_distribution random_bool;
3062     bool transpose_a = random_bool(generator());
3063     bool transpose_b = random_bool(generator());
3064     if (transpose_a) {
3065       std::swap(a_dims[0], a_dims[1]);
3066     }
3067     if (transpose_b) {
3068       std::swap(b_dims[0], b_dims[1]);
3069     }
3070 
3071     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul")
3072                                              .RandomInput(DT_FLOAT, a_dims)
3073                                              .RandomInput(DT_FLOAT, b_dims)
3074                                              .Attr("Ta", DT_FLOAT)
3075                                              .Attr("Tb", DT_FLOAT)
3076                                              .Attr("transpose_a", transpose_a)
3077                                              .Attr("transpose_b", transpose_b));
3078   });
3079 }
3080 
TEST_F(OpTest,SparseSoftmaxCrossEntropyWithLogits)3081 TEST_F(OpTest, SparseSoftmaxCrossEntropyWithLogits) {
3082   Repeatedly([this]() {
3083     std::vector<int64> dims = RandomDims(2, 2, 1);
3084     int64 batch_size = dims[0];
3085     int64 num_classes = dims[1];
3086 
3087     std::vector<int32> indices(batch_size);
3088     for (int64 i = 0; i < batch_size; ++i) {
3089       indices[i] =
3090           std::uniform_int_distribution<int32>(0, num_classes - 1)(generator());
3091     }
3092 
3093     return ExpectTfAndXlaOutputsAreClose(
3094         OpTestBuilder("SparseSoftmaxCrossEntropyWithLogits")
3095             .RandomInput(DT_FLOAT, dims)
3096             .Input(test::AsTensor<int32>(indices))
3097             .Attr("T", DT_FLOAT)
3098             .Attr("Tlabels", DT_INT32));
3099   });
3100 }
3101 
TEST_F(OpTest,Split)3102 TEST_F(OpTest, Split) {
3103   Repeatedly([this]() {
3104     auto type = Choose<DataType>(kAllXlaTypes);
3105     std::vector<int64> dims = RandomDims(1);
3106     std::uniform_int_distribution<int> ud;
3107     int32 dim = std::uniform_int_distribution<int32>(
3108         -static_cast<int32>(dims.size()),
3109         static_cast<int32>(dims.size()) - 1)(generator());
3110     int n = std::uniform_int_distribution<int>(1, 5)(generator());
3111     // Ensure 'dim' is evenly divisible by 'n'.
3112     dims[dim] /= n;
3113     dims[dim] *= n;
3114     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Split")
3115                                              .Input(test::AsScalar<int32>(dim))
3116                                              .RandomInput(type, dims)
3117                                              .Attr("T", type)
3118                                              .Attr("num_split", n));
3119   });
3120 }
3121 
TEST_F(OpTest,Sqrt)3122 TEST_F(OpTest, Sqrt) {
3123   Repeatedly([this]() {
3124     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
3125     return ExpectTfAndXlaOutputsAreClose(
3126         OpTestBuilder("Sqrt").RandomInput(type).Attr("T", type));
3127   });
3128 }
3129 
TEST_F(OpTest,SqrtGrad)3130 TEST_F(OpTest, SqrtGrad) {
3131   Repeatedly([this]() {
3132     auto dims = RandomDims();
3133     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
3134     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SqrtGrad")
3135                                              .RandomInput(type, dims)
3136                                              .RandomInput(type, dims)
3137                                              .Attr("T", type));
3138   });
3139 }
3140 
TEST_F(OpTest,SquaredDifference)3141 TEST_F(OpTest, SquaredDifference) {
3142   Repeatedly([this]() {
3143     auto dims = BroadcastableDims();
3144     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SquaredDifference")
3145                                              .RandomInput(DT_FLOAT, dims.first)
3146                                              .RandomInput(DT_FLOAT, dims.second)
3147                                              .Attr("T", DT_FLOAT));
3148   });
3149 }
3150 
TEST_F(OpTest,Square)3151 TEST_F(OpTest, Square) {
3152   Repeatedly([this]() {
3153     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3154     return ExpectTfAndXlaOutputsAreClose(
3155         OpTestBuilder("Square").RandomInput(type).Attr("T", type));
3156   });
3157 }
3158 
TEST_F(OpTest,Squeeze)3159 TEST_F(OpTest, Squeeze) {
3160   Repeatedly([this]() {
3161     auto type = Choose<DataType>(kAllXlaTypes);
3162     std::vector<int64> t_dims = RandomDims(0, kDefaultMaxRank, 0, 5);
3163     std::bernoulli_distribution random_bool;
3164     std::vector<int> squeeze_dims;
3165     for (int i = 0; i < t_dims.size(); ++i) {
3166       if (t_dims[i] == 1 && random_bool(generator())) {
3167         squeeze_dims.push_back(i);
3168       }
3169     }
3170     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Squeeze")
3171                                              .RandomInput(type, t_dims)
3172                                              .Attr("squeeze_dims", squeeze_dims)
3173                                              .Attr("T", type));
3174   });
3175 }
3176 
TEST_F(OpTest,Sub)3177 TEST_F(OpTest, Sub) {
3178   Repeatedly([this]() {
3179     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3180     auto dims = BroadcastableDims();
3181     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sub")
3182                                              .RandomInput(type, dims.first)
3183                                              .RandomInput(type, dims.second)
3184                                              .Attr("T", type));
3185   });
3186 }
3187 
TEST_F(OpTest,Sum)3188 TEST_F(OpTest, Sum) {
3189   Repeatedly([this]() {
3190     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3191     std::vector<int64> data_dims = RandomDims();
3192     Tensor indices = RandomReductionIndices(data_dims.size());
3193     bool keep_dims = Choose<bool>({false, true});
3194     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sum")
3195                                              .RandomInput(type, data_dims)
3196                                              .Input(indices)
3197                                              .Attr("T", type)
3198                                              .Attr("keep_dims", keep_dims));
3199   });
3200 }
3201 
TEST_F(OpTest,StridedSlice)3202 TEST_F(OpTest, StridedSlice) {
3203   Repeatedly([this]() {
3204     auto type = Choose<DataType>(kAllXlaTypes);
3205     std::vector<int64> data_dims = RandomDims();
3206     std::vector<int32> begin(data_dims.size()), end(data_dims.size());
3207     std::vector<int32> strides(data_dims.size());
3208     for (int i = 0; i < data_dims.size(); ++i) {
3209       begin[i] = std::uniform_int_distribution<int32>(
3210           -2 * data_dims[i], 2 * data_dims[i])(generator());
3211       end[i] = std::uniform_int_distribution<int32>(
3212           -2 * data_dims[i], 2 * data_dims[i])(generator());
3213       // TODO(b/31360685): support strides other than 1 or -1
3214       strides[i] = std::bernoulli_distribution()(generator()) ? 1 : -1;
3215     }
3216     int64 max_bitmask = (1LL << data_dims.size()) - 1;
3217     std::uniform_int_distribution<int64> bitmask_distribution(0, max_bitmask);
3218     int64 begin_mask = bitmask_distribution(generator());
3219     int64 end_mask = bitmask_distribution(generator());
3220 
3221     // Create a ellipsis bitmask with at most one 1 bit set.
3222     int64 ellipsis_mask = 0;
3223     if (!data_dims.empty() && std::bernoulli_distribution()(generator())) {
3224       int ellipsis_pos = std::uniform_int_distribution<int>(
3225           0, data_dims.size() - 1)(generator());
3226       ellipsis_mask = 1LL << ellipsis_pos;
3227     }
3228 
3229     int64 new_axis_mask = bitmask_distribution(generator());
3230     int64 shrink_axis_mask = bitmask_distribution(generator());
3231     return ExpectTfAndXlaOutputsAreClose(
3232         OpTestBuilder("StridedSlice")
3233             .RandomInput(type, data_dims)
3234             .Input(test::AsTensor<int32>(begin))
3235             .Input(test::AsTensor<int32>(end))
3236             .Input(test::AsTensor<int32>(strides))
3237             .Attr("T", type)
3238             .Attr("Index", DT_INT32)
3239             .Attr("begin_mask", begin_mask)
3240             .Attr("end_mask", end_mask)
3241             .Attr("ellipsis_mask", ellipsis_mask)
3242             .Attr("new_axis_mask", new_axis_mask)
3243             .Attr("shrink_axis_mask", shrink_axis_mask));
3244   });
3245 }
3246 
TEST_F(OpTest,StridedSliceGrad)3247 TEST_F(OpTest, StridedSliceGrad) {
3248   Repeatedly([this]() {
3249     auto type = Choose<DataType>(kAllXlaTypes);
3250 
3251     // Dimensions of the forward input.
3252     std::vector<int64> dims = RandomDims();
3253 
3254     std::vector<int64> begin(dims.size()), end(dims.size());
3255     std::vector<int64> strides(dims.size());
3256     for (int i = 0; i < dims.size(); ++i) {
3257       begin[i] = std::uniform_int_distribution<int64>(-2 * dims[i],
3258                                                       2 * dims[i])(generator());
3259       end[i] = std::uniform_int_distribution<int64>(-2 * dims[i],
3260                                                     2 * dims[i])(generator());
3261       strides[i] = std::uniform_int_distribution<int64>(
3262           -2 * dims[i], 2 * dims[i])(generator());
3263     }
3264     int64 max_bitmask = (1LL << dims.size()) - 1;
3265     std::uniform_int_distribution<int64> bitmask_distribution(0, max_bitmask);
3266     int64 begin_mask = bitmask_distribution(generator());
3267     int64 end_mask = bitmask_distribution(generator());
3268 
3269     // Create a ellipsis bitmask with at most one 1 bit set.
3270     int64 ellipsis_mask = 0;
3271     if (!dims.empty() && std::bernoulli_distribution()(generator())) {
3272       int ellipsis_pos =
3273           std::uniform_int_distribution<int>(0, dims.size() - 1)(generator());
3274       ellipsis_mask = 1LL << ellipsis_pos;
3275     }
3276 
3277     int64 new_axis_mask = bitmask_distribution(generator());
3278     int64 shrink_axis_mask = bitmask_distribution(generator());
3279 
3280     // TODO(phawkins): use shape inference for the forward op to compute the
3281     // gradient shape for the backward op. At present, there is a low
3282     // probability of the golden op succeeding.
3283     return ExpectTfAndXlaOutputsAreClose(
3284         OpTestBuilder("StridedSliceGrad")
3285             .Input(test::AsTensor<int64>(dims))
3286             .Input(test::AsTensor<int64>(begin))
3287             .Input(test::AsTensor<int64>(end))
3288             .Input(test::AsTensor<int64>(strides))
3289             .RandomInput(type, RandomDims(1))
3290             .Attr("T", type)
3291             .Attr("Index", DT_INT64)
3292             .Attr("begin_mask", begin_mask)
3293             .Attr("end_mask", end_mask)
3294             .Attr("ellipsis_mask", ellipsis_mask)
3295             .Attr("new_axis_mask", new_axis_mask)
3296             .Attr("shrink_axis_mask", shrink_axis_mask));
3297   });
3298 }
3299 
TEST_F(OpTest,Tan)3300 TEST_F(OpTest, Tan) {
3301   Repeatedly([this]() {
3302     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
3303     return ExpectTfAndXlaOutputsAreClose(
3304         OpTestBuilder("Tan").RandomInput(type).Attr("T", type));
3305   });
3306 }
3307 
TEST_F(OpTest,Tanh)3308 TEST_F(OpTest, Tanh) {
3309   Repeatedly([this]() {
3310     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
3311     return ExpectTfAndXlaOutputsAreClose(
3312         OpTestBuilder("Tanh").RandomInput(type).Attr("T", type));
3313   });
3314 }
3315 
TEST_F(OpTest,TanhGrad)3316 TEST_F(OpTest, TanhGrad) {
3317   Repeatedly([this]() {
3318     auto dims = RandomDims();
3319     auto type = Choose<DataType>({DT_FLOAT, DT_COMPLEX64});
3320     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TanhGrad")
3321                                              .RandomInput(type, dims)
3322                                              .RandomInput(type, dims)
3323                                              .Attr("T", type));
3324   });
3325 }
3326 
TEST_F(OpTest,Tile)3327 TEST_F(OpTest, Tile) {
3328   Repeatedly([this]() {
3329     auto type = Choose<DataType>(kAllXlaTypes);
3330     std::vector<int64> t_dims = RandomDims(1);
3331     std::vector<int32> multiples(t_dims.size());
3332     for (int i = 0; i < t_dims.size(); ++i) {
3333       multiples[i] = std::uniform_int_distribution<int>(1, 3)(generator());
3334     }
3335     return ExpectTfAndXlaOutputsAreClose(
3336         OpTestBuilder("Tile")
3337             .RandomInput(type, t_dims)
3338             .Input(test::AsTensor<int32>(multiples))
3339             .Attr("T", type));
3340   });
3341 }
3342 
TEST_F(OpTest,Transpose)3343 TEST_F(OpTest, Transpose) {
3344   Repeatedly([this]() {
3345     auto type = Choose<DataType>(kAllXlaTypes);
3346     std::vector<int64> data_dims = RandomDims();
3347     std::vector<int32> perm(data_dims.size());
3348     std::iota(perm.begin(), perm.end(), 0);
3349     std::shuffle(perm.begin(), perm.end(), generator());
3350     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Transpose")
3351                                              .RandomInput(type, data_dims)
3352                                              .Input(test::AsTensor<int32>(perm))
3353                                              .Attr("T", type));
3354   });
3355 }
3356 
TEST_F(OpTest,TruncateDiv)3357 TEST_F(OpTest, TruncateDiv) {
3358   Repeatedly([this]() {
3359     DataType type = DT_INT32;
3360     auto dims = BroadcastableDims();
3361     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TruncateDiv")
3362                                              .RandomInput(type, dims.first)
3363                                              .RandomInput(type, dims.second)
3364                                              .Attr("T", type));
3365   });
3366 }
3367 
TEST_F(OpTest,TruncateMod)3368 TEST_F(OpTest, TruncateMod) {
3369   Repeatedly([this]() {
3370     auto type = Choose<DataType>({DT_INT32, DT_FLOAT});
3371     auto dims = BroadcastableDims();
3372     return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TruncateMod")
3373                                              .RandomInput(type, dims.first)
3374                                              .RandomInput(type, dims.second)
3375                                              .Attr("T", type));
3376   });
3377 }
3378 
TEST_F(OpTest,ZerosLike)3379 TEST_F(OpTest, ZerosLike) {
3380   Repeatedly([this]() {
3381     auto type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_COMPLEX64});
3382     return ExpectTfAndXlaOutputsAreClose(
3383         OpTestBuilder("ZerosLike").RandomInput(type).Attr("T", type));
3384   });
3385 }
3386 
3387 // Example failing run:
3388 //   --tf_xla_reference_device=GPU:0
3389 //   --tf_xla_test_use_jit=true --tf_xla_test_device=GPU:0
3390 //   --tf_xla_test_repetitions=2
3391 //   --gunit_filter='OpTest.FusedBatchNormTraining'
3392 //   --tf_xla_random_seed=2838146746
TEST_F(OpTest,FusedBatchNormTraining)3393 TEST_F(OpTest, FusedBatchNormTraining) {
3394   bool is_nhwc = RandomBool();
3395   std::vector<int64> x_dims = RandomDims(/*min_rank=*/4, /*max_rank=*/4,
3396                                          /*min_size=*/5, /*max_size=*/20);
3397   std::vector<int64> scale_dims = {x_dims[is_nhwc ? 3 : 1]};
3398   std::vector<int64> offset_dims = {x_dims[is_nhwc ? 3 : 1]};
3399   std::vector<int64> mean_dims = {0};
3400   std::vector<int64> variance_dims = {0};
3401   DataType type = DT_FLOAT;
3402   Repeatedly([&] {
3403     return ExpectTfAndXlaOutputsAreClose(
3404         OpTestBuilder("FusedBatchNorm")
3405             .RandomInput(type, x_dims)
3406             .RandomInput(type, scale_dims)
3407             .RandomInput(type, offset_dims)
3408             .RandomInput(type, mean_dims)
3409             .RandomInput(type, variance_dims)
3410             .Attr("T", type)
3411             .Attr("data_format", is_nhwc ? "NHWC" : "NCHW")
3412             .Attr("epsilon", static_cast<float>(1.001e-05))
3413             .Attr("is_training", true));
3414   });
3415 }
3416 }  // anonymous namespace
3417 }  // namespace tensorflow
3418 
main(int argc,char ** argv)3419 int main(int argc, char** argv) {
3420   tensorflow::tf_xla_test_device_ptr = new tensorflow::string("GPU:0");
3421   tensorflow::tf_xla_reference_device_ptr = new tensorflow::string("CPU:0");
3422   std::vector<tensorflow::Flag> flag_list = {
3423       tensorflow::Flag(
3424           "tf_xla_random_seed", &tensorflow::tf_xla_random_seed,
3425           "Random seed to use for XLA tests. <= 0 means choose a seed "
3426           "nondetermistically."),
3427       // TODO(phawkins): it might make more sense to run each test up to a
3428       // configurable time bound.
3429       tensorflow::Flag("tf_xla_test_repetitions",
3430                        &tensorflow::tf_xla_test_repetitions,
3431                        "Number of repetitions for each test."),
3432       tensorflow::Flag("tf_xla_max_tensor_size",
3433                        &tensorflow::tf_xla_max_tensor_size,
3434                        "Maximum number of elements for random input tensors."),
3435       tensorflow::Flag("tf_xla_test_device", tensorflow::tf_xla_test_device_ptr,
3436                        "Tensorflow device type to use for test"),
3437       tensorflow::Flag("tf_xla_reference_device",
3438                        tensorflow::tf_xla_reference_device_ptr,
3439                        "Tensorflow device type to use for reference"),
3440       tensorflow::Flag("tf_xla_test_use_jit", &tensorflow::tf_xla_test_use_jit,
3441                        "Use JIT compilation for the operator under test"),
3442   };
3443   tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
3444   const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
3445   if (!parse_result) {
3446     LOG(ERROR) << "\n" << usage;
3447     return 2;
3448   }
3449   testing::InitGoogleTest(&argc, argv);
3450   if (argc > 1) {
3451     LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
3452     return 2;
3453   }
3454   // XLA devices register kernels at construction time; create all known devices
3455   // to make sure the kernels are registered.
3456   std::vector<std::unique_ptr<tensorflow::Device>> devices;
3457   TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices(
3458       tensorflow::SessionOptions(), "", &devices));
3459   tensorflow::DeviceMgr device_mgr(std::move(devices));
3460 
3461   tensorflow::Device* ignored;
3462   TF_QCHECK_OK(
3463       device_mgr.LookupDevice(*tensorflow::tf_xla_test_device_ptr, &ignored))
3464       << "Unknown test device (" << *tensorflow::tf_xla_test_device_ptr
3465       << "). Did you build in the right configuration (e.g., is CUDA enabled)?";
3466 
3467   return RUN_ALL_TESTS();
3468 }
3469