1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the Python extension-based XLA client.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import itertools 23import threading 24 25import numpy as np 26 27from tensorflow.compiler.xla import xla_data_pb2 28from tensorflow.compiler.xla.python import custom_call_for_test 29from tensorflow.compiler.xla.python import xla_client 30import unittest 31 32 33class EnumTest(unittest.TestCase): 34 """Verifies Python enumerations match their protocol buffer equivalents.""" 35 36 def testPrimitiveType(self): 37 for name, value in xla_client.PrimitiveType.__members__.items(): 38 self.assertEqual(value, getattr(xla_data_pb2, name)) 39 40 def testFormat(self): 41 for name, value in xla_client.Format.__members__.items(): 42 self.assertEqual(value, getattr(xla_data_pb2, name)) 43 44 45class ComputationTest(unittest.TestCase): 46 """Base class for running an XLA Computation through the local client.""" 47 48 def _NewComputation(self, name=None): 49 if name is None: 50 name = self.id() 51 return xla_client.ComputationBuilder(name) 52 53 def _Execute(self, c, arguments): 54 compiled_c = c.Build().CompileWithExampleArguments(arguments) 55 return compiled_c.ExecuteWithPythonValues(arguments) 56 57 def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected): 58 assert expected is not None 59 result = self._Execute(c, arguments) 60 # Numpy's comparison methods are a bit too lenient by treating inputs as 61 # "array-like", meaning that scalar 4 will be happily compared equal to 62 # [[4]]. We'd like to be more strict so assert shapes as well. 63 self.assertEqual(np.asanyarray(result).shape, np.asanyarray(expected).shape) 64 assert_func(result, expected) 65 66 def _ExecuteAndCompareExact(self, c, arguments=(), expected=None): 67 self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected) 68 69 def _ExecuteAndCompareClose(self, c, arguments=(), expected=None, rtol=1e-7, 70 atol=0): 71 self._ExecuteAndAssertWith( 72 functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol), 73 c, arguments, expected) 74 75 76def NumpyArrayF32(*args, **kwargs): 77 """Convenience wrapper to create Numpy arrays with a np.float32 dtype.""" 78 return np.array(*args, dtype=np.float32, **kwargs) 79 80 81def NumpyArrayF64(*args, **kwargs): 82 """Convenience wrapper to create Numpy arrays with a np.float64 dtype.""" 83 return np.array(*args, dtype=np.float64, **kwargs) 84 85 86def NumpyArrayS32(*args, **kwargs): 87 """Convenience wrapper to create Numpy arrays with a np.int32 dtype.""" 88 return np.array(*args, dtype=np.int32, **kwargs) 89 90 91def NumpyArrayS64(*args, **kwargs): 92 """Convenience wrapper to create Numpy arrays with a np.int64 dtype.""" 93 return np.array(*args, dtype=np.int64, **kwargs) 94 95 96def NumpyArrayBool(*args, **kwargs): 97 """Convenience wrapper to create Numpy arrays with a np.bool dtype.""" 98 return np.array(*args, dtype=np.bool, **kwargs) 99 100 101class ComputationPrinting(unittest.TestCase): 102 103 def ExampleComputation(self): 104 builder = xla_client.ComputationBuilder("acomputation") 105 p0 = builder.ParameterFromNumpy(np.float32(0)) 106 p1 = builder.ParameterFromNumpy(np.zeros((4,), np.float32)) 107 builder.Mul(p0, p1) 108 return builder.Build() 109 110 def testComputationToHloText(self): 111 computation = self.ExampleComputation() 112 hlo_text = computation.GetHloText() 113 self.assertTrue(hlo_text.startswith("HloModule acomputation")) 114 115 def testComputationToHloGraph(self): 116 computation = self.ExampleComputation() 117 hlo_dot_graph = computation.GetHloDotGraph() 118 self.assertTrue(hlo_dot_graph.startswith("digraph ")) 119 120 121class ComputationsWithConstantsTest(ComputationTest): 122 """Tests focusing on Constant ops.""" 123 124 def testConstantScalarSumS8(self): 125 c = self._NewComputation() 126 root = c.Add(c.Constant(np.int8(1)), c.Constant(np.int8(2))) 127 self.assertEqual(c.GetShape(root), c.GetReturnValueShape()) 128 self._ExecuteAndCompareExact(c, expected=np.int8(3)) 129 130 def testConstantScalarSumF32(self): 131 c = self._NewComputation() 132 root = c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) 133 self.assertEqual(c.GetShape(root), c.GetReturnValueShape()) 134 self._ExecuteAndCompareClose(c, expected=4.25) 135 136 def testConstantScalarSumF64(self): 137 c = self._NewComputation() 138 c.Add(c.ConstantF64Scalar(1.11), c.ConstantF64Scalar(3.14)) 139 self._ExecuteAndCompareClose(c, expected=4.25) 140 141 def testConstantScalarSumS32(self): 142 c = self._NewComputation() 143 c.Add(c.ConstantS32Scalar(1), c.ConstantS32Scalar(2)) 144 self._ExecuteAndCompareClose(c, expected=3) 145 146 def testConstantScalarSumS64(self): 147 c = self._NewComputation() 148 c.Add(c.ConstantS64Scalar(1), c.ConstantS64Scalar(2)) 149 self._ExecuteAndCompareClose(c, expected=3) 150 151 def testConstantVectorMulF32(self): 152 c = self._NewComputation() 153 c.Mul( 154 c.Constant(NumpyArrayF32([2.5, 3.3, -1.2, 0.7])), 155 c.Constant(NumpyArrayF32([-1.2, 2, -2, -3]))) 156 self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1]) 157 158 def testConstantVectorMulF64(self): 159 c = self._NewComputation() 160 c.Mul( 161 c.Constant(NumpyArrayF64([2.5, 3.3, -1.2, 0.7])), 162 c.Constant(NumpyArrayF64([-1.2, 2, -2, -3]))) 163 self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1]) 164 165 def testConstantVectorScalarDivF32(self): 166 c = self._NewComputation() 167 c.Div( 168 c.Constant(NumpyArrayF32([1.5, 2.5, 3.0, -10.8])), 169 c.ConstantF32Scalar(2.0)) 170 self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4]) 171 172 def testConstantVectorScalarDivF64(self): 173 c = self._NewComputation() 174 c.Div( 175 c.Constant(NumpyArrayF64([1.5, 2.5, 3.0, -10.8])), 176 c.ConstantF64Scalar(2.0)) 177 self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4]) 178 179 def testConstantVectorScalarPowF32(self): 180 c = self._NewComputation() 181 c.Pow(c.Constant(NumpyArrayF32([1.5, 2.5, 3.0])), c.ConstantF32Scalar(2.)) 182 self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.]) 183 184 def testConstantVectorScalarPowF64(self): 185 c = self._NewComputation() 186 c.Pow(c.Constant(NumpyArrayF64([1.5, 2.5, 3.0])), c.ConstantF64Scalar(2.)) 187 self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.]) 188 189 def testIota(self): 190 c = self._NewComputation() 191 c.Iota(np.float32, 10) 192 self._ExecuteAndCompareExact(c, expected=np.arange(10, dtype=np.float32)) 193 194 def testBroadcastedIota(self): 195 c = self._NewComputation() 196 c.BroadcastedIota(np.int64, (2, 3), 1) 197 expected = np.array([[0, 1, 2], [0, 1, 2]], dtype=np.int64) 198 self._ExecuteAndCompareExact(c, expected=expected) 199 200 def testBooleanAnd(self): 201 c = self._NewComputation() 202 c.And( 203 c.Constant(NumpyArrayBool([True, False, True, False])), 204 c.Constant(NumpyArrayBool([True, True, False, False]))) 205 self._ExecuteAndCompareExact(c, expected=[True, False, False, False]) 206 207 def testBooleanOr(self): 208 c = self._NewComputation() 209 c.Or( 210 c.Constant(NumpyArrayBool([True, False, True, False])), 211 c.Constant(NumpyArrayBool([True, True, False, False]))) 212 self._ExecuteAndCompareExact(c, expected=[True, True, True, False]) 213 214 def testBooleanXor(self): 215 c = self._NewComputation() 216 c.Xor( 217 c.Constant(NumpyArrayBool([True, False, True, False])), 218 c.Constant(NumpyArrayBool([True, True, False, False]))) 219 self._ExecuteAndCompareExact(c, expected=[False, True, True, False]) 220 221 def testSum2DF32(self): 222 c = self._NewComputation() 223 c.Add( 224 c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6]])), 225 c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) 226 self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) 227 228 def testShiftLeft(self): 229 c = self._NewComputation() 230 c.ShiftLeft(c.Constant(NumpyArrayS32([3])), 231 c.Constant(NumpyArrayS32([2]))) 232 self._ExecuteAndCompareClose(c, expected=[12]) 233 234 def testShiftRightArithmetic(self): 235 c = self._NewComputation() 236 c.ShiftRightArithmetic(c.Constant(NumpyArrayS32([-2])), 237 c.Constant(NumpyArrayS32([1]))) 238 self._ExecuteAndCompareClose(c, expected=[-1]) 239 240 def testShiftRightLogical(self): 241 c = self._NewComputation() 242 c.ShiftRightLogical(c.Constant(NumpyArrayS32([-1])), 243 c.Constant(NumpyArrayS32([1]))) 244 self._ExecuteAndCompareClose(c, expected=[2**31 - 1]) 245 246 def testSum2DF64(self): 247 c = self._NewComputation() 248 c.Add( 249 c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6]])), 250 c.Constant(NumpyArrayF64([[1, -1, 1], [-1, 1, -1]]))) 251 self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) 252 253 def testSum2DWith1DBroadcastDim0F32(self): 254 # sum of a 2D array with a 1D array where the latter is replicated across 255 # dimension 0 to match the former's shape. 256 c = self._NewComputation() 257 c.Add( 258 c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 259 c.Constant(NumpyArrayF32([10, 20, 30])), 260 broadcast_dimensions=(0,)) 261 self._ExecuteAndCompareClose( 262 c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]]) 263 264 def testSum2DWith1DBroadcastDim0F64(self): 265 # sum of a 2D array with a 1D array where the latter is replicated across 266 # dimension 0 to match the former's shape. 267 c = self._NewComputation() 268 c.Add( 269 c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 270 c.Constant(NumpyArrayF64([10, 20, 30])), 271 broadcast_dimensions=(0,)) 272 self._ExecuteAndCompareClose( 273 c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]]) 274 275 def testSum2DWith1DBroadcastDim1F32(self): 276 # sum of a 2D array with a 1D array where the latter is replicated across 277 # dimension 1 to match the former's shape. 278 c = self._NewComputation() 279 c.Add( 280 c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 281 c.Constant(NumpyArrayF32([10, 20, 30])), 282 broadcast_dimensions=(1,)) 283 self._ExecuteAndCompareClose( 284 c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]]) 285 286 def testSum2DWith1DBroadcastDim1F64(self): 287 # sum of a 2D array with a 1D array where the latter is replicated across 288 # dimension 1 to match the former's shape. 289 c = self._NewComputation() 290 c.Add( 291 c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 292 c.Constant(NumpyArrayF64([10, 20, 30])), 293 broadcast_dimensions=(1,)) 294 self._ExecuteAndCompareClose( 295 c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]]) 296 297 def testConstantAxpyF32(self): 298 c = self._NewComputation() 299 c.Add( 300 c.Mul( 301 c.ConstantF32Scalar(2), 302 c.Constant(NumpyArrayF32([2.2, 3.3, 4.4, 5.5]))), 303 c.Constant(NumpyArrayF32([100, -100, 200, -200]))) 304 self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189]) 305 306 def testConstantAxpyF64(self): 307 c = self._NewComputation() 308 c.Add( 309 c.Mul( 310 c.ConstantF64Scalar(2), 311 c.Constant(NumpyArrayF64([2.2, 3.3, 4.4, 5.5]))), 312 c.Constant(NumpyArrayF64([100, -100, 200, -200]))) 313 self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189]) 314 315 def testCustomCall(self): 316 c = self._NewComputation() 317 for name, fn in custom_call_for_test.cpu_custom_call_targets.items(): 318 xla_client.register_cpu_custom_call_target(name, fn) 319 c.CustomCall( 320 b"test_subtract_f32", 321 operands=(c.ConstantF32Scalar(1.25), c.ConstantF32Scalar(0.5)), 322 shape_with_layout=xla_client.Shape.array_shape(np.float32, (), ()), 323 operand_shapes_with_layout=( 324 xla_client.Shape.array_shape(np.float32, (), ()), 325 xla_client.Shape.array_shape(np.float32, (), ()), 326 )) 327 self._ExecuteAndCompareClose(c, expected=0.75) 328 329 330class ParametersTest(ComputationTest): 331 """Tests focusing on Parameter ops and argument-passing.""" 332 333 def setUp(self): 334 self.f32_scalar_2 = NumpyArrayF32(2.0) 335 self.f32_4vector = NumpyArrayF32([-2.3, 3.3, -4.3, 5.3]) 336 self.f64_scalar_2 = NumpyArrayF64(2.0) 337 self.f64_4vector = NumpyArrayF64([-2.3, 3.3, -4.3, 5.3]) 338 self.s32_scalar_3 = NumpyArrayS32(3) 339 self.s32_4vector = NumpyArrayS32([10, 15, -2, 7]) 340 self.s64_scalar_3 = NumpyArrayS64(3) 341 self.s64_4vector = NumpyArrayS64([10, 15, -2, 7]) 342 343 def testScalarTimesVectorAutonumberF32(self): 344 c = self._NewComputation() 345 p0 = c.ParameterFromNumpy(self.f32_scalar_2) 346 p1 = c.ParameterFromNumpy(self.f32_4vector) 347 c.Mul(p0, p1) 348 self._ExecuteAndCompareClose( 349 c, 350 arguments=[self.f32_scalar_2, self.f32_4vector], 351 expected=[-4.6, 6.6, -8.6, 10.6]) 352 353 def testScalarTimesVectorAutonumberF64(self): 354 c = self._NewComputation() 355 p0 = c.ParameterFromNumpy(self.f64_scalar_2) 356 p1 = c.ParameterFromNumpy(self.f64_4vector) 357 c.Mul(p0, p1) 358 self._ExecuteAndCompareClose( 359 c, 360 arguments=[self.f64_scalar_2, self.f64_4vector], 361 expected=[-4.6, 6.6, -8.6, 10.6]) 362 363 def testScalarTimesVectorS32(self): 364 c = self._NewComputation() 365 p0 = c.ParameterFromNumpy(self.s32_scalar_3) 366 p1 = c.ParameterFromNumpy(self.s32_4vector) 367 c.Mul(p0, p1) 368 self._ExecuteAndCompareExact( 369 c, 370 arguments=[self.s32_scalar_3, self.s32_4vector], 371 expected=[30, 45, -6, 21]) 372 373 def testScalarTimesVectorS64(self): 374 c = self._NewComputation() 375 p0 = c.ParameterFromNumpy(self.s64_scalar_3) 376 p1 = c.ParameterFromNumpy(self.s64_4vector) 377 c.Mul(p0, p1) 378 self._ExecuteAndCompareExact( 379 c, 380 arguments=[self.s64_scalar_3, self.s64_4vector], 381 expected=[30, 45, -6, 21]) 382 383 def testScalarMinusVectorExplicitNumberingF32(self): 384 # Use explicit numbering and pass parameter_num first. Sub is used since 385 # it's not commutative and can help catch parameter reversal within the 386 # computation. 387 c = self._NewComputation() 388 p1 = c.ParameterFromNumpy(self.f32_4vector, parameter_num=1) 389 p0 = c.ParameterFromNumpy(self.f32_scalar_2, parameter_num=0) 390 c.Sub(p1, p0) 391 self._ExecuteAndCompareClose( 392 c, 393 arguments=[self.f32_scalar_2, self.f32_4vector], 394 expected=[-4.3, 1.3, -6.3, 3.3]) 395 396 def testScalarMinusVectorExplicitNumberingF64(self): 397 # Use explicit numbering and pass parameter_num first. Sub is used since 398 # it's not commutative and can help catch parameter reversal within the 399 # computation. 400 c = self._NewComputation() 401 p1 = c.ParameterFromNumpy(self.f64_4vector, parameter_num=1) 402 p0 = c.ParameterFromNumpy(self.f64_scalar_2, parameter_num=0) 403 c.Sub(p1, p0) 404 self._ExecuteAndCompareClose( 405 c, 406 arguments=[self.f64_scalar_2, self.f64_4vector], 407 expected=[-4.3, 1.3, -6.3, 3.3]) 408 409 410class LocalBufferTest(ComputationTest): 411 """Tests focusing on execution with LocalBuffers.""" 412 413 def _Execute(self, c, arguments): 414 compiled_c = c.Build().CompileWithExampleArguments(arguments) 415 arg_buffers = [xla_client.LocalBuffer.from_pyval(arg) for arg in arguments] 416 result_buffer = compiled_c.Execute(arg_buffers) 417 return result_buffer.to_py() 418 419 def testConstantSum(self): 420 c = self._NewComputation() 421 c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) 422 self._ExecuteAndCompareClose(c, expected=4.25) 423 424 def testOneParameterSum(self): 425 c = self._NewComputation() 426 c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14)) 427 self._ExecuteAndCompareClose( 428 c, 429 arguments=[NumpyArrayF32(1.11)], 430 expected=4.25) 431 432 def testTwoParameterSum(self): 433 c = self._NewComputation() 434 c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), 435 c.ParameterFromNumpy(NumpyArrayF32(0.))) 436 self._ExecuteAndCompareClose( 437 c, 438 arguments=[NumpyArrayF32(1.11), NumpyArrayF32(3.14)], 439 expected=4.25) 440 441 def testCannotCallWithDeletedBuffers(self): 442 c = self._NewComputation() 443 c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14)) 444 arg = NumpyArrayF32(1.11) 445 compiled_c = c.Build().CompileWithExampleArguments([arg]) 446 arg_buffer = xla_client.LocalBuffer.from_pyval(arg) 447 arg_buffer.delete() 448 with self.assertRaises(ValueError): 449 compiled_c.Execute([arg_buffer]) 450 451 def testDestructureTupleEmpty(self): 452 t = () 453 local_buffer = xla_client.LocalBuffer.from_pyval(t) 454 pieces = local_buffer.destructure() 455 self.assertTrue(local_buffer.is_deleted()) 456 self.assertEqual(len(pieces), 0) 457 458 def testDestructureTupleOneArrayElement(self): 459 t = (np.array([1, 2, 3, 4], dtype=np.int32),) 460 local_buffer = xla_client.LocalBuffer.from_pyval(t) 461 pieces = local_buffer.destructure() 462 self.assertTrue(local_buffer.is_deleted()) 463 self.assertEqual(len(pieces), 1) 464 array = pieces[0] 465 got = array.to_py() 466 want = NumpyArrayS32([1, 2, 3, 4]) 467 np.testing.assert_equal(want, got) 468 469 def testDestructureTupleTwoArrayElementDifferentType(self): 470 t = (np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32), 471 np.array([2, 3, 4, 5], dtype=np.int32)) 472 local_buffer = xla_client.LocalBuffer.from_pyval(t) 473 pieces = local_buffer.destructure() 474 self.assertTrue(local_buffer.is_deleted()) 475 self.assertEqual(len(pieces), 2) 476 array0, array1 = pieces 477 got = array0.to_py() 478 want = NumpyArrayF32([1.0, 2.0, 3.0, 4.0]) 479 np.testing.assert_equal(want, got) 480 got = array1.to_py() 481 want = NumpyArrayS32([2, 3, 4, 5]) 482 np.testing.assert_equal(want, got) 483 484 def testDestructureTupleNested(self): 485 t = ((NumpyArrayF32([1.0, 2.0]), NumpyArrayS32([3, 4])), NumpyArrayS32([5])) 486 local_buffer = xla_client.LocalBuffer.from_pyval(t) 487 pieces = local_buffer.destructure() 488 self.assertTrue(local_buffer.is_deleted()) 489 self.assertEqual(len(pieces), 2) 490 tuple0, array1 = pieces 491 got = array1.to_py() 492 want = NumpyArrayS32([5]) 493 np.testing.assert_equal(want, got) 494 got = tuple0.to_py() 495 self.assertEqual(type(got), tuple) 496 self.assertEqual(len(got), 2) 497 np.testing.assert_equal(NumpyArrayF32([1.0, 2.0]), got[0]) 498 np.testing.assert_equal(NumpyArrayS32([3, 4]), got[1]) 499 500 def testShape(self): 501 pyval = np.array([[1., 2.]], np.float32) 502 local_buffer = xla_client.LocalBuffer.from_pyval(pyval) 503 xla_shape = local_buffer.shape() 504 self.assertEqual(xla_shape.dimensions(), (1, 2,)) 505 self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32)) 506 507 508class SingleOpTest(ComputationTest): 509 """Tests for single ops. 510 511 The goal here is smoke testing - to exercise the most basic functionality of 512 single XLA ops. As minimal as possible number of additional ops are added 513 around the op being tested. 514 """ 515 516 def testConcatenateF32(self): 517 c = self._NewComputation() 518 c.Concatenate( 519 (c.Constant(NumpyArrayF32([1.0, 2.0, 3.0])), 520 c.Constant(NumpyArrayF32([4.0, 5.0, 6.0]))), 521 dimension=0) 522 self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 523 524 def testConcatenateF64(self): 525 c = self._NewComputation() 526 c.Concatenate( 527 (c.Constant(NumpyArrayF64([1.0, 2.0, 3.0])), 528 c.Constant(NumpyArrayF64([4.0, 5.0, 6.0]))), 529 dimension=0) 530 self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 531 532 def testConvertElementType(self): 533 xla_types = { 534 np.bool: xla_client.PrimitiveType.PRED, 535 np.int32: xla_client.PrimitiveType.S32, 536 np.int64: xla_client.PrimitiveType.S64, 537 np.float32: xla_client.PrimitiveType.F32, 538 np.float64: xla_client.PrimitiveType.F64, 539 } 540 541 def _ConvertAndTest(template, src_dtype, dst_dtype): 542 c = self._NewComputation() 543 x = c.Constant(np.array(template, dtype=src_dtype)) 544 c.ConvertElementType(x, xla_types[dst_dtype]) 545 546 result = c.Build().Compile().ExecuteWithPythonValues() 547 expected = np.array(template, dtype=dst_dtype) 548 549 self.assertEqual(result.shape, expected.shape) 550 self.assertEqual(result.dtype, expected.dtype) 551 np.testing.assert_equal(result, expected) 552 553 x = [0, 1, 0, 0, 1] 554 for src_dtype, dst_dtype in itertools.product(xla_types, xla_types): 555 _ConvertAndTest(x, src_dtype, dst_dtype) 556 557 def testBitcastConvertType(self): 558 xla_x32_types = { 559 np.int32: xla_client.PrimitiveType.S32, 560 np.float32: xla_client.PrimitiveType.F32, 561 } 562 563 xla_x64_types = { 564 np.int64: xla_client.PrimitiveType.S64, 565 np.float64: xla_client.PrimitiveType.F64, 566 } 567 568 def _ConvertAndTest(template, src_dtype, dst_dtype, dst_etype): 569 c = self._NewComputation() 570 x = c.Constant(np.array(template, dtype=src_dtype)) 571 c.BitcastConvertType(x, dst_etype) 572 573 result = c.Build().Compile().ExecuteWithPythonValues() 574 expected = np.array(template, src_dtype).view(dst_dtype) 575 576 self.assertEqual(result.shape, expected.shape) 577 self.assertEqual(result.dtype, expected.dtype) 578 np.testing.assert_equal(result, expected) 579 580 x = [0, 1, 0, 0, 1] 581 for xla_types in [xla_x32_types, xla_x64_types]: 582 for src_dtype, dst_dtype in itertools.product(xla_types, xla_types): 583 _ConvertAndTest(x, src_dtype, dst_dtype, xla_types[dst_dtype]) 584 585 # TODO(b/123523486) implement AllToAll on CPU 586 def DISABLED_testAllToAllOneReplica(self): 587 samples = [ 588 NumpyArrayF32([97.0]), 589 NumpyArrayF32([64.0, 117.0]), 590 NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), 591 ] 592 for lhs in samples[:1]: 593 c = self._NewComputation() 594 c.AllToAll(c.Constant(lhs), 0, 0) 595 self._ExecuteAndCompareExact(c, expected=lhs) 596 597 def testCrossReplicaSumOneReplica(self): 598 samples = [ 599 NumpyArrayF32(42.0), 600 NumpyArrayF32([97.0]), 601 NumpyArrayF32([64.0, 117.0]), 602 NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), 603 ] 604 for lhs in samples: 605 c = self._NewComputation() 606 c.CrossReplicaSum(c.Constant(lhs)) 607 self._ExecuteAndCompareExact(c, expected=lhs) 608 609 def testReplicaId(self): 610 c = self._NewComputation() 611 _ = c.ReplicaId() 612 self._ExecuteAndCompareExact(c, expected=0) 613 614 def testCrossReplicaSumOneReplicaWithSingletonGroup(self): 615 samples = [ 616 NumpyArrayF32(42.0), 617 NumpyArrayF32([97.0]), 618 NumpyArrayF32([64.0, 117.0]), 619 NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), 620 ] 621 for lhs in samples: 622 c = self._NewComputation() 623 c.CrossReplicaSum(c.Constant(lhs), [[0]]) 624 self._ExecuteAndCompareExact(c, expected=lhs) 625 626 def testDotMatrixVectorF32(self): 627 c = self._NewComputation() 628 lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]) 629 rhs = NumpyArrayF32([[10.0], [20.0]]) 630 c.Dot(c.Constant(lhs), c.Constant(rhs)) 631 self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) 632 633 def testDotMatrixVectorF64(self): 634 c = self._NewComputation() 635 lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]]) 636 rhs = NumpyArrayF64([[10.0], [20.0]]) 637 c.Dot(c.Constant(lhs), c.Constant(rhs)) 638 self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) 639 640 def testDotMatrixMatrixF32(self): 641 c = self._NewComputation() 642 lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]) 643 rhs = NumpyArrayF32([[10.0, 20.0], [100.0, 200.0]]) 644 c.Dot(c.Constant(lhs), c.Constant(rhs)) 645 self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) 646 647 def testDotMatrixMatrixF64(self): 648 c = self._NewComputation() 649 lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]]) 650 rhs = NumpyArrayF64([[10.0, 20.0], [100.0, 200.0]]) 651 c.Dot(c.Constant(lhs), c.Constant(rhs)) 652 self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) 653 654 def testDotGeneral(self): 655 c = self._NewComputation() 656 rng = np.random.RandomState(0) 657 lhs = NumpyArrayF32(rng.randn(10, 3, 4)) 658 rhs = NumpyArrayF32(rng.randn(10, 4, 5)) 659 dimension_numbers = (([2], [1]), ([0], [0])) 660 c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers) 661 self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs)) 662 663 def testDotGeneralWithDotDimensionNumbersProto(self): 664 c = self._NewComputation() 665 rng = np.random.RandomState(0) 666 lhs = NumpyArrayF32(rng.randn(10, 3, 4)) 667 rhs = NumpyArrayF32(rng.randn(10, 4, 5)) 668 669 dimension_numbers = xla_client.DotDimensionNumbers() 670 dimension_numbers.lhs_contracting_dimensions.append(2) 671 dimension_numbers.rhs_contracting_dimensions.append(1) 672 dimension_numbers.lhs_batch_dimensions.append(0) 673 dimension_numbers.rhs_batch_dimensions.append(0) 674 675 c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers) 676 self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs)) 677 678 def testConvF32Same(self): 679 c = self._NewComputation() 680 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 681 lhs = a(1, 2, 3, 4) 682 rhs = a(1, 2, 1, 2) * 10 683 c.Conv(c.Constant(lhs), c.Constant(rhs), 684 [1, 1], xla_client.PaddingType.SAME) 685 result = np.array([[[[640., 700., 760., 300.], 686 [880., 940., 1000., 380.], 687 [1120., 1180., 1240., 460.]]]]) 688 self._ExecuteAndCompareClose(c, expected=result) 689 690 def testConvF32Valid(self): 691 c = self._NewComputation() 692 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 693 lhs = a(1, 2, 3, 4) 694 rhs = a(1, 2, 1, 2) * 10 695 c.Conv(c.Constant(lhs), c.Constant(rhs), 696 [2, 1], xla_client.PaddingType.VALID) 697 result = np.array([[[[640., 700., 760.], 698 [1120., 1180., 1240.]]]]) 699 self._ExecuteAndCompareClose(c, expected=result) 700 701 def testConvWithGeneralPaddingF32(self): 702 c = self._NewComputation() 703 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 704 lhs = a(1, 1, 2, 3) 705 rhs = a(1, 1, 1, 2) * 10 706 strides = [1, 1] 707 pads = [(1, 0), (0, 1)] 708 lhs_dilation = (2, 1) 709 rhs_dilation = (1, 1) 710 c.ConvWithGeneralPadding(c.Constant(lhs), c.Constant(rhs), 711 strides, pads, lhs_dilation, rhs_dilation) 712 result = np.array([[[[0., 0., 0.], 713 [10., 20., 0.], 714 [0., 0., 0.], 715 [40., 50., 0.]]]]) 716 self._ExecuteAndCompareClose(c, expected=result) 717 718 def testConvGeneralDilatedF32(self): 719 c = self._NewComputation() 720 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 721 lhs = a(1, 1, 2, 3) 722 rhs = a(1, 1, 1, 2) * 10 723 strides = [1, 1] 724 pads = [(1, 0), (0, 1)] 725 lhs_dilation = (2, 1) 726 rhs_dilation = (1, 1) 727 dimension_numbers = ("NCHW", "OIHW", "NCHW") 728 c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs), 729 strides, pads, lhs_dilation, rhs_dilation, 730 dimension_numbers) 731 result = np.array([[[[0., 0., 0.], 732 [10., 20., 0.], 733 [0., 0., 0.], 734 [40., 50., 0.]]]]) 735 self._ExecuteAndCompareClose(c, expected=result) 736 737 def testConvGeneralDilatedPermutedF32(self): 738 c = self._NewComputation() 739 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 740 lhs = a(1, 1, 2, 3) 741 rhs = a(1, 1, 1, 2) * 10 742 strides = [1, 1] 743 pads = [(1, 0), (0, 1)] 744 lhs_dilation = (2, 1) 745 rhs_dilation = (1, 1) 746 747 dimension_numbers = ("NHWC", "OIHW", "CWNH") 748 c.ConvGeneralDilated(c.Constant(np.transpose(lhs, (0, 2, 3, 1))), 749 c.Constant(rhs), 750 strides, pads, lhs_dilation, rhs_dilation, 751 dimension_numbers) 752 result = np.array([[[[0., 0., 0.], 753 [10., 20., 0.], 754 [0., 0., 0.], 755 [40., 50., 0.]]]]) 756 self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2))) 757 758 def testConvGeneralDilatedGroupedConvolutionF32(self): 759 c = self._NewComputation() 760 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 761 lhs = a(1, 2, 2, 3) 762 rhs = a(2, 1, 1, 2) * 10 763 strides = [1, 1] 764 pads = [(1, 0), (0, 1)] 765 lhs_dilation = (2, 1) 766 rhs_dilation = (1, 1) 767 dimension_numbers = ("NCHW", "OIHW", "NCHW") 768 feature_group_count = 2 769 c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs), 770 strides, pads, lhs_dilation, rhs_dilation, 771 dimension_numbers, feature_group_count) 772 result = np.array([[[[0., 0., 0.], 773 [10., 20., 0.], 774 [0., 0., 0.], 775 [40., 50., 0.]], 776 [[0., 0., 0.], 777 [330., 380., 160.], 778 [0., 0., 0.], 779 [480., 530., 220.]]]]) 780 self._ExecuteAndCompareClose(c, expected=result) 781 782 def testBooleanNot(self): 783 c = self._NewComputation() 784 arr = NumpyArrayBool([True, False, True]) 785 c.Not(c.Constant(arr)) 786 self._ExecuteAndCompareClose(c, expected=~arr) 787 788 def testCountLeadingZeros(self): 789 c = self._NewComputation() 790 arr = NumpyArrayS32([0x7FFF, 0x12345678]) 791 c.Clz(c.Constant(arr)) 792 self._ExecuteAndCompareClose(c, expected=[17, 3]) 793 794 def testExp(self): 795 c = self._NewComputation() 796 arr = NumpyArrayF32([3.3, 12.1]) 797 c.Exp(c.Constant(arr)) 798 self._ExecuteAndCompareClose(c, expected=np.exp(arr)) 799 800 def testExpm1(self): 801 c = self._NewComputation() 802 arr = NumpyArrayF32([3.3, 12.1]) 803 c.Expm1(c.Constant(arr)) 804 self._ExecuteAndCompareClose(c, expected=np.expm1(arr)) 805 806 def testRound(self): 807 c = self._NewComputation() 808 arr = NumpyArrayF32([3.3, 12.1]) 809 c.Round(c.Constant(arr)) 810 self._ExecuteAndCompareClose(c, expected=np.round(arr)) 811 812 def testLog(self): 813 c = self._NewComputation() 814 arr = NumpyArrayF32([3.3, 12.1]) 815 c.Log(c.Constant(arr)) 816 self._ExecuteAndCompareClose(c, expected=np.log(arr)) 817 818 def testLog1p(self): 819 c = self._NewComputation() 820 arr = NumpyArrayF32([3.3, 12.1]) 821 c.Log1p(c.Constant(arr)) 822 self._ExecuteAndCompareClose(c, expected=np.log1p(arr)) 823 824 def testNeg(self): 825 c = self._NewComputation() 826 arr = NumpyArrayF32([3.3, 12.1]) 827 c.Neg(c.Constant(arr)) 828 self._ExecuteAndCompareClose(c, expected=-arr) 829 830 def testFloor(self): 831 c = self._NewComputation() 832 arr = NumpyArrayF32([3.3, 12.1]) 833 c.Floor(c.Constant(arr)) 834 self._ExecuteAndCompareClose(c, expected=np.floor(arr)) 835 836 def testCeil(self): 837 c = self._NewComputation() 838 arr = NumpyArrayF32([3.3, 12.1]) 839 c.Ceil(c.Constant(arr)) 840 self._ExecuteAndCompareClose(c, expected=np.ceil(arr)) 841 842 def testAbs(self): 843 c = self._NewComputation() 844 arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.]) 845 c.Abs(c.Constant(arr)) 846 self._ExecuteAndCompareClose(c, expected=np.abs(arr)) 847 848 def testTanh(self): 849 c = self._NewComputation() 850 arr = NumpyArrayF32([3.3, 12.1]) 851 c.Tanh(c.Constant(arr)) 852 self._ExecuteAndCompareClose(c, expected=np.tanh(arr)) 853 854 def testTrans(self): 855 856 def _TransposeAndTest(array): 857 c = self._NewComputation() 858 c.Trans(c.Constant(array)) 859 self._ExecuteAndCompareClose(c, expected=array.T) 860 861 # Test square and non-square matrices in both default (C) and F orders. 862 for array_fun in [NumpyArrayF32, NumpyArrayF64]: 863 _TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]])) 864 _TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]], order="F")) 865 _TransposeAndTest(array_fun([[1, 2], [4, 5]])) 866 _TransposeAndTest(array_fun([[1, 2], [4, 5]], order="F")) 867 868 def testTranspose(self): 869 870 def _TransposeAndTest(array, permutation): 871 c = self._NewComputation() 872 c.Transpose(c.Constant(array), permutation) 873 expected = np.transpose(array, permutation) 874 self._ExecuteAndCompareClose(c, expected=expected) 875 876 _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1]) 877 _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0]) 878 _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1]) 879 _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0]) 880 881 arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32) 882 for permutation in itertools.permutations(range(arr.ndim)): 883 _TransposeAndTest(arr, permutation) 884 _TransposeAndTest(np.asfortranarray(arr), permutation) 885 886 def testEq(self): 887 c = self._NewComputation() 888 c.Eq( 889 c.Constant(NumpyArrayS32([1, 2, 3, 4])), 890 c.Constant(NumpyArrayS32([4, 2, 3, 1]))) 891 self._ExecuteAndCompareExact(c, expected=[False, True, True, False]) 892 893 def testNe(self): 894 c = self._NewComputation() 895 c.Ne( 896 c.Constant(NumpyArrayS32([1, 2, 3, 4])), 897 c.Constant(NumpyArrayS32([4, 2, 3, 1]))) 898 self._ExecuteAndCompareExact(c, expected=[True, False, False, True]) 899 900 c.Ne( 901 c.Constant(NumpyArrayF32([-2.0, 0.0, 902 float("nan"), 903 float("nan")])), 904 c.Constant(NumpyArrayF32([2.0, -0.0, 1.0, float("nan")]))) 905 self._ExecuteAndAssertWith( 906 np.testing.assert_allclose, c, (), expected=[True, False, True, True]) 907 908 def testGt(self): 909 c = self._NewComputation() 910 c.Gt( 911 c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), 912 c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) 913 self._ExecuteAndCompareExact(c, expected=[False, True, True, False, False]) 914 915 def testGe(self): 916 c = self._NewComputation() 917 c.Ge( 918 c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), 919 c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) 920 self._ExecuteAndCompareExact(c, expected=[True, True, True, False, False]) 921 922 def testLt(self): 923 c = self._NewComputation() 924 c.Lt( 925 c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), 926 c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) 927 self._ExecuteAndCompareExact(c, expected=[False, False, False, True, True]) 928 929 def testLe(self): 930 c = self._NewComputation() 931 c.Le( 932 c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), 933 c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) 934 self._ExecuteAndCompareExact(c, expected=[True, False, False, True, True]) 935 936 def testMax(self): 937 c = self._NewComputation() 938 c.Max( 939 c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), 940 c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) 941 self._ExecuteAndCompareExact(c, expected=[1.0, 2.0, 3.0, 7.0, 12.0]) 942 943 def testMaxExplicitBroadcastDim0(self): 944 c = self._NewComputation() 945 c.Max( 946 c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 947 c.Constant(NumpyArrayF32([3, 4, 5])), 948 broadcast_dimensions=(0,)) 949 self._ExecuteAndCompareExact(c, expected=[[3, 3, 3], [4, 5, 6], [7, 8, 9]]) 950 951 def testMaxExplicitBroadcastDim1(self): 952 c = self._NewComputation() 953 c.Max( 954 c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 955 c.Constant(NumpyArrayF32([3, 4, 5])), 956 broadcast_dimensions=(1,)) 957 self._ExecuteAndCompareExact(c, expected=[[3, 4, 5], [4, 5, 6], [7, 8, 9]]) 958 959 def testMin(self): 960 c = self._NewComputation() 961 c.Min( 962 c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), 963 c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) 964 self._ExecuteAndCompareExact(c, expected=[1.0, 0.0, 2.0, 4.0, 9.0]) 965 966 def testPad(self): 967 c = self._NewComputation() 968 c.Pad( 969 c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), 970 c.Constant(NumpyArrayF32(0.0)), 971 [(1, 2, 1), (0, 1, 0)]) 972 self._ExecuteAndCompareClose(c, expected=[[0.0, 0.0, 0.0], 973 [1.0, 2.0, 0.0], 974 [0.0, 0.0, 0.0], 975 [3.0, 4.0, 0.0], 976 [0.0, 0.0, 0.0], 977 [0.0, 0.0, 0.0]]) 978 979 def testPadWithPaddingConfig(self): 980 c = self._NewComputation() 981 padding_config = xla_client.PaddingConfig() 982 for lo, hi, interior in [(1, 2, 1), (0, 1, 0)]: 983 dimension = xla_client.PaddingConfigDimension() 984 dimension.edge_padding_low = lo 985 dimension.edge_padding_high = hi 986 dimension.interior_padding = interior 987 padding_config.dimensions.append(dimension) 988 c.Pad( 989 c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), 990 c.Constant(NumpyArrayF32(0.0)), 991 padding_config) 992 self._ExecuteAndCompareClose(c, expected=[[0.0, 0.0, 0.0], 993 [1.0, 2.0, 0.0], 994 [0.0, 0.0, 0.0], 995 [3.0, 4.0, 0.0], 996 [0.0, 0.0, 0.0], 997 [0.0, 0.0, 0.0]]) 998 999 def testReshape(self): 1000 c = self._NewComputation() 1001 c.Reshape( 1002 c.Constant(NumpyArrayS32([[1, 2], [3, 4], [5, 6]])), 1003 dimensions=[0, 1], 1004 new_sizes=[2, 3]) 1005 self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 5, 6]]) 1006 1007 def testCollapse(self): 1008 c = self._NewComputation() 1009 c.Collapse( 1010 c.Constant(NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), 1011 dimensions=[1, 2]) 1012 self._ExecuteAndCompareExact(c, expected=[[1, 2, 3, 4], [5, 6, 7, 8]]) 1013 1014 def testRev(self): 1015 c = self._NewComputation() 1016 c.Rev( 1017 c.Constant(NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), 1018 dimensions=[0, 2]) 1019 self._ExecuteAndCompareExact( 1020 c, expected=[[[6, 5], [8, 7]], [[2, 1], [4, 3]]]) 1021 1022 def testClampF32(self): 1023 c = self._NewComputation() 1024 c.Clamp( 1025 c.Constant(NumpyArrayF32(-1)), 1026 c.Constant(NumpyArrayF32([-2, -1, 0, 1, 2, 3])), 1027 c.Constant(NumpyArrayF32(2))) 1028 self._ExecuteAndCompareExact(c, expected=[-1, -1, 0, 1, 2, 2]) 1029 1030 def testClampS32(self): 1031 c = self._NewComputation() 1032 c.Clamp( 1033 c.Constant(NumpyArrayS32(-1)), 1034 c.Constant(NumpyArrayS32([-2, -1, 0, 1, 2, 3])), 1035 c.Constant(NumpyArrayS32(2))) 1036 self._ExecuteAndCompareExact(c, expected=[-1, -1, 0, 1, 2, 2]) 1037 1038 def testSelect(self): 1039 c = self._NewComputation() 1040 c.Select( 1041 c.Constant(NumpyArrayBool([True, False, False, True, False])), 1042 c.Constant(NumpyArrayS32([1, 2, 3, 4, 5])), 1043 c.Constant(NumpyArrayS32([-1, -2, -3, -4, -5]))) 1044 self._ExecuteAndCompareExact(c, expected=[1, -2, -3, 4, -5]) 1045 1046 def testSlice(self): 1047 c = self._NewComputation() 1048 c.Slice( 1049 c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), [1, 0], 1050 [3, 2]) 1051 self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]]) 1052 1053 def testSliceInDim(self): 1054 c = self._NewComputation() 1055 c.SliceInDim( 1056 c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1057 start_index=1, 1058 limit_index=2, 1059 stride=1, 1060 dimno=1) 1061 self._ExecuteAndCompareExact(c, expected=[[2], [5], [8]]) 1062 c.SliceInDim( 1063 c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1064 start_index=0, 1065 limit_index=3, 1066 stride=2, 1067 dimno=0) 1068 self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [7, 8, 9]]) 1069 1070 def testDynamicSlice(self): 1071 c = self._NewComputation() 1072 c.DynamicSlice( 1073 c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1074 c.Constant(NumpyArrayS32([1, 0])), [2, 2]) 1075 self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]]) 1076 1077 def testDynamicUpdateSlice(self): 1078 c = self._NewComputation() 1079 c.DynamicUpdateSlice( 1080 c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1081 c.Constant(NumpyArrayS32([[1, 2], [3, 4]])), 1082 c.Constant(NumpyArrayS32([1, 1]))) 1083 self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 1, 2], [7, 3, 4]]) 1084 1085 def testTuple(self): 1086 c = self._NewComputation() 1087 c.Tuple( 1088 c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])), 1089 c.Constant(NumpyArrayBool([True, False, False, True]))) 1090 result = c.Build().Compile().ExecuteWithPythonValues() 1091 self.assertIsInstance(result, tuple) 1092 np.testing.assert_equal(result[0], 42) 1093 np.testing.assert_allclose(result[1], [1.0, 2.0]) 1094 np.testing.assert_equal(result[2], [True, False, False, True]) 1095 1096 def testGetTupleElement(self): 1097 c = self._NewComputation() 1098 c.GetTupleElement( 1099 c.Tuple( 1100 c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])), 1101 c.Constant(NumpyArrayBool([True, False, False, True]))), 1) 1102 self._ExecuteAndCompareClose(c, expected=[1.0, 2.0]) 1103 1104 def testBroadcast(self): 1105 c = self._NewComputation() 1106 c.Broadcast(c.Constant(NumpyArrayS32([10, 20, 30, 40])), sizes=(3,)) 1107 self._ExecuteAndCompareExact( 1108 c, expected=[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]]) 1109 1110 def testBroadcastInDim(self): 1111 c = self._NewComputation() 1112 c.BroadcastInDim(c.Constant(NumpyArrayS32([1, 2])), [2, 2], [0]) 1113 self._ExecuteAndCompareExact(c, expected=[[1, 1], [2, 2]]) 1114 c.BroadcastInDim(c.Constant(NumpyArrayS32([1, 2])), [2, 2], [1]) 1115 self._ExecuteAndCompareExact(c, expected=[[1, 2], [1, 2]]) 1116 1117 def testRngNormal(self): 1118 shape = (2, 3) 1119 c = self._NewComputation() 1120 c.RngNormal(c.Constant(NumpyArrayF32(0.)), c.Constant(NumpyArrayF32(1.)), 1121 dims=shape) 1122 result = c.Build().Compile().ExecuteWithPythonValues() 1123 # since the result is random, we just check shape and uniqueness 1124 self.assertEqual(result.shape, shape) 1125 self.assertEqual(len(np.unique(result)), np.prod(shape)) 1126 1127 def testRngUniformF32(self): 1128 lo, hi = 2., 4. 1129 shape = (2, 3) 1130 c = self._NewComputation() 1131 c.RngUniform(c.Constant(NumpyArrayF32(lo)), c.Constant(NumpyArrayF32(hi)), 1132 dims=shape) 1133 result = c.Build().Compile().ExecuteWithPythonValues() 1134 # since the result is random, we just check shape, uniqueness, and range 1135 self.assertEqual(result.shape, shape) 1136 self.assertEqual(len(np.unique(result)), np.prod(shape)) 1137 self.assertTrue(np.all(lo <= result)) 1138 self.assertTrue(np.all(result < hi)) 1139 1140 def testRngUniformS32(self): 1141 lo, hi = 2, 4 1142 shape = (2, 3) 1143 c = self._NewComputation() 1144 c.RngUniform(c.Constant(NumpyArrayS32(lo)), c.Constant(NumpyArrayS32(hi)), 1145 dims=shape) 1146 result = c.Build().Compile().ExecuteWithPythonValues() 1147 # since the result is random, we just check shape, integrality, and range 1148 self.assertEqual(result.shape, shape) 1149 self.assertEqual(result.dtype, np.int32) 1150 self.assertTrue(np.all(lo <= result)) 1151 self.assertTrue(np.all(result < hi)) 1152 1153 def testCholesky(self): 1154 l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]], 1155 dtype=np.float32) 1156 c = self._NewComputation() 1157 c.Cholesky(c.Constant(np.dot(l, l.T))) 1158 self._ExecuteAndCompareClose(c, expected=l, rtol=1e-4) 1159 1160 def testQR(self): 1161 a = np.array( 1162 [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]], 1163 dtype=np.float32) 1164 c = self._NewComputation() 1165 c.QR(c.Constant(a), full_matrices=True) 1166 q, r = self._Execute(c, ()) 1167 np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4) 1168 1169 def testEigh(self): 1170 a = np.array( 1171 [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]], 1172 dtype=np.float32) 1173 a = (a + a.T) / 2 1174 1175 c = self._NewComputation() 1176 c.Eigh(c.Constant(a), full_matrices=True) 1177 v, w = self._Execute(c, ()) 1178 self.assertLess(np.linalg.norm(np.dot(a, v) - w * v), 1e-3) 1179 1180 def testSVD(self): 1181 a = np.array( 1182 [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]], 1183 dtype=np.float32) 1184 c = self._NewComputation() 1185 c.SVD(c.Constant(a)) 1186 u, d, v = self._Execute(c, ()) 1187 self.assertLess(np.linalg.norm(a - np.matmul(u * d, v.T)), 1e-3) 1188 1189 def testTriangularSolve(self): 1190 a_vals = np.array( 1191 [[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]], 1192 dtype=np.float32) 1193 b_vals = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], 1194 dtype=np.float32) 1195 1196 c = self._NewComputation() 1197 c.TriangularSolve(c.Constant(a_vals), c.Constant(b_vals), left_side=False, 1198 lower=True, transpose_a=True) 1199 self._ExecuteAndCompareClose(c, expected=np.array([ 1200 [0.5, 0.08333334, 0.04629629, 0.03367003], 1201 [2.5, -0.25, -0.1388889, -0.1010101], 1202 [4.5, -0.58333331, -0.32407406, -0.23569024], 1203 ], dtype=np.float32), rtol=1e-4) 1204 1205 def testIsConstant(self): 1206 c = self._NewComputation() 1207 a = c.ConstantS32Scalar(3) 1208 b = c.ConstantS32Scalar(1) 1209 x = c.ParameterFromNumpy(NumpyArrayS32(0)) 1210 const_expr = c.Sub(b, a) 1211 non_const_expr = c.Mul(const_expr, x) 1212 self.assertTrue(c.IsConstant(const_expr)) 1213 self.assertFalse(c.IsConstant(non_const_expr)) 1214 # self.assertTrue(c.IsConstant(c.Sub(c.Add(x, a), x))) # TODO(b/77245564) 1215 1216 def testGather(self): 1217 a = np.arange(9).astype(np.int32).reshape((3, 3)) 1218 indices = np.array([[[0, 2], [2, 1]], [[1, 2], [2, 0]]], dtype=np.int32) 1219 dnums = xla_client.GatherDimensionNumbers() 1220 dnums.offset_dims.append(1) 1221 dnums.offset_dims.append(2) 1222 dnums.start_index_map.append(0) 1223 dnums.start_index_map.append(1) 1224 dnums.index_vector_dim = 2 1225 c = self._NewComputation() 1226 c.Gather(c.Constant(a), c.Constant(indices), dnums, slice_sizes=[1, 1]) 1227 g = self._Execute(c, ()) 1228 expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32) 1229 np.testing.assert_allclose(g, expected, rtol=1e-4) 1230 1231 1232class EmbeddedComputationsTest(ComputationTest): 1233 """Tests for XLA graphs with embedded computations (such as maps).""" 1234 1235 def _CreateConstantS32Computation(self): 1236 """Computation (f32) -> s32 that returns a constant 1 for any input.""" 1237 c = self._NewComputation("constant_s32_one") 1238 # TODO(eliben): consider adding a nicer way to create new parameters without 1239 # having to create dummy Numpy arrays or populating Shape messages. Perhaps 1240 # we need our own (Python-client-own) way to represent Shapes conveniently. 1241 c.ParameterFromNumpy(NumpyArrayF32(0)) 1242 c.ConstantS32Scalar(1) 1243 return c.Build() 1244 1245 def _CreateConstantS64Computation(self): 1246 """Computation (f64) -> s64 that returns a constant 1 for any input.""" 1247 c = self._NewComputation("constant_s64_one") 1248 # TODO(eliben): consider adding a nicer way to create new parameters without 1249 # having to create dummy Numpy arrays or populating Shape messages. Perhaps 1250 # we need our own (Python-client-own) way to represent Shapes conveniently. 1251 c.ParameterFromNumpy(NumpyArrayF64(0)) 1252 c.ConstantS64Scalar(1) 1253 return c.Build() 1254 1255 def _CreateConstantF32Computation(self): 1256 """Computation (f32) -> f32 that returns a constant 1.0 for any input.""" 1257 c = self._NewComputation("constant_f32_one") 1258 c.ParameterFromNumpy(NumpyArrayF32(0)) 1259 c.ConstantF32Scalar(1.0) 1260 return c.Build() 1261 1262 def _CreateConstantF64Computation(self): 1263 """Computation (f64) -> f64 that returns a constant 1.0 for any input.""" 1264 c = self._NewComputation("constant_f64_one") 1265 c.ParameterFromNumpy(NumpyArrayF64(0)) 1266 c.ConstantF64Scalar(1.0) 1267 return c.Build() 1268 1269 def _CreateMulF32By2Computation(self): 1270 """Computation (f32) -> f32 that multiplies its parameter by 2.""" 1271 c = self._NewComputation("mul_f32_by2") 1272 c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(2.0)) 1273 return c.Build() 1274 1275 def _CreateMulF32ByParamComputation(self): 1276 """Computation (f32) -> f32 that multiplies one parameter by the other.""" 1277 c = self._NewComputation("mul_f32_by_param") 1278 c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)), 1279 c.ParameterFromNumpy(NumpyArrayF32(0))) 1280 return c.Build() 1281 1282 def _CreateMulF64By2Computation(self): 1283 """Computation (f64) -> f64 that multiplies its parameter by 2.""" 1284 c = self._NewComputation("mul_f64_by2") 1285 c.Mul(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(2.0)) 1286 return c.Build() 1287 1288 def _CreateBinaryAddS32Computation(self): 1289 """Computation (s32, s32) -> s32 that adds its two parameters.""" 1290 c = self._NewComputation("add_param0_by_param1") 1291 c.Add( 1292 c.ParameterFromNumpy(NumpyArrayS32(0)), 1293 c.ParameterFromNumpy(NumpyArrayS32(0))) 1294 return c.Build() 1295 1296 def _CreateBinaryAddF32Computation(self): 1297 """Computation (f32, f32) -> f32 that adds its two parameters.""" 1298 c = self._NewComputation("add_param0_by_param1") 1299 c.Add( 1300 c.ParameterFromNumpy(NumpyArrayF32(0)), 1301 c.ParameterFromNumpy(NumpyArrayF32(0))) 1302 return c.Build() 1303 1304 def _CreateBinaryAddF64Computation(self): 1305 """Computation (f64, f64) -> f64 that adds its two parameters.""" 1306 c = self._NewComputation("add_param0_by_param1") 1307 c.Add( 1308 c.ParameterFromNumpy(NumpyArrayF64(0)), 1309 c.ParameterFromNumpy(NumpyArrayF64(0))) 1310 return c.Build() 1311 1312 def _CreateBinaryDivF32Computation(self): 1313 """Computation (f32, f32) -> f32 that divides its two parameters.""" 1314 c = self._NewComputation("div_param0_by_param1") 1315 c.Div( 1316 c.ParameterFromNumpy(NumpyArrayF32(0)), 1317 c.ParameterFromNumpy(NumpyArrayF32(0))) 1318 return c.Build() 1319 1320 def _CreateBinaryDivF64Computation(self): 1321 """Computation (f64, f64) -> f64 that divides its two parameters.""" 1322 c = self._NewComputation("div_param0_by_param1") 1323 c.Div( 1324 c.ParameterFromNumpy(NumpyArrayF64(0)), 1325 c.ParameterFromNumpy(NumpyArrayF64(0))) 1326 return c.Build() 1327 1328 def _CreateTestF32Lt10Computation(self): 1329 """Computation (f32) -> bool that tests if its parameter is less than 10.""" 1330 c = self._NewComputation("test_f32_lt_10") 1331 c.Lt(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(10.)) 1332 return c.Build() 1333 1334 def _CreateTestF64Lt10Computation(self): 1335 """Computation (f64) -> bool that tests if its parameter is less than 10.""" 1336 c = self._NewComputation("test_f64_lt_10") 1337 c.Lt(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(10.)) 1338 return c.Build() 1339 1340 def _CreateBinaryGeF32Computation(self): 1341 """Computation (f32, f32) -> bool that tests first_param >= second_param.""" 1342 c = self._NewComputation("param0_lt_param1") 1343 c.Ge(c.ParameterFromNumpy(NumpyArrayF32(0)), 1344 c.ParameterFromNumpy(NumpyArrayF32(0))) 1345 return c.Build() 1346 1347 def _CreateBinaryGeF64Computation(self): 1348 """Computation (f64, f64) -> bool that tests first_param >= second_param.""" 1349 c = self._NewComputation("param0_lt_param1") 1350 c.Ge(c.ParameterFromNumpy(NumpyArrayF64(0)), 1351 c.ParameterFromNumpy(NumpyArrayF64(0))) 1352 return c.Build() 1353 1354 def _MakeSample3DArrayF32(self): 1355 return NumpyArrayF32([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], 1356 [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]]) 1357 1358 def _MakeSample3DArrayF64(self): 1359 return NumpyArrayF64([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], 1360 [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]]) 1361 1362 def testCallF32(self): 1363 c = self._NewComputation() 1364 c.Call( 1365 self._CreateMulF32By2Computation(), 1366 operands=(c.ConstantF32Scalar(5.0),)) 1367 self._ExecuteAndCompareClose(c, expected=10.0) 1368 1369 def testCallF64(self): 1370 c = self._NewComputation() 1371 c.Call( 1372 self._CreateMulF64By2Computation(), 1373 operands=(c.ConstantF64Scalar(5.0),)) 1374 self._ExecuteAndCompareClose(c, expected=10.0) 1375 1376 def testMapEachElementToS32Constant(self): 1377 c = self._NewComputation() 1378 c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], 1379 self._CreateConstantS32Computation(), [0]) 1380 self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1]) 1381 1382 def testMapEachElementToS64Constant(self): 1383 c = self._NewComputation() 1384 c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], 1385 self._CreateConstantS64Computation(), [0]) 1386 self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1]) 1387 1388 def testMapMulBy2F32(self): 1389 c = self._NewComputation() 1390 c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], 1391 self._CreateMulF32By2Computation(), [0]) 1392 self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0]) 1393 1394 def testMapMulBy2F64(self): 1395 c = self._NewComputation() 1396 c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], 1397 self._CreateMulF64By2Computation(), [0]) 1398 self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0]) 1399 1400 def testSimpleMapChainF32(self): 1401 # Chains a map of constant-f32 with a map of mul-by-2 1402 c = self._NewComputation() 1403 const_f32 = c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], 1404 self._CreateConstantF32Computation(), [0]) 1405 c.Map([const_f32], self._CreateMulF32By2Computation(), [0]) 1406 self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0]) 1407 1408 def testSimpleMapChainF64(self): 1409 # Chains a map of constant-f64 with a map of mul-by-2 1410 c = self._NewComputation() 1411 const_f64 = c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], 1412 self._CreateConstantF64Computation(), [0]) 1413 c.Map([const_f64], self._CreateMulF64By2Computation(), [0]) 1414 self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0]) 1415 1416 def testDivVectorsWithMapF32(self): 1417 c = self._NewComputation() 1418 c.Map((c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])), 1419 c.Constant(NumpyArrayF32([5.0, 5.0, 4.0, 4.0]))), 1420 self._CreateBinaryDivF32Computation(), [0]) 1421 self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0]) 1422 1423 def testDivVectorsWithMapF64(self): 1424 c = self._NewComputation() 1425 c.Map((c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])), 1426 c.Constant(NumpyArrayF64([5.0, 5.0, 4.0, 4.0]))), 1427 self._CreateBinaryDivF64Computation(), [0]) 1428 self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0]) 1429 1430 def testSelectAndScatterF32(self): 1431 c = self._NewComputation() 1432 c.SelectAndScatter(c.Constant(NumpyArrayF32([[1., 2., 6.], [4., 5., 3.]])), 1433 select=self._CreateBinaryGeF32Computation(), 1434 window_dimensions=(2, 1), 1435 window_strides=(1, 2), 1436 padding=xla_client.PaddingType.VALID, 1437 source=c.Constant(NumpyArrayF32([[0.1, 0.2]])), 1438 init_value=c.Constant(NumpyArrayF32(1)), 1439 scatter=self._CreateBinaryAddF32Computation()) 1440 self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]]) 1441 1442 def testSelectAndScatterF64(self): 1443 c = self._NewComputation() 1444 c.SelectAndScatter(c.Constant(NumpyArrayF64([[1., 2., 6.], [4., 5., 3.]])), 1445 select=self._CreateBinaryGeF64Computation(), 1446 window_dimensions=(2, 1), 1447 window_strides=(1, 2), 1448 padding=xla_client.PaddingType.VALID, 1449 source=c.Constant(NumpyArrayF64([[0.1, 0.2]])), 1450 init_value=c.Constant(NumpyArrayF64(1)), 1451 scatter=self._CreateBinaryAddF64Computation()) 1452 self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]]) 1453 1454 def testReduce1DtoScalarF32(self): 1455 c = self._NewComputation() 1456 c.Reduce( 1457 operand=c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])), 1458 init_value=c.ConstantF32Scalar(0), 1459 computation_to_apply=self._CreateBinaryAddF32Computation(), 1460 dimensions=[0]) 1461 self._ExecuteAndCompareClose(c, expected=10) 1462 1463 def testReduce1DtoScalarF64(self): 1464 c = self._NewComputation() 1465 c.Reduce( 1466 operand=c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])), 1467 init_value=c.ConstantF64Scalar(0), 1468 computation_to_apply=self._CreateBinaryAddF64Computation(), 1469 dimensions=[0]) 1470 self._ExecuteAndCompareClose(c, expected=10) 1471 1472 def testReduce2DTo1DDim0F32(self): 1473 input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1474 c = self._NewComputation() 1475 c.Reduce( 1476 operand=c.Constant(input_array), 1477 init_value=c.ConstantF32Scalar(0), 1478 computation_to_apply=self._CreateBinaryAddF32Computation(), 1479 dimensions=[0]) 1480 self._ExecuteAndCompareClose(c, expected=[5, 7, 9]) 1481 1482 def testReduce2DTo1DDim0F64(self): 1483 input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1484 c = self._NewComputation() 1485 c.Reduce( 1486 operand=c.Constant(input_array), 1487 init_value=c.ConstantF64Scalar(0), 1488 computation_to_apply=self._CreateBinaryAddF64Computation(), 1489 dimensions=[0]) 1490 self._ExecuteAndCompareClose(c, expected=[5, 7, 9]) 1491 1492 def testReduce2DTo1DDim1F32(self): 1493 input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1494 c = self._NewComputation() 1495 c.Reduce( 1496 operand=c.Constant(input_array), 1497 init_value=c.ConstantF32Scalar(0), 1498 computation_to_apply=self._CreateBinaryAddF32Computation(), 1499 dimensions=[1]) 1500 self._ExecuteAndCompareClose(c, expected=[6, 15]) 1501 1502 def testReduce2DTo1DDim1F64(self): 1503 input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1504 c = self._NewComputation() 1505 c.Reduce( 1506 operand=c.Constant(input_array), 1507 init_value=c.ConstantF64Scalar(0), 1508 computation_to_apply=self._CreateBinaryAddF64Computation(), 1509 dimensions=[1]) 1510 self._ExecuteAndCompareClose(c, expected=[6, 15]) 1511 1512 def testReduce3DAllPossibleWaysF32(self): 1513 input_array = self._MakeSample3DArrayF32() 1514 1515 def _ReduceAndTest(*dims): 1516 c = self._NewComputation() 1517 c.Reduce( 1518 operand=c.Constant(input_array), 1519 init_value=c.ConstantF32Scalar(0), 1520 computation_to_apply=self._CreateBinaryAddF32Computation(), 1521 dimensions=dims) 1522 self._ExecuteAndCompareClose( 1523 c, expected=np.sum(input_array, axis=tuple(dims))) 1524 1525 _ReduceAndTest(0) 1526 _ReduceAndTest(0, 1) 1527 _ReduceAndTest(0, 2) 1528 _ReduceAndTest(1, 2) 1529 _ReduceAndTest(0, 1, 2) 1530 1531 def testReduce3DAllPossibleWaysF64(self): 1532 input_array = self._MakeSample3DArrayF64() 1533 1534 def _ReduceAndTest(*dims): 1535 c = self._NewComputation() 1536 c.Reduce( 1537 operand=c.Constant(input_array), 1538 init_value=c.ConstantF64Scalar(0), 1539 computation_to_apply=self._CreateBinaryAddF64Computation(), 1540 dimensions=dims) 1541 self._ExecuteAndCompareClose( 1542 c, expected=np.sum(input_array, axis=tuple(dims))) 1543 1544 _ReduceAndTest(0) 1545 _ReduceAndTest(0) 1546 _ReduceAndTest(0, 1) 1547 _ReduceAndTest(0, 2) 1548 _ReduceAndTest(1, 2) 1549 _ReduceAndTest(0, 1, 2) 1550 1551 def testReduceWindowValidUnitStridesF32(self): 1552 input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1553 c = self._NewComputation() 1554 c.ReduceWindow(operand=c.Constant(input_array), 1555 init_value=c.ConstantF32Scalar(0), 1556 computation_to_apply=self._CreateBinaryAddF32Computation(), 1557 window_dimensions=(2, 1), window_strides=(1, 1), 1558 padding=xla_client.PaddingType.VALID) 1559 self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]]) 1560 1561 def testReduceWindowSameUnitStridesF32(self): 1562 input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1563 c = self._NewComputation() 1564 c.ReduceWindow(operand=c.Constant(input_array), 1565 init_value=c.ConstantF32Scalar(0), 1566 computation_to_apply=self._CreateBinaryAddF32Computation(), 1567 window_dimensions=(2, 1), window_strides=(1, 1), 1568 padding=xla_client.PaddingType.SAME) 1569 self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]]) 1570 1571 def testReduceWindowValidGeneralStridesF32(self): 1572 input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1573 c = self._NewComputation() 1574 c.ReduceWindow(operand=c.Constant(input_array), 1575 init_value=c.ConstantF32Scalar(0), 1576 computation_to_apply=self._CreateBinaryAddF32Computation(), 1577 window_dimensions=(2, 1), window_strides=(1, 2), 1578 padding=xla_client.PaddingType.VALID) 1579 self._ExecuteAndCompareClose(c, expected=[[5., 9.]]) 1580 1581 def testReduceWindowValidUnitStridesF64(self): 1582 input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1583 c = self._NewComputation() 1584 c.ReduceWindow(operand=c.Constant(input_array), 1585 init_value=c.ConstantF64Scalar(0), 1586 computation_to_apply=self._CreateBinaryAddF64Computation(), 1587 window_dimensions=(2, 1), window_strides=(1, 1), 1588 padding=xla_client.PaddingType.VALID) 1589 self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]]) 1590 1591 def testReduceWindowSameUnitStridesF64(self): 1592 input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1593 c = self._NewComputation() 1594 c.ReduceWindow(operand=c.Constant(input_array), 1595 init_value=c.ConstantF64Scalar(0), 1596 computation_to_apply=self._CreateBinaryAddF64Computation(), 1597 window_dimensions=(2, 1), window_strides=(1, 1), 1598 padding=xla_client.PaddingType.SAME) 1599 self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]]) 1600 1601 def testReduceWindowValidGeneralStridesF64(self): 1602 input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1603 c = self._NewComputation() 1604 c.ReduceWindow(operand=c.Constant(input_array), 1605 init_value=c.ConstantF64Scalar(0), 1606 computation_to_apply=self._CreateBinaryAddF64Computation(), 1607 window_dimensions=(2, 1), window_strides=(1, 2), 1608 padding=xla_client.PaddingType.VALID) 1609 self._ExecuteAndCompareClose(c, expected=[[5., 9.]]) 1610 1611 def testWhileF32(self): 1612 cond = self._CreateTestF32Lt10Computation() 1613 body = self._CreateMulF32By2Computation() 1614 c = self._NewComputation() 1615 init = c.ConstantF32Scalar(1.) 1616 c.While(cond, body, init) 1617 self._ExecuteAndCompareClose(c, expected=16.) 1618 1619 def testWhileF64(self): 1620 cond = self._CreateTestF64Lt10Computation() 1621 body = self._CreateMulF64By2Computation() 1622 c = self._NewComputation() 1623 init = c.ConstantF64Scalar(1.) 1624 c.While(cond, body, init) 1625 self._ExecuteAndCompareClose(c, expected=16.) 1626 1627 def testConditionalTrue(self): 1628 c = self._NewComputation() 1629 pred = c.ConstantPredScalar(True) 1630 true_operand = c.ConstantF32Scalar(3.) 1631 true_computation = self._CreateMulF32By2Computation() 1632 false_operand = c.ConstantF32Scalar(2.) 1633 false_computation = self._CreateConstantF32Computation() 1634 c.Conditional(pred, true_operand, true_computation, false_operand, 1635 false_computation) 1636 self._ExecuteAndCompareClose(c, expected=6.) 1637 1638 def testConditionalFalse(self): 1639 c = self._NewComputation() 1640 pred = c.ConstantPredScalar(False) 1641 true_operand = c.ConstantF32Scalar(3.) 1642 true_computation = self._CreateMulF32By2Computation() 1643 false_operand = c.ConstantF32Scalar(2.) 1644 false_computation = self._CreateConstantF32Computation() 1645 c.Conditional(pred, true_operand, true_computation, false_operand, 1646 false_computation) 1647 self._ExecuteAndCompareClose(c, expected=1.) 1648 1649 def testInfeedS32Values(self): 1650 to_infeed = NumpyArrayS32([1, 2, 3, 4]) 1651 c = self._NewComputation() 1652 c.Infeed(xla_client.Shape.from_pyval(to_infeed[0])) 1653 compiled_c = c.Build().CompileWithExampleArguments() 1654 for item in to_infeed: 1655 xla_client.transfer_to_infeed(item) 1656 1657 for item in to_infeed: 1658 result = compiled_c.ExecuteWithPythonValues() 1659 self.assertEqual(result, item) 1660 1661 def testInfeedThenOutfeedS32(self): 1662 to_round_trip = NumpyArrayS32([1, 2, 3, 4]) 1663 c = self._NewComputation() 1664 x = c.Infeed(xla_client.Shape.from_pyval(to_round_trip[0])) 1665 c.Outfeed(x) 1666 1667 compiled_c = c.Build().CompileWithExampleArguments() 1668 1669 for want in to_round_trip: 1670 execution = threading.Thread(target=compiled_c.Execute) 1671 execution.start() 1672 xla_client.transfer_to_infeed(want) 1673 got = xla_client.transfer_from_outfeed( 1674 xla_client.Shape.from_pyval(to_round_trip[0])) 1675 execution.join() 1676 self.assertEqual(want, got) 1677 1678 def testScatter(self): 1679 a = np.arange(9).astype(np.int32).reshape((3, 3)) 1680 scatter_indices = np.array([0, 2], dtype=np.int32) 1681 updates = np.array([[10, 20, 30], [70, 80, 90]], dtype=np.int32) 1682 1683 dnums = xla_client.ScatterDimensionNumbers() 1684 dnums.update_window_dims.append(1) 1685 dnums.inserted_window_dims.append(0) 1686 dnums.scatter_dims_to_operand_dims.append(0) 1687 dnums.index_vector_dim = 1 1688 1689 c = self._NewComputation() 1690 c.Scatter(c.Constant(a), c.Constant(scatter_indices), c.Constant(updates), 1691 self._CreateBinaryAddS32Computation(), dnums) 1692 expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]], dtype=np.int32) 1693 self._ExecuteAndCompareClose(c, expected=expected) 1694 1695 1696class ErrorTest(ComputationTest): 1697 1698 def setUp(self): 1699 self.f32_scalar_2 = NumpyArrayF32(2.0) 1700 self.s32_scalar_2 = NumpyArrayS32(2) 1701 1702 def testInvokeWithWrongElementType(self): 1703 c = self._NewComputation() 1704 c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata()) 1705 c.ParameterFromNumpy(self.s32_scalar_2) 1706 c.ClearOpMetadata() 1707 self.assertRaisesRegexp( 1708 RuntimeError, r"Invalid argument shape.*xla_client_test.py.*" 1709 r"expected s32\[\], got f32\[\]", 1710 lambda: c.Build().CompileWithExampleArguments([self.f32_scalar_2])) 1711 1712 1713class ComputationRootTest(ComputationTest): 1714 """Tests related to setting the root of the computation.""" 1715 1716 def testComputationRootDifferentFromLastOp(self): 1717 c = self._NewComputation() 1718 x = c.ParameterFromNumpy(NumpyArrayF32(2.0)) 1719 result = c.Add(x, c.ConstantF32Scalar(3.14)) 1720 extra = c.Add(result, c.ConstantF32Scalar(1.618)) # pylint: disable=unused-variable 1721 1722 arg = NumpyArrayF32(1.0) 1723 compiled_c = c.Build(result).CompileWithExampleArguments([arg]) 1724 ans = compiled_c.ExecuteWithPythonValues([arg]) 1725 np.testing.assert_allclose(ans, 4.14) 1726 1727 1728if __name__ == "__main__": 1729 unittest.main() 1730