1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Experimental library that exposes XLA operations directly in TensorFlow.
16
17It is sometimes useful to be able to build HLO programs directly from
18TensorFlow. This file provides Tensorflow operators that mirror the semantics of
19HLO operators as closely as possible.
20
21Note: There is no promise of backward or forward compatibility for operators
22defined in this module. This is primarily because the underlying HLO operators
23do not promise backward or forward compatibility.
24"""
25
26from __future__ import absolute_import
27from __future__ import division
28from __future__ import print_function
29
30from tensorflow.compiler.tf2xla.ops import gen_xla_ops
31from tensorflow.core.framework import attr_value_pb2
32from tensorflow.python.framework import constant_op
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import ops
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import bitwise_ops
37from tensorflow.python.ops import gen_math_ops
38from tensorflow.python.ops import gen_random_ops
39from tensorflow.python.ops import math_ops
40from tensorflow.python.ops import random_ops
41from tensorflow.python.ops import special_math_ops
42
43# TODO(phawkins): provide wrappers for all XLA operators. Currently the missing
44# ops include:
45# infeed/outfeed (available via tf.contrib.tpu)
46# collectives, e.g., cross-replica-sum (available via tf.contrib.tpu)
47# conditional
48# gather/scatter
49# collapse
50
51# This file reuses builtin names (following XLA's names, so we can call things
52# like xla.max), so we capture the builtin versions here.
53# pylint: disable=redefined-builtin
54_max = max
55_min = min
56_slice = slice  # pylint: disable=invalid-name
57
58constant = constant_op.constant
59
60# Unary operators.
61
62# For most arithmetic operators there is a TensorFlow operator
63# that exactly corresponds to each XLA operator. Rather than defining
64# XLA-specific variants, we reuse the corresponding TensorFlow operator.
65# TODO(phawkins): It would be even better to have TensorFlow operators that 1:1
66# wrap every HLO operator, because that would allow us to be confident that the
67# semantics match.
68
69
70def _unary_op(fn):
71  """Wrapper that restricts `fn` to have the correct signature."""
72
73  def unary_op_wrapper(x, name=None):
74    return fn(x, name=name)
75
76  return unary_op_wrapper
77
78
79abs = _unary_op(math_ops.abs)
80# TODO(phawkins): implement clz.
81conj = _unary_op(math_ops.conj)
82cos = _unary_op(math_ops.cos)
83ceil = _unary_op(math_ops.ceil)
84digamma = _unary_op(math_ops.digamma)
85erf = _unary_op(math_ops.erf)
86erfc = _unary_op(math_ops.erfc)
87erfinv = _unary_op(math_ops.erfinv)
88ndtri = _unary_op(math_ops.ndtri)
89exp = _unary_op(math_ops.exp)
90expm1 = _unary_op(math_ops.expm1)
91floor = _unary_op(math_ops.floor)
92imag = _unary_op(math_ops.imag)
93is_finite = _unary_op(math_ops.is_finite)
94lgamma = _unary_op(math_ops.lgamma)
95log = _unary_op(math_ops.log)
96log1p = _unary_op(math_ops.log1p)
97logical_not = _unary_op(math_ops.logical_not)
98neg = _unary_op(math_ops.neg)
99real = _unary_op(math_ops.real)
100# TODO(phawkins): unlike xla::Round, this rounds to even instead of zero for
101# numbers halfway between two integers.
102round = _unary_op(math_ops.round)
103sin = _unary_op(math_ops.sin)
104sign = _unary_op(math_ops.sign)
105tanh = _unary_op(math_ops.tanh)
106
107# Bessel
108bessel_i0e = _unary_op(special_math_ops.bessel_i0e)
109bessel_i1e = _unary_op(special_math_ops.bessel_i1e)
110
111# Binary operators
112
113# The main difference between TensorFlow and XLA binary ops is the broadcasting
114# semantics. TensorFlow uses Numpy-style broadcasting semantics, whereas XLA
115# requires an explicit specification of which dimensions to broadcast if the
116# arguments have different ranks.
117
118
119def _broadcasting_binary_op(fn):
120  """Wraps a binary Tensorflow operator and performs XLA-style broadcasting."""
121
122  def broadcasting_binary_op_wrapper(x, y, broadcast_dims=None, name=None):
123    """Inner wrapper function."""
124    broadcast_dims = broadcast_dims or []
125    broadcast_dims = ops.convert_to_tensor(broadcast_dims, dtypes.int64)
126    # Rather than relying on having static shape information in the TensorFlow
127    # graph, we use an XlaBroadcastHelper op that can compute the correct shapes
128    # at JIT compilation time.
129    x, y = gen_xla_ops.xla_broadcast_helper(x, y, broadcast_dims)
130    return fn(x, y, name=name)
131
132  return broadcasting_binary_op_wrapper
133
134
135# Map from TF signed types to TF unsigned types.
136_SIGNED_TO_UNSIGNED_TABLE = {
137    dtypes.int8: dtypes.uint8,
138    dtypes.int16: dtypes.uint16,
139    dtypes.int32: dtypes.uint32,
140    dtypes.int64: dtypes.uint64,
141}
142
143# Map from TF unsigned types to TF signed types.
144_UNSIGNED_TO_SIGNED_TABLE = {
145    dtypes.uint8: dtypes.int8,
146    dtypes.uint16: dtypes.int16,
147    dtypes.uint32: dtypes.int32,
148    dtypes.uint64: dtypes.int64,
149}
150
151
152def _shift_right_logical_helper(x, y, name=None):
153  """Performs an integer right logical shift irrespective of input type."""
154  assert y.dtype == x.dtype
155  dtype = x.dtype
156  signed = dtype in _SIGNED_TO_UNSIGNED_TABLE
157  if signed:
158    unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[dtype]
159    x = math_ops.cast(x, unsigned_dtype)
160    y = math_ops.cast(y, unsigned_dtype)
161  output = bitwise_ops.right_shift(x, y, name=name)
162  if signed:
163    output = math_ops.cast(output, dtype)
164  return output
165
166
167def _shift_right_arithmetic_helper(x, y, name=None):
168  """Performs an integer right arithmetic shift irrespective of input type."""
169  assert y.dtype == x.dtype
170  dtype = x.dtype
171  unsigned = dtype in _UNSIGNED_TO_SIGNED_TABLE
172  if unsigned:
173    signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[dtype]
174    x = math_ops.cast(x, signed_dtype)
175    y = math_ops.cast(y, signed_dtype)
176  output = bitwise_ops.right_shift(x, y, name=name)
177  if unsigned:
178    output = math_ops.cast(output, dtype)
179  return output
180
181
182add = _broadcasting_binary_op(math_ops.add)
183sub = _broadcasting_binary_op(math_ops.sub)
184mul = _broadcasting_binary_op(math_ops.mul)
185div = _broadcasting_binary_op(math_ops.div)
186rem = _broadcasting_binary_op(gen_math_ops.mod)
187max = _broadcasting_binary_op(math_ops.maximum)
188min = _broadcasting_binary_op(math_ops.minimum)
189atan2 = _broadcasting_binary_op(math_ops.atan2)
190complex = _broadcasting_binary_op(math_ops.complex)
191logical_and = _broadcasting_binary_op(math_ops.logical_and)
192logical_or = _broadcasting_binary_op(math_ops.logical_or)
193logical_xor = _broadcasting_binary_op(math_ops.logical_xor)
194eq = _broadcasting_binary_op(math_ops.equal)
195ne = _broadcasting_binary_op(math_ops.not_equal)
196ge = _broadcasting_binary_op(math_ops.greater_equal)
197gt = _broadcasting_binary_op(math_ops.greater)
198le = _broadcasting_binary_op(math_ops.less_equal)
199lt = _broadcasting_binary_op(math_ops.less)
200pow = _broadcasting_binary_op(math_ops.pow)
201shift_left = _broadcasting_binary_op(bitwise_ops.left_shift)
202shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper)
203shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper)
204
205igamma = _broadcasting_binary_op(math_ops.igamma)
206igamma_grad_a = _broadcasting_binary_op(gen_math_ops.igamma_grad_a)
207random_gamma_grad = _broadcasting_binary_op(gen_random_ops.random_gamma_grad)
208igammac = _broadcasting_binary_op(math_ops.igammac)
209polygamma = _broadcasting_binary_op(math_ops.polygamma)
210zeta = _broadcasting_binary_op(math_ops.zeta)
211
212
213def _binary_op(fn):
214  """Wrapper that restricts `fn` to have the correct signature."""
215
216  def binary_op_wrapper(x, y, name=None):
217    return fn(x, y, name=name)
218
219  return binary_op_wrapper
220
221
222transpose = _binary_op(array_ops.transpose)
223rev = _binary_op(array_ops.reverse)
224
225bitcast_convert_type = array_ops.bitcast
226
227
228def broadcast(x, dims, name=None):
229  x = ops.convert_to_tensor(x)
230  shape = array_ops.concat([constant_op.constant(dims),
231                            array_ops.shape(x)],
232                           axis=0)
233  return array_ops.broadcast_to(x, shape, name=name)
234
235
236def clamp(a, x, b, name=None):
237  return min(max(a, x, name=name), b, name=name)
238
239
240concatenate = array_ops.concat
241
242
243def conv(lhs,
244         rhs,
245         window_strides,
246         padding,
247         lhs_dilation,
248         rhs_dilation,
249         dimension_numbers,
250         feature_group_count=1,
251         precision_config=None,
252         name=None):
253  """Wraps the XLA ConvGeneralDilated operator.
254
255  ConvGeneralDilated is the most general form of XLA convolution and is
256  documented at
257  https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
258
259  Args:
260    lhs: the input tensor
261    rhs: the kernel tensor
262    window_strides: the inter-window strides
263    padding: the padding to apply at the start and end of each input dimensions
264    lhs_dilation: dilation to apply between input elements
265    rhs_dilation: dilation to apply between kernel elements
266    dimension_numbers: a `ConvolutionDimensionNumbers` proto.
267    feature_group_count: number of feature groups for grouped convolution.
268    precision_config: a `xla.PrecisionConfig` proto.
269    name: an optional name for the operator
270
271  Returns:
272    A tensor representing the output of the convolution.
273  """
274  precision_config_proto = ""
275  if precision_config:
276    precision_config_proto = precision_config.SerializeToString()
277  return gen_xla_ops.xla_conv(
278      lhs,
279      rhs,
280      window_strides=window_strides,
281      padding=padding,
282      lhs_dilation=lhs_dilation,
283      rhs_dilation=rhs_dilation,
284      feature_group_count=feature_group_count,
285      dimension_numbers=dimension_numbers.SerializeToString(),
286      precision_config=precision_config_proto,
287      name=name)
288
289
290convert_element_type = math_ops.cast
291
292
293def dot(lhs, rhs, name=None):
294  return math_ops.tensordot(lhs, rhs, axes=1, name=name)
295
296
297def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None):
298  precision_config_proto = ""
299  if precision_config:
300    precision_config_proto = precision_config.SerializeToString()
301  return gen_xla_ops.xla_dot(
302      lhs,
303      rhs,
304      dimension_numbers=dimension_numbers.SerializeToString(),
305      precision_config=precision_config_proto,
306      name=name)
307
308
309def self_adjoint_eig(a, lower, max_iter, epsilon):
310  return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon)
311
312
313def svd(a, max_iter, epsilon, precision_config=None):
314  precision_config_proto = ""
315  if precision_config:
316    precision_config_proto = precision_config.SerializeToString()
317  return gen_xla_ops.xla_svd(a, max_iter, epsilon, precision_config_proto)
318
319
320dynamic_slice = gen_xla_ops.xla_dynamic_slice
321dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice
322einsum = gen_xla_ops.xla_einsum
323
324# TODO(phawkins): generalize tf.pad to support interior padding, and then remove
325# the XLA-specific pad operator.
326pad = gen_xla_ops.xla_pad
327
328
329def random_normal(mu, sigma, dims, name=None):
330  mu = ops.convert_to_tensor(mu)
331  return random_ops.random_normal(
332      dims, mean=mu, stddev=sigma, dtype=mu.dtype, name=name)
333
334
335def random_uniform(minval, maxval, dims, name=None):
336  minval = ops.convert_to_tensor(minval)
337  return random_ops.random_uniform(
338      dims, minval, maxval, dtype=minval.dtype, name=name)
339
340
341recv = gen_xla_ops.xla_recv
342reduce = gen_xla_ops.xla_reduce
343variadic_reduce = gen_xla_ops.xla_variadic_reduce
344
345ops.no_gradient("XlaVariadicReduce")
346
347
348def reduce_window(operand,
349                  init,
350                  reducer,
351                  window_dimensions,
352                  window_strides=None,
353                  base_dilations=None,
354                  window_dilations=None,
355                  padding=None,
356                  name=None):
357  """Wraps the XLA ReduceWindow operator.
358
359  ReduceWindow is documented at
360  https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
361
362  Args:
363    operand: the input tensor
364    init: a scalar tensor representing the initial value for the reduction
365    reducer: a reduction function that combines a pair of scalars.
366    window_dimensions: shape of the window, as a list of integers
367    window_strides: inter-window strides, as a list of integers. Optional; if
368      omitted, defaults to strides of 1.
369    padding: padding to apply to 'operand'. List of (low, high) pairs of
370      integers that specify the padding to apply before and after each
371      dimension. Optional; if omitted, defaults to no padding.
372    name: the operator name, or None.
373
374  Returns:
375    A tensor that represents the output of the reduce_window operator.
376  """
377  window_strides = window_strides or [1] * len(window_dimensions)
378  base_dilations = base_dilations or [1] * len(window_dimensions)
379  window_dilations = window_dilations or [1] * len(window_dimensions)
380  padding = padding or [(0, 0)] * len(window_dimensions)
381  return gen_xla_ops.xla_reduce_window(
382      input=operand,
383      init_value=init,
384      window_dimensions=window_dimensions,
385      window_strides=window_strides,
386      base_dilations=base_dilations,
387      window_dilations=window_dilations,
388      padding=padding,
389      computation=reducer,
390      name=name)
391
392
393replica_id = gen_xla_ops.xla_replica_id
394
395# Set a static bound for the given input value as a hint to Xla compiler,
396# returns the same value.
397# Usage:
398# def f(t, p):
399#   p = xla.set_bound(p, 3) # Tells xla the constraint that p <= 3.
400#   return t[:p]            # xla knows the bound of the slice is 3.
401set_bound = gen_xla_ops.xla_set_bound
402
403
404# Make a static dimension into a xla bounded dynamic dimension. The current
405# static dimension size will become the bound and the second operand becomes the
406# dynamic size of the dimension.
407#
408# This should mostly be used for testing.
409#
410# def f():
411#   array = tf.convert_to_tensor([[1, 2, 3, 4, 5]])
412#   # Tells xla the valid size of the array is 3.
413#   dim = 0
414#   p = xla_set_dynamic_dimension_size(array, dim, 3)
415#   assert(reduce_sum(p) == 6) # xla knows only the first 3 elements are valid.
416set_dynamic_dimension_size = gen_xla_ops.xla_set_dynamic_dimension_size
417
418
419def reshape(x, new_sizes, dimensions=None, name=None):
420  if dimensions is not None:
421    x = array_ops.transpose(x, dimensions)
422  x = array_ops.reshape(x, new_sizes, name=name)
423  return x
424
425
426def select(condition, x, y, name=None):
427  return array_ops.where(condition, x, y, name)
428
429
430select_and_scatter = gen_xla_ops.xla_select_and_scatter
431send = gen_xla_ops.xla_send
432
433
434def slice(x, start_dims, limit_dims, strides):
435  spec = [
436      _slice(start, limit, stride)
437      for (start, limit, stride) in zip(start_dims, limit_dims, strides)
438  ]
439  return x[tuple(spec)]
440
441
442sharding = gen_xla_ops.xla_sharding
443
444
445@ops.RegisterGradient("XlaSharding")
446def _sharding_grad(op, grad):
447  sharding_attr = op.get_attr("sharding")
448  grad_sharding = gen_xla_ops.xla_sharding(grad, sharding=sharding_attr)
449  # pylint: disable=protected-access
450  grad_sharding.op._set_attr("_XlaSharding",
451                             attr_value_pb2.AttrValue(s=sharding_attr))
452  return [grad_sharding]
453
454
455spmd_full_to_shard_shape = gen_xla_ops.xla_spmd_full_to_shard_shape
456spmd_shard_to_full_shape = gen_xla_ops.xla_spmd_shard_to_full_shape
457
458
459@ops.RegisterGradient("XlaSpmdFullToShardShape")
460def _spmd_full_to_shard_shape_grad(op, grad):
461  s2f = gen_xla_ops.xla_spmd_shard_to_full_shape(
462      grad,
463      manual_sharding=op.get_attr("manual_sharding"),
464      full_shape=op.inputs[0].shape.as_list())
465  return [s2f]
466
467
468@ops.RegisterGradient("XlaSpmdShardToFullShape")
469def _spmd_shard_to_full_shape_grad(op, grad):
470  f2s = gen_xla_ops.xla_spmd_full_to_shard_shape(
471      grad, manual_sharding=op.get_attr("manual_sharding"))
472  return [f2s]
473
474
475sort = gen_xla_ops.xla_sort
476key_value_sort = gen_xla_ops.xla_key_value_sort
477variadic_sort = gen_xla_ops.xla_variadic_sort
478while_loop = gen_xla_ops.xla_while
479dequantize = gen_xla_ops.xla_dequantize
480
481
482def gather(operand, start_indices, dimension_numbers, slice_sizes,
483           indices_are_sorted=False, name=None):
484  return gen_xla_ops.xla_gather(
485      operand,
486      start_indices,
487      slice_sizes=slice_sizes,
488      dimension_numbers=dimension_numbers.SerializeToString(),
489      indices_are_sorted=indices_are_sorted,
490      name=name)
491
492
493def scatter(operand, scatter_indices, updates, update_computation,
494            dimension_numbers, indices_are_sorted=False, name=None):
495  return gen_xla_ops.xla_scatter(
496      operand,
497      scatter_indices,
498      updates,
499      update_computation=update_computation,
500      dimension_numbers=dimension_numbers.SerializeToString(),
501      indices_are_sorted=indices_are_sorted,
502      name=name)
503