1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Basic arithmetic operators.
16
17See the @{$python/math_ops} guide.
18
19@@add
20@@subtract
21@@multiply
22@@scalar_mul
23@@div
24@@divide
25@@truediv
26@@floordiv
27@@realdiv
28@@truncatediv
29@@floor_div
30@@truncatemod
31@@floormod
32@@mod
33@@cross
34@@add_n
35@@abs
36@@negative
37@@sign
38@@reciprocal
39@@square
40@@round
41@@sqrt
42@@rsqrt
43@@pow
44@@exp
45@@expm1
46@@log
47@@log1p
48@@sinh
49@@cosh
50@@asinh
51@@acosh
52@@atanh
53@@ceil
54@@floor
55@@maximum
56@@minimum
57@@cos
58@@sin
59@@lbeta
60@@tan
61@@acos
62@@asin
63@@atan
64@@atan2
65@@lgamma
66@@digamma
67@@erf
68@@erfc
69@@squared_difference
70@@igamma
71@@igammac
72@@zeta
73@@polygamma
74@@betainc
75@@rint
76@@diag
77@@diag_part
78@@trace
79@@transpose
80@@eye
81@@matrix_diag
82@@matrix_diag_part
83@@matrix_band_part
84@@matrix_set_diag
85@@matrix_transpose
86@@matmul
87@@norm
88@@matrix_determinant
89@@matrix_inverse
90@@cholesky
91@@cholesky_solve
92@@matrix_exponential
93@@matrix_logarithm
94@@matrix_solve
95@@matrix_triangular_solve
96@@matrix_solve_ls
97@@qr
98@@self_adjoint_eig
99@@self_adjoint_eigvals
100@@svd
101@@tensordot
102@@complex
103@@conj
104@@imag
105@@angle
106@@real
107@@fft
108@@ifft
109@@fft2d
110@@ifft2d
111@@fft3d
112@@ifft3d
113@@reduce_sum
114@@reduce_prod
115@@reduce_min
116@@reduce_max
117@@reduce_mean
118@@reduce_all
119@@reduce_any
120@@reduce_logsumexp
121@@count_nonzero
122@@accumulate_n
123@@einsum
124@@bincount
125@@cumsum
126@@cumprod
127@@segment_sum
128@@segment_prod
129@@segment_min
130@@segment_max
131@@segment_mean
132@@unsorted_segment_sum
133@@unsorted_segment_max
134@@unsorted_segment_min
135@@unsorted_segment_prod
136@@unsorted_segment_sqrt_n
137@@sparse_segment_sum
138@@sparse_segment_mean
139@@sparse_segment_sqrt_n
140@@argmin
141@@argmax
142@@setdiff1d
143@@where
144@@unique
145@@edit_distance
146@@invert_permutation
147"""
148from __future__ import absolute_import
149from __future__ import division
150from __future__ import print_function
151
152import numpy as np
153from six.moves import xrange  # pylint: disable=redefined-builtin
154
155from tensorflow.python.eager import context
156from tensorflow.python.framework import common_shapes
157from tensorflow.python.framework import constant_op
158from tensorflow.python.framework import dtypes
159from tensorflow.python.framework import graph_util
160from tensorflow.python.framework import ops
161from tensorflow.python.framework import sparse_tensor
162from tensorflow.python.framework import tensor_shape
163from tensorflow.python.ops import array_ops
164from tensorflow.python.ops import gen_control_flow_ops
165from tensorflow.python.ops import gen_data_flow_ops
166from tensorflow.python.ops import gen_math_ops
167from tensorflow.python.ops import gen_nn_ops
168from tensorflow.python.ops import gen_sparse_ops
169from tensorflow.python.ops import gen_spectral_ops
170from tensorflow.python.ops import gen_state_ops
171from tensorflow.python.ops import state_ops
172# go/tf-wildcard-import
173# pylint: disable=wildcard-import
174from tensorflow.python.ops.gen_math_ops import *
175# pylint: enable=wildcard-import
176from tensorflow.python.util import compat
177from tensorflow.python.util import deprecation
178from tensorflow.python.util.tf_export import tf_export
179
180# Aliases for some automatically-generated names.
181linspace = gen_math_ops.lin_space
182
183arg_max = deprecation.deprecated(None, "Use `argmax` instead")(arg_max)  # pylint: disable=used-before-assignment
184arg_min = deprecation.deprecated(None, "Use `argmin` instead")(arg_min)  # pylint: disable=used-before-assignment
185
186
187def _set_doc(doc):
188
189  def _decorator(func):
190    func.__doc__ = doc
191    return func
192
193  return _decorator
194
195
196# pylint: disable=redefined-builtin
197@tf_export("argmax")
198@deprecation.deprecated_args(None, "Use the `axis` argument instead",
199                             "dimension")
200@_set_doc(
201    gen_math_ops.arg_max.__doc__.replace("dimensions", "axes").replace(
202        "dimension", "axis"))
203def argmax(input,
204           axis=None,
205           name=None,
206           dimension=None,
207           output_type=dtypes.int64):
208  if dimension is not None:
209    if axis is not None:
210      raise ValueError("Cannot specify both 'axis' and 'dimension'")
211    axis = dimension
212  elif axis is None:
213    axis = 0
214  return gen_math_ops.arg_max(input, axis, name=name, output_type=output_type)
215
216
217@tf_export("argmin")
218@deprecation.deprecated_args(None, "Use the `axis` argument instead",
219                             "dimension")
220@_set_doc(
221    gen_math_ops.arg_min.__doc__.replace("dimensions", "axes").replace(
222        "dimension", "axis"))
223def argmin(input,
224           axis=None,
225           name=None,
226           dimension=None,
227           output_type=dtypes.int64):
228  if dimension is not None:
229    if axis is not None:
230      raise ValueError("Cannot specify both 'axis' and 'dimension'")
231    axis = dimension
232  elif axis is None:
233    axis = 0
234  return gen_math_ops.arg_min(input, axis, name=name, output_type=output_type)
235
236
237# pylint: enable=redefined-builtin
238
239
240# pylint: disable=anomalous-backslash-in-string,protected-access
241# pylint: disable=g-docstring-has-escape
242@tf_export("abs")
243def abs(x, name=None):  # pylint: disable=redefined-builtin
244  r"""Computes the absolute value of a tensor.
245
246  Given a tensor `x` of complex numbers, this operation returns a tensor of type
247  `float32` or `float64` that is the absolute value of each element in `x`. All
248  elements in `x` must be complex numbers of the form \\(a + bj\\). The
249  absolute value is computed as \\( \sqrt{a^2 + b^2}\\).  For example:
250  ```python
251  x = tf.constant([[-2.25 + 4.75j], [-3.25 + 5.75j]])
252  tf.abs(x)  # [5.25594902, 6.60492229]
253  ```
254
255  Args:
256    x: A `Tensor` or `SparseTensor` of type `float32`, `float64`, `int32`,
257      `int64`, `complex64` or `complex128`.
258    name: A name for the operation (optional).
259
260  Returns:
261    A `Tensor` or `SparseTensor` the same size and type as `x` with absolute
262      values.
263    Note, for `complex64` or `complex128` input, the returned `Tensor` will be
264      of type `float32` or `float64`, respectively.
265  """
266  with ops.name_scope(name, "Abs", [x]) as name:
267    if isinstance(x, sparse_tensor.SparseTensor):
268      if x.values.dtype.is_complex:
269        x_abs = gen_math_ops._complex_abs(
270            x.values, Tout=x.values.dtype.real_dtype, name=name)
271        return sparse_tensor.SparseTensor(
272            indices=x.indices, values=x_abs, dense_shape=x.dense_shape)
273      x_abs = gen_math_ops._abs(x.values, name=name)
274      return sparse_tensor.SparseTensor(
275          indices=x.indices, values=x_abs, dense_shape=x.dense_shape)
276    else:
277      x = ops.convert_to_tensor(x, name="x")
278      if x.dtype.is_complex:
279        return gen_math_ops._complex_abs(x, Tout=x.dtype.real_dtype, name=name)
280      return gen_math_ops._abs(x, name=name)
281
282
283# pylint: enable=g-docstring-has-escape
284
285
286# pylint: disable=redefined-builtin
287def _bucketize(input, boundaries, name=None):
288  return gen_math_ops._bucketize(input=input, boundaries=boundaries, name=name)
289
290
291# pylint: enable=redefined-builtin
292
293
294class DivideDelegateWithName(object):
295  """Use Python2/Python3 division delegation to implement divide for tensors."""
296
297  def __init__(self, x, name):
298    """Construct DivideDelegateWithName.
299
300    Args:
301      x: Tensor to use as left operand in operator overloads
302      name: The name that is preferred for the op created.
303    """
304    self.x = x
305    self.name = name
306
307  def __truediv__(self, y):
308    return _truediv_python3(self.x, y, self.name)
309
310  def __floordiv__(self, y):
311    return floordiv(self.x, y, self.name)
312
313  def __div__(self, y):
314    return _div_python2(self.x, y, self.name)
315
316
317@tf_export("divide")
318def divide(x, y, name=None):
319  """Computes Python style division of `x` by `y`."""
320
321  if name is not None:
322    # Cannot use tensors operator overload, because it has no way to track
323    # override names. Use a dummy class to track the runtime division behavior
324    return DivideDelegateWithName(x, name) / y
325  else:
326    return x / y
327
328
329@tf_export("multiply")
330def multiply(x, y, name=None):
331  return gen_math_ops._mul(x, y, name)
332
333
334multiply.__doc__ = gen_math_ops._mul.__doc__.replace("Mul", "`tf.multiply`")
335
336
337# TODO(aselle): put deprecation in after another round of global code changes
338@deprecation.deprecated(
339    "2016-12-30",
340    "`tf.mul(x, y)` is deprecated, please use `tf.multiply(x, y)` or `x * y`")
341def _mul(x, y, name=None):
342  return gen_math_ops._mul(x, y, name)
343
344
345_mul.__doc__ = (
346    gen_math_ops._mul.__doc__ + ("" if _mul.__doc__ is None else _mul.__doc__))
347
348
349@tf_export("subtract")
350def subtract(x, y, name=None):
351  return gen_math_ops._sub(x, y, name)
352
353
354subtract.__doc__ = gen_math_ops._sub.__doc__.replace("`Sub`", "`tf.subtract`")
355
356
357# TODO(aselle): put deprecation in after another round of global code changes
358@deprecation.deprecated(
359    "2016-12-30",
360    "`tf.sub(x, y)` is deprecated, please use `tf.subtract(x, y)` or `x - y`")
361def _sub(x, y, name=None):
362  return gen_math_ops._sub(x, y, name)
363
364
365_sub.__doc__ = (
366    gen_math_ops._sub.__doc__ + ("" if _sub.__doc__ is None else _sub.__doc__))
367
368
369# pylint: disable=g-docstring-has-escape
370@tf_export("negative")
371def negative(x, name=None):
372  """Computes numerical negative value element-wise.
373
374  I.e., \\(y = -x\\).
375
376  Args:
377    x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
378      `float32`, `float64`, `int32`, `int64`, `complex64`, `complex128`.
379    name: A name for the operation (optional).
380
381  Returns:
382    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
383  """
384  with ops.name_scope(name, "Neg", [x]) as name:
385    if isinstance(x, sparse_tensor.SparseTensor):
386      x_neg = gen_math_ops._neg(x.values, name=name)
387      return sparse_tensor.SparseTensor(
388          indices=x.indices, values=x_neg, dense_shape=x.dense_shape)
389    else:
390      return gen_math_ops._neg(x, name=name)
391
392
393# pylint: enable=g-docstring-has-escape
394
395
396# pylint: disable=g-docstring-has-escape
397@deprecation.deprecated(
398    "2016-12-30",
399    "`tf.neg(x)` is deprecated, please use `tf.negative(x)` or `-x`")
400def _neg(x, name=None):
401  """Computes numerical negative value element-wise.
402
403  I.e., \\(y = -x\\).
404
405  Args:
406    x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
407      `float32`, `float64`, `int32`, `int64`, `complex64`, `complex128`.
408    name: A name for the operation (optional).
409
410  Returns:
411    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
412  """
413  return negative(x, name)
414
415
416# pylint: enable=g-docstring-has-escape
417
418
419@tf_export("sign")
420def sign(x, name=None):
421  """Returns an element-wise indication of the sign of a number.
422
423  `y = sign(x) = -1` if `x < 0`; 0 if `x == 0` or `tf.is_nan(x)`; 1 if `x > 0`.
424
425  Zero is returned for NaN inputs.
426
427  For complex numbers, `y = sign(x) = x / |x|` if `x != 0`, otherwise `y = 0`.
428
429  Args:
430    x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
431      `float32`, `float64`, `int32`, `int64`, `complex64`, `complex128`.
432    name: A name for the operation (optional).
433
434  Returns:
435    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
436
437  @compatibility(numpy)
438  Equivalent to numpy.sign except for the behavior for input values of NaN.
439  @end_compatibility
440  """
441  with ops.name_scope(name, "Sign", [x]) as name:
442    if isinstance(x, sparse_tensor.SparseTensor):
443      x_sign = gen_math_ops.sign(x.values, name=name)
444      return sparse_tensor.SparseTensor(
445          indices=x.indices, values=x_sign, dense_shape=x.dense_shape)
446    else:
447      return gen_math_ops.sign(x, name=name)
448
449
450@tf_export("square")
451def square(x, name=None):
452  r"""Computes square of x element-wise.
453
454  I.e., \\(y = x * x = x^2\\).
455
456  Args:
457    x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
458      `float32`, `float64`, `int32`, `int64`, `complex64`, `complex128`.
459    name: A name for the operation (optional).
460
461  Returns:
462    A `Tensor` or `SparseTensor`. Has the same type as `x`.
463  """
464  with ops.name_scope(name, "Square", [x]) as name:
465    if isinstance(x, sparse_tensor.SparseTensor):
466      x_square = gen_math_ops.square(x.values, name=name)
467      return sparse_tensor.SparseTensor(
468          indices=x.indices, values=x_square, dense_shape=x.dense_shape)
469    else:
470      return gen_math_ops.square(x, name=name)
471
472
473@tf_export("sqrt")
474def sqrt(x, name=None):
475  r"""Computes square root of x element-wise.
476
477  I.e., \\(y = \sqrt{x} = x^{1/2}\\).
478
479  Args:
480    x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
481      `float32`, `float64`, `complex64`, `complex128`.
482    name: A name for the operation (optional).
483
484  Returns:
485    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
486  """
487  with ops.name_scope(name, "Sqrt", [x]) as name:
488    if isinstance(x, sparse_tensor.SparseTensor):
489      x_sqrt = gen_math_ops.sqrt(x.values, name=name)
490      return sparse_tensor.SparseTensor(
491          indices=x.indices, values=x_sqrt, dense_shape=x.dense_shape)
492    else:
493      return gen_math_ops.sqrt(x, name=name)
494
495
496@tf_export("erf")
497def erf(x, name=None):
498  """Computes the Gauss error function of `x` element-wise.
499
500  Args:
501    x: A `Tensor` of `SparseTensor`. Must be one of the following types: `half`,
502      `float32`, `float64`.
503    name: A name for the operation (optional).
504
505  Returns:
506    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
507  """
508  with ops.name_scope(name, "Erf", [x]) as name:
509    if isinstance(x, sparse_tensor.SparseTensor):
510      x_erf = gen_math_ops.erf(x.values, name=name)
511      return sparse_tensor.SparseTensor(
512          indices=x.indices, values=x_erf, dense_shape=x.dense_shape)
513    else:
514      return gen_math_ops.erf(x, name=name)
515
516
517@tf_export("scalar_mul")
518def scalar_mul(scalar, x):
519  """Multiplies a scalar times a `Tensor` or `IndexedSlices` object.
520
521  Intended for use in gradient code which might deal with `IndexedSlices`
522  objects, which are easy to multiply by a scalar but more expensive to
523  multiply with arbitrary tensors.
524
525  Args:
526    scalar: A 0-D scalar `Tensor`. Must have known shape.
527    x: A `Tensor` or `IndexedSlices` to be scaled.
528
529  Returns:
530    `scalar * x` of the same type (`Tensor` or `IndexedSlices`) as `x`.
531
532  Raises:
533    ValueError: if scalar is not a 0-D `scalar`.
534  """
535  scalar = ops.convert_to_tensor(
536      scalar, dtype=x.dtype.base_dtype, name="scalar")
537  shape = scalar.get_shape()
538  if shape.ndims == 0:
539    if isinstance(x, ops.IndexedSlices):
540      return ops.IndexedSlices(scalar * x.values, x.indices, x.dense_shape)
541    else:
542      return scalar * x
543  else:
544    raise ValueError("Only scalar multiply works, got shape %s" % shape)
545
546
547@tf_export("pow")
548def pow(x, y, name=None):  # pylint: disable=redefined-builtin
549  r"""Computes the power of one value to another.
550
551  Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for
552  corresponding elements in `x` and `y`. For example:
553
554  ```python
555  x = tf.constant([[2, 2], [3, 3]])
556  y = tf.constant([[8, 16], [2, 3]])
557  tf.pow(x, y)  # [[256, 65536], [9, 27]]
558  ```
559
560  Args:
561    x: A `Tensor` of type `float32`, `float64`, `int32`, `int64`, `complex64`,
562     or `complex128`.
563    y: A `Tensor` of type `float32`, `float64`, `int32`, `int64`, `complex64`,
564     or `complex128`.
565    name: A name for the operation (optional).
566
567  Returns:
568    A `Tensor`.
569  """
570  with ops.name_scope(name, "Pow", [x]) as name:
571    return gen_math_ops._pow(x, y, name=name)
572
573
574# pylint: disable=redefined-builtin,redefined-outer-name
575@tf_export("complex")
576def complex(real, imag, name=None):
577  r"""Converts two real numbers to a complex number.
578
579  Given a tensor `real` representing the real part of a complex number, and a
580  tensor `imag` representing the imaginary part of a complex number, this
581  operation returns complex numbers elementwise of the form \\(a + bj\\), where
582  *a* represents the `real` part and *b* represents the `imag` part.
583
584  The input tensors `real` and `imag` must have the same shape.
585
586  For example:
587
588  ```python
589  real = tf.constant([2.25, 3.25])
590  imag = tf.constant([4.75, 5.75])
591  tf.complex(real, imag)  # [[2.25 + 4.75j], [3.25 + 5.75j]]
592  ```
593
594  Args:
595    real: A `Tensor`. Must be one of the following types: `float32`,
596      `float64`.
597    imag: A `Tensor`. Must have the same type as `real`.
598    name: A name for the operation (optional).
599
600  Returns:
601    A `Tensor` of type `complex64` or `complex128`.
602  """
603  real = ops.convert_to_tensor(real, name="real")
604  imag = ops.convert_to_tensor(imag, name="imag")
605  with ops.name_scope(name, "Complex", [real, imag]) as name:
606    input_types = (real.dtype, imag.dtype)
607    if input_types == (dtypes.float64, dtypes.float64):
608      Tout = dtypes.complex128
609    elif input_types == (dtypes.float32, dtypes.float32):
610      Tout = dtypes.complex64
611    else:
612      raise TypeError("real and imag have incorrect types: "
613                      "{} {}".format(real.dtype.name, imag.dtype.name))
614    return gen_math_ops._complex(real, imag, Tout=Tout, name=name)
615
616
617@tf_export("real")
618def real(input, name=None):
619  r"""Returns the real part of a complex (or real) tensor.
620
621  Given a tensor `input`, this operation returns a tensor of type `float` that
622  is the real part of each element in `input` considered as a complex number.
623
624  For example:
625
626  ```python
627  x = tf.constant([-2.25 + 4.75j, 3.25 + 5.75j])
628  tf.real(x)  # [-2.25, 3.25]
629  ```
630
631  If `input` is already real, it is returned unchanged.
632
633  Args:
634    input: A `Tensor`. Must have numeric type.
635    name: A name for the operation (optional).
636
637  Returns:
638    A `Tensor` of type `float32` or `float64`.
639  """
640  with ops.name_scope(name, "Real", [input]) as name:
641    if input.dtype.is_complex:
642      real_dtype = input.dtype.real_dtype
643      return gen_math_ops.real(input, Tout=real_dtype, name=name)
644    else:
645      return input
646
647
648@tf_export("imag")
649def imag(input, name=None):
650  r"""Returns the imaginary part of a complex (or real) tensor.
651
652  Given a tensor `input`, this operation returns a tensor of type `float` that
653  is the imaginary part of each element in `input` considered as a complex
654  number. If `input` is real, a tensor of all zeros is returned.
655
656  For example:
657
658  ```python
659  x = tf.constant([-2.25 + 4.75j, 3.25 + 5.75j])
660  tf.imag(x)  # [4.75, 5.75]
661  ```
662
663  Args:
664    input: A `Tensor`. Must be one of the following types: `float`, `double`,
665      `complex64`, `complex128`.
666    name: A name for the operation (optional).
667
668  Returns:
669    A `Tensor` of type `float32` or `float64`.
670  """
671  with ops.name_scope(name, "Imag", [input]) as name:
672    if input.dtype.is_complex:
673      return gen_math_ops.imag(input, Tout=input.dtype.real_dtype, name=name)
674    else:
675      return array_ops.zeros_like(input)
676
677
678@tf_export("angle")
679def angle(input, name=None):
680  r"""Returns the element-wise argument of a complex (or real) tensor.
681
682  Given a tensor `input`, this operation returns a tensor of type `float` that
683  is the argument of each element in `input` considered as a complex number.
684
685  The elements in `input` are considered to be complex numbers of the form
686  \\(a + bj\\), where *a* is the real part and *b* is the imaginary part.
687  If `input` is real then *b* is zero by definition.
688
689  The argument returned by this function is of the form \\(atan2(b, a)\\).
690  If `input` is real, a tensor of all zeros is returned.
691
692  For example:
693
694  ```
695  # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
696  tf.angle(input) ==> [2.0132, 1.056]
697  ```
698
699  Args:
700    input: A `Tensor`. Must be one of the following types: `float`, `double`,
701      `complex64`, `complex128`.
702    name: A name for the operation (optional).
703
704  Returns:
705    A `Tensor` of type `float32` or `float64`.
706  """
707  with ops.name_scope(name, "Angle", [input]) as name:
708    if input.dtype.is_complex:
709      return gen_math_ops.angle(input, Tout=input.dtype.real_dtype, name=name)
710    else:
711      return array_ops.zeros_like(input)
712
713
714# pylint: enable=redefined-outer-name,redefined-builtin
715
716
717@tf_export("round")
718def round(x, name=None):  # pylint: disable=redefined-builtin
719  """Rounds the values of a tensor to the nearest integer, element-wise.
720
721  Rounds half to even.  Also known as bankers rounding. If you want to round
722  according to the current system rounding mode use tf::cint.
723  For example:
724
725  ```python
726  x = tf.constant([0.9, 2.5, 2.3, 1.5, -4.5])
727  tf.round(x)  # [ 1.0, 2.0, 2.0, 2.0, -4.0 ]
728  ```
729
730  Args:
731    x: A `Tensor` of type `float32` or `float64`.
732    name: A name for the operation (optional).
733
734  Returns:
735    A `Tensor` of same shape and type as `x`.
736  """
737  x = ops.convert_to_tensor(x, name="x")
738  if x.dtype.is_integer:
739    return x
740  else:
741    return gen_math_ops.round(x, name=name)
742
743
744@tf_export("cast")
745def cast(x, dtype, name=None):
746  """Casts a tensor to a new type.
747
748  The operation casts `x` (in case of `Tensor`) or `x.values`
749  (in case of `SparseTensor`) to `dtype`.
750
751  For example:
752
753  ```python
754  x = tf.constant([1.8, 2.2], dtype=tf.float32)
755  tf.cast(x, tf.int32)  # [1, 2], dtype=tf.int32
756  ```
757
758  Args:
759    x: A `Tensor` or `SparseTensor`.
760    dtype: The destination type.
761    name: A name for the operation (optional).
762
763  Returns:
764    A `Tensor` or `SparseTensor` with same shape as `x`.
765
766  Raises:
767    TypeError: If `x` cannot be cast to the `dtype`.
768  """
769  base_type = dtypes.as_dtype(dtype).base_dtype
770  with ops.name_scope(name, "Cast", [x]) as name:
771    if isinstance(x, sparse_tensor.SparseTensor):
772      values_cast = cast(x.values, base_type, name=name)
773      return sparse_tensor.SparseTensor(x.indices, values_cast, x.dense_shape)
774    else:
775      # TODO(josh11b): If x is not already a Tensor, we could return
776      # ops.convert_to_tensor(x, dtype=dtype, ...)  here, but that
777      # allows some conversions that cast() can't do, e.g. casting numbers to
778      # strings.
779      x = ops.convert_to_tensor(x, name="x")
780      if x.dtype.base_dtype == base_type:
781        return x
782      return gen_math_ops.cast(x, base_type, name=name)
783
784
785@tf_export("saturate_cast")
786def saturate_cast(value, dtype, name=None):
787  """Performs a safe saturating cast of `value` to `dtype`.
788
789  This function casts the input to `dtype` without applying any scaling.  If
790  there is a danger that values would over or underflow in the cast, this op
791  applies the appropriate clamping before the cast.
792
793  Args:
794    value: A `Tensor`.
795    dtype: The desired output `DType`.
796    name: A name for the operation (optional).
797
798  Returns:
799    `value` safely cast to `dtype`.
800  """
801  # When casting to a type with smaller representable range, clamp.
802  # Note that this covers casting to unsigned types as well.
803  with ops.name_scope(name, "saturate_cast", [value]) as name:
804    value = ops.convert_to_tensor(value, name="value")
805    dtype = dtypes.as_dtype(dtype).base_dtype
806    if value.dtype.min < dtype.min:
807      value = gen_math_ops.maximum(value,
808                                   ops.convert_to_tensor(
809                                       dtype.min, dtype=value.dtype,
810                                       name="min"))
811    if value.dtype.max > dtype.max:
812      value = gen_math_ops.minimum(value,
813                                   ops.convert_to_tensor(
814                                       dtype.max, dtype=value.dtype,
815                                       name="max"))
816    return cast(value, dtype, name=name)
817
818
819@tf_export("to_float")
820def to_float(x, name="ToFloat"):
821  """Casts a tensor to type `float32`.
822
823  Args:
824    x: A `Tensor` or `SparseTensor`.
825    name: A name for the operation (optional).
826
827  Returns:
828    A `Tensor` or `SparseTensor` with same shape as `x` with type `float32`.
829
830  Raises:
831    TypeError: If `x` cannot be cast to the `float32`.
832  """
833  return cast(x, dtypes.float32, name=name)
834
835
836@tf_export("to_double")
837def to_double(x, name="ToDouble"):
838  """Casts a tensor to type `float64`.
839
840  Args:
841    x: A `Tensor` or `SparseTensor`.
842    name: A name for the operation (optional).
843
844  Returns:
845    A `Tensor` or `SparseTensor` with same shape as `x` with type `float64`.
846
847  Raises:
848    TypeError: If `x` cannot be cast to the `float64`.
849  """
850  return cast(x, dtypes.float64, name=name)
851
852
853@tf_export("to_int32")
854def to_int32(x, name="ToInt32"):
855  """Casts a tensor to type `int32`.
856
857  Args:
858    x: A `Tensor` or `SparseTensor`.
859    name: A name for the operation (optional).
860
861  Returns:
862    A `Tensor` or `SparseTensor` with same shape as `x` with type `int32`.
863
864  Raises:
865    TypeError: If `x` cannot be cast to the `int32`.
866  """
867  return cast(x, dtypes.int32, name=name)
868
869
870@tf_export("to_int64")
871def to_int64(x, name="ToInt64"):
872  """Casts a tensor to type `int64`.
873
874  Args:
875    x: A `Tensor` or `SparseTensor`.
876    name: A name for the operation (optional).
877
878  Returns:
879    A `Tensor` or `SparseTensor` with same shape as `x` with type `int64`.
880
881  Raises:
882    TypeError: If `x` cannot be cast to the `int64`.
883  """
884  return cast(x, dtypes.int64, name=name)
885
886
887@tf_export("to_bfloat16")
888def to_bfloat16(x, name="ToBFloat16"):
889  """Casts a tensor to type `bfloat16`.
890
891  Args:
892    x: A `Tensor` or `SparseTensor`.
893    name: A name for the operation (optional).
894
895  Returns:
896    A `Tensor` or `SparseTensor` with same shape as `x` with type `bfloat16`.
897
898  Raises:
899    TypeError: If `x` cannot be cast to the `bfloat16`.
900  """
901  return cast(x, dtypes.bfloat16, name=name)
902
903
904ops.Tensor._override_operator("__neg__", gen_math_ops._neg)
905ops.Tensor._override_operator("__abs__", abs)
906# __invert__ corresponds to the ~ operator.  Here we follow the numpy convention
907# ~ marks an elementwise bit-wise inverse.  This is only implemented for boolean
908# tensors and will throw a TypeError if used on nonboolean arrays
909ops.Tensor._override_operator("__invert__", gen_math_ops.logical_not)
910
911
912def _OverrideBinaryOperatorHelper(func, op_name, clazz_object=ops.Tensor):
913  """Register operators with different tensor and scalar versions.
914
915  If `clazz_object` is `SparseTensor`, assumes `func` takes `(sp_indices,
916  sp_values, sp_shape, dense)` and outputs `(new_sp_values)`.
917
918  Args:
919    func: the operator
920    op_name: name of the operator being overridden
921    clazz_object: class to override for.  Either `Tensor` or `SparseTensor`.
922  """
923
924  def binary_op_wrapper(x, y):
925    with ops.name_scope(None, op_name, [x, y]) as name:
926      if not isinstance(y, sparse_tensor.SparseTensor):
927        try:
928          y = ops.convert_to_tensor(y, dtype=x.dtype.base_dtype, name="y")
929        except TypeError:
930          # If the RHS is not a tensor, it might be a tensor aware object
931          # that can implement the operator with knowledge of itself
932          # and the tensor.
933          if hasattr(type(y), "__r%s__" % op_name):
934            return NotImplemented
935          else:
936            raise
937      return func(x, y, name=name)
938
939  def binary_op_wrapper_sparse(sp_x, y):
940    with ops.name_scope(None, op_name, [sp_x, y]) as name:
941      y = ops.convert_to_tensor(y, dtype=sp_x.dtype.base_dtype, name="y")
942      return sparse_tensor.SparseTensor(sp_x.indices,
943                                        func(
944                                            sp_x.indices,
945                                            sp_x.values,
946                                            sp_x.dense_shape,
947                                            y,
948                                            name=name), sp_x.dense_shape)
949
950  def r_binary_op_wrapper(y, x):
951    with ops.name_scope(None, op_name, [x, y]) as name:
952      x = ops.convert_to_tensor(x, dtype=y.dtype.base_dtype, name="x")
953      return func(x, y, name=name)
954
955  # Propagate func.__doc__ to the wrappers
956  try:
957    doc = func.__doc__
958  except AttributeError:
959    doc = None
960  binary_op_wrapper.__doc__ = doc
961  r_binary_op_wrapper.__doc__ = doc
962  binary_op_wrapper_sparse.__doc__ = doc
963
964  if clazz_object is ops.Tensor:
965    clazz_object._override_operator("__%s__" % op_name, binary_op_wrapper)
966    del binary_op_wrapper
967    clazz_object._override_operator("__r%s__" % op_name, r_binary_op_wrapper)
968    del r_binary_op_wrapper
969  else:
970    clazz_object._override_operator("__%s__" % op_name,
971                                    binary_op_wrapper_sparse)
972    del binary_op_wrapper_sparse
973
974
975# Conversion table for __truediv__.  None entries mean no conversion required.
976_TRUEDIV_TABLE = {
977    dtypes.uint8: dtypes.float32,
978    dtypes.int8: dtypes.float32,
979    dtypes.uint16: dtypes.float32,
980    dtypes.int16: dtypes.float32,
981    dtypes.int32: dtypes.float64,
982    dtypes.int64: dtypes.float64,
983    dtypes.bfloat16: None,
984    dtypes.float16: None,
985    dtypes.float32: None,
986    dtypes.float64: None,
987    dtypes.complex64: None,
988    dtypes.complex128: None,
989}
990
991
992# NOTE: the support of "sparse (true)div dense" is currently not baked in into
993# "tf.(true_)div()".  Until such an API decision is made, the supported usage is
994# to explicitly use the "/" operator to invoke either truediv or div.
995def _sparse_dense_truediv(sp_indices, sp_values, sp_shape, y, name=None):
996  """Internal helper function for 'sp_t / dense_t'."""
997  with ops.name_scope(name, "truediv",
998                      [sp_indices, sp_values, sp_shape, y]) as name:
999    sp_values = ops.convert_to_tensor(sp_values, name="sp_values")
1000    y = ops.convert_to_tensor(y, name="y")
1001    x_dtype = sp_values.dtype.base_dtype
1002    y_dtype = y.dtype.base_dtype
1003    if x_dtype != y_dtype:
1004      raise TypeError("x and y must have the same dtype, got %r != %r" %
1005                      (x_dtype, y_dtype))
1006    try:
1007      dtype = _TRUEDIV_TABLE[x_dtype]
1008    except KeyError:
1009      raise TypeError("Invalid dtype %r in __truediv__" % x_dtype)
1010    if dtype is not None:
1011      sp_values = cast(sp_values, dtype)
1012      y = cast(y, dtype)
1013    return gen_sparse_ops.sparse_dense_cwise_div(
1014        sp_indices, sp_values, sp_shape, y, name=name)
1015
1016
1017def _truediv_python3(x, y, name=None):
1018  with ops.name_scope(name, "truediv", [x, y]) as name:
1019    x = ops.convert_to_tensor(x, name="x")
1020    y = ops.convert_to_tensor(y, name="y")
1021    x_dtype = x.dtype.base_dtype
1022    y_dtype = y.dtype.base_dtype
1023    if x_dtype != y_dtype:
1024      raise TypeError("x and y must have the same dtype, got %r != %r" %
1025                      (x_dtype, y_dtype))
1026    try:
1027      dtype = _TRUEDIV_TABLE[x_dtype]
1028    except KeyError:
1029      raise TypeError("Invalid dtype %r in __truediv__" % x_dtype)
1030    if dtype is not None:
1031      x = cast(x, dtype)
1032      y = cast(y, dtype)
1033    return gen_math_ops._real_div(x, y, name=name)
1034
1035
1036def _div_python2(x, y, name=None):
1037  """Divide two values using Python 2 semantics. Used for Tensor.__div__.
1038
1039  Args:
1040    x: `Tensor` numerator of real numeric type.
1041    y: `Tensor` denominator of real numeric type.
1042    name: A name for the operation (optional).
1043  Returns:
1044    `x / y` returns the quotient of x and y.
1045  """
1046
1047  with ops.name_scope(name, "div", [x, y]) as name:
1048    x = ops.convert_to_tensor(x, name="x")
1049    y = ops.convert_to_tensor(y, name="y", dtype=x.dtype.base_dtype)
1050    x_dtype = x.dtype.base_dtype
1051    y_dtype = y.dtype.base_dtype
1052    if x_dtype != y_dtype:
1053      raise TypeError("x and y must have the same dtype, got %r != %r" %
1054                      (x_dtype, y_dtype))
1055    if x_dtype.is_floating or x_dtype.is_complex:
1056      return gen_math_ops._real_div(x, y, name=name)
1057    else:
1058      return gen_math_ops._floor_div(x, y, name=name)
1059
1060
1061@tf_export("truediv")
1062def truediv(x, y, name=None):
1063  """Divides x / y elementwise (using Python 3 division operator semantics).
1064
1065  NOTE: Prefer using the Tensor operator or tf.divide which obey Python
1066  division operator semantics.
1067
1068  This function forces Python 3 division operator semantics where all integer
1069  arguments are cast to floating types first.   This op is generated by normal
1070  `x / y` division in Python 3 and in Python 2.7 with
1071  `from __future__ import division`.  If you want integer division that rounds
1072  down, use `x // y` or `tf.floordiv`.
1073
1074  `x` and `y` must have the same numeric type.  If the inputs are floating
1075  point, the output will have the same type.  If the inputs are integral, the
1076  inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32`
1077  and `int64` (matching the behavior of Numpy).
1078
1079  Args:
1080    x: `Tensor` numerator of numeric type.
1081    y: `Tensor` denominator of numeric type.
1082    name: A name for the operation (optional).
1083
1084  Returns:
1085    `x / y` evaluated in floating point.
1086
1087  Raises:
1088    TypeError: If `x` and `y` have different dtypes.
1089  """
1090  return _truediv_python3(x, y, name)
1091
1092
1093@tf_export("div")
1094def div(x, y, name=None):
1095  """Divides x / y elementwise (using Python 2 division operator semantics).
1096
1097  NOTE: Prefer using the Tensor division operator or tf.divide which obey Python
1098  division operator semantics.
1099
1100  This function divides `x` and `y`, forcing Python 2.7 semantics. That is,
1101  if one of `x` or `y` is a float, then the result will be a float.
1102  Otherwise, the output will be an integer type. Flooring semantics are used
1103  for integer division.
1104
1105  Args:
1106    x: `Tensor` numerator of real numeric type.
1107    y: `Tensor` denominator of real numeric type.
1108    name: A name for the operation (optional).
1109  Returns:
1110    `x / y` returns the quotient of x and y.
1111  """
1112  return _div_python2(x, y, name)
1113
1114
1115# TODO(aselle): This should be removed
1116mod = gen_math_ops._floor_mod
1117
1118
1119# TODO(aselle): Deprecate this once all internal functionality uses
1120# tf.truncatediv
1121@tf_export("floordiv")
1122def floordiv(x, y, name=None):
1123  """Divides `x / y` elementwise, rounding toward the most negative integer.
1124
1125  The same as `tf.div(x,y)` for integers, but uses `tf.floor(tf.div(x,y))` for
1126  floating point arguments so that the result is always an integer (though
1127  possibly an integer represented as floating point).  This op is generated by
1128  `x // y` floor division in Python 3 and in Python 2.7 with
1129  `from __future__ import division`.
1130
1131  Note that for efficiency, `floordiv` uses C semantics for negative numbers
1132  (unlike Python and Numpy).
1133
1134  `x` and `y` must have the same type, and the result will have the same type
1135  as well.
1136
1137  Args:
1138    x: `Tensor` numerator of real numeric type.
1139    y: `Tensor` denominator of real numeric type.
1140    name: A name for the operation (optional).
1141
1142  Returns:
1143    `x / y` rounded down (except possibly towards zero for negative integers).
1144
1145  Raises:
1146    TypeError: If the inputs are complex.
1147  """
1148  with ops.name_scope(name, "floordiv", [x, y]) as name:
1149    return gen_math_ops._floor_div(x, y, name=name)
1150
1151
1152realdiv = gen_math_ops._real_div
1153truncatediv = gen_math_ops._truncate_div
1154# TODO(aselle): Rename this to floordiv when we can.
1155floor_div = gen_math_ops._floor_div
1156truncatemod = gen_math_ops._truncate_mod
1157floormod = gen_math_ops._floor_mod
1158
1159
1160def _mul_dispatch(x, y, name=None):
1161  """Dispatches cwise mul for "Dense*Dense" and "Dense*Sparse"."""
1162  is_tensor_y = isinstance(y, ops.Tensor)
1163  if is_tensor_y:
1164    return gen_math_ops._mul(x, y, name=name)
1165  else:
1166    assert isinstance(y, sparse_tensor.SparseTensor)  # Case: Dense * Sparse.
1167    new_vals = gen_sparse_ops.sparse_dense_cwise_mul(y.indices, y.values,
1168                                                     y.dense_shape, x, name)
1169    return sparse_tensor.SparseTensor(y.indices, new_vals, y.dense_shape)
1170
1171
1172# NOTE(aselle): When integer division is added for sparse_dense_cwise,
1173# div, truediv, and floordiv should be delegated appropriately for
1174# Python sematnics, analogous to dense cwise tensor operations.
1175_OverrideBinaryOperatorHelper(gen_sparse_ops.sparse_dense_cwise_div, "div",
1176                              sparse_tensor.SparseTensor)
1177_OverrideBinaryOperatorHelper(_sparse_dense_truediv, "truediv",
1178                              sparse_tensor.SparseTensor)
1179_OverrideBinaryOperatorHelper(gen_sparse_ops.sparse_dense_cwise_mul, "mul",
1180                              sparse_tensor.SparseTensor)
1181
1182_OverrideBinaryOperatorHelper(gen_math_ops.add, "add")
1183_OverrideBinaryOperatorHelper(gen_math_ops._sub, "sub")
1184_OverrideBinaryOperatorHelper(_mul_dispatch, "mul")
1185_OverrideBinaryOperatorHelper(_div_python2, "div")
1186_OverrideBinaryOperatorHelper(_truediv_python3, "truediv")
1187_OverrideBinaryOperatorHelper(floordiv, "floordiv")
1188_OverrideBinaryOperatorHelper(gen_math_ops._floor_mod, "mod")
1189_OverrideBinaryOperatorHelper(pow, "pow")
1190
1191
1192@tf_export("logical_xor")
1193def logical_xor(x, y, name="LogicalXor"):
1194  """x ^ y = (x | y) & ~(x & y)."""
1195  # TODO(alemi) Make this a cwise op if people end up relying on it.
1196  return gen_math_ops.logical_and(
1197      gen_math_ops.logical_or(x, y),
1198      gen_math_ops.logical_not(gen_math_ops.logical_and(x, y)),
1199      name=name)
1200
1201
1202_OverrideBinaryOperatorHelper(gen_math_ops.logical_and, "and")
1203_OverrideBinaryOperatorHelper(gen_math_ops.logical_or, "or")
1204_OverrideBinaryOperatorHelper(logical_xor, "xor")
1205
1206ops.Tensor._override_operator("__lt__", gen_math_ops.less)
1207ops.Tensor._override_operator("__le__", gen_math_ops.less_equal)
1208ops.Tensor._override_operator("__gt__", gen_math_ops.greater)
1209ops.Tensor._override_operator("__ge__", gen_math_ops.greater_equal)
1210
1211
1212@tf_export("range")
1213def range(start, limit=None, delta=1, dtype=None, name="range"):  # pylint: disable=redefined-builtin
1214  """Creates a sequence of numbers.
1215
1216  Creates a sequence of numbers that begins at `start` and extends by
1217  increments of `delta` up to but not including `limit`.
1218
1219  The dtype of the resulting tensor is inferred from the inputs unless
1220  it is provided explicitly.
1221
1222  Like the Python builtin `range`, `start` defaults to 0, so that
1223  `range(n) = range(0, n)`.
1224
1225  For example:
1226
1227  ```python
1228  start = 3
1229  limit = 18
1230  delta = 3
1231  tf.range(start, limit, delta)  # [3, 6, 9, 12, 15]
1232
1233  start = 3
1234  limit = 1
1235  delta = -0.5
1236  tf.range(start, limit, delta)  # [3, 2.5, 2, 1.5]
1237
1238  limit = 5
1239  tf.range(limit)  # [0, 1, 2, 3, 4]
1240  ```
1241
1242  Args:
1243    start: A 0-D `Tensor` (scalar). Acts as first entry in the range if
1244      `limit` is not None; otherwise, acts as range limit and first entry
1245      defaults to 0.
1246    limit: A 0-D `Tensor` (scalar). Upper limit of sequence,
1247      exclusive. If None, defaults to the value of `start` while the first
1248      entry of the range defaults to 0.
1249    delta: A 0-D `Tensor` (scalar). Number that increments
1250      `start`. Defaults to 1.
1251    dtype: The type of the elements of the resulting tensor.
1252    name: A name for the operation. Defaults to "range".
1253
1254  Returns:
1255    An 1-D `Tensor` of type `dtype`.
1256
1257  @compatibility(numpy)
1258  Equivalent to np.arange
1259  @end_compatibility
1260  """
1261  if limit is None:
1262    start, limit = 0, start
1263
1264  with ops.name_scope(name, "Range", [start, limit, delta]) as name:
1265    start = ops.convert_to_tensor(start, dtype=dtype, name="start")
1266    limit = ops.convert_to_tensor(limit, dtype=dtype, name="limit")
1267    delta = ops.convert_to_tensor(delta, dtype=dtype, name="delta")
1268
1269    # infer dtype if not explicitly provided
1270    if dtype is None:
1271      dtype_hierarchy = [
1272          dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
1273      ]
1274      assert all(arg.dtype in dtype_hierarchy for arg in [start, limit, delta])
1275      inferred_dtype = max(
1276          [arg.dtype for arg in [start, limit, delta]],
1277          key=dtype_hierarchy.index)
1278
1279      start = cast(start, inferred_dtype)
1280      limit = cast(limit, inferred_dtype)
1281      delta = cast(delta, inferred_dtype)
1282
1283    return gen_math_ops._range(start, limit, delta, name=name)
1284
1285
1286# Reduction operations
1287def _ReductionDims(x, axis, reduction_indices):
1288  """Returns range(0, rank(x)) if reduction_indices is None."""
1289  # TODO(aselle): Remove this after deprecation
1290  if reduction_indices is not None:
1291    if axis is not None:
1292      raise ValueError("Can't specify both axis' and 'reduction_indices'.")
1293    axis = reduction_indices
1294  if axis is not None:
1295    return axis
1296  else:
1297    # Fast path: avoid creating Rank and Range ops if ndims is known.
1298    if isinstance(x, ops.Tensor) and x.get_shape().ndims is not None:
1299      return constant_op.constant(
1300          np.arange(x.get_shape().ndims), dtype=dtypes.int32)
1301    if (isinstance(x, sparse_tensor.SparseTensor) and
1302        x.dense_shape.get_shape().is_fully_defined()):
1303      rank = x.dense_shape.get_shape()[0].value  # sparse.dense_shape is 1-D.
1304      return constant_op.constant(np.arange(rank), dtype=dtypes.int32)
1305
1306    # Otherwise, we rely on Range and Rank to do the right thing at run-time.
1307    return range(0, array_ops.rank(x))
1308
1309
1310def _may_reduce_to_scalar(keepdims, axis, reduction_indices, output):
1311  """Set a reduction's output's shape to be a scalar if we are certain."""
1312  if (not output.shape.is_fully_defined()) and (not keepdims) and (
1313      axis is None) and (reduction_indices is None):
1314    output.set_shape(())
1315  return output
1316
1317
1318@tf_export("reduce_sum")
1319@deprecation.deprecated_args(
1320    None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
1321def reduce_sum(input_tensor,
1322               axis=None,
1323               keepdims=None,
1324               name=None,
1325               reduction_indices=None,
1326               keep_dims=None):
1327  """Computes the sum of elements across dimensions of a tensor.
1328
1329  Reduces `input_tensor` along the dimensions given in `axis`.
1330  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
1331  entry in `axis`. If `keepdims` is true, the reduced dimensions
1332  are retained with length 1.
1333
1334  If `axis` has no entries, all dimensions are reduced, and a
1335  tensor with a single element is returned.
1336
1337  For example:
1338
1339  ```python
1340  x = tf.constant([[1, 1, 1], [1, 1, 1]])
1341  tf.reduce_sum(x)  # 6
1342  tf.reduce_sum(x, 0)  # [2, 2, 2]
1343  tf.reduce_sum(x, 1)  # [3, 3]
1344  tf.reduce_sum(x, 1, keepdims=True)  # [[3], [3]]
1345  tf.reduce_sum(x, [0, 1])  # 6
1346  ```
1347
1348  Args:
1349    input_tensor: The tensor to reduce. Should have numeric type.
1350    axis: The dimensions to reduce. If `None` (the default),
1351      reduces all dimensions. Must be in the range
1352      `[-rank(input_tensor), rank(input_tensor))`.
1353    keepdims: If true, retains reduced dimensions with length 1.
1354    name: A name for the operation (optional).
1355    reduction_indices: The old (deprecated) name for axis.
1356    keep_dims: Deprecated alias for `keepdims`.
1357
1358  Returns:
1359    The reduced tensor.
1360
1361  @compatibility(numpy)
1362  Equivalent to np.sum
1363  @end_compatibility
1364  """
1365  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
1366                                                    "keep_dims", keep_dims)
1367  if keepdims is None:
1368    keepdims = False
1369
1370  return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
1371                               gen_math_ops._sum(
1372                                   input_tensor,
1373                                   _ReductionDims(input_tensor, axis,
1374                                                  reduction_indices),
1375                                   keepdims,
1376                                   name=name))
1377
1378
1379@tf_export("count_nonzero")
1380@deprecation.deprecated_args(
1381    None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
1382def count_nonzero(input_tensor,
1383                  axis=None,
1384                  keepdims=None,
1385                  dtype=dtypes.int64,
1386                  name=None,
1387                  reduction_indices=None,
1388                  keep_dims=None):
1389  """Computes number of nonzero elements across dimensions of a tensor.
1390
1391  Reduces `input_tensor` along the dimensions given in `axis`.
1392  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
1393  entry in `axis`. If `keepdims` is true, the reduced dimensions
1394  are retained with length 1.
1395
1396  If `axis` has no entries, all dimensions are reduced, and a
1397  tensor with a single element is returned.
1398
1399  **NOTE** Floating point comparison to zero is done by exact floating point
1400  equality check.  Small values are **not** rounded to zero for purposes of
1401  the nonzero check.
1402
1403  For example:
1404
1405  ```python
1406  x = tf.constant([[0, 1, 0], [1, 1, 0]])
1407  tf.count_nonzero(x)  # 3
1408  tf.count_nonzero(x, 0)  # [1, 2, 0]
1409  tf.count_nonzero(x, 1)  # [1, 2]
1410  tf.count_nonzero(x, 1, keepdims=True)  # [[1], [2]]
1411  tf.count_nonzero(x, [0, 1])  # 3
1412  ```
1413
1414  Args:
1415    input_tensor: The tensor to reduce. Should be of numeric type, or `bool`.
1416    axis: The dimensions to reduce. If `None` (the default),
1417      reduces all dimensions. Must be in the range
1418      `[-rank(input_tensor), rank(input_tensor))`.
1419    keepdims: If true, retains reduced dimensions with length 1.
1420    dtype: The output dtype; defaults to `tf.int64`.
1421    name: A name for the operation (optional).
1422    reduction_indices: The old (deprecated) name for axis.
1423    keep_dims: Deprecated alias for `keepdims`.
1424
1425  Returns:
1426    The reduced tensor (number of nonzero values).
1427  """
1428  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
1429                                                    "keep_dims", keep_dims)
1430  if keepdims is None:
1431    keepdims = False
1432
1433  with ops.name_scope(name, "count_nonzero", [input_tensor]):
1434    input_tensor = ops.convert_to_tensor(input_tensor, name="input_tensor")
1435    zero = input_tensor.dtype.as_numpy_dtype()
1436    return cast(
1437        reduce_sum(
1438            # int64 reduction happens on GPU
1439            to_int64(gen_math_ops.not_equal(input_tensor, zero)),
1440            axis=axis,
1441            keepdims=keepdims,
1442            reduction_indices=reduction_indices),
1443        dtype=dtype)
1444
1445
1446@tf_export("reduce_mean")
1447@deprecation.deprecated_args(
1448    None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
1449def reduce_mean(input_tensor,
1450                axis=None,
1451                keepdims=None,
1452                name=None,
1453                reduction_indices=None,
1454                keep_dims=None):
1455  """Computes the mean of elements across dimensions of a tensor.
1456
1457  Reduces `input_tensor` along the dimensions given in `axis`.
1458  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
1459  entry in `axis`. If `keepdims` is true, the reduced dimensions
1460  are retained with length 1.
1461
1462  If `axis` has no entries, all dimensions are reduced, and a
1463  tensor with a single element is returned.
1464
1465  For example:
1466
1467  ```python
1468  x = tf.constant([[1., 1.], [2., 2.]])
1469  tf.reduce_mean(x)  # 1.5
1470  tf.reduce_mean(x, 0)  # [1.5, 1.5]
1471  tf.reduce_mean(x, 1)  # [1.,  2.]
1472  ```
1473
1474  Args:
1475    input_tensor: The tensor to reduce. Should have numeric type.
1476    axis: The dimensions to reduce. If `None` (the default),
1477      reduces all dimensions. Must be in the range
1478      `[-rank(input_tensor), rank(input_tensor)]`.
1479    keepdims: If true, retains reduced dimensions with length 1.
1480    name: A name for the operation (optional).
1481    reduction_indices: The old (deprecated) name for axis.
1482    keep_dims: Deprecated alias for `keepdims`.
1483
1484  Returns:
1485    The reduced tensor.
1486
1487  @compatibility(numpy)
1488  Equivalent to np.mean
1489
1490  Please note that `np.mean` has a `dtype` parameter that could be used to
1491  specify the output type. By default this is `dtype=float64`. On the other
1492  hand, `tf.reduce_mean` has an aggressive type inference from `input_tensor`,
1493  for example:
1494
1495  ```python
1496  x = tf.constant([1, 0, 1, 0])
1497  tf.reduce_mean(x)  # 0
1498  y = tf.constant([1., 0., 1., 0.])
1499  tf.reduce_mean(y)  # 0.5
1500  ```
1501
1502  @end_compatibility
1503  """
1504  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
1505                                                    "keep_dims", keep_dims)
1506
1507  if keepdims is None:
1508    keepdims = False
1509  return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
1510                               gen_math_ops._mean(
1511                                   input_tensor,
1512                                   _ReductionDims(input_tensor, axis,
1513                                                  reduction_indices),
1514                                   keepdims,
1515                                   name=name))
1516
1517
1518@tf_export("reduce_prod")
1519@deprecation.deprecated_args(
1520    None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
1521def reduce_prod(input_tensor,
1522                axis=None,
1523                keepdims=None,
1524                name=None,
1525                reduction_indices=None,
1526                keep_dims=None):
1527  """Computes the product of elements across dimensions of a tensor.
1528
1529  Reduces `input_tensor` along the dimensions given in `axis`.
1530  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
1531  entry in `axis`. If `keepdims` is true, the reduced dimensions
1532  are retained with length 1.
1533
1534  If `axis` has no entries, all dimensions are reduced, and a
1535  tensor with a single element is returned.
1536
1537  Args:
1538    input_tensor: The tensor to reduce. Should have numeric type.
1539    axis: The dimensions to reduce. If `None` (the default),
1540      reduces all dimensions. Must be in the range
1541      `[-rank(input_tensor), rank(input_tensor))`.
1542    keepdims: If true, retains reduced dimensions with length 1.
1543    name: A name for the operation (optional).
1544    reduction_indices: The old (deprecated) name for axis.
1545    keep_dims: Deprecated alias for `keepdims`.
1546
1547  Returns:
1548    The reduced tensor.
1549
1550  @compatibility(numpy)
1551  Equivalent to np.prod
1552  @end_compatibility
1553  """
1554  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
1555                                                    "keep_dims", keep_dims)
1556
1557  if keepdims is None:
1558    keepdims = False
1559  return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
1560                               gen_math_ops._prod(
1561                                   input_tensor,
1562                                   _ReductionDims(input_tensor, axis,
1563                                                  reduction_indices),
1564                                   keepdims,
1565                                   name=name))
1566
1567
1568@tf_export("reduce_min")
1569@deprecation.deprecated_args(
1570    None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
1571def reduce_min(input_tensor,
1572               axis=None,
1573               keepdims=None,
1574               name=None,
1575               reduction_indices=None,
1576               keep_dims=None):
1577  """Computes the minimum of elements across dimensions of a tensor.
1578
1579  Reduces `input_tensor` along the dimensions given in `axis`.
1580  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
1581  entry in `axis`. If `keepdims` is true, the reduced dimensions
1582  are retained with length 1.
1583
1584  If `axis` has no entries, all dimensions are reduced, and a
1585  tensor with a single element is returned.
1586
1587  Args:
1588    input_tensor: The tensor to reduce. Should have numeric type.
1589    axis: The dimensions to reduce. If `None` (the default),
1590      reduces all dimensions. Must be in the range
1591      `[-rank(input_tensor), rank(input_tensor))`.
1592    keepdims: If true, retains reduced dimensions with length 1.
1593    name: A name for the operation (optional).
1594    reduction_indices: The old (deprecated) name for axis.
1595    keep_dims: Deprecated alias for `keepdims`.
1596
1597  Returns:
1598    The reduced tensor.
1599
1600  @compatibility(numpy)
1601  Equivalent to np.min
1602  @end_compatibility
1603  """
1604  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
1605                                                    "keep_dims", keep_dims)
1606  if keepdims is None:
1607    keepdims = False
1608  return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
1609                               gen_math_ops._min(
1610                                   input_tensor,
1611                                   _ReductionDims(input_tensor, axis,
1612                                                  reduction_indices),
1613                                   keepdims,
1614                                   name=name))
1615
1616
1617@tf_export("reduce_max")
1618@deprecation.deprecated_args(
1619    None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
1620def reduce_max(input_tensor,
1621               axis=None,
1622               keepdims=None,
1623               name=None,
1624               reduction_indices=None,
1625               keep_dims=None):
1626  """Computes the maximum of elements across dimensions of a tensor.
1627
1628  Reduces `input_tensor` along the dimensions given in `axis`.
1629  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
1630  entry in `axis`. If `keepdims` is true, the reduced dimensions
1631  are retained with length 1.
1632
1633  If `axis` has no entries, all dimensions are reduced, and a
1634  tensor with a single element is returned.
1635
1636  Args:
1637    input_tensor: The tensor to reduce. Should have numeric type.
1638    axis: The dimensions to reduce. If `None` (the default),
1639      reduces all dimensions. Must be in the range
1640      `[-rank(input_tensor), rank(input_tensor))`.
1641    keepdims: If true, retains reduced dimensions with length 1.
1642    name: A name for the operation (optional).
1643    reduction_indices: The old (deprecated) name for axis.
1644    keep_dims: Deprecated alias for `keepdims`.
1645
1646  Returns:
1647    The reduced tensor.
1648
1649  @compatibility(numpy)
1650  Equivalent to np.max
1651  @end_compatibility
1652  """
1653  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
1654                                                    "keep_dims", keep_dims)
1655  if keepdims is None:
1656    keepdims = False
1657  return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
1658                               gen_math_ops._max(
1659                                   input_tensor,
1660                                   _ReductionDims(input_tensor, axis,
1661                                                  reduction_indices),
1662                                   keepdims,
1663                                   name=name))
1664
1665
1666@tf_export("reduce_all")
1667@deprecation.deprecated_args(
1668    None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
1669def reduce_all(input_tensor,
1670               axis=None,
1671               keepdims=None,
1672               name=None,
1673               reduction_indices=None,
1674               keep_dims=None):
1675  """Computes the "logical and" of elements across dimensions of a tensor.
1676
1677  Reduces `input_tensor` along the dimensions given in `axis`.
1678  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
1679  entry in `axis`. If `keepdims` is true, the reduced dimensions
1680  are retained with length 1.
1681
1682  If `axis` has no entries, all dimensions are reduced, and a
1683  tensor with a single element is returned.
1684
1685  For example:
1686
1687  ```python
1688  x = tf.constant([[True,  True], [False, False]])
1689  tf.reduce_all(x)  # False
1690  tf.reduce_all(x, 0)  # [False, False]
1691  tf.reduce_all(x, 1)  # [True, False]
1692  ```
1693
1694  Args:
1695    input_tensor: The boolean tensor to reduce.
1696    axis: The dimensions to reduce. If `None` (the default),
1697      reduces all dimensions. Must be in the range
1698      `[-rank(input_tensor), rank(input_tensor))`.
1699    keepdims: If true, retains reduced dimensions with length 1.
1700    name: A name for the operation (optional).
1701    reduction_indices: The old (deprecated) name for axis.
1702    keep_dims: Deprecated alias for `keepdims`.
1703
1704  Returns:
1705    The reduced tensor.
1706
1707  @compatibility(numpy)
1708  Equivalent to np.all
1709  @end_compatibility
1710  """
1711  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
1712                                                    "keep_dims", keep_dims)
1713  if keepdims is None:
1714    keepdims = False
1715  return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
1716                               gen_math_ops._all(
1717                                   input_tensor,
1718                                   _ReductionDims(input_tensor, axis,
1719                                                  reduction_indices),
1720                                   keepdims,
1721                                   name=name))
1722
1723
1724@tf_export("reduce_any")
1725@deprecation.deprecated_args(
1726    None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
1727def reduce_any(input_tensor,
1728               axis=None,
1729               keepdims=None,
1730               name=None,
1731               reduction_indices=None,
1732               keep_dims=None):
1733  """Computes the "logical or" of elements across dimensions of a tensor.
1734
1735  Reduces `input_tensor` along the dimensions given in `axis`.
1736  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
1737  entry in `axis`. If `keepdims` is true, the reduced dimensions
1738  are retained with length 1.
1739
1740  If `axis` has no entries, all dimensions are reduced, and a
1741  tensor with a single element is returned.
1742
1743  For example:
1744
1745  ```python
1746  x = tf.constant([[True,  True], [False, False]])
1747  tf.reduce_any(x)  # True
1748  tf.reduce_any(x, 0)  # [True, True]
1749  tf.reduce_any(x, 1)  # [True, False]
1750  ```
1751
1752  Args:
1753    input_tensor: The boolean tensor to reduce.
1754    axis: The dimensions to reduce. If `None` (the default),
1755      reduces all dimensions. Must be in the range
1756      `[-rank(input_tensor), rank(input_tensor))`.
1757    keepdims: If true, retains reduced dimensions with length 1.
1758    name: A name for the operation (optional).
1759    reduction_indices: The old (deprecated) name for axis.
1760    keep_dims: Deprecated alias for `keepdims`.
1761
1762  Returns:
1763    The reduced tensor.
1764
1765  @compatibility(numpy)
1766  Equivalent to np.any
1767  @end_compatibility
1768  """
1769  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
1770                                                    "keep_dims", keep_dims)
1771  if keepdims is None:
1772    keepdims = False
1773  return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
1774                               gen_math_ops._any(
1775                                   input_tensor,
1776                                   _ReductionDims(input_tensor, axis,
1777                                                  reduction_indices),
1778                                   keepdims,
1779                                   name=name))
1780
1781
1782@tf_export("reduce_logsumexp")
1783@deprecation.deprecated_args(
1784    None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
1785def reduce_logsumexp(input_tensor,
1786                     axis=None,
1787                     keepdims=None,
1788                     name=None,
1789                     reduction_indices=None,
1790                     keep_dims=None):
1791  """Computes log(sum(exp(elements across dimensions of a tensor))).
1792
1793  Reduces `input_tensor` along the dimensions given in `axis`.
1794  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
1795  entry in `axis`. If `keepdims` is true, the reduced dimensions
1796  are retained with length 1.
1797
1798  If `axis` has no entries, all dimensions are reduced, and a
1799  tensor with a single element is returned.
1800
1801  This function is more numerically stable than log(sum(exp(input))). It avoids
1802  overflows caused by taking the exp of large inputs and underflows caused by
1803  taking the log of small inputs.
1804
1805  For example:
1806
1807  ```python
1808  x = tf.constant([[0., 0., 0.], [0., 0., 0.]])
1809  tf.reduce_logsumexp(x)  # log(6)
1810  tf.reduce_logsumexp(x, 0)  # [log(2), log(2), log(2)]
1811  tf.reduce_logsumexp(x, 1)  # [log(3), log(3)]
1812  tf.reduce_logsumexp(x, 1, keepdims=True)  # [[log(3)], [log(3)]]
1813  tf.reduce_logsumexp(x, [0, 1])  # log(6)
1814  ```
1815
1816  Args:
1817    input_tensor: The tensor to reduce. Should have numeric type.
1818    axis: The dimensions to reduce. If `None` (the default),
1819      reduces all dimensions. Must be in the range
1820      `[-rank(input_tensor), rank(input_tensor))`.
1821    keepdims: If true, retains reduced dimensions with length 1.
1822    name: A name for the operation (optional).
1823    reduction_indices: The old (deprecated) name for axis.
1824    keep_dims: Deprecated alias for `keepdims`.
1825
1826  Returns:
1827    The reduced tensor.
1828  """
1829  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
1830                                                    "keep_dims", keep_dims)
1831  if keepdims is None:
1832    keepdims = False
1833  with ops.name_scope(name, "ReduceLogSumExp", [input_tensor]) as name:
1834    raw_max = reduce_max(
1835        input_tensor,
1836        axis=axis,
1837        reduction_indices=reduction_indices,
1838        keepdims=True)
1839    my_max = array_ops.stop_gradient(
1840        array_ops.where(
1841            gen_math_ops.is_finite(raw_max), raw_max,
1842            array_ops.zeros_like(raw_max)))
1843    result = gen_math_ops.log(
1844        reduce_sum(
1845            gen_math_ops.exp(input_tensor - my_max),
1846            axis,
1847            keepdims=keepdims,
1848            reduction_indices=reduction_indices))
1849    if not keepdims:
1850      my_max = array_ops.reshape(my_max, array_ops.shape(result))
1851    result += my_max
1852    return _may_reduce_to_scalar(keepdims, axis, reduction_indices, result)
1853
1854
1855@tf_export("trace", "linalg.trace")
1856def trace(x, name=None):
1857  """Compute the trace of a tensor `x`.
1858
1859  `trace(x)` returns the sum along the main diagonal of each inner-most matrix
1860  in x. If x is of rank `k` with shape `[I, J, K, ..., L, M, N]`, then output
1861  is a tensor of rank `k-2` with dimensions `[I, J, K, ..., L]` where
1862
1863  `output[i, j, k, ..., l] = trace(x[i, j, i, ..., l, :, :])`
1864
1865  For example:
1866
1867  ```python
1868  x = tf.constant([[1, 2], [3, 4]])
1869  tf.trace(x)  # 5
1870
1871  x = tf.constant([[1, 2, 3],
1872                   [4, 5, 6],
1873                   [7, 8, 9]])
1874  tf.trace(x)  # 15
1875
1876  x = tf.constant([[[1, 2, 3],
1877                    [4, 5, 6],
1878                    [7, 8, 9]],
1879                   [[-1, -2, -3],
1880                    [-4, -5, -6],
1881                    [-7, -8, -9]]])
1882  tf.trace(x)  # [15, -15]
1883  ```
1884
1885  Args:
1886    x: tensor.
1887    name: A name for the operation (optional).
1888
1889  Returns:
1890    The trace of input tensor.
1891  """
1892  with ops.name_scope(name, "Trace", [x]) as name:
1893    x = ops.convert_to_tensor(x, name="x")
1894    return reduce_sum(array_ops.matrix_diag_part(x), [-1], name=name)
1895
1896
1897@tf_export("matmul")
1898def matmul(a,
1899           b,
1900           transpose_a=False,
1901           transpose_b=False,
1902           adjoint_a=False,
1903           adjoint_b=False,
1904           a_is_sparse=False,
1905           b_is_sparse=False,
1906           name=None):
1907  """Multiplies matrix `a` by matrix `b`, producing `a` * `b`.
1908
1909  The inputs must, following any transpositions, be tensors of rank >= 2
1910  where the inner 2 dimensions specify valid matrix multiplication arguments,
1911  and any further outer dimensions match.
1912
1913  Both matrices must be of the same type. The supported types are:
1914  `float16`, `float32`, `float64`, `int32`, `complex64`, `complex128`.
1915
1916  Either matrix can be transposed or adjointed (conjugated and transposed) on
1917  the fly by setting one of the corresponding flag to `True`. These are `False`
1918  by default.
1919
1920  If one or both of the matrices contain a lot of zeros, a more efficient
1921  multiplication algorithm can be used by setting the corresponding
1922  `a_is_sparse` or `b_is_sparse` flag to `True`. These are `False` by default.
1923  This optimization is only available for plain matrices (rank-2 tensors) with
1924  datatypes `bfloat16` or `float32`.
1925
1926  For example:
1927
1928  ```python
1929  # 2-D tensor `a`
1930  # [[1, 2, 3],
1931  #  [4, 5, 6]]
1932  a = tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3])
1933
1934  # 2-D tensor `b`
1935  # [[ 7,  8],
1936  #  [ 9, 10],
1937  #  [11, 12]]
1938  b = tf.constant([7, 8, 9, 10, 11, 12], shape=[3, 2])
1939
1940  # `a` * `b`
1941  # [[ 58,  64],
1942  #  [139, 154]]
1943  c = tf.matmul(a, b)
1944
1945
1946  # 3-D tensor `a`
1947  # [[[ 1,  2,  3],
1948  #   [ 4,  5,  6]],
1949  #  [[ 7,  8,  9],
1950  #   [10, 11, 12]]]
1951  a = tf.constant(np.arange(1, 13, dtype=np.int32),
1952                  shape=[2, 2, 3])
1953
1954  # 3-D tensor `b`
1955  # [[[13, 14],
1956  #   [15, 16],
1957  #   [17, 18]],
1958  #  [[19, 20],
1959  #   [21, 22],
1960  #   [23, 24]]]
1961  b = tf.constant(np.arange(13, 25, dtype=np.int32),
1962                  shape=[2, 3, 2])
1963
1964  # `a` * `b`
1965  # [[[ 94, 100],
1966  #   [229, 244]],
1967  #  [[508, 532],
1968  #   [697, 730]]]
1969  c = tf.matmul(a, b)
1970
1971  # Since python >= 3.5 the @ operator is supported (see PEP 465).
1972  # In TensorFlow, it simply calls the `tf.matmul()` function, so the
1973  # following lines are equivalent:
1974  d = a @ b @ [[10.], [11.]]
1975  d = tf.matmul(tf.matmul(a, b), [[10.], [11.]])
1976  ```
1977
1978  Args:
1979    a: `Tensor` of type `float16`, `float32`, `float64`, `int32`, `complex64`,
1980      `complex128` and rank > 1.
1981    b: `Tensor` with same type and rank as `a`.
1982    transpose_a: If `True`, `a` is transposed before multiplication.
1983    transpose_b: If `True`, `b` is transposed before multiplication.
1984    adjoint_a: If `True`, `a` is conjugated and transposed before
1985      multiplication.
1986    adjoint_b: If `True`, `b` is conjugated and transposed before
1987      multiplication.
1988    a_is_sparse: If `True`, `a` is treated as a sparse matrix.
1989    b_is_sparse: If `True`, `b` is treated as a sparse matrix.
1990    name: Name for the operation (optional).
1991
1992  Returns:
1993    A `Tensor` of the same type as `a` and `b` where each inner-most matrix is
1994    the product of the corresponding matrices in `a` and `b`, e.g. if all
1995    transpose or adjoint attributes are `False`:
1996
1997    `output`[..., i, j] = sum_k (`a`[..., i, k] * `b`[..., k, j]),
1998    for all indices i, j.
1999
2000    Note: This is matrix product, not element-wise product.
2001
2002
2003  Raises:
2004    ValueError: If transpose_a and adjoint_a, or transpose_b and adjoint_b
2005      are both set to True.
2006  """
2007  with ops.name_scope(name, "MatMul", [a, b]) as name:
2008    if transpose_a and adjoint_a:
2009      raise ValueError("Only one of transpose_a and adjoint_a can be True.")
2010    if transpose_b and adjoint_b:
2011      raise ValueError("Only one of transpose_b and adjoint_b can be True.")
2012
2013    a = ops.convert_to_tensor(a, name="a")
2014    b = ops.convert_to_tensor(b, name="b")
2015    # TODO(apassos) remove _shape_tuple here when it is not needed.
2016    a_shape = a._shape_tuple()  # pylint: disable=protected-access
2017    b_shape = b._shape_tuple()  # pylint: disable=protected-access
2018    if (not a_is_sparse and
2019        not b_is_sparse) and ((a_shape is None or len(a_shape) > 2) and
2020                              (b_shape is None or len(b_shape) > 2)):
2021      # BatchMatmul does not support transpose, so we conjugate the matrix and
2022      # use adjoint instead. Conj() is a noop for real matrices.
2023      if transpose_a:
2024        a = conj(a)
2025        adjoint_a = True
2026      if transpose_b:
2027        b = conj(b)
2028        adjoint_b = True
2029      return gen_math_ops._batch_mat_mul(
2030          a, b, adj_x=adjoint_a, adj_y=adjoint_b, name=name)
2031
2032    # Neither matmul nor sparse_matmul support adjoint, so we conjugate
2033    # the matrix and use transpose instead. Conj() is a noop for real
2034    # matrices.
2035    if adjoint_a:
2036      a = conj(a)
2037      transpose_a = True
2038    if adjoint_b:
2039      b = conj(b)
2040      transpose_b = True
2041
2042    use_sparse_matmul = False
2043    if a_is_sparse or b_is_sparse:
2044      sparse_matmul_types = [dtypes.bfloat16, dtypes.float32]
2045      use_sparse_matmul = (
2046          a.dtype in sparse_matmul_types and b.dtype in sparse_matmul_types)
2047    if a.dtype == dtypes.bfloat16 or b.dtype == dtypes.bfloat16:
2048      # matmul currently doesn't handle bfloat16 inputs.
2049      use_sparse_matmul = True
2050    if use_sparse_matmul:
2051      ret = sparse_matmul(
2052          a,
2053          b,
2054          transpose_a=transpose_a,
2055          transpose_b=transpose_b,
2056          a_is_sparse=a_is_sparse,
2057          b_is_sparse=b_is_sparse,
2058          name=name)
2059      # sparse_matmul always returns float32, even with
2060      # bfloat16 inputs. This prevents us from configuring bfloat16 training.
2061      # casting to bfloat16 also matches non-sparse matmul behavior better.
2062      if a.dtype == dtypes.bfloat16 and b.dtype == dtypes.bfloat16:
2063        ret = cast(ret, dtypes.bfloat16)
2064      return ret
2065    else:
2066      return gen_math_ops._mat_mul(
2067          a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name)
2068
2069
2070_OverrideBinaryOperatorHelper(matmul, "matmul")
2071
2072sparse_matmul = gen_math_ops._sparse_mat_mul
2073
2074
2075@ops.RegisterStatistics("MatMul", "flops")
2076def _calc_mat_mul_flops(graph, node):
2077  """Calculates the compute resources needed for MatMul."""
2078  transpose_a = node.attr["transpose_a"].b
2079  a_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
2080  a_shape.assert_is_fully_defined()
2081  if transpose_a:
2082    k = int(a_shape[0])
2083  else:
2084    k = int(a_shape[1])
2085  output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
2086  output_shape.assert_is_fully_defined()
2087  output_count = np.prod(output_shape.as_list())
2088  return ops.OpStats("flops", (k * output_count * 2))
2089
2090
2091def _as_indexed_slices(x, optimize=True):
2092  """Convert 'x' to IndexedSlices.
2093
2094  Convert a dense Tensor to a block-sparse IndexedSlices.
2095
2096  Args:
2097    x: Either a Tensor object, or an IndexedSlices object.
2098    optimize: if true, attempt to optimize the conversion of 'x'.
2099
2100  Returns:
2101    An IndexedSlices object.
2102
2103  Raises:
2104    TypeError: If 'x' is not a Tensor or an IndexedSlices object.
2105  """
2106  # TODO(touts): op_scope
2107  if not isinstance(x, (ops.Tensor, ops.IndexedSlices)):
2108    raise TypeError("Not a Tensor or IndexedSlices: %s" % type(x))
2109  if isinstance(x, ops.IndexedSlices):
2110    return x
2111  x_shape = array_ops.shape_internal(x, optimize=optimize)
2112  return ops.IndexedSlices(x, range(0, x_shape[0]), x_shape)
2113
2114
2115def _as_indexed_slices_list(inputs, optimize=True):
2116  """Convert all elements of 'inputs' to IndexedSlices.
2117
2118  Additionally, homogenize the types of all the indices to
2119  either int32 or int64.
2120
2121  Args:
2122    inputs: List containing either Tensor or IndexedSlices objects.
2123    optimize: if true, attempt to optimize the conversion of each input.
2124
2125  Returns:
2126    A list of IndexedSlices objects.
2127
2128  Raises:
2129    TypeError: If 'inputs' is not a list or a tuple.
2130  """
2131  if not isinstance(inputs, (list, tuple)):
2132    raise TypeError("Expected a list or tuple, not a %s" % type(inputs))
2133  outputs = [_as_indexed_slices(i, optimize=optimize) for i in inputs]
2134  with_int32_index = [
2135      o.indices for o in outputs if o.indices.dtype == dtypes.int32
2136  ]
2137  if not with_int32_index or len(with_int32_index) == len(outputs):
2138    return outputs
2139  casted_outputs = []
2140  for o in outputs:
2141    if o.indices.dtype == dtypes.int32:
2142      casted_outputs.append(
2143          ops.IndexedSlices(o.values, cast(o.indices, dtypes.int64),
2144                            o.dense_shape))
2145    else:
2146      casted_outputs.append(o)
2147  return casted_outputs
2148
2149
2150@tf_export("add_n")
2151def add_n(inputs, name=None):
2152  """Adds all input tensors element-wise.
2153
2154  Args:
2155    inputs: A list of `Tensor` objects, each with same shape and type.
2156    name: A name for the operation (optional).
2157
2158  Returns:
2159    A `Tensor` of same shape and type as the elements of `inputs`.
2160
2161  Raises:
2162    ValueError: If `inputs` don't all have same shape and dtype or the shape
2163    cannot be inferred.
2164  """
2165  if not inputs or not isinstance(inputs, (list, tuple)):
2166    raise ValueError("inputs must be a list of at least one Tensor with the "
2167                     "same dtype and shape")
2168  inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs)
2169  if not all(isinstance(x, ops.Tensor) for x in inputs):
2170    raise ValueError("inputs must be a list of at least one Tensor with the "
2171                     "same dtype and shape")
2172
2173  if len(inputs) == 1:
2174    if name:
2175      return array_ops.identity(inputs[0], name=name)
2176    return inputs[0]
2177  return gen_math_ops._add_n(inputs, name=name)
2178
2179
2180@tf_export("accumulate_n")
2181def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
2182  """Returns the element-wise sum of a list of tensors.
2183
2184  Optionally, pass `shape` and `tensor_dtype` for shape and type checking,
2185  otherwise, these are inferred.
2186
2187  NOTE: This operation is not differentiable and cannot be used if inputs depend
2188  on trainable variables. Please use `tf.add_n` for such cases.
2189
2190  Aside from differentiability, `tf.accumulate_n` performs the same operation as
2191  `tf.add_n`, but does not wait for all of its inputs to be ready before
2192  beginning to sum. This can save memory if inputs are ready at different times,
2193  since minimum temporary storage is proportional to the output size rather than
2194  the inputs size.
2195
2196  For example:
2197
2198  ```python
2199  a = tf.constant([[1, 2], [3, 4]])
2200  b = tf.constant([[5, 0], [0, 6]])
2201  tf.accumulate_n([a, b, a])  # [[7, 4], [6, 14]]
2202
2203  # Explicitly pass shape and type
2204  tf.accumulate_n([a, b, a], shape=[2, 2], tensor_dtype=tf.int32)  # [[7,  4],
2205                                                                   #  [6, 14]]
2206  ```
2207
2208  Args:
2209    inputs: A list of `Tensor` objects, each with same shape and type.
2210    shape: Shape of elements of `inputs`.
2211    tensor_dtype: The type of `inputs`.
2212    name: A name for the operation (optional).
2213
2214  Returns:
2215    A `Tensor` of same shape and type as the elements of `inputs`.
2216
2217  Raises:
2218    ValueError: If `inputs` don't all have same shape and dtype or the shape
2219    cannot be inferred.
2220  """
2221  if context.in_eager_mode():
2222    # TODO(apassos) remove this once the lifetime of eager variables gets
2223    # addressed.
2224    raise ValueError("accumulate_n not supported in eager mode")
2225  if not inputs or not isinstance(inputs, (list, tuple)):
2226    raise ValueError("inputs must be a list of at least one Tensor with the "
2227                     "same dtype and shape")
2228  inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs)
2229  if not all(isinstance(x, ops.Tensor) for x in inputs):
2230    raise ValueError("inputs must be a list of at least one Tensor with the "
2231                     "same dtype and shape")
2232  if not all(x.dtype == inputs[0].dtype for x in inputs):
2233    raise ValueError("inputs must be a list of at least one Tensor with the "
2234                     "same dtype and shape")
2235  if shape is not None:
2236    shape = tensor_shape.as_shape(shape)
2237  else:
2238    shape = tensor_shape.unknown_shape()
2239  for input_tensor in inputs:
2240    if isinstance(input_tensor, ops.Tensor):
2241      shape = shape.merge_with(input_tensor.get_shape())
2242  if tensor_dtype is None:
2243    tensor_dtype = inputs[0].dtype
2244  if tensor_dtype != inputs[0].dtype:
2245    raise TypeError("tensor_dtype is {}, but input is of type {}".format(
2246        tensor_dtype, inputs[0].dtype))
2247  if len(inputs) == 1:
2248    return inputs[0]
2249  with ops.name_scope(name, "AccumulateN", inputs) as name:
2250    var = gen_state_ops._temporary_variable(
2251        shape=tensor_shape.vector(0), dtype=tensor_dtype)
2252    with ops.colocate_with(var):
2253      zeros = array_ops.zeros_like(gen_control_flow_ops._merge(inputs)[0])
2254      zeros.set_shape(shape)
2255      ref = state_ops.assign(var, zeros, validate_shape=False)
2256      update_ops = [
2257          state_ops.assign_add(ref, input_tensor, use_locking=True)
2258          for input_tensor in inputs
2259      ]
2260      with ops.control_dependencies(update_ops):
2261        return gen_state_ops._destroy_temporary_variable(
2262            ref, var_name=var.op.name, name=name)
2263
2264
2265@tf_export("nn.sigmoid", "sigmoid")
2266def sigmoid(x, name=None):
2267  """Computes sigmoid of `x` element-wise.
2268
2269  Specifically, `y = 1 / (1 + exp(-x))`.
2270
2271  Args:
2272    x: A Tensor with type `float16`, `float32`, `float64`, `complex64`,
2273      or `complex128`.
2274    name: A name for the operation (optional).
2275
2276  Returns:
2277    A Tensor with the same type as `x`.
2278
2279  @compatibility(numpy)
2280  Equivalent to np.scipy.special.expit
2281  @end_compatibility
2282  """
2283  with ops.name_scope(name, "Sigmoid", [x]) as name:
2284    x = ops.convert_to_tensor(x, name="x")
2285    return gen_math_ops._sigmoid(x, name=name)
2286
2287
2288@tf_export("log_sigmoid")
2289def log_sigmoid(x, name=None):
2290  """Computes log sigmoid of `x` element-wise.
2291
2292  Specifically, `y = log(1 / (1 + exp(-x)))`.  For numerical stability,
2293  we use `y = -tf.nn.softplus(-x)`.
2294
2295  Args:
2296    x: A Tensor with type `float32` or `float64`.
2297    name: A name for the operation (optional).
2298
2299  Returns:
2300    A Tensor with the same type as `x`.
2301  """
2302  with ops.name_scope(name, "LogSigmoid", [x]) as name:
2303    x = ops.convert_to_tensor(x, name="x")
2304    return gen_math_ops._neg(gen_nn_ops.softplus(-x), name=name)
2305
2306
2307@tf_export("nn.tanh", "tanh")
2308def tanh(x, name=None):
2309  """Computes hyperbolic tangent of `x` element-wise.
2310
2311  Args:
2312    x: A Tensor or SparseTensor with type `float16`, `float32`, `double`,
2313      `complex64`, or `complex128`.
2314    name: A name for the operation (optional).
2315
2316  Returns:
2317    A Tensor or SparseTensor respectively with the same type as `x`.
2318  """
2319  with ops.name_scope(name, "Tanh", [x]) as name:
2320    if isinstance(x, sparse_tensor.SparseTensor):
2321      x_tanh = gen_math_ops._tanh(x.values, name=name)
2322      return sparse_tensor.SparseTensor(
2323          indices=x.indices, values=x_tanh, dense_shape=x.dense_shape)
2324    else:
2325      return gen_math_ops._tanh(x, name=name)
2326
2327
2328@tf_export("bincount")
2329def bincount(arr,
2330             weights=None,
2331             minlength=None,
2332             maxlength=None,
2333             dtype=dtypes.int32):
2334  """Counts the number of occurrences of each value in an integer array.
2335
2336  If `minlength` and `maxlength` are not given, returns a vector with length
2337  `tf.reduce_max(arr) + 1` if `arr` is non-empty, and length 0 otherwise.
2338  If `weights` are non-None, then index `i` of the output stores the sum of the
2339  value in `weights` at each index where the corresponding value in `arr` is
2340  `i`.
2341
2342  Args:
2343    arr: An int32 tensor of non-negative values.
2344    weights: If non-None, must be the same shape as arr. For each value in
2345        `arr`, the bin will be incremented by the corresponding weight instead
2346        of 1.
2347    minlength: If given, ensures the output has length at least `minlength`,
2348        padding with zeros at the end if necessary.
2349    maxlength: If given, skips values in `arr` that are equal or greater than
2350        `maxlength`, ensuring that the output has length at most `maxlength`.
2351    dtype: If `weights` is None, determines the type of the output bins.
2352
2353  Returns:
2354    A vector with the same dtype as `weights` or the given `dtype`. The bin
2355    values.
2356  """
2357  arr = ops.convert_to_tensor(arr, name="arr", dtype=dtypes.int32)
2358  array_is_nonempty = reduce_prod(array_ops.shape(arr)) > 0
2359  output_size = cast(array_is_nonempty, dtypes.int32) * (reduce_max(arr) + 1)
2360  if minlength is not None:
2361    minlength = ops.convert_to_tensor(
2362        minlength, name="minlength", dtype=dtypes.int32)
2363    output_size = gen_math_ops.maximum(minlength, output_size)
2364  if maxlength is not None:
2365    maxlength = ops.convert_to_tensor(
2366        maxlength, name="maxlength", dtype=dtypes.int32)
2367    output_size = gen_math_ops.minimum(maxlength, output_size)
2368  if weights is not None:
2369    weights = ops.convert_to_tensor(weights, name="weights")
2370    return gen_math_ops.unsorted_segment_sum(weights, arr, output_size)
2371  weights = constant_op.constant([], dtype)
2372  return gen_math_ops.bincount(arr, output_size, weights)
2373
2374
2375@tf_export("cumsum")
2376def cumsum(x, axis=0, exclusive=False, reverse=False, name=None):
2377  """Compute the cumulative sum of the tensor `x` along `axis`.
2378
2379  By default, this op performs an inclusive cumsum, which means that the first
2380  element of the input is identical to the first element of the output:
2381
2382  ```python
2383  tf.cumsum([a, b, c])  # [a, a + b, a + b + c]
2384  ```
2385
2386  By setting the `exclusive` kwarg to `True`, an exclusive cumsum is performed
2387  instead:
2388
2389  ```python
2390  tf.cumsum([a, b, c], exclusive=True)  # [0, a, a + b]
2391  ```
2392
2393  By setting the `reverse` kwarg to `True`, the cumsum is performed in the
2394  opposite direction:
2395
2396  ```python
2397  tf.cumsum([a, b, c], reverse=True)  # [a + b + c, b + c, c]
2398  ```
2399
2400  This is more efficient than using separate `tf.reverse` ops.
2401
2402  The `reverse` and `exclusive` kwargs can also be combined:
2403
2404  ```python
2405  tf.cumsum([a, b, c], exclusive=True, reverse=True)  # [b + c, c, 0]
2406  ```
2407
2408  Args:
2409    x: A `Tensor`. Must be one of the following types: `float32`, `float64`,
2410       `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`,
2411       `complex128`, `qint8`, `quint8`, `qint32`, `half`.
2412    axis: A `Tensor` of type `int32` (default: 0). Must be in the range
2413      `[-rank(x), rank(x))`.
2414    exclusive: If `True`, perform exclusive cumsum.
2415    reverse: A `bool` (default: False).
2416    name: A name for the operation (optional).
2417
2418  Returns:
2419    A `Tensor`. Has the same type as `x`.
2420  """
2421  with ops.name_scope(name, "Cumsum", [x]) as name:
2422    x = ops.convert_to_tensor(x, name="x")
2423    return gen_math_ops.cumsum(
2424        x, axis, exclusive=exclusive, reverse=reverse, name=name)
2425
2426
2427@tf_export("cumprod")
2428def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
2429  """Compute the cumulative product of the tensor `x` along `axis`.
2430
2431  By default, this op performs an inclusive cumprod, which means that the
2432  first element of the input is identical to the first element of the output:
2433
2434  ```python
2435  tf.cumprod([a, b, c])  # [a, a * b, a * b * c]
2436  ```
2437
2438  By setting the `exclusive` kwarg to `True`, an exclusive cumprod is
2439  performed
2440  instead:
2441
2442  ```python
2443  tf.cumprod([a, b, c], exclusive=True)  # [1, a, a * b]
2444  ```
2445
2446  By setting the `reverse` kwarg to `True`, the cumprod is performed in the
2447  opposite direction:
2448
2449  ```python
2450  tf.cumprod([a, b, c], reverse=True)  # [a * b * c, b * c, c]
2451  ```
2452
2453  This is more efficient than using separate `tf.reverse` ops.
2454  The `reverse` and `exclusive` kwargs can also be combined:
2455
2456  ```python
2457  tf.cumprod([a, b, c], exclusive=True, reverse=True)  # [b * c, c, 1]
2458  ```
2459
2460  Args:
2461    x: A `Tensor`. Must be one of the following types: `float32`, `float64`,
2462       `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`,
2463       `complex128`, `qint8`, `quint8`, `qint32`, `half`.
2464    axis: A `Tensor` of type `int32` (default: 0). Must be in the range
2465      `[-rank(x), rank(x))`.
2466    exclusive: If `True`, perform exclusive cumprod.
2467    reverse: A `bool` (default: False).
2468    name: A name for the operation (optional).
2469
2470  Returns:
2471    A `Tensor`. Has the same type as `x`.
2472  """
2473  with ops.name_scope(name, "Cumprod", [x]) as name:
2474    x = ops.convert_to_tensor(x, name="x")
2475    return gen_math_ops.cumprod(
2476        x, axis, exclusive=exclusive, reverse=reverse, name=name)
2477
2478
2479@tf_export("conj")
2480def conj(x, name=None):
2481  r"""Returns the complex conjugate of a complex number.
2482
2483  Given a tensor `input` of complex numbers, this operation returns a tensor of
2484  complex numbers that are the complex conjugate of each element in `input`. The
2485  complex numbers in `input` must be of the form \\(a + bj\\), where *a* is the
2486  real part and *b* is the imaginary part.
2487
2488  The complex conjugate returned by this operation is of the form \\(a - bj\\).
2489
2490  For example:
2491
2492      # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
2493      tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j]
2494
2495  If `x` is real, it is returned unchanged.
2496
2497  Args:
2498    x: `Tensor` to conjugate.  Must have numeric or variant type.
2499    name: A name for the operation (optional).
2500
2501  Returns:
2502    A `Tensor` that is the conjugate of `x` (with the same type).
2503
2504  Raises:
2505    TypeError: If `x` is not a numeric tensor.
2506  """
2507  if isinstance(x, ops.Tensor):
2508    dt = x.dtype
2509    if dt.is_floating or dt.is_integer:
2510      return x
2511  with ops.name_scope(name, "Conj", [x]) as name:
2512    x = ops.convert_to_tensor(x, name="x")
2513    if x.dtype.is_complex or x.dtype == dtypes.variant:
2514      return gen_math_ops._conj(x, name=name)
2515    elif x.dtype.is_floating or x.dtype.is_integer:
2516      return x
2517    else:
2518      raise TypeError(
2519          "Expected numeric or variant tensor, got dtype %r" % x.dtype)
2520
2521
2522def _BroadcastShape(op):
2523  """Common shape function for binary operators that broadcast their inputs."""
2524  return [
2525      common_shapes.broadcast_shape(op.inputs[0].get_shape(),
2526                                    op.inputs[1].get_shape())
2527  ]
2528
2529
2530def reduced_shape(input_shape, axes):
2531  """Helper function for reduction ops.
2532
2533  Args:
2534    input_shape: 1-D Tensor, the shape of the Tensor being reduced.
2535    axes: 1-D Tensor, the reduction axes.
2536  Returns:
2537    A 1-D Tensor, the output shape as if keepdims were set to True.
2538  """
2539  # Example:
2540  # cast needed for SparseTensor reductions
2541  input_shape = to_int32(input_shape)  # [2, 3, 5, 7]
2542  axes = to_int32(axes)  # [1, 2]
2543
2544  input_rank = array_ops.size(input_shape)  # 4
2545  axes = (axes + input_rank) % input_rank
2546  axes_shape = array_ops.shape(axes)  # [2]
2547  return gen_data_flow_ops.dynamic_stitch(  # [2, 1, 1, 7]
2548      [
2549          range(input_rank),  # [0, 1, 2, 3]
2550          axes
2551      ],  # [1, 2]
2552      [
2553          input_shape,  # [2, 3, 5, 7]
2554          array_ops.fill(axes_shape, 1)
2555      ])  # [1, 1]
2556
2557
2558def _unsorted_segment_N(data, segment_ids, num_segments):
2559  """ Helper function for unsorted_segment_mean/_sqrtN. Computes the number
2560      of segment entries with 0-entries set to 1 to allow division by N.
2561  """
2562  # bincount doesn't support negative indices so we use unsorted_segment_sum
2563  ones_tensor = array_ops.ones(segment_ids.shape, dtype=data.dtype)
2564  N = gen_math_ops.unsorted_segment_sum(ones_tensor, segment_ids, num_segments)
2565  # add dimensions for all non-reduced axes
2566  ndims_output = data.shape.ndims - segment_ids.shape.ndims
2567  broadcast_shape = [num_segments] + [1] * ndims_output
2568  N = array_ops.reshape(N, broadcast_shape)
2569  return gen_math_ops.maximum(N, 1)
2570
2571
2572@tf_export("unsorted_segment_mean")
2573def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
2574  r""" Computes the mean along segments of a tensor.
2575
2576  Read @{$math_ops#segmentation$the section on segmentation} for an explanation
2577  of segments.
2578
2579  This operator is similar to the unsorted segment sum operator found
2580  [here](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
2581  Instead of computing the sum over segments, it computes the mean of all
2582  entries belonging to a segment such that:
2583
2584  \\(output_i = 1/N_i \sum data_j\\) where the sum is over `j` such
2585  that `segment_ids[j] == i` with \\N_i\\ being the number of occurrences
2586  of id \\i\\.
2587
2588  If there is no entry for a given segment ID `i`, it outputs 0.
2589
2590  segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
2591  first dimension.
2592
2593  output: Has same shape as data, except for dimension 0 which
2594  has size `num_segments`.
2595  """
2596  with ops.name_scope(name, "UnsortedSegmentMean"):
2597    data = ops.convert_to_tensor(data)
2598    segment_ids = ops.convert_to_tensor(segment_ids)
2599    N = _unsorted_segment_N(data, segment_ids, num_segments)
2600    summed = gen_math_ops.unsorted_segment_sum(data, segment_ids, num_segments)
2601    return summed / N
2602
2603
2604@tf_export("unsorted_segment_sqrt_n")
2605def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None):
2606  r"""Computes the sum along segments of a tensor divided by the sqrt(N).
2607
2608  Read @{$math_ops#segmentation$the section on segmentation} for an explanation
2609  of segments.
2610
2611  This operator is similar to the unsorted segment sum operator found
2612  [here](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
2613  Additionally to computing the sum over segments, it divides the results by
2614  sqrt(N).
2615
2616  \\(output_i = 1/sqrt(N_i) \sum data_j\\) where the sum is over `j` such
2617  that `segment_ids[j] == i` with \\N_i\\ being the number of occurrences
2618  of id \\i\\.
2619
2620  If there is no entry for a given segment ID `i`, it outputs 0.
2621
2622  Note that this op only supports floating point and complex dtypes,
2623  due to tf.sqrt only supporting these types.
2624
2625  segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
2626  first dimension.
2627
2628  output: Has same shape as data, except for dimension 0 which
2629  has size `num_segments`.
2630  """
2631  with ops.name_scope(name, "UnsortedSegmentSqrtN"):
2632    data = ops.convert_to_tensor(data)
2633    segment_ids = ops.convert_to_tensor(segment_ids)
2634    N = _unsorted_segment_N(data, segment_ids, num_segments)
2635    summed = gen_math_ops.unsorted_segment_sum(data, segment_ids, num_segments)
2636    return summed / gen_math_ops.sqrt(N)
2637
2638
2639@tf_export("sparse_segment_sum")
2640def sparse_segment_sum(data, indices, segment_ids, name=None,
2641                       num_segments=None):
2642  r"""Computes the sum along sparse segments of a tensor.
2643
2644  Read @{$math_ops#Segmentation$the section on segmentation} for an explanation
2645  of segments.
2646
2647  Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first
2648  dimension, selecting a subset of dimension 0, specified by `indices`.
2649  `segment_ids` is allowed to have missing ids, in which case the output will
2650  be zeros at those indices. In those cases `num_segments` is used to determine
2651  the size of the output.
2652
2653  For example:
2654
2655  ```python
2656  c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
2657
2658  # Select two rows, one segment.
2659  tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0]))
2660  # => [[0 0 0 0]]
2661
2662  # Select two rows, two segment.
2663  tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1]))
2664  # => [[ 1  2  3  4]
2665  #     [-1 -2 -3 -4]]
2666
2667  # With missing segment ids.
2668  tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 2]),
2669                        num_segments=4)
2670  # => [[ 1  2  3  4]
2671  #     [ 0  0  0  0]
2672  #     [-1 -2 -3 -4]
2673  #     [ 0  0  0  0]]
2674
2675  # Select all rows, two segments.
2676  tf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1]))
2677  # => [[0 0 0 0]
2678  #     [5 6 7 8]]
2679
2680  # Which is equivalent to:
2681  tf.segment_sum(c, tf.constant([0, 0, 1]))
2682  ```
2683
2684  Args:
2685    data: A `Tensor` with data that will be assembled in the output.
2686    indices: A 1-D `Tensor` with indices into `data`. Has same rank as
2687      `segment_ids`.
2688    segment_ids: A 1-D `Tensor` with indices into the output `Tensor`.
2689      Values should be sorted and can be repeated.
2690    name: A name for the operation (optional).
2691    num_segments: An optional int32 scalar. Indicates the size of the output
2692      `Tensor`.
2693
2694  Returns:
2695    A `tensor` of the shape as data, except for dimension 0 which
2696    has size `k`, the number of segments specified via `num_segments` or
2697    inferred for the last element in `segments_ids`.
2698  """
2699  if num_segments is not None:
2700    return gen_math_ops.sparse_segment_sum_with_num_segments(
2701        data=data,
2702        indices=indices,
2703        segment_ids=segment_ids,
2704        num_segments=num_segments,
2705        name=name)
2706  else:
2707    return gen_math_ops.sparse_segment_sum(
2708        data=data,
2709        indices=indices,
2710        segment_ids=segment_ids,
2711        name=name)
2712
2713
2714@tf_export("sparse_segment_mean")
2715def sparse_segment_mean(data, indices, segment_ids, name=None,
2716                        num_segments=None):
2717  r"""Computes the mean along sparse segments of a tensor.
2718
2719  Read @{$math_ops#Segmentation$the section on segmentation} for an explanation
2720  of segments.
2721
2722  Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
2723  dimension, selecting a subset of dimension 0, specified by `indices`.
2724  `segment_ids` is allowed to have missing ids, in which case the output will
2725  be zeros at those indices. In those cases `num_segments` is used to determine
2726  the size of the output.
2727
2728  Args:
2729    data: A `Tensor` with data that will be assembled in the output.
2730    indices: A 1-D `Tensor` with indices into `data`. Has same rank as
2731      `segment_ids`.
2732    segment_ids: A 1-D `Tensor` with indices into the output `Tensor`.
2733      Values should be sorted and can be repeated.
2734    name: A name for the operation (optional).
2735    num_segments: An optional int32 scalar. Indicates the size of the output
2736      `Tensor`.
2737
2738  Returns:
2739    A `tensor` of the shape as data, except for dimension 0 which
2740    has size `k`, the number of segments specified via `num_segments` or
2741    inferred for the last element in `segments_ids`.
2742  """
2743  if num_segments is not None:
2744    return gen_math_ops.sparse_segment_mean_with_num_segments(
2745        data=data,
2746        indices=indices,
2747        segment_ids=segment_ids,
2748        num_segments=num_segments,
2749        name=name)
2750  else:
2751    return gen_math_ops.sparse_segment_mean(
2752        data=data,
2753        indices=indices,
2754        segment_ids=segment_ids,
2755        name=name)
2756
2757
2758@tf_export("sparse_segment_sqrt_n")
2759def sparse_segment_sqrt_n(data, indices, segment_ids, name=None,
2760                          num_segments=None):
2761  r"""Computes the sum along sparse segments of a tensor divided by the sqrt(N).
2762
2763  `N` is the size of the segment being reduced.
2764
2765  Args:
2766    data: A `Tensor` with data that will be assembled in the output.
2767    indices: A 1-D `Tensor` with indices into `data`. Has same rank as
2768      `segment_ids`.
2769    segment_ids: A 1-D `Tensor` with indices into the output `Tensor`.
2770      Values should be sorted and can be repeated.
2771    name: A name for the operation (optional).
2772    num_segments: An optional int32 scalar. Indicates the size of the output
2773      `Tensor`.
2774
2775  Returns:
2776    A `tensor` of the shape as data, except for dimension 0 which
2777    has size `k`, the number of segments specified via `num_segments` or
2778    inferred for the last element in `segments_ids`.
2779  """
2780  if num_segments is not None:
2781    return gen_math_ops.sparse_segment_sqrt_n_with_num_segments(
2782        data=data,
2783        indices=indices,
2784        segment_ids=segment_ids,
2785        num_segments=num_segments,
2786        name=name)
2787  else:
2788    return gen_math_ops.sparse_segment_sqrt_n(
2789        data=data,
2790        indices=indices,
2791        segment_ids=segment_ids,
2792        name=name)
2793
2794
2795@tf_export("tensordot", "linalg.tensordot")
2796def tensordot(a, b, axes, name=None):
2797  r"""Tensor contraction of a and b along specified axes.
2798
2799  Tensordot (also known as tensor contraction) sums the product of elements
2800  from `a` and `b` over the indices specified by `a_axes` and `b_axes`.
2801  The lists `a_axes` and `b_axes` specify those pairs of axes along which to
2802  contract the tensors. The axis `a_axes[i]` of `a` must have the same dimension
2803  as axis `b_axes[i]` of `b` for all `i` in `range(0, len(a_axes))`. The lists
2804  `a_axes` and `b_axes` must have identical length and consist of unique
2805  integers that specify valid axes for each of the tensors.
2806
2807  This operation corresponds to `numpy.tensordot(a, b, axes)`.
2808
2809  Example 1: When `a` and `b` are matrices (order 2), the case `axes = 1`
2810  is equivalent to matrix multiplication.
2811
2812  Example 2: When `a` and `b` are matrices (order 2), the case
2813  `axes = [[1], [0]]` is equivalent to matrix multiplication.
2814
2815  Example 3: Suppose that \\(a_{ijk}\\) and \\(b_{lmn}\\) represent two
2816  tensors of order 3. Then, `contract(a, b, [[0], [2]])` is the order 4 tensor
2817  \\(c_{jklm}\\) whose entry
2818  corresponding to the indices \\((j,k,l,m)\\) is given by:
2819
2820  \\( c_{jklm} = \sum_i a_{ijk} b_{lmi} \\).
2821
2822  In general, `order(c) = order(a) + order(b) - 2*len(axes[0])`.
2823
2824  Args:
2825    a: `Tensor` of type `float32` or `float64`.
2826    b: `Tensor` with the same type as `a`.
2827    axes: Either a scalar `N`, or a list or an `int32` `Tensor` of shape [2, k].
2828     If axes is a scalar, sum over the last N axes of a and the first N axes
2829     of b in order.
2830     If axes is a list or `Tensor` the first and second row contain the set of
2831     unique integers specifying axes along which the contraction is computed,
2832     for `a` and `b`, respectively. The number of axes for `a` and `b` must
2833     be equal.
2834    name: A name for the operation (optional).
2835
2836  Returns:
2837    A `Tensor` with the same type as `a`.
2838
2839  Raises:
2840    ValueError: If the shapes of `a`, `b`, and `axes` are incompatible.
2841    IndexError: If the values in axes exceed the rank of the corresponding
2842      tensor.
2843  """
2844
2845  def _tensordot_reshape(a, axes, flipped=False):
2846    """Helper method to perform transpose and reshape for contraction op.
2847
2848    This method is helpful in reducing `math_ops.tensordot` to `math_ops.matmul`
2849    using `array_ops.transpose` and `array_ops.reshape`. The method takes a
2850    tensor and performs the correct transpose and reshape operation for a given
2851    set of indices. It returns the reshaped tensor as well as a list of indices
2852    necessary to reshape the tensor again after matrix multiplication.
2853
2854    Args:
2855      a: `Tensor`.
2856      axes: List or `int32` `Tensor` of unique indices specifying valid axes of
2857       `a`.
2858      flipped: An optional `bool`. Defaults to `False`. If `True`, the method
2859        assumes that `a` is the second argument in the contraction operation.
2860
2861    Returns:
2862      A tuple `(reshaped_a, free_dims, free_dims_static)` where `reshaped_a` is
2863      the tensor `a` reshaped to allow contraction via `matmul`, `free_dims` is
2864      either a list of integers or an `int32` `Tensor`, depending on whether
2865      the shape of a is fully specified, and free_dims_static is either a list
2866      of integers and None values, or None, representing the inferred
2867      static shape of the free dimensions
2868    """
2869    if a.get_shape().is_fully_defined() and isinstance(axes, (list, tuple)):
2870      shape_a = a.get_shape().as_list()
2871      axes = [i if i >= 0 else i + len(shape_a) for i in axes]
2872      free = [i for i in xrange(len(shape_a)) if i not in axes]
2873      free_dims = [shape_a[i] for i in free]
2874      prod_free = int(np.prod([shape_a[i] for i in free]))
2875      prod_axes = int(np.prod([shape_a[i] for i in axes]))
2876      perm = list(axes) + free if flipped else free + list(axes)
2877      new_shape = [prod_axes, prod_free] if flipped else [prod_free, prod_axes]
2878      reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape)
2879      return reshaped_a, free_dims, free_dims
2880    else:
2881      if a.get_shape().ndims is not None and isinstance(axes, (list, tuple)):
2882        shape_a = a.get_shape().as_list()
2883        axes = [i if i >= 0 else i + len(shape_a) for i in axes]
2884        free = [i for i in xrange(len(shape_a)) if i not in axes]
2885        free_dims_static = [shape_a[i] for i in free]
2886      else:
2887        free_dims_static = None
2888      shape_a = array_ops.shape(a)
2889      rank_a = array_ops.rank(a)
2890      axes = ops.convert_to_tensor(axes, dtype=dtypes.int32, name="axes")
2891      axes = cast(axes >= 0, dtypes.int32) * axes + cast(
2892          axes < 0, dtypes.int32) * (
2893              axes + rank_a)
2894      free, _ = array_ops.setdiff1d(range(rank_a), axes)
2895      free_dims = array_ops.gather(shape_a, free)
2896      axes_dims = array_ops.gather(shape_a, axes)
2897      prod_free_dims = reduce_prod(free_dims)
2898      prod_axes_dims = reduce_prod(axes_dims)
2899      perm = array_ops.concat([axes_dims, free_dims], 0)
2900      if flipped:
2901        perm = array_ops.concat([axes, free], 0)
2902        new_shape = array_ops.stack([prod_axes_dims, prod_free_dims])
2903      else:
2904        perm = array_ops.concat([free, axes], 0)
2905        new_shape = array_ops.stack([prod_free_dims, prod_axes_dims])
2906      reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape)
2907      return reshaped_a, free_dims, free_dims_static
2908
2909  def _tensordot_axes(a, axes):
2910    """Generates two sets of contraction axes for the two tensor arguments."""
2911    a_shape = a.get_shape()
2912    if isinstance(axes, compat.integral_types):
2913      if axes < 0:
2914        raise ValueError("'axes' must be at least 0.")
2915      if a_shape.ndims is not None:
2916        if axes > a_shape.ndims:
2917          raise ValueError("'axes' must not be larger than the number of "
2918                           "dimensions of tensor %s." % a)
2919        return (list(xrange(a_shape.ndims - axes, a_shape.ndims)),
2920                list(xrange(axes)))
2921      else:
2922        rank = array_ops.rank(a)
2923        return (range(rank - axes, rank, dtype=dtypes.int32),
2924                range(axes, dtype=dtypes.int32))
2925    elif isinstance(axes, (list, tuple)):
2926      if len(axes) != 2:
2927        raise ValueError("'axes' must be an integer or have length 2.")
2928      a_axes = axes[0]
2929      b_axes = axes[1]
2930      if isinstance(a_axes, compat.integral_types) and \
2931          isinstance(b_axes, compat.integral_types):
2932        a_axes = [a_axes]
2933        b_axes = [b_axes]
2934      if len(a_axes) != len(b_axes):
2935        raise ValueError(
2936            "Different number of contraction axes 'a' and 'b', %s != %s." %
2937            (len(a_axes), len(b_axes)))
2938      return a_axes, b_axes
2939    else:
2940      axes = ops.convert_to_tensor(axes, name="axes", dtype=dtypes.int32)
2941      return axes[0], axes[1]
2942
2943  with ops.name_scope(name, "Tensordot", [a, b, axes]) as name:
2944    a = ops.convert_to_tensor(a, name="a")
2945    b = ops.convert_to_tensor(b, name="b")
2946    a_axes, b_axes = _tensordot_axes(a, axes)
2947    a_reshape, a_free_dims, a_free_dims_static = _tensordot_reshape(a, a_axes)
2948    b_reshape, b_free_dims, b_free_dims_static = _tensordot_reshape(
2949        b, b_axes, True)
2950    ab_matmul = matmul(a_reshape, b_reshape)
2951    if isinstance(a_free_dims, list) and isinstance(b_free_dims, list):
2952      return array_ops.reshape(ab_matmul, a_free_dims + b_free_dims, name=name)
2953    else:
2954      a_free_dims = ops.convert_to_tensor(a_free_dims, dtype=dtypes.int32)
2955      b_free_dims = ops.convert_to_tensor(b_free_dims, dtype=dtypes.int32)
2956      product = array_ops.reshape(
2957          ab_matmul, array_ops.concat([a_free_dims, b_free_dims], 0), name=name)
2958      if a_free_dims_static is not None and b_free_dims_static is not None:
2959        product.set_shape(a_free_dims_static + b_free_dims_static)
2960      return product
2961
2962
2963# FFT ops were moved to tf.spectral. tf.fft symbols were part of the TensorFlow
2964# 1.0 API so we leave these here for backwards compatibility.
2965fft = gen_spectral_ops.fft
2966ifft = gen_spectral_ops.ifft
2967fft2d = gen_spectral_ops.fft2d
2968ifft2d = gen_spectral_ops.ifft2d
2969fft3d = gen_spectral_ops.fft3d
2970ifft3d = gen_spectral_ops.ifft3d
2971