1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""An XLA client in Python, supporting AOT compilation.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22import collections 23import enum # pylint: disable=g-bad-import-order 24import inspect 25import itertools 26import os 27 28import numpy as np 29 30import six 31from six.moves import xrange 32 33# Note this module does *not* depend on any Python protocol buffers. The XLA 34# Python bindings are currently packaged both as part of jaxlib and as part 35# of TensorFlow. If we use protocol buffers here, then importing both jaxlib 36# and TensorFlow may fail with duplicate protocol buffer message definitions. 37 38from tensorflow.compiler.xla.python import pywrap_xla as c_api 39 40# Import the XRT backend, if available. 41try: 42 # pylint: disable=g-import-not-at-top 43 from tensorflow.compiler.xla.python import pywrap_xrt as xrt_api 44except ImportError: 45 xrt_api = None 46 47 48# Most functions are snake_case for consistency with other modules, whereas 49# method names of ComputationBuilder and Computation are CamelCase for 50# consistency with XLA. 51# pylint: disable=invalid-name 52 53 54# Version of the XLA Python client. 55# 56# JAX packages the XLA python plugin as a binary pip module (jaxlib) that is 57# packaged separately from the Python code that consumes it (jax). 58# 59# We occasionally need to make backwards-incompatible changes to jaxlib, in 60# which case we need to be able to detect when incompatible versions are 61# installed. 62def version(): 63 return (0, 1, 8) 64 65 66_OP_METADATA_FIELDS = [ 67 'op_type', 68 'op_name', 69 'source_file', 70 'source_line', 71] 72OpMetadata = collections.namedtuple('OpMetadata', _OP_METADATA_FIELDS) 73 74 75@six.add_metaclass(abc.ABCMeta) 76class Backend(object): 77 """Abstract base class for XLA backends.""" 78 79 @abc.abstractmethod 80 def device_count(self): 81 """Returns the number of devices known to the backend.""" 82 83 @abc.abstractmethod 84 def buffer_from_pyval(self, pyval, device=0): 85 """Allocates a fresh buffer and populates it with `pyval`.""" 86 87 @abc.abstractmethod 88 def delete_buffer(self, c_buffer): 89 """Deletes buffer `c_buffer`.""" 90 91 @abc.abstractmethod 92 def destructure_tuple(self, c_buffer): 93 """Destructures a tuple buffer into a sequence of buffers.""" 94 95 @abc.abstractmethod 96 def compile(self, computation, argument_shapes, result_shape, 97 compile_options): 98 """Compiles a computation. Returns an executable.""" 99 100 @abc.abstractmethod 101 def delete_executable(self, executable): 102 """Deletes an executable.""" 103 104 @abc.abstractmethod 105 def execute(self, executable, args): 106 """Runs an executable without replication.""" 107 108 @abc.abstractmethod 109 def execute_replicated(self, executable, per_replica_args): 110 """Runs an executable in a replicated manner.""" 111 112 113def _maybe_encode_string(s): 114 if six.PY3: 115 return s.encode('utf-8') 116 else: 117 return s 118 119 120class XlaLocalBackend(Backend): 121 """XLA backend implemented using the in-process xla::LocalClient API.""" 122 123 def __init__(self, platform=None): 124 platform = platform or _get_default_platform_name() 125 self.client = c_api.LocalClient.Get(_maybe_encode_string(platform)) 126 self._delete_buffer = c_api.DeleteLocalShapedBuffer 127 self._delete_executable = c_api.DeleteLocalExecutable 128 129 def device_count(self): 130 return self.client.DeviceCount() 131 132 def buffer_from_pyval(self, pyval, device=0): 133 return c_api.LocalShapedBuffer.FromLiteral(pyval, None, self.client, device) 134 135 def delete_buffer(self, c_buffer): 136 self._delete_buffer(c_buffer) 137 138 def destructure_tuple(self, c_buffer): 139 result = c_buffer.DestructureTuple() 140 return [result.Release(i) for i in xrange(result.size())] 141 142 def compile(self, c_computation, argument_shapes, result_shape, 143 compile_options): 144 return c_computation.Compile(argument_shapes, compile_options, self.client) 145 146 def delete_executable(self, executable): 147 self._delete_executable(executable) 148 149 def execute(self, executable, args): 150 return executable.Execute(args) 151 152 def execute_replicated(self, executable, per_replica_args): 153 output_buffer_tup = executable.ExecutePerReplica(per_replica_args) 154 size = output_buffer_tup.size() 155 return [output_buffer_tup.Release(i) for i in xrange(size)] 156 157 158class XrtBackend(Backend): 159 """XLA backend implemented using XRT.""" 160 161 def __init__(self, target): 162 self.target = target 163 self._delete_buffer = xrt_api.DeleteXrtAllocation 164 self._delete_executable = xrt_api.DeleteXrtExecutable 165 166 def device_count(self): 167 return 1 # Multidevice execution not implemented. 168 169 def buffer_from_pyval(self, pyval, device=0): 170 if device != 0: 171 raise NotImplementedError( 172 'Multi-replica execution is not yet supported via the XRT backend.') 173 return xrt_api.XrtAllocation.FromLiteral(pyval, 174 _maybe_encode_string(self.target)) 175 176 def delete_buffer(self, c_buffer): 177 self._delete_buffer(c_buffer) 178 179 def destructure_tuple(self, c_buffer): 180 result = xrt_api.DestructureXrtAllocationTuple( 181 c_buffer, _maybe_encode_string(self.target)) 182 return [result.Release(i) for i in xrange(result.size())] 183 184 def compile(self, c_computation, argument_shapes, result_shape, 185 compile_options): 186 return xrt_api.XrtExecutable.CompileForXrt( 187 c_computation.GetSerializedProto(), argument_shapes, result_shape, 188 _maybe_encode_string(self.target)) 189 190 def delete_executable(self, executable): 191 self._delete_executable(executable) 192 193 def execute(self, executable, args): 194 return executable.Execute(args) 195 196 def execute_replicated(self, executable, per_replica_args): 197 if len(per_replica_args) != 1: 198 raise NotImplementedError( 199 'Multi-replica execution is not yet supported via the XRT backend.') 200 return [executable.Execute(per_replica_args[0])] 201 202 203_default_platform_name = 'Host' 204_default_backend = None 205 206 207def _get_default_platform_name(): 208 return _default_platform_name 209 210 211def _get_default_local_backend(): 212 global _default_backend 213 global _default_platform_name 214 if _default_backend is None: 215 _default_backend = XlaLocalBackend(_default_platform_name) 216 return _default_backend 217 218 219class BackendType(enum.Enum): 220 XLA_LOCAL = 1 221 XRT = 2 222 223 224def BackendSpec(backend, target): 225 """Compatibility wrapper to support older clients. Do not use in new code.""" 226 if backend == BackendType.XLA_LOCAL: 227 return _get_default_local_backend() 228 elif backend == BackendType.XRT: 229 return XrtBackend(target) 230 else: 231 raise ValueError('Unknown backend {}'.format(backend)) 232 233 234def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): 235 """Helper for use in source mapping that returns an OpMetadata object.""" 236 full_filename, lineno = inspect.stack()[skip_frames][1:3] 237 filename = os.path.basename(full_filename) 238 return OpMetadata( 239 op_type=op_type, 240 op_name=op_name, 241 source_file=filename, 242 source_line=lineno) 243 244 245class PaddingType(enum.Enum): 246 VALID = 1 247 SAME = 2 248 249 250def _convert_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims, 251 window_strides): 252 """Maps PaddingType or string to pad values (list of pairs of ints).""" 253 if not isinstance(padding_type, (str, PaddingType)): 254 msg = 'padding_type must be str or PaddingType, got {}.' 255 raise TypeError(msg.format(type(padding_type))) 256 257 if isinstance(padding_type, str): 258 if padding_type.upper() == 'VALID': 259 padding_type = PaddingType.VALID 260 elif padding_type.upper() == 'SAME': 261 padding_type = PaddingType.SAME 262 else: 263 msg = 'Unknown padding type string: expected "VALID" or "SAME", got {}.' 264 raise ValueError(msg.format(padding_type)) 265 266 if padding_type == PaddingType.VALID: 267 return [(0, 0)] * len(window_strides) 268 elif padding_type == PaddingType.SAME: 269 out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int) 270 pad_sizes = [max((out_size - 1) * stride + filter_size - in_size, 0) 271 for out_size, stride, filter_size, in_size 272 in zip(out_shape, window_strides, rhs_dims, lhs_dims)] 273 return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] 274 else: 275 msg = 'Unexpected PaddingType value: {}' 276 raise ValueError(msg.format(padding_type)) 277 278 279_UNARY_OPS = [ 280 'Not', 281 'Clz', 282 'Abs', 283 'Exp', 284 'Expm1', 285 'Floor', 286 'Round', 287 'Ceil', 288 'Log', 289 'Log1p', 290 'Sign', 291 'Cos', 292 'Sin', 293 'Tanh', 294 'IsFinite', 295 'Sqrt', 296 'Rsqrt', 297 'Square', 298 'Reciprocal', 299 'Neg', 300 'Erf', 301 'Erfc', 302 'ErfInv', 303 'Lgamma', 304 'Digamma', 305 'Acos', 306 'Asin', 307 'Atan', 308 'Tan', 309 'Acosh', 310 'Asinh', 311 'Atanh', 312 'Cosh', 313 'Sinh', 314 'Real', 315 'Imag', 316 'Conj', 317] 318 319_BINARY_OPS = [ 320 'Eq', 321 'Ne', 322 'Ge', 323 'Gt', 324 'Lt', 325 'Le', 326 'Add', 327 'Sub', 328 'Mul', 329 'Div', 330 'Rem', 331 'Max', 332 'Min', 333 'And', 334 'Or', 335 'Xor', 336 'Pow', 337 'ShiftLeft', 338 'ShiftRightArithmetic', 339 'ShiftRightLogical', 340 'Atan2', 341 'Complex', 342] 343 344 345class PrimitiveType(enum.IntEnum): 346 """Python copy of the XLA PrimitiveType enum. 347 348 Must match the corresponding protocol buffer. 349 """ 350 PRIMITIVE_TYPE_INVALID = 0 351 PRED = 1 352 S8 = 2 353 S16 = 3 354 S32 = 4 355 S64 = 5 356 U8 = 6 357 U16 = 7 358 U32 = 8 359 U64 = 9 360 BF16 = 16 361 F16 = 10 362 F32 = 11 363 F64 = 12 364 C64 = 15 365 C128 = 18 366 TUPLE = 13 367 OPAQUE = 14 368 TOKEN = 17 369 370 371XLA_ELEMENT_TYPE_TO_DTYPE = { 372 PrimitiveType.PRED: np.dtype('bool'), 373 PrimitiveType.S8: np.dtype('int8'), 374 PrimitiveType.S16: np.dtype('int16'), 375 PrimitiveType.S32: np.dtype('int32'), 376 PrimitiveType.S64: np.dtype('int64'), 377 PrimitiveType.U8: np.dtype('uint8'), 378 PrimitiveType.U16: np.dtype('uint16'), 379 PrimitiveType.U32: np.dtype('uint32'), 380 PrimitiveType.U64: np.dtype('uint64'), 381 PrimitiveType.F16: np.dtype('float16'), 382 PrimitiveType.F32: np.dtype('float32'), 383 PrimitiveType.F64: np.dtype('float64'), 384 PrimitiveType.C64: np.dtype('complex64'), 385 PrimitiveType.C128: np.dtype('complex128'), 386 PrimitiveType.TUPLE: np.dtype(np.object), 387} 388 389# Note the conversion on the key. Numpy has a known issue wherein dtype hashing 390# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, 391# when keying by dtype in this dict, we use the string form of dtypes. 392DTYPE_TO_XLA_ELEMENT_TYPE = { 393 str(dt): et for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items() 394} 395 396 397def dtype_to_etype(dtype): 398 """Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE.""" 399 return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] 400 401 402class LocalBuffer(object): 403 """Represents a handle to data owned by XLA. 404 405 The referent is ready for use in executing a local, compiled 406 Computation. On XLA platforms involving a device (e.g. GPU), this 407 means the referent is in device memory. 408 """ 409 410 def __init__(self, c_buffer, backend, device): 411 self.c_buffer = c_buffer 412 self._backend = backend 413 self._device = device 414 415 @staticmethod 416 def from_pyval(pyval, device=0, backend=None): 417 """Allocate and copy to XLA the given python value.""" 418 backend = backend or _get_default_local_backend() 419 pyval = require_numpy_array_layout(pyval) 420 cbuf = backend.buffer_from_pyval(pyval, device) 421 return LocalBuffer(cbuf, backend, device) 422 423 def to_py(self): 424 return self.c_buffer.ToLiteral() 425 426 def shape(self): 427 return _wrap_shape(self.c_buffer.shape()) 428 429 def device(self): 430 return self._device 431 432 def delete(self): 433 if self.c_buffer is not None: 434 # Python may have freed c_api first. 435 if c_api: 436 self._backend.delete_buffer(self.c_buffer) 437 self.c_buffer = None 438 439 def destructure(self): 440 """Assuming a tuple buffer, unpack it into constituent tuple elements.""" 441 assert self.c_buffer is not None 442 result = self._backend.destructure_tuple(self.c_buffer) 443 self.delete() 444 return tuple( 445 LocalBuffer(sub_buffer, device=self._device, backend=self._backend) 446 for sub_buffer in result) 447 448 def is_deleted(self): 449 return self.c_buffer is None 450 451 def __del__(self): 452 self.delete() 453 454 455class Format(enum.IntEnum): 456 """Python copy of the Format protocol buffer enum.""" 457 INVALID_FORMAT = 0 458 DENSE = 1 459 SPARSE = 2 460 461 462class Shape(object): 463 """Represents an XLA shape. 464 465 A shape is either an array shape, having rank-many integer 466 dimensions and an element type (represented by a Numpy dtype), or it 467 is a tuple shape, having a shape for every tuple component: 468 469 type shape = 470 TupleShape of shape list 471 | ArrayShape of { dimensions: int list; element_type: dtype } 472 473 Callers are expected to instantiate this class only via the static 474 constructors: tuple_shape, array_shape, and from_pyval. 475 """ 476 477 @staticmethod 478 def tuple_shape(tuple_shapes): 479 """Construct a tuple shape.""" 480 if (not isinstance(tuple_shapes, (tuple, list)) or 481 not all(isinstance(t, Shape) for t in tuple_shapes)): 482 raise TypeError('tuple_shapes must be a tuple of Shapes') 483 return Shape(tuple_shapes, tuple) 484 485 @staticmethod 486 def array_shape(element_type, dimensions, minor_to_major=None): 487 """Construct an array shape.""" 488 if (not isinstance(dimensions, tuple) or 489 not all(isinstance(i, int) for i in dimensions)): 490 dimensions = tuple(int(i) for i in dimensions) 491 return Shape( 492 dimensions, np.dtype(element_type), minor_to_major=minor_to_major) 493 494 @staticmethod 495 def from_pyval(pyval): 496 def convert(pyval): 497 if isinstance(pyval, tuple): 498 return Shape.tuple_shape(tuple(convert(elt) for elt in pyval)) 499 else: 500 pyval = require_numpy_array_layout(pyval) 501 return Shape.array_shape(pyval.dtype, np.shape(pyval)) 502 return convert(pyval) 503 504 def __init__(self, dimensions, dtype, minor_to_major=None): 505 assert isinstance(dimensions, tuple) 506 self._dimensions = dimensions 507 self._dtype = dtype 508 self._is_tuple = dtype == tuple 509 self._minor_to_major = minor_to_major 510 self._check_minor_to_major() 511 512 def __eq__(self, other): 513 # pylint: disable=protected-access 514 return (self._dtype == other._dtype and 515 self._dimensions == other._dimensions and 516 self._minor_to_major == other._minor_to_major) 517 518 def __ne__(self, other): 519 return not self == other 520 521 def __hash__(self): 522 return hash((self._dtype, self._dimensions, self._minor_to_major)) 523 524 def __repr__(self): 525 return ('xla_client.Shape(_dtype={!r}, _dimensions={!r}, ' 526 '_is_tuple={!r}, _minor_to_major={!r})').format( 527 self._dtype, self._dimensions, self._is_tuple, 528 self._minor_to_major) 529 530 def is_tuple(self): 531 return self._is_tuple 532 533 def is_array(self): 534 return not self._is_tuple 535 536 def tuple_shapes(self): 537 if not self.is_tuple(): 538 raise ValueError('not a tuple shape') 539 return self._dimensions 540 541 def numpy_dtype(self): 542 """Like element_type(), but returns dtype('O') in case of a tuple shape.""" 543 if self.is_tuple(): 544 return np.dtype(np.object) 545 else: 546 return self.element_type() 547 548 def xla_element_type(self): 549 return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.numpy_dtype())] 550 551 def element_type(self): 552 if not self.is_array(): 553 raise ValueError('not an array shape') 554 return self._dtype 555 556 def dimensions(self): 557 if not self.is_array(): 558 raise ValueError('not an array shape') 559 return self._dimensions 560 561 def rank(self): 562 return len(self.dimensions()) 563 564 def minor_to_major(self): 565 return self._minor_to_major 566 567 def map_leaves(self, f): 568 """Map f over each leaf-level array subshape. 569 570 Args: 571 f: The function to apply. Whenever f returns None, the identity is applied 572 instead. 573 574 Returns: 575 A new Shape with the mapped leaves. 576 """ 577 if self.is_tuple(): 578 children = tuple(child.map_leaves(f) for child in self.tuple_shapes()) 579 return Shape.tuple_shape(children) 580 else: 581 mapped = f(self) 582 return self if mapped is None else mapped 583 584 def _check_minor_to_major(self): 585 mtm = self._minor_to_major 586 if self.is_tuple(): 587 assert mtm is None, self 588 if mtm is not None: 589 assert self.rank() == len(mtm), self 590 assert sorted(mtm) == list(range(len(mtm))), self 591 592 def update_minor_to_major(self, minor_to_major): 593 if not self.is_array(): 594 raise ValueError('not an array shape') 595 if not isinstance(minor_to_major, tuple): 596 raise TypeError('minor_to_major must be a tuple') 597 updated = Shape.array_shape(self.element_type(), self.dimensions(), 598 minor_to_major) 599 updated._check_minor_to_major() # pylint: disable=protected-access 600 return updated 601 602 def with_major_to_minor_layout_if_absent(self): 603 """Returns a copy of a shape with missing layouts set to major-to-minor.""" 604 605 def f(a): 606 if a.minor_to_major(): 607 return None 608 return a.update_minor_to_major(tuple(xrange(a.rank() - 1, -1, -1))) 609 610 return self.map_leaves(f) 611 612 def serialize(self, proto): 613 """Serializes 'shape' into proto.""" 614 if self.is_tuple(): 615 proto.element_type = PrimitiveType.TUPLE 616 for shape in self.tuple_shapes(): 617 shape.serialize(proto.tuple_shapes.add()) 618 else: 619 proto.element_type = dtype_to_etype(self.element_type()) 620 proto.dimensions.extend(self.dimensions()) 621 proto.is_dynamic_dimension.extend([False for _ in self.dimensions()]) 622 if self.minor_to_major(): 623 proto.layout.format = Format.DENSE 624 proto.layout.minor_to_major.extend(self.minor_to_major()) 625 626 627ProgramShape = collections.namedtuple('ProgramShape', 628 ('parameter_shapes', 'result_shape')) 629 630 631def _wrap_shape(shape_info): 632 dtype, dims = shape_info 633 element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)] 634 if element_type == PrimitiveType.TUPLE: 635 shapes = tuple(_wrap_shape(subshape_info) for subshape_info in dims) 636 return Shape.tuple_shape(shapes) 637 else: 638 return Shape.array_shape(dtype, dims) 639 640 641def _wrap_program_shape(shape_info): 642 arg_shapes, result_shape = shape_info 643 return ProgramShape([_wrap_shape(arg) for arg in arg_shapes], 644 _wrap_shape(result_shape)) 645 646 647def require_numpy_array_layout(value): 648 if isinstance(value, tuple): 649 return tuple(require_numpy_array_layout(x) for x in value) 650 else: 651 return np.require(value, requirements=['C', 'A']) 652 653 654class CompileOptions(object): 655 """Python object for XLA compile options. 656 657 These options can be passed to the 'compile' step when using a local XLA 658 client. 659 """ 660 661 def __init__(self): 662 self.xla_dump_to = None 663 self.dump_hlo_pass_re = None 664 self.dump_hlo_module_re = None 665 self.dump_hlo_as_text = None 666 self.dump_hlo_as_proto = None 667 self.hlo_profile = None 668 self.num_replicas = get_replica_count() 669 670 671def transfer_to_infeed(value, device_ordinal=0): 672 """Transfers the given value into the XLA infeed queue. 673 674 XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with 675 a totally ordered stream of values. This is dequeued from XLA computations via 676 the Infeed() operation. 677 678 Args: 679 value: the value that the caller would like to enqueue into the XLA infeed 680 queue 681 device_ordinal: the device to infeed the value to. Each device has a 682 distinct infeed queue. 683 """ 684 # TODO(phawkins): support non-default backends. 685 backend = _get_default_local_backend() 686 backend.client.TransferToInfeed( 687 require_numpy_array_layout(value), device_ordinal) 688 689 690def transfer_from_outfeed(shape, device_ordinal=0): 691 """Transfers a literal of the given shape from `device_ordinal`'s outfeed. 692 693 Args: 694 shape: The shape of the value to transfer from outfeed. 695 device_ordinal: The device ordinal to transfer the outfeed value from. Each 696 device has a distinct outfeed queue.. 697 698 Returns: 699 The literal value that is produced from the outfeed queue. 700 """ 701 # TODO(phawkins): support non-default backends. 702 backend = _get_default_local_backend() 703 return backend.client.TransferFromOutfeed(shape, device_ordinal) 704 705 706class Computation(object): 707 """Python wrapper for an XLA Computation. 708 709 A Computation can be compiled to form an Executable, or used as a 710 subcomputation in ComputationBuilder methods. 711 """ 712 713 def __init__(self, c_computation, backend=None): 714 self._c_computation = c_computation 715 # The backend argument is deprecated. Pass a backend to Compile() instead. 716 self._backend = backend 717 self._delete_computation = c_api.DeleteComputation 718 719 @property 720 def computation(self): 721 return self._c_computation 722 723 def GetSerializedProto(self): 724 """Gets the serialized HloModuleProto proto object in this computation. 725 726 Returns: 727 A string containing a serialized HloModuleProto proto containing the 728 computation and its dependencies. 729 """ 730 return self.computation.GetSerializedProto() 731 732 def GetHloText(self): 733 """Get the textual HLO representation of this computation. 734 735 Returns: 736 A string containing the textual HLO. 737 """ 738 return self.computation.GetHloText() 739 740 def GetHloDotGraph(self): 741 """Get a Graphviz Dot representation of this computation. 742 743 Returns: 744 A string containing the graphviz dot graph. 745 """ 746 return self.computation.GetHloDotGraph() 747 748 def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None, 749 backend=None): 750 """Compiles a computation. 751 752 Computations are the result of a "ComputationBuild'ing" process. 753 754 Arguments: 755 argument_shapes: parameter shapes -- they are first laid out by layout_fn 756 if layout_fn is provided. Otherwise, the default layout for those shapes 757 will be used. 758 compile_options: options to use for compilation, includes an optional laid 759 out result shape for the computation. 760 layout_fn: lambda that is used to lay out the argument/result shapes. 761 backend: a `Backend` for which an executable should be generated. 762 763 Returns: 764 A Executable instance. 765 """ 766 backend = backend or self._backend or _get_default_local_backend() 767 result_shape = _wrap_shape(self.computation.GetReturnValueShape()) 768 769 if layout_fn: 770 argument_shapes = [ 771 shape.map_leaves(layout_fn) for shape in argument_shapes 772 ] 773 result_shape = result_shape.map_leaves(layout_fn) 774 775 argument_shapes = list(argument_shapes) 776 777 compile_options = compile_options or CompileOptions() 778 compile_options.result_shape = result_shape 779 c = backend.compile(self.computation, argument_shapes, result_shape, 780 compile_options) 781 return Executable(c, backend=backend) 782 783 def CompileWithExampleArguments(self, 784 arguments=(), 785 compile_options=None, 786 layout_fn=None, 787 backend=None): 788 return self.Compile( 789 argument_shapes=[Shape.from_pyval(arg) for arg in arguments], 790 compile_options=compile_options, 791 layout_fn=layout_fn, 792 backend=backend) 793 794 def GetProgramShape(self): 795 return _wrap_program_shape(self._c_computation.GetProgramShape()) 796 797 def GetReturnValueShape(self): 798 return _wrap_shape(self._c_computation.GetReturnValueShape()) 799 800 def __del__(self): 801 if self._c_computation: 802 self._delete_computation(self._c_computation) 803 804 805class Executable(object): 806 """Python wrapper for an XLA Executable.""" 807 808 def __init__(self, c_executable, backend=None): 809 self._c_executable = c_executable 810 self._device_ordinals = c_executable.DeviceOrdinals() 811 self._backend = backend 812 813 def DeviceOrdinals(self): 814 """Returns a list containing the device ordinals for each replica.""" 815 return self._device_ordinals 816 817 def Execute(self, arguments=(), check_for_deleted_args=True): 818 """Execute on one replica with LocalBuffer arguments and return value.""" 819 if check_for_deleted_args and any(arg.is_deleted() for arg in arguments): 820 raise ValueError('Executing with deleted local buffer argument') 821 raw_args = [arg.c_buffer for arg in arguments] 822 output_buffer = self._backend.execute(self._c_executable, raw_args) 823 return LocalBuffer( 824 output_buffer, backend=self._backend, device=self._device_ordinals[0]) 825 826 def ExecutePerReplica(self, arguments=None): 827 """Execute on many replicas with LocalBuffer arguments and return value. 828 829 Args: 830 arguments: A sequence of sequences of LocalBuffers. The i'th inner 831 sequence comprises the arguments for execution on the i'th replica. 832 833 Returns: 834 A list of the computation's outputs for each replica, as a LocalBuffer. If 835 a shallow sequence of arguments was passed in for `arguments`, then the 836 sole, zero'th replica's output is returned instead, as a LocalBuffer. 837 """ 838 if arguments is None: 839 arguments = ((),) * len(self._device_ordinals) 840 else: 841 arguments = [list(replica_args) for replica_args in arguments] 842 843 # Check arguments 844 for replica, replica_args in enumerate(arguments): 845 for arg in replica_args: 846 if arg.is_deleted(): 847 raise ValueError('Executing with deleted local buffer argument') 848 if arg.device() != self._device_ordinals[replica]: 849 raise ValueError( 850 'Executing on device {} with argument from device {}'.format( 851 self._device_ordinals[replica], arg.device())) 852 853 # Pull out argument buffer handles 854 # pylint: disable=g-complex-comprehension 855 stripped_args = [ 856 [arg.c_buffer for arg in replica_args] for replica_args in arguments 857 ] 858 859 # Execute 860 output_buffers = self._backend.execute_replicated(self._c_executable, 861 stripped_args) 862 863 # Wrap output handles in LocalBuffer instances 864 return tuple( 865 LocalBuffer( 866 output_buffer, 867 backend=self._backend, 868 device=self._device_ordinals[replica]) 869 for replica, output_buffer in enumerate(output_buffers)) 870 871 def ExecuteWithPythonValues(self, arguments=()): 872 """Execute on one replica with Python values as arguments and output.""" 873 874 def put(arg): 875 return LocalBuffer.from_pyval( 876 arg, device=self._device_ordinals[0], backend=self._backend) 877 878 arguments = [put(arg) for arg in arguments] 879 return self.Execute(arguments).to_py() 880 881 def ExecuteWithPythonValuesPerReplica(self, arguments): 882 """Execute on many replicas with Python values as arguments and output.""" 883 884 def put(arg, device): 885 return LocalBuffer.from_pyval(arg, device, backend=self._backend) 886 887 # pylint: disable=g-complex-comprehension 888 arguments = [[ 889 put(arg, self._device_ordinals[replica]) for arg in replica_args 890 ] for replica, replica_args in enumerate(arguments)] 891 return [out.to_py() for out in self.ExecutePerReplica(arguments)] 892 893 def __del__(self): 894 # Python may have freed c_api first. 895 if c_api and self._c_executable: 896 self._backend.delete_executable(self._c_executable) 897 898 899class ComputationBuilder(object): 900 """XLA computation builder. 901 902 Enqueues XLA ops in sequence and in order to build a 903 Computation, which in turn can be compiled into a 904 LocalExecutable, which in turn can be locally executed. 905 """ 906 907 # The methods of this class map 1-to-1 onto the XLA C++ 908 # computation builder API. Therefore, there's no need to laboriously list 909 # arguments and return values for every method, especially where it's obvious. 910 # 911 # pylint: disable=g-doc-return-or-yield 912 # pylint: disable=g-doc-args 913 914 def __init__(self, name): 915 self._client = c_api.ComputationBuilder(name.encode('utf8')) 916 self._parameter_numbering = itertools.count() 917 918 def Build(self, root=None, backend=None): 919 """Builds a `Computation` from the contents of the builder. 920 921 Args: 922 root: if not None, the operator containing the return value of the 923 computation. 924 backend: deprecated. Pass a `backend` to `Computation.Compile` instead. 925 926 Returns: 927 A `Computation`. 928 """ 929 if root is not None: 930 return Computation(self._client.BuildWithRoot(root), backend=backend) 931 else: 932 return Computation(self._client.Build(), backend=backend) 933 934 def SetOpMetadata(self, op_metadata): 935 """Set metadata for operations that are about to be enqueued.""" 936 self._client.SetOpMetadata(op_metadata) 937 938 def ClearOpMetadata(self): 939 """Clear metadata for operations that are about to be enqueued.""" 940 self._client.ClearOpMetadata() 941 942 def Infeed(self, shape): 943 """Enqueues an infeed op onto the computation. 944 945 Infeed operations dequeue data of the given shape from the device's infeed 946 queue for subsequent use in the computation. 947 948 Returns: 949 A LocalOp. 950 """ 951 return self._client.Infeed(shape) 952 953 def Outfeed(self, operand): 954 """Enqueues an outfeed op onto the computation. 955 956 Outfeed operations enqueue data, using the given operand, onto the XLA 957 outfeed queue for subsequent dequeue via the client API. 958 """ 959 self._client.Outfeed(operand, self.GetShape(operand), ''.encode('utf-8')) 960 961 def Constant(self, value): 962 """Enqueues a constant op onto the computation. 963 964 Args: 965 value: value for the constant, as a np.array with an explicit dtype set to 966 one of the supported types. 967 968 Returns: 969 A LocalOp. 970 """ 971 value = require_numpy_array_layout(value) 972 return self._client.ConstantLiteral(value) 973 974 def ConstantF32Scalar(self, value): 975 """Convenience method to enqueue a scalar F32 constant op. 976 977 Args: 978 value: a floating-point number. 979 980 Returns: 981 A LocalOp. 982 """ 983 return self.Constant(np.array(value, dtype=np.float32)) 984 985 def ConstantF64Scalar(self, value): 986 """Convenience method to enqueue a scalar F32 constant op. 987 988 Args: 989 value: a floating-point number. 990 991 Returns: 992 A LocalOp. 993 """ 994 return self.Constant(np.array(value, dtype=np.float64)) 995 996 def ConstantS32Scalar(self, value): 997 """Convenience method to enqueue a scalar S32 constant op. 998 999 Args: 1000 value: a floating-point number. 1001 1002 Returns: 1003 A LocalOp. 1004 """ 1005 return self.Constant(np.array(value, dtype=np.int32)) 1006 1007 def ConstantS64Scalar(self, value): 1008 """Convenience method to enqueue a scalar S64 constant op. 1009 1010 Args: 1011 value: a floating-point number. 1012 1013 Returns: 1014 A LocalOp. 1015 """ 1016 return self.Constant(np.array(value, dtype=np.int64)) 1017 1018 def ConstantPredScalar(self, value): 1019 """Convenience method to enqueue a scalar PRED constant op. 1020 1021 Args: 1022 value: a boolean value. 1023 1024 Returns: 1025 A LocalOp. 1026 """ 1027 return self.Constant(np.array(value, dtype=np.bool)) 1028 1029 def ParameterWithShape(self, shape, name=None, parameter_num=None): 1030 """Enqueues a Parameter op onto the computation, given a shape. 1031 1032 Args: 1033 shape: the parameter's shape as a Shape object. 1034 name: optional string name for the parameter. 1035 parameter_num: parameter number in the computation function. If None, the 1036 next linear parameter number is used. The default value capability can 1037 be used for auto-numbering. If you're using auto-numbering for some 1038 parameters, use it for *all* parameters to avoid clashes. 1039 1040 Returns: 1041 A LocalOp. 1042 """ 1043 if name is None: 1044 name = '' 1045 if parameter_num is None: 1046 parameter_num = next(self._parameter_numbering) 1047 1048 return self._client.Parameter(parameter_num, shape, name.encode('utf8')) 1049 1050 def ParameterFromNumpy(self, value, name=None, parameter_num=None): 1051 """Enqueues a Parameter op onto the computation. 1052 1053 Args: 1054 value: a Numpy array, or a nested tuple thereof, from which the shape is 1055 inferred. 1056 name: as in ParameterWithShape. 1057 parameter_num: as in ParameterWithShape. 1058 1059 Returns: 1060 A LocalOp. 1061 """ 1062 return self.ParameterWithShape( 1063 Shape.from_pyval(value), name=name, parameter_num=parameter_num) 1064 1065 def Iota(self, dtype, size): 1066 """Enqueues an iota constant onto the computation. 1067 1068 Args: 1069 dtype: expected numpy dtype of the output. 1070 size: integer, the number of elements in the array. 1071 1072 Returns: 1073 A LocalOp representing the added iota constant. 1074 """ 1075 element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] 1076 return self._client.Iota(element_type, size) 1077 1078 def BroadcastedIota(self, dtype, shape, dimension): 1079 """Enqueues a broadcasted iota constant onto the computation. 1080 1081 Args: 1082 dtype: expected numpy dtype of the output. 1083 shape: tuple of integers, the expected output shape (dimensions). 1084 dimension: positive integer, dimension along which to increment values. 1085 1086 Returns: 1087 A LocalOp representing the added broadcasted iota constant. 1088 """ 1089 xla_shape = Shape.array_shape(dtype, shape) 1090 return self._client.BroadcastedIota(xla_shape, dimension) 1091 1092 def Broadcast(self, operand, sizes): 1093 """Enqueues a broadcast operation onto the computation. 1094 1095 Args: 1096 operand: the operand LocalOp to broadcast. 1097 sizes: an iterable of broadcast sizes. 1098 1099 Returns: 1100 A LocalOp representing the added broadcast op. 1101 """ 1102 return self._client.Broadcast(operand, sizes) 1103 1104 def BroadcastInDim(self, operand, shape, broadcast_dimensions): 1105 """Enqueues a broadcast-in-dimensions operation onto the computation. 1106 1107 Args: 1108 operand: the operand LocalOp to broadcast. 1109 shape: tuple of integers, the expected output shape. 1110 broadcast_dimensions: tuple of integers identifying which dimensions of 1111 the output are to be broadcast into. 1112 1113 Returns: 1114 A LocalOp representing the added broadcast-in-dimensions op. 1115 """ 1116 return self._client.BroadcastInDim(operand, shape, broadcast_dimensions) 1117 1118 def Concatenate(self, operands, dimension): 1119 """Enqueues a concatenate operation onto the computation. 1120 1121 Args: 1122 operands: the operands to concatenate. 1123 dimension: the dimension in which to perform the concatenation. 1124 1125 Returns: 1126 A LocalOp representing the added concatenate op. 1127 """ 1128 return self._client.ConcatInDim(operands, dimension) 1129 1130 def ConvertElementType(self, operand, new_element_type): 1131 """Enqueues an element type conversion operation onto the computation. 1132 1133 Args: 1134 operand: the operand to convert. 1135 new_element_type: the target primitive type. 1136 1137 Returns: 1138 A LocalOp representing the added conversion op. 1139 """ 1140 return self._client.ConvertElementType(operand, new_element_type) 1141 1142 def BitcastConvertType(self, operand, new_element_type): 1143 """Enqueues a bitcast type conversion operation onto the computation. 1144 1145 Args: 1146 operand: the operand to convert. 1147 new_element_type: the target primitive type. 1148 1149 Returns: 1150 A LocalOp representing the added conversion op. 1151 """ 1152 return self._client.BitcastConvertType(operand, new_element_type) 1153 1154 def GetShape(self, operand): 1155 return _wrap_shape(self._client.GetShape(operand)) 1156 1157 def GetReturnValueShape(self): 1158 return _wrap_shape(self._client.GetReturnValueShape()) 1159 1160 def GetComputationStats(self): 1161 raise NotImplementedError() 1162 1163 def ReplicaId(self): 1164 """Enqueues a ReplicaId operation onto the computation. 1165 1166 Returns: 1167 A LocalOp representing the replica id. 1168 """ 1169 return self._client.ReplicaId() 1170 1171 def Pad(self, operand, padding_value, padding_config): 1172 """Enqueues a Pad operation onto the computation. 1173 1174 Args: 1175 operand: LocalOp representing the array to pad. 1176 padding_value: LocalOp representing the scalar pad value. 1177 padding_config: either a PaddingConfig or a list of integer triples 1178 (edge_padding_low, edge_padding_high, interior_padding) representing the 1179 configuration of the padding operation. 1180 1181 Returns: 1182 A LocalOp representing the added Pad op. 1183 """ 1184 if isinstance(padding_config, tuple) or isinstance(padding_config, list): 1185 padding_config = GetPaddingConfigFromTriples(padding_config) 1186 return self._client.Pad(operand, padding_value, padding_config) 1187 1188 def Reshape(self, operand, dimensions, new_sizes): 1189 """Enqueues a reshape op onto the computation. 1190 1191 Args: 1192 operand: LocalOp representing the array to be reshaped. 1193 dimensions: sequence of integers encoding the order in which dimensions 1194 are collapsed or None, in which case dimensions are flattened in order. 1195 new_sizes: sequence of integers encoding the new dimension sizes (shape). 1196 1197 Returns: 1198 A LocalOp representing the added Reshape op. 1199 """ 1200 if dimensions is None: 1201 ndim = len(self.GetShape(operand).dimensions()) 1202 dimensions = tuple(range(ndim)) 1203 return self._client.Reshape(operand, dimensions, new_sizes) 1204 1205 def AllToAll(self, 1206 operand, 1207 split_dimension, 1208 concat_dimension, 1209 replica_groups=None): 1210 """AllToAll op. 1211 1212 Args: 1213 operand: LocalOp representing the input array 1214 split_dimension: the dimension along which the operand is split 1215 concat_dimension: the dimension along which the split blocks are 1216 concatenated 1217 replica_groups: optional, list of lists of ints encoding a partition of 1218 the set {0, 1, ..., num_replicas} into equally-sized replica groups 1219 within which the all-to-all is performed. If not supplied or None (the 1220 default), all replicas belong to the same group. 1221 1222 Returns: 1223 A LocalOp that represents the all-to-all concatenation. 1224 """ 1225 if replica_groups is None: 1226 replica_groups_protos = [] # special value for XLA API 1227 else: 1228 replica_groups = list(replica_groups) 1229 replica_groups_protos = [ 1230 _make_replica_group_proto(group) for group in replica_groups 1231 ] 1232 if not replica_groups: 1233 split_count = get_replica_count() 1234 else: 1235 split_count = len(replica_groups[0]) 1236 if not all(split_count == len(g) for g in replica_groups): 1237 raise ValueError('Replica groups must be equally sized') 1238 return self._client.AllToAll(operand, split_dimension, concat_dimension, 1239 split_count, replica_groups_protos) 1240 1241 def CrossReplicaSum(self, operand, replica_groups=None): 1242 """CrossReplicaSum op. 1243 1244 Args: 1245 operand: the operand to sum across replica instances. 1246 replica_groups: optional, list of lists of ints encoding a partition of 1247 the set {0, 1, ..., num_replicas} into equally-sized replica groups 1248 within which the cross-replica sum is performed. If not supplied or None 1249 (the default), all replicas belong to the same group. 1250 1251 Returns: 1252 A LocalOp that represents on each replica the sum of its group's values. 1253 """ 1254 if replica_groups is None: 1255 replica_groups = [] # special value for XLA API 1256 else: 1257 replica_groups = [ 1258 _make_replica_group_proto(group) for group in replica_groups 1259 ] 1260 return self._client.CrossReplicaSum(operand, replica_groups) 1261 1262 def Collapse(self, operand, dimensions): 1263 """Collapse op.""" 1264 return self._client.Collapse(operand, dimensions) 1265 1266 def Trans(self, operand): 1267 """Specialized matrix transpose op.""" 1268 return self._client.Transpose(operand, [1, 0]) 1269 1270 def Transpose(self, operand, permutation): 1271 """Transpose op.""" 1272 return self._client.Transpose(operand, permutation) 1273 1274 def Rev(self, operand, dimensions): 1275 """Rev op.""" 1276 return self._client.Rev(operand, dimensions) 1277 1278 def Clamp(self, min, operand, max): # pylint: disable=redefined-builtin 1279 """Clamp op.""" 1280 return self._client.Clamp(min, operand, max) 1281 1282 def SelectAndScatter(self, operand, select, window_dimensions, window_strides, 1283 padding, source, init_value, scatter): 1284 """Select and scatter op, used by the gradient of ReduceWindow. 1285 1286 Args: 1287 operand: LocalOp for array of dimension N and type T over which the 1288 windows slide. 1289 select: Computation of type (T, T) -> Pred to apply to the elements of 1290 each window to indicate which element is selected. 1291 window_dimensions: sequence of N integers for dimensions of the window. 1292 window_strides: sequence of N integers for the strides of the window. 1293 padding: PaddingType representing either 'SAME' or 'VALID ' padding. 1294 source: LocalOp for array of type T with values to scatter. 1295 init_value: LocalOp of scalar type T for initial out value. 1296 scatter: Computation of type (T, T) -> T to apply to each scatter source 1297 element with its destination element. 1298 1299 Returns: 1300 A LocalOp representing the added SelectAndScatter op. 1301 """ 1302 pads = _convert_padding_type_to_pad_values( 1303 padding, self.GetShape(operand).dimensions(), window_dimensions, 1304 window_strides) 1305 return self._client.SelectAndScatterWithGeneralPadding( 1306 operand, select.computation, window_dimensions, window_strides, pads, 1307 source, init_value, scatter.computation) 1308 1309 def Select(self, pred, on_true, on_false): 1310 """Element-wise selection op. 1311 1312 Constructs an output array from elements of two input arrays, based on the 1313 values of a predicate array. 1314 """ 1315 return self._client.Select(pred, on_true, on_false) 1316 1317 def Slice(self, operand, start_indices, limit_indices, strides=None): 1318 """Enqueues a slice operation onto the computation. 1319 1320 Args: 1321 operand: LocalOp for the N dimensional array to be sliced. 1322 start_indices: iterable of N integers containing the starting indices of 1323 the slice for each dimension. 1324 limit_indices: iterable of N integers containing the ending indices 1325 (exclusive) of the slice for each dimension. 1326 strides: optional iterable of N integers containing the stride sizes for 1327 each dimension. 1328 1329 Returns: 1330 A LocalOp representing the added Slice op. 1331 """ 1332 if strides is None: 1333 start_indices = list(start_indices) 1334 strides = [1] * len(start_indices) 1335 return self._client.Slice(operand, start_indices, limit_indices, strides) 1336 1337 def SliceInDim(self, operand, start_index, limit_index, stride, dimno): 1338 """Enqueues a slice-in-dimension operation onto the computation. 1339 1340 Args: 1341 operand: LocalOp for the N dimensional array to be sliced. 1342 start_index: an integer containing the start index of the slice. 1343 limit_index: an integer containing the end index of the slice. 1344 stride: an integer containing the stride size for the slice. 1345 dimno: an integer indicating the dimension along which to slice. 1346 1347 Returns: 1348 A LocalOp representing the added Slice op. 1349 """ 1350 return self._client.SliceInDim(operand, start_index, limit_index, stride, 1351 dimno) 1352 1353 def DynamicSlice(self, operand, start_indices, slice_sizes): 1354 """Enqueues a slice op with dynamic start indices onto the computation. 1355 1356 Args: 1357 operand: LocalOp for the N dimensional array to be sliced. 1358 start_indices: LocalOp for the 1D array of N integers containing the 1359 starting indices of the slice. 1360 slice_sizes: iterable of N integers containing the slice sizes in each 1361 dimension. 1362 1363 Returns: 1364 A LocalOp representing the added DynamicSlice op. 1365 """ 1366 return self._client.DynamicSlice(operand, start_indices, slice_sizes) 1367 1368 def DynamicUpdateSlice(self, operand, update, start_indices): 1369 """Enqueues a dynamic update slice operation onto the computation. 1370 1371 Args: 1372 operand: LocalOp for the N dimensional array to be updated. 1373 update: N dimensional array comprising the slice update. 1374 start_indices: Rank-1 array of N integers comprising the starting indices 1375 of the slice along each dimension. 1376 1377 Returns: 1378 A LocalOp representing the added DynamicUpdateSlice op. 1379 """ 1380 return self._client.DynamicUpdateSlice(operand, update, start_indices) 1381 1382 def Tuple(self, *ops): 1383 """Enqueues a tuple operation onto the computation. 1384 1385 Args: 1386 ops: a sequence of tuple operands (each a LocalOp). 1387 1388 Returns: 1389 A LocalOp representing the added Tuple op. 1390 """ 1391 return self._client.Tuple(ops) 1392 1393 def GetTupleElement(self, tup, index): 1394 """Enqueues a 'get tuple element' operation onto the computation. 1395 1396 Args: 1397 tup: the tuple operand (a LocalOp). 1398 index: numeric index to select from the tuple. 1399 1400 Returns: 1401 A LocalOp representing the added GetTupleElement op. 1402 """ 1403 return self._client.GetTupleElement(tup, index) 1404 1405 def Call(self, computation_to_apply, operands): 1406 """Enqueues a call operation onto the computation. 1407 1408 Args: 1409 computation_to_apply: a Computation object. 1410 operands: an iterable of LocalOp. The number and types of operands must 1411 match the arity of computation_to_apply. 1412 1413 Returns: 1414 A LocalOp representing the added call op. 1415 """ 1416 return self._client.Call(computation_to_apply.computation, operands) 1417 1418 def CustomCall(self, 1419 call_target_name, 1420 operands, 1421 shape_with_layout, 1422 operand_shapes_with_layout, 1423 opaque=None): 1424 """Enqueues a custom call operation onto the computation. 1425 1426 Args: 1427 call_target_name: the name of the function to call. 1428 operands: an iterable of LocalOp. The number and types of operands must 1429 match the arity of `operand_shapes_with_layout`. 1430 shape_with_layout: the shape of the operator's output, with layout. 1431 operand_shapes_with_layout: the shapes of `operands`, including the 1432 expected layouts. 1433 opaque: an opaque string passed to the backend. 1434 1435 Returns: 1436 A LocalOp representing the added custom call op. 1437 """ 1438 opaque = opaque or b'' 1439 return self._client.CustomCall(call_target_name, operands, 1440 shape_with_layout, 1441 operand_shapes_with_layout, opaque) 1442 1443 def Map(self, operands, computation_to_apply, dimensions): 1444 """Enqueues a map operation onto the computation. 1445 1446 Args: 1447 operands: an iterable of LocalOp. 1448 computation_to_apply: a Computation object. 1449 dimensions: dimensions over which to apply map the function. 1450 1451 Returns: 1452 A LocalOp representing the added Map op. 1453 """ 1454 return self._client.Map(operands, computation_to_apply.computation, 1455 dimensions) 1456 1457 def Reduce(self, operand, init_value, computation_to_apply, dimensions): 1458 """Enqueues a reduction operation onto the computation. 1459 1460 Args: 1461 operand: reduction operand (LocalOp). 1462 init_value: reduction initial value (LocalOp). 1463 computation_to_apply: a Computation object - binary reduction function. 1464 dimensions: sequence of dimensions (integers) to reduce on. 1465 1466 Returns: 1467 A LocalOp representing the added Reduce op. 1468 """ 1469 return self._client.Reduce(operand, init_value, 1470 computation_to_apply.computation, dimensions) 1471 1472 def ReduceWindow(self, operand, init_value, computation_to_apply, 1473 window_dimensions, window_strides, padding): 1474 """Enqueues a windowed reduction operation onto the computation. 1475 1476 Args: 1477 operand: reduction operand (LocalOp). 1478 init_value: reduction initial value (LocalOp). 1479 computation_to_apply: a binary reduction function (Computation). 1480 window_dimensions: dimensions of window (sequence of integers). 1481 window_strides: strides for window (sequence of integers). 1482 padding: PaddingType representing either 'SAME' or 'VALID' padding. 1483 1484 Returns: 1485 A LocalOp representing the added ReduceWindow op. 1486 """ 1487 pads = _convert_padding_type_to_pad_values( 1488 padding, 1489 self.GetShape(operand).dimensions(), window_dimensions, window_strides) 1490 return self._client.ReduceWindowWithGeneralPadding( 1491 operand, init_value, computation_to_apply.computation, 1492 window_dimensions, window_strides, (), (), pads) 1493 1494 def ReduceWindowWithGeneralPadding( 1495 self, operand, init_value, computation_to_apply, window_dimensions, 1496 window_strides, base_dilations, window_dilations, padding): 1497 """Enqueues a windowed reduction operation onto the computation. 1498 1499 Args: 1500 operand: reduction operand (LocalOp). 1501 init_value: reduction initial value (LocalOp). 1502 computation_to_apply: a binary reduction function (Computation). 1503 window_dimensions: dimensions of window (sequence of integers). 1504 window_strides: strides for window (sequence of integers). 1505 base_dilations: dilations for the base (sequence of integers). 1506 window_dilations: dilations for window (sequence of integers). 1507 padding: length-N array-like of pairs of integers of (low, high) padding. 1508 1509 Returns: 1510 A LocalOp representing the added ReduceWindow op. 1511 """ 1512 return self._client.ReduceWindowWithGeneralPadding( 1513 operand, init_value, computation_to_apply.computation, 1514 window_dimensions, window_strides, base_dilations, window_dilations, 1515 padding) 1516 1517 def RngNormal(self, mu, sigma, dims): 1518 """Enqueues an RngNormal operation onto the computation. 1519 1520 Args: 1521 mu: A LocalOp to an F32 scalar specifying the mean. 1522 sigma: A LocalOp to an F32 scalar specifying the standard deviation. 1523 dims: A 1D array-like of nonnegative integers specifying the dimensions. 1524 Returns: a LocalOp to the generated array of F32 values. 1525 """ 1526 shape = Shape.array_shape(self.GetShape(mu).element_type(), dims) 1527 return self._client.RngNormal(mu, sigma, shape) 1528 1529 def RngUniform(self, a, b, dims): 1530 """Enqueues an RngUniform operation onto the computation. 1531 1532 Args: 1533 a: a LocalOp to an F32, S32, or U32 scalar (consistent with the type of b) 1534 specifying the low end of the interval [a, b) over which values are 1535 generated. 1536 b: a LocalOp to an F32, S32, or U32 scalar (consistent with the type of a) 1537 specifying the high end of the interval [a, b) over which values are 1538 generated. 1539 dims: A 1D array-like of nonnegative integers specifying the dimensions. 1540 Returns: a LocalOp to the generated array of values with the same numeric 1541 type (F32, S32, or U32) as the arguments a and b. 1542 """ 1543 shape = Shape.array_shape(self.GetShape(a).element_type(), dims) 1544 return self._client.RngUniform(a, b, shape) 1545 1546 def While(self, cond, body, init): 1547 """Enqueues a While operation onto the computation. 1548 1549 Args: 1550 cond: a Computation for the loop condition, which has type T -> PRED 1551 body: a Computation for the loop body, which has type T -> T 1552 init: a LocalOp for the initial parameter, which has type T 1553 Returns: a LocalOp representing the While operation. 1554 """ 1555 return self._client.While(cond.computation, body.computation, init) 1556 1557 def Conditional(self, pred, true_operand, true_computation, false_operand, 1558 false_computation): 1559 """Enqueues a Conditional operation onto the computation. 1560 1561 Args: 1562 predicate: a LocalOp to test, which has scalar type PRED 1563 true_operand: a LocalOp of type T_0 1564 true_computation: a Computation to apply to true_operand, type T_0 -> S 1565 false_operand: a ComputationDatahandle of type T_1 1566 false_computation: a Computation to apply to false_operand, type T_1 -> S 1567 Returns: a LocalOp representing the Conditional operation. 1568 """ 1569 return self._client.Conditional(pred, true_operand, 1570 true_computation.computation, false_operand, 1571 false_computation.computation) 1572 1573 def IsConstant(self, operand): 1574 """Checks whether the given operand is a compile-time constant. 1575 1576 Args: 1577 operand: a ComputationDataHandle to test. 1578 Returns: bool indicating whether `operand` is a compile-time constant, 1579 meaning its value does not depend on any parametersor, or on stateful 1580 operators such as `RngNormal` or `Infeed`. 1581 """ 1582 return self._client.IsConstant(operand) 1583 1584 def BuildConstantSubGraph(self, operand): 1585 """Builds a constant sub graph. 1586 1587 Args: 1588 operand: a LocalOp to test. 1589 Returns: a Computation that is rooted on the given `operand` which is a 1590 compile-time constant. 1591 """ 1592 return self._client.BuildConstantSubGraph(operand) 1593 1594 def Dot(self, lhs, rhs): 1595 """Enqueues a dot operation onto the computation. 1596 1597 Args: 1598 lhs: LocalOp for the rank 1 or rank 2 left-hand-side array. 1599 rhs: LocalOp for the rank 1 or rank 2 right-hand-side array. 1600 Returns: a LocalOp representing the Dot operation. 1601 """ 1602 return self._client.Dot(lhs, rhs) 1603 1604 def DotGeneral(self, lhs, rhs, dimension_numbers): 1605 """Enqueues a general dot operation onto the computation. 1606 1607 Args: 1608 lhs: LocalOp for the left-hand-side array. 1609 rhs: LocalOp for the right-hand-side array. 1610 dimension_numbers: either a DotDimensionNumbers or a nested tuple 1611 ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of 1612 integers representing the dimensions to treat as contracting dimensions 1613 and batch dimensions on each input operand. 1614 Returns: a LocalOp representing the DotGeneral operation. 1615 """ 1616 if isinstance(dimension_numbers, tuple): 1617 dimension_numbers = GetDotDimensionsFromLists(dimension_numbers) 1618 return self._client.DotGeneral(lhs, rhs, dimension_numbers) 1619 1620 def Conv(self, lhs, rhs, window_strides, padding, feature_group_count=1): 1621 """Enqueues a Conv operation onto the computation. 1622 1623 Args: 1624 lhs: LocalOp for the rank N+2 array of inputs. 1625 rhs: LocalOp for the rank N+2 array of kernel weights. 1626 window_strides: length-N array-like of integer kernel strides. 1627 padding: PaddingType representing either 'SAME' or 'VALID' padding. 1628 feature_group_count: number of feature groups for grouped convolution. 1629 Returns: a LocalOp representing the Conv operation. 1630 """ 1631 pads = _convert_padding_type_to_pad_values( 1632 padding, 1633 self.GetShape(lhs).dimensions()[2:], 1634 self.GetShape(rhs).dimensions()[2:], window_strides) 1635 return self.ConvGeneralDilated( 1636 lhs, rhs, window_strides, pads, (), (), dimension_numbers=None, 1637 feature_group_count=feature_group_count) 1638 1639 def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding, 1640 lhs_dilation, rhs_dilation, feature_group_count=1): 1641 """Enqueues a ConvWithGeneralPadding operation onto the computation. 1642 1643 Args: 1644 lhs: LocalOp for the rank N+2 array of inputs. 1645 rhs: LocalOp for the rank N+2 array of kernel weights. 1646 window_strides: length-N array-like of kernel strides. 1647 padding: length-N array-like of pairs of integers of (low, high) padding. 1648 lhs_dilation: length-N array-like of dilation factors. 1649 rhs_dilation: length-N array-like of dilation factors. 1650 feature_group_count: number of feature groups for grouped convolution. 1651 1652 Returns: 1653 A ComputationdataHandle representing the added ConvWithGeneralPadding op. 1654 """ 1655 return self.ConvGeneralDilated( 1656 lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, 1657 dimension_numbers=None, feature_group_count=feature_group_count) 1658 1659 def _GetConvDimensionNumbers(self, num_spatial_dims): 1660 """Create ConvolutionDimensionNumbers proto for convolutions.""" 1661 nd = num_spatial_dims 1662 dimension_numbers = ConvolutionDimensionNumbers() 1663 dimension_numbers.input_batch_dimension = 0 1664 dimension_numbers.input_feature_dimension = 1 1665 dimension_numbers.output_batch_dimension = 0 1666 dimension_numbers.output_feature_dimension = 1 1667 dimension_numbers.kernel_output_feature_dimension = 0 1668 dimension_numbers.kernel_input_feature_dimension = 1 1669 dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd)) 1670 dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd)) 1671 dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) 1672 return dimension_numbers 1673 1674 def ConvGeneralDilated(self, lhs, rhs, window_strides, padding, lhs_dilation, 1675 rhs_dilation, dimension_numbers=None, 1676 feature_group_count=1): 1677 """Enqueues a ConvGeneralDilated operation onto the computation. 1678 1679 Args: 1680 lhs: LocalOp for the rank N+2 array of inputs. 1681 rhs: LocalOp for the rank N+2 array of kernel weights. 1682 window_strides: length-N array-like of integer kernel strides. 1683 padding: length-N array-like of pairs of integers of (low, high) padding. 1684 lhs_dilation: length-N array-like of integer dilation factors. 1685 rhs_dilation: length-N array-like of integer dilation factors. 1686 dimension_numbers: optional, either a ConvolutionDimensionNumbers object 1687 or a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of 1688 length N+2 identifying by position: (1) batch dimensions in lhs, rhs, 1689 and the output with the character 'N', (2) feature dimensions in lhs 1690 and the output with the character 'C', (3) input and output feature 1691 dimensions in rhs with the characters 'I' and 'O' respectively, and 1692 (4) spatial dimension correspondences between lhs, rhs, and the output 1693 using any distinct characters. For example, to indicate dimension 1694 numbers consistent with the Conv operation with two spatial 1695 dimensions, one could use ('NCHW', 'OIHW', 'NCHW'). As another 1696 example, to indicate dimension numbers consistent with the TensorFlow 1697 Conv2D operation, one could use ('NHWC', 'HWIO', 'NHWC'). When using 1698 the latter form of convolution dimension specification, window strides 1699 are associated with spatial dimension character labels according to 1700 the order in which the labels appear in the rhs_spec string, so that 1701 window_strides[0] is matched with the dimension corresponding to the 1702 first character appearing in rhs_spec that is not 'I' or 'O'. By 1703 default, use the same dimension numbering as Conv and 1704 ConvWithGeneralPadding. 1705 feature_group_count: number of feature groups for grouped convolution. 1706 Returns: a LocalOp representing the ConvGenralDilated operation. 1707 """ 1708 if dimension_numbers is None: 1709 dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) 1710 elif isinstance(dimension_numbers, tuple): 1711 lhs_spec, rhs_spec, out_spec = dimension_numbers 1712 dimension_numbers = ConvolutionDimensionNumbers() 1713 1714 dimension_numbers.input_batch_dimension = lhs_spec.index('N') 1715 dimension_numbers.input_feature_dimension = lhs_spec.index('C') 1716 dimension_numbers.output_batch_dimension = out_spec.index('N') 1717 dimension_numbers.output_feature_dimension = out_spec.index('C') 1718 dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O') 1719 dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I') 1720 1721 dimension_numbers.kernel_spatial_dimensions.extend( 1722 i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'}) 1723 dimension_numbers.input_spatial_dimensions.extend( 1724 sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), 1725 key=lambda i: rhs_spec.index(lhs_spec[i]))) 1726 dimension_numbers.output_spatial_dimensions.extend( 1727 sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), 1728 key=lambda i: rhs_spec.index(out_spec[i]))) 1729 return self._client.ConvGeneralDilated( 1730 lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, 1731 dimension_numbers, feature_group_count) 1732 1733 def Sort(self, operand, dimension=-1): 1734 """Enqueues a sort operation onto the computation.""" 1735 return self._client.Sort(operand, dimension) 1736 1737 def SortKeyVal(self, keys, values, dimension=-1): 1738 """Enqueues a key-value sort operation onto the computation.""" 1739 return self._client.SortKeyVal(keys, values, dimension) 1740 1741 def Cholesky(self, a, lower=True): 1742 """Enqueues a Cholesky decomposition onto the computation.""" 1743 return self._client.Cholesky(a, lower) 1744 1745 def QR(self, a, full_matrices=True): 1746 """Enqueues a QR decomposition onto the computation.""" 1747 return self._client.QR(a, full_matrices) 1748 1749 def TriangularSolve(self, 1750 a, 1751 b, 1752 left_side=False, 1753 lower=False, 1754 transpose_a=False, 1755 conjugate_a=False, 1756 unit_diagonal=False): 1757 """Enqueues a triangular-solve operation onto the computation.""" 1758 if not transpose_a: 1759 transpose = 1 1760 if conjugate_a: 1761 a = self.Conj(a) 1762 else: 1763 transpose = 3 if conjugate_a else 2 1764 return self._client.TriangularSolve(a, b, left_side, lower, unit_diagonal, 1765 transpose) 1766 1767 def Eigh(self, a, full_matrices=True): 1768 """Enqueues a symmetric/Hermitian eigendecomposition.""" 1769 return self._client.Eigh(a, full_matrices) 1770 1771 def SVD(self, a): 1772 """Enqueues a singular value decomposition.""" 1773 return self._client.SVD(a) 1774 1775 def Gather(self, a, start_indices, dimension_numbers, slice_sizes): 1776 """Enqueues a Gather operation onto the computation.""" 1777 return self._client.Gather(a, start_indices, dimension_numbers, slice_sizes) 1778 1779 def Scatter(self, a, scatter_indices, updates, update_computation, 1780 dimension_numbers): 1781 """Enqueues a Scatter operation onto the computation.""" 1782 return self._client.Scatter( 1783 a, scatter_indices, updates, update_computation.computation, 1784 dimension_numbers) 1785 1786 1787def _forward_methods_to_local_builder(): 1788 """Forward remaining ComputationBuilder methods to the C API. 1789 1790 Set up methods, corresponding to unary and binary XLA operations, 1791 whose calls are forwarded in a boilerplate manner to the underlying 1792 ComputationBuilder C-extension API. 1793 """ 1794 1795 def forward_to_local_builder_with_handles(target_method, is_binop=False): 1796 """Generate a forwarding method that wraps/unwraps data handles.""" 1797 1798 def forward(self, *args, **kwargs): 1799 arg_list = list(args) 1800 1801 if is_binop and len(arg_list) < 3: 1802 arg_list.append(kwargs.get('broadcast_dimensions', ())) 1803 1804 return target_method( 1805 self._client, # pylint: disable=protected-access 1806 *arg_list) 1807 1808 return forward 1809 1810 for method_name in _UNARY_OPS: 1811 forward = forward_to_local_builder_with_handles( 1812 getattr(c_api.ComputationBuilder, method_name)) 1813 forward.__name__ = method_name 1814 setattr(ComputationBuilder, method_name, forward) 1815 1816 for method_name in _BINARY_OPS: 1817 forward = forward_to_local_builder_with_handles( 1818 getattr(c_api.ComputationBuilder, method_name), is_binop=True) 1819 forward.__name__ = method_name 1820 setattr(ComputationBuilder, method_name, forward) 1821 1822 1823_forward_methods_to_local_builder() 1824 1825_default_replica_count = 1 1826 1827 1828def initialize_replica_count(replica_count): 1829 """Initializes the default replica count to use. 1830 1831 Deprecated; pass `num_replicas` as an option to `Computation.Compile()` 1832 instead. 1833 1834 Args: 1835 replica_count: number of replicas that are desired for set up during XLA 1836 initialization. 1837 1838 Raises: 1839 A runtime exception if the XLA service has already been initialized. 1840 """ 1841 global _default_replica_count 1842 _default_replica_count = replica_count 1843 1844 1845def get_replica_count(): 1846 """Returns the default replica count. 1847 1848 Deprecated; pass `num_replicas` as an option to `Computation.Compile()` 1849 instead. 1850 """ 1851 return _default_replica_count 1852 1853 1854def initialize_platform_name(platform_name): 1855 """Initializes the default platform name to use for XLA. 1856 1857 Args: 1858 platform_name: string name of platform. 1859 """ 1860 global _default_platform_name 1861 _default_platform_name = platform_name 1862 1863 # Make sure the platform is valid by trying to instantiate it. 1864 _get_default_local_backend() 1865 1866 1867def register_cpu_custom_call_target(name, fn): 1868 """Registers a CPU custom call target. 1869 1870 Args: 1871 name: bytes containing the name of the function. 1872 fn: a PyCapsule object containing the function pointer. 1873 """ 1874 c_api.RegisterCpuCustomCallTarget(name, fn) 1875 1876 1877class PaddingConfigDimension(object): 1878 """Python representation of a xla.PaddingConfigDimension protobuf.""" 1879 __slots__ = ('edge_padding_low', 'edge_padding_high', 'interior_padding') 1880 1881 def __init__(self): 1882 self.edge_padding_low = [] 1883 self.edge_padding_high = [] 1884 self.interior_padding = [] 1885 1886 1887class PaddingConfig(object): 1888 """Python representation of a xla.PaddingConfig protobuf.""" 1889 __slots__ = ('dimensions',) 1890 1891 def __init__(self): 1892 self.dimensions = [] 1893 1894 1895def GetPaddingConfigFromTriples(triples): 1896 """Create PaddingConfig proto from list of triples of integers.""" 1897 padding_config = PaddingConfig() 1898 for lo, hi, interior in triples: 1899 dimension = PaddingConfigDimension() 1900 dimension.edge_padding_low = lo 1901 dimension.edge_padding_high = hi 1902 dimension.interior_padding = interior 1903 padding_config.dimensions.append(dimension) 1904 return padding_config 1905 1906 1907class DotDimensionNumbers(object): 1908 """Python representation of a xla.DotDimensionNumbers protobuf.""" 1909 __slots__ = ('lhs_contracting_dimensions', 'rhs_contracting_dimensions', 1910 'lhs_batch_dimensions', 'rhs_batch_dimensions') 1911 1912 def __init__(self): 1913 self.lhs_contracting_dimensions = [] 1914 self.rhs_contracting_dimensions = [] 1915 self.lhs_batch_dimensions = [] 1916 self.rhs_batch_dimensions = [] 1917 1918 1919def GetDotDimensionsFromLists(dimension_numbers): 1920 (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers 1921 dot_dims_proto = DotDimensionNumbers() 1922 dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract) 1923 dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract) 1924 dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch) 1925 dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch) 1926 return dot_dims_proto 1927 1928 1929class ConvolutionDimensionNumbers(object): 1930 """Python representation of a xla.ConvolutionDimensionNumbers protobuf.""" 1931 __slots__ = ('input_batch_dimension', 'input_feature_dimension', 1932 'input_spatial_dimensions', 'kernel_input_feature_dimension', 1933 'kernel_output_feature_dimension', 'kernel_spatial_dimensions', 1934 'output_batch_dimension', 'output_feature_dimension', 1935 'output_spatial_dimensions') 1936 1937 def __init__(self): 1938 self.input_batch_dimension = 0 1939 self.input_feature_dimension = 0 1940 self.input_spatial_dimensions = [] 1941 self.kernel_input_feature_dimension = 0 1942 self.kernel_output_feature_dimension = 0 1943 self.kernel_spatial_dimensions = [] 1944 self.output_batch_dimension = 0 1945 self.output_feature_dimension = 0 1946 self.output_spatial_dimensions = [] 1947 1948 1949class GatherDimensionNumbers(object): 1950 """Python representation of a xla.GatherDimensionNumbers protobuf.""" 1951 __slots__ = ('offset_dims', 'collapsed_slice_dims', 'start_index_map', 1952 'index_vector_dim') 1953 1954 def __init__(self): 1955 self.offset_dims = [] 1956 self.collapsed_slice_dims = [] 1957 self.start_index_map = [] 1958 self.index_vector_dim = 0 1959 1960 1961class ScatterDimensionNumbers(object): 1962 """Python representation of a xla.ScatterDimensionNumbers protobuf.""" 1963 __slots__ = ('update_window_dims', 'inserted_window_dims', 1964 'scatter_dims_to_operand_dims', 'index_vector_dim') 1965 1966 def __init__(self): 1967 self.update_window_dims = [] 1968 self.inserted_window_dims = [] 1969 self.scatter_dims_to_operand_dims = [] 1970 self.index_vector_dim = 0 1971 1972 1973class ReplicaGroup(object): 1974 """Python representation of a xla.ReplicaGroup protobuf.""" 1975 __slots__ = ('replica_ids',) 1976 1977 def __init__(self): 1978 self.replica_ids = [] 1979 1980 1981def _make_replica_group_proto(replica_group): 1982 replica_group_proto = ReplicaGroup() 1983 replica_group_proto.replica_ids.extend(replica_group) 1984 return replica_group_proto 1985