1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""An XLA client in Python, supporting AOT compilation."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22import collections
23import enum  # pylint: disable=g-bad-import-order
24import inspect
25import itertools
26import os
27
28import numpy as np
29
30import six
31from six.moves import xrange
32
33# Note this module does *not* depend on any Python protocol buffers. The XLA
34# Python bindings are currently packaged both as part of jaxlib and as part
35# of TensorFlow. If we use protocol buffers here, then importing both jaxlib
36# and TensorFlow may fail with duplicate protocol buffer message definitions.
37
38from tensorflow.compiler.xla.python import pywrap_xla as c_api
39
40# Import the XRT backend, if available.
41try:
42  # pylint: disable=g-import-not-at-top
43  from tensorflow.compiler.xla.python import pywrap_xrt as xrt_api
44except ImportError:
45  xrt_api = None
46
47
48# Most functions are snake_case for consistency with other modules, whereas
49# method names of ComputationBuilder and Computation are CamelCase for
50# consistency with XLA.
51# pylint: disable=invalid-name
52
53
54# Version of the XLA Python client.
55#
56# JAX packages the XLA python plugin as a binary pip module (jaxlib) that is
57# packaged separately from the Python code that consumes it (jax).
58#
59# We occasionally need to make backwards-incompatible changes to jaxlib, in
60# which case we need to be able to detect when incompatible versions are
61# installed.
62def version():
63  return (0, 1, 8)
64
65
66_OP_METADATA_FIELDS = [
67    'op_type',
68    'op_name',
69    'source_file',
70    'source_line',
71]
72OpMetadata = collections.namedtuple('OpMetadata', _OP_METADATA_FIELDS)
73
74
75@six.add_metaclass(abc.ABCMeta)
76class Backend(object):
77  """Abstract base class for XLA backends."""
78
79  @abc.abstractmethod
80  def device_count(self):
81    """Returns the number of devices known to the backend."""
82
83  @abc.abstractmethod
84  def buffer_from_pyval(self, pyval, device=0):
85    """Allocates a fresh buffer and populates it with `pyval`."""
86
87  @abc.abstractmethod
88  def delete_buffer(self, c_buffer):
89    """Deletes buffer `c_buffer`."""
90
91  @abc.abstractmethod
92  def destructure_tuple(self, c_buffer):
93    """Destructures a tuple buffer into a sequence of buffers."""
94
95  @abc.abstractmethod
96  def compile(self, computation, argument_shapes, result_shape,
97              compile_options):
98    """Compiles a computation. Returns an executable."""
99
100  @abc.abstractmethod
101  def delete_executable(self, executable):
102    """Deletes an executable."""
103
104  @abc.abstractmethod
105  def execute(self, executable, args):
106    """Runs an executable without replication."""
107
108  @abc.abstractmethod
109  def execute_replicated(self, executable, per_replica_args):
110    """Runs an executable in a replicated manner."""
111
112
113def _maybe_encode_string(s):
114  if six.PY3:
115    return s.encode('utf-8')
116  else:
117    return s
118
119
120class XlaLocalBackend(Backend):
121  """XLA backend implemented using the in-process xla::LocalClient API."""
122
123  def __init__(self, platform=None):
124    platform = platform or _get_default_platform_name()
125    self.client = c_api.LocalClient.Get(_maybe_encode_string(platform))
126    self._delete_buffer = c_api.DeleteLocalShapedBuffer
127    self._delete_executable = c_api.DeleteLocalExecutable
128
129  def device_count(self):
130    return self.client.DeviceCount()
131
132  def buffer_from_pyval(self, pyval, device=0):
133    return c_api.LocalShapedBuffer.FromLiteral(pyval, None, self.client, device)
134
135  def delete_buffer(self, c_buffer):
136    self._delete_buffer(c_buffer)
137
138  def destructure_tuple(self, c_buffer):
139    result = c_buffer.DestructureTuple()
140    return [result.Release(i) for i in xrange(result.size())]
141
142  def compile(self, c_computation, argument_shapes, result_shape,
143              compile_options):
144    return c_computation.Compile(argument_shapes, compile_options, self.client)
145
146  def delete_executable(self, executable):
147    self._delete_executable(executable)
148
149  def execute(self, executable, args):
150    return executable.Execute(args)
151
152  def execute_replicated(self, executable, per_replica_args):
153    output_buffer_tup = executable.ExecutePerReplica(per_replica_args)
154    size = output_buffer_tup.size()
155    return [output_buffer_tup.Release(i) for i in xrange(size)]
156
157
158class XrtBackend(Backend):
159  """XLA backend implemented using XRT."""
160
161  def __init__(self, target):
162    self.target = target
163    self._delete_buffer = xrt_api.DeleteXrtAllocation
164    self._delete_executable = xrt_api.DeleteXrtExecutable
165
166  def device_count(self):
167    return 1  # Multidevice execution not implemented.
168
169  def buffer_from_pyval(self, pyval, device=0):
170    if device != 0:
171      raise NotImplementedError(
172          'Multi-replica execution is not yet supported via the XRT backend.')
173    return xrt_api.XrtAllocation.FromLiteral(pyval,
174                                             _maybe_encode_string(self.target))
175
176  def delete_buffer(self, c_buffer):
177    self._delete_buffer(c_buffer)
178
179  def destructure_tuple(self, c_buffer):
180    result = xrt_api.DestructureXrtAllocationTuple(
181        c_buffer, _maybe_encode_string(self.target))
182    return [result.Release(i) for i in xrange(result.size())]
183
184  def compile(self, c_computation, argument_shapes, result_shape,
185              compile_options):
186    return xrt_api.XrtExecutable.CompileForXrt(
187        c_computation.GetSerializedProto(), argument_shapes, result_shape,
188        _maybe_encode_string(self.target))
189
190  def delete_executable(self, executable):
191    self._delete_executable(executable)
192
193  def execute(self, executable, args):
194    return executable.Execute(args)
195
196  def execute_replicated(self, executable, per_replica_args):
197    if len(per_replica_args) != 1:
198      raise NotImplementedError(
199          'Multi-replica execution is not yet supported via the XRT backend.')
200    return [executable.Execute(per_replica_args[0])]
201
202
203_default_platform_name = 'Host'
204_default_backend = None
205
206
207def _get_default_platform_name():
208  return _default_platform_name
209
210
211def _get_default_local_backend():
212  global _default_backend
213  global _default_platform_name
214  if _default_backend is None:
215    _default_backend = XlaLocalBackend(_default_platform_name)
216  return _default_backend
217
218
219class BackendType(enum.Enum):
220  XLA_LOCAL = 1
221  XRT = 2
222
223
224def BackendSpec(backend, target):
225  """Compatibility wrapper to support older clients. Do not use in new code."""
226  if backend == BackendType.XLA_LOCAL:
227    return _get_default_local_backend()
228  elif backend == BackendType.XRT:
229    return XrtBackend(target)
230  else:
231    raise ValueError('Unknown backend {}'.format(backend))
232
233
234def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1):
235  """Helper for use in source mapping that returns an OpMetadata object."""
236  full_filename, lineno = inspect.stack()[skip_frames][1:3]
237  filename = os.path.basename(full_filename)
238  return OpMetadata(
239      op_type=op_type,
240      op_name=op_name,
241      source_file=filename,
242      source_line=lineno)
243
244
245class PaddingType(enum.Enum):
246  VALID = 1
247  SAME = 2
248
249
250def _convert_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims,
251                                        window_strides):
252  """Maps PaddingType or string to pad values (list of pairs of ints)."""
253  if not isinstance(padding_type, (str, PaddingType)):
254    msg = 'padding_type must be str or PaddingType, got {}.'
255    raise TypeError(msg.format(type(padding_type)))
256
257  if isinstance(padding_type, str):
258    if padding_type.upper() == 'VALID':
259      padding_type = PaddingType.VALID
260    elif padding_type.upper() == 'SAME':
261      padding_type = PaddingType.SAME
262    else:
263      msg = 'Unknown padding type string: expected "VALID" or "SAME", got {}.'
264      raise ValueError(msg.format(padding_type))
265
266  if padding_type == PaddingType.VALID:
267    return [(0, 0)] * len(window_strides)
268  elif padding_type == PaddingType.SAME:
269    out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int)
270    pad_sizes = [max((out_size - 1) * stride + filter_size - in_size, 0)
271                 for out_size, stride, filter_size, in_size
272                 in zip(out_shape, window_strides, rhs_dims, lhs_dims)]
273    return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes]
274  else:
275    msg = 'Unexpected PaddingType value: {}'
276    raise ValueError(msg.format(padding_type))
277
278
279_UNARY_OPS = [
280    'Not',
281    'Clz',
282    'Abs',
283    'Exp',
284    'Expm1',
285    'Floor',
286    'Round',
287    'Ceil',
288    'Log',
289    'Log1p',
290    'Sign',
291    'Cos',
292    'Sin',
293    'Tanh',
294    'IsFinite',
295    'Sqrt',
296    'Rsqrt',
297    'Square',
298    'Reciprocal',
299    'Neg',
300    'Erf',
301    'Erfc',
302    'ErfInv',
303    'Lgamma',
304    'Digamma',
305    'Acos',
306    'Asin',
307    'Atan',
308    'Tan',
309    'Acosh',
310    'Asinh',
311    'Atanh',
312    'Cosh',
313    'Sinh',
314    'Real',
315    'Imag',
316    'Conj',
317]
318
319_BINARY_OPS = [
320    'Eq',
321    'Ne',
322    'Ge',
323    'Gt',
324    'Lt',
325    'Le',
326    'Add',
327    'Sub',
328    'Mul',
329    'Div',
330    'Rem',
331    'Max',
332    'Min',
333    'And',
334    'Or',
335    'Xor',
336    'Pow',
337    'ShiftLeft',
338    'ShiftRightArithmetic',
339    'ShiftRightLogical',
340    'Atan2',
341    'Complex',
342]
343
344
345class PrimitiveType(enum.IntEnum):
346  """Python copy of the XLA PrimitiveType enum.
347
348  Must match the corresponding protocol buffer.
349  """
350  PRIMITIVE_TYPE_INVALID = 0
351  PRED = 1
352  S8 = 2
353  S16 = 3
354  S32 = 4
355  S64 = 5
356  U8 = 6
357  U16 = 7
358  U32 = 8
359  U64 = 9
360  BF16 = 16
361  F16 = 10
362  F32 = 11
363  F64 = 12
364  C64 = 15
365  C128 = 18
366  TUPLE = 13
367  OPAQUE = 14
368  TOKEN = 17
369
370
371XLA_ELEMENT_TYPE_TO_DTYPE = {
372    PrimitiveType.PRED: np.dtype('bool'),
373    PrimitiveType.S8: np.dtype('int8'),
374    PrimitiveType.S16: np.dtype('int16'),
375    PrimitiveType.S32: np.dtype('int32'),
376    PrimitiveType.S64: np.dtype('int64'),
377    PrimitiveType.U8: np.dtype('uint8'),
378    PrimitiveType.U16: np.dtype('uint16'),
379    PrimitiveType.U32: np.dtype('uint32'),
380    PrimitiveType.U64: np.dtype('uint64'),
381    PrimitiveType.F16: np.dtype('float16'),
382    PrimitiveType.F32: np.dtype('float32'),
383    PrimitiveType.F64: np.dtype('float64'),
384    PrimitiveType.C64: np.dtype('complex64'),
385    PrimitiveType.C128: np.dtype('complex128'),
386    PrimitiveType.TUPLE: np.dtype(np.object),
387}
388
389# Note the conversion on the key. Numpy has a known issue wherein dtype hashing
390# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus,
391# when keying by dtype in this dict, we use the string form of dtypes.
392DTYPE_TO_XLA_ELEMENT_TYPE = {
393    str(dt): et for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items()
394}
395
396
397def dtype_to_etype(dtype):
398  """Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE."""
399  return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))]
400
401
402class LocalBuffer(object):
403  """Represents a handle to data owned by XLA.
404
405  The referent is ready for use in executing a local, compiled
406  Computation. On XLA platforms involving a device (e.g. GPU), this
407  means the referent is in device memory.
408  """
409
410  def __init__(self, c_buffer, backend, device):
411    self.c_buffer = c_buffer
412    self._backend = backend
413    self._device = device
414
415  @staticmethod
416  def from_pyval(pyval, device=0, backend=None):
417    """Allocate and copy to XLA the given python value."""
418    backend = backend or _get_default_local_backend()
419    pyval = require_numpy_array_layout(pyval)
420    cbuf = backend.buffer_from_pyval(pyval, device)
421    return LocalBuffer(cbuf, backend, device)
422
423  def to_py(self):
424    return self.c_buffer.ToLiteral()
425
426  def shape(self):
427    return _wrap_shape(self.c_buffer.shape())
428
429  def device(self):
430    return self._device
431
432  def delete(self):
433    if self.c_buffer is not None:
434      # Python may have freed c_api first.
435      if c_api:
436        self._backend.delete_buffer(self.c_buffer)
437      self.c_buffer = None
438
439  def destructure(self):
440    """Assuming a tuple buffer, unpack it into constituent tuple elements."""
441    assert self.c_buffer is not None
442    result = self._backend.destructure_tuple(self.c_buffer)
443    self.delete()
444    return tuple(
445        LocalBuffer(sub_buffer, device=self._device, backend=self._backend)
446        for sub_buffer in result)
447
448  def is_deleted(self):
449    return self.c_buffer is None
450
451  def __del__(self):
452    self.delete()
453
454
455class Format(enum.IntEnum):
456  """Python copy of the Format protocol buffer enum."""
457  INVALID_FORMAT = 0
458  DENSE = 1
459  SPARSE = 2
460
461
462class Shape(object):
463  """Represents an XLA shape.
464
465  A shape is either an array shape, having rank-many integer
466  dimensions and an element type (represented by a Numpy dtype), or it
467  is a tuple shape, having a shape for every tuple component:
468
469    type shape =
470        TupleShape of shape list
471      | ArrayShape of { dimensions: int list; element_type: dtype }
472
473  Callers are expected to instantiate this class only via the static
474  constructors: tuple_shape, array_shape, and from_pyval.
475  """
476
477  @staticmethod
478  def tuple_shape(tuple_shapes):
479    """Construct a tuple shape."""
480    if (not isinstance(tuple_shapes, (tuple, list)) or
481        not all(isinstance(t, Shape) for t in tuple_shapes)):
482      raise TypeError('tuple_shapes must be a tuple of Shapes')
483    return Shape(tuple_shapes, tuple)
484
485  @staticmethod
486  def array_shape(element_type, dimensions, minor_to_major=None):
487    """Construct an array shape."""
488    if (not isinstance(dimensions, tuple) or
489        not all(isinstance(i, int) for i in dimensions)):
490      dimensions = tuple(int(i) for i in dimensions)
491    return Shape(
492        dimensions, np.dtype(element_type), minor_to_major=minor_to_major)
493
494  @staticmethod
495  def from_pyval(pyval):
496    def convert(pyval):
497      if isinstance(pyval, tuple):
498        return Shape.tuple_shape(tuple(convert(elt) for elt in pyval))
499      else:
500        pyval = require_numpy_array_layout(pyval)
501        return Shape.array_shape(pyval.dtype, np.shape(pyval))
502    return convert(pyval)
503
504  def __init__(self, dimensions, dtype, minor_to_major=None):
505    assert isinstance(dimensions, tuple)
506    self._dimensions = dimensions
507    self._dtype = dtype
508    self._is_tuple = dtype == tuple
509    self._minor_to_major = minor_to_major
510    self._check_minor_to_major()
511
512  def __eq__(self, other):
513    # pylint: disable=protected-access
514    return (self._dtype == other._dtype and
515            self._dimensions == other._dimensions and
516            self._minor_to_major == other._minor_to_major)
517
518  def __ne__(self, other):
519    return not self == other
520
521  def __hash__(self):
522    return hash((self._dtype, self._dimensions, self._minor_to_major))
523
524  def __repr__(self):
525    return ('xla_client.Shape(_dtype={!r}, _dimensions={!r}, '
526            '_is_tuple={!r}, _minor_to_major={!r})').format(
527                self._dtype, self._dimensions, self._is_tuple,
528                self._minor_to_major)
529
530  def is_tuple(self):
531    return self._is_tuple
532
533  def is_array(self):
534    return not self._is_tuple
535
536  def tuple_shapes(self):
537    if not self.is_tuple():
538      raise ValueError('not a tuple shape')
539    return self._dimensions
540
541  def numpy_dtype(self):
542    """Like element_type(), but returns dtype('O') in case of a tuple shape."""
543    if self.is_tuple():
544      return np.dtype(np.object)
545    else:
546      return self.element_type()
547
548  def xla_element_type(self):
549    return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.numpy_dtype())]
550
551  def element_type(self):
552    if not self.is_array():
553      raise ValueError('not an array shape')
554    return self._dtype
555
556  def dimensions(self):
557    if not self.is_array():
558      raise ValueError('not an array shape')
559    return self._dimensions
560
561  def rank(self):
562    return len(self.dimensions())
563
564  def minor_to_major(self):
565    return self._minor_to_major
566
567  def map_leaves(self, f):
568    """Map f over each leaf-level array subshape.
569
570    Args:
571      f: The function to apply. Whenever f returns None, the identity is applied
572        instead.
573
574    Returns:
575      A new Shape with the mapped leaves.
576    """
577    if self.is_tuple():
578      children = tuple(child.map_leaves(f) for child in self.tuple_shapes())
579      return Shape.tuple_shape(children)
580    else:
581      mapped = f(self)
582      return self if mapped is None else mapped
583
584  def _check_minor_to_major(self):
585    mtm = self._minor_to_major
586    if self.is_tuple():
587      assert mtm is None, self
588    if mtm is not None:
589      assert self.rank() == len(mtm), self
590      assert sorted(mtm) == list(range(len(mtm))), self
591
592  def update_minor_to_major(self, minor_to_major):
593    if not self.is_array():
594      raise ValueError('not an array shape')
595    if not isinstance(minor_to_major, tuple):
596      raise TypeError('minor_to_major must be a tuple')
597    updated = Shape.array_shape(self.element_type(), self.dimensions(),
598                                minor_to_major)
599    updated._check_minor_to_major()  # pylint: disable=protected-access
600    return updated
601
602  def with_major_to_minor_layout_if_absent(self):
603    """Returns a copy of a shape with missing layouts set to major-to-minor."""
604
605    def f(a):
606      if a.minor_to_major():
607        return None
608      return a.update_minor_to_major(tuple(xrange(a.rank() - 1, -1, -1)))
609
610    return self.map_leaves(f)
611
612  def serialize(self, proto):
613    """Serializes 'shape' into proto."""
614    if self.is_tuple():
615      proto.element_type = PrimitiveType.TUPLE
616      for shape in self.tuple_shapes():
617        shape.serialize(proto.tuple_shapes.add())
618    else:
619      proto.element_type = dtype_to_etype(self.element_type())
620      proto.dimensions.extend(self.dimensions())
621      proto.is_dynamic_dimension.extend([False for _ in self.dimensions()])
622      if self.minor_to_major():
623        proto.layout.format = Format.DENSE
624        proto.layout.minor_to_major.extend(self.minor_to_major())
625
626
627ProgramShape = collections.namedtuple('ProgramShape',
628                                      ('parameter_shapes', 'result_shape'))
629
630
631def _wrap_shape(shape_info):
632  dtype, dims = shape_info
633  element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)]
634  if element_type == PrimitiveType.TUPLE:
635    shapes = tuple(_wrap_shape(subshape_info) for subshape_info in dims)
636    return Shape.tuple_shape(shapes)
637  else:
638    return Shape.array_shape(dtype, dims)
639
640
641def _wrap_program_shape(shape_info):
642  arg_shapes, result_shape = shape_info
643  return ProgramShape([_wrap_shape(arg) for arg in arg_shapes],
644                      _wrap_shape(result_shape))
645
646
647def require_numpy_array_layout(value):
648  if isinstance(value, tuple):
649    return tuple(require_numpy_array_layout(x) for x in value)
650  else:
651    return np.require(value, requirements=['C', 'A'])
652
653
654class CompileOptions(object):
655  """Python object for XLA compile options.
656
657  These options can be passed to the 'compile' step when using a local XLA
658  client.
659  """
660
661  def __init__(self):
662    self.xla_dump_to = None
663    self.dump_hlo_pass_re = None
664    self.dump_hlo_module_re = None
665    self.dump_hlo_as_text = None
666    self.dump_hlo_as_proto = None
667    self.hlo_profile = None
668    self.num_replicas = get_replica_count()
669
670
671def transfer_to_infeed(value, device_ordinal=0):
672  """Transfers the given value into the XLA infeed queue.
673
674  XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with
675  a totally ordered stream of values. This is dequeued from XLA computations via
676  the Infeed() operation.
677
678  Args:
679    value: the value that the caller would like to enqueue into the XLA infeed
680      queue
681    device_ordinal: the device to infeed the value to. Each device has a
682      distinct infeed queue.
683  """
684  # TODO(phawkins): support non-default backends.
685  backend = _get_default_local_backend()
686  backend.client.TransferToInfeed(
687      require_numpy_array_layout(value), device_ordinal)
688
689
690def transfer_from_outfeed(shape, device_ordinal=0):
691  """Transfers a literal of the given shape from `device_ordinal`'s outfeed.
692
693  Args:
694    shape: The shape of the value to transfer from outfeed.
695    device_ordinal: The device ordinal to transfer the outfeed value from. Each
696      device has a distinct outfeed queue..
697
698  Returns:
699    The literal value that is produced from the outfeed queue.
700  """
701  # TODO(phawkins): support non-default backends.
702  backend = _get_default_local_backend()
703  return backend.client.TransferFromOutfeed(shape, device_ordinal)
704
705
706class Computation(object):
707  """Python wrapper for an XLA Computation.
708
709  A Computation can be compiled to form an Executable, or used as a
710  subcomputation in ComputationBuilder methods.
711  """
712
713  def __init__(self, c_computation, backend=None):
714    self._c_computation = c_computation
715    # The backend argument is deprecated. Pass a backend to Compile() instead.
716    self._backend = backend
717    self._delete_computation = c_api.DeleteComputation
718
719  @property
720  def computation(self):
721    return self._c_computation
722
723  def GetSerializedProto(self):
724    """Gets the serialized HloModuleProto proto object in this computation.
725
726    Returns:
727       A string containing a serialized HloModuleProto proto containing the
728       computation and its dependencies.
729    """
730    return self.computation.GetSerializedProto()
731
732  def GetHloText(self):
733    """Get the textual HLO representation of this computation.
734
735    Returns:
736       A string containing the textual HLO.
737    """
738    return self.computation.GetHloText()
739
740  def GetHloDotGraph(self):
741    """Get a Graphviz Dot representation of this computation.
742
743    Returns:
744       A string containing the graphviz dot graph.
745    """
746    return self.computation.GetHloDotGraph()
747
748  def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None,
749              backend=None):
750    """Compiles a computation.
751
752    Computations are the result of a "ComputationBuild'ing" process.
753
754    Arguments:
755      argument_shapes: parameter shapes -- they are first laid out by layout_fn
756        if layout_fn is provided. Otherwise, the default layout for those shapes
757        will be used.
758      compile_options: options to use for compilation, includes an optional laid
759        out result shape for the computation.
760      layout_fn: lambda that is used to lay out the argument/result shapes.
761      backend: a `Backend` for which an executable should be generated.
762
763    Returns:
764      A Executable instance.
765    """
766    backend = backend or self._backend or _get_default_local_backend()
767    result_shape = _wrap_shape(self.computation.GetReturnValueShape())
768
769    if layout_fn:
770      argument_shapes = [
771          shape.map_leaves(layout_fn) for shape in argument_shapes
772      ]
773      result_shape = result_shape.map_leaves(layout_fn)
774
775    argument_shapes = list(argument_shapes)
776
777    compile_options = compile_options or CompileOptions()
778    compile_options.result_shape = result_shape
779    c = backend.compile(self.computation, argument_shapes, result_shape,
780                        compile_options)
781    return Executable(c, backend=backend)
782
783  def CompileWithExampleArguments(self,
784                                  arguments=(),
785                                  compile_options=None,
786                                  layout_fn=None,
787                                  backend=None):
788    return self.Compile(
789        argument_shapes=[Shape.from_pyval(arg) for arg in arguments],
790        compile_options=compile_options,
791        layout_fn=layout_fn,
792        backend=backend)
793
794  def GetProgramShape(self):
795    return _wrap_program_shape(self._c_computation.GetProgramShape())
796
797  def GetReturnValueShape(self):
798    return _wrap_shape(self._c_computation.GetReturnValueShape())
799
800  def __del__(self):
801    if self._c_computation:
802      self._delete_computation(self._c_computation)
803
804
805class Executable(object):
806  """Python wrapper for an XLA Executable."""
807
808  def __init__(self, c_executable, backend=None):
809    self._c_executable = c_executable
810    self._device_ordinals = c_executable.DeviceOrdinals()
811    self._backend = backend
812
813  def DeviceOrdinals(self):
814    """Returns a list containing the device ordinals for each replica."""
815    return self._device_ordinals
816
817  def Execute(self, arguments=(), check_for_deleted_args=True):
818    """Execute on one replica with LocalBuffer arguments and return value."""
819    if check_for_deleted_args and any(arg.is_deleted() for arg in arguments):
820      raise ValueError('Executing with deleted local buffer argument')
821    raw_args = [arg.c_buffer for arg in arguments]
822    output_buffer = self._backend.execute(self._c_executable, raw_args)
823    return LocalBuffer(
824        output_buffer, backend=self._backend, device=self._device_ordinals[0])
825
826  def ExecutePerReplica(self, arguments=None):
827    """Execute on many replicas with LocalBuffer arguments and return value.
828
829    Args:
830      arguments: A sequence of sequences of LocalBuffers. The i'th inner
831        sequence comprises the arguments for execution on the i'th replica.
832
833    Returns:
834      A list of the computation's outputs for each replica, as a LocalBuffer. If
835      a shallow sequence of arguments was passed in for `arguments`, then the
836      sole, zero'th replica's output is returned instead, as a LocalBuffer.
837    """
838    if arguments is None:
839      arguments = ((),) * len(self._device_ordinals)
840    else:
841      arguments = [list(replica_args) for replica_args in arguments]
842
843    # Check arguments
844    for replica, replica_args in enumerate(arguments):
845      for arg in replica_args:
846        if arg.is_deleted():
847          raise ValueError('Executing with deleted local buffer argument')
848        if arg.device() != self._device_ordinals[replica]:
849          raise ValueError(
850              'Executing on device {} with argument from device {}'.format(
851                  self._device_ordinals[replica], arg.device()))
852
853    # Pull out argument buffer handles
854    # pylint: disable=g-complex-comprehension
855    stripped_args = [
856        [arg.c_buffer for arg in replica_args] for replica_args in arguments
857    ]
858
859    # Execute
860    output_buffers = self._backend.execute_replicated(self._c_executable,
861                                                      stripped_args)
862
863    # Wrap output handles in LocalBuffer instances
864    return tuple(
865        LocalBuffer(
866            output_buffer,
867            backend=self._backend,
868            device=self._device_ordinals[replica])
869        for replica, output_buffer in enumerate(output_buffers))
870
871  def ExecuteWithPythonValues(self, arguments=()):
872    """Execute on one replica with Python values as arguments and output."""
873
874    def put(arg):
875      return LocalBuffer.from_pyval(
876          arg, device=self._device_ordinals[0], backend=self._backend)
877
878    arguments = [put(arg) for arg in arguments]
879    return self.Execute(arguments).to_py()
880
881  def ExecuteWithPythonValuesPerReplica(self, arguments):
882    """Execute on many replicas with Python values as arguments and output."""
883
884    def put(arg, device):
885      return LocalBuffer.from_pyval(arg, device, backend=self._backend)
886
887    # pylint: disable=g-complex-comprehension
888    arguments = [[
889        put(arg, self._device_ordinals[replica]) for arg in replica_args
890    ] for replica, replica_args in enumerate(arguments)]
891    return [out.to_py() for out in self.ExecutePerReplica(arguments)]
892
893  def __del__(self):
894    # Python may have freed c_api first.
895    if c_api and self._c_executable:
896      self._backend.delete_executable(self._c_executable)
897
898
899class ComputationBuilder(object):
900  """XLA computation builder.
901
902  Enqueues XLA ops in sequence and in order to build a
903  Computation, which in turn can be compiled into a
904  LocalExecutable, which in turn can be locally executed.
905  """
906
907  # The methods of this class map 1-to-1 onto the XLA C++
908  # computation builder API. Therefore, there's no need to laboriously list
909  # arguments and return values for every method, especially where it's obvious.
910  #
911  # pylint: disable=g-doc-return-or-yield
912  # pylint: disable=g-doc-args
913
914  def __init__(self, name):
915    self._client = c_api.ComputationBuilder(name.encode('utf8'))
916    self._parameter_numbering = itertools.count()
917
918  def Build(self, root=None, backend=None):
919    """Builds a `Computation` from the contents of the builder.
920
921    Args:
922      root: if not None, the operator containing the return value of the
923        computation.
924      backend: deprecated. Pass a `backend` to `Computation.Compile` instead.
925
926    Returns:
927      A `Computation`.
928    """
929    if root is not None:
930      return Computation(self._client.BuildWithRoot(root), backend=backend)
931    else:
932      return Computation(self._client.Build(), backend=backend)
933
934  def SetOpMetadata(self, op_metadata):
935    """Set metadata for operations that are about to be enqueued."""
936    self._client.SetOpMetadata(op_metadata)
937
938  def ClearOpMetadata(self):
939    """Clear metadata for operations that are about to be enqueued."""
940    self._client.ClearOpMetadata()
941
942  def Infeed(self, shape):
943    """Enqueues an infeed op onto the computation.
944
945    Infeed operations dequeue data of the given shape from the device's infeed
946    queue for subsequent use in the computation.
947
948    Returns:
949      A LocalOp.
950    """
951    return self._client.Infeed(shape)
952
953  def Outfeed(self, operand):
954    """Enqueues an outfeed op onto the computation.
955
956    Outfeed operations enqueue data, using the given operand, onto the XLA
957    outfeed queue for subsequent dequeue via the client API.
958    """
959    self._client.Outfeed(operand, self.GetShape(operand), ''.encode('utf-8'))
960
961  def Constant(self, value):
962    """Enqueues a constant op onto the computation.
963
964    Args:
965      value: value for the constant, as a np.array with an explicit dtype set to
966        one of the supported types.
967
968    Returns:
969      A LocalOp.
970    """
971    value = require_numpy_array_layout(value)
972    return self._client.ConstantLiteral(value)
973
974  def ConstantF32Scalar(self, value):
975    """Convenience method to enqueue a scalar F32 constant op.
976
977    Args:
978      value: a floating-point number.
979
980    Returns:
981      A LocalOp.
982    """
983    return self.Constant(np.array(value, dtype=np.float32))
984
985  def ConstantF64Scalar(self, value):
986    """Convenience method to enqueue a scalar F32 constant op.
987
988    Args:
989      value: a floating-point number.
990
991    Returns:
992      A LocalOp.
993    """
994    return self.Constant(np.array(value, dtype=np.float64))
995
996  def ConstantS32Scalar(self, value):
997    """Convenience method to enqueue a scalar S32 constant op.
998
999    Args:
1000      value: a floating-point number.
1001
1002    Returns:
1003      A LocalOp.
1004    """
1005    return self.Constant(np.array(value, dtype=np.int32))
1006
1007  def ConstantS64Scalar(self, value):
1008    """Convenience method to enqueue a scalar S64 constant op.
1009
1010    Args:
1011      value: a floating-point number.
1012
1013    Returns:
1014      A LocalOp.
1015    """
1016    return self.Constant(np.array(value, dtype=np.int64))
1017
1018  def ConstantPredScalar(self, value):
1019    """Convenience method to enqueue a scalar PRED constant op.
1020
1021    Args:
1022      value: a boolean value.
1023
1024    Returns:
1025      A LocalOp.
1026    """
1027    return self.Constant(np.array(value, dtype=np.bool))
1028
1029  def ParameterWithShape(self, shape, name=None, parameter_num=None):
1030    """Enqueues a Parameter op onto the computation, given a shape.
1031
1032    Args:
1033      shape: the parameter's shape as a Shape object.
1034      name: optional string name for the parameter.
1035      parameter_num: parameter number in the computation function. If None, the
1036        next linear parameter number is used. The default value capability can
1037        be used for auto-numbering. If you're using auto-numbering for some
1038        parameters, use it for *all* parameters to avoid clashes.
1039
1040    Returns:
1041      A LocalOp.
1042    """
1043    if name is None:
1044      name = ''
1045    if parameter_num is None:
1046      parameter_num = next(self._parameter_numbering)
1047
1048    return self._client.Parameter(parameter_num, shape, name.encode('utf8'))
1049
1050  def ParameterFromNumpy(self, value, name=None, parameter_num=None):
1051    """Enqueues a Parameter op onto the computation.
1052
1053    Args:
1054      value: a Numpy array, or a nested tuple thereof, from which the shape is
1055        inferred.
1056      name: as in ParameterWithShape.
1057      parameter_num: as in ParameterWithShape.
1058
1059    Returns:
1060      A LocalOp.
1061    """
1062    return self.ParameterWithShape(
1063        Shape.from_pyval(value), name=name, parameter_num=parameter_num)
1064
1065  def Iota(self, dtype, size):
1066    """Enqueues an iota constant onto the computation.
1067
1068    Args:
1069      dtype: expected numpy dtype of the output.
1070      size: integer, the number of elements in the array.
1071
1072    Returns:
1073      A LocalOp representing the added iota constant.
1074    """
1075    element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))]
1076    return self._client.Iota(element_type, size)
1077
1078  def BroadcastedIota(self, dtype, shape, dimension):
1079    """Enqueues a broadcasted iota constant onto the computation.
1080
1081    Args:
1082      dtype: expected numpy dtype of the output.
1083      shape: tuple of integers, the expected output shape (dimensions).
1084      dimension: positive integer, dimension along which to increment values.
1085
1086    Returns:
1087      A LocalOp representing the added broadcasted iota constant.
1088    """
1089    xla_shape = Shape.array_shape(dtype, shape)
1090    return self._client.BroadcastedIota(xla_shape, dimension)
1091
1092  def Broadcast(self, operand, sizes):
1093    """Enqueues a broadcast operation onto the computation.
1094
1095    Args:
1096      operand: the operand LocalOp to broadcast.
1097      sizes: an iterable of broadcast sizes.
1098
1099    Returns:
1100      A LocalOp representing the added broadcast op.
1101    """
1102    return self._client.Broadcast(operand, sizes)
1103
1104  def BroadcastInDim(self, operand, shape, broadcast_dimensions):
1105    """Enqueues a broadcast-in-dimensions operation onto the computation.
1106
1107    Args:
1108      operand: the operand LocalOp to broadcast.
1109      shape: tuple of integers, the expected output shape.
1110      broadcast_dimensions: tuple of integers identifying which dimensions of
1111        the output are to be broadcast into.
1112
1113    Returns:
1114      A LocalOp representing the added broadcast-in-dimensions op.
1115    """
1116    return self._client.BroadcastInDim(operand, shape, broadcast_dimensions)
1117
1118  def Concatenate(self, operands, dimension):
1119    """Enqueues a concatenate operation onto the computation.
1120
1121    Args:
1122      operands: the operands to concatenate.
1123      dimension: the dimension in which to perform the concatenation.
1124
1125    Returns:
1126      A LocalOp representing the added concatenate op.
1127    """
1128    return self._client.ConcatInDim(operands, dimension)
1129
1130  def ConvertElementType(self, operand, new_element_type):
1131    """Enqueues an element type conversion operation onto the computation.
1132
1133    Args:
1134      operand: the operand to convert.
1135      new_element_type: the target primitive type.
1136
1137    Returns:
1138      A LocalOp representing the added conversion op.
1139    """
1140    return self._client.ConvertElementType(operand, new_element_type)
1141
1142  def BitcastConvertType(self, operand, new_element_type):
1143    """Enqueues a bitcast type conversion operation onto the computation.
1144
1145    Args:
1146      operand: the operand to convert.
1147      new_element_type: the target primitive type.
1148
1149    Returns:
1150      A LocalOp representing the added conversion op.
1151    """
1152    return self._client.BitcastConvertType(operand, new_element_type)
1153
1154  def GetShape(self, operand):
1155    return _wrap_shape(self._client.GetShape(operand))
1156
1157  def GetReturnValueShape(self):
1158    return _wrap_shape(self._client.GetReturnValueShape())
1159
1160  def GetComputationStats(self):
1161    raise NotImplementedError()
1162
1163  def ReplicaId(self):
1164    """Enqueues a ReplicaId operation onto the computation.
1165
1166    Returns:
1167      A LocalOp representing the replica id.
1168    """
1169    return self._client.ReplicaId()
1170
1171  def Pad(self, operand, padding_value, padding_config):
1172    """Enqueues a Pad operation onto the computation.
1173
1174    Args:
1175      operand: LocalOp representing the array to pad.
1176      padding_value: LocalOp representing the scalar pad value.
1177      padding_config: either a PaddingConfig or a list of integer triples
1178        (edge_padding_low, edge_padding_high, interior_padding) representing the
1179        configuration of the padding operation.
1180
1181    Returns:
1182      A LocalOp representing the added Pad op.
1183    """
1184    if isinstance(padding_config, tuple) or isinstance(padding_config, list):
1185      padding_config = GetPaddingConfigFromTriples(padding_config)
1186    return self._client.Pad(operand, padding_value, padding_config)
1187
1188  def Reshape(self, operand, dimensions, new_sizes):
1189    """Enqueues a reshape op onto the computation.
1190
1191    Args:
1192      operand: LocalOp representing the array to be reshaped.
1193      dimensions: sequence of integers encoding the order in which dimensions
1194        are collapsed or None, in which case dimensions are flattened in order.
1195      new_sizes: sequence of integers encoding the new dimension sizes (shape).
1196
1197    Returns:
1198      A LocalOp representing the added Reshape op.
1199    """
1200    if dimensions is None:
1201      ndim = len(self.GetShape(operand).dimensions())
1202      dimensions = tuple(range(ndim))
1203    return self._client.Reshape(operand, dimensions, new_sizes)
1204
1205  def AllToAll(self,
1206               operand,
1207               split_dimension,
1208               concat_dimension,
1209               replica_groups=None):
1210    """AllToAll op.
1211
1212    Args:
1213      operand: LocalOp representing the input array
1214      split_dimension: the dimension along which the operand is split
1215      concat_dimension: the dimension along which the split blocks are
1216        concatenated
1217      replica_groups: optional, list of lists of ints encoding a partition of
1218        the set {0, 1, ..., num_replicas} into equally-sized replica groups
1219        within which the all-to-all is performed. If not supplied or None (the
1220        default), all replicas belong to the same group.
1221
1222    Returns:
1223      A LocalOp that represents the all-to-all concatenation.
1224    """
1225    if replica_groups is None:
1226      replica_groups_protos = []  # special value for XLA API
1227    else:
1228      replica_groups = list(replica_groups)
1229      replica_groups_protos = [
1230          _make_replica_group_proto(group) for group in replica_groups
1231      ]
1232    if not replica_groups:
1233      split_count = get_replica_count()
1234    else:
1235      split_count = len(replica_groups[0])
1236      if not all(split_count == len(g) for g in replica_groups):
1237        raise ValueError('Replica groups must be equally sized')
1238    return self._client.AllToAll(operand, split_dimension, concat_dimension,
1239                                 split_count, replica_groups_protos)
1240
1241  def CrossReplicaSum(self, operand, replica_groups=None):
1242    """CrossReplicaSum op.
1243
1244    Args:
1245      operand: the operand to sum across replica instances.
1246      replica_groups: optional, list of lists of ints encoding a partition of
1247        the set {0, 1, ..., num_replicas} into equally-sized replica groups
1248        within which the cross-replica sum is performed. If not supplied or None
1249        (the default), all replicas belong to the same group.
1250
1251    Returns:
1252      A LocalOp that represents on each replica the sum of its group's values.
1253    """
1254    if replica_groups is None:
1255      replica_groups = []  # special value for XLA API
1256    else:
1257      replica_groups = [
1258          _make_replica_group_proto(group) for group in replica_groups
1259      ]
1260    return self._client.CrossReplicaSum(operand, replica_groups)
1261
1262  def Collapse(self, operand, dimensions):
1263    """Collapse op."""
1264    return self._client.Collapse(operand, dimensions)
1265
1266  def Trans(self, operand):
1267    """Specialized matrix transpose op."""
1268    return self._client.Transpose(operand, [1, 0])
1269
1270  def Transpose(self, operand, permutation):
1271    """Transpose op."""
1272    return self._client.Transpose(operand, permutation)
1273
1274  def Rev(self, operand, dimensions):
1275    """Rev op."""
1276    return self._client.Rev(operand, dimensions)
1277
1278  def Clamp(self, min, operand, max):  # pylint: disable=redefined-builtin
1279    """Clamp op."""
1280    return self._client.Clamp(min, operand, max)
1281
1282  def SelectAndScatter(self, operand, select, window_dimensions, window_strides,
1283                       padding, source, init_value, scatter):
1284    """Select and scatter op, used by the gradient of ReduceWindow.
1285
1286    Args:
1287      operand: LocalOp for array of dimension N and type T over which the
1288        windows slide.
1289      select: Computation of type (T, T) -> Pred to apply to the elements of
1290        each window to indicate which element is selected.
1291      window_dimensions: sequence of N integers for dimensions of the window.
1292      window_strides: sequence of N integers for the strides of the window.
1293      padding: PaddingType representing either 'SAME' or 'VALID ' padding.
1294      source: LocalOp for array of type T with values to scatter.
1295      init_value: LocalOp of scalar type T for initial out value.
1296      scatter: Computation of type (T, T) -> T to apply to each scatter source
1297        element with its destination element.
1298
1299    Returns:
1300      A LocalOp representing the added SelectAndScatter op.
1301    """
1302    pads = _convert_padding_type_to_pad_values(
1303        padding, self.GetShape(operand).dimensions(), window_dimensions,
1304        window_strides)
1305    return self._client.SelectAndScatterWithGeneralPadding(
1306        operand, select.computation, window_dimensions, window_strides, pads,
1307        source, init_value, scatter.computation)
1308
1309  def Select(self, pred, on_true, on_false):
1310    """Element-wise selection op.
1311
1312    Constructs an output array from elements of two input arrays, based on the
1313    values of a predicate array.
1314    """
1315    return self._client.Select(pred, on_true, on_false)
1316
1317  def Slice(self, operand, start_indices, limit_indices, strides=None):
1318    """Enqueues a slice operation onto the computation.
1319
1320    Args:
1321      operand: LocalOp for the N dimensional array to be sliced.
1322      start_indices: iterable of N integers containing the starting indices of
1323        the slice for each dimension.
1324      limit_indices: iterable of N integers containing the ending indices
1325        (exclusive) of the slice for each dimension.
1326      strides: optional iterable of N integers containing the stride sizes for
1327        each dimension.
1328
1329    Returns:
1330      A LocalOp representing the added Slice op.
1331    """
1332    if strides is None:
1333      start_indices = list(start_indices)
1334      strides = [1] * len(start_indices)
1335    return self._client.Slice(operand, start_indices, limit_indices, strides)
1336
1337  def SliceInDim(self, operand, start_index, limit_index, stride, dimno):
1338    """Enqueues a slice-in-dimension operation onto the computation.
1339
1340    Args:
1341      operand: LocalOp for the N dimensional array to be sliced.
1342      start_index: an integer containing the start index of the slice.
1343      limit_index: an integer containing the end index of the slice.
1344      stride: an integer containing the stride size for the slice.
1345      dimno: an integer indicating the dimension along which to slice.
1346
1347    Returns:
1348      A LocalOp representing the added Slice op.
1349    """
1350    return self._client.SliceInDim(operand, start_index, limit_index, stride,
1351                                   dimno)
1352
1353  def DynamicSlice(self, operand, start_indices, slice_sizes):
1354    """Enqueues a slice op with dynamic start indices onto the computation.
1355
1356    Args:
1357      operand: LocalOp for the N dimensional array to be sliced.
1358      start_indices: LocalOp for the 1D array of N integers containing the
1359        starting indices of the slice.
1360      slice_sizes: iterable of N integers containing the slice sizes in each
1361        dimension.
1362
1363    Returns:
1364      A LocalOp representing the added DynamicSlice op.
1365    """
1366    return self._client.DynamicSlice(operand, start_indices, slice_sizes)
1367
1368  def DynamicUpdateSlice(self, operand, update, start_indices):
1369    """Enqueues a dynamic update slice operation onto the computation.
1370
1371    Args:
1372      operand: LocalOp for the N dimensional array to be updated.
1373      update: N dimensional array comprising the slice update.
1374      start_indices: Rank-1 array of N integers comprising the starting indices
1375        of the slice along each dimension.
1376
1377    Returns:
1378      A LocalOp representing the added DynamicUpdateSlice op.
1379    """
1380    return self._client.DynamicUpdateSlice(operand, update, start_indices)
1381
1382  def Tuple(self, *ops):
1383    """Enqueues a tuple operation onto the computation.
1384
1385    Args:
1386      ops: a sequence of tuple operands (each a LocalOp).
1387
1388    Returns:
1389      A LocalOp representing the added Tuple op.
1390    """
1391    return self._client.Tuple(ops)
1392
1393  def GetTupleElement(self, tup, index):
1394    """Enqueues a 'get tuple element' operation onto the computation.
1395
1396    Args:
1397      tup: the tuple operand (a LocalOp).
1398      index: numeric index to select from the tuple.
1399
1400    Returns:
1401      A LocalOp representing the added GetTupleElement op.
1402    """
1403    return self._client.GetTupleElement(tup, index)
1404
1405  def Call(self, computation_to_apply, operands):
1406    """Enqueues a call operation onto the computation.
1407
1408    Args:
1409      computation_to_apply: a Computation object.
1410      operands: an iterable of LocalOp. The number and types of operands must
1411        match the arity of computation_to_apply.
1412
1413    Returns:
1414      A LocalOp representing the added call op.
1415    """
1416    return self._client.Call(computation_to_apply.computation, operands)
1417
1418  def CustomCall(self,
1419                 call_target_name,
1420                 operands,
1421                 shape_with_layout,
1422                 operand_shapes_with_layout,
1423                 opaque=None):
1424    """Enqueues a custom call operation onto the computation.
1425
1426    Args:
1427      call_target_name: the name of the function to call.
1428      operands: an iterable of LocalOp. The number and types of operands must
1429        match the arity of `operand_shapes_with_layout`.
1430      shape_with_layout: the shape of the operator's output, with layout.
1431      operand_shapes_with_layout: the shapes of `operands`, including the
1432        expected layouts.
1433      opaque: an opaque string passed to the backend.
1434
1435    Returns:
1436      A LocalOp representing the added custom call op.
1437    """
1438    opaque = opaque or b''
1439    return self._client.CustomCall(call_target_name, operands,
1440                                   shape_with_layout,
1441                                   operand_shapes_with_layout, opaque)
1442
1443  def Map(self, operands, computation_to_apply, dimensions):
1444    """Enqueues a map operation onto the computation.
1445
1446    Args:
1447      operands: an iterable of LocalOp.
1448      computation_to_apply: a Computation object.
1449      dimensions: dimensions over which to apply map the function.
1450
1451    Returns:
1452      A LocalOp representing the added Map op.
1453    """
1454    return self._client.Map(operands, computation_to_apply.computation,
1455                            dimensions)
1456
1457  def Reduce(self, operand, init_value, computation_to_apply, dimensions):
1458    """Enqueues a reduction operation onto the computation.
1459
1460    Args:
1461      operand: reduction operand (LocalOp).
1462      init_value: reduction initial value (LocalOp).
1463      computation_to_apply: a Computation object - binary reduction function.
1464      dimensions: sequence of dimensions (integers) to reduce on.
1465
1466    Returns:
1467      A LocalOp representing the added Reduce op.
1468    """
1469    return self._client.Reduce(operand, init_value,
1470                               computation_to_apply.computation, dimensions)
1471
1472  def ReduceWindow(self, operand, init_value, computation_to_apply,
1473                   window_dimensions, window_strides, padding):
1474    """Enqueues a windowed reduction operation onto the computation.
1475
1476    Args:
1477      operand: reduction operand (LocalOp).
1478      init_value: reduction initial value (LocalOp).
1479      computation_to_apply: a binary reduction function (Computation).
1480      window_dimensions: dimensions of window (sequence of integers).
1481      window_strides: strides for window (sequence of integers).
1482      padding: PaddingType representing either 'SAME' or 'VALID' padding.
1483
1484    Returns:
1485      A LocalOp representing the added ReduceWindow op.
1486    """
1487    pads = _convert_padding_type_to_pad_values(
1488        padding,
1489        self.GetShape(operand).dimensions(), window_dimensions, window_strides)
1490    return self._client.ReduceWindowWithGeneralPadding(
1491        operand, init_value, computation_to_apply.computation,
1492        window_dimensions, window_strides, (), (), pads)
1493
1494  def ReduceWindowWithGeneralPadding(
1495      self, operand, init_value, computation_to_apply, window_dimensions,
1496      window_strides, base_dilations, window_dilations, padding):
1497    """Enqueues a windowed reduction operation onto the computation.
1498
1499    Args:
1500      operand: reduction operand (LocalOp).
1501      init_value: reduction initial value (LocalOp).
1502      computation_to_apply: a binary reduction function (Computation).
1503      window_dimensions: dimensions of window (sequence of integers).
1504      window_strides: strides for window (sequence of integers).
1505      base_dilations: dilations for the base (sequence of integers).
1506      window_dilations: dilations for window (sequence of integers).
1507      padding: length-N array-like of pairs of integers of (low, high) padding.
1508
1509    Returns:
1510      A LocalOp representing the added ReduceWindow op.
1511    """
1512    return self._client.ReduceWindowWithGeneralPadding(
1513        operand, init_value, computation_to_apply.computation,
1514        window_dimensions, window_strides, base_dilations, window_dilations,
1515        padding)
1516
1517  def RngNormal(self, mu, sigma, dims):
1518    """Enqueues an RngNormal operation onto the computation.
1519
1520    Args:
1521      mu: A LocalOp to an F32 scalar specifying the mean.
1522      sigma: A LocalOp to an F32 scalar specifying the standard deviation.
1523      dims: A 1D array-like of nonnegative integers specifying the dimensions.
1524    Returns: a LocalOp to the generated array of F32 values.
1525    """
1526    shape = Shape.array_shape(self.GetShape(mu).element_type(), dims)
1527    return self._client.RngNormal(mu, sigma, shape)
1528
1529  def RngUniform(self, a, b, dims):
1530    """Enqueues an RngUniform operation onto the computation.
1531
1532    Args:
1533      a: a LocalOp to an F32, S32, or U32 scalar (consistent with the type of b)
1534        specifying the low end of the interval [a, b) over which values are
1535        generated.
1536      b: a LocalOp to an F32, S32, or U32 scalar (consistent with the type of a)
1537        specifying the high end of the interval [a, b) over which values are
1538        generated.
1539      dims: A 1D array-like of nonnegative integers specifying the dimensions.
1540    Returns: a LocalOp to the generated array of values with the same numeric
1541      type (F32, S32, or U32) as the arguments a and b.
1542    """
1543    shape = Shape.array_shape(self.GetShape(a).element_type(), dims)
1544    return self._client.RngUniform(a, b, shape)
1545
1546  def While(self, cond, body, init):
1547    """Enqueues a While operation onto the computation.
1548
1549    Args:
1550      cond: a Computation for the loop condition, which has type T -> PRED
1551      body: a Computation for the loop body, which has type T -> T
1552      init: a LocalOp for the initial parameter, which has type T
1553    Returns: a LocalOp representing the While operation.
1554    """
1555    return self._client.While(cond.computation, body.computation, init)
1556
1557  def Conditional(self, pred, true_operand, true_computation, false_operand,
1558                  false_computation):
1559    """Enqueues a Conditional operation onto the computation.
1560
1561    Args:
1562      predicate: a LocalOp to test, which has scalar type PRED
1563      true_operand: a LocalOp of type T_0
1564      true_computation: a Computation to apply to true_operand, type T_0 -> S
1565      false_operand: a ComputationDatahandle of type T_1
1566      false_computation: a Computation to apply to false_operand, type T_1 -> S
1567    Returns: a LocalOp representing the Conditional operation.
1568    """
1569    return self._client.Conditional(pred, true_operand,
1570                                    true_computation.computation, false_operand,
1571                                    false_computation.computation)
1572
1573  def IsConstant(self, operand):
1574    """Checks whether the given operand is a compile-time constant.
1575
1576    Args:
1577      operand: a ComputationDataHandle to test.
1578    Returns: bool indicating whether `operand` is a compile-time constant,
1579      meaning its value does not depend on any parametersor, or on stateful
1580      operators such as `RngNormal` or `Infeed`.
1581    """
1582    return self._client.IsConstant(operand)
1583
1584  def BuildConstantSubGraph(self, operand):
1585    """Builds a constant sub graph.
1586
1587    Args:
1588      operand: a LocalOp to test.
1589    Returns: a Computation that is rooted on the given `operand` which is a
1590      compile-time constant.
1591    """
1592    return self._client.BuildConstantSubGraph(operand)
1593
1594  def Dot(self, lhs, rhs):
1595    """Enqueues a dot operation onto the computation.
1596
1597    Args:
1598      lhs: LocalOp for the rank 1 or rank 2 left-hand-side array.
1599      rhs: LocalOp for the rank 1 or rank 2 right-hand-side array.
1600    Returns: a LocalOp representing the Dot operation.
1601    """
1602    return self._client.Dot(lhs, rhs)
1603
1604  def DotGeneral(self, lhs, rhs, dimension_numbers):
1605    """Enqueues a general dot operation onto the computation.
1606
1607    Args:
1608      lhs: LocalOp for the left-hand-side array.
1609      rhs: LocalOp for the right-hand-side array.
1610      dimension_numbers: either a DotDimensionNumbers or a nested tuple
1611        ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of
1612        integers representing the dimensions to treat as contracting dimensions
1613        and batch dimensions on each input operand.
1614    Returns: a LocalOp representing the DotGeneral operation.
1615    """
1616    if isinstance(dimension_numbers, tuple):
1617      dimension_numbers = GetDotDimensionsFromLists(dimension_numbers)
1618    return self._client.DotGeneral(lhs, rhs, dimension_numbers)
1619
1620  def Conv(self, lhs, rhs, window_strides, padding, feature_group_count=1):
1621    """Enqueues a Conv operation onto the computation.
1622
1623    Args:
1624      lhs: LocalOp for the rank N+2 array of inputs.
1625      rhs: LocalOp for the rank N+2 array of kernel weights.
1626      window_strides: length-N array-like of integer kernel strides.
1627      padding: PaddingType representing either 'SAME' or 'VALID' padding.
1628      feature_group_count: number of feature groups for grouped convolution.
1629    Returns: a LocalOp representing the Conv operation.
1630    """
1631    pads = _convert_padding_type_to_pad_values(
1632        padding,
1633        self.GetShape(lhs).dimensions()[2:],
1634        self.GetShape(rhs).dimensions()[2:], window_strides)
1635    return self.ConvGeneralDilated(
1636        lhs, rhs, window_strides, pads, (), (), dimension_numbers=None,
1637        feature_group_count=feature_group_count)
1638
1639  def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding,
1640                             lhs_dilation, rhs_dilation, feature_group_count=1):
1641    """Enqueues a ConvWithGeneralPadding operation onto the computation.
1642
1643    Args:
1644      lhs: LocalOp for the rank N+2 array of inputs.
1645      rhs: LocalOp for the rank N+2 array of kernel weights.
1646      window_strides: length-N array-like of kernel strides.
1647      padding: length-N array-like of pairs of integers of (low, high) padding.
1648      lhs_dilation: length-N array-like of dilation factors.
1649      rhs_dilation: length-N array-like of dilation factors.
1650      feature_group_count: number of feature groups for grouped convolution.
1651
1652    Returns:
1653      A ComputationdataHandle representing the added ConvWithGeneralPadding op.
1654    """
1655    return self.ConvGeneralDilated(
1656        lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
1657        dimension_numbers=None, feature_group_count=feature_group_count)
1658
1659  def _GetConvDimensionNumbers(self, num_spatial_dims):
1660    """Create ConvolutionDimensionNumbers proto for convolutions."""
1661    nd = num_spatial_dims
1662    dimension_numbers = ConvolutionDimensionNumbers()
1663    dimension_numbers.input_batch_dimension = 0
1664    dimension_numbers.input_feature_dimension = 1
1665    dimension_numbers.output_batch_dimension = 0
1666    dimension_numbers.output_feature_dimension = 1
1667    dimension_numbers.kernel_output_feature_dimension = 0
1668    dimension_numbers.kernel_input_feature_dimension = 1
1669    dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd))
1670    dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd))
1671    dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd))
1672    return dimension_numbers
1673
1674  def ConvGeneralDilated(self, lhs, rhs, window_strides, padding, lhs_dilation,
1675                         rhs_dilation, dimension_numbers=None,
1676                         feature_group_count=1):
1677    """Enqueues a ConvGeneralDilated operation onto the computation.
1678
1679    Args:
1680      lhs: LocalOp for the rank N+2 array of inputs.
1681      rhs: LocalOp for the rank N+2 array of kernel weights.
1682      window_strides: length-N array-like of integer kernel strides.
1683      padding: length-N array-like of pairs of integers of (low, high) padding.
1684      lhs_dilation: length-N array-like of integer dilation factors.
1685      rhs_dilation: length-N array-like of integer dilation factors.
1686      dimension_numbers: optional, either a ConvolutionDimensionNumbers object
1687        or a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of
1688        length N+2 identifying by position: (1) batch dimensions in lhs, rhs,
1689          and the output with the character 'N', (2) feature dimensions in lhs
1690          and the output with the character 'C', (3) input and output feature
1691          dimensions in rhs with the characters 'I' and 'O' respectively, and
1692          (4) spatial dimension correspondences between lhs, rhs, and the output
1693          using any distinct characters. For example, to indicate dimension
1694          numbers consistent with the Conv operation with two spatial
1695          dimensions, one could use ('NCHW', 'OIHW', 'NCHW'). As another
1696          example, to indicate dimension numbers consistent with the TensorFlow
1697          Conv2D operation, one could use ('NHWC', 'HWIO', 'NHWC'). When using
1698          the latter form of convolution dimension specification, window strides
1699          are associated with spatial dimension character labels according to
1700          the order in which the labels appear in the rhs_spec string, so that
1701          window_strides[0] is matched with the dimension corresponding to the
1702          first character appearing in rhs_spec that is not 'I' or 'O'. By
1703          default, use the same dimension numbering as Conv and
1704          ConvWithGeneralPadding.
1705      feature_group_count: number of feature groups for grouped convolution.
1706    Returns: a LocalOp representing the ConvGenralDilated operation.
1707    """
1708    if dimension_numbers is None:
1709      dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
1710    elif isinstance(dimension_numbers, tuple):
1711      lhs_spec, rhs_spec, out_spec = dimension_numbers
1712      dimension_numbers = ConvolutionDimensionNumbers()
1713
1714      dimension_numbers.input_batch_dimension = lhs_spec.index('N')
1715      dimension_numbers.input_feature_dimension = lhs_spec.index('C')
1716      dimension_numbers.output_batch_dimension = out_spec.index('N')
1717      dimension_numbers.output_feature_dimension = out_spec.index('C')
1718      dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O')
1719      dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I')
1720
1721      dimension_numbers.kernel_spatial_dimensions.extend(
1722          i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'})
1723      dimension_numbers.input_spatial_dimensions.extend(
1724          sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}),
1725                 key=lambda i: rhs_spec.index(lhs_spec[i])))
1726      dimension_numbers.output_spatial_dimensions.extend(
1727          sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}),
1728                 key=lambda i: rhs_spec.index(out_spec[i])))
1729    return self._client.ConvGeneralDilated(
1730        lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
1731        dimension_numbers, feature_group_count)
1732
1733  def Sort(self, operand, dimension=-1):
1734    """Enqueues a sort operation onto the computation."""
1735    return self._client.Sort(operand, dimension)
1736
1737  def SortKeyVal(self, keys, values, dimension=-1):
1738    """Enqueues a key-value sort operation onto the computation."""
1739    return self._client.SortKeyVal(keys, values, dimension)
1740
1741  def Cholesky(self, a, lower=True):
1742    """Enqueues a Cholesky decomposition onto the computation."""
1743    return self._client.Cholesky(a, lower)
1744
1745  def QR(self, a, full_matrices=True):
1746    """Enqueues a QR decomposition onto the computation."""
1747    return self._client.QR(a, full_matrices)
1748
1749  def TriangularSolve(self,
1750                      a,
1751                      b,
1752                      left_side=False,
1753                      lower=False,
1754                      transpose_a=False,
1755                      conjugate_a=False,
1756                      unit_diagonal=False):
1757    """Enqueues a triangular-solve operation onto the computation."""
1758    if not transpose_a:
1759      transpose = 1
1760      if conjugate_a:
1761        a = self.Conj(a)
1762    else:
1763      transpose = 3 if conjugate_a else 2
1764    return self._client.TriangularSolve(a, b, left_side, lower, unit_diagonal,
1765                                        transpose)
1766
1767  def Eigh(self, a, full_matrices=True):
1768    """Enqueues a symmetric/Hermitian eigendecomposition."""
1769    return self._client.Eigh(a, full_matrices)
1770
1771  def SVD(self, a):
1772    """Enqueues a singular value decomposition."""
1773    return self._client.SVD(a)
1774
1775  def Gather(self, a, start_indices, dimension_numbers, slice_sizes):
1776    """Enqueues a Gather operation onto the computation."""
1777    return self._client.Gather(a, start_indices, dimension_numbers, slice_sizes)
1778
1779  def Scatter(self, a, scatter_indices, updates, update_computation,
1780              dimension_numbers):
1781    """Enqueues a Scatter operation onto the computation."""
1782    return self._client.Scatter(
1783        a, scatter_indices, updates, update_computation.computation,
1784        dimension_numbers)
1785
1786
1787def _forward_methods_to_local_builder():
1788  """Forward remaining ComputationBuilder methods to the C API.
1789
1790  Set up methods, corresponding to unary and binary XLA operations,
1791  whose calls are forwarded in a boilerplate manner to the underlying
1792  ComputationBuilder C-extension API.
1793  """
1794
1795  def forward_to_local_builder_with_handles(target_method, is_binop=False):
1796    """Generate a forwarding method that wraps/unwraps data handles."""
1797
1798    def forward(self, *args, **kwargs):
1799      arg_list = list(args)
1800
1801      if is_binop and len(arg_list) < 3:
1802        arg_list.append(kwargs.get('broadcast_dimensions', ()))
1803
1804      return target_method(
1805          self._client,  # pylint: disable=protected-access
1806          *arg_list)
1807
1808    return forward
1809
1810  for method_name in _UNARY_OPS:
1811    forward = forward_to_local_builder_with_handles(
1812        getattr(c_api.ComputationBuilder, method_name))
1813    forward.__name__ = method_name
1814    setattr(ComputationBuilder, method_name, forward)
1815
1816  for method_name in _BINARY_OPS:
1817    forward = forward_to_local_builder_with_handles(
1818        getattr(c_api.ComputationBuilder, method_name), is_binop=True)
1819    forward.__name__ = method_name
1820    setattr(ComputationBuilder, method_name, forward)
1821
1822
1823_forward_methods_to_local_builder()
1824
1825_default_replica_count = 1
1826
1827
1828def initialize_replica_count(replica_count):
1829  """Initializes the default replica count to use.
1830
1831  Deprecated; pass `num_replicas` as an option to `Computation.Compile()`
1832  instead.
1833
1834  Args:
1835    replica_count: number of replicas that are desired for set up during XLA
1836      initialization.
1837
1838  Raises:
1839    A runtime exception if the XLA service has already been initialized.
1840  """
1841  global _default_replica_count
1842  _default_replica_count = replica_count
1843
1844
1845def get_replica_count():
1846  """Returns the default replica count.
1847
1848  Deprecated; pass `num_replicas` as an option to `Computation.Compile()`
1849  instead.
1850  """
1851  return _default_replica_count
1852
1853
1854def initialize_platform_name(platform_name):
1855  """Initializes the default platform name to use for XLA.
1856
1857  Args:
1858    platform_name: string name of platform.
1859  """
1860  global _default_platform_name
1861  _default_platform_name = platform_name
1862
1863  # Make sure the platform is valid by trying to instantiate it.
1864  _get_default_local_backend()
1865
1866
1867def register_cpu_custom_call_target(name, fn):
1868  """Registers a CPU custom call target.
1869
1870  Args:
1871    name: bytes containing the name of the function.
1872    fn: a PyCapsule object containing the function pointer.
1873  """
1874  c_api.RegisterCpuCustomCallTarget(name, fn)
1875
1876
1877class PaddingConfigDimension(object):
1878  """Python representation of a xla.PaddingConfigDimension protobuf."""
1879  __slots__ = ('edge_padding_low', 'edge_padding_high', 'interior_padding')
1880
1881  def __init__(self):
1882    self.edge_padding_low = []
1883    self.edge_padding_high = []
1884    self.interior_padding = []
1885
1886
1887class PaddingConfig(object):
1888  """Python representation of a xla.PaddingConfig protobuf."""
1889  __slots__ = ('dimensions',)
1890
1891  def __init__(self):
1892    self.dimensions = []
1893
1894
1895def GetPaddingConfigFromTriples(triples):
1896  """Create PaddingConfig proto from list of triples of integers."""
1897  padding_config = PaddingConfig()
1898  for lo, hi, interior in triples:
1899    dimension = PaddingConfigDimension()
1900    dimension.edge_padding_low = lo
1901    dimension.edge_padding_high = hi
1902    dimension.interior_padding = interior
1903    padding_config.dimensions.append(dimension)
1904  return padding_config
1905
1906
1907class DotDimensionNumbers(object):
1908  """Python representation of a xla.DotDimensionNumbers protobuf."""
1909  __slots__ = ('lhs_contracting_dimensions', 'rhs_contracting_dimensions',
1910               'lhs_batch_dimensions', 'rhs_batch_dimensions')
1911
1912  def __init__(self):
1913    self.lhs_contracting_dimensions = []
1914    self.rhs_contracting_dimensions = []
1915    self.lhs_batch_dimensions = []
1916    self.rhs_batch_dimensions = []
1917
1918
1919def GetDotDimensionsFromLists(dimension_numbers):
1920  (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
1921  dot_dims_proto = DotDimensionNumbers()
1922  dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract)
1923  dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract)
1924  dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch)
1925  dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch)
1926  return dot_dims_proto
1927
1928
1929class ConvolutionDimensionNumbers(object):
1930  """Python representation of a xla.ConvolutionDimensionNumbers protobuf."""
1931  __slots__ = ('input_batch_dimension', 'input_feature_dimension',
1932               'input_spatial_dimensions', 'kernel_input_feature_dimension',
1933               'kernel_output_feature_dimension', 'kernel_spatial_dimensions',
1934               'output_batch_dimension', 'output_feature_dimension',
1935               'output_spatial_dimensions')
1936
1937  def __init__(self):
1938    self.input_batch_dimension = 0
1939    self.input_feature_dimension = 0
1940    self.input_spatial_dimensions = []
1941    self.kernel_input_feature_dimension = 0
1942    self.kernel_output_feature_dimension = 0
1943    self.kernel_spatial_dimensions = []
1944    self.output_batch_dimension = 0
1945    self.output_feature_dimension = 0
1946    self.output_spatial_dimensions = []
1947
1948
1949class GatherDimensionNumbers(object):
1950  """Python representation of a xla.GatherDimensionNumbers protobuf."""
1951  __slots__ = ('offset_dims', 'collapsed_slice_dims', 'start_index_map',
1952               'index_vector_dim')
1953
1954  def __init__(self):
1955    self.offset_dims = []
1956    self.collapsed_slice_dims = []
1957    self.start_index_map = []
1958    self.index_vector_dim = 0
1959
1960
1961class ScatterDimensionNumbers(object):
1962  """Python representation of a xla.ScatterDimensionNumbers protobuf."""
1963  __slots__ = ('update_window_dims', 'inserted_window_dims',
1964               'scatter_dims_to_operand_dims', 'index_vector_dim')
1965
1966  def __init__(self):
1967    self.update_window_dims = []
1968    self.inserted_window_dims = []
1969    self.scatter_dims_to_operand_dims = []
1970    self.index_vector_dim = 0
1971
1972
1973class ReplicaGroup(object):
1974  """Python representation of a xla.ReplicaGroup protobuf."""
1975  __slots__ = ('replica_ids',)
1976
1977  def __init__(self):
1978    self.replica_ids = []
1979
1980
1981def _make_replica_group_proto(replica_group):
1982  replica_group_proto = ReplicaGroup()
1983  replica_group_proto.replica_ids.extend(replica_group)
1984  return replica_group_proto
1985