1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Mathematical operations."""
16# pylint: disable=g-direct-tensorflow-import
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import numbers
23import sys
24
25import numpy as np
26import six
27
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import errors
31from tensorflow.python.framework import ops
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import bitwise_ops
34from tensorflow.python.ops import clip_ops
35from tensorflow.python.ops import control_flow_ops
36from tensorflow.python.ops import gen_math_ops
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops import nn_ops
39from tensorflow.python.ops import sort_ops
40from tensorflow.python.ops import special_math_ops
41from tensorflow.python.ops.numpy_ops import np_array_ops
42from tensorflow.python.ops.numpy_ops import np_arrays
43from tensorflow.python.ops.numpy_ops import np_dtypes
44from tensorflow.python.ops.numpy_ops import np_export
45from tensorflow.python.ops.numpy_ops import np_utils
46
47
48pi = np_export.np_export_constant(__name__, 'pi', np.pi)
49e = np_export.np_export_constant(__name__, 'e', np.e)
50inf = np_export.np_export_constant(__name__, 'inf', np.inf)
51
52
53@np_utils.np_doc_only('dot')
54def dot(a, b):  # pylint: disable=missing-docstring
55
56  def f(a, b):  # pylint: disable=missing-docstring
57    return np_utils.cond(
58        np_utils.logical_or(
59            math_ops.equal(array_ops.rank(a), 0),
60            math_ops.equal(array_ops.rank(b), 0)),
61        lambda: a * b,
62        lambda: np_utils.cond(  # pylint: disable=g-long-lambda
63            math_ops.equal(array_ops.rank(b), 1),
64            lambda: math_ops.tensordot(a, b, axes=[[-1], [-1]]),
65            lambda: math_ops.tensordot(a, b, axes=[[-1], [-2]])))
66
67  return _bin_op(f, a, b)
68
69
70# TODO(wangpeng): Make element-wise ops `ufunc`s
71def _bin_op(tf_fun, a, b, promote=True):
72  if promote:
73    a, b = np_array_ops._promote_dtype_binary(a, b)  # pylint: disable=protected-access
74  else:
75    a = np_array_ops.array(a)
76    b = np_array_ops.array(b)
77  return tf_fun(a, b)
78
79
80@np_utils.np_doc('add')
81def add(x1, x2):
82
83  def add_or_or(x1, x2):
84    if x1.dtype == dtypes.bool:
85      assert x2.dtype == dtypes.bool
86      return math_ops.logical_or(x1, x2)
87    return math_ops.add(x1, x2)
88
89  return _bin_op(add_or_or, x1, x2)
90
91
92@np_utils.np_doc('subtract')
93def subtract(x1, x2):
94  return _bin_op(math_ops.subtract, x1, x2)
95
96
97@np_utils.np_doc('multiply')
98def multiply(x1, x2):
99
100  def mul_or_and(x1, x2):
101    if x1.dtype == dtypes.bool:
102      assert x2.dtype == dtypes.bool
103      return math_ops.logical_and(x1, x2)
104    return math_ops.multiply(x1, x2)
105
106  return _bin_op(mul_or_and, x1, x2)
107
108
109@np_utils.np_doc('true_divide')
110def true_divide(x1, x2):  # pylint: disable=missing-function-docstring
111
112  def _avoid_float64(x1, x2):
113    if x1.dtype == x2.dtype and x1.dtype in (dtypes.int32, dtypes.int64):
114      x1 = math_ops.cast(x1, dtype=dtypes.float32)
115      x2 = math_ops.cast(x2, dtype=dtypes.float32)
116    return x1, x2
117
118  def f(x1, x2):
119    if x1.dtype == dtypes.bool:
120      assert x2.dtype == dtypes.bool
121      float_ = np_dtypes.default_float_type()
122      x1 = math_ops.cast(x1, float_)
123      x2 = math_ops.cast(x2, float_)
124    if not np_dtypes.is_allow_float64():
125      # math_ops.truediv in Python3 produces float64 when both inputs are int32
126      # or int64. We want to avoid that when is_allow_float64() is False.
127      x1, x2 = _avoid_float64(x1, x2)
128    return math_ops.truediv(x1, x2)
129
130  return _bin_op(f, x1, x2)
131
132
133@np_utils.np_doc('divide')
134def divide(x1, x2):  # pylint: disable=missing-function-docstring
135  return true_divide(x1, x2)
136
137
138@np_utils.np_doc('floor_divide')
139def floor_divide(x1, x2):  # pylint: disable=missing-function-docstring
140
141  def f(x1, x2):
142    if x1.dtype == dtypes.bool:
143      assert x2.dtype == dtypes.bool
144      x1 = math_ops.cast(x1, dtypes.int8)
145      x2 = math_ops.cast(x2, dtypes.int8)
146    return math_ops.floordiv(x1, x2)
147
148  return _bin_op(f, x1, x2)
149
150
151@np_utils.np_doc('mod')
152def mod(x1, x2):  # pylint: disable=missing-function-docstring
153
154  def f(x1, x2):
155    if x1.dtype == dtypes.bool:
156      assert x2.dtype == dtypes.bool
157      x1 = math_ops.cast(x1, dtypes.int8)
158      x2 = math_ops.cast(x2, dtypes.int8)
159    return math_ops.mod(x1, x2)
160
161  return _bin_op(f, x1, x2)
162
163
164@np_utils.np_doc('remainder')
165def remainder(x1, x2):  # pylint: disable=missing-function-docstring
166  return mod(x1, x2)
167
168
169@np_utils.np_doc('divmod')
170def divmod(x1, x2):  # pylint: disable=redefined-builtin
171  return floor_divide(x1, x2), mod(x1, x2)
172
173
174@np_utils.np_doc('maximum')
175def maximum(x1, x2):  # pylint: disable=missing-function-docstring
176
177  # Fast path for when maximum is used as relu.
178  if isinstance(
179      x2, numbers.Real) and not isinstance(x2, bool) and x2 == 0 and isinstance(
180          x1, np_arrays.ndarray) and x1.dtype != dtypes.bool:
181    return nn_ops.relu(np_array_ops.asarray(x1))
182
183  def max_or_or(x1, x2):
184    if x1.dtype == dtypes.bool:
185      assert x2.dtype == dtypes.bool
186      return math_ops.logical_or(x1, x2)
187    return math_ops.maximum(x1, x2)
188
189  return _bin_op(max_or_or, x1, x2)
190
191
192@np_utils.np_doc('minimum')
193def minimum(x1, x2):
194
195  def min_or_and(x1, x2):
196    if x1.dtype == dtypes.bool:
197      assert x2.dtype == dtypes.bool
198      return math_ops.logical_and(x1, x2)
199    return math_ops.minimum(x1, x2)
200
201  return _bin_op(min_or_and, x1, x2)
202
203
204@np_utils.np_doc('clip')
205def clip(a, a_min, a_max):  # pylint: disable=missing-docstring
206  if a_min is None and a_max is None:
207    raise ValueError('Not more than one of `a_min` and `a_max` may be `None`.')
208  if a_min is None:
209    return minimum(a, a_max)
210  elif a_max is None:
211    return maximum(a, a_min)
212  else:
213    a, a_min, a_max = np_array_ops._promote_dtype(a, a_min, a_max)  # pylint: disable=protected-access
214    return clip_ops.clip_by_value(*np_utils.tf_broadcast(a, a_min, a_max))
215
216
217@np_utils.np_doc('matmul')
218def matmul(x1, x2):  # pylint: disable=missing-docstring
219  def f(x1, x2):
220    try:
221      if x1._rank() == 2 and x2._rank() == 2:  # pylint: disable=protected-access
222        # Fast path for known ranks.
223        return gen_math_ops.mat_mul(x1, x2)
224      return np_utils.cond(
225          math_ops.equal(np_utils.tf_rank(x2), 1),
226          lambda: math_ops.tensordot(x1, x2, axes=1),
227          lambda: np_utils.cond(  # pylint: disable=g-long-lambda
228              math_ops.equal(np_utils.tf_rank(x1), 1),
229              lambda: math_ops.tensordot(  # pylint: disable=g-long-lambda
230                  x1, x2, axes=[[0], [-2]]),
231              lambda: math_ops.matmul(x1, x2)))
232    except errors.InvalidArgumentError as err:
233      six.reraise(ValueError, ValueError(str(err)), sys.exc_info()[2])
234
235  return _bin_op(f, x1, x2)
236
237
238# Exported so it can be called from Tensor.__matmul__. NumPy's matmul handles
239# batched matmul as well, so simply including promotion in TF's current
240# __matmul__ implementation was not sufficient.
241setattr(np_arrays.ndarray, '_matmul', matmul)
242
243
244@np_utils.np_doc('tensordot')
245def tensordot(a, b, axes=2):
246  return _bin_op(lambda a, b: math_ops.tensordot(a, b, axes=axes), a, b)
247
248
249@np_utils.np_doc_only('inner')
250def inner(a, b):  # pylint: disable=missing-function-docstring
251
252  def f(a, b):
253    return np_utils.cond(
254        np_utils.logical_or(
255            math_ops.equal(array_ops.rank(a), 0),
256            math_ops.equal(array_ops.rank(b), 0)), lambda: a * b,
257        lambda: math_ops.tensordot(a, b, axes=[[-1], [-1]]))
258
259  return _bin_op(f, a, b)
260
261
262@np_utils.np_doc('cross')
263def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):  # pylint: disable=missing-docstring
264
265  def f(a, b):  # pylint: disable=missing-docstring
266    # We can't assign to captured variable `axisa`, so make a new variable
267    if axis is None:
268      axis_a = axisa
269      axis_b = axisb
270      axis_c = axisc
271    else:
272      axis_a = axis
273      axis_b = axis
274      axis_c = axis
275    if axis_a < 0:
276      axis_a = np_utils.add(axis_a, array_ops.rank(a))
277    if axis_b < 0:
278      axis_b = np_utils.add(axis_b, array_ops.rank(b))
279
280    def maybe_move_axis_to_last(a, axis):
281
282      def move_axis_to_last(a, axis):
283        return array_ops.transpose(
284            a,
285            array_ops.concat([
286                math_ops.range(axis),
287                math_ops.range(axis + 1, array_ops.rank(a)), [axis]
288            ],
289                             axis=0))
290
291      return np_utils.cond(axis == np_utils.subtract(array_ops.rank(a), 1),
292                           lambda: a, lambda: move_axis_to_last(a, axis))
293
294    a = maybe_move_axis_to_last(a, axis_a)
295    b = maybe_move_axis_to_last(b, axis_b)
296    a_dim = np_utils.getitem(array_ops.shape(a), -1)
297    b_dim = np_utils.getitem(array_ops.shape(b), -1)
298
299    def maybe_pad_0(a, size_of_last_dim):
300
301      def pad_0(a):
302        return array_ops.pad(
303            a,
304            array_ops.concat([
305                array_ops.zeros([array_ops.rank(a) - 1, 2], dtypes.int32),
306                constant_op.constant([[0, 1]], dtypes.int32)
307            ],
308                             axis=0))
309
310      return np_utils.cond(
311          math_ops.equal(size_of_last_dim, 2), lambda: pad_0(a), lambda: a)
312
313    a = maybe_pad_0(a, a_dim)
314    b = maybe_pad_0(b, b_dim)
315    c = math_ops.cross(*np_utils.tf_broadcast(a, b))
316    if axis_c < 0:
317      axis_c = np_utils.add(axis_c, array_ops.rank(c))
318
319    def move_last_to_axis(a, axis):
320      r = array_ops.rank(a)
321      return array_ops.transpose(
322          a,
323          array_ops.concat(
324              [math_ops.range(axis), [r - 1],
325               math_ops.range(axis, r - 1)],
326              axis=0))
327
328    c = np_utils.cond(
329        (a_dim == 2) & (b_dim == 2),
330        lambda: c[..., 2],
331        lambda: np_utils.cond(  # pylint: disable=g-long-lambda
332            axis_c == np_utils.subtract(array_ops.rank(c), 1), lambda: c,
333            lambda: move_last_to_axis(c, axis_c)))
334    return c
335
336  return _bin_op(f, a, b)
337
338
339@np_utils.np_doc_only('vdot')
340def vdot(a, b):  # pylint: disable=missing-docstring
341  a, b = np_array_ops._promote_dtype(a, b)
342  a = np_array_ops.reshape(a, [-1])
343  b = np_array_ops.reshape(b, [-1])
344  if a.dtype == np_dtypes.complex128 or a.dtype == np_dtypes.complex64:
345    a = conj(a)
346  return dot(a, b)
347
348
349@np_utils.np_doc('power')
350def power(x1, x2):
351  return _bin_op(math_ops.pow, x1, x2)
352
353
354@np_utils.np_doc('float_power')
355def float_power(x1, x2):
356  return power(x1, x2)
357
358
359@np_utils.np_doc('arctan2')
360def arctan2(x1, x2):
361  return _bin_op(math_ops.atan2, x1, x2)
362
363
364@np_utils.np_doc('nextafter')
365def nextafter(x1, x2):
366  return _bin_op(math_ops.nextafter, x1, x2)
367
368
369@np_utils.np_doc('heaviside')
370def heaviside(x1, x2):  # pylint: disable=missing-function-docstring
371
372  def f(x1, x2):
373    return array_ops.where_v2(
374        x1 < 0, constant_op.constant(0, dtype=x2.dtype),
375        array_ops.where_v2(x1 > 0, constant_op.constant(1, dtype=x2.dtype), x2))
376
377  y = _bin_op(f, x1, x2)
378  if not np.issubdtype(y.dtype.as_numpy_dtype, np.inexact):
379    y = y.astype(np_dtypes.default_float_type())
380  return y
381
382
383@np_utils.np_doc('hypot')
384def hypot(x1, x2):
385  return sqrt(square(x1) + square(x2))
386
387
388@np_utils.np_doc('kron')
389def kron(a, b):  # pylint: disable=missing-function-docstring
390  # pylint: disable=protected-access,g-complex-comprehension
391  a, b = np_array_ops._promote_dtype(a, b)
392  t_a = np_utils.cond(
393      a.ndim < b.ndim,
394      lambda: np_array_ops.reshape(  # pylint: disable=g-long-lambda
395          a, np_array_ops._pad_left_to(b.ndim, a.shape)),
396      lambda: a)
397  t_b = np_utils.cond(
398      b.ndim < a.ndim,
399      lambda: np_array_ops.reshape(  # pylint: disable=g-long-lambda
400          b, np_array_ops._pad_left_to(a.ndim, b.shape)),
401      lambda: b)
402
403  def _make_shape(shape, prepend):
404    ones = array_ops.ones_like(shape)
405    if prepend:
406      shapes = [ones, shape]
407    else:
408      shapes = [shape, ones]
409    return array_ops.reshape(array_ops.stack(shapes, axis=1), [-1])
410
411  a_shape = array_ops.shape(t_a)
412  b_shape = array_ops.shape(t_b)
413  a_reshaped = np_array_ops.reshape(t_a, _make_shape(a_shape, False))
414  b_reshaped = np_array_ops.reshape(t_b, _make_shape(b_shape, True))
415  out_shape = a_shape * b_shape
416  return np_array_ops.reshape(a_reshaped * b_reshaped, out_shape)
417
418
419@np_utils.np_doc('outer')
420def outer(a, b):
421
422  def f(a, b):
423    return array_ops.reshape(a, [-1, 1]) * array_ops.reshape(b, [-1])
424
425  return _bin_op(f, a, b)
426
427
428# This can also be implemented via tf.reduce_logsumexp
429@np_utils.np_doc('logaddexp')
430def logaddexp(x1, x2):
431  amax = maximum(x1, x2)
432  delta = x1 - x2
433  return np_array_ops.where(
434      isnan(delta),
435      x1 + x2,  # NaNs or infinities of the same sign.
436      amax + log1p(exp(-abs(delta))))
437
438
439@np_utils.np_doc('logaddexp2')
440def logaddexp2(x1, x2):
441  amax = maximum(x1, x2)
442  delta = x1 - x2
443  return np_array_ops.where(
444      isnan(delta),
445      x1 + x2,  # NaNs or infinities of the same sign.
446      amax + log1p(exp2(-abs(delta))) / np.log(2))
447
448
449@np_utils.np_doc('polyval')
450def polyval(p, x):  # pylint: disable=missing-function-docstring
451
452  def f(p, x):
453    if p.shape.rank == 0:
454      p = array_ops.reshape(p, [1])
455    p = array_ops.unstack(p)
456    # TODO(wangpeng): Make tf version take a tensor for p instead of a list.
457    y = math_ops.polyval(p, x)
458    # If the polynomial is 0-order, numpy requires the result to be broadcast to
459    # `x`'s shape.
460    if len(p) == 1:
461      y = array_ops.broadcast_to(y, x.shape)
462    return y
463
464  return _bin_op(f, p, x)
465
466
467@np_utils.np_doc('isclose')
468def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):  # pylint: disable=missing-docstring
469
470  def f(a, b):  # pylint: disable=missing-docstring
471    dtype = a.dtype
472    if np.issubdtype(dtype.as_numpy_dtype, np.inexact):
473      rtol_ = ops.convert_to_tensor(rtol, dtype.real_dtype)
474      atol_ = ops.convert_to_tensor(atol, dtype.real_dtype)
475      result = (math_ops.abs(a - b) <= atol_ + rtol_ * math_ops.abs(b))
476      if equal_nan:
477        result = result | (math_ops.is_nan(a) & math_ops.is_nan(b))
478      return result
479    else:
480      return a == b
481
482  return _bin_op(f, a, b)
483
484
485@np_utils.np_doc('allclose')
486def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
487  return np_array_ops.all(
488      isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan))
489
490
491def _tf_gcd(x1, x2):  # pylint: disable=missing-function-docstring
492
493  def _gcd_cond_fn(_, x2):
494    return math_ops.reduce_any(x2 != 0)
495
496  def _gcd_body_fn(x1, x2):
497    # math_ops.mod will raise an error when any element of x2 is 0. To avoid
498    # that, we change those zeros to ones. Their values don't matter because
499    # they won't be used.
500    x2_safe = array_ops.where_v2(x2 != 0, x2, constant_op.constant(1, x2.dtype))
501    x1, x2 = (array_ops.where_v2(x2 != 0, x2, x1),
502              array_ops.where_v2(x2 != 0, math_ops.mod(x1, x2_safe),
503                                 constant_op.constant(0, x2.dtype)))
504    return (array_ops.where_v2(x1 < x2, x2,
505                               x1), array_ops.where_v2(x1 < x2, x1, x2))
506
507  if (not np.issubdtype(x1.dtype.as_numpy_dtype, np.integer) or
508      not np.issubdtype(x2.dtype.as_numpy_dtype, np.integer)):
509    raise ValueError('Arguments to gcd must be integers.')
510  shape = array_ops.broadcast_dynamic_shape(
511      array_ops.shape(x1), array_ops.shape(x2))
512  x1 = array_ops.broadcast_to(x1, shape)
513  x2 = array_ops.broadcast_to(x2, shape)
514  value, _ = control_flow_ops.while_loop(_gcd_cond_fn, _gcd_body_fn,
515                                         (math_ops.abs(x1), math_ops.abs(x2)))
516  return value
517
518
519# Note that np.gcd may not be present in some supported versions of numpy.
520@np_utils.np_doc('gcd')
521def gcd(x1, x2):
522  return _bin_op(_tf_gcd, x1, x2)
523
524
525# Note that np.lcm may not be present in some supported versions of numpy.
526@np_utils.np_doc('lcm')
527def lcm(x1, x2):  # pylint: disable=missing-function-docstring
528
529  def f(x1, x2):
530    d = _tf_gcd(x1, x2)
531    # Same as the `x2_safe` trick above
532    d_safe = array_ops.where_v2(
533        math_ops.equal(d, 0), constant_op.constant(1, d.dtype), d)
534    return array_ops.where_v2(
535        math_ops.equal(d, 0), constant_op.constant(0, d.dtype),
536        math_ops.abs(x1 * x2) // d_safe)
537
538  return _bin_op(f, x1, x2)
539
540
541def _bitwise_binary_op(tf_fn, x1, x2):  # pylint: disable=missing-function-docstring
542
543  def f(x1, x2):
544    is_bool = (x1.dtype == dtypes.bool)
545    if is_bool:
546      assert x2.dtype == dtypes.bool
547      x1 = math_ops.cast(x1, dtypes.int8)
548      x2 = math_ops.cast(x2, dtypes.int8)
549    r = tf_fn(x1, x2)
550    if is_bool:
551      r = math_ops.cast(r, dtypes.bool)
552    return r
553
554  return _bin_op(f, x1, x2)
555
556
557@np_utils.np_doc('bitwise_and')
558def bitwise_and(x1, x2):
559  return _bitwise_binary_op(bitwise_ops.bitwise_and, x1, x2)
560
561
562@np_utils.np_doc('bitwise_or')
563def bitwise_or(x1, x2):
564  return _bitwise_binary_op(bitwise_ops.bitwise_or, x1, x2)
565
566
567@np_utils.np_doc('bitwise_xor')
568def bitwise_xor(x1, x2):
569  return _bitwise_binary_op(bitwise_ops.bitwise_xor, x1, x2)
570
571
572@np_utils.np_doc('bitwise_not', link=np_utils.AliasOf('invert'))
573def bitwise_not(x):
574
575  def f(x):
576    if x.dtype == dtypes.bool:
577      return math_ops.logical_not(x)
578    return bitwise_ops.invert(x)
579
580  return _scalar(f, x)
581
582
583def _scalar(tf_fn, x, promote_to_float=False):
584  """Computes the tf_fn(x) for each element in `x`.
585
586  Args:
587    tf_fn: function that takes a single Tensor argument.
588    x: array_like. Could be an ndarray, a Tensor or any object that can be
589      converted to a Tensor using `ops.convert_to_tensor`.
590    promote_to_float: whether to cast the argument to a float dtype
591      (`np_dtypes.default_float_type`) if it is not already.
592
593  Returns:
594    An ndarray with the same shape as `x`. The default output dtype is
595    determined by `np_dtypes.default_float_type`, unless x is an ndarray with a
596    floating point type, in which case the output type is same as x.dtype.
597  """
598  x = np_array_ops.asarray(x)
599  if promote_to_float and not np.issubdtype(x.dtype.as_numpy_dtype, np.inexact):
600    x = x.astype(np_dtypes.default_float_type())
601  return tf_fn(x)
602
603
604@np_utils.np_doc('log')
605def log(x):
606  return _scalar(math_ops.log, x, True)
607
608
609@np_utils.np_doc('exp')
610def exp(x):
611  return _scalar(math_ops.exp, x, True)
612
613
614@np_utils.np_doc('sqrt')
615def sqrt(x):
616  return _scalar(math_ops.sqrt, x, True)
617
618
619@np_utils.np_doc('abs', link=np_utils.AliasOf('absolute'))
620def abs(x):  # pylint: disable=redefined-builtin
621  return _scalar(math_ops.abs, x)
622
623
624@np_utils.np_doc('absolute')
625def absolute(x):
626  return abs(x)
627
628
629@np_utils.np_doc('fabs')
630def fabs(x):
631  return abs(x)
632
633
634@np_utils.np_doc('ceil')
635def ceil(x):
636  return _scalar(math_ops.ceil, x, True)
637
638
639@np_utils.np_doc('floor')
640def floor(x):
641  return _scalar(math_ops.floor, x, True)
642
643
644@np_utils.np_doc('conj')
645def conj(x):
646  return _scalar(math_ops.conj, x)
647
648
649@np_utils.np_doc('negative')
650def negative(x):
651  return _scalar(math_ops.negative, x)
652
653
654@np_utils.np_doc('reciprocal')
655def reciprocal(x):
656  return _scalar(math_ops.reciprocal, x)
657
658
659@np_utils.np_doc('signbit')
660def signbit(x):
661
662  def f(x):
663    if x.dtype == dtypes.bool:
664      return array_ops.fill(array_ops.shape(x), False)
665    return x < 0
666
667  return _scalar(f, x)
668
669
670@np_utils.np_doc('sin')
671def sin(x):
672  return _scalar(math_ops.sin, x, True)
673
674
675@np_utils.np_doc('cos')
676def cos(x):
677  return _scalar(math_ops.cos, x, True)
678
679
680@np_utils.np_doc('tan')
681def tan(x):
682  return _scalar(math_ops.tan, x, True)
683
684
685@np_utils.np_doc('sinh')
686def sinh(x):
687  return _scalar(math_ops.sinh, x, True)
688
689
690@np_utils.np_doc('cosh')
691def cosh(x):
692  return _scalar(math_ops.cosh, x, True)
693
694
695@np_utils.np_doc('tanh')
696def tanh(x):
697  return _scalar(math_ops.tanh, x, True)
698
699
700@np_utils.np_doc('arcsin')
701def arcsin(x):
702  return _scalar(math_ops.asin, x, True)
703
704
705@np_utils.np_doc('arccos')
706def arccos(x):
707  return _scalar(math_ops.acos, x, True)
708
709
710@np_utils.np_doc('arctan')
711def arctan(x):
712  return _scalar(math_ops.atan, x, True)
713
714
715@np_utils.np_doc('arcsinh')
716def arcsinh(x):
717  return _scalar(math_ops.asinh, x, True)
718
719
720@np_utils.np_doc('arccosh')
721def arccosh(x):
722  return _scalar(math_ops.acosh, x, True)
723
724
725@np_utils.np_doc('arctanh')
726def arctanh(x):
727  return _scalar(math_ops.atanh, x, True)
728
729
730@np_utils.np_doc('deg2rad')
731def deg2rad(x):
732
733  def f(x):
734    return x * (np.pi / 180.0)
735
736  return _scalar(f, x, True)
737
738
739@np_utils.np_doc('rad2deg')
740def rad2deg(x):
741  return x * (180.0 / np.pi)
742
743
744_tf_float_types = [
745    dtypes.bfloat16, dtypes.float16, dtypes.float32, dtypes.float64
746]
747
748
749@np_utils.np_doc('angle')
750def angle(z, deg=False):  # pylint: disable=missing-function-docstring
751
752  def f(x):
753    if x.dtype in _tf_float_types:
754      # Workaround for b/147515503
755      return array_ops.where_v2(x < 0, np.pi, 0)
756    else:
757      return math_ops.angle(x)
758
759  y = _scalar(f, z, True)
760  if deg:
761    y = rad2deg(y)
762  return y
763
764
765@np_utils.np_doc('cbrt')
766def cbrt(x):
767
768  def f(x):
769    # __pow__ can't handle negative base, so we use `abs` here.
770    rt = math_ops.abs(x)**(1.0 / 3)
771    return array_ops.where_v2(x < 0, -rt, rt)
772
773  return _scalar(f, x, True)
774
775
776@np_utils.np_doc('conjugate', link=np_utils.AliasOf('conj'))
777def conjugate(x):
778  return _scalar(math_ops.conj, x)
779
780
781@np_utils.np_doc('exp2')
782def exp2(x):
783
784  def f(x):
785    return 2**x
786
787  return _scalar(f, x, True)
788
789
790@np_utils.np_doc('expm1')
791def expm1(x):
792  return _scalar(math_ops.expm1, x, True)
793
794
795@np_utils.np_doc('fix')
796def fix(x):
797
798  def f(x):
799    return array_ops.where_v2(x < 0, math_ops.ceil(x), math_ops.floor(x))
800
801  return _scalar(f, x, True)
802
803
804@np_utils.np_doc('iscomplex')
805def iscomplex(x):
806  return np_array_ops.imag(x) != 0
807
808
809@np_utils.np_doc('isreal')
810def isreal(x):
811  return np_array_ops.imag(x) == 0
812
813
814@np_utils.np_doc('iscomplexobj')
815def iscomplexobj(x):
816  x = np_array_ops.array(x)
817  return np.issubdtype(x.dtype.as_numpy_dtype, np.complexfloating)
818
819
820@np_utils.np_doc('isrealobj')
821def isrealobj(x):
822  return not iscomplexobj(x)
823
824
825@np_utils.np_doc('isnan')
826def isnan(x):
827  return _scalar(math_ops.is_nan, x, True)
828
829
830def _make_nan_reduction(np_fun_name, reduction, init_val):
831  """Helper to generate nan* functions."""
832
833  @np_utils.np_doc(np_fun_name)
834  def nan_reduction(a, axis=None, dtype=None, keepdims=False):
835    a = np_array_ops.array(a)
836    v = np_array_ops.array(init_val, dtype=a.dtype)
837    return reduction(
838        np_array_ops.where(isnan(a), v, a),
839        axis=axis,
840        dtype=dtype,
841        keepdims=keepdims)
842
843  return nan_reduction
844
845
846nansum = _make_nan_reduction('nansum', np_array_ops.sum, 0)
847nanprod = _make_nan_reduction('nanprod', np_array_ops.prod, 1)
848
849
850@np_utils.np_doc('nanmean')
851def nanmean(a, axis=None, dtype=None, keepdims=None):  # pylint: disable=missing-docstring
852  a = np_array_ops.array(a)
853  if np.issubdtype(a.dtype.as_numpy_dtype, np.bool_) or np.issubdtype(
854      a.dtype.as_numpy_dtype, np.integer):
855    return np_array_ops.mean(a, axis=axis, dtype=dtype, keepdims=keepdims)
856  nan_mask = logical_not(isnan(a))
857  if dtype is None:
858    dtype = a.dtype.as_numpy_dtype
859  normalizer = np_array_ops.sum(
860      nan_mask, axis=axis, dtype=dtype, keepdims=keepdims)
861  return nansum(a, axis=axis, dtype=dtype, keepdims=keepdims) / normalizer
862
863
864@np_utils.np_doc('isfinite')
865def isfinite(x):
866  return _scalar(math_ops.is_finite, x, True)
867
868
869@np_utils.np_doc('isinf')
870def isinf(x):
871  return _scalar(math_ops.is_inf, x, True)
872
873
874@np_utils.np_doc('isneginf')
875def isneginf(x):
876  return x == np_array_ops.full_like(x, -np.inf)
877
878
879@np_utils.np_doc('isposinf')
880def isposinf(x):
881  return x == np_array_ops.full_like(x, np.inf)
882
883
884@np_utils.np_doc('log2')
885def log2(x):
886  return log(x) / np.log(2)
887
888
889@np_utils.np_doc('log10')
890def log10(x):
891  return log(x) / np.log(10)
892
893
894@np_utils.np_doc('log1p')
895def log1p(x):
896  return _scalar(math_ops.log1p, x, True)
897
898
899@np_utils.np_doc('positive')
900def positive(x):
901  return _scalar(lambda x: x, x)
902
903
904@np_utils.np_doc('sinc')
905def sinc(x):
906
907  def f(x):
908    pi_x = x * np.pi
909    return array_ops.where_v2(x == 0, array_ops.ones_like(x),
910                              math_ops.sin(pi_x) / pi_x)
911
912  return _scalar(f, x, True)
913
914
915@np_utils.np_doc('square')
916def square(x):
917  return _scalar(math_ops.square, x)
918
919
920@np_utils.np_doc('diff')
921def diff(a, n=1, axis=-1):  # pylint: disable=missing-function-docstring
922
923  def f(a):
924    # TODO(agarwal): transpose and reshape to N, H, 1 and do a 1D convolution
925    # TODO(agarwal): avoid depending on static rank.
926    nd = a.shape.rank
927    if nd is None:
928      raise ValueError('diff currently requires known rank for input `a`')
929    if (axis + nd if axis < 0 else axis) >= nd:
930      raise ValueError('axis %s is out of bounds for array of dimension %s' %
931                       (axis, nd))
932    if n < 0:
933      raise ValueError('order must be non-negative but got %s' % n)
934    slice1 = [slice(None)] * nd
935    slice2 = [slice(None)] * nd
936    slice1[axis] = slice(1, None)
937    slice2[axis] = slice(None, -1)
938    slice1 = tuple(slice1)
939    slice2 = tuple(slice2)
940    op = math_ops.not_equal if a.dtype == dtypes.bool else math_ops.subtract
941    for _ in range(n):
942      a = op(a[slice1], a[slice2])
943    return a
944
945  return _scalar(f, a)
946
947
948def _wrap(f, reverse=False):
949  """Wraps binary ops so they can be added as operator overloads on ndarray."""
950
951  def _f(a, b):
952    if reverse:
953      a, b = b, a
954
955    if getattr(b, '__array_priority__',
956               0) > np_arrays.ndarray.__array_priority__:
957      return NotImplemented
958
959    return f(a, b)
960
961  return _f
962
963
964def _comparison(tf_fun, x1, x2, cast_bool_to_int=False):
965  """Helper function for comparision."""
966  dtype = np_utils.result_type(x1, x2)
967  # Cast x1 and x2 to the result_type if needed.
968  x1 = np_array_ops.array(x1, dtype=dtype)
969  x2 = np_array_ops.array(x2, dtype=dtype)
970  if cast_bool_to_int and x1.dtype == dtypes.bool:
971    x1 = math_ops.cast(x1, dtypes.int32)
972    x2 = math_ops.cast(x2, dtypes.int32)
973  return tf_fun(x1, x2)
974
975
976@np_utils.np_doc('equal')
977def equal(x1, x2):
978  return _comparison(math_ops.equal, x1, x2)
979
980
981@np_utils.np_doc('not_equal')
982def not_equal(x1, x2):
983  return _comparison(math_ops.not_equal, x1, x2)
984
985
986@np_utils.np_doc('greater')
987def greater(x1, x2):
988  return _comparison(math_ops.greater, x1, x2, True)
989
990
991@np_utils.np_doc('greater_equal')
992def greater_equal(x1, x2):
993  return _comparison(math_ops.greater_equal, x1, x2, True)
994
995
996@np_utils.np_doc('less')
997def less(x1, x2):
998  return _comparison(math_ops.less, x1, x2, True)
999
1000
1001@np_utils.np_doc('less_equal')
1002def less_equal(x1, x2):
1003  return _comparison(math_ops.less_equal, x1, x2, True)
1004
1005
1006@np_utils.np_doc('array_equal')
1007def array_equal(a1, a2):  # pylint: disable=missing-function-docstring
1008
1009  def f(x1, x2):
1010    return np_utils.cond(
1011        math_ops.equal(array_ops.rank(x1), array_ops.rank(x2)),
1012        lambda: np_utils.cond(  # pylint: disable=g-long-lambda
1013            np_utils.reduce_all(
1014                math_ops.equal(array_ops.shape(x1), array_ops.shape(x2))
1015            ),
1016            lambda: math_ops.reduce_all(math_ops.equal(x1, x2)),
1017            lambda: constant_op.constant(False)),
1018        lambda: constant_op.constant(False))
1019
1020  return _comparison(f, a1, a2)
1021
1022
1023def _logical_binary_op(tf_fun, x1, x2):
1024  x1 = np_array_ops.array(x1, dtype=np.bool_)
1025  x2 = np_array_ops.array(x2, dtype=np.bool_)
1026  return tf_fun(x1, x2)
1027
1028
1029@np_utils.np_doc('logical_and')
1030def logical_and(x1, x2):
1031  return _logical_binary_op(math_ops.logical_and, x1, x2)
1032
1033
1034@np_utils.np_doc('logical_or')
1035def logical_or(x1, x2):
1036  return _logical_binary_op(math_ops.logical_or, x1, x2)
1037
1038
1039@np_utils.np_doc('logical_xor')
1040def logical_xor(x1, x2):
1041  return _logical_binary_op(math_ops.logical_xor, x1, x2)
1042
1043
1044@np_utils.np_doc('logical_not')
1045def logical_not(x):
1046  x = np_array_ops.array(x, dtype=np.bool_)
1047  return math_ops.logical_not(x)
1048
1049
1050@np_utils.np_doc('linspace')
1051def linspace(  # pylint: disable=missing-docstring
1052    start,
1053    stop,
1054    num=50,
1055    endpoint=True,
1056    retstep=False,
1057    dtype=float,
1058    axis=0):
1059  if dtype:
1060    dtype = np_utils.result_type(dtype)
1061  start = np_array_ops.array(start, dtype=dtype)
1062  stop = np_array_ops.array(stop, dtype=dtype)
1063  if num < 0:
1064    raise ValueError('Number of samples {} must be non-negative.'.format(num))
1065  step = ops.convert_to_tensor(np.nan)
1066  if endpoint:
1067    result = math_ops.linspace(start, stop, num, axis=axis)
1068    if num > 1:
1069      step = (stop - start) / (num - 1)
1070  else:
1071    # math_ops.linspace does not support endpoint=False so we manually handle it
1072    # here.
1073    if num > 1:
1074      step = ((stop - start) / num)
1075      new_stop = math_ops.cast(stop, step.dtype) - step
1076      start = math_ops.cast(start, new_stop.dtype)
1077      result = math_ops.linspace(start, new_stop, num, axis=axis)
1078    else:
1079      result = math_ops.linspace(start, stop, num, axis=axis)
1080  if dtype:
1081    result = math_ops.cast(result, dtype)
1082  if retstep:
1083    return (result, step)
1084  else:
1085    return result
1086
1087
1088@np_utils.np_doc('logspace')
1089def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0):
1090  dtype = np_utils.result_type(start, stop, dtype)
1091  result = linspace(
1092      start, stop, num=num, endpoint=endpoint, dtype=dtype, axis=axis)
1093  result = math_ops.pow(math_ops.cast(base, result.dtype), result)
1094  if dtype:
1095    result = math_ops.cast(result, dtype)
1096  return result
1097
1098
1099@np_utils.np_doc('geomspace')
1100def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):  # pylint: disable=missing-docstring
1101  dtype = dtypes.as_dtype(dtype) if dtype else np_utils.result_type(
1102      start, stop, float(num), np_array_ops.zeros((), dtype))
1103  computation_dtype = np.promote_types(dtype.as_numpy_dtype, np.float32)
1104  start = np_array_ops.asarray(start, dtype=computation_dtype)
1105  stop = np_array_ops.asarray(stop, dtype=computation_dtype)
1106  # follow the numpy geomspace convention for negative and complex endpoints
1107  start_sign = 1 - np_array_ops.sign(np_array_ops.real(start))
1108  stop_sign = 1 - np_array_ops.sign(np_array_ops.real(stop))
1109  signflip = 1 - start_sign * stop_sign // 2
1110  res = signflip * logspace(
1111      log10(signflip * start),
1112      log10(signflip * stop),
1113      num,
1114      endpoint=endpoint,
1115      base=10.0,
1116      dtype=computation_dtype,
1117      axis=0)
1118  if axis != 0:
1119    res = np_array_ops.moveaxis(res, 0, axis)
1120  return math_ops.cast(res, dtype)
1121
1122
1123@np_utils.np_doc('ptp')
1124def ptp(a, axis=None, keepdims=None):
1125  return (np_array_ops.amax(a, axis=axis, keepdims=keepdims) -
1126          np_array_ops.amin(a, axis=axis, keepdims=keepdims))
1127
1128
1129@np_utils.np_doc_only('concatenate')
1130def concatenate(arys, axis=0):
1131  if not isinstance(arys, (list, tuple)):
1132    arys = [arys]
1133  if not arys:
1134    raise ValueError('Need at least one array to concatenate.')
1135  dtype = np_utils.result_type(*arys)
1136  arys = [np_array_ops.array(array, dtype=dtype) for array in arys]
1137  return array_ops.concat(arys, axis)
1138
1139
1140@np_utils.np_doc_only('tile')
1141def tile(a, reps):  # pylint: disable=missing-function-docstring
1142  a = np_array_ops.array(a)
1143  reps = np_array_ops.array(reps, dtype=dtypes.int32).reshape([-1])
1144
1145  a_rank = array_ops.rank(a)
1146  reps_size = array_ops.size(reps)
1147  reps = array_ops.pad(
1148      reps, [[math_ops.maximum(a_rank - reps_size, 0), 0]], constant_values=1)
1149  a_shape = array_ops.pad(
1150      array_ops.shape(a), [[math_ops.maximum(reps_size - a_rank, 0), 0]],
1151      constant_values=1)
1152  a = array_ops.reshape(a, a_shape)
1153
1154  return array_ops.tile(a, reps)
1155
1156
1157@np_utils.np_doc('count_nonzero')
1158def count_nonzero(a, axis=None):
1159  return math_ops.count_nonzero(np_array_ops.array(a), axis)
1160
1161
1162@np_utils.np_doc('argsort')
1163def argsort(a, axis=-1, kind='quicksort', order=None):  # pylint: disable=missing-docstring
1164  # TODO(nareshmodi): make string tensors also work.
1165  if kind not in ('quicksort', 'stable'):
1166    raise ValueError("Only 'quicksort' and 'stable' arguments are supported.")
1167  if order is not None:
1168    raise ValueError("'order' argument to sort is not supported.")
1169  stable = (kind == 'stable')
1170
1171  a = np_array_ops.array(a)
1172
1173  def _argsort(a, axis, stable):
1174    if axis is None:
1175      a = array_ops.reshape(a, [-1])
1176      axis = 0
1177
1178    return sort_ops.argsort(a, axis, stable=stable)
1179
1180  tf_ans = np_utils.cond(
1181      math_ops.equal(array_ops.rank(a), 0), lambda: constant_op.constant([0]),
1182      lambda: _argsort(a, axis, stable))
1183
1184  return np_array_ops.array(tf_ans, dtype=np.intp)
1185
1186
1187@np_utils.np_doc('sort')
1188def sort(a, axis=-1, kind='quicksort', order=None):  # pylint: disable=missing-docstring
1189  if kind != 'quicksort':
1190    raise ValueError("Only 'quicksort' is supported.")
1191  if order is not None:
1192    raise ValueError("'order' argument to sort is not supported.")
1193
1194  a = np_array_ops.array(a)
1195
1196  if axis is None:
1197    return sort_ops.sort(array_ops.reshape(a, [-1]), 0)
1198  else:
1199    return sort_ops.sort(a, axis)
1200
1201
1202def _argminmax(fn, a, axis=None):
1203  a = np_array_ops.array(a)
1204  if axis is None:
1205    # When axis is None numpy flattens the array.
1206    a_t = array_ops.reshape(a, [-1])
1207  else:
1208    a_t = np_array_ops.atleast_1d(a)
1209  return fn(input=a_t, axis=axis)
1210
1211
1212@np_utils.np_doc('argmax')
1213def argmax(a, axis=None):
1214  return _argminmax(math_ops.argmax, a, axis)
1215
1216
1217@np_utils.np_doc('argmin')
1218def argmin(a, axis=None):
1219  return _argminmax(math_ops.argmin, a, axis)
1220
1221
1222@np_utils.np_doc('append')
1223def append(arr, values, axis=None):
1224  if axis is None:
1225    return concatenate([np_array_ops.ravel(arr), np_array_ops.ravel(values)], 0)
1226  else:
1227    return concatenate([arr, values], axis=axis)
1228
1229
1230@np_utils.np_doc('average')
1231def average(a, axis=None, weights=None, returned=False):  # pylint: disable=missing-docstring
1232  if axis is not None and not isinstance(axis, six.integer_types):
1233    # TODO(wangpeng): Support tuple of ints as `axis`
1234    raise ValueError('`axis` must be an integer. Tuple of ints is not '
1235                     'supported yet. Got type: %s' % type(axis))
1236  a = np_array_ops.array(a)
1237  if weights is None:  # Treat all weights as 1
1238    if not np.issubdtype(a.dtype.as_numpy_dtype, np.inexact):
1239      a = a.astype(
1240          np_utils.result_type(a.dtype, np_dtypes.default_float_type()))
1241    avg = math_ops.reduce_mean(a, axis=axis)
1242    if returned:
1243      if axis is None:
1244        weights_sum = array_ops.size(a)
1245      else:
1246        weights_sum = array_ops.shape(a)[axis]
1247      weights_sum = math_ops.cast(weights_sum, a.dtype)
1248  else:
1249    if np.issubdtype(a.dtype.as_numpy_dtype, np.inexact):
1250      out_dtype = np_utils.result_type(a.dtype, weights)
1251    else:
1252      out_dtype = np_utils.result_type(a.dtype, weights,
1253                                       np_dtypes.default_float_type())
1254    a = np_array_ops.array(a, out_dtype)
1255    weights = np_array_ops.array(weights, out_dtype)
1256
1257    def rank_equal_case():
1258      control_flow_ops.Assert(
1259          math_ops.reduce_all(array_ops.shape(a) == array_ops.shape(weights)),
1260          [array_ops.shape(a), array_ops.shape(weights)])
1261      weights_sum = math_ops.reduce_sum(weights, axis=axis)
1262      avg = math_ops.reduce_sum(a * weights, axis=axis) / weights_sum
1263      return avg, weights_sum
1264
1265    if axis is None:
1266      avg, weights_sum = rank_equal_case()
1267    else:
1268
1269      def rank_not_equal_case():
1270        control_flow_ops.Assert(
1271            array_ops.rank(weights) == 1, [array_ops.rank(weights)])
1272        weights_sum = math_ops.reduce_sum(weights)
1273        axes = ops.convert_to_tensor([[axis], [0]])
1274        avg = math_ops.tensordot(a, weights, axes) / weights_sum
1275        return avg, weights_sum
1276
1277      # We condition on rank rather than shape equality, because if we do the
1278      # latter, when the shapes are partially unknown but the ranks are known
1279      # and different, np_utils.cond will run shape checking on the true branch,
1280      # which will raise a shape-checking error.
1281      avg, weights_sum = np_utils.cond(
1282          math_ops.equal(array_ops.rank(a), array_ops.rank(weights)),
1283          rank_equal_case, rank_not_equal_case)
1284
1285  avg = np_array_ops.array(avg)
1286  if returned:
1287    weights_sum = np_array_ops.broadcast_to(weights_sum, array_ops.shape(avg))
1288    return avg, weights_sum
1289  return avg
1290
1291
1292@np_utils.np_doc('trace')
1293def trace(a, offset=0, axis1=0, axis2=1, dtype=None):  # pylint: disable=missing-docstring
1294  if dtype:
1295    dtype = np_utils.result_type(dtype)
1296  a = np_array_ops.asarray(a, dtype)
1297
1298  if offset == 0:
1299    a_shape = a.shape
1300    if a_shape.rank is not None:
1301      rank = len(a_shape)
1302      if (axis1 == -2 or axis1 == rank - 2) and (axis2 == -1 or
1303                                                 axis2 == rank - 1):
1304        return math_ops.trace(a)
1305
1306  a = np_array_ops.diagonal(a, offset, axis1, axis2)
1307  return np_array_ops.sum(a, -1, dtype)
1308
1309
1310@np_utils.np_doc('meshgrid')
1311def meshgrid(*xi, **kwargs):
1312  """This currently requires copy=True and sparse=False."""
1313  sparse = kwargs.get('sparse', False)
1314  if sparse:
1315    raise ValueError('meshgrid doesnt support returning sparse arrays yet')
1316
1317  copy = kwargs.get('copy', True)
1318  if not copy:
1319    raise ValueError('meshgrid only supports copy=True')
1320
1321  indexing = kwargs.get('indexing', 'xy')
1322
1323  xi = [np_array_ops.asarray(arg) for arg in xi]
1324  kwargs = {'indexing': indexing}
1325
1326  outputs = array_ops.meshgrid(*xi, **kwargs)
1327
1328  return outputs
1329
1330
1331# Uses np_doc_only here because np.einsum (in 1.16) doesn't have argument
1332# `subscripts`, even though the doc says it has.
1333@np_utils.np_doc_only('einsum')
1334def einsum(subscripts, *operands, **kwargs):  # pylint: disable=missing-docstring
1335  casting = kwargs.get('casting', 'safe')
1336  optimize = kwargs.get('optimize', False)
1337  if casting == 'safe':
1338    operands = np_array_ops._promote_dtype(*operands)  # pylint: disable=protected-access
1339  elif casting == 'no':
1340    operands = [np_array_ops.asarray(x) for x in operands]
1341  else:
1342    raise ValueError('casting policy not supported: %s' % casting)
1343  if not optimize:
1344    # TF doesn't have a "no optimization" option.
1345    # TODO(wangpeng): Print a warning that np and tf use different
1346    #   optimizations.
1347    tf_optimize = 'greedy'
1348  elif optimize == True:  # pylint: disable=singleton-comparison,g-explicit-bool-comparison
1349    tf_optimize = 'greedy'
1350  elif optimize == 'greedy':
1351    tf_optimize = 'greedy'
1352  elif optimize == 'optimal':
1353    tf_optimize = 'optimal'
1354  else:
1355    raise ValueError('`optimize` method not supported: %s' % optimize)
1356  res = special_math_ops.einsum(subscripts, *operands, optimize=tf_optimize)
1357  return res
1358
1359
1360def _tensor_t(self):
1361  """Returns a Tensor which is the transpose of this Tensor."""
1362  return self.transpose()
1363
1364
1365def _tensor_ndim(self):
1366  """Returns the rank of the Tensor."""
1367  return self.shape.ndims
1368
1369
1370def _tensor_pos(self):
1371  """Returns self, for unary operator `+`."""
1372  return self
1373
1374
1375def _tensor_size(self):
1376  """Returns the number of elements in this Tensor, if fully known."""
1377  if not self.shape.is_fully_defined():
1378    return None
1379  return np.prod(self.shape.as_list())
1380
1381
1382def _tensor_tolist(self):
1383  if isinstance(self, ops.EagerTensor):
1384    return self._numpy().tolist()  # pylint: disable=protected-access
1385
1386  raise ValueError('Symbolic Tensors do not support the tolist API.')
1387
1388
1389def enable_numpy_methods_on_tensor():
1390  """Adds additional NumPy methods on tf.Tensor class."""
1391  t = property(_tensor_t)
1392  setattr(ops.Tensor, 'T', t)
1393
1394  ndim = property(_tensor_ndim)
1395  setattr(ops.Tensor, 'ndim', ndim)
1396
1397  size = property(_tensor_size)
1398  setattr(ops.Tensor, 'size', size)
1399
1400  setattr(ops.Tensor, '__pos__', _tensor_pos)
1401  setattr(ops.Tensor, 'tolist', _tensor_tolist)
1402
1403  # TODO(b/178540516): Make a custom `setattr` that changes the method's
1404  #   docstring to the TF one.
1405  setattr(ops.Tensor, 'transpose', np_array_ops.transpose)
1406  setattr(ops.Tensor, 'reshape', np_array_ops._reshape_method_wrapper)  # pylint: disable=protected-access
1407  setattr(ops.Tensor, 'ravel', np_array_ops.ravel)
1408  setattr(ops.Tensor, 'clip', clip)
1409  setattr(ops.Tensor, 'astype', math_ops.cast)
1410  setattr(ops.Tensor, '__round__', np_array_ops.around)
1411
1412  # TODO(wangpeng): Remove `data` when all uses of it are removed
1413  data = property(lambda self: self)
1414  setattr(ops.Tensor, 'data', data)
1415