1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the Python extension-based XLA client."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import itertools
23import threading
24
25import numpy as np
26
27from tensorflow.compiler.xla import xla_data_pb2
28from tensorflow.compiler.xla.python import custom_call_for_test
29from tensorflow.compiler.xla.python import xla_client
30import unittest
31
32
33class EnumTest(unittest.TestCase):
34  """Verifies Python enumerations match their protocol buffer equivalents."""
35
36  def testPrimitiveType(self):
37    for name, value in xla_client.PrimitiveType.__members__.items():
38      self.assertEqual(value, getattr(xla_data_pb2, name))
39
40  def testFormat(self):
41    for name, value in xla_client.Format.__members__.items():
42      self.assertEqual(value, getattr(xla_data_pb2, name))
43
44
45class ComputationTest(unittest.TestCase):
46  """Base class for running an XLA Computation through the local client."""
47
48  def _NewComputation(self, name=None):
49    if name is None:
50      name = self.id()
51    return xla_client.ComputationBuilder(name)
52
53  def _Execute(self, c, arguments):
54    compiled_c = c.Build().CompileWithExampleArguments(arguments)
55    return compiled_c.ExecuteWithPythonValues(arguments)
56
57  def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected):
58    assert expected is not None
59    result = self._Execute(c, arguments)
60    # Numpy's comparison methods are a bit too lenient by treating inputs as
61    # "array-like", meaning that scalar 4 will be happily compared equal to
62    # [[4]]. We'd like to be more strict so assert shapes as well.
63    self.assertEqual(np.asanyarray(result).shape, np.asanyarray(expected).shape)
64    assert_func(result, expected)
65
66  def _ExecuteAndCompareExact(self, c, arguments=(), expected=None):
67    self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected)
68
69  def _ExecuteAndCompareClose(self, c, arguments=(), expected=None, rtol=1e-7,
70                              atol=0):
71    self._ExecuteAndAssertWith(
72        functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol),
73        c, arguments, expected)
74
75
76def NumpyArrayF32(*args, **kwargs):
77  """Convenience wrapper to create Numpy arrays with a np.float32 dtype."""
78  return np.array(*args, dtype=np.float32, **kwargs)
79
80
81def NumpyArrayF64(*args, **kwargs):
82  """Convenience wrapper to create Numpy arrays with a np.float64 dtype."""
83  return np.array(*args, dtype=np.float64, **kwargs)
84
85
86def NumpyArrayS32(*args, **kwargs):
87  """Convenience wrapper to create Numpy arrays with a np.int32 dtype."""
88  return np.array(*args, dtype=np.int32, **kwargs)
89
90
91def NumpyArrayS64(*args, **kwargs):
92  """Convenience wrapper to create Numpy arrays with a np.int64 dtype."""
93  return np.array(*args, dtype=np.int64, **kwargs)
94
95
96def NumpyArrayBool(*args, **kwargs):
97  """Convenience wrapper to create Numpy arrays with a np.bool dtype."""
98  return np.array(*args, dtype=np.bool, **kwargs)
99
100
101class ComputationPrinting(unittest.TestCase):
102
103  def ExampleComputation(self):
104    builder = xla_client.ComputationBuilder("acomputation")
105    p0 = builder.ParameterFromNumpy(np.float32(0))
106    p1 = builder.ParameterFromNumpy(np.zeros((4,), np.float32))
107    builder.Mul(p0, p1)
108    return builder.Build()
109
110  def testComputationToHloText(self):
111    computation = self.ExampleComputation()
112    hlo_text = computation.GetHloText()
113    self.assertTrue(hlo_text.startswith("HloModule acomputation"))
114
115  def testComputationToHloGraph(self):
116    computation = self.ExampleComputation()
117    hlo_dot_graph = computation.GetHloDotGraph()
118    self.assertTrue(hlo_dot_graph.startswith("digraph "))
119
120
121class ComputationsWithConstantsTest(ComputationTest):
122  """Tests focusing on Constant ops."""
123
124  def testConstantScalarSumS8(self):
125    c = self._NewComputation()
126    root = c.Add(c.Constant(np.int8(1)), c.Constant(np.int8(2)))
127    self.assertEqual(c.GetShape(root), c.GetReturnValueShape())
128    self._ExecuteAndCompareExact(c, expected=np.int8(3))
129
130  def testConstantScalarSumF32(self):
131    c = self._NewComputation()
132    root = c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14))
133    self.assertEqual(c.GetShape(root), c.GetReturnValueShape())
134    self._ExecuteAndCompareClose(c, expected=4.25)
135
136  def testConstantScalarSumF64(self):
137    c = self._NewComputation()
138    c.Add(c.ConstantF64Scalar(1.11), c.ConstantF64Scalar(3.14))
139    self._ExecuteAndCompareClose(c, expected=4.25)
140
141  def testConstantScalarSumS32(self):
142    c = self._NewComputation()
143    c.Add(c.ConstantS32Scalar(1), c.ConstantS32Scalar(2))
144    self._ExecuteAndCompareClose(c, expected=3)
145
146  def testConstantScalarSumS64(self):
147    c = self._NewComputation()
148    c.Add(c.ConstantS64Scalar(1), c.ConstantS64Scalar(2))
149    self._ExecuteAndCompareClose(c, expected=3)
150
151  def testConstantVectorMulF32(self):
152    c = self._NewComputation()
153    c.Mul(
154        c.Constant(NumpyArrayF32([2.5, 3.3, -1.2, 0.7])),
155        c.Constant(NumpyArrayF32([-1.2, 2, -2, -3])))
156    self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1])
157
158  def testConstantVectorMulF64(self):
159    c = self._NewComputation()
160    c.Mul(
161        c.Constant(NumpyArrayF64([2.5, 3.3, -1.2, 0.7])),
162        c.Constant(NumpyArrayF64([-1.2, 2, -2, -3])))
163    self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1])
164
165  def testConstantVectorScalarDivF32(self):
166    c = self._NewComputation()
167    c.Div(
168        c.Constant(NumpyArrayF32([1.5, 2.5, 3.0, -10.8])),
169        c.ConstantF32Scalar(2.0))
170    self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4])
171
172  def testConstantVectorScalarDivF64(self):
173    c = self._NewComputation()
174    c.Div(
175        c.Constant(NumpyArrayF64([1.5, 2.5, 3.0, -10.8])),
176        c.ConstantF64Scalar(2.0))
177    self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4])
178
179  def testConstantVectorScalarPowF32(self):
180    c = self._NewComputation()
181    c.Pow(c.Constant(NumpyArrayF32([1.5, 2.5, 3.0])), c.ConstantF32Scalar(2.))
182    self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.])
183
184  def testConstantVectorScalarPowF64(self):
185    c = self._NewComputation()
186    c.Pow(c.Constant(NumpyArrayF64([1.5, 2.5, 3.0])), c.ConstantF64Scalar(2.))
187    self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.])
188
189  def testIota(self):
190    c = self._NewComputation()
191    c.Iota(np.float32, 10)
192    self._ExecuteAndCompareExact(c, expected=np.arange(10, dtype=np.float32))
193
194  def testBroadcastedIota(self):
195    c = self._NewComputation()
196    c.BroadcastedIota(np.int64, (2, 3), 1)
197    expected = np.array([[0, 1, 2], [0, 1, 2]], dtype=np.int64)
198    self._ExecuteAndCompareExact(c, expected=expected)
199
200  def testBooleanAnd(self):
201    c = self._NewComputation()
202    c.And(
203        c.Constant(NumpyArrayBool([True, False, True, False])),
204        c.Constant(NumpyArrayBool([True, True, False, False])))
205    self._ExecuteAndCompareExact(c, expected=[True, False, False, False])
206
207  def testBooleanOr(self):
208    c = self._NewComputation()
209    c.Or(
210        c.Constant(NumpyArrayBool([True, False, True, False])),
211        c.Constant(NumpyArrayBool([True, True, False, False])))
212    self._ExecuteAndCompareExact(c, expected=[True, True, True, False])
213
214  def testBooleanXor(self):
215    c = self._NewComputation()
216    c.Xor(
217        c.Constant(NumpyArrayBool([True, False, True, False])),
218        c.Constant(NumpyArrayBool([True, True, False, False])))
219    self._ExecuteAndCompareExact(c, expected=[False, True, True, False])
220
221  def testSum2DF32(self):
222    c = self._NewComputation()
223    c.Add(
224        c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6]])),
225        c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]])))
226    self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]])
227
228  def testShiftLeft(self):
229    c = self._NewComputation()
230    c.ShiftLeft(c.Constant(NumpyArrayS32([3])),
231                c.Constant(NumpyArrayS32([2])))
232    self._ExecuteAndCompareClose(c, expected=[12])
233
234  def testShiftRightArithmetic(self):
235    c = self._NewComputation()
236    c.ShiftRightArithmetic(c.Constant(NumpyArrayS32([-2])),
237                           c.Constant(NumpyArrayS32([1])))
238    self._ExecuteAndCompareClose(c, expected=[-1])
239
240  def testShiftRightLogical(self):
241    c = self._NewComputation()
242    c.ShiftRightLogical(c.Constant(NumpyArrayS32([-1])),
243                        c.Constant(NumpyArrayS32([1])))
244    self._ExecuteAndCompareClose(c, expected=[2**31 - 1])
245
246  def testSum2DF64(self):
247    c = self._NewComputation()
248    c.Add(
249        c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6]])),
250        c.Constant(NumpyArrayF64([[1, -1, 1], [-1, 1, -1]])))
251    self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]])
252
253  def testSum2DWith1DBroadcastDim0F32(self):
254    # sum of a 2D array with a 1D array where the latter is replicated across
255    # dimension 0 to match the former's shape.
256    c = self._NewComputation()
257    c.Add(
258        c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
259        c.Constant(NumpyArrayF32([10, 20, 30])),
260        broadcast_dimensions=(0,))
261    self._ExecuteAndCompareClose(
262        c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]])
263
264  def testSum2DWith1DBroadcastDim0F64(self):
265    # sum of a 2D array with a 1D array where the latter is replicated across
266    # dimension 0 to match the former's shape.
267    c = self._NewComputation()
268    c.Add(
269        c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
270        c.Constant(NumpyArrayF64([10, 20, 30])),
271        broadcast_dimensions=(0,))
272    self._ExecuteAndCompareClose(
273        c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]])
274
275  def testSum2DWith1DBroadcastDim1F32(self):
276    # sum of a 2D array with a 1D array where the latter is replicated across
277    # dimension 1 to match the former's shape.
278    c = self._NewComputation()
279    c.Add(
280        c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
281        c.Constant(NumpyArrayF32([10, 20, 30])),
282        broadcast_dimensions=(1,))
283    self._ExecuteAndCompareClose(
284        c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]])
285
286  def testSum2DWith1DBroadcastDim1F64(self):
287    # sum of a 2D array with a 1D array where the latter is replicated across
288    # dimension 1 to match the former's shape.
289    c = self._NewComputation()
290    c.Add(
291        c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
292        c.Constant(NumpyArrayF64([10, 20, 30])),
293        broadcast_dimensions=(1,))
294    self._ExecuteAndCompareClose(
295        c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]])
296
297  def testConstantAxpyF32(self):
298    c = self._NewComputation()
299    c.Add(
300        c.Mul(
301            c.ConstantF32Scalar(2),
302            c.Constant(NumpyArrayF32([2.2, 3.3, 4.4, 5.5]))),
303        c.Constant(NumpyArrayF32([100, -100, 200, -200])))
304    self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189])
305
306  def testConstantAxpyF64(self):
307    c = self._NewComputation()
308    c.Add(
309        c.Mul(
310            c.ConstantF64Scalar(2),
311            c.Constant(NumpyArrayF64([2.2, 3.3, 4.4, 5.5]))),
312        c.Constant(NumpyArrayF64([100, -100, 200, -200])))
313    self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189])
314
315  def testCustomCall(self):
316    c = self._NewComputation()
317    for name, fn in custom_call_for_test.cpu_custom_call_targets.items():
318      xla_client.register_cpu_custom_call_target(name, fn)
319    c.CustomCall(
320        b"test_subtract_f32",
321        operands=(c.ConstantF32Scalar(1.25), c.ConstantF32Scalar(0.5)),
322        shape_with_layout=xla_client.Shape.array_shape(np.float32, (), ()),
323        operand_shapes_with_layout=(
324            xla_client.Shape.array_shape(np.float32, (), ()),
325            xla_client.Shape.array_shape(np.float32, (), ()),
326        ))
327    self._ExecuteAndCompareClose(c, expected=0.75)
328
329
330class ParametersTest(ComputationTest):
331  """Tests focusing on Parameter ops and argument-passing."""
332
333  def setUp(self):
334    self.f32_scalar_2 = NumpyArrayF32(2.0)
335    self.f32_4vector = NumpyArrayF32([-2.3, 3.3, -4.3, 5.3])
336    self.f64_scalar_2 = NumpyArrayF64(2.0)
337    self.f64_4vector = NumpyArrayF64([-2.3, 3.3, -4.3, 5.3])
338    self.s32_scalar_3 = NumpyArrayS32(3)
339    self.s32_4vector = NumpyArrayS32([10, 15, -2, 7])
340    self.s64_scalar_3 = NumpyArrayS64(3)
341    self.s64_4vector = NumpyArrayS64([10, 15, -2, 7])
342
343  def testScalarTimesVectorAutonumberF32(self):
344    c = self._NewComputation()
345    p0 = c.ParameterFromNumpy(self.f32_scalar_2)
346    p1 = c.ParameterFromNumpy(self.f32_4vector)
347    c.Mul(p0, p1)
348    self._ExecuteAndCompareClose(
349        c,
350        arguments=[self.f32_scalar_2, self.f32_4vector],
351        expected=[-4.6, 6.6, -8.6, 10.6])
352
353  def testScalarTimesVectorAutonumberF64(self):
354    c = self._NewComputation()
355    p0 = c.ParameterFromNumpy(self.f64_scalar_2)
356    p1 = c.ParameterFromNumpy(self.f64_4vector)
357    c.Mul(p0, p1)
358    self._ExecuteAndCompareClose(
359        c,
360        arguments=[self.f64_scalar_2, self.f64_4vector],
361        expected=[-4.6, 6.6, -8.6, 10.6])
362
363  def testScalarTimesVectorS32(self):
364    c = self._NewComputation()
365    p0 = c.ParameterFromNumpy(self.s32_scalar_3)
366    p1 = c.ParameterFromNumpy(self.s32_4vector)
367    c.Mul(p0, p1)
368    self._ExecuteAndCompareExact(
369        c,
370        arguments=[self.s32_scalar_3, self.s32_4vector],
371        expected=[30, 45, -6, 21])
372
373  def testScalarTimesVectorS64(self):
374    c = self._NewComputation()
375    p0 = c.ParameterFromNumpy(self.s64_scalar_3)
376    p1 = c.ParameterFromNumpy(self.s64_4vector)
377    c.Mul(p0, p1)
378    self._ExecuteAndCompareExact(
379        c,
380        arguments=[self.s64_scalar_3, self.s64_4vector],
381        expected=[30, 45, -6, 21])
382
383  def testScalarMinusVectorExplicitNumberingF32(self):
384    # Use explicit numbering and pass parameter_num first. Sub is used since
385    # it's not commutative and can help catch parameter reversal within the
386    # computation.
387    c = self._NewComputation()
388    p1 = c.ParameterFromNumpy(self.f32_4vector, parameter_num=1)
389    p0 = c.ParameterFromNumpy(self.f32_scalar_2, parameter_num=0)
390    c.Sub(p1, p0)
391    self._ExecuteAndCompareClose(
392        c,
393        arguments=[self.f32_scalar_2, self.f32_4vector],
394        expected=[-4.3, 1.3, -6.3, 3.3])
395
396  def testScalarMinusVectorExplicitNumberingF64(self):
397    # Use explicit numbering and pass parameter_num first. Sub is used since
398    # it's not commutative and can help catch parameter reversal within the
399    # computation.
400    c = self._NewComputation()
401    p1 = c.ParameterFromNumpy(self.f64_4vector, parameter_num=1)
402    p0 = c.ParameterFromNumpy(self.f64_scalar_2, parameter_num=0)
403    c.Sub(p1, p0)
404    self._ExecuteAndCompareClose(
405        c,
406        arguments=[self.f64_scalar_2, self.f64_4vector],
407        expected=[-4.3, 1.3, -6.3, 3.3])
408
409
410class LocalBufferTest(ComputationTest):
411  """Tests focusing on execution with LocalBuffers."""
412
413  def _Execute(self, c, arguments):
414    compiled_c = c.Build().CompileWithExampleArguments(arguments)
415    arg_buffers = [xla_client.LocalBuffer.from_pyval(arg) for arg in arguments]
416    result_buffer = compiled_c.Execute(arg_buffers)
417    return result_buffer.to_py()
418
419  def testConstantSum(self):
420    c = self._NewComputation()
421    c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14))
422    self._ExecuteAndCompareClose(c, expected=4.25)
423
424  def testOneParameterSum(self):
425    c = self._NewComputation()
426    c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14))
427    self._ExecuteAndCompareClose(
428        c,
429        arguments=[NumpyArrayF32(1.11)],
430        expected=4.25)
431
432  def testTwoParameterSum(self):
433    c = self._NewComputation()
434    c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)),
435          c.ParameterFromNumpy(NumpyArrayF32(0.)))
436    self._ExecuteAndCompareClose(
437        c,
438        arguments=[NumpyArrayF32(1.11), NumpyArrayF32(3.14)],
439        expected=4.25)
440
441  def testCannotCallWithDeletedBuffers(self):
442    c = self._NewComputation()
443    c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14))
444    arg = NumpyArrayF32(1.11)
445    compiled_c = c.Build().CompileWithExampleArguments([arg])
446    arg_buffer = xla_client.LocalBuffer.from_pyval(arg)
447    arg_buffer.delete()
448    with self.assertRaises(ValueError):
449      compiled_c.Execute([arg_buffer])
450
451  def testDestructureTupleEmpty(self):
452    t = ()
453    local_buffer = xla_client.LocalBuffer.from_pyval(t)
454    pieces = local_buffer.destructure()
455    self.assertTrue(local_buffer.is_deleted())
456    self.assertEqual(len(pieces), 0)
457
458  def testDestructureTupleOneArrayElement(self):
459    t = (np.array([1, 2, 3, 4], dtype=np.int32),)
460    local_buffer = xla_client.LocalBuffer.from_pyval(t)
461    pieces = local_buffer.destructure()
462    self.assertTrue(local_buffer.is_deleted())
463    self.assertEqual(len(pieces), 1)
464    array = pieces[0]
465    got = array.to_py()
466    want = NumpyArrayS32([1, 2, 3, 4])
467    np.testing.assert_equal(want, got)
468
469  def testDestructureTupleTwoArrayElementDifferentType(self):
470    t = (np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32),
471         np.array([2, 3, 4, 5], dtype=np.int32))
472    local_buffer = xla_client.LocalBuffer.from_pyval(t)
473    pieces = local_buffer.destructure()
474    self.assertTrue(local_buffer.is_deleted())
475    self.assertEqual(len(pieces), 2)
476    array0, array1 = pieces
477    got = array0.to_py()
478    want = NumpyArrayF32([1.0, 2.0, 3.0, 4.0])
479    np.testing.assert_equal(want, got)
480    got = array1.to_py()
481    want = NumpyArrayS32([2, 3, 4, 5])
482    np.testing.assert_equal(want, got)
483
484  def testDestructureTupleNested(self):
485    t = ((NumpyArrayF32([1.0, 2.0]), NumpyArrayS32([3, 4])), NumpyArrayS32([5]))
486    local_buffer = xla_client.LocalBuffer.from_pyval(t)
487    pieces = local_buffer.destructure()
488    self.assertTrue(local_buffer.is_deleted())
489    self.assertEqual(len(pieces), 2)
490    tuple0, array1 = pieces
491    got = array1.to_py()
492    want = NumpyArrayS32([5])
493    np.testing.assert_equal(want, got)
494    got = tuple0.to_py()
495    self.assertEqual(type(got), tuple)
496    self.assertEqual(len(got), 2)
497    np.testing.assert_equal(NumpyArrayF32([1.0, 2.0]), got[0])
498    np.testing.assert_equal(NumpyArrayS32([3, 4]), got[1])
499
500  def testShape(self):
501    pyval = np.array([[1., 2.]], np.float32)
502    local_buffer = xla_client.LocalBuffer.from_pyval(pyval)
503    xla_shape = local_buffer.shape()
504    self.assertEqual(xla_shape.dimensions(), (1, 2,))
505    self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32))
506
507
508class SingleOpTest(ComputationTest):
509  """Tests for single ops.
510
511  The goal here is smoke testing - to exercise the most basic functionality of
512  single XLA ops. As minimal as possible number of additional ops are added
513  around the op being tested.
514  """
515
516  def testConcatenateF32(self):
517    c = self._NewComputation()
518    c.Concatenate(
519        (c.Constant(NumpyArrayF32([1.0, 2.0, 3.0])),
520         c.Constant(NumpyArrayF32([4.0, 5.0, 6.0]))),
521        dimension=0)
522    self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
523
524  def testConcatenateF64(self):
525    c = self._NewComputation()
526    c.Concatenate(
527        (c.Constant(NumpyArrayF64([1.0, 2.0, 3.0])),
528         c.Constant(NumpyArrayF64([4.0, 5.0, 6.0]))),
529        dimension=0)
530    self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
531
532  def testConvertElementType(self):
533    xla_types = {
534        np.bool: xla_client.PrimitiveType.PRED,
535        np.int32: xla_client.PrimitiveType.S32,
536        np.int64: xla_client.PrimitiveType.S64,
537        np.float32: xla_client.PrimitiveType.F32,
538        np.float64: xla_client.PrimitiveType.F64,
539    }
540
541    def _ConvertAndTest(template, src_dtype, dst_dtype):
542      c = self._NewComputation()
543      x = c.Constant(np.array(template, dtype=src_dtype))
544      c.ConvertElementType(x, xla_types[dst_dtype])
545
546      result = c.Build().Compile().ExecuteWithPythonValues()
547      expected = np.array(template, dtype=dst_dtype)
548
549      self.assertEqual(result.shape, expected.shape)
550      self.assertEqual(result.dtype, expected.dtype)
551      np.testing.assert_equal(result, expected)
552
553    x = [0, 1, 0, 0, 1]
554    for src_dtype, dst_dtype in itertools.product(xla_types, xla_types):
555      _ConvertAndTest(x, src_dtype, dst_dtype)
556
557  def testBitcastConvertType(self):
558    xla_x32_types = {
559        np.int32: xla_client.PrimitiveType.S32,
560        np.float32: xla_client.PrimitiveType.F32,
561    }
562
563    xla_x64_types = {
564        np.int64: xla_client.PrimitiveType.S64,
565        np.float64: xla_client.PrimitiveType.F64,
566    }
567
568    def _ConvertAndTest(template, src_dtype, dst_dtype, dst_etype):
569      c = self._NewComputation()
570      x = c.Constant(np.array(template, dtype=src_dtype))
571      c.BitcastConvertType(x, dst_etype)
572
573      result = c.Build().Compile().ExecuteWithPythonValues()
574      expected = np.array(template, src_dtype).view(dst_dtype)
575
576      self.assertEqual(result.shape, expected.shape)
577      self.assertEqual(result.dtype, expected.dtype)
578      np.testing.assert_equal(result, expected)
579
580    x = [0, 1, 0, 0, 1]
581    for xla_types in [xla_x32_types, xla_x64_types]:
582      for src_dtype, dst_dtype in itertools.product(xla_types, xla_types):
583        _ConvertAndTest(x, src_dtype, dst_dtype, xla_types[dst_dtype])
584
585  # TODO(b/123523486) implement AllToAll on CPU
586  def DISABLED_testAllToAllOneReplica(self):
587    samples = [
588        NumpyArrayF32([97.0]),
589        NumpyArrayF32([64.0, 117.0]),
590        NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]),
591    ]
592    for lhs in samples[:1]:
593      c = self._NewComputation()
594      c.AllToAll(c.Constant(lhs), 0, 0)
595      self._ExecuteAndCompareExact(c, expected=lhs)
596
597  def testCrossReplicaSumOneReplica(self):
598    samples = [
599        NumpyArrayF32(42.0),
600        NumpyArrayF32([97.0]),
601        NumpyArrayF32([64.0, 117.0]),
602        NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]),
603    ]
604    for lhs in samples:
605      c = self._NewComputation()
606      c.CrossReplicaSum(c.Constant(lhs))
607      self._ExecuteAndCompareExact(c, expected=lhs)
608
609  def testReplicaId(self):
610    c = self._NewComputation()
611    _ = c.ReplicaId()
612    self._ExecuteAndCompareExact(c, expected=0)
613
614  def testCrossReplicaSumOneReplicaWithSingletonGroup(self):
615    samples = [
616        NumpyArrayF32(42.0),
617        NumpyArrayF32([97.0]),
618        NumpyArrayF32([64.0, 117.0]),
619        NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]),
620    ]
621    for lhs in samples:
622      c = self._NewComputation()
623      c.CrossReplicaSum(c.Constant(lhs), [[0]])
624      self._ExecuteAndCompareExact(c, expected=lhs)
625
626  def testDotMatrixVectorF32(self):
627    c = self._NewComputation()
628    lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]])
629    rhs = NumpyArrayF32([[10.0], [20.0]])
630    c.Dot(c.Constant(lhs), c.Constant(rhs))
631    self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
632
633  def testDotMatrixVectorF64(self):
634    c = self._NewComputation()
635    lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]])
636    rhs = NumpyArrayF64([[10.0], [20.0]])
637    c.Dot(c.Constant(lhs), c.Constant(rhs))
638    self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
639
640  def testDotMatrixMatrixF32(self):
641    c = self._NewComputation()
642    lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]])
643    rhs = NumpyArrayF32([[10.0, 20.0], [100.0, 200.0]])
644    c.Dot(c.Constant(lhs), c.Constant(rhs))
645    self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
646
647  def testDotMatrixMatrixF64(self):
648    c = self._NewComputation()
649    lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]])
650    rhs = NumpyArrayF64([[10.0, 20.0], [100.0, 200.0]])
651    c.Dot(c.Constant(lhs), c.Constant(rhs))
652    self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
653
654  def testDotGeneral(self):
655    c = self._NewComputation()
656    rng = np.random.RandomState(0)
657    lhs = NumpyArrayF32(rng.randn(10, 3, 4))
658    rhs = NumpyArrayF32(rng.randn(10, 4, 5))
659    dimension_numbers = (([2], [1]), ([0], [0]))
660    c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers)
661    self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs))
662
663  def testDotGeneralWithDotDimensionNumbersProto(self):
664    c = self._NewComputation()
665    rng = np.random.RandomState(0)
666    lhs = NumpyArrayF32(rng.randn(10, 3, 4))
667    rhs = NumpyArrayF32(rng.randn(10, 4, 5))
668
669    dimension_numbers = xla_client.DotDimensionNumbers()
670    dimension_numbers.lhs_contracting_dimensions.append(2)
671    dimension_numbers.rhs_contracting_dimensions.append(1)
672    dimension_numbers.lhs_batch_dimensions.append(0)
673    dimension_numbers.rhs_batch_dimensions.append(0)
674
675    c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers)
676    self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs))
677
678  def testConvF32Same(self):
679    c = self._NewComputation()
680    a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
681    lhs = a(1, 2, 3, 4)
682    rhs = a(1, 2, 1, 2) * 10
683    c.Conv(c.Constant(lhs), c.Constant(rhs),
684           [1, 1], xla_client.PaddingType.SAME)
685    result = np.array([[[[640., 700., 760., 300.],
686                         [880., 940., 1000., 380.],
687                         [1120., 1180., 1240., 460.]]]])
688    self._ExecuteAndCompareClose(c, expected=result)
689
690  def testConvF32Valid(self):
691    c = self._NewComputation()
692    a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
693    lhs = a(1, 2, 3, 4)
694    rhs = a(1, 2, 1, 2) * 10
695    c.Conv(c.Constant(lhs), c.Constant(rhs),
696           [2, 1], xla_client.PaddingType.VALID)
697    result = np.array([[[[640., 700., 760.],
698                         [1120., 1180., 1240.]]]])
699    self._ExecuteAndCompareClose(c, expected=result)
700
701  def testConvWithGeneralPaddingF32(self):
702    c = self._NewComputation()
703    a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
704    lhs = a(1, 1, 2, 3)
705    rhs = a(1, 1, 1, 2) * 10
706    strides = [1, 1]
707    pads = [(1, 0), (0, 1)]
708    lhs_dilation = (2, 1)
709    rhs_dilation = (1, 1)
710    c.ConvWithGeneralPadding(c.Constant(lhs), c.Constant(rhs),
711                             strides, pads, lhs_dilation, rhs_dilation)
712    result = np.array([[[[0., 0., 0.],
713                         [10., 20., 0.],
714                         [0., 0., 0.],
715                         [40., 50., 0.]]]])
716    self._ExecuteAndCompareClose(c, expected=result)
717
718  def testConvGeneralDilatedF32(self):
719    c = self._NewComputation()
720    a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
721    lhs = a(1, 1, 2, 3)
722    rhs = a(1, 1, 1, 2) * 10
723    strides = [1, 1]
724    pads = [(1, 0), (0, 1)]
725    lhs_dilation = (2, 1)
726    rhs_dilation = (1, 1)
727    dimension_numbers = ("NCHW", "OIHW", "NCHW")
728    c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs),
729                         strides, pads, lhs_dilation, rhs_dilation,
730                         dimension_numbers)
731    result = np.array([[[[0., 0., 0.],
732                         [10., 20., 0.],
733                         [0., 0., 0.],
734                         [40., 50., 0.]]]])
735    self._ExecuteAndCompareClose(c, expected=result)
736
737  def testConvGeneralDilatedPermutedF32(self):
738    c = self._NewComputation()
739    a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
740    lhs = a(1, 1, 2, 3)
741    rhs = a(1, 1, 1, 2) * 10
742    strides = [1, 1]
743    pads = [(1, 0), (0, 1)]
744    lhs_dilation = (2, 1)
745    rhs_dilation = (1, 1)
746
747    dimension_numbers = ("NHWC", "OIHW", "CWNH")
748    c.ConvGeneralDilated(c.Constant(np.transpose(lhs, (0, 2, 3, 1))),
749                         c.Constant(rhs),
750                         strides, pads, lhs_dilation, rhs_dilation,
751                         dimension_numbers)
752    result = np.array([[[[0., 0., 0.],
753                         [10., 20., 0.],
754                         [0., 0., 0.],
755                         [40., 50., 0.]]]])
756    self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2)))
757
758  def testConvGeneralDilatedGroupedConvolutionF32(self):
759    c = self._NewComputation()
760    a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
761    lhs = a(1, 2, 2, 3)
762    rhs = a(2, 1, 1, 2) * 10
763    strides = [1, 1]
764    pads = [(1, 0), (0, 1)]
765    lhs_dilation = (2, 1)
766    rhs_dilation = (1, 1)
767    dimension_numbers = ("NCHW", "OIHW", "NCHW")
768    feature_group_count = 2
769    c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs),
770                         strides, pads, lhs_dilation, rhs_dilation,
771                         dimension_numbers, feature_group_count)
772    result = np.array([[[[0., 0., 0.],
773                         [10., 20., 0.],
774                         [0., 0., 0.],
775                         [40., 50., 0.]],
776                        [[0., 0., 0.],
777                         [330., 380., 160.],
778                         [0., 0., 0.],
779                         [480., 530., 220.]]]])
780    self._ExecuteAndCompareClose(c, expected=result)
781
782  def testBooleanNot(self):
783    c = self._NewComputation()
784    arr = NumpyArrayBool([True, False, True])
785    c.Not(c.Constant(arr))
786    self._ExecuteAndCompareClose(c, expected=~arr)
787
788  def testCountLeadingZeros(self):
789    c = self._NewComputation()
790    arr = NumpyArrayS32([0x7FFF, 0x12345678])
791    c.Clz(c.Constant(arr))
792    self._ExecuteAndCompareClose(c, expected=[17, 3])
793
794  def testExp(self):
795    c = self._NewComputation()
796    arr = NumpyArrayF32([3.3, 12.1])
797    c.Exp(c.Constant(arr))
798    self._ExecuteAndCompareClose(c, expected=np.exp(arr))
799
800  def testExpm1(self):
801    c = self._NewComputation()
802    arr = NumpyArrayF32([3.3, 12.1])
803    c.Expm1(c.Constant(arr))
804    self._ExecuteAndCompareClose(c, expected=np.expm1(arr))
805
806  def testRound(self):
807    c = self._NewComputation()
808    arr = NumpyArrayF32([3.3, 12.1])
809    c.Round(c.Constant(arr))
810    self._ExecuteAndCompareClose(c, expected=np.round(arr))
811
812  def testLog(self):
813    c = self._NewComputation()
814    arr = NumpyArrayF32([3.3, 12.1])
815    c.Log(c.Constant(arr))
816    self._ExecuteAndCompareClose(c, expected=np.log(arr))
817
818  def testLog1p(self):
819    c = self._NewComputation()
820    arr = NumpyArrayF32([3.3, 12.1])
821    c.Log1p(c.Constant(arr))
822    self._ExecuteAndCompareClose(c, expected=np.log1p(arr))
823
824  def testNeg(self):
825    c = self._NewComputation()
826    arr = NumpyArrayF32([3.3, 12.1])
827    c.Neg(c.Constant(arr))
828    self._ExecuteAndCompareClose(c, expected=-arr)
829
830  def testFloor(self):
831    c = self._NewComputation()
832    arr = NumpyArrayF32([3.3, 12.1])
833    c.Floor(c.Constant(arr))
834    self._ExecuteAndCompareClose(c, expected=np.floor(arr))
835
836  def testCeil(self):
837    c = self._NewComputation()
838    arr = NumpyArrayF32([3.3, 12.1])
839    c.Ceil(c.Constant(arr))
840    self._ExecuteAndCompareClose(c, expected=np.ceil(arr))
841
842  def testAbs(self):
843    c = self._NewComputation()
844    arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.])
845    c.Abs(c.Constant(arr))
846    self._ExecuteAndCompareClose(c, expected=np.abs(arr))
847
848  def testTanh(self):
849    c = self._NewComputation()
850    arr = NumpyArrayF32([3.3, 12.1])
851    c.Tanh(c.Constant(arr))
852    self._ExecuteAndCompareClose(c, expected=np.tanh(arr))
853
854  def testTrans(self):
855
856    def _TransposeAndTest(array):
857      c = self._NewComputation()
858      c.Trans(c.Constant(array))
859      self._ExecuteAndCompareClose(c, expected=array.T)
860
861    # Test square and non-square matrices in both default (C) and F orders.
862    for array_fun in [NumpyArrayF32, NumpyArrayF64]:
863      _TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]]))
864      _TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]], order="F"))
865      _TransposeAndTest(array_fun([[1, 2], [4, 5]]))
866      _TransposeAndTest(array_fun([[1, 2], [4, 5]], order="F"))
867
868  def testTranspose(self):
869
870    def _TransposeAndTest(array, permutation):
871      c = self._NewComputation()
872      c.Transpose(c.Constant(array), permutation)
873      expected = np.transpose(array, permutation)
874      self._ExecuteAndCompareClose(c, expected=expected)
875
876    _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1])
877    _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0])
878    _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1])
879    _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0])
880
881    arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32)
882    for permutation in itertools.permutations(range(arr.ndim)):
883      _TransposeAndTest(arr, permutation)
884      _TransposeAndTest(np.asfortranarray(arr), permutation)
885
886  def testEq(self):
887    c = self._NewComputation()
888    c.Eq(
889        c.Constant(NumpyArrayS32([1, 2, 3, 4])),
890        c.Constant(NumpyArrayS32([4, 2, 3, 1])))
891    self._ExecuteAndCompareExact(c, expected=[False, True, True, False])
892
893  def testNe(self):
894    c = self._NewComputation()
895    c.Ne(
896        c.Constant(NumpyArrayS32([1, 2, 3, 4])),
897        c.Constant(NumpyArrayS32([4, 2, 3, 1])))
898    self._ExecuteAndCompareExact(c, expected=[True, False, False, True])
899
900    c.Ne(
901        c.Constant(NumpyArrayF32([-2.0, 0.0,
902                                  float("nan"),
903                                  float("nan")])),
904        c.Constant(NumpyArrayF32([2.0, -0.0, 1.0, float("nan")])))
905    self._ExecuteAndAssertWith(
906        np.testing.assert_allclose, c, (), expected=[True, False, True, True])
907
908  def testGt(self):
909    c = self._NewComputation()
910    c.Gt(
911        c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
912        c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
913    self._ExecuteAndCompareExact(c, expected=[False, True, True, False, False])
914
915  def testGe(self):
916    c = self._NewComputation()
917    c.Ge(
918        c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
919        c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
920    self._ExecuteAndCompareExact(c, expected=[True, True, True, False, False])
921
922  def testLt(self):
923    c = self._NewComputation()
924    c.Lt(
925        c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
926        c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
927    self._ExecuteAndCompareExact(c, expected=[False, False, False, True, True])
928
929  def testLe(self):
930    c = self._NewComputation()
931    c.Le(
932        c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
933        c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
934    self._ExecuteAndCompareExact(c, expected=[True, False, False, True, True])
935
936  def testMax(self):
937    c = self._NewComputation()
938    c.Max(
939        c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])),
940        c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0])))
941    self._ExecuteAndCompareExact(c, expected=[1.0, 2.0, 3.0, 7.0, 12.0])
942
943  def testMaxExplicitBroadcastDim0(self):
944    c = self._NewComputation()
945    c.Max(
946        c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
947        c.Constant(NumpyArrayF32([3, 4, 5])),
948        broadcast_dimensions=(0,))
949    self._ExecuteAndCompareExact(c, expected=[[3, 3, 3], [4, 5, 6], [7, 8, 9]])
950
951  def testMaxExplicitBroadcastDim1(self):
952    c = self._NewComputation()
953    c.Max(
954        c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
955        c.Constant(NumpyArrayF32([3, 4, 5])),
956        broadcast_dimensions=(1,))
957    self._ExecuteAndCompareExact(c, expected=[[3, 4, 5], [4, 5, 6], [7, 8, 9]])
958
959  def testMin(self):
960    c = self._NewComputation()
961    c.Min(
962        c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])),
963        c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0])))
964    self._ExecuteAndCompareExact(c, expected=[1.0, 0.0, 2.0, 4.0, 9.0])
965
966  def testPad(self):
967    c = self._NewComputation()
968    c.Pad(
969        c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])),
970        c.Constant(NumpyArrayF32(0.0)),
971        [(1, 2, 1), (0, 1, 0)])
972    self._ExecuteAndCompareClose(c, expected=[[0.0, 0.0, 0.0],
973                                              [1.0, 2.0, 0.0],
974                                              [0.0, 0.0, 0.0],
975                                              [3.0, 4.0, 0.0],
976                                              [0.0, 0.0, 0.0],
977                                              [0.0, 0.0, 0.0]])
978
979  def testPadWithPaddingConfig(self):
980    c = self._NewComputation()
981    padding_config = xla_client.PaddingConfig()
982    for lo, hi, interior in [(1, 2, 1), (0, 1, 0)]:
983      dimension = xla_client.PaddingConfigDimension()
984      dimension.edge_padding_low = lo
985      dimension.edge_padding_high = hi
986      dimension.interior_padding = interior
987      padding_config.dimensions.append(dimension)
988    c.Pad(
989        c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])),
990        c.Constant(NumpyArrayF32(0.0)),
991        padding_config)
992    self._ExecuteAndCompareClose(c, expected=[[0.0, 0.0, 0.0],
993                                              [1.0, 2.0, 0.0],
994                                              [0.0, 0.0, 0.0],
995                                              [3.0, 4.0, 0.0],
996                                              [0.0, 0.0, 0.0],
997                                              [0.0, 0.0, 0.0]])
998
999  def testReshape(self):
1000    c = self._NewComputation()
1001    c.Reshape(
1002        c.Constant(NumpyArrayS32([[1, 2], [3, 4], [5, 6]])),
1003        dimensions=[0, 1],
1004        new_sizes=[2, 3])
1005    self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 5, 6]])
1006
1007  def testCollapse(self):
1008    c = self._NewComputation()
1009    c.Collapse(
1010        c.Constant(NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])),
1011        dimensions=[1, 2])
1012    self._ExecuteAndCompareExact(c, expected=[[1, 2, 3, 4], [5, 6, 7, 8]])
1013
1014  def testRev(self):
1015    c = self._NewComputation()
1016    c.Rev(
1017        c.Constant(NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])),
1018        dimensions=[0, 2])
1019    self._ExecuteAndCompareExact(
1020        c, expected=[[[6, 5], [8, 7]], [[2, 1], [4, 3]]])
1021
1022  def testClampF32(self):
1023    c = self._NewComputation()
1024    c.Clamp(
1025        c.Constant(NumpyArrayF32(-1)),
1026        c.Constant(NumpyArrayF32([-2, -1, 0, 1, 2, 3])),
1027        c.Constant(NumpyArrayF32(2)))
1028    self._ExecuteAndCompareExact(c, expected=[-1, -1, 0, 1, 2, 2])
1029
1030  def testClampS32(self):
1031    c = self._NewComputation()
1032    c.Clamp(
1033        c.Constant(NumpyArrayS32(-1)),
1034        c.Constant(NumpyArrayS32([-2, -1, 0, 1, 2, 3])),
1035        c.Constant(NumpyArrayS32(2)))
1036    self._ExecuteAndCompareExact(c, expected=[-1, -1, 0, 1, 2, 2])
1037
1038  def testSelect(self):
1039    c = self._NewComputation()
1040    c.Select(
1041        c.Constant(NumpyArrayBool([True, False, False, True, False])),
1042        c.Constant(NumpyArrayS32([1, 2, 3, 4, 5])),
1043        c.Constant(NumpyArrayS32([-1, -2, -3, -4, -5])))
1044    self._ExecuteAndCompareExact(c, expected=[1, -2, -3, 4, -5])
1045
1046  def testSlice(self):
1047    c = self._NewComputation()
1048    c.Slice(
1049        c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), [1, 0],
1050        [3, 2])
1051    self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]])
1052
1053  def testSliceInDim(self):
1054    c = self._NewComputation()
1055    c.SliceInDim(
1056        c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1057        start_index=1,
1058        limit_index=2,
1059        stride=1,
1060        dimno=1)
1061    self._ExecuteAndCompareExact(c, expected=[[2], [5], [8]])
1062    c.SliceInDim(
1063        c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1064        start_index=0,
1065        limit_index=3,
1066        stride=2,
1067        dimno=0)
1068    self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [7, 8, 9]])
1069
1070  def testDynamicSlice(self):
1071    c = self._NewComputation()
1072    c.DynamicSlice(
1073        c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1074        c.Constant(NumpyArrayS32([1, 0])), [2, 2])
1075    self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]])
1076
1077  def testDynamicUpdateSlice(self):
1078    c = self._NewComputation()
1079    c.DynamicUpdateSlice(
1080        c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1081        c.Constant(NumpyArrayS32([[1, 2], [3, 4]])),
1082        c.Constant(NumpyArrayS32([1, 1])))
1083    self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 1, 2], [7, 3, 4]])
1084
1085  def testTuple(self):
1086    c = self._NewComputation()
1087    c.Tuple(
1088        c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])),
1089        c.Constant(NumpyArrayBool([True, False, False, True])))
1090    result = c.Build().Compile().ExecuteWithPythonValues()
1091    self.assertIsInstance(result, tuple)
1092    np.testing.assert_equal(result[0], 42)
1093    np.testing.assert_allclose(result[1], [1.0, 2.0])
1094    np.testing.assert_equal(result[2], [True, False, False, True])
1095
1096  def testGetTupleElement(self):
1097    c = self._NewComputation()
1098    c.GetTupleElement(
1099        c.Tuple(
1100            c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])),
1101            c.Constant(NumpyArrayBool([True, False, False, True]))), 1)
1102    self._ExecuteAndCompareClose(c, expected=[1.0, 2.0])
1103
1104  def testBroadcast(self):
1105    c = self._NewComputation()
1106    c.Broadcast(c.Constant(NumpyArrayS32([10, 20, 30, 40])), sizes=(3,))
1107    self._ExecuteAndCompareExact(
1108        c, expected=[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]])
1109
1110  def testBroadcastInDim(self):
1111    c = self._NewComputation()
1112    c.BroadcastInDim(c.Constant(NumpyArrayS32([1, 2])), [2, 2], [0])
1113    self._ExecuteAndCompareExact(c, expected=[[1, 1], [2, 2]])
1114    c.BroadcastInDim(c.Constant(NumpyArrayS32([1, 2])), [2, 2], [1])
1115    self._ExecuteAndCompareExact(c, expected=[[1, 2], [1, 2]])
1116
1117  def testRngNormal(self):
1118    shape = (2, 3)
1119    c = self._NewComputation()
1120    c.RngNormal(c.Constant(NumpyArrayF32(0.)), c.Constant(NumpyArrayF32(1.)),
1121                dims=shape)
1122    result = c.Build().Compile().ExecuteWithPythonValues()
1123    # since the result is random, we just check shape and uniqueness
1124    self.assertEqual(result.shape, shape)
1125    self.assertEqual(len(np.unique(result)), np.prod(shape))
1126
1127  def testRngUniformF32(self):
1128    lo, hi = 2., 4.
1129    shape = (2, 3)
1130    c = self._NewComputation()
1131    c.RngUniform(c.Constant(NumpyArrayF32(lo)), c.Constant(NumpyArrayF32(hi)),
1132                 dims=shape)
1133    result = c.Build().Compile().ExecuteWithPythonValues()
1134    # since the result is random, we just check shape, uniqueness, and range
1135    self.assertEqual(result.shape, shape)
1136    self.assertEqual(len(np.unique(result)), np.prod(shape))
1137    self.assertTrue(np.all(lo <= result))
1138    self.assertTrue(np.all(result < hi))
1139
1140  def testRngUniformS32(self):
1141    lo, hi = 2, 4
1142    shape = (2, 3)
1143    c = self._NewComputation()
1144    c.RngUniform(c.Constant(NumpyArrayS32(lo)), c.Constant(NumpyArrayS32(hi)),
1145                 dims=shape)
1146    result = c.Build().Compile().ExecuteWithPythonValues()
1147    # since the result is random, we just check shape, integrality, and range
1148    self.assertEqual(result.shape, shape)
1149    self.assertEqual(result.dtype, np.int32)
1150    self.assertTrue(np.all(lo <= result))
1151    self.assertTrue(np.all(result < hi))
1152
1153  def testCholesky(self):
1154    l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]],
1155                 dtype=np.float32)
1156    c = self._NewComputation()
1157    c.Cholesky(c.Constant(np.dot(l, l.T)))
1158    self._ExecuteAndCompareClose(c, expected=l, rtol=1e-4)
1159
1160  def testQR(self):
1161    a = np.array(
1162        [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]],
1163        dtype=np.float32)
1164    c = self._NewComputation()
1165    c.QR(c.Constant(a), full_matrices=True)
1166    q, r = self._Execute(c, ())
1167    np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4)
1168
1169  def testEigh(self):
1170    a = np.array(
1171        [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]],
1172        dtype=np.float32)
1173    a = (a + a.T) / 2
1174
1175    c = self._NewComputation()
1176    c.Eigh(c.Constant(a), full_matrices=True)
1177    v, w = self._Execute(c, ())
1178    self.assertLess(np.linalg.norm(np.dot(a, v) - w * v), 1e-3)
1179
1180  def testSVD(self):
1181    a = np.array(
1182        [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]],
1183        dtype=np.float32)
1184    c = self._NewComputation()
1185    c.SVD(c.Constant(a))
1186    u, d, v = self._Execute(c, ())
1187    self.assertLess(np.linalg.norm(a - np.matmul(u * d, v.T)), 1e-3)
1188
1189  def testTriangularSolve(self):
1190    a_vals = np.array(
1191        [[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]],
1192        dtype=np.float32)
1193    b_vals = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
1194                      dtype=np.float32)
1195
1196    c = self._NewComputation()
1197    c.TriangularSolve(c.Constant(a_vals), c.Constant(b_vals), left_side=False,
1198                      lower=True, transpose_a=True)
1199    self._ExecuteAndCompareClose(c, expected=np.array([
1200        [0.5, 0.08333334, 0.04629629, 0.03367003],
1201        [2.5, -0.25, -0.1388889, -0.1010101],
1202        [4.5, -0.58333331, -0.32407406, -0.23569024],
1203    ], dtype=np.float32), rtol=1e-4)
1204
1205  def testIsConstant(self):
1206    c = self._NewComputation()
1207    a = c.ConstantS32Scalar(3)
1208    b = c.ConstantS32Scalar(1)
1209    x = c.ParameterFromNumpy(NumpyArrayS32(0))
1210    const_expr = c.Sub(b, a)
1211    non_const_expr = c.Mul(const_expr, x)
1212    self.assertTrue(c.IsConstant(const_expr))
1213    self.assertFalse(c.IsConstant(non_const_expr))
1214    # self.assertTrue(c.IsConstant(c.Sub(c.Add(x, a), x)))  # TODO(b/77245564)
1215
1216  def testGather(self):
1217    a = np.arange(9).astype(np.int32).reshape((3, 3))
1218    indices = np.array([[[0, 2], [2, 1]], [[1, 2], [2, 0]]], dtype=np.int32)
1219    dnums = xla_client.GatherDimensionNumbers()
1220    dnums.offset_dims.append(1)
1221    dnums.offset_dims.append(2)
1222    dnums.start_index_map.append(0)
1223    dnums.start_index_map.append(1)
1224    dnums.index_vector_dim = 2
1225    c = self._NewComputation()
1226    c.Gather(c.Constant(a), c.Constant(indices), dnums, slice_sizes=[1, 1])
1227    g = self._Execute(c, ())
1228    expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32)
1229    np.testing.assert_allclose(g, expected, rtol=1e-4)
1230
1231
1232class EmbeddedComputationsTest(ComputationTest):
1233  """Tests for XLA graphs with embedded computations (such as maps)."""
1234
1235  def _CreateConstantS32Computation(self):
1236    """Computation (f32) -> s32 that returns a constant 1 for any input."""
1237    c = self._NewComputation("constant_s32_one")
1238    # TODO(eliben): consider adding a nicer way to create new parameters without
1239    # having to create dummy Numpy arrays or populating Shape messages. Perhaps
1240    # we need our own (Python-client-own) way to represent Shapes conveniently.
1241    c.ParameterFromNumpy(NumpyArrayF32(0))
1242    c.ConstantS32Scalar(1)
1243    return c.Build()
1244
1245  def _CreateConstantS64Computation(self):
1246    """Computation (f64) -> s64 that returns a constant 1 for any input."""
1247    c = self._NewComputation("constant_s64_one")
1248    # TODO(eliben): consider adding a nicer way to create new parameters without
1249    # having to create dummy Numpy arrays or populating Shape messages. Perhaps
1250    # we need our own (Python-client-own) way to represent Shapes conveniently.
1251    c.ParameterFromNumpy(NumpyArrayF64(0))
1252    c.ConstantS64Scalar(1)
1253    return c.Build()
1254
1255  def _CreateConstantF32Computation(self):
1256    """Computation (f32) -> f32 that returns a constant 1.0 for any input."""
1257    c = self._NewComputation("constant_f32_one")
1258    c.ParameterFromNumpy(NumpyArrayF32(0))
1259    c.ConstantF32Scalar(1.0)
1260    return c.Build()
1261
1262  def _CreateConstantF64Computation(self):
1263    """Computation (f64) -> f64 that returns a constant 1.0 for any input."""
1264    c = self._NewComputation("constant_f64_one")
1265    c.ParameterFromNumpy(NumpyArrayF64(0))
1266    c.ConstantF64Scalar(1.0)
1267    return c.Build()
1268
1269  def _CreateMulF32By2Computation(self):
1270    """Computation (f32) -> f32 that multiplies its parameter by 2."""
1271    c = self._NewComputation("mul_f32_by2")
1272    c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(2.0))
1273    return c.Build()
1274
1275  def _CreateMulF32ByParamComputation(self):
1276    """Computation (f32) -> f32 that multiplies one parameter by the other."""
1277    c = self._NewComputation("mul_f32_by_param")
1278    c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)),
1279          c.ParameterFromNumpy(NumpyArrayF32(0)))
1280    return c.Build()
1281
1282  def _CreateMulF64By2Computation(self):
1283    """Computation (f64) -> f64 that multiplies its parameter by 2."""
1284    c = self._NewComputation("mul_f64_by2")
1285    c.Mul(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(2.0))
1286    return c.Build()
1287
1288  def _CreateBinaryAddS32Computation(self):
1289    """Computation (s32, s32) -> s32 that adds its two parameters."""
1290    c = self._NewComputation("add_param0_by_param1")
1291    c.Add(
1292        c.ParameterFromNumpy(NumpyArrayS32(0)),
1293        c.ParameterFromNumpy(NumpyArrayS32(0)))
1294    return c.Build()
1295
1296  def _CreateBinaryAddF32Computation(self):
1297    """Computation (f32, f32) -> f32 that adds its two parameters."""
1298    c = self._NewComputation("add_param0_by_param1")
1299    c.Add(
1300        c.ParameterFromNumpy(NumpyArrayF32(0)),
1301        c.ParameterFromNumpy(NumpyArrayF32(0)))
1302    return c.Build()
1303
1304  def _CreateBinaryAddF64Computation(self):
1305    """Computation (f64, f64) -> f64 that adds its two parameters."""
1306    c = self._NewComputation("add_param0_by_param1")
1307    c.Add(
1308        c.ParameterFromNumpy(NumpyArrayF64(0)),
1309        c.ParameterFromNumpy(NumpyArrayF64(0)))
1310    return c.Build()
1311
1312  def _CreateBinaryDivF32Computation(self):
1313    """Computation (f32, f32) -> f32 that divides its two parameters."""
1314    c = self._NewComputation("div_param0_by_param1")
1315    c.Div(
1316        c.ParameterFromNumpy(NumpyArrayF32(0)),
1317        c.ParameterFromNumpy(NumpyArrayF32(0)))
1318    return c.Build()
1319
1320  def _CreateBinaryDivF64Computation(self):
1321    """Computation (f64, f64) -> f64 that divides its two parameters."""
1322    c = self._NewComputation("div_param0_by_param1")
1323    c.Div(
1324        c.ParameterFromNumpy(NumpyArrayF64(0)),
1325        c.ParameterFromNumpy(NumpyArrayF64(0)))
1326    return c.Build()
1327
1328  def _CreateTestF32Lt10Computation(self):
1329    """Computation (f32) -> bool that tests if its parameter is less than 10."""
1330    c = self._NewComputation("test_f32_lt_10")
1331    c.Lt(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(10.))
1332    return c.Build()
1333
1334  def _CreateTestF64Lt10Computation(self):
1335    """Computation (f64) -> bool that tests if its parameter is less than 10."""
1336    c = self._NewComputation("test_f64_lt_10")
1337    c.Lt(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(10.))
1338    return c.Build()
1339
1340  def _CreateBinaryGeF32Computation(self):
1341    """Computation (f32, f32) -> bool that tests first_param >= second_param."""
1342    c = self._NewComputation("param0_lt_param1")
1343    c.Ge(c.ParameterFromNumpy(NumpyArrayF32(0)),
1344         c.ParameterFromNumpy(NumpyArrayF32(0)))
1345    return c.Build()
1346
1347  def _CreateBinaryGeF64Computation(self):
1348    """Computation (f64, f64) -> bool that tests first_param >= second_param."""
1349    c = self._NewComputation("param0_lt_param1")
1350    c.Ge(c.ParameterFromNumpy(NumpyArrayF64(0)),
1351         c.ParameterFromNumpy(NumpyArrayF64(0)))
1352    return c.Build()
1353
1354  def _MakeSample3DArrayF32(self):
1355    return NumpyArrayF32([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]],
1356                          [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])
1357
1358  def _MakeSample3DArrayF64(self):
1359    return NumpyArrayF64([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]],
1360                          [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])
1361
1362  def testCallF32(self):
1363    c = self._NewComputation()
1364    c.Call(
1365        self._CreateMulF32By2Computation(),
1366        operands=(c.ConstantF32Scalar(5.0),))
1367    self._ExecuteAndCompareClose(c, expected=10.0)
1368
1369  def testCallF64(self):
1370    c = self._NewComputation()
1371    c.Call(
1372        self._CreateMulF64By2Computation(),
1373        operands=(c.ConstantF64Scalar(5.0),))
1374    self._ExecuteAndCompareClose(c, expected=10.0)
1375
1376  def testMapEachElementToS32Constant(self):
1377    c = self._NewComputation()
1378    c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))],
1379          self._CreateConstantS32Computation(), [0])
1380    self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1])
1381
1382  def testMapEachElementToS64Constant(self):
1383    c = self._NewComputation()
1384    c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))],
1385          self._CreateConstantS64Computation(), [0])
1386    self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1])
1387
1388  def testMapMulBy2F32(self):
1389    c = self._NewComputation()
1390    c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))],
1391          self._CreateMulF32By2Computation(), [0])
1392    self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0])
1393
1394  def testMapMulBy2F64(self):
1395    c = self._NewComputation()
1396    c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))],
1397          self._CreateMulF64By2Computation(), [0])
1398    self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0])
1399
1400  def testSimpleMapChainF32(self):
1401    # Chains a map of constant-f32 with a map of mul-by-2
1402    c = self._NewComputation()
1403    const_f32 = c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))],
1404                      self._CreateConstantF32Computation(), [0])
1405    c.Map([const_f32], self._CreateMulF32By2Computation(), [0])
1406    self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0])
1407
1408  def testSimpleMapChainF64(self):
1409    # Chains a map of constant-f64 with a map of mul-by-2
1410    c = self._NewComputation()
1411    const_f64 = c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))],
1412                      self._CreateConstantF64Computation(), [0])
1413    c.Map([const_f64], self._CreateMulF64By2Computation(), [0])
1414    self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0])
1415
1416  def testDivVectorsWithMapF32(self):
1417    c = self._NewComputation()
1418    c.Map((c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])),
1419           c.Constant(NumpyArrayF32([5.0, 5.0, 4.0, 4.0]))),
1420          self._CreateBinaryDivF32Computation(), [0])
1421    self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0])
1422
1423  def testDivVectorsWithMapF64(self):
1424    c = self._NewComputation()
1425    c.Map((c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])),
1426           c.Constant(NumpyArrayF64([5.0, 5.0, 4.0, 4.0]))),
1427          self._CreateBinaryDivF64Computation(), [0])
1428    self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0])
1429
1430  def testSelectAndScatterF32(self):
1431    c = self._NewComputation()
1432    c.SelectAndScatter(c.Constant(NumpyArrayF32([[1., 2., 6.], [4., 5., 3.]])),
1433                       select=self._CreateBinaryGeF32Computation(),
1434                       window_dimensions=(2, 1),
1435                       window_strides=(1, 2),
1436                       padding=xla_client.PaddingType.VALID,
1437                       source=c.Constant(NumpyArrayF32([[0.1, 0.2]])),
1438                       init_value=c.Constant(NumpyArrayF32(1)),
1439                       scatter=self._CreateBinaryAddF32Computation())
1440    self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]])
1441
1442  def testSelectAndScatterF64(self):
1443    c = self._NewComputation()
1444    c.SelectAndScatter(c.Constant(NumpyArrayF64([[1., 2., 6.], [4., 5., 3.]])),
1445                       select=self._CreateBinaryGeF64Computation(),
1446                       window_dimensions=(2, 1),
1447                       window_strides=(1, 2),
1448                       padding=xla_client.PaddingType.VALID,
1449                       source=c.Constant(NumpyArrayF64([[0.1, 0.2]])),
1450                       init_value=c.Constant(NumpyArrayF64(1)),
1451                       scatter=self._CreateBinaryAddF64Computation())
1452    self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]])
1453
1454  def testReduce1DtoScalarF32(self):
1455    c = self._NewComputation()
1456    c.Reduce(
1457        operand=c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])),
1458        init_value=c.ConstantF32Scalar(0),
1459        computation_to_apply=self._CreateBinaryAddF32Computation(),
1460        dimensions=[0])
1461    self._ExecuteAndCompareClose(c, expected=10)
1462
1463  def testReduce1DtoScalarF64(self):
1464    c = self._NewComputation()
1465    c.Reduce(
1466        operand=c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])),
1467        init_value=c.ConstantF64Scalar(0),
1468        computation_to_apply=self._CreateBinaryAddF64Computation(),
1469        dimensions=[0])
1470    self._ExecuteAndCompareClose(c, expected=10)
1471
1472  def testReduce2DTo1DDim0F32(self):
1473    input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
1474    c = self._NewComputation()
1475    c.Reduce(
1476        operand=c.Constant(input_array),
1477        init_value=c.ConstantF32Scalar(0),
1478        computation_to_apply=self._CreateBinaryAddF32Computation(),
1479        dimensions=[0])
1480    self._ExecuteAndCompareClose(c, expected=[5, 7, 9])
1481
1482  def testReduce2DTo1DDim0F64(self):
1483    input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
1484    c = self._NewComputation()
1485    c.Reduce(
1486        operand=c.Constant(input_array),
1487        init_value=c.ConstantF64Scalar(0),
1488        computation_to_apply=self._CreateBinaryAddF64Computation(),
1489        dimensions=[0])
1490    self._ExecuteAndCompareClose(c, expected=[5, 7, 9])
1491
1492  def testReduce2DTo1DDim1F32(self):
1493    input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
1494    c = self._NewComputation()
1495    c.Reduce(
1496        operand=c.Constant(input_array),
1497        init_value=c.ConstantF32Scalar(0),
1498        computation_to_apply=self._CreateBinaryAddF32Computation(),
1499        dimensions=[1])
1500    self._ExecuteAndCompareClose(c, expected=[6, 15])
1501
1502  def testReduce2DTo1DDim1F64(self):
1503    input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
1504    c = self._NewComputation()
1505    c.Reduce(
1506        operand=c.Constant(input_array),
1507        init_value=c.ConstantF64Scalar(0),
1508        computation_to_apply=self._CreateBinaryAddF64Computation(),
1509        dimensions=[1])
1510    self._ExecuteAndCompareClose(c, expected=[6, 15])
1511
1512  def testReduce3DAllPossibleWaysF32(self):
1513    input_array = self._MakeSample3DArrayF32()
1514
1515    def _ReduceAndTest(*dims):
1516      c = self._NewComputation()
1517      c.Reduce(
1518          operand=c.Constant(input_array),
1519          init_value=c.ConstantF32Scalar(0),
1520          computation_to_apply=self._CreateBinaryAddF32Computation(),
1521          dimensions=dims)
1522      self._ExecuteAndCompareClose(
1523          c, expected=np.sum(input_array, axis=tuple(dims)))
1524
1525    _ReduceAndTest(0)
1526    _ReduceAndTest(0, 1)
1527    _ReduceAndTest(0, 2)
1528    _ReduceAndTest(1, 2)
1529    _ReduceAndTest(0, 1, 2)
1530
1531  def testReduce3DAllPossibleWaysF64(self):
1532    input_array = self._MakeSample3DArrayF64()
1533
1534    def _ReduceAndTest(*dims):
1535      c = self._NewComputation()
1536      c.Reduce(
1537          operand=c.Constant(input_array),
1538          init_value=c.ConstantF64Scalar(0),
1539          computation_to_apply=self._CreateBinaryAddF64Computation(),
1540          dimensions=dims)
1541      self._ExecuteAndCompareClose(
1542          c, expected=np.sum(input_array, axis=tuple(dims)))
1543
1544    _ReduceAndTest(0)
1545    _ReduceAndTest(0)
1546    _ReduceAndTest(0, 1)
1547    _ReduceAndTest(0, 2)
1548    _ReduceAndTest(1, 2)
1549    _ReduceAndTest(0, 1, 2)
1550
1551  def testReduceWindowValidUnitStridesF32(self):
1552    input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
1553    c = self._NewComputation()
1554    c.ReduceWindow(operand=c.Constant(input_array),
1555                   init_value=c.ConstantF32Scalar(0),
1556                   computation_to_apply=self._CreateBinaryAddF32Computation(),
1557                   window_dimensions=(2, 1), window_strides=(1, 1),
1558                   padding=xla_client.PaddingType.VALID)
1559    self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]])
1560
1561  def testReduceWindowSameUnitStridesF32(self):
1562    input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
1563    c = self._NewComputation()
1564    c.ReduceWindow(operand=c.Constant(input_array),
1565                   init_value=c.ConstantF32Scalar(0),
1566                   computation_to_apply=self._CreateBinaryAddF32Computation(),
1567                   window_dimensions=(2, 1), window_strides=(1, 1),
1568                   padding=xla_client.PaddingType.SAME)
1569    self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]])
1570
1571  def testReduceWindowValidGeneralStridesF32(self):
1572    input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
1573    c = self._NewComputation()
1574    c.ReduceWindow(operand=c.Constant(input_array),
1575                   init_value=c.ConstantF32Scalar(0),
1576                   computation_to_apply=self._CreateBinaryAddF32Computation(),
1577                   window_dimensions=(2, 1), window_strides=(1, 2),
1578                   padding=xla_client.PaddingType.VALID)
1579    self._ExecuteAndCompareClose(c, expected=[[5., 9.]])
1580
1581  def testReduceWindowValidUnitStridesF64(self):
1582    input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
1583    c = self._NewComputation()
1584    c.ReduceWindow(operand=c.Constant(input_array),
1585                   init_value=c.ConstantF64Scalar(0),
1586                   computation_to_apply=self._CreateBinaryAddF64Computation(),
1587                   window_dimensions=(2, 1), window_strides=(1, 1),
1588                   padding=xla_client.PaddingType.VALID)
1589    self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]])
1590
1591  def testReduceWindowSameUnitStridesF64(self):
1592    input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
1593    c = self._NewComputation()
1594    c.ReduceWindow(operand=c.Constant(input_array),
1595                   init_value=c.ConstantF64Scalar(0),
1596                   computation_to_apply=self._CreateBinaryAddF64Computation(),
1597                   window_dimensions=(2, 1), window_strides=(1, 1),
1598                   padding=xla_client.PaddingType.SAME)
1599    self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]])
1600
1601  def testReduceWindowValidGeneralStridesF64(self):
1602    input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
1603    c = self._NewComputation()
1604    c.ReduceWindow(operand=c.Constant(input_array),
1605                   init_value=c.ConstantF64Scalar(0),
1606                   computation_to_apply=self._CreateBinaryAddF64Computation(),
1607                   window_dimensions=(2, 1), window_strides=(1, 2),
1608                   padding=xla_client.PaddingType.VALID)
1609    self._ExecuteAndCompareClose(c, expected=[[5., 9.]])
1610
1611  def testWhileF32(self):
1612    cond = self._CreateTestF32Lt10Computation()
1613    body = self._CreateMulF32By2Computation()
1614    c = self._NewComputation()
1615    init = c.ConstantF32Scalar(1.)
1616    c.While(cond, body, init)
1617    self._ExecuteAndCompareClose(c, expected=16.)
1618
1619  def testWhileF64(self):
1620    cond = self._CreateTestF64Lt10Computation()
1621    body = self._CreateMulF64By2Computation()
1622    c = self._NewComputation()
1623    init = c.ConstantF64Scalar(1.)
1624    c.While(cond, body, init)
1625    self._ExecuteAndCompareClose(c, expected=16.)
1626
1627  def testConditionalTrue(self):
1628    c = self._NewComputation()
1629    pred = c.ConstantPredScalar(True)
1630    true_operand = c.ConstantF32Scalar(3.)
1631    true_computation = self._CreateMulF32By2Computation()
1632    false_operand = c.ConstantF32Scalar(2.)
1633    false_computation = self._CreateConstantF32Computation()
1634    c.Conditional(pred, true_operand, true_computation, false_operand,
1635                  false_computation)
1636    self._ExecuteAndCompareClose(c, expected=6.)
1637
1638  def testConditionalFalse(self):
1639    c = self._NewComputation()
1640    pred = c.ConstantPredScalar(False)
1641    true_operand = c.ConstantF32Scalar(3.)
1642    true_computation = self._CreateMulF32By2Computation()
1643    false_operand = c.ConstantF32Scalar(2.)
1644    false_computation = self._CreateConstantF32Computation()
1645    c.Conditional(pred, true_operand, true_computation, false_operand,
1646                  false_computation)
1647    self._ExecuteAndCompareClose(c, expected=1.)
1648
1649  def testInfeedS32Values(self):
1650    to_infeed = NumpyArrayS32([1, 2, 3, 4])
1651    c = self._NewComputation()
1652    c.Infeed(xla_client.Shape.from_pyval(to_infeed[0]))
1653    compiled_c = c.Build().CompileWithExampleArguments()
1654    for item in to_infeed:
1655      xla_client.transfer_to_infeed(item)
1656
1657    for item in to_infeed:
1658      result = compiled_c.ExecuteWithPythonValues()
1659      self.assertEqual(result, item)
1660
1661  def testInfeedThenOutfeedS32(self):
1662    to_round_trip = NumpyArrayS32([1, 2, 3, 4])
1663    c = self._NewComputation()
1664    x = c.Infeed(xla_client.Shape.from_pyval(to_round_trip[0]))
1665    c.Outfeed(x)
1666
1667    compiled_c = c.Build().CompileWithExampleArguments()
1668
1669    for want in to_round_trip:
1670      execution = threading.Thread(target=compiled_c.Execute)
1671      execution.start()
1672      xla_client.transfer_to_infeed(want)
1673      got = xla_client.transfer_from_outfeed(
1674          xla_client.Shape.from_pyval(to_round_trip[0]))
1675      execution.join()
1676      self.assertEqual(want, got)
1677
1678  def testScatter(self):
1679    a = np.arange(9).astype(np.int32).reshape((3, 3))
1680    scatter_indices = np.array([0, 2], dtype=np.int32)
1681    updates = np.array([[10, 20, 30], [70, 80, 90]], dtype=np.int32)
1682
1683    dnums = xla_client.ScatterDimensionNumbers()
1684    dnums.update_window_dims.append(1)
1685    dnums.inserted_window_dims.append(0)
1686    dnums.scatter_dims_to_operand_dims.append(0)
1687    dnums.index_vector_dim = 1
1688
1689    c = self._NewComputation()
1690    c.Scatter(c.Constant(a), c.Constant(scatter_indices), c.Constant(updates),
1691              self._CreateBinaryAddS32Computation(), dnums)
1692    expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]], dtype=np.int32)
1693    self._ExecuteAndCompareClose(c, expected=expected)
1694
1695
1696class ErrorTest(ComputationTest):
1697
1698  def setUp(self):
1699    self.f32_scalar_2 = NumpyArrayF32(2.0)
1700    self.s32_scalar_2 = NumpyArrayS32(2)
1701
1702  def testInvokeWithWrongElementType(self):
1703    c = self._NewComputation()
1704    c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata())
1705    c.ParameterFromNumpy(self.s32_scalar_2)
1706    c.ClearOpMetadata()
1707    self.assertRaisesRegexp(
1708        RuntimeError, r"Invalid argument shape.*xla_client_test.py.*"
1709        r"expected s32\[\], got f32\[\]",
1710        lambda: c.Build().CompileWithExampleArguments([self.f32_scalar_2]))
1711
1712
1713class ComputationRootTest(ComputationTest):
1714  """Tests related to setting the root of the computation."""
1715
1716  def testComputationRootDifferentFromLastOp(self):
1717    c = self._NewComputation()
1718    x = c.ParameterFromNumpy(NumpyArrayF32(2.0))
1719    result = c.Add(x, c.ConstantF32Scalar(3.14))
1720    extra = c.Add(result, c.ConstantF32Scalar(1.618))  # pylint: disable=unused-variable
1721
1722    arg = NumpyArrayF32(1.0)
1723    compiled_c = c.Build(result).CompileWithExampleArguments([arg])
1724    ans = compiled_c.ExecuteWithPythonValues([arg])
1725    np.testing.assert_allclose(ans, 4.14)
1726
1727
1728if __name__ == "__main__":
1729  unittest.main()
1730