1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Math Operations.
16
17Note: Functions taking `Tensor` arguments can also take anything accepted by
18`tf.convert_to_tensor`.
19
20Note: Elementwise binary operations in TensorFlow follow [numpy-style
21broadcasting](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
22
23TensorFlow provides a variety of math functions including:
24
25* Basic arithmetic operators and trigonometric functions.
26* Special math functions (like: `tf.math.igamma` and `tf.math.zeta`)
27* Complex number functions (like: `tf.math.imag` and `tf.math.angle`)
28* Reductions and scans (like: `tf.math.reduce_mean` and `tf.math.cumsum`)
29* Segment functions (like: `tf.math.segment_sum`)
30
31See: `tf.linalg` for matrix and tensor functions.
32
33<a id=Segmentation></a>
34
35## About Segmentation
36
37TensorFlow provides several operations that you can use to perform common
38math computations on tensor segments.
39Here a segmentation is a partitioning of a tensor along
40the first dimension, i.e. it  defines a mapping from the first dimension onto
41`segment_ids`. The `segment_ids` tensor should be the size of
42the first dimension, `d0`, with consecutive IDs in the range `0` to `k`,
43where `k<d0`.
44In particular, a segmentation of a matrix tensor is a mapping of rows to
45segments.
46
47For example:
48
49```python
50c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
51tf.math.segment_sum(c, tf.constant([0, 0, 1]))
52#  ==>  [[0 0 0 0]
53#        [5 6 7 8]]
54```
55
56The standard `segment_*` functions assert that the segment indices are sorted.
57If you have unsorted indices use the equivalent `unsorted_segment_` function.
58These functions take an additional argument `num_segments` so that the output
59tensor can be efficiently allocated.
60
61``` python
62c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
63tf.math.unsorted_segment_sum(c, tf.constant([0, 1, 0]), num_segments=2)
64# ==> [[ 6,  8, 10, 12],
65#       [-1, -2, -3, -4]]
66```
67
68"""
69from __future__ import absolute_import
70from __future__ import division
71from __future__ import print_function
72
73import numbers
74import numpy as np
75import six
76from six.moves import builtins
77from six.moves import xrange  # pylint: disable=redefined-builtin
78
79from tensorflow.python.eager import context
80from tensorflow.python.framework import constant_op
81from tensorflow.python.framework import dtypes
82from tensorflow.python.framework import graph_util
83from tensorflow.python.framework import ops
84from tensorflow.python.framework import sparse_tensor
85from tensorflow.python.framework import tensor_shape
86from tensorflow.python.framework import tensor_util
87from tensorflow.python.ops import array_ops
88from tensorflow.python.ops import gen_array_ops
89from tensorflow.python.ops import gen_bitwise_ops
90from tensorflow.python.ops import gen_data_flow_ops
91from tensorflow.python.ops import gen_math_ops
92from tensorflow.python.ops import gen_nn_ops
93from tensorflow.python.ops import gen_sparse_ops
94# go/tf-wildcard-import
95# pylint: disable=wildcard-import
96from tensorflow.python.ops.gen_math_ops import *
97# pylint: enable=wildcard-import
98from tensorflow.python.platform import tf_logging as logging
99from tensorflow.python.util import compat
100from tensorflow.python.util import deprecation
101from tensorflow.python.util import dispatch
102from tensorflow.python.util import nest
103from tensorflow.python.util import tf_decorator
104from tensorflow.python.util.compat import collections_abc
105from tensorflow.python.util.lazy_loader import LazyLoader
106from tensorflow.python.util.tf_export import tf_export
107
108
109np_dtypes = LazyLoader(
110    "np_dtypes", globals(),
111    "tensorflow.python.ops.numpy_ops.np_dtypes")
112
113
114# Aliases for some automatically-generated names.
115nextafter = gen_math_ops.next_after
116
117
118@tf_export("linspace", v1=["lin_space", "linspace"])
119@dispatch.add_dispatch_support
120@deprecation.deprecated_endpoints("lin_space")
121def linspace_nd(start, stop, num, name=None, axis=0):
122  r"""Generates evenly-spaced values in an interval along a given axis.
123
124  A sequence of `num` evenly-spaced values are generated beginning at `start`
125  along a given `axis`.
126  If `num > 1`, the values in the sequence increase by
127  `(stop - start) / (num - 1)`, so that the last one is exactly `stop`.
128  If `num <= 0`, `ValueError` is raised.
129
130  Matches
131  [np.linspace](https://docs.scipy.org/doc/numpy/reference/generated/numpy.linspace.html)'s
132  behaviour
133  except when `num == 0`.
134
135  For example:
136
137  ```
138  tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0  11.0  12.0]
139  ```
140
141  `Start` and `stop` can be tensors of arbitrary size:
142
143  >>> tf.linspace([0., 5.], [10., 40.], 5, axis=0)
144  <tf.Tensor: shape=(5, 2), dtype=float32, numpy=
145  array([[ 0.  ,  5.  ],
146         [ 2.5 , 13.75],
147         [ 5.  , 22.5 ],
148         [ 7.5 , 31.25],
149         [10.  , 40.  ]], dtype=float32)>
150
151  `Axis` is where the values will be generated (the dimension in the
152  returned tensor which corresponds to the axis will be equal to `num`)
153
154  >>> tf.linspace([0., 5.], [10., 40.], 5, axis=-1)
155  <tf.Tensor: shape=(2, 5), dtype=float32, numpy=
156  array([[ 0.  ,  2.5 ,  5.  ,  7.5 , 10.  ],
157         [ 5.  , 13.75, 22.5 , 31.25, 40.  ]], dtype=float32)>
158
159
160
161  Args:
162    start: A `Tensor`. Must be one of the following types: `bfloat16`,
163      `float32`, `float64`. N-D tensor. First entry in the range.
164    stop: A `Tensor`. Must have the same type and shape as `start`. N-D tensor.
165      Last entry in the range.
166    num: A `Tensor`. Must be one of the following types: `int32`, `int64`. 0-D
167      tensor. Number of values to generate.
168    name: A name for the operation (optional).
169    axis: Axis along which the operation is performed (used only when N-D
170      tensors are provided).
171
172  Returns:
173    A `Tensor`. Has the same type as `start`.
174  """
175
176  with ops.name_scope(name, "linspace", [start, stop]):
177    start = ops.convert_to_tensor(start, name="start")
178    # stop must be convertible to the same dtype as start
179    stop = ops.convert_to_tensor(stop, name="stop", dtype=start.dtype)
180    num_int = array_ops.convert_to_int_tensor(num, name="num")
181    num = cast(num_int, dtype=start.dtype)
182
183    broadcast_shape = array_ops.broadcast_dynamic_shape(
184        array_ops.shape(start), array_ops.shape(stop))
185    start = array_ops.broadcast_to(start, broadcast_shape)
186    stop = array_ops.broadcast_to(stop, broadcast_shape)
187
188    expanded_start = array_ops.expand_dims(start, axis=axis)
189    expanded_stop = array_ops.expand_dims(stop, axis=axis)
190
191    shape = array_ops.shape(expanded_start)
192    ndims = array_ops.shape(shape)[0]
193
194    axis = array_ops.where_v2(axis >= 0, axis, ndims + axis)
195
196    # The purpose is to avoid having negative values when repeating.
197    num_fill = gen_math_ops.maximum(num_int - 2, 0)
198    # To avoid having negative values in the range or zero division
199    # the result is sliced in the end so a correct result is returned for
200    # num == 1, and num == 0.
201    n_steps = gen_math_ops.maximum(num_int - 1, 1)
202    delta = (expanded_stop - expanded_start) / cast(n_steps,
203                                                    expanded_stop.dtype)
204    # Re-cast tensors as delta.
205    expanded_start = cast(expanded_start, delta.dtype)
206    expanded_stop = cast(expanded_stop, delta.dtype)
207    # If num < 0, we will throw exception in the range
208    # otherwise use the same div for delta
209    range_end = array_ops.where_v2(num_int >= 0, n_steps, -1)
210    # Even though range supports an output dtype, its limited
211    # (e.g. doesn't support half at the moment).
212    desired_range = cast(range(1, range_end, dtype=dtypes.int64), delta.dtype)
213    mask = gen_math_ops.equal(axis, range(ndims))
214    # desired_range_shape is [1. 1. 1. ... 1. num_fill 1. 1. ... 1.], where the
215    # index of num_fill is equal to axis.
216    desired_range_shape = array_ops.where_v2(mask, num_fill, 1)
217    desired_range = array_ops.reshape(desired_range, desired_range_shape)
218
219    res = expanded_start + delta * desired_range
220
221    # Add the start and endpoints to the result, and slice out the desired
222    # portion.
223    all_tensors = (expanded_start, res, expanded_stop)
224    concatenated = array_ops.concat(all_tensors, axis=axis)
225    begin = array_ops.zeros_like(shape)
226    size = array_ops.where_v2(mask, num_int, shape)
227
228    return array_ops.slice(concatenated, begin, size)
229
230
231linspace = linspace_nd
232
233arg_max = deprecation.deprecated(None, "Use `tf.math.argmax` instead")(arg_max)  # pylint: disable=used-before-assignment
234arg_min = deprecation.deprecated(None, "Use `tf.math.argmin` instead")(arg_min)  # pylint: disable=used-before-assignment
235tf_export(v1=["arg_max"])(dispatch.add_dispatch_support(arg_max))
236tf_export(v1=["arg_min"])(dispatch.add_dispatch_support(arg_min))
237
238
239# This is set by resource_variable_ops.py. It is included in this way since
240# there is a circular dependency between math_ops and resource_variable_ops
241_resource_variable_type = None
242
243
244def _set_doc(doc):
245
246  def _decorator(func):
247    func.__doc__ = doc
248    return func
249
250  return _decorator
251
252
253# pylint: disable=redefined-builtin
254@tf_export(v1=["math.argmax", "argmax"])
255@dispatch.add_dispatch_support
256@deprecation.deprecated_args(None, "Use the `axis` argument instead",
257                             "dimension")
258@_set_doc(
259    gen_math_ops.arg_max.__doc__.replace("dimensions",
260                                         "axes").replace("dimension", "axis"))
261def argmax(input,
262           axis=None,
263           name=None,
264           dimension=None,
265           output_type=dtypes.int64):
266  axis = deprecation.deprecated_argument_lookup("axis", axis, "dimension",
267                                                dimension)
268  return argmax_v2(input, axis, output_type, name)
269
270
271@tf_export("math.argmax", "argmax", v1=[])
272@dispatch.add_dispatch_support
273def argmax_v2(input, axis=None, output_type=dtypes.int64, name=None):
274  """Returns the index with the largest value across axes of a tensor.
275
276  In case of identity returns the smallest index.
277
278  For example:
279
280  >>> A = tf.constant([2, 20, 30, 3, 6])
281  >>> tf.math.argmax(A)  # A[2] is maximum in tensor A
282  <tf.Tensor: shape=(), dtype=int64, numpy=2>
283  >>> B = tf.constant([[2, 20, 30, 3, 6], [3, 11, 16, 1, 8],
284  ...                  [14, 45, 23, 5, 27]])
285  >>> tf.math.argmax(B, 0)
286  <tf.Tensor: shape=(5,), dtype=int64, numpy=array([2, 2, 0, 2, 2])>
287  >>> tf.math.argmax(B, 1)
288  <tf.Tensor: shape=(3,), dtype=int64, numpy=array([2, 2, 1])>
289  >>> C = tf.constant([0, 0, 0, 0])
290  >>> tf.math.argmax(C) # Returns smallest index in case of ties
291  <tf.Tensor: shape=(), dtype=int64, numpy=0>
292
293  Args:
294    input: A `Tensor`.
295    axis: An integer, the axis to reduce across. Default to 0.
296    output_type: An optional output dtype (`tf.int32` or `tf.int64`). Defaults
297      to `tf.int64`.
298    name: An optional name for the operation.
299
300  Returns:
301    A `Tensor` of type `output_type`.
302  """
303  if axis is None:
304    axis = 0
305  return gen_math_ops.arg_max(input, axis, name=name, output_type=output_type)
306
307
308@tf_export(v1=["math.argmin", "argmin"])
309@dispatch.add_dispatch_support
310@deprecation.deprecated_args(None, "Use the `axis` argument instead",
311                             "dimension")
312@_set_doc(
313    gen_math_ops.arg_min.__doc__.replace("dimensions",
314                                         "axes").replace("dimension", "axis"))
315def argmin(input,
316           axis=None,
317           name=None,
318           dimension=None,
319           output_type=dtypes.int64):
320  axis = deprecation.deprecated_argument_lookup("axis", axis, "dimension",
321                                                dimension)
322  return argmin_v2(input, axis, output_type, name)
323
324
325@tf_export("math.argmin", "argmin", v1=[])
326@dispatch.add_dispatch_support
327def argmin_v2(input, axis=None, output_type=dtypes.int64, name=None):
328  """Returns the index with the smallest value across axes of a tensor.
329
330  Returns the smallest index in case of ties.
331
332  Args:
333    input: A `Tensor`. Must be one of the following types: `float32`, `float64`,
334      `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`,
335      `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, `uint32`,
336      `uint64`.
337    axis: A `Tensor`. Must be one of the following types: `int32`, `int64`.
338      int32 or int64, must be in the range `-rank(input), rank(input))`.
339      Describes which axis of the input Tensor to reduce across. For vectors,
340      use axis = 0.
341    output_type: An optional `tf.DType` from: `tf.int32, tf.int64`. Defaults to
342      `tf.int64`.
343    name: A name for the operation (optional).
344
345  Returns:
346    A `Tensor` of type `output_type`.
347
348  Usage:
349  ```python
350  import tensorflow as tf
351  a = [1, 10, 26.9, 2.8, 166.32, 62.3]
352  b = tf.math.argmin(input = a)
353  c = tf.keras.backend.eval(b)
354  # c = 0
355  # here a[0] = 1 which is the smallest element of a across axis 0
356  ```
357  """
358  if axis is None:
359    axis = 0
360  return gen_math_ops.arg_min(input, axis, name=name, output_type=output_type)
361
362
363# pylint: enable=redefined-builtin
364
365
366# pylint: disable=anomalous-backslash-in-string,protected-access
367# pylint: disable=g-docstring-has-escape
368@tf_export("math.abs", "abs")
369@dispatch.add_dispatch_support
370def abs(x, name=None):  # pylint: disable=redefined-builtin
371  r"""Computes the absolute value of a tensor.
372
373  Given a tensor of integer or floating-point values, this operation returns a
374  tensor of the same type, where each element contains the absolute value of the
375  corresponding element in the input.
376
377  Given a tensor `x` of complex numbers, this operation returns a tensor of type
378  `float32` or `float64` that is the absolute value of each element in `x`. For
379  a complex number \\(a + bj\\), its absolute value is computed as
380  \\(\sqrt{a^2 + b^2}\\).
381
382  For example:
383
384  >>> # real number
385  >>> x = tf.constant([-2.25, 3.25])
386  >>> tf.abs(x)
387  <tf.Tensor: shape=(2,), dtype=float32,
388  numpy=array([2.25, 3.25], dtype=float32)>
389
390  >>> # complex number
391  >>> x = tf.constant([[-2.25 + 4.75j], [-3.25 + 5.75j]])
392  >>> tf.abs(x)
393  <tf.Tensor: shape=(2, 1), dtype=float64, numpy=
394  array([[5.25594901],
395         [6.60492241]])>
396
397  Args:
398    x: A `Tensor` or `SparseTensor` of type `float16`, `float32`, `float64`,
399      `int32`, `int64`, `complex64` or `complex128`.
400    name: A name for the operation (optional).
401
402  Returns:
403    A `Tensor` or `SparseTensor` of the same size, type and sparsity as `x`,
404      with absolute values. Note, for `complex64` or `complex128` input, the
405      returned `Tensor` will be of type `float32` or `float64`, respectively.
406  """
407  with ops.name_scope(name, "Abs", [x]) as name:
408    x = ops.convert_to_tensor(x, name="x")
409    if x.dtype.is_complex:
410      return gen_math_ops.complex_abs(x, Tout=x.dtype.real_dtype, name=name)
411    return gen_math_ops._abs(x, name=name)
412
413
414# pylint: enable=g-docstring-has-escape
415
416
417# pylint: disable=redefined-builtin
418def _bucketize(input, boundaries, name=None):
419  return gen_math_ops.bucketize(input=input, boundaries=boundaries, name=name)
420
421
422# pylint: enable=redefined-builtin
423
424
425class DivideDelegateWithName(object):
426  """Use Python2/Python3 division delegation to implement divide for tensors."""
427
428  def __init__(self, x, name):
429    """Construct DivideDelegateWithName.
430
431    Args:
432      x: Tensor to use as left operand in operator overloads
433      name: The name that is preferred for the op created.
434    """
435    self.x = x
436    self.name = name
437
438  def __truediv__(self, y):
439    return _truediv_python3(self.x, y, self.name)
440
441  def __floordiv__(self, y):
442    return floordiv(self.x, y, self.name)
443
444  def __div__(self, y):
445    return _div_python2(self.x, y, self.name)
446
447
448@tf_export("math.divide", "divide")
449@dispatch.add_dispatch_support
450def divide(x, y, name=None):
451  """Computes Python style division of `x` by `y`.
452
453  For example:
454
455  >>> x = tf.constant([16, 12, 11])
456  >>> y = tf.constant([4, 6, 2])
457  >>> tf.divide(x,y)
458  <tf.Tensor: shape=(3,), dtype=float64,
459  numpy=array([4. , 2. , 5.5])>
460
461  Args:
462    x: A `Tensor`
463    y: A `Tensor`
464    name: A name for the operation (optional).
465
466  Returns:
467    A `Tensor` with same shape as input
468  """
469
470  if name is not None:
471    # Cannot use tensors operator overload, because it has no way to track
472    # override names. Use a dummy class to track the runtime division behavior
473    return DivideDelegateWithName(x, name) / y
474  else:
475    # We do conversion here to make sure at least x is a tensor.
476    if not tensor_util.is_tf_type(x):
477      dtype = y.dtype.base_dtype if tensor_util.is_tf_type(y) else None
478      x = ops.convert_to_tensor(x, dtype=dtype)
479    return x / y
480
481
482@tf_export("math.multiply", "multiply")
483@dispatch.add_dispatch_support
484def multiply(x, y, name=None):
485  """Returns an element-wise x * y.
486
487  For example:
488
489  >>> x = tf.constant(([1, 2, 3, 4]))
490  >>> tf.math.multiply(x, x)
491  <tf.Tensor: shape=(4,), dtype=..., numpy=array([ 1,  4,  9, 16], dtype=int32)>
492
493  Since `tf.math.multiply` will convert its arguments to `Tensor`s, you can also
494  pass in non-`Tensor` arguments:
495
496  >>> tf.math.multiply(7,6)
497  <tf.Tensor: shape=(), dtype=int32, numpy=42>
498
499  If `x.shape` is not the same as `y.shape`, they will be broadcast to a
500  compatible shape. (More about broadcasting
501  [here](https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).)
502
503  For example:
504
505  >>> x = tf.ones([1, 2]);
506  >>> y = tf.ones([2, 1]);
507  >>> x * y  # Taking advantage of operator overriding
508  <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
509  array([[1., 1.],
510       [1., 1.]], dtype=float32)>
511
512  The reduction version of this elementwise operation is `tf.math.reduce_prod`
513
514  Args:
515    x: A Tensor. Must be one of the following types: `bfloat16`,
516      `half`, `float32`, `float64`, `uint8`, `int8`, `uint16`,
517      `int16`, `int32`, `int64`, `complex64`, `complex128`.
518    y: A `Tensor`. Must have the same type as `x`.
519    name: A name for the operation (optional).
520
521  Returns:
522
523  A `Tensor`.  Has the same type as `x`.
524
525  Raises:
526
527   * InvalidArgumentError: When `x` and `y` have incompatible shapes or types.
528  """
529
530  return gen_math_ops.mul(x, y, name)
531
532
533# TODO(aselle): put deprecation in after another round of global code changes
534@deprecation.deprecated(
535    "2016-12-30",
536    "`tf.mul(x, y)` is deprecated; use `tf.math.multiply(x, y)` or `x * y`")
537def _mul(x, y, name=None):
538  return gen_math_ops.mul(x, y, name)
539
540
541_mul.__doc__ = (
542    gen_math_ops.mul.__doc__ + ("" if _mul.__doc__ is None else _mul.__doc__))
543
544
545@tf_export("math.subtract", "subtract")
546@dispatch.add_dispatch_support
547def subtract(x, y, name=None):
548  return gen_math_ops.sub(x, y, name)
549
550
551subtract.__doc__ = gen_math_ops.sub.__doc__
552
553
554# TODO(aselle): put deprecation in after another round of global code changes
555@deprecation.deprecated(
556    "2016-12-30",
557    "`tf.sub(x, y)` is deprecated, please use `tf.subtract(x, y)` or `x - y`")
558def _sub(x, y, name=None):
559  return gen_math_ops.sub(x, y, name)
560
561
562_sub.__doc__ = (
563    gen_math_ops.sub.__doc__ + ("" if _sub.__doc__ is None else _sub.__doc__))
564
565negative = gen_math_ops.neg
566
567
568# pylint: disable=g-docstring-has-escape
569@deprecation.deprecated(
570    "2016-12-30",
571    "`tf.neg(x)` is deprecated, please use `tf.negative(x)` or `-x`")
572def _neg(x, name=None):
573  """Computes numerical negative value element-wise.
574
575  I.e., \\(y = -x\\).
576
577  Args:
578    x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
579      `float32`, `float64`, `int32`, `int64`, `complex64`, `complex128`.
580    name: A name for the operation (optional).
581
582  Returns:
583    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
584  """
585  return negative(x, name)
586
587
588# pylint: enable=g-docstring-has-escape
589
590
591@tf_export(v1=["math.scalar_mul", "scalar_mul"])
592@dispatch.add_dispatch_support
593def scalar_mul(scalar, x, name=None):
594  """Multiplies a scalar times a `Tensor` or `IndexedSlices` object.
595
596  Intended for use in gradient code which might deal with `IndexedSlices`
597  objects, which are easy to multiply by a scalar but more expensive to
598  multiply with arbitrary tensors.
599
600  Args:
601    scalar: A 0-D scalar `Tensor`. Must have known shape.
602    x: A `Tensor` or `IndexedSlices` to be scaled.
603    name: A name for the operation (optional).
604
605  Returns:
606    `scalar * x` of the same type (`Tensor` or `IndexedSlices`) as `x`.
607
608  Raises:
609    ValueError: if scalar is not a 0-D `scalar`.
610  """
611  scalar = ops.convert_to_tensor(
612      scalar, dtype=x.dtype.base_dtype, name="scalar")
613  shape = scalar.get_shape()
614  if shape.ndims == 0:
615    if isinstance(x, ops.IndexedSlices):
616      return ops.IndexedSlices(
617          gen_math_ops.mul(scalar, x.values, name), x.indices, x.dense_shape)
618    else:
619      return gen_math_ops.mul(scalar, x, name)
620  else:
621    raise ValueError("Only scalar multiply works, got shape %s" % shape)
622
623
624@tf_export("math.scalar_mul", "scalar_mul", v1=[])
625@dispatch.add_dispatch_support
626@_set_doc(scalar_mul.__doc__)
627def scalar_mul_v2(scalar, x, name=None):
628  with ops.name_scope(name, "scalar_mul", [x]) as name:
629    return scalar_mul(scalar, x, name)
630
631
632@tf_export("math.pow", "pow")
633@dispatch.add_dispatch_support
634def pow(x, y, name=None):  # pylint: disable=redefined-builtin
635  r"""Computes the power of one value to another.
636
637  Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for
638  corresponding elements in `x` and `y`. For example:
639
640  ```python
641  x = tf.constant([[2, 2], [3, 3]])
642  y = tf.constant([[8, 16], [2, 3]])
643  tf.pow(x, y)  # [[256, 65536], [9, 27]]
644  ```
645
646  Args:
647    x: A `Tensor` of type `float16`, `float32`, `float64`, `int32`, `int64`,
648      `complex64`, or `complex128`.
649    y: A `Tensor` of type `float16`, `float32`, `float64`, `int32`, `int64`,
650      `complex64`, or `complex128`.
651    name: A name for the operation (optional).
652
653  Returns:
654    A `Tensor`.
655  """
656  with ops.name_scope(name, "Pow", [x]) as name:
657    return gen_math_ops._pow(x, y, name=name)
658
659
660# pylint: disable=redefined-builtin,redefined-outer-name
661@tf_export("dtypes.complex", "complex")
662@dispatch.add_dispatch_support
663def complex(real, imag, name=None):
664  r"""Converts two real numbers to a complex number.
665
666  Given a tensor `real` representing the real part of a complex number, and a
667  tensor `imag` representing the imaginary part of a complex number, this
668  operation returns complex numbers elementwise of the form \\(a + bj\\), where
669  *a* represents the `real` part and *b* represents the `imag` part.
670
671  The input tensors `real` and `imag` must have the same shape.
672
673  For example:
674
675  ```python
676  real = tf.constant([2.25, 3.25])
677  imag = tf.constant([4.75, 5.75])
678  tf.complex(real, imag)  # [[2.25 + 4.75j], [3.25 + 5.75j]]
679  ```
680
681  Args:
682    real: A `Tensor`. Must be one of the following types: `float32`, `float64`.
683    imag: A `Tensor`. Must have the same type as `real`.
684    name: A name for the operation (optional).
685
686  Returns:
687    A `Tensor` of type `complex64` or `complex128`.
688
689  Raises:
690    TypeError: Real and imag must be correct types
691  """
692  real = ops.convert_to_tensor(real, name="real")
693  imag = ops.convert_to_tensor(imag, name="imag")
694  with ops.name_scope(name, "Complex", [real, imag]) as name:
695    input_types = (real.dtype, imag.dtype)
696    if input_types == (dtypes.float64, dtypes.float64):
697      Tout = dtypes.complex128
698    elif input_types == (dtypes.float32, dtypes.float32):
699      Tout = dtypes.complex64
700    else:
701      raise TypeError("real and imag have incorrect types: "
702                      "{} {}".format(real.dtype.name, imag.dtype.name))
703    return gen_math_ops._complex(real, imag, Tout=Tout, name=name)
704
705
706@tf_export("math.sign", "sign")
707@dispatch.add_dispatch_support
708def sign(x, name=None):
709  r"""Returns an element-wise indication of the sign of a number.
710
711  `y = sign(x) = -1 if x < 0; 0 if x == 0; 1 if x > 0`.
712
713  For complex numbers, `y = sign(x) = x / |x| if x != 0, otherwise y = 0`.
714
715  Example usage:
716
717  >>> # real number
718  >>> tf.math.sign([0., 2., -3.])
719  <tf.Tensor: shape=(3,), dtype=float32,
720  numpy=array([ 0.,  1., -1.], dtype=float32)>
721
722  >>> # complex number
723  >>> tf.math.sign([1 + 1j, 0 + 0j])
724  <tf.Tensor: shape=(2,), dtype=complex128,
725  numpy=array([0.70710678+0.70710678j, 0.        +0.j        ])>
726
727  Args:
728   x: A Tensor. Must be one of the following types: bfloat16, half, float32,
729     float64, int32, int64, complex64, complex128.
730   name: A name for the operation (optional).
731
732  Returns:
733   A Tensor. Has the same type as x.
734
735   If x is a SparseTensor, returns SparseTensor(x.indices,
736     tf.math.sign(x.values, ...), x.dense_shape).
737  """
738  x = ops.convert_to_tensor(x)
739  if x.dtype.is_complex:
740    return gen_math_ops.div_no_nan(
741        x,
742        cast(
743            gen_math_ops.complex_abs(
744                x,
745                Tout=dtypes.float32
746                if x.dtype == dtypes.complex64 else dtypes.float64),
747            dtype=x.dtype),
748        name=name)
749  return gen_math_ops.sign(x, name=name)
750
751
752@tf_export("math.real", v1=["math.real", "real"])
753@dispatch.add_dispatch_support
754@deprecation.deprecated_endpoints("real")
755@dispatch.add_dispatch_support
756def real(input, name=None):
757  r"""Returns the real part of a complex (or real) tensor.
758
759  Given a tensor `input`, this operation returns a tensor of type `float` that
760  is the real part of each element in `input` considered as a complex number.
761
762  For example:
763
764  ```python
765  x = tf.constant([-2.25 + 4.75j, 3.25 + 5.75j])
766  tf.math.real(x)  # [-2.25, 3.25]
767  ```
768
769  If `input` is already real, it is returned unchanged.
770
771  Args:
772    input: A `Tensor`. Must have numeric type.
773    name: A name for the operation (optional).
774
775  Returns:
776    A `Tensor` of type `float32` or `float64`.
777  """
778  with ops.name_scope(name, "Real", [input]) as name:
779    input = ops.convert_to_tensor(input, name="input")
780    if input.dtype.is_complex:
781      real_dtype = input.dtype.real_dtype
782      return gen_math_ops.real(input, Tout=real_dtype, name=name)
783    else:
784      return input
785
786
787@tf_export("math.imag", v1=["math.imag", "imag"])
788@dispatch.add_dispatch_support
789@deprecation.deprecated_endpoints("imag")
790@dispatch.add_dispatch_support
791def imag(input, name=None):
792  r"""Returns the imaginary part of a complex (or real) tensor.
793
794  Given a tensor `input`, this operation returns a tensor of type `float` that
795  is the imaginary part of each element in `input` considered as a complex
796  number. If `input` is real, a tensor of all zeros is returned.
797
798  For example:
799
800  ```python
801  x = tf.constant([-2.25 + 4.75j, 3.25 + 5.75j])
802  tf.math.imag(x)  # [4.75, 5.75]
803  ```
804
805  Args:
806    input: A `Tensor`. Must be one of the following types: `float`, `double`,
807      `complex64`, `complex128`.
808    name: A name for the operation (optional).
809
810  Returns:
811    A `Tensor` of type `float32` or `float64`.
812  """
813  with ops.name_scope(name, "Imag", [input]) as name:
814    input = ops.convert_to_tensor(input, name="input")
815    if input.dtype.is_complex:
816      return gen_math_ops.imag(input, Tout=input.dtype.real_dtype, name=name)
817    else:
818      return array_ops.zeros_like(input)
819
820
821@tf_export("math.angle", v1=["math.angle", "angle"])
822@dispatch.add_dispatch_support
823@deprecation.deprecated_endpoints("angle")
824@dispatch.add_dispatch_support
825def angle(input, name=None):
826  r"""Returns the element-wise argument of a complex (or real) tensor.
827
828  Given a tensor `input`, this operation returns a tensor of type `float` that
829  is the argument of each element in `input` considered as a complex number.
830
831  The elements in `input` are considered to be complex numbers of the form
832  \\(a + bj\\), where *a* is the real part and *b* is the imaginary part.
833  If `input` is real then *b* is zero by definition.
834
835  The argument returned by this function is of the form \\(atan2(b, a)\\).
836  If `input` is real, a tensor of all zeros is returned.
837
838  For example:
839
840  ```
841  input = tf.constant([-2.25 + 4.75j, 3.25 + 5.75j], dtype=tf.complex64)
842  tf.math.angle(input).numpy()
843  # ==> array([2.0131705, 1.056345 ], dtype=float32)
844  ```
845
846  Args:
847    input: A `Tensor`. Must be one of the following types: `float`, `double`,
848      `complex64`, `complex128`.
849    name: A name for the operation (optional).
850
851  Returns:
852    A `Tensor` of type `float32` or `float64`.
853  """
854  with ops.name_scope(name, "Angle", [input]) as name:
855    input = ops.convert_to_tensor(input, name="input")
856    if input.dtype.is_complex:
857      return gen_math_ops.angle(input, Tout=input.dtype.real_dtype, name=name)
858    else:
859      return array_ops.where(input < 0, np.pi * array_ops.ones_like(input),
860                             array_ops.zeros_like(input))
861
862
863# pylint: enable=redefined-outer-name,redefined-builtin
864
865
866@tf_export("math.round", "round")
867@dispatch.add_dispatch_support
868def round(x, name=None):  # pylint: disable=redefined-builtin
869  """Rounds the values of a tensor to the nearest integer, element-wise.
870
871  Rounds half to even.  Also known as bankers rounding. If you want to round
872  according to the current system rounding mode use tf::cint.
873  For example:
874
875  ```python
876  x = tf.constant([0.9, 2.5, 2.3, 1.5, -4.5])
877  tf.round(x)  # [ 1.0, 2.0, 2.0, 2.0, -4.0 ]
878  ```
879
880  Args:
881    x: A `Tensor` of type `float16`, `float32`, `float64`, `int32`, or `int64`.
882    name: A name for the operation (optional).
883
884  Returns:
885    A `Tensor` of same shape and type as `x`.
886  """
887  x = ops.convert_to_tensor(x, name="x")
888  if x.dtype.is_integer:
889    return x
890  else:
891    return gen_math_ops.round(x, name=name)
892
893
894@tf_export("cast", "dtypes.cast")
895@dispatch.add_dispatch_support
896def cast(x, dtype, name=None):
897  """Casts a tensor to a new type.
898
899  The operation casts `x` (in case of `Tensor`) or `x.values`
900  (in case of `SparseTensor` or `IndexedSlices`) to `dtype`.
901
902  For example:
903
904  >>> x = tf.constant([1.8, 2.2], dtype=tf.float32)
905  >>> tf.cast(x, tf.int32)
906  <tf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 2], dtype=int32)>
907
908  Notice `tf.cast` has an alias `tf.dtypes.cast`:
909
910  >>> x = tf.constant([1.8, 2.2], dtype=tf.float32)
911  >>> tf.dtypes.cast(x, tf.int32)
912  <tf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 2], dtype=int32)>
913
914  The operation supports data types (for `x` and `dtype`) of
915  `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`, `int64`,
916  `float16`, `float32`, `float64`, `complex64`, `complex128`, `bfloat16`.
917  In case of casting from complex types (`complex64`, `complex128`) to real
918  types, only the real part of `x` is returned. In case of casting from real
919  types to complex types (`complex64`, `complex128`), the imaginary part of the
920  returned value is set to `0`. The handling of complex types here matches the
921  behavior of numpy.
922
923  Note casting nan and inf values to integral types has undefined behavior.
924
925  Args:
926    x: A `Tensor` or `SparseTensor` or `IndexedSlices` of numeric type. It could
927      be `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`,
928      `int64`, `float16`, `float32`, `float64`, `complex64`, `complex128`,
929      `bfloat16`.
930    dtype: The destination type. The list of supported dtypes is the same as
931      `x`.
932    name: A name for the operation (optional).
933
934  Returns:
935    A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` and
936      same type as `dtype`.
937
938  Raises:
939    TypeError: If `x` cannot be cast to the `dtype`.
940  """
941  base_type = dtypes.as_dtype(dtype).base_dtype
942  if isinstance(x,
943                (ops.Tensor, _resource_variable_type)) and base_type == x.dtype:
944    return x
945  with ops.name_scope(name, "Cast", [x]) as name:
946    if isinstance(x, sparse_tensor.SparseTensor):
947      values_cast = cast(x.values, base_type, name=name)
948      x = sparse_tensor.SparseTensor(x.indices, values_cast, x.dense_shape)
949    elif isinstance(x, ops.IndexedSlices):
950      values_cast = cast(x.values, base_type, name=name)
951      x = ops.IndexedSlices(values_cast, x.indices, x.dense_shape)
952    else:
953      # TODO(josh11b): If x is not already a Tensor, we could return
954      # ops.convert_to_tensor(x, dtype=dtype, ...)  here, but that
955      # allows some conversions that cast() can't do, e.g. casting numbers to
956      # strings.
957      x = ops.convert_to_tensor(x, name="x")
958      if x.dtype.base_dtype != base_type:
959        x = gen_math_ops.cast(x, base_type, name=name)
960    if x.dtype.is_complex and base_type.is_floating:
961      logging.warn("Casting complex to real discards imaginary part.")
962    return x
963
964
965@tf_export("dtypes.saturate_cast", "saturate_cast")
966@dispatch.add_dispatch_support
967def saturate_cast(value, dtype, name=None):
968  """Performs a safe saturating cast of `value` to `dtype`.
969
970  This function casts the input to `dtype` without applying any scaling.  If
971  there is a danger that values would over or underflow in the cast, this op
972  applies the appropriate clamping before the cast.
973
974  Args:
975    value: A `Tensor`.
976    dtype: The desired output `DType`.
977    name: A name for the operation (optional).
978
979  Returns:
980    `value` safely cast to `dtype`.
981  """
982  # When casting to a type with smaller representable range, clamp.
983  # Note that this covers casting to unsigned types as well.
984  with ops.name_scope(name, "saturate_cast", [value]) as name:
985    value = ops.convert_to_tensor(value, name="value")
986    dtype = dtypes.as_dtype(dtype).base_dtype
987    if value.dtype.min < dtype.min:
988      value = gen_math_ops.maximum(
989          value,
990          ops.convert_to_tensor(dtype.min, dtype=value.dtype, name="min"))
991    if value.dtype.max > dtype.max:
992      value = gen_math_ops.minimum(
993          value,
994          ops.convert_to_tensor(dtype.max, dtype=value.dtype, name="max"))
995    return cast(value, dtype, name=name)
996
997
998@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
999@tf_export(v1=["to_float"])
1000@dispatch.add_dispatch_support
1001def to_float(x, name="ToFloat"):
1002  """Casts a tensor to type `float32`.
1003
1004  Args:
1005    x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
1006    name: A name for the operation (optional).
1007
1008  Returns:
1009    A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
1010    type `float32`.
1011
1012  Raises:
1013    TypeError: If `x` cannot be cast to the `float32`.
1014  """
1015  return cast(x, dtypes.float32, name=name)
1016
1017
1018@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
1019@tf_export(v1=["to_double"])
1020@dispatch.add_dispatch_support
1021def to_double(x, name="ToDouble"):
1022  """Casts a tensor to type `float64`.
1023
1024  Args:
1025    x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
1026    name: A name for the operation (optional).
1027
1028  Returns:
1029    A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
1030    type `float64`.
1031
1032  Raises:
1033    TypeError: If `x` cannot be cast to the `float64`.
1034  """
1035  return cast(x, dtypes.float64, name=name)
1036
1037
1038@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
1039@tf_export(v1=["to_int32"])
1040@dispatch.add_dispatch_support
1041def to_int32(x, name="ToInt32"):
1042  """Casts a tensor to type `int32`.
1043
1044  Args:
1045    x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
1046    name: A name for the operation (optional).
1047
1048  Returns:
1049    A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
1050    type `int32`.
1051
1052  Raises:
1053    TypeError: If `x` cannot be cast to the `int32`.
1054  """
1055  return cast(x, dtypes.int32, name=name)
1056
1057
1058@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
1059@tf_export(v1=["to_int64"])
1060@dispatch.add_dispatch_support
1061def to_int64(x, name="ToInt64"):
1062  """Casts a tensor to type `int64`.
1063
1064  Args:
1065    x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
1066    name: A name for the operation (optional).
1067
1068  Returns:
1069    A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
1070    type `int64`.
1071
1072  Raises:
1073    TypeError: If `x` cannot be cast to the `int64`.
1074  """
1075  return cast(x, dtypes.int64, name=name)
1076
1077
1078@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
1079@tf_export(v1=["to_bfloat16"])
1080@dispatch.add_dispatch_support
1081def to_bfloat16(x, name="ToBFloat16"):
1082  """Casts a tensor to type `bfloat16`.
1083
1084  Args:
1085    x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
1086    name: A name for the operation (optional).
1087
1088  Returns:
1089    A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
1090    type `bfloat16`.
1091
1092  Raises:
1093    TypeError: If `x` cannot be cast to the `bfloat16`.
1094  """
1095  return cast(x, dtypes.bfloat16, name=name)
1096
1097
1098@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
1099@tf_export(v1=["to_complex64"])
1100@dispatch.add_dispatch_support
1101def to_complex64(x, name="ToComplex64"):
1102  """Casts a tensor to type `complex64`.
1103
1104  Args:
1105    x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
1106    name: A name for the operation (optional).
1107
1108  Returns:
1109    A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
1110    type `complex64`.
1111
1112  Raises:
1113    TypeError: If `x` cannot be cast to the `complex64`.
1114  """
1115  return cast(x, dtypes.complex64, name=name)
1116
1117
1118@deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.")
1119@tf_export(v1=["to_complex128"])
1120@dispatch.add_dispatch_support
1121def to_complex128(x, name="ToComplex128"):
1122  """Casts a tensor to type `complex128`.
1123
1124  Args:
1125    x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
1126    name: A name for the operation (optional).
1127
1128  Returns:
1129    A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
1130    type `complex128`.
1131
1132  Raises:
1133    TypeError: If `x` cannot be cast to the `complex128`.
1134  """
1135  return cast(x, dtypes.complex128, name=name)
1136
1137
1138ops.Tensor._override_operator("__neg__", gen_math_ops.neg)
1139ops.Tensor._override_operator("__abs__", abs)
1140
1141
1142def _maybe_get_dtype(x):
1143  """Returns a numpy type if available from x. Skips if x is numpy.ndarray."""
1144  # Don't put np.ndarray in this list, because np.result_type looks at the
1145  # value (not just dtype) of np.ndarray to decide the result type.
1146  if isinstance(x, numbers.Real):
1147    return x
1148  if isinstance(x, ops.Tensor):
1149    return x.dtype.as_numpy_dtype
1150  if isinstance(x, dtypes.DType):
1151    return x.as_numpy_dtype
1152  if isinstance(x, tensor_shape.TensorShape):
1153    return np.int32
1154  if isinstance(x, (list, tuple)):
1155    raise ValueError("Got sequence {}".format(x))
1156  return x
1157
1158
1159def maybe_promote_tensors(*tensors, force_same_dtype=True):
1160  """Promote tensors if numpy style promotion is enabled."""
1161  if not tensors:
1162    return tensors
1163  if not ops._numpy_style_type_promotion:
1164    if not force_same_dtype:
1165      return tensors
1166    promoted_tensors = []
1167    promoted_tensors.append(tensors[0])
1168    dtype = tensors[0].dtype.base_dtype
1169    for tensor in tensors[1:]:
1170      promoted_tensors.append(
1171          ops.convert_to_tensor(tensor, dtype, name="x"))
1172    return promoted_tensors
1173  result_type = np_dtypes._result_type(
1174      *[_maybe_get_dtype(x) for x in nest.flatten(tensors)])
1175  def _promote_or_cast(x):
1176    if isinstance(x, ops.Tensor):
1177      x = cast(x, result_type)
1178    else:
1179      x = ops.convert_to_tensor(x, result_type)
1180    return x
1181  return [_promote_or_cast(x) for x in tensors]
1182
1183
1184def _OverrideBinaryOperatorHelper(func, op_name, clazz_object=ops.Tensor):
1185  """Register operators with different tensor and scalar versions.
1186
1187  If `clazz_object` is `SparseTensor`, assumes `func` takes `(sp_indices,
1188  sp_values, sp_shape, dense)` and outputs `(new_sp_values)`.
1189
1190  Args:
1191    func: the operator
1192    op_name: name of the operator being overridden
1193    clazz_object: class to override for.  Either `Tensor` or `SparseTensor`.
1194  """
1195
1196  def binary_op_wrapper(x, y):
1197    with ops.name_scope(None, op_name, [x, y]) as name:
1198      try:
1199        # force_same_dtype=False to preserve existing TF behavior
1200        # TODO(b/178860388): Figure out why binary_op_wrapper and
1201        #   r_binary_op_wrapper use different force_same_dtype values.
1202        x, y = maybe_promote_tensors(x, y, force_same_dtype=False)
1203        return func(x, y, name=name)
1204      except (TypeError, ValueError) as e:
1205        # Even if dispatching the op failed, the RHS may be a tensor aware
1206        # object that can implement the operator with knowledge of itself
1207        # and the tensor.
1208        # If the RHS is not tensor aware we still want to raise the
1209        # original error from the LHS, because it may be more
1210        # informative.
1211        if hasattr(type(y), "__r%s__" % op_name):
1212          try:
1213            r_op = getattr(y, "__r%s__" % op_name)
1214            out = r_op(x)
1215            if out is NotImplemented:
1216              raise
1217            return out
1218          except (TypeError, ValueError):
1219            raise e
1220        else:
1221          raise
1222
1223  def binary_op_wrapper_sparse(sp_x, y):
1224    with ops.name_scope(None, op_name, [sp_x, y]) as name:
1225      y = ops.convert_to_tensor(y, dtype=sp_x.dtype.base_dtype, name="y")
1226      return sparse_tensor.SparseTensor(
1227          sp_x.indices,
1228          func(sp_x.indices, sp_x.values, sp_x.dense_shape, y, name=name),
1229          sp_x.dense_shape)
1230
1231  def r_binary_op_wrapper(y, x):
1232    with ops.name_scope(None, op_name, [x, y]) as name:
1233      # TODO(b/178860388): Figure out why binary_op_wrapper and
1234      #   r_binary_op_wrapper use different force_same_dtype values.
1235      y, x = maybe_promote_tensors(y, x)
1236      return func(x, y, name=name)
1237
1238  # Propagate func.__doc__ to the wrappers
1239  try:
1240    doc = func.__doc__
1241  except AttributeError:
1242    doc = None
1243  binary_op_wrapper.__doc__ = doc
1244  r_binary_op_wrapper.__doc__ = doc
1245  binary_op_wrapper_sparse.__doc__ = doc
1246
1247  if clazz_object is ops.Tensor:
1248    clazz_object._override_operator("__%s__" % op_name, binary_op_wrapper)
1249    del binary_op_wrapper
1250    clazz_object._override_operator("__r%s__" % op_name, r_binary_op_wrapper)
1251    del r_binary_op_wrapper
1252  else:
1253    clazz_object._override_operator("__%s__" % op_name,
1254                                    binary_op_wrapper_sparse)
1255    del binary_op_wrapper_sparse
1256
1257
1258# Conversion table for __truediv__.  None entries mean no conversion required.
1259_TRUEDIV_TABLE = {
1260    dtypes.uint8: dtypes.float32,
1261    dtypes.int8: dtypes.float32,
1262    dtypes.uint16: dtypes.float32,
1263    dtypes.int16: dtypes.float32,
1264    dtypes.int32: dtypes.float64,
1265    dtypes.int64: dtypes.float64,
1266    dtypes.bfloat16: None,
1267    dtypes.float16: None,
1268    dtypes.float32: None,
1269    dtypes.float64: None,
1270    dtypes.complex64: None,
1271    dtypes.complex128: None,
1272}
1273
1274
1275# NOTE: the support of "sparse (true)div dense" is currently not baked in into
1276# "tf.(true_)div()".  Until such an API decision is made, the supported usage is
1277# to explicitly use the "/" operator to invoke either truediv or div.
1278def _sparse_dense_truediv(sp_indices, sp_values, sp_shape, y, name=None):
1279  """Internal helper function for 'sp_t / dense_t'."""
1280  with ops.name_scope(name, "truediv",
1281                      [sp_indices, sp_values, sp_shape, y]) as name:
1282    sp_values = ops.convert_to_tensor(sp_values, name="sp_values")
1283    y = ops.convert_to_tensor(y, name="y")
1284    x_dtype = sp_values.dtype.base_dtype
1285    y_dtype = y.dtype.base_dtype
1286    if x_dtype != y_dtype:
1287      raise TypeError("x and y must have the same dtype, got %r != %r" %
1288                      (x_dtype, y_dtype))
1289    try:
1290      dtype = _TRUEDIV_TABLE[x_dtype]
1291    except KeyError:
1292      raise TypeError("Invalid dtype %r in __truediv__" % x_dtype)
1293    if dtype is not None:
1294      sp_values = cast(sp_values, dtype)
1295      y = cast(y, dtype)
1296    return gen_sparse_ops.sparse_dense_cwise_div(
1297        sp_indices, sp_values, sp_shape, y, name=name)
1298
1299
1300def _truediv_python3(x, y, name=None):
1301  with ops.name_scope(name, "truediv", [x, y]) as name:
1302    x = ops.convert_to_tensor(x, name="x")
1303    y = ops.convert_to_tensor(y, dtype_hint=x.dtype.base_dtype, name="y")
1304    x_dtype = x.dtype.base_dtype
1305    y_dtype = y.dtype.base_dtype
1306    if x_dtype != y_dtype:
1307      raise TypeError("x and y must have the same dtype, got %r != %r" %
1308                      (x_dtype, y_dtype))
1309    try:
1310      dtype = _TRUEDIV_TABLE[x_dtype]
1311    except KeyError:
1312      raise TypeError("Invalid dtype %r in __truediv__" % x_dtype)
1313    if dtype is not None:
1314      x = cast(x, dtype)
1315      y = cast(y, dtype)
1316    return gen_math_ops.real_div(x, y, name=name)
1317
1318
1319def _div_python2(x, y, name=None):
1320  """Divide two values using Python 2 semantics.
1321
1322  Used for Tensor.__div__.
1323
1324  Args:
1325    x: `Tensor` numerator of real numeric type.
1326    y: `Tensor` denominator of real numeric type.
1327    name: A name for the operation (optional).
1328
1329  Returns:
1330    `x / y` returns the quotient of x and y.
1331  """
1332
1333  with ops.name_scope(name, "div", [x, y]) as name:
1334    x = ops.convert_to_tensor(x, name="x")
1335    y = ops.convert_to_tensor(y, name="y", dtype=x.dtype.base_dtype)
1336    x_dtype = x.dtype.base_dtype
1337    y_dtype = y.dtype.base_dtype
1338    if x_dtype != y_dtype:
1339      raise TypeError("x and y must have the same dtype, got %r != %r" %
1340                      (x_dtype, y_dtype))
1341    if x_dtype.is_floating or x_dtype.is_complex:
1342      return gen_math_ops.real_div(x, y, name=name)
1343    else:
1344      return gen_math_ops.floor_div(x, y, name=name)
1345
1346
1347@tf_export("math.truediv", "truediv")
1348@dispatch.add_dispatch_support
1349def truediv(x, y, name=None):
1350  """Divides x / y elementwise (using Python 3 division operator semantics).
1351
1352  NOTE: Prefer using the Tensor operator or tf.divide which obey Python
1353  division operator semantics.
1354
1355  This function forces Python 3 division operator semantics where all integer
1356  arguments are cast to floating types first.   This op is generated by normal
1357  `x / y` division in Python 3 and in Python 2.7 with
1358  `from __future__ import division`.  If you want integer division that rounds
1359  down, use `x // y` or `tf.math.floordiv`.
1360
1361  `x` and `y` must have the same numeric type.  If the inputs are floating
1362  point, the output will have the same type.  If the inputs are integral, the
1363  inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32`
1364  and `int64` (matching the behavior of Numpy).
1365
1366  Args:
1367    x: `Tensor` numerator of numeric type.
1368    y: `Tensor` denominator of numeric type.
1369    name: A name for the operation (optional).
1370
1371  Returns:
1372    `x / y` evaluated in floating point.
1373
1374  Raises:
1375    TypeError: If `x` and `y` have different dtypes.
1376  """
1377  return _truediv_python3(x, y, name)
1378
1379
1380@deprecation.deprecated(
1381    date=None,
1382    instructions="Deprecated in favor of operator or tf.math.divide.")
1383@tf_export(v1=["div"])
1384@dispatch.add_dispatch_support
1385def div(x, y, name=None):
1386  """Divides x / y elementwise (using Python 2 division operator semantics).
1387
1388  NOTE: Prefer using the Tensor division operator or tf.divide which obey Python
1389  3 division operator semantics.
1390
1391  This function divides `x` and `y`, forcing Python 2 semantics. That is, if `x`
1392  and `y` are both integers then the result will be an integer. This is in
1393  contrast to Python 3, where division with `/` is always a float while division
1394  with `//` is always an integer.
1395
1396  Args:
1397    x: `Tensor` numerator of real numeric type.
1398    y: `Tensor` denominator of real numeric type.
1399    name: A name for the operation (optional).
1400
1401  Returns:
1402    `x / y` returns the quotient of x and y.
1403  """
1404  return _div_python2(x, y, name)
1405
1406
1407@tf_export("math.divide_no_nan", v1=["math.divide_no_nan", "div_no_nan"])
1408@dispatch.add_dispatch_support
1409@deprecation.deprecated_endpoints("div_no_nan")
1410@dispatch.add_dispatch_support
1411def div_no_nan(x, y, name=None):
1412  """Computes a safe divide which returns 0 if the y is zero.
1413
1414  Args:
1415    x: A `Tensor`. Must be one of the following types: `float32`, `float64`.
1416    y: A `Tensor` whose dtype is compatible with `x`.
1417    name: A name for the operation (optional).
1418
1419  Returns:
1420    The element-wise value of the x divided by y.
1421  """
1422
1423  with ops.name_scope(name, "div_no_nan", [x, y]) as name:
1424    x = ops.convert_to_tensor(x, name="x")
1425    y = ops.convert_to_tensor(y, name="y", dtype=x.dtype.base_dtype)
1426    return gen_math_ops.div_no_nan(x, y, name=name)
1427
1428
1429@tf_export("math.multiply_no_nan")
1430@dispatch.add_dispatch_support
1431def multiply_no_nan(x, y, name=None):
1432  """Computes the product of x and y and returns 0 if the y is zero, even if x is NaN or infinite.
1433
1434  Args:
1435    x: A `Tensor`. Must be one of the following types: `float32`, `float64`.
1436    y: A `Tensor` whose dtype is compatible with `x`.
1437    name: A name for the operation (optional).
1438
1439  Returns:
1440    The element-wise value of the x times y.
1441  """
1442
1443  with ops.name_scope(name, "multiply_no_nan", [x, y]) as name:
1444    x = ops.convert_to_tensor(x, name="x")
1445    y = ops.convert_to_tensor(y, name="y", dtype=x.dtype.base_dtype)
1446    x_dtype = x.dtype.base_dtype
1447    y_dtype = y.dtype.base_dtype
1448    if x_dtype != y_dtype:
1449      raise TypeError("x and y must have the same dtype, got %r != %r" %
1450                      (x_dtype, y_dtype))
1451    return gen_math_ops.mul_no_nan(x, y, name=name)
1452
1453
1454# TODO(aselle): This should be removed
1455mod = gen_math_ops.floor_mod
1456
1457
1458# TODO(aselle): Deprecate this once all internal functionality uses
1459# tf.truncatediv
1460@tf_export("math.floordiv", v1=["math.floordiv", "floordiv"])
1461@dispatch.add_dispatch_support
1462@deprecation.deprecated_endpoints("floordiv")
1463def floordiv(x, y, name=None):
1464  """Divides `x / y` elementwise, rounding toward the most negative integer.
1465
1466  The same as `tf.compat.v1.div(x,y)` for integers, but uses
1467  `tf.floor(tf.compat.v1.div(x,y))` for
1468  floating point arguments so that the result is always an integer (though
1469  possibly an integer represented as floating point).  This op is generated by
1470  `x // y` floor division in Python 3 and in Python 2.7 with
1471  `from __future__ import division`.
1472
1473  `x` and `y` must have the same type, and the result will have the same type
1474  as well.
1475
1476  Args:
1477    x: `Tensor` numerator of real numeric type.
1478    y: `Tensor` denominator of real numeric type.
1479    name: A name for the operation (optional).
1480
1481  Returns:
1482    `x / y` rounded down.
1483
1484  Raises:
1485    TypeError: If the inputs are complex.
1486  """
1487  with ops.name_scope(name, "floordiv", [x, y]) as name:
1488    return gen_math_ops.floor_div(x, y, name=name)
1489
1490
1491realdiv = gen_math_ops.real_div
1492truncatediv = gen_math_ops.truncate_div
1493# TODO(aselle): Rename this to floordiv when we can.
1494floor_div = gen_math_ops.floor_div
1495truncatemod = gen_math_ops.truncate_mod
1496floormod = gen_math_ops.floor_mod
1497
1498
1499@tf_export("__operators__.add", v1=[])
1500@dispatch.add_dispatch_support
1501def _add_dispatch(x, y, name=None):
1502  """The operation invoked by the `Tensor.__add__` operator.
1503
1504    Purpose in the API:
1505
1506      This method is exposed in TensorFlow's API so that library developers
1507      can register dispatching for `Tensor.__add__` to allow it to handle
1508      custom composite tensors & other custom objects.
1509
1510      The API symbol is not intended to be called by users directly and does
1511      appear in TensorFlow's generated documentation.
1512
1513  Args:
1514    x: The left-hand side of the `+` operator.
1515    y: The right-hand side of the `+` operator.
1516    name: an optional name for the operation.
1517
1518  Returns:
1519    The result of the elementwise `+` operation.
1520  """
1521  if not isinstance(y, ops.Tensor) and not isinstance(
1522      y, sparse_tensor.SparseTensor):
1523    y = ops.convert_to_tensor(y, dtype_hint=x.dtype.base_dtype, name="y")
1524  if x.dtype == dtypes.string:
1525    return gen_math_ops.add(x, y, name=name)
1526  else:
1527    return gen_math_ops.add_v2(x, y, name=name)
1528
1529
1530def _mul_dispatch(x, y, name=None):
1531  """Dispatches cwise mul for "Dense*Dense" and "Dense*Sparse"."""
1532  if isinstance(y, sparse_tensor.SparseTensor):  # Case: Dense * Sparse.
1533    new_vals = gen_sparse_ops.sparse_dense_cwise_mul(y.indices, y.values,
1534                                                     y.dense_shape, x, name)
1535    return sparse_tensor.SparseTensor(y.indices, new_vals, y.dense_shape)
1536  else:
1537    return multiply(x, y, name=name)
1538
1539
1540# NOTE(aselle): When integer division is added for sparse_dense_cwise,
1541# div, truediv, and floordiv should be delegated appropriately for
1542# Python semantics, analogous to dense cwise tensor operations.
1543_OverrideBinaryOperatorHelper(gen_sparse_ops.sparse_dense_cwise_div, "div",
1544                              sparse_tensor.SparseTensor)
1545_OverrideBinaryOperatorHelper(_sparse_dense_truediv, "truediv",
1546                              sparse_tensor.SparseTensor)
1547_OverrideBinaryOperatorHelper(gen_sparse_ops.sparse_dense_cwise_mul, "mul",
1548                              sparse_tensor.SparseTensor)
1549
1550_OverrideBinaryOperatorHelper(_add_dispatch, "add")
1551_OverrideBinaryOperatorHelper(subtract, "sub")
1552_OverrideBinaryOperatorHelper(_mul_dispatch, "mul")
1553_OverrideBinaryOperatorHelper(div, "div")
1554_OverrideBinaryOperatorHelper(truediv, "truediv")
1555_OverrideBinaryOperatorHelper(floordiv, "floordiv")
1556_OverrideBinaryOperatorHelper(gen_math_ops.floor_mod, "mod")
1557_OverrideBinaryOperatorHelper(pow, "pow")
1558
1559
1560@tf_export("math.logical_xor", v1=["math.logical_xor", "logical_xor"])
1561@dispatch.add_dispatch_support
1562@deprecation.deprecated_endpoints("logical_xor")
1563def logical_xor(x, y, name="LogicalXor"):
1564  """Logical XOR function.
1565
1566  x ^ y = (x | y) & ~(x & y)
1567
1568  Requires that `x` and `y` have the same shape or have
1569  [broadcast-compatible](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
1570  shapes. For example, `x` and `y` can be:
1571
1572  - Two single elements of type `bool`
1573  - One `tf.Tensor` of type `bool` and one single `bool`, where the result will
1574    be calculated by applying logical XOR with the single element to each
1575    element in the larger Tensor.
1576  - Two `tf.Tensor` objects of type `bool` of the same shape. In this case,
1577    the result will be the element-wise logical XOR of the two input tensors.
1578
1579  Usage:
1580
1581  >>> a = tf.constant([True])
1582  >>> b = tf.constant([False])
1583  >>> tf.math.logical_xor(a, b)
1584  <tf.Tensor: shape=(1,), dtype=bool, numpy=array([ True])>
1585
1586  >>> c = tf.constant([True])
1587  >>> x = tf.constant([False, True, True, False])
1588  >>> tf.math.logical_xor(c, x)
1589  <tf.Tensor: shape=(4,), dtype=bool, numpy=array([ True, False, False,  True])>
1590
1591  >>> y = tf.constant([False, False, True, True])
1592  >>> z = tf.constant([False, True, False, True])
1593  >>> tf.math.logical_xor(y, z)
1594  <tf.Tensor: shape=(4,), dtype=bool, numpy=array([False,  True,  True, False])>
1595
1596  Args:
1597      x: A `tf.Tensor` type bool.
1598      y: A `tf.Tensor` of type bool.
1599      name: A name for the operation (optional).
1600
1601  Returns:
1602    A `tf.Tensor` of type bool with the same size as that of x or y.
1603  """
1604  # TODO(alemi) Make this a cwise op if people end up relying on it.
1605  return gen_math_ops.logical_and(
1606      gen_math_ops.logical_or(x, y),
1607      gen_math_ops.logical_not(gen_math_ops.logical_and(x, y)),
1608      name=name)
1609
1610
1611def and_(x, y, name=None):
1612  if x.dtype == dtypes.bool:
1613    return gen_math_ops.logical_and(x, y, name)
1614  return gen_bitwise_ops.bitwise_and(x, y)
1615
1616
1617def or_(x, y, name=None):
1618  if x.dtype == dtypes.bool:
1619    return gen_math_ops.logical_or(x, y, name)
1620  return gen_bitwise_ops.bitwise_or(x, y)
1621
1622
1623def xor_(x, y, name=None):
1624  if x.dtype == dtypes.bool:
1625    return logical_xor(x, y, name)
1626  return gen_bitwise_ops.bitwise_xor(x, y)
1627
1628
1629def invert_(x, name=None):
1630  if x.dtype == dtypes.bool:
1631    return gen_math_ops.logical_not(x, name=name)
1632  return gen_bitwise_ops.invert(x, name=name)
1633
1634
1635_OverrideBinaryOperatorHelper(and_, "and")
1636_OverrideBinaryOperatorHelper(or_, "or")
1637_OverrideBinaryOperatorHelper(xor_, "xor")
1638ops.Tensor._override_operator("__invert__", invert_)
1639
1640
1641def _promote_dtypes_decorator(fn):
1642  def wrapper(x, y, *args, **kwargs):
1643    x, y = maybe_promote_tensors(x, y, force_same_dtype=False)
1644    return fn(x, y, *args, **kwargs)
1645  return tf_decorator.make_decorator(fn, wrapper)
1646
1647
1648ops.Tensor._override_operator("__lt__", _promote_dtypes_decorator(
1649    gen_math_ops.less))
1650ops.Tensor._override_operator("__le__", _promote_dtypes_decorator(
1651    gen_math_ops.less_equal))
1652ops.Tensor._override_operator("__gt__", _promote_dtypes_decorator(
1653    gen_math_ops.greater))
1654ops.Tensor._override_operator("__ge__", _promote_dtypes_decorator(
1655    gen_math_ops.greater_equal))
1656
1657
1658@tf_export("math.equal", "equal")
1659@dispatch.add_dispatch_support
1660def equal(x, y, name=None):
1661  """Returns the truth value of (x == y) element-wise.
1662
1663  Performs a [broadcast](
1664  https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) with the
1665  arguments and then an element-wise equality comparison, returning a Tensor of
1666  boolean values.
1667
1668  For example:
1669
1670  >>> x = tf.constant([2, 4])
1671  >>> y = tf.constant(2)
1672  >>> tf.math.equal(x, y)
1673  <tf.Tensor: shape=(2,), dtype=bool, numpy=array([ True,  False])>
1674
1675  >>> x = tf.constant([2, 4])
1676  >>> y = tf.constant([2, 4])
1677  >>> tf.math.equal(x, y)
1678  <tf.Tensor: shape=(2,), dtype=bool, numpy=array([ True,  True])>
1679
1680  Args:
1681    x: A `tf.Tensor` or `tf.sparse.SparseTensor` or `tf.IndexedSlices`.
1682    y: A `tf.Tensor` or `tf.sparse.SparseTensor` or `tf.IndexedSlices`.
1683    name: A name for the operation (optional).
1684
1685  Returns:
1686    A `tf.Tensor` of type bool with the same size as that of x or y.
1687
1688  Raises:
1689    `tf.errors.InvalidArgumentError`: If shapes of arguments are incompatible
1690  """
1691  return gen_math_ops.equal(x, y, name=name)
1692
1693
1694@tf_export("math.not_equal", "not_equal")
1695@dispatch.add_dispatch_support
1696def not_equal(x, y, name=None):
1697  """Returns the truth value of (x != y) element-wise.
1698
1699  Performs a [broadcast](
1700  https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) with the
1701  arguments and then an element-wise inequality comparison, returning a Tensor
1702  of boolean values.
1703
1704  For example:
1705
1706  >>> x = tf.constant([2, 4])
1707  >>> y = tf.constant(2)
1708  >>> tf.math.not_equal(x, y)
1709  <tf.Tensor: shape=(2,), dtype=bool, numpy=array([False,  True])>
1710
1711  >>> x = tf.constant([2, 4])
1712  >>> y = tf.constant([2, 4])
1713  >>> tf.math.not_equal(x, y)
1714  <tf.Tensor: shape=(2,), dtype=bool, numpy=array([False,  False])>
1715
1716  Args:
1717    x: A `tf.Tensor` or `tf.sparse.SparseTensor` or `tf.IndexedSlices`.
1718    y: A `tf.Tensor` or `tf.sparse.SparseTensor` or `tf.IndexedSlices`.
1719    name: A name for the operation (optional).
1720
1721  Returns:
1722    A `tf.Tensor` of type bool with the same size as that of x or y.
1723
1724  Raises:
1725    `tf.errors.InvalidArgumentError`: If shapes of arguments are incompatible
1726  """
1727  return gen_math_ops.not_equal(x, y, name=name)
1728
1729
1730@tf_export("__operators__.eq", v1=[])
1731@dispatch.add_dispatch_support
1732def tensor_equals(self, other):
1733  """The operation invoked by the `Tensor.__eq__` operator.
1734
1735  Compares two tensors element-wise for equality if they are
1736  broadcast-compatible; or returns False if they are not broadcast-compatible.
1737  (Note that this behavior differs from `tf.math.equal`, which raises an
1738  exception if the two tensors are not broadcast-compatible.)
1739
1740  Purpose in the API:
1741
1742    This method is exposed in TensorFlow's API so that library developers
1743    can register dispatching for `Tensor.__eq__` to allow it to handle
1744    custom composite tensors & other custom objects.
1745
1746    The API symbol is not intended to be called by users directly and does
1747    appear in TensorFlow's generated documentation.
1748
1749  Args:
1750    self: The left-hand side of the `==` operator.
1751    other: The right-hand side of the `==` operator.
1752
1753  Returns:
1754    The result of the elementwise `==` operation, or `False` if the arguments
1755    are not broadcast-compatible.
1756  """
1757  if other is None:
1758    return False
1759  g = getattr(self, "graph", None)
1760  if (ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions() and
1761      (g is None or g.building_function)):
1762    self, other = maybe_promote_tensors(self, other)
1763    return gen_math_ops.equal(self, other, incompatible_shape_error=False)
1764  else:
1765    # In legacy graph mode, tensor equality is object equality
1766    return self is other
1767
1768
1769@tf_export("__operators__.ne", v1=[])
1770@dispatch.add_dispatch_support
1771def tensor_not_equals(self, other):
1772  """The operation invoked by the `Tensor.__ne__` operator.
1773
1774  Compares two tensors element-wise for inequality if they are
1775  broadcast-compatible; or returns True if they are not broadcast-compatible.
1776  (Note that this behavior differs from `tf.math.not_equal`, which raises an
1777  exception if the two tensors are not broadcast-compatible.)
1778
1779  Purpose in the API:
1780
1781    This method is exposed in TensorFlow's API so that library developers
1782    can register dispatching for `Tensor.__ne__` to allow it to handle
1783    custom composite tensors & other custom objects.
1784
1785    The API symbol is not intended to be called by users directly and does
1786    appear in TensorFlow's generated documentation.
1787
1788  Args:
1789    self: The left-hand side of the `!=` operator.
1790    other: The right-hand side of the `!=` operator.
1791
1792  Returns:
1793    The result of the elementwise `!=` operation, or `True` if the arguments
1794    are not broadcast-compatible.
1795  """
1796  if other is None:
1797    return True
1798  if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions():
1799    self, other = maybe_promote_tensors(self, other)
1800    return gen_math_ops.not_equal(self, other, incompatible_shape_error=False)
1801  else:
1802    # In legacy graph mode, tensor equality is object equality
1803    return self is not other
1804
1805
1806ops.Tensor._override_operator("__eq__", tensor_equals)
1807ops.Tensor._override_operator("__ne__", tensor_not_equals)
1808
1809
1810@tf_export("range")
1811@dispatch.add_dispatch_support
1812def range(start, limit=None, delta=1, dtype=None, name="range"):  # pylint: disable=redefined-builtin
1813  """Creates a sequence of numbers.
1814
1815  Creates a sequence of numbers that begins at `start` and extends by
1816  increments of `delta` up to but not including `limit`.
1817
1818  The dtype of the resulting tensor is inferred from the inputs unless
1819  it is provided explicitly.
1820
1821  Like the Python builtin `range`, `start` defaults to 0, so that
1822  `range(n) = range(0, n)`.
1823
1824  For example:
1825
1826  >>> start = 3
1827  >>> limit = 18
1828  >>> delta = 3
1829  >>> tf.range(start, limit, delta)
1830  <tf.Tensor: shape=(5,), dtype=int32,
1831  numpy=array([ 3,  6,  9, 12, 15], dtype=int32)>
1832
1833  >>> start = 3
1834  >>> limit = 1
1835  >>> delta = -0.5
1836  >>> tf.range(start, limit, delta)
1837  <tf.Tensor: shape=(4,), dtype=float32,
1838  numpy=array([3. , 2.5, 2. , 1.5], dtype=float32)>
1839
1840  >>> limit = 5
1841  >>> tf.range(limit)
1842  <tf.Tensor: shape=(5,), dtype=int32,
1843  numpy=array([0, 1, 2, 3, 4], dtype=int32)>
1844
1845  Args:
1846    start: A 0-D `Tensor` (scalar). Acts as first entry in the range if `limit`
1847      is not None; otherwise, acts as range limit and first entry defaults to 0.
1848    limit: A 0-D `Tensor` (scalar). Upper limit of sequence, exclusive. If None,
1849      defaults to the value of `start` while the first entry of the range
1850      defaults to 0.
1851    delta: A 0-D `Tensor` (scalar). Number that increments `start`. Defaults to
1852      1.
1853    dtype: The type of the elements of the resulting tensor.
1854    name: A name for the operation. Defaults to "range".
1855
1856  Returns:
1857    An 1-D `Tensor` of type `dtype`.
1858
1859  @compatibility(numpy)
1860  Equivalent to np.arange
1861  @end_compatibility
1862  """
1863  if limit is None:
1864    start, limit = 0, start
1865
1866  with ops.name_scope(name, "Range", [start, limit, delta]) as name:
1867    if not isinstance(start, ops.Tensor):
1868      start = ops.convert_to_tensor(start, dtype=dtype, name="start")
1869    if not isinstance(limit, ops.Tensor):
1870      limit = ops.convert_to_tensor(limit, dtype=dtype, name="limit")
1871    if not isinstance(delta, ops.Tensor):
1872      delta = ops.convert_to_tensor(delta, dtype=dtype, name="delta")
1873
1874    # infer dtype if not explicitly provided
1875    if dtype is None:
1876      dtype_hierarchy = [
1877          dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
1878      ]
1879      assert all(arg.dtype in dtype_hierarchy for arg in [start, limit, delta])
1880      inferred_dtype = max([arg.dtype for arg in [start, limit, delta]],
1881                           key=dtype_hierarchy.index)
1882    else:
1883      inferred_dtype = dtype
1884    # Always try to perform a cast even when start/limit/delta are already
1885    # tensors. This will resolve the case where start/limit/delta's original's
1886    # dtype is different from provided dtype.
1887    start = cast(start, inferred_dtype)
1888    limit = cast(limit, inferred_dtype)
1889    delta = cast(delta, inferred_dtype)
1890
1891    return gen_math_ops._range(start, limit, delta, name=name)
1892
1893
1894def _range_tensor_conversion_function(value, dtype=None, name=None,
1895                                      as_ref=False):
1896  del as_ref
1897  return range(value.start, value.stop, value.step, dtype=dtype, name=name)
1898
1899
1900if not six.PY2:
1901  ops.register_tensor_conversion_function(builtins.range,
1902                                          _range_tensor_conversion_function)
1903
1904# Reduction operations
1905def _ReductionDims(x, axis):  # pylint: disable=invalid-name
1906  """Returns range(0, rank(x)) if axis is None."""
1907  if axis is not None:
1908    return axis
1909  else:
1910    x_rank = None
1911    if isinstance(x, ops.Tensor):
1912      x_rank = x.shape.rank
1913    elif (isinstance(x, sparse_tensor.SparseTensor) and
1914          x.dense_shape.shape.is_fully_defined()):
1915      x_rank = x.dense_shape.shape.dims[0].value  # sparse.dense_shape is 1-D.
1916    # Fast path: avoid creating Rank and Range ops if ndims is known.
1917    if x_rank:
1918      return constant_op.constant(np.arange(x_rank, dtype=np.int32))
1919    else:
1920      # Otherwise, we rely on Range and Rank to do the right thing at run-time.
1921      return range(0, array_ops.rank(x))
1922
1923
1924def _has_fully_defined_shape(tensor):
1925  """Returns true if tensor has a fully defined shape."""
1926  return isinstance(tensor, ops.EagerTensor) or tensor.shape.is_fully_defined()
1927
1928
1929def _may_reduce_to_scalar(keepdims, axis, output):
1930  """Set a reduction's output shape to be a scalar if we are certain."""
1931  if not _has_fully_defined_shape(output) and (not keepdims) and (
1932      axis is None):
1933    output.set_shape(())
1934  return output
1935
1936
1937@tf_export(v1=["math.reduce_sum", "reduce_sum"])
1938@dispatch.add_dispatch_support
1939@deprecation.deprecated_args(None,
1940                             "keep_dims is deprecated, use keepdims instead",
1941                             "keep_dims")
1942def reduce_sum_v1(input_tensor,
1943                  axis=None,
1944                  keepdims=None,
1945                  name=None,
1946                  reduction_indices=None,
1947                  keep_dims=None):
1948  """Computes the sum of elements across dimensions of a tensor.
1949
1950  This is the reduction operation for the elementwise `tf.math.add` op.
1951
1952  Reduces `input_tensor` along the dimensions given in `axis`.
1953  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
1954  of the entries in `axis`, which must be unique. If `keepdims` is true, the
1955  reduced dimensions are retained with length 1.
1956
1957  If `axis` is None, all dimensions are reduced, and a
1958  tensor with a single element is returned.
1959
1960  For example:
1961
1962    >>> # x has a shape of (2, 3) (two rows and three columns):
1963    >>> x = tf.constant([[1, 1, 1], [1, 1, 1]])
1964    >>> x.numpy()
1965    array([[1, 1, 1],
1966           [1, 1, 1]], dtype=int32)
1967    >>> # sum all the elements
1968    >>> # 1 + 1 + 1 + 1 + 1+ 1 = 6
1969    >>> tf.reduce_sum(x).numpy()
1970    6
1971    >>> # reduce along the first dimension
1972    >>> # the result is [1, 1, 1] + [1, 1, 1] = [2, 2, 2]
1973    >>> tf.reduce_sum(x, 0).numpy()
1974    array([2, 2, 2], dtype=int32)
1975    >>> # reduce along the second dimension
1976    >>> # the result is [1, 1] + [1, 1] + [1, 1] = [3, 3]
1977    >>> tf.reduce_sum(x, 1).numpy()
1978    array([3, 3], dtype=int32)
1979    >>> # keep the original dimensions
1980    >>> tf.reduce_sum(x, 1, keepdims=True).numpy()
1981    array([[3],
1982           [3]], dtype=int32)
1983    >>> # reduce along both dimensions
1984    >>> # the result is 1 + 1 + 1 + 1 + 1 + 1 = 6
1985    >>> # or, equivalently, reduce along rows, then reduce the resultant array
1986    >>> # [1, 1, 1] + [1, 1, 1] = [2, 2, 2]
1987    >>> # 2 + 2 + 2 = 6
1988    >>> tf.reduce_sum(x, [0, 1]).numpy()
1989    6
1990
1991  Args:
1992    input_tensor: The tensor to reduce. Should have numeric type.
1993    axis: The dimensions to reduce. If `None` (the default), reduces all
1994      dimensions. Must be in the range `[-rank(input_tensor),
1995      rank(input_tensor))`.
1996    keepdims: If true, retains reduced dimensions with length 1.
1997    name: A name for the operation (optional).
1998    reduction_indices: The old (deprecated) name for axis.
1999    keep_dims: Deprecated alias for `keepdims`.
2000
2001  Returns:
2002    The reduced tensor, of the same dtype as the input_tensor.
2003
2004  @compatibility(numpy)
2005  Equivalent to np.sum apart the fact that numpy upcast uint8 and int32 to
2006  int64 while tensorflow returns the same dtype as the input.
2007  @end_compatibility
2008  """
2009  axis = deprecation.deprecated_argument_lookup("axis", axis,
2010                                                "reduction_indices",
2011                                                reduction_indices)
2012  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
2013                                                    "keep_dims", keep_dims)
2014  return reduce_sum(input_tensor, axis, keepdims, name)
2015
2016
2017@tf_export("math.reduce_sum", "reduce_sum", v1=[])
2018@dispatch.add_dispatch_support
2019def reduce_sum(input_tensor, axis=None, keepdims=False, name=None):
2020  """Computes the sum of elements across dimensions of a tensor.
2021
2022  This is the reduction operation for the elementwise `tf.math.add` op.
2023
2024  Reduces `input_tensor` along the dimensions given in `axis`.
2025  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
2026  of the entries in `axis`, which must be unique. If `keepdims` is true, the
2027  reduced dimensions are retained with length 1.
2028
2029  If `axis` is None, all dimensions are reduced, and a
2030  tensor with a single element is returned.
2031
2032  For example:
2033
2034    >>> # x has a shape of (2, 3) (two rows and three columns):
2035    >>> x = tf.constant([[1, 1, 1], [1, 1, 1]])
2036    >>> x.numpy()
2037    array([[1, 1, 1],
2038           [1, 1, 1]], dtype=int32)
2039    >>> # sum all the elements
2040    >>> # 1 + 1 + 1 + 1 + 1+ 1 = 6
2041    >>> tf.reduce_sum(x).numpy()
2042    6
2043    >>> # reduce along the first dimension
2044    >>> # the result is [1, 1, 1] + [1, 1, 1] = [2, 2, 2]
2045    >>> tf.reduce_sum(x, 0).numpy()
2046    array([2, 2, 2], dtype=int32)
2047    >>> # reduce along the second dimension
2048    >>> # the result is [1, 1] + [1, 1] + [1, 1] = [3, 3]
2049    >>> tf.reduce_sum(x, 1).numpy()
2050    array([3, 3], dtype=int32)
2051    >>> # keep the original dimensions
2052    >>> tf.reduce_sum(x, 1, keepdims=True).numpy()
2053    array([[3],
2054           [3]], dtype=int32)
2055    >>> # reduce along both dimensions
2056    >>> # the result is 1 + 1 + 1 + 1 + 1 + 1 = 6
2057    >>> # or, equivalently, reduce along rows, then reduce the resultant array
2058    >>> # [1, 1, 1] + [1, 1, 1] = [2, 2, 2]
2059    >>> # 2 + 2 + 2 = 6
2060    >>> tf.reduce_sum(x, [0, 1]).numpy()
2061    6
2062
2063  Args:
2064    input_tensor: The tensor to reduce. Should have numeric type.
2065    axis: The dimensions to reduce. If `None` (the default), reduces all
2066      dimensions. Must be in the range `[-rank(input_tensor),
2067      rank(input_tensor)]`.
2068    keepdims: If true, retains reduced dimensions with length 1.
2069    name: A name for the operation (optional).
2070
2071  Returns:
2072    The reduced tensor, of the same dtype as the input_tensor.
2073
2074  @compatibility(numpy)
2075  Equivalent to np.sum apart the fact that numpy upcast uint8 and int32 to
2076  int64 while tensorflow returns the same dtype as the input.
2077  @end_compatibility
2078  """
2079
2080  return reduce_sum_with_dims(input_tensor, axis, keepdims, name,
2081                              _ReductionDims(input_tensor, axis))
2082
2083
2084def reduce_sum_with_dims(input_tensor,
2085                         axis=None,
2086                         keepdims=False,
2087                         name=None,
2088                         dims=None):
2089  keepdims = False if keepdims is None else bool(keepdims)
2090  return _may_reduce_to_scalar(
2091      keepdims, axis,
2092      gen_math_ops._sum(input_tensor, dims, keepdims, name=name))
2093
2094
2095@tf_export("math.reduce_euclidean_norm")
2096@dispatch.add_dispatch_support
2097def reduce_euclidean_norm(input_tensor, axis=None, keepdims=False, name=None):
2098  """Computes the Euclidean norm of elements across dimensions of a tensor.
2099
2100  Reduces `input_tensor` along the dimensions given in `axis`.
2101  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
2102  of the entries in `axis`, which must be unique. If `keepdims` is true, the
2103  reduced dimensions are retained with length 1.
2104
2105  If `axis` is None, all dimensions are reduced, and a
2106  tensor with a single element is returned.
2107
2108  For example:
2109
2110  ```python
2111  x = tf.constant([[1, 2, 3], [1, 1, 1]]) # x.dtype is tf.int32
2112  tf.math.reduce_euclidean_norm(x)  # returns 4 as dtype is tf.int32
2113  y = tf.constant([[1, 2, 3], [1, 1, 1]], dtype = tf.float32)
2114  tf.math.reduce_euclidean_norm(y)  # returns 4.1231055 which is sqrt(17)
2115  tf.math.reduce_euclidean_norm(y, 0)  # [sqrt(2), sqrt(5), sqrt(10)]
2116  tf.math.reduce_euclidean_norm(y, 1)  # [sqrt(14), sqrt(3)]
2117  tf.math.reduce_euclidean_norm(y, 1, keepdims=True)  # [[sqrt(14)], [sqrt(3)]]
2118  tf.math.reduce_euclidean_norm(y, [0, 1])  # sqrt(17)
2119  ```
2120
2121  Args:
2122    input_tensor: The tensor to reduce. Should have numeric type.
2123    axis: The dimensions to reduce. If `None` (the default), reduces all
2124      dimensions. Must be in the range `[-rank(input_tensor),
2125      rank(input_tensor))`.
2126    keepdims: If true, retains reduced dimensions with length 1.
2127    name: A name for the operation (optional).
2128
2129  Returns:
2130    The reduced tensor, of the same dtype as the input_tensor.
2131  """
2132  keepdims = bool(keepdims)
2133  return _may_reduce_to_scalar(
2134      keepdims, axis,
2135      gen_math_ops.euclidean_norm(
2136          input_tensor, _ReductionDims(input_tensor, axis), keepdims,
2137          name=name))
2138
2139
2140@tf_export(v1=["math.count_nonzero", "count_nonzero"])
2141@dispatch.add_dispatch_support
2142@deprecation.deprecated_args(None,
2143                             "keep_dims is deprecated, use keepdims instead",
2144                             "keep_dims")
2145@deprecation.deprecated_args(
2146    None, "reduction_indices is deprecated, use axis instead",
2147    "reduction_indices")
2148def count_nonzero(input_tensor=None,
2149                  axis=None,
2150                  keepdims=None,
2151                  dtype=dtypes.int64,
2152                  name=None,
2153                  reduction_indices=None,
2154                  keep_dims=None,
2155                  input=None):  # pylint: disable=redefined-builtin
2156  """Computes number of nonzero elements across dimensions of a tensor.
2157
2158  Reduces `input_tensor` along the dimensions given in `axis`.
2159  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
2160  entry in `axis`. If `keepdims` is true, the reduced dimensions
2161  are retained with length 1.
2162
2163  If `axis` has no entries, all dimensions are reduced, and a
2164  tensor with a single element is returned.
2165
2166  **NOTE** Floating point comparison to zero is done by exact floating point
2167  equality check.  Small values are **not** rounded to zero for purposes of
2168  the nonzero check.
2169
2170  For example:
2171
2172  ```python
2173  x = tf.constant([[0, 1, 0], [1, 1, 0]])
2174  tf.math.count_nonzero(x)  # 3
2175  tf.math.count_nonzero(x, 0)  # [1, 2, 0]
2176  tf.math.count_nonzero(x, 1)  # [1, 2]
2177  tf.math.count_nonzero(x, 1, keepdims=True)  # [[1], [2]]
2178  tf.math.count_nonzero(x, [0, 1])  # 3
2179  ```
2180
2181  **NOTE** Strings are compared against zero-length empty string `""`. Any
2182  string with a size greater than zero is already considered as nonzero.
2183
2184  For example:
2185  ```python
2186  x = tf.constant(["", "a", "  ", "b", ""])
2187  tf.math.count_nonzero(x) # 3, with "a", "  ", and "b" as nonzero strings.
2188  ```
2189
2190  Args:
2191    input_tensor: The tensor to reduce. Should be of numeric type, `bool`, or
2192      `string`.
2193    axis: The dimensions to reduce. If `None` (the default), reduces all
2194      dimensions. Must be in the range `[-rank(input_tensor),
2195      rank(input_tensor))`.
2196    keepdims: If true, retains reduced dimensions with length 1.
2197    dtype: The output dtype; defaults to `tf.int64`.
2198    name: A name for the operation (optional).
2199    reduction_indices: The old (deprecated) name for axis.
2200    keep_dims: Deprecated alias for `keepdims`.
2201    input: Overrides input_tensor. For compatibility.
2202
2203  Returns:
2204    The reduced tensor (number of nonzero values).
2205  """
2206  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
2207                                                    "keep_dims", keep_dims)
2208  input_tensor = deprecation.deprecated_argument_lookup("input", input,
2209                                                        "input_tensor",
2210                                                        input_tensor)
2211  axis = deprecation.deprecated_argument_lookup("axis", axis,
2212                                                "reduction_indices",
2213                                                reduction_indices)
2214
2215  return count_nonzero_v2(input_tensor, axis, keepdims, dtype, name)
2216
2217
2218@tf_export("math.count_nonzero", v1=[])
2219@dispatch.add_dispatch_support
2220def count_nonzero_v2(
2221    input,  # pylint: disable=redefined-builtin
2222    axis=None,
2223    keepdims=None,
2224    dtype=dtypes.int64,
2225    name=None):
2226  """Computes number of nonzero elements across dimensions of a tensor.
2227
2228  Reduces `input` along the dimensions given in `axis`.
2229  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
2230  entry in `axis`. If `keepdims` is true, the reduced dimensions
2231  are retained with length 1.
2232
2233  If `axis` has no entries, all dimensions are reduced, and a
2234  tensor with a single element is returned.
2235
2236  **NOTE** Floating point comparison to zero is done by exact floating point
2237  equality check.  Small values are **not** rounded to zero for purposes of
2238  the nonzero check.
2239
2240  For example:
2241
2242  ```python
2243  x = tf.constant([[0, 1, 0], [1, 1, 0]])
2244  tf.math.count_nonzero(x)  # 3
2245  tf.math.count_nonzero(x, 0)  # [1, 2, 0]
2246  tf.math.count_nonzero(x, 1)  # [1, 2]
2247  tf.math.count_nonzero(x, 1, keepdims=True)  # [[1], [2]]
2248  tf.math.count_nonzero(x, [0, 1])  # 3
2249  ```
2250
2251  **NOTE** Strings are compared against zero-length empty string `""`. Any
2252  string with a size greater than zero is already considered as nonzero.
2253
2254  For example:
2255  ```python
2256  x = tf.constant(["", "a", "  ", "b", ""])
2257  tf.math.count_nonzero(x) # 3, with "a", "  ", and "b" as nonzero strings.
2258  ```
2259
2260  Args:
2261    input: The tensor to reduce. Should be of numeric type, `bool`, or `string`.
2262    axis: The dimensions to reduce. If `None` (the default), reduces all
2263      dimensions. Must be in the range `[-rank(input), rank(input))`.
2264    keepdims: If true, retains reduced dimensions with length 1.
2265    dtype: The output dtype; defaults to `tf.int64`.
2266    name: A name for the operation (optional).
2267
2268  Returns:
2269    The reduced tensor (number of nonzero values).
2270  """
2271  if keepdims is None:
2272    keepdims = False
2273  with ops.name_scope(name, "count_nonzero", [input]):
2274    input = ops.convert_to_tensor(input, name="input")
2275    # A scalar of 'zero' is enough as `not_equal` will broadcast.
2276    zero = array_ops.zeros([], dtype=input.dtype)
2277    return cast(
2278        reduce_sum(
2279            # int64 reduction happens on GPU
2280            cast(gen_math_ops.not_equal(input, zero), dtypes.int64),
2281            axis=axis,
2282            keepdims=keepdims),
2283        dtype=dtype)
2284
2285
2286@tf_export(v1=["math.reduce_mean", "reduce_mean"])
2287@dispatch.add_dispatch_support
2288def reduce_mean_v1(input_tensor,
2289                   axis=None,
2290                   keepdims=None,
2291                   name=None,
2292                   reduction_indices=None,
2293                   keep_dims=None):
2294  """Computes the mean of elements across dimensions of a tensor.
2295
2296  Reduces `input_tensor` along the dimensions given in `axis` by computing the
2297  mean of elements across the dimensions in `axis`.
2298  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
2299  the entries in `axis`, which must be unique. If `keepdims` is true, the
2300  reduced dimensions are retained with length 1.
2301
2302  If `axis` is None, all dimensions are reduced, and a tensor with a single
2303  element is returned.
2304
2305  For example:
2306
2307  >>> x = tf.constant([[1., 1.], [2., 2.]])
2308  >>> tf.reduce_mean(x)
2309  <tf.Tensor: shape=(), dtype=float32, numpy=1.5>
2310  >>> tf.reduce_mean(x, 0)
2311  <tf.Tensor: shape=(2,), dtype=float32, numpy=array([1.5, 1.5], dtype=float32)>
2312  >>> tf.reduce_mean(x, 1)
2313  <tf.Tensor: shape=(2,), dtype=float32, numpy=array([1., 2.], dtype=float32)>
2314
2315  Args:
2316    input_tensor: The tensor to reduce. Should have numeric type.
2317    axis: The dimensions to reduce. If `None` (the default), reduces all
2318      dimensions. Must be in the range `[-rank(input_tensor),
2319      rank(input_tensor))`.
2320    keepdims: If true, retains reduced dimensions with length 1.
2321    name: A name for the operation (optional).
2322    reduction_indices: The old (deprecated) name for axis.
2323    keep_dims: Deprecated alias for `keepdims`.
2324
2325  Returns:
2326    The reduced tensor.
2327
2328  @compatibility(numpy)
2329  Equivalent to np.mean
2330
2331  Please note that `np.mean` has a `dtype` parameter that could be used to
2332  specify the output type. By default this is `dtype=float64`. On the other
2333  hand, `tf.reduce_mean` has an aggressive type inference from `input_tensor`,
2334  for example:
2335
2336  >>> x = tf.constant([1, 0, 1, 0])
2337  >>> tf.reduce_mean(x)
2338  <tf.Tensor: shape=(), dtype=int32, numpy=0>
2339  >>> y = tf.constant([1., 0., 1., 0.])
2340  >>> tf.reduce_mean(y)
2341  <tf.Tensor: shape=(), dtype=float32, numpy=0.5>
2342
2343  @end_compatibility
2344  """
2345  axis = deprecation.deprecated_argument_lookup("axis", axis,
2346                                                "reduction_indices",
2347                                                reduction_indices)
2348  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
2349                                                    "keep_dims", keep_dims)
2350  return reduce_mean(input_tensor, axis, keepdims, name)
2351
2352
2353@tf_export("math.reduce_mean", "reduce_mean", v1=[])
2354@dispatch.add_dispatch_support
2355def reduce_mean(input_tensor, axis=None, keepdims=False, name=None):
2356  """Computes the mean of elements across dimensions of a tensor.
2357
2358  Reduces `input_tensor` along the dimensions given in `axis` by computing the
2359  mean of elements across the dimensions in `axis`.
2360  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
2361  of the entries in `axis`, which must be unique. If `keepdims` is true, the
2362  reduced dimensions are retained with length 1.
2363
2364  If `axis` is None, all dimensions are reduced, and a tensor with a single
2365  element is returned.
2366
2367  For example:
2368
2369  >>> x = tf.constant([[1., 1.], [2., 2.]])
2370  >>> tf.reduce_mean(x)
2371  <tf.Tensor: shape=(), dtype=float32, numpy=1.5>
2372  >>> tf.reduce_mean(x, 0)
2373  <tf.Tensor: shape=(2,), dtype=float32, numpy=array([1.5, 1.5], dtype=float32)>
2374  >>> tf.reduce_mean(x, 1)
2375  <tf.Tensor: shape=(2,), dtype=float32, numpy=array([1., 2.], dtype=float32)>
2376
2377  Args:
2378    input_tensor: The tensor to reduce. Should have numeric type.
2379    axis: The dimensions to reduce. If `None` (the default), reduces all
2380      dimensions. Must be in the range `[-rank(input_tensor),
2381      rank(input_tensor))`.
2382    keepdims: If true, retains reduced dimensions with length 1.
2383    name: A name for the operation (optional).
2384
2385  Returns:
2386    The reduced tensor.
2387
2388  @compatibility(numpy)
2389  Equivalent to np.mean
2390
2391  Please note that `np.mean` has a `dtype` parameter that could be used to
2392  specify the output type. By default this is `dtype=float64`. On the other
2393  hand, `tf.reduce_mean` has an aggressive type inference from `input_tensor`,
2394  for example:
2395
2396  >>> x = tf.constant([1, 0, 1, 0])
2397  >>> tf.reduce_mean(x)
2398  <tf.Tensor: shape=(), dtype=int32, numpy=0>
2399  >>> y = tf.constant([1., 0., 1., 0.])
2400  >>> tf.reduce_mean(y)
2401  <tf.Tensor: shape=(), dtype=float32, numpy=0.5>
2402
2403  @end_compatibility
2404  """
2405  keepdims = False if keepdims is None else bool(keepdims)
2406  return _may_reduce_to_scalar(
2407      keepdims, axis,
2408      gen_math_ops.mean(
2409          input_tensor, _ReductionDims(input_tensor, axis), keepdims,
2410          name=name))
2411
2412
2413@tf_export("math.reduce_variance")
2414@dispatch.add_dispatch_support
2415def reduce_variance(input_tensor, axis=None, keepdims=False, name=None):
2416  """Computes the variance of elements across dimensions of a tensor.
2417
2418  Reduces `input_tensor` along the dimensions given in `axis`.
2419  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
2420  of the entries in `axis`, which must be unique. If `keepdims` is true, the
2421  reduced dimensions are retained with length 1.
2422
2423  If `axis` is None, all dimensions are reduced, and a
2424  tensor with a single element is returned.
2425
2426  For example:
2427
2428  >>> x = tf.constant([[1., 2.], [3., 4.]])
2429  >>> tf.math.reduce_variance(x)
2430  <tf.Tensor: shape=(), dtype=float32, numpy=1.25>
2431  >>> tf.math.reduce_variance(x, 0)
2432  <tf.Tensor: shape=(2,), dtype=float32, numpy=array([1., 1.], ...)>
2433  >>> tf.math.reduce_variance(x, 1)
2434  <tf.Tensor: shape=(2,), dtype=float32, numpy=array([0.25, 0.25], ...)>
2435
2436  Args:
2437    input_tensor: The tensor to reduce. Should have real or complex type.
2438    axis: The dimensions to reduce. If `None` (the default), reduces all
2439      dimensions. Must be in the range `[-rank(input_tensor),
2440      rank(input_tensor))`.
2441    keepdims: If true, retains reduced dimensions with length 1.
2442    name: A name scope for the associated operations (optional).
2443
2444  Returns:
2445    The reduced tensor, of the same dtype as the input_tensor. Note,  for
2446    `complex64` or `complex128` input, the returned `Tensor` will be of type
2447    `float32` or `float64`, respectively.
2448
2449  @compatibility(numpy)
2450  Equivalent to np.var
2451
2452  Please note `np.var` has a `dtype` parameter that could be used to specify the
2453  output type. By default this is `dtype=float64`. On the other hand,
2454  `tf.math.reduce_variance` has aggressive type inference from `input_tensor`.
2455  @end_compatibility
2456  """
2457  name = name if name else "reduce_variance"
2458  with ops.name_scope(name):
2459    means = reduce_mean(input_tensor, axis=axis, keepdims=True)
2460    if means.dtype.is_integer:
2461      raise TypeError("Input must be either real or complex")
2462    diff = input_tensor - means
2463    if diff.dtype.is_complex:
2464      # For complex values we need to take the absolute value before squaring.
2465      # This is achieved by multiplying with the conjugate.
2466      real_dtype = diff.dtype.real_dtype
2467      squared_deviations = gen_math_ops.real(
2468          gen_math_ops.mul(gen_math_ops.conj(diff), diff), Tout=real_dtype)
2469    else:
2470      squared_deviations = gen_math_ops.square(diff)
2471    return reduce_mean(squared_deviations, axis=axis, keepdims=keepdims)
2472
2473
2474@tf_export("math.reduce_std")
2475@dispatch.add_dispatch_support
2476def reduce_std(input_tensor, axis=None, keepdims=False, name=None):
2477  """Computes the standard deviation of elements across dimensions of a tensor.
2478
2479  Reduces `input_tensor` along the dimensions given in `axis`.
2480  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
2481  of the entries in `axis`, which must be unique. If `keepdims` is true, the
2482  reduced dimensions are retained with length 1.
2483
2484  If `axis` is None, all dimensions are reduced, and a
2485  tensor with a single element is returned.
2486
2487  For example:
2488
2489  >>> x = tf.constant([[1., 2.], [3., 4.]])
2490  >>> tf.math.reduce_std(x)
2491  <tf.Tensor: shape=(), dtype=float32, numpy=1.118034>
2492  >>> tf.math.reduce_std(x, 0)
2493  <tf.Tensor: shape=(2,), dtype=float32, numpy=array([1., 1.], dtype=float32)>
2494  >>> tf.math.reduce_std(x, 1)
2495  <tf.Tensor: shape=(2,), dtype=float32, numpy=array([0.5, 0.5], dtype=float32)>
2496
2497  Args:
2498    input_tensor: The tensor to reduce. Should have real or complex type.
2499    axis: The dimensions to reduce. If `None` (the default), reduces all
2500      dimensions. Must be in the range `[-rank(input_tensor),
2501      rank(input_tensor))`.
2502    keepdims: If true, retains reduced dimensions with length 1.
2503    name: A name scope for the associated operations (optional).
2504
2505  Returns:
2506    The reduced tensor, of the same dtype as the input_tensor. Note,  for
2507    `complex64` or `complex128` input, the returned `Tensor` will be of type
2508    `float32` or `float64`, respectively.
2509
2510  @compatibility(numpy)
2511  Equivalent to np.std
2512
2513  Please note `np.std` has a `dtype` parameter that could be used to specify the
2514  output type. By default this is `dtype=float64`. On the other hand,
2515  `tf.math.reduce_std` has aggressive type inference from `input_tensor`.
2516  @end_compatibility
2517  """
2518  name = name if name else "reduce_std"
2519  with ops.name_scope(name):
2520    variance = reduce_variance(input_tensor, axis=axis, keepdims=keepdims)
2521    return gen_math_ops.sqrt(variance)
2522
2523
2524@tf_export("math.reduce_prod", "reduce_prod", v1=[])
2525@dispatch.add_dispatch_support
2526def reduce_prod(input_tensor, axis=None, keepdims=False, name=None):
2527  """Computes `tf.math.multiply` of elements across dimensions of a tensor.
2528
2529  This is the reduction operation for the elementwise `tf.math.multiply` op.
2530
2531  Reduces `input_tensor` along the dimensions given in `axis`.
2532  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
2533  entry in `axis`. If `keepdims` is true, the reduced dimensions
2534  are retained with length 1.
2535
2536  If `axis` is None, all dimensions are reduced, and a
2537  tensor with a single element is returned.
2538
2539  For example:
2540
2541    >>> x = tf.constant([[1., 2.], [3., 4.]])
2542    >>> tf.math.reduce_prod(x)
2543    <tf.Tensor: shape=(), dtype=float32, numpy=24.>
2544    >>> tf.math.reduce_prod(x, 0)
2545    <tf.Tensor: shape=(2,), dtype=float32, numpy=array([3., 8.], dtype=float32)>
2546    >>> tf.math.reduce_prod(x, 1)
2547    <tf.Tensor: shape=(2,), dtype=float32, numpy=array([2., 12.],
2548    dtype=float32)>
2549
2550  Args:
2551    input_tensor: The tensor to reduce. Should have numeric type.
2552    axis: The dimensions to reduce. If `None` (the default), reduces all
2553      dimensions. Must be in the range `[-rank(input_tensor),
2554      rank(input_tensor))`.
2555    keepdims: If true, retains reduced dimensions with length 1.
2556    name: A name for the operation (optional).
2557
2558  Returns:
2559    The reduced tensor.
2560
2561  @compatibility(numpy)
2562  Equivalent to np.prod
2563  @end_compatibility
2564  """
2565  keepdims = False if keepdims is None else bool(keepdims)
2566  return _may_reduce_to_scalar(
2567      keepdims, axis,
2568      gen_math_ops.prod(
2569          input_tensor, _ReductionDims(input_tensor, axis), keepdims,
2570          name=name))
2571
2572
2573@tf_export(v1=["math.reduce_prod", "reduce_prod"])
2574@dispatch.add_dispatch_support
2575@deprecation.deprecated_args(None,
2576                             "keep_dims is deprecated, use keepdims instead",
2577                             "keep_dims")
2578def reduce_prod_v1(input_tensor,
2579                   axis=None,
2580                   keepdims=None,
2581                   name=None,
2582                   reduction_indices=None,
2583                   keep_dims=None):
2584  """Computes `tf.math.multiply` of elements across dimensions of a tensor.
2585
2586  This is the reduction operation for the elementwise `tf.math.multiply` op.
2587
2588  Reduces `input_tensor` along the dimensions given in `axis`.
2589  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
2590  of the entries in `axis`, which must be unique. If `keepdims` is true, the
2591  reduced dimensions are retained with length 1.
2592
2593  If `axis` is None, all dimensions are reduced, and a
2594  tensor with a single element is returned.
2595
2596  For example:
2597
2598    >>> x = tf.constant([[1., 2.], [3., 4.]])
2599    >>> tf.math.reduce_prod(x)
2600    <tf.Tensor: shape=(), dtype=float32, numpy=24.>
2601    >>> tf.math.reduce_prod(x, 0)
2602    <tf.Tensor: shape=(2,), dtype=float32, numpy=array([3., 8.], dtype=float32)>
2603    >>> tf.math.reduce_prod(x, 1)
2604    <tf.Tensor: shape=(2,), dtype=float32, numpy=array([2., 12.],
2605    dtype=float32)>
2606
2607  Args:
2608    input_tensor: The tensor to reduce. Should have numeric type.
2609    axis: The dimensions to reduce. If `None` (the default), reduces all
2610      dimensions. Must be in the range `[-rank(input_tensor),
2611      rank(input_tensor))`.
2612    keepdims: If true, retains reduced dimensions with length 1.
2613    name: A name for the operation (optional).
2614    reduction_indices: The old (deprecated) name for axis.
2615    keep_dims: Deprecated alias for `keepdims`.
2616
2617  Returns:
2618    The reduced tensor.
2619
2620  @compatibility(numpy)
2621  Equivalent to np.prod
2622  @end_compatibility
2623  """
2624  axis = deprecation.deprecated_argument_lookup("axis", axis,
2625                                                "reduction_indices",
2626                                                reduction_indices)
2627  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
2628                                                    "keep_dims", keep_dims)
2629  return reduce_prod(input_tensor, axis, keepdims, name)
2630
2631
2632@tf_export(v1=["math.reduce_min", "reduce_min"])
2633@dispatch.add_dispatch_support
2634@deprecation.deprecated_args(None,
2635                             "keep_dims is deprecated, use keepdims instead",
2636                             "keep_dims")
2637def reduce_min_v1(input_tensor,
2638                  axis=None,
2639                  keepdims=None,
2640                  name=None,
2641                  reduction_indices=None,
2642                  keep_dims=None):
2643  """Computes the `tf.math.minimum` of elements across dimensions of a tensor.
2644
2645  This is the reduction operation for the elementwise `tf.math.minimum` op.
2646
2647  Reduces `input_tensor` along the dimensions given in `axis`.
2648  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
2649  of the entries in `axis`, which must be unique. If `keepdims` is true, the
2650  reduced dimensions are retained with length 1.
2651
2652  If `axis` is None, all dimensions are reduced, and a
2653  tensor with a single element is returned.
2654
2655  Usage example:
2656
2657    >>> x = tf.constant([5, 1, 2, 4])
2658    >>> tf.reduce_min(x)
2659    <tf.Tensor: shape=(), dtype=int32, numpy=1>
2660    >>> x = tf.constant([-5, -1, -2, -4])
2661    >>> tf.reduce_min(x)
2662    <tf.Tensor: shape=(), dtype=int32, numpy=-5>
2663    >>> x = tf.constant([4, float('nan')])
2664    >>> tf.reduce_min(x)
2665    <tf.Tensor: shape=(), dtype=float32, numpy=nan>
2666    >>> x = tf.constant([float('nan'), float('nan')])
2667    >>> tf.reduce_min(x)
2668    <tf.Tensor: shape=(), dtype=float32, numpy=nan>
2669    >>> x = tf.constant([float('-inf'), float('inf')])
2670    >>> tf.reduce_min(x)
2671    <tf.Tensor: shape=(), dtype=float32, numpy=-inf>
2672
2673  See the numpy docs for `np.amin` and `np.nanmin` behavior.
2674
2675  Args:
2676    input_tensor: The tensor to reduce. Should have real numeric type.
2677    axis: The dimensions to reduce. If `None` (the default), reduces all
2678      dimensions. Must be in the range `[-rank(input_tensor),
2679      rank(input_tensor))`.
2680    keepdims: If true, retains reduced dimensions with length 1.
2681    name: A name for the operation (optional).
2682    reduction_indices: The old (deprecated) name for axis.
2683    keep_dims: Deprecated alias for `keepdims`.
2684
2685  Returns:
2686    The reduced tensor.
2687  """
2688  axis = deprecation.deprecated_argument_lookup("axis", axis,
2689                                                "reduction_indices",
2690                                                reduction_indices)
2691  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
2692                                                    "keep_dims", keep_dims)
2693  return reduce_min(input_tensor, axis, keepdims, name)
2694
2695
2696@tf_export("math.reduce_min", "reduce_min", v1=[])
2697@dispatch.add_dispatch_support
2698def reduce_min(input_tensor, axis=None, keepdims=False, name=None):
2699  """Computes the `tf.math.minimum` of elements across dimensions of a tensor.
2700
2701  This is the reduction operation for the elementwise `tf.math.minimum` op.
2702
2703  Reduces `input_tensor` along the dimensions given in `axis`.
2704  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
2705  of the entries in `axis`, which must be unique. If `keepdims` is true, the
2706  reduced dimensions are retained with length 1.
2707
2708  If `axis` is None, all dimensions are reduced, and a
2709  tensor with a single element is returned.
2710
2711  For example:
2712
2713  >>> a = tf.constant([
2714  ...   [[1, 2], [3, 4]],
2715  ...   [[1, 2], [3, 4]]
2716  ... ])
2717  >>> tf.reduce_min(a)
2718  <tf.Tensor: shape=(), dtype=int32, numpy=1>
2719
2720  Choosing a specific axis returns minimum element in the given axis:
2721
2722  >>> b = tf.constant([[1, 2, 3], [4, 5, 6]])
2723  >>> tf.reduce_min(b, axis=0)
2724  <tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>
2725  >>> tf.reduce_min(b, axis=1)
2726  <tf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 4], dtype=int32)>
2727
2728  Setting `keepdims` to `True` retains the dimension of `input_tensor`:
2729
2730  >>> tf.reduce_min(a, keepdims=True)
2731  <tf.Tensor: shape=(1, 1, 1), dtype=int32, numpy=array([[[1]]], dtype=int32)>
2732  >>> tf.math.reduce_min(a, axis=0, keepdims=True)
2733  <tf.Tensor: shape=(1, 2, 2), dtype=int32, numpy=
2734  array([[[1, 2],
2735          [3, 4]]], dtype=int32)>
2736
2737  Args:
2738    input_tensor: The tensor to reduce. Should have real numeric type.
2739    axis: The dimensions to reduce. If `None` (the default), reduces all
2740      dimensions. Must be in the range `[-rank(input_tensor),
2741      rank(input_tensor))`.
2742    keepdims: If true, retains reduced dimensions with length 1.
2743    name: A name for the operation (optional).
2744
2745  Returns:
2746    The reduced tensor.
2747
2748  @compatibility(numpy)
2749  Equivalent to np.min
2750  @end_compatibility
2751  """
2752  keepdims = False if keepdims is None else bool(keepdims)
2753  return _may_reduce_to_scalar(
2754      keepdims, axis,
2755      gen_math_ops._min(
2756          input_tensor, _ReductionDims(input_tensor, axis), keepdims,
2757          name=name))
2758
2759
2760@tf_export(v1=["math.reduce_max", "reduce_max"])
2761@dispatch.add_dispatch_support
2762@deprecation.deprecated_args(None,
2763                             "keep_dims is deprecated, use keepdims instead",
2764                             "keep_dims")
2765def reduce_max_v1(input_tensor,
2766                  axis=None,
2767                  keepdims=None,
2768                  name=None,
2769                  reduction_indices=None,
2770                  keep_dims=None):
2771  """Computes `tf.math.maximum` of elements across dimensions of a tensor.
2772
2773  This is the reduction operation for the elementwise `tf.math.maximum` op.
2774
2775  Reduces `input_tensor` along the dimensions given in `axis`.
2776  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
2777  of the entries in `axis`, which must be unique. If `keepdims` is true, the
2778  reduced dimensions are retained with length 1.
2779
2780  If `axis` is None, all dimensions are reduced, and a
2781  tensor with a single element is returned.
2782
2783  Usage example:
2784
2785    >>> x = tf.constant([5, 1, 2, 4])
2786    >>> tf.reduce_max(x)
2787    <tf.Tensor: shape=(), dtype=int32, numpy=5>
2788    >>> x = tf.constant([-5, -1, -2, -4])
2789    >>> tf.reduce_max(x)
2790    <tf.Tensor: shape=(), dtype=int32, numpy=-1>
2791    >>> x = tf.constant([4, float('nan')])
2792    >>> tf.reduce_max(x)
2793    <tf.Tensor: shape=(), dtype=float32, numpy=nan>
2794    >>> x = tf.constant([float('nan'), float('nan')])
2795    >>> tf.reduce_max(x)
2796    <tf.Tensor: shape=(), dtype=float32, numpy=nan>
2797    >>> x = tf.constant([float('-inf'), float('inf')])
2798    >>> tf.reduce_max(x)
2799    <tf.Tensor: shape=(), dtype=float32, numpy=inf>
2800
2801  See the numpy docs for `np.amax` and `np.nanmax` behavior.
2802
2803  Args:
2804    input_tensor: The tensor to reduce. Should have real numeric type.
2805    axis: The dimensions to reduce. If `None` (the default), reduces all
2806      dimensions. Must be in the range `[-rank(input_tensor),
2807      rank(input_tensor))`.
2808    keepdims: If true, retains reduced dimensions with length 1.
2809    name: A name for the operation (optional).
2810    reduction_indices: The old (deprecated) name for axis.
2811    keep_dims: Deprecated alias for `keepdims`.
2812
2813  Returns:
2814    The reduced tensor.
2815  """
2816  axis = deprecation.deprecated_argument_lookup("axis", axis,
2817                                                "reduction_indices",
2818                                                reduction_indices)
2819  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
2820                                                    "keep_dims", keep_dims)
2821  return reduce_max(input_tensor, axis, keepdims, name)
2822
2823
2824@tf_export("math.reduce_max", "reduce_max", v1=[])
2825@dispatch.add_dispatch_support
2826def reduce_max(input_tensor, axis=None, keepdims=False, name=None):
2827  """Computes `tf.math.maximum` of elements across dimensions of a tensor.
2828
2829  This is the reduction operation for the elementwise `tf.math.maximum` op.
2830
2831  Reduces `input_tensor` along the dimensions given in `axis`.
2832  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
2833  of the entries in `axis`, which must be unique. If `keepdims` is true, the
2834  reduced dimensions are retained with length 1.
2835
2836  If `axis` is None, all dimensions are reduced, and a
2837  tensor with a single element is returned.
2838
2839  Usage example:
2840
2841    >>> x = tf.constant([5, 1, 2, 4])
2842    >>> tf.reduce_max(x)
2843    <tf.Tensor: shape=(), dtype=int32, numpy=5>
2844    >>> x = tf.constant([-5, -1, -2, -4])
2845    >>> tf.reduce_max(x)
2846    <tf.Tensor: shape=(), dtype=int32, numpy=-1>
2847    >>> x = tf.constant([4, float('nan')])
2848    >>> tf.reduce_max(x)
2849    <tf.Tensor: shape=(), dtype=float32, numpy=nan>
2850    >>> x = tf.constant([float('nan'), float('nan')])
2851    >>> tf.reduce_max(x)
2852    <tf.Tensor: shape=(), dtype=float32, numpy=nan>
2853    >>> x = tf.constant([float('-inf'), float('inf')])
2854    >>> tf.reduce_max(x)
2855    <tf.Tensor: shape=(), dtype=float32, numpy=inf>
2856
2857  See the numpy docs for `np.amax` and `np.nanmax` behavior.
2858
2859  Args:
2860    input_tensor: The tensor to reduce. Should have real numeric type.
2861    axis: The dimensions to reduce. If `None` (the default), reduces all
2862      dimensions. Must be in the range `[-rank(input_tensor),
2863      rank(input_tensor))`.
2864    keepdims: If true, retains reduced dimensions with length 1.
2865    name: A name for the operation (optional).
2866
2867  Returns:
2868    The reduced tensor.
2869  """
2870  return reduce_max_with_dims(input_tensor, axis, keepdims, name,
2871                              _ReductionDims(input_tensor, axis))
2872
2873
2874def reduce_max_with_dims(input_tensor,
2875                         axis=None,
2876                         keepdims=False,
2877                         name=None,
2878                         dims=None):
2879  keepdims = False if keepdims is None else bool(keepdims)
2880  return _may_reduce_to_scalar(
2881      keepdims, axis,
2882      gen_math_ops._max(input_tensor, dims, keepdims, name=name))
2883
2884
2885@tf_export(v1=["math.reduce_all", "reduce_all"])
2886@dispatch.add_dispatch_support
2887@deprecation.deprecated_args(None,
2888                             "keep_dims is deprecated, use keepdims instead",
2889                             "keep_dims")
2890def reduce_all_v1(input_tensor,
2891                  axis=None,
2892                  keepdims=None,
2893                  name=None,
2894                  reduction_indices=None,
2895                  keep_dims=None):
2896  """Computes `tf.math.logical_and` of elements across dimensions of a tensor.
2897
2898  This is the reduction operation for the elementwise `tf.math.logical_and` op.
2899
2900  Reduces `input_tensor` along the dimensions given in `axis`.
2901  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
2902  of the entries in `axis`, which must be unique. If `keepdims` is true, the
2903  reduced dimensions are retained with length 1.
2904
2905  If `axis` is None, all dimensions are reduced, and a
2906  tensor with a single element is returned.
2907
2908  For example:
2909
2910    >>> x = tf.constant([[True,  True], [False, False]])
2911    >>> tf.math.reduce_all(x)
2912    <tf.Tensor: shape=(), dtype=bool, numpy=False>
2913    >>> tf.math.reduce_all(x, 0)
2914    <tf.Tensor: shape=(2,), dtype=bool, numpy=array([False, False])>
2915    >>> tf.math.reduce_all(x, 1)
2916    <tf.Tensor: shape=(2,), dtype=bool, numpy=array([ True, False])>
2917
2918  Args:
2919    input_tensor: The boolean tensor to reduce.
2920    axis: The dimensions to reduce. If `None` (the default), reduces all
2921      dimensions. Must be in the range `[-rank(input_tensor),
2922      rank(input_tensor))`.
2923    keepdims: If true, retains reduced dimensions with length 1.
2924    name: A name for the operation (optional).
2925    reduction_indices: The old (deprecated) name for axis.
2926    keep_dims: Deprecated alias for `keepdims`.
2927
2928  Returns:
2929    The reduced tensor.
2930
2931  @compatibility(numpy)
2932  Equivalent to np.all
2933  @end_compatibility
2934  """
2935  axis = deprecation.deprecated_argument_lookup("axis", axis,
2936                                                "reduction_indices",
2937                                                reduction_indices)
2938  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
2939                                                    "keep_dims", keep_dims)
2940  return reduce_all(input_tensor, axis, keepdims, name)
2941
2942
2943@tf_export("math.reduce_all", "reduce_all", v1=[])
2944@dispatch.add_dispatch_support
2945def reduce_all(input_tensor, axis=None, keepdims=False, name=None):
2946  """Computes `tf.math.logical_and` of elements across dimensions of a tensor.
2947
2948  This is the reduction operation for the elementwise `tf.math.logical_and` op.
2949
2950  Reduces `input_tensor` along the dimensions given in `axis`.
2951  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
2952  of the entries in `axis`, which must be unique. If `keepdims` is true, the
2953  reduced dimensions are retained with length 1.
2954
2955  If `axis` is None, all dimensions are reduced, and a
2956  tensor with a single element is returned.
2957
2958  For example:
2959
2960    >>> x = tf.constant([[True,  True], [False, False]])
2961    >>> tf.math.reduce_all(x)
2962    <tf.Tensor: shape=(), dtype=bool, numpy=False>
2963    >>> tf.math.reduce_all(x, 0)
2964    <tf.Tensor: shape=(2,), dtype=bool, numpy=array([False, False])>
2965    >>> tf.math.reduce_all(x, 1)
2966    <tf.Tensor: shape=(2,), dtype=bool, numpy=array([ True, False])>
2967
2968  Args:
2969    input_tensor: The boolean tensor to reduce.
2970    axis: The dimensions to reduce. If `None` (the default), reduces all
2971      dimensions. Must be in the range `[-rank(input_tensor),
2972      rank(input_tensor))`.
2973    keepdims: If true, retains reduced dimensions with length 1.
2974    name: A name for the operation (optional).
2975
2976  Returns:
2977    The reduced tensor.
2978
2979  @compatibility(numpy)
2980  Equivalent to np.all
2981  @end_compatibility
2982  """
2983  keepdims = False if keepdims is None else bool(keepdims)
2984  return _may_reduce_to_scalar(
2985      keepdims, axis,
2986      gen_math_ops._all(
2987          input_tensor, _ReductionDims(input_tensor, axis), keepdims,
2988          name=name))
2989
2990
2991@tf_export(v1=["math.reduce_any", "reduce_any"])
2992@dispatch.add_dispatch_support
2993@deprecation.deprecated_args(None,
2994                             "keep_dims is deprecated, use keepdims instead",
2995                             "keep_dims")
2996def reduce_any_v1(input_tensor,
2997                  axis=None,
2998                  keepdims=None,
2999                  name=None,
3000                  reduction_indices=None,
3001                  keep_dims=None):
3002  """Computes `tf.math.logical_or` of elements across dimensions of a tensor.
3003
3004  This is the reduction operation for the elementwise `tf.math.logical_or` op.
3005
3006  Reduces `input_tensor` along the dimensions given in `axis`.
3007  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
3008  of the entries in `axis`, which must be unique. If `keepdims` is true, the
3009  reduced dimensions are retained with length 1.
3010
3011  If `axis` is None, all dimensions are reduced, and a
3012  tensor with a single element is returned.
3013
3014  For example:
3015
3016    >>> x = tf.constant([[True,  True], [False, False]])
3017    >>> tf.reduce_any(x)
3018    <tf.Tensor: shape=(), dtype=bool, numpy=True>
3019    >>> tf.reduce_any(x, 0)
3020    <tf.Tensor: shape=(2,), dtype=bool, numpy=array([ True,  True])>
3021    >>> tf.reduce_any(x, 1)
3022    <tf.Tensor: shape=(2,), dtype=bool, numpy=array([ True, False])>
3023
3024  Args:
3025    input_tensor: The boolean tensor to reduce.
3026    axis: The dimensions to reduce. If `None` (the default), reduces all
3027      dimensions. Must be in the range `[-rank(input_tensor),
3028      rank(input_tensor))`.
3029    keepdims: If true, retains reduced dimensions with length 1.
3030    name: A name for the operation (optional).
3031    reduction_indices: The old (deprecated) name for axis.
3032    keep_dims: Deprecated alias for `keepdims`.
3033
3034  Returns:
3035    The reduced tensor.
3036
3037  @compatibility(numpy)
3038  Equivalent to np.any
3039  @end_compatibility
3040  """
3041  axis = deprecation.deprecated_argument_lookup("axis", axis,
3042                                                "reduction_indices",
3043                                                reduction_indices)
3044  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
3045                                                    "keep_dims", keep_dims)
3046  return reduce_any(input_tensor, axis, keepdims, name)
3047
3048
3049@tf_export("math.reduce_any", "reduce_any", v1=[])
3050@dispatch.add_dispatch_support
3051def reduce_any(input_tensor, axis=None, keepdims=False, name=None):
3052  """Computes `tf.math.logical_or` of elements across dimensions of a tensor.
3053
3054  This is the reduction operation for the elementwise `tf.math.logical_or` op.
3055
3056  Reduces `input_tensor` along the dimensions given in `axis`.
3057  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
3058  of the entries in `axis`, which must be unique. If `keepdims` is true, the
3059  reduced dimensions are retained with length 1.
3060
3061  If `axis` is None, all dimensions are reduced, and a
3062  tensor with a single element is returned.
3063
3064  For example:
3065
3066    >>> x = tf.constant([[True,  True], [False, False]])
3067    >>> tf.reduce_any(x)
3068    <tf.Tensor: shape=(), dtype=bool, numpy=True>
3069    >>> tf.reduce_any(x, 0)
3070    <tf.Tensor: shape=(2,), dtype=bool, numpy=array([ True,  True])>
3071    >>> tf.reduce_any(x, 1)
3072    <tf.Tensor: shape=(2,), dtype=bool, numpy=array([ True, False])>
3073
3074  Args:
3075    input_tensor: The boolean tensor to reduce.
3076    axis: The dimensions to reduce. If `None` (the default), reduces all
3077      dimensions. Must be in the range `[-rank(input_tensor),
3078      rank(input_tensor))`.
3079    keepdims: If true, retains reduced dimensions with length 1.
3080    name: A name for the operation (optional).
3081
3082  Returns:
3083    The reduced tensor.
3084
3085  @compatibility(numpy)
3086  Equivalent to np.any
3087  @end_compatibility
3088  """
3089  keepdims = False if keepdims is None else bool(keepdims)
3090  return _may_reduce_to_scalar(
3091      keepdims, axis,
3092      gen_math_ops._any(
3093          input_tensor, _ReductionDims(input_tensor, axis), keepdims,
3094          name=name))
3095
3096
3097@tf_export(v1=["math.reduce_logsumexp", "reduce_logsumexp"])
3098@dispatch.add_dispatch_support
3099@deprecation.deprecated_args(None,
3100                             "keep_dims is deprecated, use keepdims instead",
3101                             "keep_dims")
3102def reduce_logsumexp_v1(input_tensor,
3103                        axis=None,
3104                        keepdims=None,
3105                        name=None,
3106                        reduction_indices=None,
3107                        keep_dims=None):
3108  """Computes log(sum(exp(elements across dimensions of a tensor))).
3109
3110  Reduces `input_tensor` along the dimensions given in `axis`.
3111  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
3112  of the entries in `axis`, which must be unique. If `keepdims` is true, the
3113  reduced dimensions are retained with length 1.
3114
3115  If `axis` has no entries, all dimensions are reduced, and a
3116  tensor with a single element is returned.
3117
3118  This function is more numerically stable than log(sum(exp(input))). It avoids
3119  overflows caused by taking the exp of large inputs and underflows caused by
3120  taking the log of small inputs.
3121
3122  For example:
3123
3124  ```python
3125  x = tf.constant([[0., 0., 0.], [0., 0., 0.]])
3126  tf.reduce_logsumexp(x)  # log(6)
3127  tf.reduce_logsumexp(x, 0)  # [log(2), log(2), log(2)]
3128  tf.reduce_logsumexp(x, 1)  # [log(3), log(3)]
3129  tf.reduce_logsumexp(x, 1, keepdims=True)  # [[log(3)], [log(3)]]
3130  tf.reduce_logsumexp(x, [0, 1])  # log(6)
3131  ```
3132
3133  Args:
3134    input_tensor: The tensor to reduce. Should have numeric type.
3135    axis: The dimensions to reduce. If `None` (the default), reduces all
3136      dimensions. Must be in the range `[-rank(input_tensor),
3137      rank(input_tensor))`.
3138    keepdims: If true, retains reduced dimensions with length 1.
3139    name: A name for the operation (optional).
3140    reduction_indices: The old (deprecated) name for axis.
3141    keep_dims: Deprecated alias for `keepdims`.
3142
3143  Returns:
3144    The reduced tensor.
3145  """
3146  axis = deprecation.deprecated_argument_lookup("axis", axis,
3147                                                "reduction_indices",
3148                                                reduction_indices)
3149  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
3150                                                    "keep_dims", keep_dims)
3151  return reduce_logsumexp(input_tensor, axis, keepdims, name)
3152
3153
3154@tf_export("math.reduce_logsumexp", "reduce_logsumexp", v1=[])
3155@dispatch.add_dispatch_support
3156def reduce_logsumexp(input_tensor, axis=None, keepdims=False, name=None):
3157  """Computes log(sum(exp(elements across dimensions of a tensor))).
3158
3159  Reduces `input_tensor` along the dimensions given in `axis`.
3160  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
3161  of the entries in `axis`, which must be unique. If `keepdims` is true, the
3162  reduced dimensions are retained with length 1.
3163
3164  If `axis` has no entries, all dimensions are reduced, and a
3165  tensor with a single element is returned.
3166
3167  This function is more numerically stable than log(sum(exp(input))). It avoids
3168  overflows caused by taking the exp of large inputs and underflows caused by
3169  taking the log of small inputs.
3170
3171  For example:
3172
3173  ```python
3174  x = tf.constant([[0., 0., 0.], [0., 0., 0.]])
3175  tf.reduce_logsumexp(x)  # log(6)
3176  tf.reduce_logsumexp(x, 0)  # [log(2), log(2), log(2)]
3177  tf.reduce_logsumexp(x, 1)  # [log(3), log(3)]
3178  tf.reduce_logsumexp(x, 1, keepdims=True)  # [[log(3)], [log(3)]]
3179  tf.reduce_logsumexp(x, [0, 1])  # log(6)
3180  ```
3181
3182  Args:
3183    input_tensor: The tensor to reduce. Should have numeric type.
3184    axis: The dimensions to reduce. If `None` (the default), reduces all
3185      dimensions. Must be in the range `[-rank(input_tensor),
3186      rank(input_tensor))`.
3187    keepdims: If true, retains reduced dimensions with length 1.
3188    name: A name for the operation (optional).
3189
3190  Returns:
3191    The reduced tensor.
3192  """
3193  keepdims = False if keepdims is None else keepdims
3194  input_tensor = ops.convert_to_tensor(input_tensor)
3195  with ops.name_scope(name, "ReduceLogSumExp", [input_tensor]) as name:
3196    reduce_dim = _ReductionDims(input_tensor, axis)
3197    raw_max = reduce_max_with_dims(
3198        input_tensor, axis=axis, keepdims=True, dims=reduce_dim)
3199    my_max = array_ops.stop_gradient(
3200        gen_math_ops.select(
3201            gen_math_ops.is_finite(raw_max), raw_max,
3202            gen_array_ops.zeros_like(raw_max)))
3203    result = gen_math_ops.log(
3204        reduce_sum_with_dims(
3205            gen_math_ops.exp(gen_math_ops.sub(input_tensor, my_max)),
3206            axis=axis,
3207            keepdims=keepdims,
3208            dims=reduce_dim))
3209    if not keepdims:
3210      my_max = array_ops.reshape(my_max, gen_array_ops.shape(result))
3211    result = _add_dispatch(result, my_max, name=name)
3212    return _may_reduce_to_scalar(keepdims, axis, result)
3213
3214
3215@tf_export("linalg.trace", v1=["linalg.trace", "trace"])
3216@dispatch.add_dispatch_support
3217@deprecation.deprecated_endpoints("trace")
3218@dispatch.add_dispatch_support
3219def trace(x, name=None):
3220  """Compute the trace of a tensor `x`.
3221
3222  `trace(x)` returns the sum along the main diagonal of each inner-most matrix
3223  in x. If x is of rank `k` with shape `[I, J, K, ..., L, M, N]`, then output
3224  is a tensor of rank `k-2` with dimensions `[I, J, K, ..., L]` where
3225
3226  `output[i, j, k, ..., l] = trace(x[i, j, k, ..., l, :, :])`
3227
3228  For example:
3229
3230  ```python
3231  x = tf.constant([[1, 2], [3, 4]])
3232  tf.linalg.trace(x)  # 5
3233
3234  x = tf.constant([[1, 2, 3],
3235                   [4, 5, 6],
3236                   [7, 8, 9]])
3237  tf.linalg.trace(x)  # 15
3238
3239  x = tf.constant([[[1, 2, 3],
3240                    [4, 5, 6],
3241                    [7, 8, 9]],
3242                   [[-1, -2, -3],
3243                    [-4, -5, -6],
3244                    [-7, -8, -9]]])
3245  tf.linalg.trace(x)  # [15, -15]
3246  ```
3247
3248  Args:
3249    x: tensor.
3250    name: A name for the operation (optional).
3251
3252  Returns:
3253    The trace of input tensor.
3254  """
3255  with ops.name_scope(name, "Trace", [x]) as name:
3256    x = ops.convert_to_tensor(x, name="x")
3257    return reduce_sum(array_ops.matrix_diag_part(x), [-1], name=name)
3258
3259
3260@tf_export("linalg.matmul", "matmul")
3261@dispatch.add_dispatch_support
3262def matmul(a,
3263           b,
3264           transpose_a=False,
3265           transpose_b=False,
3266           adjoint_a=False,
3267           adjoint_b=False,
3268           a_is_sparse=False,
3269           b_is_sparse=False,
3270           name=None):
3271  """Multiplies matrix `a` by matrix `b`, producing `a` * `b`.
3272
3273  The inputs must, following any transpositions, be tensors of rank >= 2
3274  where the inner 2 dimensions specify valid matrix multiplication dimensions,
3275  and any further outer dimensions specify matching batch size.
3276
3277  Both matrices must be of the same type. The supported types are:
3278  `float16`, `float32`, `float64`, `int32`, `complex64`, `complex128`.
3279
3280  Either matrix can be transposed or adjointed (conjugated and transposed) on
3281  the fly by setting one of the corresponding flag to `True`. These are `False`
3282  by default.
3283
3284  If one or both of the matrices contain a lot of zeros, a more efficient
3285  multiplication algorithm can be used by setting the corresponding
3286  `a_is_sparse` or `b_is_sparse` flag to `True`. These are `False` by default.
3287  This optimization is only available for plain matrices (rank-2 tensors) with
3288  datatypes `bfloat16` or `float32`.
3289
3290  A simple 2-D tensor matrix multiplication:
3291
3292  >>> a = tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3])
3293  >>> a  # 2-D tensor
3294  <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
3295  array([[1, 2, 3],
3296         [4, 5, 6]], dtype=int32)>
3297  >>> b = tf.constant([7, 8, 9, 10, 11, 12], shape=[3, 2])
3298  >>> b  # 2-D tensor
3299  <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
3300  array([[ 7,  8],
3301         [ 9, 10],
3302         [11, 12]], dtype=int32)>
3303  >>> c = tf.matmul(a, b)
3304  >>> c  # `a` * `b`
3305  <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
3306  array([[ 58,  64],
3307         [139, 154]], dtype=int32)>
3308
3309  A batch matrix multiplication with batch shape [2]:
3310
3311  >>> a = tf.constant(np.arange(1, 13, dtype=np.int32), shape=[2, 2, 3])
3312  >>> a  # 3-D tensor
3313  <tf.Tensor: shape=(2, 2, 3), dtype=int32, numpy=
3314  array([[[ 1,  2,  3],
3315          [ 4,  5,  6]],
3316         [[ 7,  8,  9],
3317          [10, 11, 12]]], dtype=int32)>
3318  >>> b = tf.constant(np.arange(13, 25, dtype=np.int32), shape=[2, 3, 2])
3319  >>> b  # 3-D tensor
3320  <tf.Tensor: shape=(2, 3, 2), dtype=int32, numpy=
3321  array([[[13, 14],
3322          [15, 16],
3323          [17, 18]],
3324         [[19, 20],
3325          [21, 22],
3326          [23, 24]]], dtype=int32)>
3327  >>> c = tf.matmul(a, b)
3328  >>> c  # `a` * `b`
3329  <tf.Tensor: shape=(2, 2, 2), dtype=int32, numpy=
3330  array([[[ 94, 100],
3331          [229, 244]],
3332         [[508, 532],
3333          [697, 730]]], dtype=int32)>
3334
3335  Since python >= 3.5 the @ operator is supported
3336  (see [PEP 465](https://www.python.org/dev/peps/pep-0465/)). In TensorFlow,
3337  it simply calls the `tf.matmul()` function, so the following lines are
3338  equivalent:
3339
3340  >>> d = a @ b @ [[10], [11]]
3341  >>> d = tf.matmul(tf.matmul(a, b), [[10], [11]])
3342
3343  Args:
3344    a: `tf.Tensor` of type `float16`, `float32`, `float64`, `int32`,
3345      `complex64`, `complex128` and rank > 1.
3346    b: `tf.Tensor` with same type and rank as `a`.
3347    transpose_a: If `True`, `a` is transposed before multiplication.
3348    transpose_b: If `True`, `b` is transposed before multiplication.
3349    adjoint_a: If `True`, `a` is conjugated and transposed before
3350      multiplication.
3351    adjoint_b: If `True`, `b` is conjugated and transposed before
3352      multiplication.
3353    a_is_sparse: If `True`, `a` is treated as a sparse matrix. Notice, this
3354      **does not support `tf.sparse.SparseTensor`**, it just makes optimizations
3355      that assume most values in `a` are zero.
3356      See `tf.sparse.sparse_dense_matmul`
3357      for some support for `tf.sparse.SparseTensor` multiplication.
3358    b_is_sparse: If `True`, `b` is treated as a sparse matrix. Notice, this
3359      **does not support `tf.sparse.SparseTensor`**, it just makes optimizations
3360      that assume most values in `a` are zero.
3361      See `tf.sparse.sparse_dense_matmul`
3362      for some support for `tf.sparse.SparseTensor` multiplication.
3363    name: Name for the operation (optional).
3364
3365  Returns:
3366    A `tf.Tensor` of the same type as `a` and `b` where each inner-most matrix
3367    is the product of the corresponding matrices in `a` and `b`, e.g. if all
3368    transpose or adjoint attributes are `False`:
3369
3370    `output[..., i, j] = sum_k (a[..., i, k] * b[..., k, j])`,
3371    for all indices `i`, `j`.
3372
3373    Note: This is matrix product, not element-wise product.
3374
3375
3376  Raises:
3377    ValueError: If `transpose_a` and `adjoint_a`, or `transpose_b` and
3378      `adjoint_b` are both set to `True`.
3379  """
3380  with ops.name_scope(name, "MatMul", [a, b]) as name:
3381    if transpose_a and adjoint_a:
3382      raise ValueError("Only one of transpose_a and adjoint_a can be True.")
3383    if transpose_b and adjoint_b:
3384      raise ValueError("Only one of transpose_b and adjoint_b can be True.")
3385
3386    if context.executing_eagerly():
3387      if not isinstance(a, (ops.EagerTensor, _resource_variable_type)):
3388        a = ops.convert_to_tensor(a, name="a")
3389      if not isinstance(b, (ops.EagerTensor, _resource_variable_type)):
3390        b = ops.convert_to_tensor(b, dtype_hint=a.dtype.base_dtype, name="b")
3391    else:
3392      a = ops.convert_to_tensor(a, name="a")
3393      b = ops.convert_to_tensor(b, dtype_hint=a.dtype.base_dtype, name="b")
3394
3395    # TODO(apassos) remove _shape_tuple here when it is not needed.
3396    a_shape = a._shape_tuple()  # pylint: disable=protected-access
3397    b_shape = b._shape_tuple()  # pylint: disable=protected-access
3398
3399    output_may_have_non_empty_batch_shape = (
3400        (a_shape is None or len(a_shape) > 2) or
3401        (b_shape is None or len(b_shape) > 2))
3402
3403    if (not a_is_sparse and
3404        not b_is_sparse) and output_may_have_non_empty_batch_shape:
3405      # BatchMatmul does not support transpose, so we conjugate the matrix and
3406      # use adjoint instead. Conj() is a noop for real matrices.
3407      if transpose_a:
3408        a = conj(a)
3409        adjoint_a = True
3410      if transpose_b:
3411        b = conj(b)
3412        adjoint_b = True
3413      return gen_math_ops.batch_mat_mul_v2(
3414          a, b, adj_x=adjoint_a, adj_y=adjoint_b, name=name)
3415
3416    # Neither matmul nor sparse_matmul support adjoint, so we conjugate
3417    # the matrix and use transpose instead. Conj() is a noop for real
3418    # matrices.
3419    if adjoint_a:
3420      a = conj(a)
3421      transpose_a = True
3422    if adjoint_b:
3423      b = conj(b)
3424      transpose_b = True
3425
3426    use_sparse_matmul = False
3427    if a_is_sparse or b_is_sparse:
3428      sparse_matmul_types = [dtypes.bfloat16, dtypes.float32]
3429      use_sparse_matmul = (
3430          a.dtype in sparse_matmul_types and b.dtype in sparse_matmul_types)
3431    if ((a.dtype == dtypes.bfloat16 or b.dtype == dtypes.bfloat16) and
3432        a.dtype != b.dtype):
3433      # matmul currently doesn't handle mixed-precision inputs.
3434      use_sparse_matmul = True
3435    if use_sparse_matmul:
3436      ret = sparse_matmul(
3437          a,
3438          b,
3439          transpose_a=transpose_a,
3440          transpose_b=transpose_b,
3441          a_is_sparse=a_is_sparse,
3442          b_is_sparse=b_is_sparse,
3443          name=name)
3444      # sparse_matmul always returns float32, even with
3445      # bfloat16 inputs. This prevents us from configuring bfloat16 training.
3446      # casting to bfloat16 also matches non-sparse matmul behavior better.
3447      if a.dtype == dtypes.bfloat16 and b.dtype == dtypes.bfloat16:
3448        ret = cast(ret, dtypes.bfloat16)
3449      return ret
3450    else:
3451      return gen_math_ops.mat_mul(
3452          a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name)
3453
3454
3455@tf_export("linalg.matvec")
3456@dispatch.add_dispatch_support
3457def matvec(a,
3458           b,
3459           transpose_a=False,
3460           adjoint_a=False,
3461           a_is_sparse=False,
3462           b_is_sparse=False,
3463           name=None):
3464  """Multiplies matrix `a` by vector `b`, producing `a` * `b`.
3465
3466  The matrix `a` must, following any transpositions, be a tensor of rank >= 2,
3467  with `shape(a)[-1] == shape(b)[-1]`, and `shape(a)[:-2]` able to broadcast
3468  with `shape(b)[:-1]`.
3469
3470  Both `a` and `b` must be of the same type. The supported types are:
3471  `float16`, `float32`, `float64`, `int32`, `complex64`, `complex128`.
3472
3473  Matrix `a` can be transposed or adjointed (conjugated and transposed) on
3474  the fly by setting one of the corresponding flag to `True`. These are `False`
3475  by default.
3476
3477  If one or both of the inputs contain a lot of zeros, a more efficient
3478  multiplication algorithm can be used by setting the corresponding
3479  `a_is_sparse` or `b_is_sparse` flag to `True`. These are `False` by default.
3480  This optimization is only available for plain matrices/vectors (rank-2/1
3481  tensors) with datatypes `bfloat16` or `float32`.
3482
3483  For example:
3484
3485  ```python
3486  # 2-D tensor `a`
3487  # [[1, 2, 3],
3488  #  [4, 5, 6]]
3489  a = tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3])
3490
3491  # 1-D tensor `b`
3492  # [7, 9, 11]
3493  b = tf.constant([7, 9, 11], shape=[3])
3494
3495  # `a` * `b`
3496  # [ 58,  64]
3497  c = tf.linalg.matvec(a, b)
3498
3499
3500  # 3-D tensor `a`
3501  # [[[ 1,  2,  3],
3502  #   [ 4,  5,  6]],
3503  #  [[ 7,  8,  9],
3504  #   [10, 11, 12]]]
3505  a = tf.constant(np.arange(1, 13, dtype=np.int32),
3506                  shape=[2, 2, 3])
3507
3508  # 2-D tensor `b`
3509  # [[13, 14, 15],
3510  #  [16, 17, 18]]
3511  b = tf.constant(np.arange(13, 19, dtype=np.int32),
3512                  shape=[2, 3])
3513
3514  # `a` * `b`
3515  # [[ 86, 212],
3516  #  [410, 563]]
3517  c = tf.linalg.matvec(a, b)
3518  ```
3519
3520  Args:
3521    a: `Tensor` of type `float16`, `float32`, `float64`, `int32`, `complex64`,
3522      `complex128` and rank > 1.
3523    b: `Tensor` with same type as `a` and compatible dimensions.
3524    transpose_a: If `True`, `a` is transposed before multiplication.
3525    adjoint_a: If `True`, `a` is conjugated and transposed before
3526      multiplication.
3527    a_is_sparse: If `True`, `a` is treated as a sparse matrix.
3528    b_is_sparse: If `True`, `b` is treated as a sparse matrix.
3529    name: Name for the operation (optional).
3530
3531  Returns:
3532    A `Tensor` of the same type as `a` and `b` where each inner-most vector is
3533    the product of the corresponding matrices in `a` and vectors in `b`, e.g. if
3534    all transpose or adjoint attributes are `False`:
3535
3536    `output`[..., i] = sum_k (`a`[..., i, k] * `b`[..., k]), for all indices i.
3537
3538    Note: This is matrix-vector product, not element-wise product.
3539
3540
3541  Raises:
3542    ValueError: If transpose_a and adjoint_a are both set to True.
3543  """
3544  with ops.name_scope(name, "MatVec", [a, b]) as name:
3545    output = matmul(
3546        a,
3547        array_ops.expand_dims(b, axis=-1),
3548        transpose_a=transpose_a,
3549        adjoint_a=adjoint_a,
3550        a_is_sparse=a_is_sparse,
3551        b_is_sparse=b_is_sparse)
3552    return array_ops.squeeze(output, axis=-1)
3553
3554
3555# TODO(b/178650720): Also support numpy-style type promotion in freestanding TF
3556#   functions (e.g. tf.add).
3557def matmul_wrapper(a, b, name=None):  # pylint: disable=missing-function-docstring
3558  if ops._numpy_style_type_promotion:
3559    return a._matmul(b)
3560  return matmul(a, b, name=name)
3561matmul_wrapper.__doc__ = matmul.__doc__
3562_OverrideBinaryOperatorHelper(matmul_wrapper, "matmul")
3563
3564sparse_matmul = deprecation.deprecated(None, "Use `tf.linalg.matmul` instead")(
3565    gen_math_ops.sparse_mat_mul)
3566tf_export(v1=["sparse_matmul"])(sparse_matmul)
3567@dispatch.add_dispatch_support
3568
3569
3570@ops.RegisterStatistics("MatMul", "flops")
3571def _calc_mat_mul_flops(graph, node):
3572  """Calculates the compute resources needed for MatMul."""
3573  transpose_a = node.attr["transpose_a"].b
3574  a_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
3575  a_shape.assert_is_fully_defined()
3576  if transpose_a:
3577    k = int(a_shape[0])
3578  else:
3579    k = int(a_shape[1])
3580  output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
3581  output_shape.assert_is_fully_defined()
3582  output_count = np.prod(output_shape.as_list())
3583  return ops.OpStats("flops", (k * output_count * 2))
3584
3585
3586@ops.RegisterStatistics("BatchMatMul", "flops")
3587@ops.RegisterStatistics("BatchMatMulV2", "flops")
3588def _calc_batch_mat_mul_flops(graph, node):
3589  """Calculates the compute resources needed for BatchMatMul."""
3590  transpose_a = node.attr["transpose_a"].b
3591  a_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
3592  a_shape.assert_is_fully_defined()
3593  if transpose_a:
3594    k = int(a_shape[-2])
3595  else:
3596    k = int(a_shape[-1])
3597  output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
3598  output_shape.assert_is_fully_defined()
3599  output_count = np.prod(output_shape.as_list())
3600  return ops.OpStats("flops", (k * output_count * 2))
3601
3602
3603def _as_indexed_slices(x, optimize=True):
3604  """Convert 'x' to IndexedSlices.
3605
3606  Convert a dense Tensor to a block-sparse IndexedSlices.
3607
3608  Args:
3609    x: Either a Tensor object, or an IndexedSlices object.
3610    optimize: if true, attempt to optimize the conversion of 'x'.
3611
3612  Returns:
3613    An IndexedSlices object.
3614
3615  Raises:
3616    TypeError: If 'x' is not a Tensor or an IndexedSlices object.
3617  """
3618  # TODO(touts): op_scope
3619  if not isinstance(x, (ops.Tensor, ops.IndexedSlices)):
3620    raise TypeError("Not a Tensor or IndexedSlices: %s" % type(x))
3621  if isinstance(x, ops.IndexedSlices):
3622    return x
3623  x_shape = array_ops.shape_internal(x, optimize=optimize)
3624  return ops.IndexedSlices(x, range(0, x_shape[0]), x_shape)
3625
3626
3627def _as_indexed_slices_list(inputs, optimize=True):
3628  """Convert all elements of 'inputs' to IndexedSlices.
3629
3630  Additionally, homogenize the types of all the indices to
3631  either int32 or int64.
3632
3633  Args:
3634    inputs: List containing either Tensor or IndexedSlices objects.
3635    optimize: if true, attempt to optimize the conversion of each input.
3636
3637  Returns:
3638    A list of IndexedSlices objects.
3639
3640  Raises:
3641    TypeError: If 'inputs' is not a list or a tuple.
3642  """
3643  if not isinstance(inputs, (list, tuple)):
3644    raise TypeError("Expected a list or tuple, not a %s" % type(inputs))
3645  outputs = [_as_indexed_slices(i, optimize=optimize) for i in inputs]
3646  with_int32_index = [
3647      o.indices for o in outputs if o.indices.dtype == dtypes.int32
3648  ]
3649  if not with_int32_index or len(with_int32_index) == len(outputs):
3650    return outputs
3651  casted_outputs = []
3652  for o in outputs:
3653    if o.indices.dtype == dtypes.int32:
3654      casted_outputs.append(
3655          ops.IndexedSlices(o.values, cast(o.indices, dtypes.int64),
3656                            o.dense_shape))
3657    else:
3658      casted_outputs.append(o)
3659  return casted_outputs
3660
3661
3662@tf_export("math.add_n", "add_n")
3663@dispatch.add_dispatch_support
3664def add_n(inputs, name=None):
3665  """Adds all input tensors element-wise.
3666
3667  `tf.math.add_n` performs the same operation as `tf.math.accumulate_n`, but it
3668  waits for all of its inputs to be ready before beginning to sum.
3669  This buffering can result in higher memory consumption when inputs are ready
3670  at different times, since the minimum temporary storage required is
3671  proportional to the input size rather than the output size.
3672
3673  This op does not [broadcast](
3674  https://docs.scipy.org/doc/numpy-1.13.0/user/basics.broadcasting.html)
3675  its inputs. If you need broadcasting, use `tf.math.add` (or the `+` operator)
3676  instead.
3677
3678  For example:
3679
3680  >>> a = tf.constant([[3, 5], [4, 8]])
3681  >>> b = tf.constant([[1, 6], [2, 9]])
3682  >>> tf.math.add_n([a, b, a])
3683  <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
3684  array([[ 7, 16],
3685         [10, 25]], dtype=int32)>
3686
3687  Args:
3688    inputs: A list of `tf.Tensor` or `tf.IndexedSlices` objects, each with the
3689      same shape and type. `tf.IndexedSlices` objects will be converted into
3690      dense tensors prior to adding.
3691    name: A name for the operation (optional).
3692
3693  Returns:
3694    A `tf.Tensor` of the same shape and type as the elements of `inputs`.
3695
3696  Raises:
3697    ValueError: If `inputs` don't all have same shape and dtype or the shape
3698    cannot be inferred.
3699  """
3700  if not inputs or not isinstance(inputs, collections_abc.Iterable):
3701    raise ValueError("inputs must be an iterable of at least one "
3702                     "Tensor/IndexedSlices with the same dtype and shape")
3703  inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs)
3704  if not all(isinstance(x, (ops.Tensor, ops.IndexedSlices)) for x in inputs):
3705    raise ValueError("inputs must be an iterable of at least one "
3706                     "Tensor/IndexedSlices with the same dtype and shape")
3707
3708  if len(inputs) == 1:
3709    if isinstance(inputs[0], ops.IndexedSlices):
3710      values = ops.convert_to_tensor(inputs[0])
3711    else:
3712      values = inputs[0]
3713    if name:
3714      return array_ops.identity(values, name=name)
3715    return values
3716  return gen_math_ops.add_n(inputs, name=name)
3717
3718
3719@tf_export("math.accumulate_n", v1=["math.accumulate_n", "accumulate_n"])
3720@dispatch.add_dispatch_support
3721@deprecation.deprecated_endpoints("accumulate_n")
3722def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
3723  """Returns the element-wise sum of a list of tensors.
3724
3725  Optionally, pass `shape` and `tensor_dtype` for shape and type checking,
3726  otherwise, these are inferred.
3727
3728  `accumulate_n` performs the same operation as `tf.math.add_n`.
3729
3730  For example:
3731
3732  ```python
3733  a = tf.constant([[1, 2], [3, 4]])
3734  b = tf.constant([[5, 0], [0, 6]])
3735  tf.math.accumulate_n([a, b, a])  # [[7, 4], [6, 14]]
3736
3737  # Explicitly pass shape and type
3738  tf.math.accumulate_n([a, b, a], shape=[2, 2], tensor_dtype=tf.int32)
3739                                                                 # [[7,  4],
3740                                                                 #  [6, 14]]
3741  ```
3742
3743  Args:
3744    inputs: A list of `Tensor` objects, each with same shape and type.
3745    shape: Expected shape of elements of `inputs` (optional). Also controls the
3746      output shape of this op, which may affect type inference in other ops. A
3747      value of `None` means "infer the input shape from the shapes in `inputs`".
3748    tensor_dtype: Expected data type of `inputs` (optional). A value of `None`
3749      means "infer the input dtype from `inputs[0]`".
3750    name: A name for the operation (optional).
3751
3752  Returns:
3753    A `Tensor` of same shape and type as the elements of `inputs`.
3754
3755  Raises:
3756    ValueError: If `inputs` don't all have same shape and dtype or the shape
3757    cannot be inferred.
3758  """
3759
3760  def _input_error():
3761    return ValueError("inputs must be a list of at least one Tensor with the "
3762                      "same dtype and shape")
3763
3764  if not inputs or not isinstance(inputs, (list, tuple)):
3765    raise _input_error()
3766  inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs)
3767  if not all(isinstance(x, ops.Tensor) for x in inputs):
3768    raise _input_error()
3769  if not all(x.dtype == inputs[0].dtype for x in inputs):
3770    raise _input_error()
3771  if shape is not None:
3772    shape = tensor_shape.as_shape(shape)
3773  else:
3774    shape = tensor_shape.unknown_shape()
3775  for input_tensor in inputs:
3776    if isinstance(input_tensor, ops.Tensor):
3777      shape = shape.merge_with(input_tensor.get_shape())
3778
3779  # tensor_dtype is for safety only; operator's output type computed in C++
3780  if tensor_dtype is not None and tensor_dtype != inputs[0].dtype:
3781    raise TypeError("tensor_dtype is {}, but input is of type {}".format(
3782        tensor_dtype, inputs[0].dtype))
3783
3784  if len(inputs) == 1 and name is None:
3785    return inputs[0]
3786  elif len(inputs) == 1 and name is not None:
3787    return array_ops.identity(inputs[0], name=name)
3788  return add_n(inputs, name=name)
3789
3790
3791@ops.RegisterGradient("AccumulateNV2")
3792def _accumulate_n_grad(op, grad):
3793  """Same as gradient for AddN. Copies the gradient to all inputs."""
3794  # Not broadcasting.
3795  return [grad] * len(op.inputs)
3796
3797
3798@tf_export("math.sigmoid", "nn.sigmoid", "sigmoid")
3799@dispatch.add_dispatch_support
3800def sigmoid(x, name=None):
3801  r"""Computes sigmoid of `x` element-wise.
3802
3803  Formula for calculating $\mathrm{sigmoid}(x) = y = 1 / (1 + \exp(-x))$.
3804
3805  For $x \in (-\infty, \infty)$, $\mathrm{sigmoid}(x) \in (0, 1)$.
3806
3807  Example Usage:
3808
3809  If a positive number is large, then its sigmoid will approach to 1 since the
3810  formula will be `y = <large_num> / (1 + <large_num>)`
3811
3812  >>> x = tf.constant([0.0, 1.0, 50.0, 100.0])
3813  >>> tf.math.sigmoid(x)
3814  <tf.Tensor: shape=(4,), dtype=float32,
3815  numpy=array([0.5      , 0.7310586, 1.       , 1.       ], dtype=float32)>
3816
3817  If a negative number is large, its sigmoid will approach to 0 since the
3818  formula will be `y = 1 / (1 + <large_num>)`
3819
3820  >>> x = tf.constant([-100.0, -50.0, -1.0, 0.0])
3821  >>> tf.math.sigmoid(x)
3822  <tf.Tensor: shape=(4,), dtype=float32, numpy=
3823  array([0.0000000e+00, 1.9287499e-22, 2.6894143e-01, 0.5],
3824        dtype=float32)>
3825
3826  Args:
3827    x: A Tensor with type `float16`, `float32`, `float64`, `complex64`, or
3828      `complex128`.
3829    name: A name for the operation (optional).
3830
3831  Returns:
3832    A Tensor with the same type as `x`.
3833
3834  Usage Example:
3835
3836  >>> x = tf.constant([-128.0, 0.0, 128.0], dtype=tf.float32)
3837  >>> tf.sigmoid(x)
3838  <tf.Tensor: shape=(3,), dtype=float32,
3839  numpy=array([0. , 0.5, 1. ], dtype=float32)>
3840
3841  @compatibility(scipy)
3842  Equivalent to scipy.special.expit
3843  @end_compatibility
3844  """
3845  with ops.name_scope(name, "Sigmoid", [x]) as name:
3846    x = ops.convert_to_tensor(x, name="x")
3847    return gen_math_ops.sigmoid(x, name=name)
3848
3849
3850@tf_export("math.log_sigmoid", v1=["math.log_sigmoid", "log_sigmoid"])
3851@dispatch.add_dispatch_support
3852@deprecation.deprecated_endpoints("log_sigmoid")
3853def log_sigmoid(x, name=None):
3854  """Computes log sigmoid of `x` element-wise.
3855
3856  Specifically, `y = log(1 / (1 + exp(-x)))`.  For numerical stability,
3857  we use `y = -tf.nn.softplus(-x)`.
3858
3859  Args:
3860    x: A Tensor with type `float32` or `float64`.
3861    name: A name for the operation (optional).
3862
3863  Returns:
3864    A Tensor with the same type as `x`.
3865
3866  Usage Example:
3867
3868  If a positive number is large, then its log_sigmoid will approach to 0 since
3869  the formula will be `y = log( <large_num> / (1 + <large_num>) )` which
3870  approximates to `log (1)` which is 0.
3871
3872  >>> x = tf.constant([0.0, 1.0, 50.0, 100.0])
3873  >>> tf.math.log_sigmoid(x)
3874  <tf.Tensor: shape=(4,), dtype=float32, numpy=
3875  array([-6.9314718e-01, -3.1326169e-01, -1.9287499e-22, -0.0000000e+00],
3876        dtype=float32)>
3877
3878  If a negative number is large, its log_sigmoid will approach to the number
3879  itself since the formula will be `y = log( 1 / (1 + <large_num>) )` which is
3880  `log (1) - log ( (1 + <large_num>) )` which approximates to `- <large_num>`
3881  that is the number itself.
3882
3883  >>> x = tf.constant([-100.0, -50.0, -1.0, 0.0])
3884  >>> tf.math.log_sigmoid(x)
3885  <tf.Tensor: shape=(4,), dtype=float32, numpy=
3886  array([-100.       ,  -50.       ,   -1.3132616,   -0.6931472],
3887        dtype=float32)>
3888  """
3889  with ops.name_scope(name, "LogSigmoid", [x]) as name:
3890    x = ops.convert_to_tensor(x, name="x")
3891    return gen_math_ops.neg(gen_nn_ops.softplus(-x), name=name)
3892
3893
3894@tf_export("math.cumsum", "cumsum")
3895@dispatch.add_dispatch_support
3896def cumsum(x, axis=0, exclusive=False, reverse=False, name=None):
3897  """Compute the cumulative sum of the tensor `x` along `axis`.
3898
3899  By default, this op performs an inclusive cumsum, which means that the first
3900  element of the input is identical to the first element of the output:
3901  For example:
3902
3903  >>> # tf.cumsum([a, b, c])   # [a, a + b, a + b + c]
3904  >>> x = tf.constant([2, 4, 6, 8])
3905  >>> tf.cumsum(x)
3906  <tf.Tensor: shape=(4,), dtype=int32,
3907  numpy=array([ 2,  6, 12, 20], dtype=int32)>
3908
3909  >>> # using varying `axis` values
3910  >>> y = tf.constant([[2, 4, 6, 8], [1,3,5,7]])
3911  >>> tf.cumsum(y, axis=0)
3912  <tf.Tensor: shape=(2, 4), dtype=int32, numpy=
3913  array([[ 2,  4,  6,  8],
3914         [ 3,  7, 11, 15]], dtype=int32)>
3915  >>> tf.cumsum(y, axis=1)
3916  <tf.Tensor: shape=(2, 4), dtype=int32, numpy=
3917  array([[ 2,  6, 12, 20],
3918         [ 1,  4,  9, 16]], dtype=int32)>
3919
3920  By setting the `exclusive` kwarg to `True`, an exclusive cumsum is performed
3921  instead:
3922
3923  >>> # tf.cumsum([a, b, c], exclusive=True)  => [0, a, a + b]
3924  >>> x = tf.constant([2, 4, 6, 8])
3925  >>> tf.cumsum(x, exclusive=True)
3926  <tf.Tensor: shape=(4,), dtype=int32,
3927  numpy=array([ 0,  2,  6, 12], dtype=int32)>
3928
3929  By setting the `reverse` kwarg to `True`, the cumsum is performed in the
3930  opposite direction:
3931
3932  >>> # tf.cumsum([a, b, c], reverse=True)  # [a + b + c, b + c, c]
3933  >>> x = tf.constant([2, 4, 6, 8])
3934  >>> tf.cumsum(x, reverse=True)
3935  <tf.Tensor: shape=(4,), dtype=int32,
3936  numpy=array([20, 18, 14,  8], dtype=int32)>
3937
3938  This is more efficient than using separate `tf.reverse` ops.
3939  The `reverse` and `exclusive` kwargs can also be combined:
3940
3941  >>> # tf.cumsum([a, b, c], exclusive=True, reverse=True)  # [b + c, c, 0]
3942  >>> x = tf.constant([2, 4, 6, 8])
3943  >>> tf.cumsum(x, exclusive=True, reverse=True)
3944  <tf.Tensor: shape=(4,), dtype=int32,
3945  numpy=array([18, 14,  8,  0], dtype=int32)>
3946
3947  Args:
3948    x: A `Tensor`. Must be one of the following types: `float32`, `float64`,
3949      `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`,
3950      `complex128`, `qint8`, `quint8`, `qint32`, `half`.
3951    axis: A `Tensor` of type `int32` (default: 0). Must be in the range
3952      `[-rank(x), rank(x))`.
3953    exclusive: If `True`, perform exclusive cumsum.
3954    reverse: A `bool` (default: False).
3955    name: A name for the operation (optional).
3956
3957  Returns:
3958    A `Tensor`. Has the same type as `x`.
3959  """
3960  with ops.name_scope(name, "Cumsum", [x]) as name:
3961    x = ops.convert_to_tensor(x, name="x")
3962    return gen_math_ops.cumsum(
3963        x, axis, exclusive=exclusive, reverse=reverse, name=name)
3964
3965
3966@tf_export("math.cumprod", v1=["math.cumprod", "cumprod"])
3967@dispatch.add_dispatch_support
3968@deprecation.deprecated_endpoints("cumprod")
3969def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
3970  """Compute the cumulative product of the tensor `x` along `axis`.
3971
3972  By default, this op performs an inclusive cumprod, which means that the
3973  first element of the input is identical to the first element of the output:
3974
3975  ```python
3976  tf.math.cumprod([a, b, c])  # [a, a * b, a * b * c]
3977  ```
3978
3979  By setting the `exclusive` kwarg to `True`, an exclusive cumprod is
3980  performed
3981  instead:
3982
3983  ```python
3984  tf.math.cumprod([a, b, c], exclusive=True)  # [1, a, a * b]
3985  ```
3986
3987  By setting the `reverse` kwarg to `True`, the cumprod is performed in the
3988  opposite direction:
3989
3990  ```python
3991  tf.math.cumprod([a, b, c], reverse=True)  # [a * b * c, b * c, c]
3992  ```
3993
3994  This is more efficient than using separate `tf.reverse` ops.
3995  The `reverse` and `exclusive` kwargs can also be combined:
3996
3997  ```python
3998  tf.math.cumprod([a, b, c], exclusive=True, reverse=True)  # [b * c, c, 1]
3999  ```
4000
4001  Args:
4002    x: A `Tensor`. Must be one of the following types: `float32`, `float64`,
4003      `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`,
4004      `complex128`, `qint8`, `quint8`, `qint32`, `half`.
4005    axis: A `Tensor` of type `int32` (default: 0). Must be in the range
4006      `[-rank(x), rank(x))`.
4007    exclusive: If `True`, perform exclusive cumprod.
4008    reverse: A `bool` (default: False).
4009    name: A name for the operation (optional).
4010
4011  Returns:
4012    A `Tensor`. Has the same type as `x`.
4013  """
4014  with ops.name_scope(name, "Cumprod", [x]) as name:
4015    x = ops.convert_to_tensor(x, name="x")
4016    return gen_math_ops.cumprod(
4017        x, axis, exclusive=exclusive, reverse=reverse, name=name)
4018
4019
4020@tf_export("math.cumulative_logsumexp", v1=["math.cumulative_logsumexp"])
4021@dispatch.add_dispatch_support
4022def cumulative_logsumexp(x, axis=0, exclusive=False, reverse=False, name=None):
4023  """Compute the cumulative log-sum-exp of the tensor `x` along `axis`.
4024
4025  By default, this op performs an inclusive cumulative log-sum-exp, which means
4026  that the first element of the input is identical to the first element of
4027  the output.
4028
4029  This operation is significantly more numerically stable than the equivalent
4030  tensorflow operation `tf.math.log(tf.math.cumsum(tf.math.exp(x)))`, although
4031  computes the same result given infinite numerical precision. However, note
4032  that in some cases, it may be less stable than `tf.math.reduce_logsumexp`
4033  for a given element, as it applies the "log-sum-exp trick" in a different
4034  way.
4035
4036  More precisely, where `tf.math.reduce_logsumexp` uses the following trick:
4037
4038  ```
4039  log(sum(exp(x))) == log(sum(exp(x - max(x)))) + max(x)
4040  ```
4041
4042  it cannot be directly used here as there is no fast way of applying it
4043  to each prefix `x[:i]`. Instead, this function implements a prefix
4044  scan using pairwise log-add-exp, which is a commutative and associative
4045  (up to floating point precision) operator:
4046
4047  ```
4048  log_add_exp(x, y) = log(exp(x) + exp(y))
4049                    = log(1 + exp(min(x, y) - max(x, y))) + max(x, y)
4050  ```
4051
4052  However, reducing using the above operator leads to a different computation
4053  tree (logs are taken repeatedly instead of only at the end), and the maximum
4054  is only computed pairwise instead of over the entire prefix. In general, this
4055  leads to a different and slightly less precise computation.
4056
4057  Args:
4058    x: A `Tensor`. Must be one of the following types: `float16`, `float32`,
4059      `float64`.
4060    axis: A `Tensor` of type `int32` or `int64` (default: 0). Must be in the
4061      range `[-rank(x), rank(x))`.
4062    exclusive: If `True`, perform exclusive cumulative log-sum-exp.
4063    reverse: If `True`, performs the cumulative log-sum-exp in the reverse
4064      direction.
4065    name: A name for the operation (optional).
4066
4067  Returns:
4068    A `Tensor`. Has the same shape and type as `x`.
4069  """
4070  with ops.name_scope(name, "CumulativeLogsumexp", [x]) as name:
4071    x = ops.convert_to_tensor(x, name="x")
4072    return gen_math_ops.cumulative_logsumexp(
4073        x, axis, exclusive=exclusive, reverse=reverse, name=name)
4074
4075
4076@tf_export("math.conj", v1=["math.conj", "conj"])
4077@dispatch.add_dispatch_support
4078@deprecation.deprecated_endpoints("conj")
4079def conj(x, name=None):
4080  r"""Returns the complex conjugate of a complex number.
4081
4082  Given a tensor `x` of complex numbers, this operation returns a tensor of
4083  complex numbers that are the complex conjugate of each element in `x`. The
4084  complex numbers in `x` must be of the form \\(a + bj\\), where `a` is the
4085  real part and `b` is the imaginary part.
4086
4087  The complex conjugate returned by this operation is of the form \\(a - bj\\).
4088
4089  For example:
4090
4091  >>> x = tf.constant([-2.25 + 4.75j, 3.25 + 5.75j])
4092  >>> tf.math.conj(x)
4093  <tf.Tensor: shape=(2,), dtype=complex128,
4094  numpy=array([-2.25-4.75j,  3.25-5.75j])>
4095
4096  If `x` is real, it is returned unchanged.
4097
4098  For example:
4099
4100  >>> x = tf.constant([-2.25, 3.25])
4101  >>> tf.math.conj(x)
4102  <tf.Tensor: shape=(2,), dtype=float32,
4103  numpy=array([-2.25,  3.25], dtype=float32)>
4104
4105  Args:
4106    x: `Tensor` to conjugate.  Must have numeric or variant type.
4107    name: A name for the operation (optional).
4108
4109  Returns:
4110    A `Tensor` that is the conjugate of `x` (with the same type).
4111
4112  Raises:
4113    TypeError: If `x` is not a numeric tensor.
4114
4115  @compatibility(numpy)
4116  Equivalent to numpy.conj.
4117  @end_compatibility
4118  """
4119  if isinstance(x, ops.Tensor):
4120    dt = x.dtype
4121    if dt.is_floating or dt.is_integer:
4122      return x
4123  with ops.name_scope(name, "Conj", [x]) as name:
4124    x = ops.convert_to_tensor(x, name="x")
4125    if x.dtype.is_complex or x.dtype == dtypes.variant:
4126      return gen_math_ops.conj(x, name=name)
4127    elif x.dtype.is_floating or x.dtype.is_integer:
4128      return x
4129    else:
4130      raise TypeError("Expected numeric or variant tensor, got dtype %r" %
4131                      x.dtype)
4132
4133
4134def reduced_shape(input_shape, axes):
4135  """Helper function for reduction ops.
4136
4137  Args:
4138    input_shape: 1-D Tensor, the shape of the Tensor being reduced.
4139    axes: 1-D Tensor, the reduction axes.
4140
4141  Returns:
4142    A 1-D Tensor, the output shape as if keepdims were set to True.
4143  """
4144  # TODO(allenl): Refactor `reduced_shape` to take the tensor corresponding to
4145  # `input_shape` rather than `tf.shape` of it. Then we can check if the shape
4146  # is fully defined here, which may be faster executing eagerly than running
4147  # `tf.shape` and then fetching its constant value.
4148  constant_input_shape = tensor_util.constant_value(input_shape)
4149  if constant_input_shape is not None:
4150    constant_axes = tensor_util.constant_value(axes)
4151    if constant_axes is not None:
4152      constant_axes = np.array(constant_axes, dtype=np.int32)
4153      constant_input_shape = np.array(constant_input_shape, dtype=np.int32)
4154      constant_input_shape[constant_axes] = 1
4155      return constant_input_shape
4156
4157  # Example:
4158  # cast needed for SparseTensor reductions
4159  input_shape = cast(input_shape, dtypes.int32)  # [2, 3, 5, 7]
4160  axes = cast(axes, dtypes.int32)  # [1, 2]
4161
4162  input_rank = array_ops.size(input_shape)  # 4
4163  axes = (axes + input_rank) % input_rank
4164  axes_shape = array_ops.shape(axes)  # [2]
4165  return gen_data_flow_ops.dynamic_stitch(  # [2, 1, 1, 7]
4166      [
4167          range(input_rank),  # [0, 1, 2, 3]
4168          axes
4169      ],  # [1, 2]
4170      [
4171          input_shape,  # [2, 3, 5, 7]
4172          array_ops.fill(axes_shape, 1)
4173      ])  # [1, 1]
4174
4175
4176def _unsorted_segment_N(data, segment_ids, num_segments):
4177  """ Helper function for unsorted_segment_mean/_sqrtN.
4178
4179  Computes the number
4180      of segment entries with 0-entries set to 1 to allow division by N.
4181  """
4182  num_segments = ops.convert_to_tensor(num_segments)
4183  # bincount doesn't support negative indices so we use unsorted_segment_sum
4184  segment_ids_shape = array_ops.shape_internal(segment_ids)
4185  ones_tensor = array_ops.ones(segment_ids_shape, dtype=data.dtype)
4186  n = gen_math_ops.unsorted_segment_sum(ones_tensor, segment_ids, num_segments)
4187  # add dimensions for all non-reduced axes
4188  broadcastable_shape = array_ops.concat(
4189      [num_segments[array_ops.newaxis],
4190       array_ops.ones([array_ops.rank(data)
4191                       - array_ops.rank(segment_ids)],
4192                      dtype=num_segments.dtype)],
4193      axis=0)
4194  n = array_ops.reshape(n, broadcastable_shape)
4195  return gen_math_ops.maximum(n, 1)
4196
4197
4198@tf_export(
4199    "math.unsorted_segment_mean",
4200    v1=["math.unsorted_segment_mean", "unsorted_segment_mean"])
4201@dispatch.add_dispatch_support
4202@deprecation.deprecated_endpoints("unsorted_segment_mean")
4203@dispatch.add_dispatch_support
4204def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
4205  r"""Computes the mean along segments of a tensor.
4206
4207  Read [the section on
4208  segmentation](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/math#about_segmentation)
4209  for an explanation of segments.
4210
4211  This operator is similar to the `tf.math.unsorted_segment_sum` operator.
4212  Instead of computing the sum over segments, it computes the mean of all
4213  entries belonging to a segment such that:
4214
4215  \\(output_i = 1/N_i \sum_{j...} data[j...]\\) where the sum is over tuples
4216  `j...` such that `segment_ids[j...] == i` with \\N_i\\ being the number of
4217  occurrences of id \\i\\.
4218
4219  If there is no entry for a given segment ID `i`, it outputs 0.
4220
4221  If the given segment ID `i` is negative, the value is dropped and will not
4222  be added to the sum of the segment.
4223
4224  Args:
4225    data: A `Tensor` with floating point or complex dtype.
4226    segment_ids: An integer tensor whose shape is a prefix of `data.shape`.
4227    num_segments: An integer scalar `Tensor`.  The number of distinct segment
4228      IDs.
4229    name: A name for the operation (optional).
4230
4231  Returns:
4232    A `Tensor`.  Has same shape as data, except for the first `segment_ids.rank`
4233    dimensions, which are replaced with a single dimension which has size
4234   `num_segments`.
4235  """
4236  with ops.name_scope(name, "UnsortedSegmentMean"):
4237    data = ops.convert_to_tensor(data)
4238    segment_ids = ops.convert_to_tensor(segment_ids)
4239    N = _unsorted_segment_N(data, segment_ids, num_segments)
4240    summed = gen_math_ops.unsorted_segment_sum(data, segment_ids, num_segments)
4241    return summed / N
4242
4243
4244@tf_export(
4245    "math.unsorted_segment_sqrt_n",
4246    v1=["math.unsorted_segment_sqrt_n", "unsorted_segment_sqrt_n"])
4247@dispatch.add_dispatch_support
4248@deprecation.deprecated_endpoints("unsorted_segment_sqrt_n")
4249@dispatch.add_dispatch_support
4250def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None):
4251  r"""Computes the sum along segments of a tensor divided by the sqrt(N).
4252
4253  Read [the section on
4254  segmentation](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/math#about_segmentation)
4255  for an explanation of segments.
4256
4257  This operator is similar to the `tf.math.unsorted_segment_sum` operator.
4258  Additionally to computing the sum over segments, it divides the results by
4259  sqrt(N).
4260
4261  \\(output_i = 1/sqrt(N_i) \sum_{j...} data[j...]\\) where the sum is over
4262  tuples `j...` such that `segment_ids[j...] == i` with \\N_i\\ being the
4263  number of occurrences of id \\i\\.
4264
4265  If there is no entry for a given segment ID `i`, it outputs 0.
4266
4267  Note that this op only supports floating point and complex dtypes,
4268  due to tf.sqrt only supporting these types.
4269
4270  If the given segment ID `i` is negative, the value is dropped and will not
4271  be added to the sum of the segment.
4272
4273  Args:
4274    data: A `Tensor` with floating point or complex dtype.
4275    segment_ids: An integer tensor whose shape is a prefix of `data.shape`.
4276    num_segments: An integer scalar `Tensor`.  The number of distinct segment
4277      IDs.
4278    name: A name for the operation (optional).
4279
4280  Returns:
4281    A `Tensor`.  Has same shape as data, except for the first `segment_ids.rank`
4282    dimensions, which are replaced with a single dimension which has size
4283   `num_segments`.
4284  """
4285  with ops.name_scope(name, "UnsortedSegmentSqrtN"):
4286    data = ops.convert_to_tensor(data)
4287    segment_ids = ops.convert_to_tensor(segment_ids)
4288    N = _unsorted_segment_N(data, segment_ids, num_segments)
4289    summed = gen_math_ops.unsorted_segment_sum(data, segment_ids, num_segments)
4290    return summed / gen_math_ops.sqrt(N)
4291
4292
4293@tf_export(v1=["sparse.segment_sum", "sparse_segment_sum"])
4294@deprecation.deprecated_endpoints("sparse_segment_sum")
4295def sparse_segment_sum(data,
4296                       indices,
4297                       segment_ids,
4298                       name=None,
4299                       num_segments=None):
4300  r"""Computes the sum along sparse segments of a tensor.
4301
4302  Read [the section on
4303  segmentation](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/math#about_segmentation)
4304  for an explanation of segments.
4305
4306  Like `tf.math.segment_sum`, but `segment_ids` can have rank less than `data`'s
4307  first dimension, selecting a subset of dimension 0, specified by `indices`.
4308  `segment_ids` is allowed to have missing ids, in which case the output will
4309  be zeros at those indices. In those cases `num_segments` is used to determine
4310  the size of the output.
4311
4312  For example:
4313
4314  ```python
4315  c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
4316
4317  # Select two rows, one segment.
4318  tf.sparse.segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0]))
4319  # => [[0 0 0 0]]
4320
4321  # Select two rows, two segment.
4322  tf.sparse.segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1]))
4323  # => [[ 1  2  3  4]
4324  #     [-1 -2 -3 -4]]
4325
4326  # With missing segment ids.
4327  tf.sparse.segment_sum(c, tf.constant([0, 1]), tf.constant([0, 2]),
4328                        num_segments=4)
4329  # => [[ 1  2  3  4]
4330  #     [ 0  0  0  0]
4331  #     [-1 -2 -3 -4]
4332  #     [ 0  0  0  0]]
4333
4334  # Select all rows, two segments.
4335  tf.sparse.segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1]))
4336  # => [[0 0 0 0]
4337  #     [5 6 7 8]]
4338
4339  # Which is equivalent to:
4340  tf.math.segment_sum(c, tf.constant([0, 0, 1]))
4341  ```
4342
4343  Args:
4344    data: A `Tensor` with data that will be assembled in the output.
4345    indices: A 1-D `Tensor` with indices into `data`. Has same rank as
4346      `segment_ids`.
4347    segment_ids: A 1-D `Tensor` with indices into the output `Tensor`. Values
4348      should be sorted and can be repeated.
4349    name: A name for the operation (optional).
4350    num_segments: An optional int32 scalar. Indicates the size of the output
4351      `Tensor`.
4352
4353  Returns:
4354    A `tensor` of the shape as data, except for dimension 0 which
4355    has size `k`, the number of segments specified via `num_segments` or
4356    inferred for the last element in `segments_ids`.
4357  """
4358  if num_segments is not None:
4359    return gen_math_ops.sparse_segment_sum_with_num_segments(
4360        data=data,
4361        indices=indices,
4362        segment_ids=segment_ids,
4363        num_segments=num_segments,
4364        name=name)
4365  else:
4366    return gen_math_ops.sparse_segment_sum(
4367        data=data, indices=indices, segment_ids=segment_ids, name=name)
4368
4369
4370@tf_export("sparse.segment_sum", v1=[])
4371def sparse_segment_sum_v2(data,
4372                          indices,
4373                          segment_ids,
4374                          num_segments=None,
4375                          name=None):
4376  r"""Computes the sum along sparse segments of a tensor.
4377
4378  Read [the section on
4379  segmentation](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/math#about_segmentation)
4380  for an explanation of segments.
4381
4382  Like `tf.math.segment_sum`, but `segment_ids` can have rank less than `data`'s
4383  first dimension, selecting a subset of dimension 0, specified by `indices`.
4384  `segment_ids` is allowed to have missing ids, in which case the output will
4385  be zeros at those indices. In those cases `num_segments` is used to determine
4386  the size of the output.
4387
4388  For example:
4389
4390  ```python
4391  c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
4392
4393  # Select two rows, one segment.
4394  tf.sparse.segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0]))
4395  # => [[0 0 0 0]]
4396
4397  # Select two rows, two segment.
4398  tf.sparse.segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1]))
4399  # => [[ 1  2  3  4]
4400  #     [-1 -2 -3 -4]]
4401
4402  # With missing segment ids.
4403  tf.sparse.segment_sum(c, tf.constant([0, 1]), tf.constant([0, 2]),
4404                        num_segments=4)
4405  # => [[ 1  2  3  4]
4406  #     [ 0  0  0  0]
4407  #     [-1 -2 -3 -4]
4408  #     [ 0  0  0  0]]
4409
4410  # Select all rows, two segments.
4411  tf.sparse.segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1]))
4412  # => [[0 0 0 0]
4413  #     [5 6 7 8]]
4414
4415  # Which is equivalent to:
4416  tf.math.segment_sum(c, tf.constant([0, 0, 1]))
4417  ```
4418
4419  Args:
4420    data: A `Tensor` with data that will be assembled in the output.
4421    indices: A 1-D `Tensor` with indices into `data`. Has same rank as
4422      `segment_ids`.
4423    segment_ids: A 1-D `Tensor` with indices into the output `Tensor`. Values
4424      should be sorted and can be repeated.
4425    num_segments: An optional int32 scalar. Indicates the size of the output
4426      `Tensor`.
4427    name: A name for the operation (optional).
4428
4429  Returns:
4430    A `tensor` of the shape as data, except for dimension 0 which
4431    has size `k`, the number of segments specified via `num_segments` or
4432    inferred for the last element in `segments_ids`.
4433  """
4434  return sparse_segment_sum(
4435      data, indices, segment_ids, name=name, num_segments=num_segments)
4436
4437
4438@tf_export(v1=["sparse.segment_mean", "sparse_segment_mean"])
4439@deprecation.deprecated_endpoints("sparse_segment_mean")
4440def sparse_segment_mean(data,
4441                        indices,
4442                        segment_ids,
4443                        name=None,
4444                        num_segments=None):
4445  r"""Computes the mean along sparse segments of a tensor.
4446
4447  Read [the section on
4448  segmentation](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/math#about_segmentation)
4449  for an explanation of segments.
4450
4451  Like `tf.math.segment_mean`, but `segment_ids` can have rank less than
4452  `data`'s first dimension, selecting a subset of dimension 0, specified by
4453  `indices`.
4454  `segment_ids` is allowed to have missing ids, in which case the output will
4455  be zeros at those indices. In those cases `num_segments` is used to determine
4456  the size of the output.
4457
4458  Args:
4459    data: A `Tensor` with data that will be assembled in the output.
4460    indices: A 1-D `Tensor` with indices into `data`. Has same rank as
4461      `segment_ids`.
4462    segment_ids: A 1-D `Tensor` with indices into the output `Tensor`. Values
4463      should be sorted and can be repeated.
4464    name: A name for the operation (optional).
4465    num_segments: An optional int32 scalar. Indicates the size of the output
4466      `Tensor`.
4467
4468  Returns:
4469    A `tensor` of the shape as data, except for dimension 0 which
4470    has size `k`, the number of segments specified via `num_segments` or
4471    inferred for the last element in `segments_ids`.
4472  """
4473  if num_segments is not None:
4474    return gen_math_ops.sparse_segment_mean_with_num_segments(
4475        data=data,
4476        indices=indices,
4477        segment_ids=segment_ids,
4478        num_segments=num_segments,
4479        name=name)
4480  else:
4481    return gen_math_ops.sparse_segment_mean(
4482        data=data, indices=indices, segment_ids=segment_ids, name=name)
4483
4484
4485@tf_export("sparse.segment_mean", v1=[])
4486def sparse_segment_mean_v2(data,
4487                           indices,
4488                           segment_ids,
4489                           num_segments=None,
4490                           name=None):
4491  r"""Computes the mean along sparse segments of a tensor.
4492
4493  Read [the section on
4494  segmentation](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/math#about_segmentation)
4495  for an explanation of segments.
4496
4497  Like `tf.math.segment_mean`, but `segment_ids` can have rank less than
4498  `data`'s first dimension, selecting a subset of dimension 0, specified by
4499  `indices`.
4500  `segment_ids` is allowed to have missing ids, in which case the output will
4501  be zeros at those indices. In those cases `num_segments` is used to determine
4502  the size of the output.
4503
4504  Args:
4505    data: A `Tensor` with data that will be assembled in the output.
4506    indices: A 1-D `Tensor` with indices into `data`. Has same rank as
4507      `segment_ids`.
4508    segment_ids: A 1-D `Tensor` with indices into the output `Tensor`. Values
4509      should be sorted and can be repeated.
4510    num_segments: An optional int32 scalar. Indicates the size of the output
4511      `Tensor`.
4512    name: A name for the operation (optional).
4513
4514  Returns:
4515    A `tensor` of the shape as data, except for dimension 0 which
4516    has size `k`, the number of segments specified via `num_segments` or
4517    inferred for the last element in `segments_ids`.
4518  """
4519  return sparse_segment_mean(
4520      data, indices, segment_ids, name=name, num_segments=num_segments)
4521
4522
4523@tf_export(v1=["sparse.segment_sqrt_n", "sparse_segment_sqrt_n"])
4524@deprecation.deprecated_endpoints("sparse_segment_sqrt_n")
4525def sparse_segment_sqrt_n(data,
4526                          indices,
4527                          segment_ids,
4528                          name=None,
4529                          num_segments=None):
4530  r"""Computes the sum along sparse segments of a tensor divided by the sqrt(N).
4531
4532  `N` is the size of the segment being reduced.
4533
4534  Args:
4535    data: A `Tensor` with data that will be assembled in the output.
4536    indices: A 1-D `Tensor` with indices into `data`. Has same rank as
4537      `segment_ids`.
4538    segment_ids: A 1-D `Tensor` with indices into the output `Tensor`. Values
4539      should be sorted and can be repeated.
4540    name: A name for the operation (optional).
4541    num_segments: An optional int32 scalar. Indicates the size of the output
4542      `Tensor`.
4543
4544  Returns:
4545    A `tensor` of the shape as data, except for dimension 0 which
4546    has size `k`, the number of segments specified via `num_segments` or
4547    inferred for the last element in `segments_ids`.
4548  """
4549  if num_segments is not None:
4550    return gen_math_ops.sparse_segment_sqrt_n_with_num_segments(
4551        data=data,
4552        indices=indices,
4553        segment_ids=segment_ids,
4554        num_segments=num_segments,
4555        name=name)
4556  else:
4557    return gen_math_ops.sparse_segment_sqrt_n(
4558        data=data, indices=indices, segment_ids=segment_ids, name=name)
4559
4560
4561@tf_export("sparse.segment_sqrt_n", v1=[])
4562def sparse_segment_sqrt_n_v2(data,
4563                             indices,
4564                             segment_ids,
4565                             num_segments=None,
4566                             name=None):
4567  r"""Computes the sum along sparse segments of a tensor divided by the sqrt(N).
4568
4569  Read [the section on
4570  segmentation](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/math#about_segmentation)
4571  for an explanation of segments.
4572
4573  Like `tf.sparse.segment_mean`, but instead of dividing by the size of the
4574  segment, `N`, divide by `sqrt(N)` instead.
4575
4576  Args:
4577    data: A `Tensor` with data that will be assembled in the output.
4578    indices: A 1-D `Tensor` with indices into `data`. Has same rank as
4579      `segment_ids`.
4580    segment_ids: A 1-D `Tensor` with indices into the output `Tensor`. Values
4581      should be sorted and can be repeated.
4582    num_segments: An optional int32 scalar. Indicates the size of the output
4583      `Tensor`.
4584    name: A name for the operation (optional).
4585
4586  Returns:
4587    A `tensor` of the shape as data, except for dimension 0 which
4588    has size `k`, the number of segments specified via `num_segments` or
4589    inferred for the last element in `segments_ids`.
4590  """
4591  return sparse_segment_sqrt_n(
4592      data, indices, segment_ids, name=name, num_segments=num_segments)
4593
4594
4595@tf_export("tensordot", "linalg.tensordot")
4596@dispatch.add_dispatch_support
4597def tensordot(a, b, axes, name=None):
4598  r"""Tensor contraction of a and b along specified axes and outer product.
4599
4600  Tensordot (also known as tensor contraction) sums the product of elements
4601  from `a` and `b` over the indices specified by `a_axes` and `b_axes`.
4602  The lists `a_axes` and `b_axes` specify those pairs of axes along which to
4603  contract the tensors. The axis `a_axes[i]` of `a` must have the same dimension
4604  as axis `b_axes[i]` of `b` for all `i` in `range(0, len(a_axes))`. The lists
4605  `a_axes` and `b_axes` must have identical length and consist of unique
4606  integers that specify valid axes for each of the tensors. Additionally
4607  outer product is supported by passing `axes=0`.
4608
4609  This operation corresponds to `numpy.tensordot(a, b, axes)`.
4610
4611  Example 1: When `a` and `b` are matrices (order 2), the case `axes = 1`
4612  is equivalent to matrix multiplication.
4613
4614  Example 2: When `a` and `b` are matrices (order 2), the case
4615  `axes = [[1], [0]]` is equivalent to matrix multiplication.
4616
4617  Example 3: When `a` and `b` are matrices (order 2), the case `axes=0` gives
4618  the outer product, a tensor of order 4.
4619
4620  Example 4: Suppose that \\(a_{ijk}\\) and \\(b_{lmn}\\) represent two
4621  tensors of order 3. Then, `contract(a, b, [[0], [2]])` is the order 4 tensor
4622  \\(c_{jklm}\\) whose entry
4623  corresponding to the indices \\((j,k,l,m)\\) is given by:
4624
4625  \\( c_{jklm} = \sum_i a_{ijk} b_{lmi} \\).
4626
4627  In general, `order(c) = order(a) + order(b) - 2*len(axes[0])`.
4628
4629  Args:
4630    a: `Tensor` of type `float32` or `float64`.
4631    b: `Tensor` with the same type as `a`.
4632    axes: Either a scalar `N`, or a list or an `int32` `Tensor` of shape [2, k].
4633      If axes is a scalar, sum over the last N axes of a and the first N axes of
4634      b in order. If axes is a list or `Tensor` the first and second row contain
4635      the set of unique integers specifying axes along which the contraction is
4636      computed, for `a` and `b`, respectively. The number of axes for `a` and
4637      `b` must be equal. If `axes=0`, computes the outer product between `a` and
4638      `b`.
4639    name: A name for the operation (optional).
4640
4641  Returns:
4642    A `Tensor` with the same type as `a`.
4643
4644  Raises:
4645    ValueError: If the shapes of `a`, `b`, and `axes` are incompatible.
4646    IndexError: If the values in axes exceed the rank of the corresponding
4647      tensor.
4648  """
4649
4650  def _tensordot_reshape(a, axes, flipped=False):
4651    """Helper method to perform transpose and reshape for contraction op.
4652
4653    This method is helpful in reducing `math_ops.tensordot` to `math_ops.matmul`
4654    using `array_ops.transpose` and `array_ops.reshape`. The method takes a
4655    tensor and performs the correct transpose and reshape operation for a given
4656    set of indices. It returns the reshaped tensor as well as a list of indices
4657    necessary to reshape the tensor again after matrix multiplication.
4658
4659    Args:
4660      a: `Tensor`.
4661      axes: List or `int32` `Tensor` of unique indices specifying valid axes of
4662        `a`.
4663      flipped: An optional `bool`. Defaults to `False`. If `True`, the method
4664        assumes that `a` is the second argument in the contraction operation.
4665
4666    Returns:
4667      A tuple `(reshaped_a, free_dims, free_dims_static)` where `reshaped_a` is
4668      the tensor `a` reshaped to allow contraction via `matmul`, `free_dims` is
4669      either a list of integers or an `int32` `Tensor`, depending on whether
4670      the shape of a is fully specified, and free_dims_static is either a list
4671      of integers and None values, or None, representing the inferred
4672      static shape of the free dimensions
4673    """
4674    if a.get_shape().is_fully_defined() and isinstance(axes, (list, tuple)):
4675      shape_a = a.get_shape().as_list()
4676      axes = [i if i >= 0 else i + len(shape_a) for i in axes]
4677      free = [i for i in xrange(len(shape_a)) if i not in axes]
4678      free_dims = [shape_a[i] for i in free]
4679      prod_free = int(np.prod([shape_a[i] for i in free]))
4680      prod_axes = int(np.prod([shape_a[i] for i in axes]))
4681      perm = list(axes) + free if flipped else free + list(axes)
4682      new_shape = [prod_axes, prod_free] if flipped else [prod_free, prod_axes]
4683      if (perm != np.arange(len(shape_a))).any():
4684        a_trans = array_ops.transpose(a, perm)
4685      else:
4686        a_trans = a
4687      if a_trans.get_shape().as_list() != new_shape:
4688        reshaped_a = array_ops.reshape(a_trans, new_shape)
4689      else:
4690        reshaped_a = a_trans
4691      return reshaped_a, free_dims, free_dims
4692    else:
4693      if a.get_shape().ndims is not None and isinstance(axes, (list, tuple)):
4694        shape_a = a.get_shape().as_list()
4695        axes = [i if i >= 0 else i + len(shape_a) for i in axes]
4696        free = [i for i in xrange(len(shape_a)) if i not in axes]
4697        axes_dims = [shape_a[i] for i in axes]
4698        free_dims = [shape_a[i] for i in free]
4699        free_dims_static = free_dims
4700        axes = ops.convert_to_tensor(axes, dtype=dtypes.int32, name="axes")
4701        free = ops.convert_to_tensor(free, dtype=dtypes.int32, name="free")
4702        shape_a = array_ops.shape(a)
4703      else:
4704        free_dims_static = None
4705        shape_a = array_ops.shape(a)
4706        rank_a = array_ops.rank(a)
4707        axes = ops.convert_to_tensor(axes, dtype=dtypes.int32, name="axes")
4708        axes = array_ops.where(axes >= 0, axes, axes + rank_a)
4709        free, _ = gen_array_ops.list_diff(range(rank_a), axes, dtypes.int32)
4710      free_dims = array_ops.gather(shape_a, free)
4711      axes_dims = array_ops.gather(shape_a, axes)
4712      prod_free_dims = reduce_prod(free_dims)
4713      prod_axes_dims = reduce_prod(axes_dims)
4714      if flipped:
4715        perm = array_ops.concat([axes, free], 0)
4716        new_shape = array_ops.stack([prod_axes_dims, prod_free_dims])
4717      else:
4718        perm = array_ops.concat([free, axes], 0)
4719        new_shape = array_ops.stack([prod_free_dims, prod_axes_dims])
4720      reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape)
4721      return reshaped_a, free_dims, free_dims_static
4722
4723  def _tensordot_axes(a, axes):
4724    """Generates two sets of contraction axes for the two tensor arguments."""
4725    a_shape = a.get_shape()
4726    if isinstance(axes, compat.integral_types):
4727      if axes < 0:
4728        raise ValueError("'axes' must be at least 0.")
4729      if a_shape.ndims is not None:
4730        if axes > a_shape.ndims:
4731          raise ValueError("'axes' must not be larger than the number of "
4732                           "dimensions of tensor %s." % a)
4733        return (list(xrange(a_shape.ndims - axes,
4734                            a_shape.ndims)), list(xrange(axes)))
4735      else:
4736        rank = array_ops.rank(a)
4737        return (range(rank - axes, rank,
4738                      dtype=dtypes.int32), range(axes, dtype=dtypes.int32))
4739    elif isinstance(axes, (list, tuple)):
4740      if len(axes) != 2:
4741        raise ValueError("'axes' must be an integer or have length 2.")
4742      a_axes = axes[0]
4743      b_axes = axes[1]
4744      if isinstance(a_axes, compat.integral_types) and \
4745          isinstance(b_axes, compat.integral_types):
4746        a_axes = [a_axes]
4747        b_axes = [b_axes]
4748      if len(a_axes) != len(b_axes):
4749        raise ValueError(
4750            "Different number of contraction axes 'a' and 'b', %s != %s." %
4751            (len(a_axes), len(b_axes)))
4752      return a_axes, b_axes
4753    else:
4754      axes = ops.convert_to_tensor(axes, name="axes", dtype=dtypes.int32)
4755      return axes[0], axes[1]
4756
4757  with ops.name_scope(name, "Tensordot", [a, b, axes]) as name:
4758    a = ops.convert_to_tensor(a, name="a")
4759    b = ops.convert_to_tensor(b, name="b")
4760    a_axes, b_axes = _tensordot_axes(a, axes)
4761    a_reshape, a_free_dims, a_free_dims_static = _tensordot_reshape(a, a_axes)
4762    b_reshape, b_free_dims, b_free_dims_static = _tensordot_reshape(
4763        b, b_axes, True)
4764    ab_matmul = matmul(a_reshape, b_reshape)
4765    if isinstance(a_free_dims, list) and isinstance(b_free_dims, list):
4766      if (ab_matmul.get_shape().is_fully_defined() and
4767          ab_matmul.get_shape().as_list() == a_free_dims + b_free_dims):
4768        return ab_matmul
4769      else:
4770        return array_ops.reshape(
4771            ab_matmul, a_free_dims + b_free_dims, name=name)
4772    else:
4773      a_free_dims = ops.convert_to_tensor(a_free_dims, dtype=dtypes.int32)
4774      b_free_dims = ops.convert_to_tensor(b_free_dims, dtype=dtypes.int32)
4775      product = array_ops.reshape(
4776          ab_matmul, array_ops.concat([a_free_dims, b_free_dims], 0), name=name)
4777      if a_free_dims_static is not None and b_free_dims_static is not None:
4778        product.set_shape(a_free_dims_static + b_free_dims_static)
4779      return product
4780
4781
4782@tf_export("math.polyval")
4783@dispatch.add_dispatch_support
4784def polyval(coeffs, x, name=None):
4785  r"""Computes the elementwise value of a polynomial.
4786
4787  If `x` is a tensor and `coeffs` is a list n + 1 tensors,
4788  this function returns the value of the n-th order polynomial
4789
4790  `p(x) = coeffs[n-1] + coeffs[n-2] * x + ...  + coeffs[0] * x**(n-1)`
4791
4792  evaluated using Horner's method, i.e.
4793
4794  ```python
4795  p(x) = coeffs[n-1] + x * (coeffs[n-2] + ... + x * (coeffs[1] + x * coeffs[0]))
4796  ```
4797
4798  Usage Example:
4799
4800  >>> coefficients = [1.0, 2.5, -4.2]
4801  >>> x = 5.0
4802  >>> y = tf.math.polyval(coefficients, x)
4803  >>> y
4804  <tf.Tensor: shape=(), dtype=float32, numpy=33.3>
4805
4806  Usage Example:
4807
4808  >>> tf.math.polyval([2, 1, 0], 3) # evaluates 2 * (3**2) + 1 * (3**1) + 0 * (3**0)
4809  <tf.Tensor: shape=(), dtype=int32, numpy=21>
4810
4811  `tf.math.polyval` can also be used in polynomial regression. Taking
4812  advantage of this function can facilitate writing a polynomial equation
4813  as compared to explicitly writing it out, especially for higher degree
4814  polynomials.
4815
4816  >>> x = tf.constant(3)
4817  >>> theta1 = tf.Variable(2)
4818  >>> theta2 = tf.Variable(1)
4819  >>> theta3 = tf.Variable(0)
4820  >>> tf.math.polyval([theta1, theta2, theta3], x)
4821  <tf.Tensor: shape=(), dtype=int32, numpy=21>
4822
4823  Args:
4824    coeffs: A list of `Tensor` representing the coefficients of the polynomial.
4825    x: A `Tensor` representing the variable of the polynomial.
4826    name: A name for the operation (optional).
4827
4828  Returns:
4829    A `tensor` of the shape as the expression p(x) with usual broadcasting
4830    rules for element-wise addition and multiplication applied.
4831
4832  @compatibility(numpy)
4833  Equivalent to numpy.polyval.
4834  @end_compatibility
4835  """
4836  if not isinstance(coeffs, list):
4837    raise ValueError("Argument coeffs must be list type "
4838                     "found {}.".format(type(coeffs)))
4839
4840  with ops.name_scope(name, "polyval", nest.flatten(coeffs) + [x]) as name:
4841    x = ops.convert_to_tensor(x, name="x")
4842    if len(coeffs) < 1:
4843      return array_ops.zeros_like(x, name=name)
4844    coeffs = [
4845        ops.convert_to_tensor(coeff, name=("coeff_%d" % index))
4846        for index, coeff in enumerate(coeffs)
4847    ]
4848    p = coeffs[0]
4849    for c in coeffs[1:]:
4850      p = c + p * x
4851    return p
4852
4853
4854@tf_export("math.reciprocal_no_nan")
4855@dispatch.add_dispatch_support
4856def reciprocal_no_nan(x, name=None):
4857  """Performs a safe reciprocal operation, element wise.
4858
4859  If a particular element is zero, the reciprocal for that element is
4860  also set to zero.
4861
4862  For example:
4863  ```python
4864  x = tf.constant([2.0, 0.5, 0, 1], dtype=tf.float32)
4865  tf.math.reciprocal_no_nan(x)  # [ 0.5, 2, 0.0, 1.0 ]
4866  ```
4867
4868  Args:
4869    x: A `Tensor` of type `float16`, `float32`, `float64` `complex64` or
4870      `complex128`.
4871    name: A name for the operation (optional).
4872
4873  Returns:
4874    A `Tensor` of same shape and type as `x`.
4875
4876  Raises:
4877    TypeError: x must be of a valid dtype.
4878
4879  """
4880
4881  with ops.name_scope(name, "reciprocal_no_nan", [x]) as scope:
4882    x = ops.convert_to_tensor(x, name="x")
4883    one = constant_op.constant(1, dtype=x.dtype.base_dtype, name="one")
4884    return gen_math_ops.div_no_nan(one, x, name=scope)
4885
4886
4887@tf_export("math.xlog1py")
4888@dispatch.add_dispatch_support
4889def xlog1py(x, y, name=None):
4890  r"""Compute x * log1p(y).
4891
4892  Given `x` and `y`, compute `x * log1p(y)`. This function safely returns
4893  zero when `x = 0`, no matter what the value of `y` is.
4894
4895  Example:
4896
4897  >>> tf.math.xlog1py(0., 1.)
4898  <tf.Tensor: shape=(), dtype=float32, numpy=0.>
4899  >>> tf.math.xlog1py(1., 1.)
4900  <tf.Tensor: shape=(), dtype=float32, numpy=0.6931472>
4901  >>> tf.math.xlog1py(2., 2.)
4902  <tf.Tensor: shape=(), dtype=float32, numpy=2.1972246>
4903  >>> tf.math.xlog1py(0., -1.)
4904  <tf.Tensor: shape=(), dtype=float32, numpy=0.>
4905
4906  Args:
4907    x: A `tf.Tensor` of type `bfloat16`, `half`, `float32`, `float64`,
4908      `complex64`, `complex128`
4909    y: A `tf.Tensor` of type `bfloat16`, `half`, `float32`, `float64`,
4910      `complex64`, `complex128`
4911    name: A name for the operation (optional).
4912
4913  Returns:
4914    `x * log1p(y)`.
4915
4916  @compatibility(scipy)
4917  Equivalent to scipy.special.xlog1py
4918  @end_compatibility
4919  """
4920  with ops.name_scope(name, "xlog1py", [x]):
4921    return gen_math_ops.xlog1py(x, y)
4922
4923
4924@tf_export("math.erfinv")
4925@dispatch.add_dispatch_support
4926def erfinv(x, name=None):
4927  """Compute inverse error function.
4928
4929  Given `x`, compute the inverse error function of `x`. This function
4930  is the inverse of `tf.math.erf`.
4931
4932  Args:
4933    x: `Tensor` with type `float` or `double`.
4934    name: A name for the operation (optional).
4935  Returns:
4936    Inverse error function of `x`.
4937  """
4938  with ops.name_scope(name, "erfinv", [x]):
4939    return gen_math_ops.erfinv(x)
4940
4941
4942@tf_export("math.ndtri")
4943@dispatch.add_dispatch_support
4944def ndtri(x, name=None):
4945  """Compute quantile of Standard Normal.
4946
4947  Args:
4948    x: `Tensor` with type `float` or `double`.
4949    name: A name for the operation (optional).
4950  Returns:
4951    Inverse error function of `x`.
4952  """
4953  with ops.name_scope(name, "ndtri", [x]):
4954    return gen_math_ops.ndtri(x)
4955
4956
4957@tf_export("math.erfcinv")
4958@dispatch.add_dispatch_support
4959def erfcinv(x, name=None):
4960  """Computes the inverse of complementary error function.
4961
4962  Given `x`, compute the inverse complementary error function of `x`.
4963  This function is the inverse of `tf.math.erfc`, and is defined on
4964  `[0, 2]`.
4965
4966  >>> tf.math.erfcinv([0., 0.2, 1., 1.5, 2.])
4967  <tf.Tensor: shape=(5,), dtype=float32, numpy=
4968  array([       inf,  0.9061935, -0.       , -0.4769363,       -inf],
4969        dtype=float32)>
4970
4971  Args:
4972    x: `Tensor` with type `float` or `double`.
4973    name: A name for the operation (optional).
4974  Returns:
4975    Inverse complementary error function of `x`.
4976
4977  @compatibility(numpy)
4978  Equivalent to scipy.special.erfcinv
4979  @end_compatibility
4980  """
4981  with ops.name_scope(name, "erfcinv", [x]):
4982    x = ops.convert_to_tensor(x, name="start")
4983    return -ndtri(0.5 * x) * np.sqrt(0.5)
4984
4985
4986@tf_export("math.ceil", v1=["math.ceil", "ceil"])
4987@dispatch.add_dispatch_support
4988@deprecation.deprecated_endpoints("ceil")
4989@dispatch.add_dispatch_support
4990def ceil(x, name=None):
4991  """Return the ceiling of the input, element-wise.
4992
4993  For example:
4994
4995  >>> tf.math.ceil([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0])
4996  <tf.Tensor: shape=(7,), dtype=float32,
4997  numpy=array([-1., -1., -0.,  1.,  2.,  2.,  2.], dtype=float32)>
4998
4999  Args:
5000    x: A `tf.Tensor`. Must be one of the following types: `bfloat16`, `half`,
5001      `float32`, `float64`. `int32`
5002    name: A name for the operation (optional).
5003
5004  Returns:
5005    A `tf.Tensor`. Has the same type as `x`.
5006
5007  @compatibility(numpy)
5008  Equivalent to np.ceil
5009  @end_compatibility
5010  """
5011  return gen_math_ops.ceil(x, name)
5012
5013
5014@tf_export("math.sqrt", "sqrt")
5015@dispatch.add_dispatch_support
5016def sqrt(x, name=None):  # pylint: disable=redefined-builtin
5017  r"""Computes element-wise square root of the input tensor.
5018
5019  Note: This operation does not support integer types.
5020
5021  >>> x = tf.constant([[4.0], [16.0]])
5022  >>> tf.sqrt(x)
5023  <tf.Tensor: shape=(2, 1), dtype=float32, numpy=
5024    array([[2.],
5025           [4.]], dtype=float32)>
5026  >>> y = tf.constant([[-4.0], [16.0]])
5027  >>> tf.sqrt(y)
5028  <tf.Tensor: shape=(2, 1), dtype=float32, numpy=
5029    array([[nan],
5030           [ 4.]], dtype=float32)>
5031  >>> z = tf.constant([[-1.0], [16.0]], dtype=tf.complex128)
5032  >>> tf.sqrt(z)
5033  <tf.Tensor: shape=(2, 1), dtype=complex128, numpy=
5034    array([[0.0+1.j],
5035           [4.0+0.j]])>
5036
5037  Note: In order to support complex type, please provide an input tensor
5038  of `complex64` or `complex128`.
5039
5040  Args:
5041    x: A `tf.Tensor` of type `bfloat16`, `half`, `float32`, `float64`,
5042      `complex64`, `complex128`
5043    name: A name for the operation (optional).
5044
5045  Returns:
5046    A `tf.Tensor` of same size, type and sparsity as `x`.
5047  """
5048  return gen_math_ops.sqrt(x, name)
5049
5050
5051# pylint: disable=g-docstring-has-escape
5052@tf_export("math.exp", "exp")
5053@dispatch.add_dispatch_support
5054def exp(x, name=None):
5055  r"""Computes exponential of x element-wise.  \\(y = e^x\\).
5056
5057  This function computes the exponential of the input tensor element-wise.
5058  i.e. `math.exp(x)` or \\(e^x\\), where `x` is the input tensor.
5059  \\(e\\) denotes Euler's number and is approximately equal to 2.718281.
5060  Output is positive for any real input.
5061
5062  >>> x = tf.constant(2.0)
5063  >>> tf.math.exp(x)
5064  <tf.Tensor: shape=(), dtype=float32, numpy=7.389056>
5065
5066  >>> x = tf.constant([2.0, 8.0])
5067  >>> tf.math.exp(x)
5068  <tf.Tensor: shape=(2,), dtype=float32,
5069  numpy=array([   7.389056, 2980.958   ], dtype=float32)>
5070
5071  For complex numbers, the exponential value is calculated as
5072  $$
5073  e^{x+iy} = {e^x} {e^{iy}} = {e^x} ({\cos (y) + i \sin (y)})
5074  $$
5075
5076  For `1+1j` the value would be computed as:
5077  $$
5078  e^1 (\cos (1) + i \sin (1)) = 2.7182817 \times (0.5403023+0.84147096j)
5079  $$
5080
5081  >>> x = tf.constant(1 + 1j)
5082  >>> tf.math.exp(x)
5083  <tf.Tensor: shape=(), dtype=complex128,
5084  numpy=(1.4686939399158851+2.2873552871788423j)>
5085
5086  Args:
5087    x: A `tf.Tensor`. Must be one of the following types: `bfloat16`, `half`,
5088      `float32`, `float64`, `complex64`, `complex128`.
5089    name: A name for the operation (optional).
5090
5091  Returns:
5092    A `tf.Tensor`. Has the same type as `x`.
5093
5094  @compatibility(numpy)
5095  Equivalent to np.exp
5096  @end_compatibility
5097  """
5098  return gen_math_ops.exp(x, name)
5099
5100
5101# pylint: enable=g-docstring-has-escape
5102
5103
5104@tf_export("math.sobol_sample")
5105@dispatch.add_dispatch_support
5106def sobol_sample(dim, num_results, skip=0, dtype=dtypes.float32, name=None):
5107  """Generates points from the Sobol sequence.
5108
5109  Creates a Sobol sequence with `num_results` samples. Each sample has dimension
5110  `dim`. Skips the first `skip` samples.
5111
5112  Args:
5113    dim: Positive scalar `Tensor` representing each sample's dimension.
5114    num_results: Positive scalar `Tensor` of dtype int32. The number of Sobol
5115        points to return in the output.
5116    skip: (Optional) Positive scalar `Tensor` of dtype int32. The number of
5117        initial points of the Sobol sequence to skip. Default value is 0.
5118    dtype: (Optional) The `tf.Dtype` of the sample. One of: `tf.float32` or
5119        `tf.float64`. Defaults to `tf.float32`.
5120    name: (Optional) Python `str` name prefixed to ops created by this function.
5121
5122  Returns:
5123    `Tensor` of samples from Sobol sequence with `shape` [num_results, dim].
5124  """
5125  with ops.name_scope(name, "sobol", [dim, num_results, skip]):
5126    return gen_math_ops.sobol_sample(dim, num_results, skip, dtype=dtype)
5127
5128
5129@tf_export("math.rsqrt", v1=["math.rsqrt", "rsqrt"])
5130@dispatch.add_dispatch_support
5131@deprecation.deprecated_endpoints("rsqrt")
5132@dispatch.add_dispatch_support
5133def rsqrt(x, name=None):
5134  """Computes reciprocal of square root of x element-wise.
5135
5136  For example:
5137
5138  >>> x = tf.constant([2., 0., -2.])
5139  >>> tf.math.rsqrt(x)
5140  <tf.Tensor: shape=(3,), dtype=float32,
5141  numpy=array([0.707, inf, nan], dtype=float32)>
5142
5143  Args:
5144    x: A `tf.Tensor`. Must be one of the following types: `bfloat16`, `half`,
5145      `float32`, `float64`.
5146    name: A name for the operation (optional).
5147
5148  Returns:
5149    A `tf.Tensor`. Has the same type as `x`.
5150  """
5151  return gen_math_ops.rsqrt(x, name)
5152
5153
5154@tf_export("math.acos", "acos")
5155@dispatch.add_dispatch_support
5156def acos(x, name=None):
5157  """Computes acos of x element-wise.
5158
5159  Provided an input tensor, the `tf.math.acos` operation
5160  returns the inverse cosine of each element of the tensor.
5161  If `y = tf.math.cos(x)` then, `x = tf.math.acos(y)`.
5162
5163  Input range is `[-1, 1]` and the output has a range of `[0, pi]`.
5164
5165  For example:
5166
5167  >>> x = tf.constant([1.0, -0.5, 3.4, 0.2, 0.0, -2], dtype = tf.float32)
5168  >>> tf.math.acos(x)
5169  <tf.Tensor: shape=(6,), dtype=float32,
5170  numpy= array([0. , 2.0943952, nan, 1.3694383, 1.5707964, nan],
5171  dtype=float32)>
5172
5173  Args:
5174    x: A `Tensor`. Must be one of the following types: `bfloat16`, `half`,
5175      `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`,
5176      `complex64`, `complex128`, `string`.
5177    name: A name for the operation (optional).
5178
5179  Returns:
5180    A `Tensor`. Has the same type as x.
5181  """
5182  return gen_math_ops.acos(x, name)
5183
5184
5185@tf_export("math.floor", "floor")
5186@dispatch.add_dispatch_support
5187def floor(x, name=None):
5188  """Returns element-wise largest integer not greater than x.
5189
5190  Both input range is `(-inf, inf)` and the
5191  output range consists of all integer values.
5192
5193  For example:
5194
5195  >>> x = tf.constant([1.3324, -1.5, 5.555, -2.532, 0.99, float("inf")])
5196  >>> tf.floor(x).numpy()
5197  array([ 1., -2.,  5., -3.,  0., inf], dtype=float32)
5198
5199  Args:
5200    x:  A `Tensor`. Must be one of the following types: `bfloat16`, `half`,
5201      `float32`, `float64`.
5202    name: A name for the operation (optional).
5203
5204  Returns:
5205    A `Tensor`. Has the same type as x.
5206  """
5207  return gen_math_ops.floor(x, name)
5208