1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Common array methods."""
16# pylint: disable=g-direct-tensorflow-import
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import enum
23import functools
24import math
25import numbers
26import numpy as np
27import six
28
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import tensor_shape
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import clip_ops
35from tensorflow.python.ops import control_flow_ops
36from tensorflow.python.ops import linalg_ops
37from tensorflow.python.ops import manip_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import sort_ops
40from tensorflow.python.ops.numpy_ops import np_arrays
41from tensorflow.python.ops.numpy_ops import np_dtypes
42from tensorflow.python.ops.numpy_ops import np_export
43from tensorflow.python.ops.numpy_ops import np_utils
44from tensorflow.python.util import nest
45
46
47newaxis = np_export.np_export_constant(__name__, 'newaxis', np.newaxis)
48
49
50@np_utils.np_doc('empty')
51def empty(shape, dtype=float):  # pylint: disable=redefined-outer-name
52  return zeros(shape, dtype)
53
54
55@np_utils.np_doc('empty_like')
56def empty_like(a, dtype=None):
57  return zeros_like(a, dtype)
58
59
60@np_utils.np_doc('zeros')
61def zeros(shape, dtype=float):  # pylint: disable=redefined-outer-name
62  dtype = (
63      np_utils.result_type(dtype) if dtype else np_dtypes.default_float_type())
64  return array_ops.zeros(shape, dtype=dtype)
65
66
67@np_utils.np_doc('zeros_like')
68def zeros_like(a, dtype=None):  # pylint: disable=missing-docstring
69  if dtype is None:
70    # We need to let np_utils.result_type decide the dtype, not tf.zeros_like
71    dtype = np_utils.result_type(a)
72  else:
73    # TF and numpy has different interpretations of Python types such as
74    # `float`, so we let `np_utils.result_type` decide.
75    dtype = np_utils.result_type(dtype)
76  dtype = dtypes.as_dtype(dtype)  # Work around b/149877262
77  return array_ops.zeros_like(a, dtype)
78
79
80@np_utils.np_doc('ones')
81def ones(shape, dtype=float):  # pylint: disable=redefined-outer-name
82  if dtype:
83    dtype = np_utils.result_type(dtype)
84  return array_ops.ones(shape, dtype=dtype)
85
86
87@np_utils.np_doc('ones_like')
88def ones_like(a, dtype=None):
89  if dtype is None:
90    dtype = np_utils.result_type(a)
91  else:
92    dtype = np_utils.result_type(dtype)
93  return array_ops.ones_like(a, dtype)
94
95
96@np_utils.np_doc('eye')
97def eye(N, M=None, k=0, dtype=float):  # pylint: disable=invalid-name,missing-docstring
98  if dtype:
99    dtype = np_utils.result_type(dtype)
100  if not M:
101    M = N
102  # Making sure N, M and k are `int`
103  N = int(N)
104  M = int(M)
105  k = int(k)
106  if k >= M or -k >= N:
107    # tf.linalg.diag will raise an error in this case
108    return zeros([N, M], dtype=dtype)
109  if k == 0:
110    return linalg_ops.eye(N, M, dtype=dtype)
111  # We need the precise length, otherwise tf.linalg.diag will raise an error
112  diag_len = min(N, M)
113  if k > 0:
114    if N >= M:
115      diag_len -= k
116    elif N + k > M:
117      diag_len = M - k
118  elif k <= 0:
119    if M >= N:
120      diag_len += k
121    elif M - k > N:
122      diag_len = N + k
123  diagonal_ = array_ops.ones([diag_len], dtype=dtype)
124  return array_ops.matrix_diag(diagonal=diagonal_, num_rows=N, num_cols=M, k=k)
125
126
127@np_utils.np_doc('identity')
128def identity(n, dtype=float):
129  return eye(N=n, M=n, dtype=dtype)
130
131
132@np_utils.np_doc('full')
133def full(shape, fill_value, dtype=None):  # pylint: disable=redefined-outer-name
134  if not isinstance(shape, np_arrays.ndarray):
135    shape = asarray(np_arrays.convert_to_tensor(shape, dtype_hint=np.int32))
136  shape = atleast_1d(shape)
137  fill_value = asarray(fill_value, dtype=dtype)
138  return array_ops.broadcast_to(fill_value, shape)
139
140
141# Using doc only here since np full_like signature doesn't seem to have the
142# shape argument (even though it exists in the documentation online).
143@np_utils.np_doc_only('full_like')
144def full_like(a, fill_value, dtype=None, order='K', subok=True, shape=None):  # pylint: disable=missing-docstring,redefined-outer-name
145  """order, subok and shape arguments mustn't be changed."""
146  if order != 'K':
147    raise ValueError('Non-standard orders are not supported.')
148  if not subok:
149    raise ValueError('subok being False is not supported.')
150  if shape:
151    raise ValueError('Overriding the shape is not supported.')
152
153  a = asarray(a)
154  dtype = dtype or np_utils.result_type(a)
155  fill_value = asarray(fill_value, dtype=dtype)
156  return array_ops.broadcast_to(fill_value, array_ops.shape(a))
157
158
159def _array_internal(val, dtype=None, copy=True, ndmin=0):  # pylint: disable=redefined-outer-name
160  """Main implementation of np.array()."""
161  result_t = val
162
163  if not isinstance(result_t, ops.Tensor):
164    if not dtype:
165      dtype = np_utils.result_type(result_t)
166    # We can't call `convert_to_tensor(result_t, dtype=dtype)` here because
167    # convert_to_tensor doesn't allow incompatible arguments such as (5.5, int)
168    # while np.array allows them. We need to convert-then-cast.
169
170    # EagerTensor conversion complains about "mixed types" when converting
171    # tensors with no dtype information. This is because it infers types based
172    # on one selected item in the list. So e.g. when converting [2., 2j]
173    # to a tensor, it will select float32 as the inferred type and not be able
174    # to convert the list to a float 32 tensor.
175    # Since we have some information about the final dtype we care about, we
176    # supply that information so that convert_to_tensor will do best-effort
177    # conversion to that dtype first.
178    result_t = np_arrays.convert_to_tensor(result_t, dtype_hint=dtype)
179    result_t = math_ops.cast(result_t, dtype=dtype)
180  elif dtype:
181    result_t = math_ops.cast(result_t, dtype)
182
183  if copy:
184    result_t = array_ops.identity(result_t)
185
186  if ndmin == 0:
187    return result_t
188
189  ndims = array_ops.rank(result_t)
190
191  def true_fn():
192    old_shape = array_ops.shape(result_t)
193    new_shape = array_ops.concat(
194        [array_ops.ones(ndmin - ndims, dtypes.int32), old_shape], axis=0)
195    return array_ops.reshape(result_t, new_shape)
196
197  result_t = np_utils.cond(
198      np_utils.greater(ndmin, ndims), true_fn, lambda: result_t)
199  return result_t
200
201
202# TODO(wangpeng): investigate whether we can make `copy` default to False.
203# pylint: disable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-doc-return-or-yield,g-doc-args
204@np_utils.np_doc_only('array')
205def array(val, dtype=None, copy=True, ndmin=0):  # pylint: disable=redefined-outer-name
206  """Since Tensors are immutable, a copy is made only if val is placed on a
207
208  different device than the current one. Even if `copy` is False, a new Tensor
209  may need to be built to satisfy `dtype` and `ndim`. This is used only if `val`
210  is an ndarray or a Tensor.
211  """  # pylint:disable=g-docstring-missing-newline
212  if dtype:
213    dtype = np_utils.result_type(dtype)
214  return _array_internal(val, dtype, copy, ndmin)
215
216
217# pylint: enable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-doc-return-or-yield,g-doc-args
218
219
220@np_utils.np_doc('asarray')
221def asarray(a, dtype=None):
222  if dtype:
223    dtype = np_utils.result_type(dtype)
224  if isinstance(a, np_arrays.ndarray) and (
225      not dtype or dtype == a.dtype.as_numpy_dtype):
226    return a
227  return array(a, dtype, copy=False)
228
229
230@np_utils.np_doc('asanyarray')
231def asanyarray(a, dtype=None):
232  return asarray(a, dtype)
233
234
235@np_utils.np_doc('ascontiguousarray')
236def ascontiguousarray(a, dtype=None):
237  return array(a, dtype, ndmin=1)
238
239
240# Numerical ranges.
241@np_utils.np_doc('arange')
242def arange(start, stop=None, step=1, dtype=None):
243  """Returns `step`-separated values in the range [start, stop).
244
245  Args:
246    start: Start of the interval. Included in the range.
247    stop: End of the interval. If not specified, `start` is treated as 0 and
248      `start` value is used as `stop`. If specified, it is not included in the
249      range if `step` is integer. When `step` is floating point, it may or may
250      not be included.
251    step: The difference between 2 consecutive values in the output range. It is
252      recommended to use `linspace` instead of using non-integer values for
253      `step`.
254    dtype: Optional. Type of the resulting ndarray. Could be a python type, a
255      NumPy type or a TensorFlow `DType`. If not provided, the largest type of
256      `start`, `stop`, `step` is used.
257
258  Raises:
259    ValueError: If step is zero.
260  """
261  if not step:
262    raise ValueError('step must be non-zero.')
263  if dtype:
264    dtype = np_utils.result_type(dtype)
265  else:
266    if stop is None:
267      dtype = np_utils.result_type(start, step)
268    else:
269      dtype = np_utils.result_type(start, step, stop)
270  if step > 0 and ((stop is not None and start > stop) or
271                   (stop is None and start < 0)):
272    return array([], dtype=dtype)
273  if step < 0 and ((stop is not None and start < stop) or
274                   (stop is None and start > 0)):
275    return array([], dtype=dtype)
276  # TODO(srbs): There are some bugs when start or stop is float type and dtype
277  # is integer type.
278  return math_ops.cast(
279      math_ops.range(start, limit=stop, delta=step), dtype=dtype)
280
281
282# Building matrices.
283@np_utils.np_doc('diag')
284def diag(v, k=0):  # pylint: disable=missing-docstring
285  """Raises an error if input is not 1- or 2-d."""
286  v = asarray(v)
287  v_rank = array_ops.rank(v)
288
289  v.shape.with_rank_at_most(2)
290
291  # TODO(nareshmodi): Consider a np_utils.Assert version that will fail during
292  # tracing time if the shape is known.
293  control_flow_ops.Assert(
294      np_utils.logical_or(math_ops.equal(v_rank, 1), math_ops.equal(v_rank, 2)),
295      [v_rank])
296
297  def _diag(v, k):
298    return np_utils.cond(
299        math_ops.equal(array_ops.size(v), 0),
300        lambda: array_ops.zeros([abs(k), abs(k)], dtype=v.dtype),
301        lambda: array_ops.matrix_diag(v, k=k))
302
303  def _diag_part(v, k):
304    v_shape = array_ops.shape(v)
305    v, k = np_utils.cond(
306        np_utils.logical_or(
307            np_utils.less_equal(k, -1 * np_utils.getitem(v_shape, 0)),
308            np_utils.greater_equal(k, np_utils.getitem(v_shape, 1)),
309        ), lambda: (array_ops.zeros([0, 0], dtype=v.dtype), 0), lambda: (v, k))
310    result = array_ops.matrix_diag_part(v, k=k)
311    return result
312
313  result = np_utils.cond(
314      math_ops.equal(v_rank, 1), lambda: _diag(v, k), lambda: _diag_part(v, k))
315  return result
316
317
318@np_utils.np_doc('diagonal')
319def diagonal(a, offset=0, axis1=0, axis2=1):  # pylint: disable=missing-docstring
320  a = asarray(a)
321
322  maybe_rank = a.shape.rank
323  if maybe_rank is not None and offset == 0 and (
324      axis1 == maybe_rank - 2 or axis1 == -2) and (axis2 == maybe_rank - 1 or
325                                                   axis2 == -1):
326    return array_ops.matrix_diag_part(a)
327
328  a = moveaxis(a, (axis1, axis2), (-2, -1))
329
330  a_shape = array_ops.shape(a)
331
332  def _zeros():  # pylint: disable=missing-docstring
333    return (array_ops.zeros(
334        array_ops.concat([a_shape[:-1], [0]], 0), dtype=a.dtype), 0)
335
336  # All zeros since diag_part doesn't handle all possible k (aka offset).
337  # Written this way since cond will run shape inference on both branches,
338  # and diag_part shape inference will fail when offset is out of bounds.
339  a, offset = np_utils.cond(
340      np_utils.logical_or(
341          np_utils.less_equal(offset, -1 * np_utils.getitem(a_shape, -2)),
342          np_utils.greater_equal(offset, np_utils.getitem(a_shape, -1)),
343      ), _zeros, lambda: (a, offset))
344
345  a = array_ops.matrix_diag_part(a, k=offset)
346  return a
347
348
349@np_utils.np_doc('diagflat')
350def diagflat(v, k=0):
351  v = asarray(v)
352  return diag(array_ops.reshape(v, [-1]), k)
353
354
355def _promote_dtype(*arrays):
356  dtype = np_utils.result_type(*arrays)
357  def _fast_asarray(a):
358    if isinstance(a, np_arrays.ndarray) and dtype == a.dtype.as_numpy_dtype:
359      return a
360    return _array_internal(a, dtype=dtype, copy=False)
361  return [_fast_asarray(a) for a in arrays]
362
363
364def _promote_dtype_binary(t1, t2):
365  dtype = np_utils._result_type_binary(t1, t2)  # pylint: disable=protected-access
366  if not(
367      isinstance(t1, np_arrays.ndarray) and dtype == t1.dtype.as_numpy_dtype):
368    t1 = _array_internal(t1, dtype=dtype, copy=False)
369  if not(
370      isinstance(t2, np_arrays.ndarray) and dtype == t2.dtype.as_numpy_dtype):
371    t2 = _array_internal(t2, dtype=dtype, copy=False)
372  return t1, t2
373
374
375@np_utils.np_doc('all')
376def all(a, axis=None, keepdims=None):  # pylint: disable=redefined-builtin
377  a = asarray(a, dtype=bool)
378  return math_ops.reduce_all(input_tensor=a, axis=axis, keepdims=keepdims)
379
380
381@np_utils.np_doc('any')
382def any(a, axis=None, keepdims=None):  # pylint: disable=redefined-builtin
383  a = asarray(a, dtype=bool)
384  return math_ops.reduce_any(input_tensor=a, axis=axis, keepdims=keepdims)
385
386
387@np_utils.np_doc('compress')
388def compress(condition, a, axis=None):  # pylint: disable=redefined-outer-name,missing-function-docstring
389  condition = asarray(condition, dtype=bool)
390  a = asarray(a)
391
392  if condition.ndim != 1:
393    raise ValueError('condition must be a 1-d array.')
394  # `np.compress` treats scalars as 1-d arrays.
395  if a.ndim == 0:
396    a = ravel(a)
397
398  if axis is None:
399    a = ravel(a)
400    axis = 0
401
402  if axis < 0:
403    axis += a.ndim
404
405  assert axis >= 0 and axis < a.ndim
406
407  # `tf.boolean_mask` requires the first dimensions of array and condition to
408  # match. `np.compress` pads condition with False when it is shorter.
409  condition_t = condition
410  a_t = a
411  if condition.shape[0] < a.shape[axis]:
412    padding = array_ops.fill([a.shape[axis] - condition.shape[0]], False)
413    condition_t = array_ops.concat([condition_t, padding], axis=0)
414  return array_ops.boolean_mask(tensor=a_t, mask=condition_t, axis=axis)
415
416
417@np_utils.np_doc('copy')
418def copy(a):
419  return array(a, copy=True)
420
421
422def _maybe_promote_to_int(a):
423  if dtypes.as_dtype(a.dtype).is_integer:
424    # If a is an integer type and its precision is less than that of `int`,
425    # the output type will be `int`.
426    a_numpy_dtype = a.dtype.as_numpy_dtype
427    output_type = np.promote_types(a_numpy_dtype, int)
428    if output_type != a_numpy_dtype:
429      a = asarray(a, dtype=output_type)
430
431  return a
432
433
434@np_utils.np_doc('cumprod')
435def cumprod(a, axis=None, dtype=None):  # pylint: disable=missing-docstring
436  a = asarray(a, dtype=dtype)
437
438  if dtype is None:
439    a = _maybe_promote_to_int(a)
440
441  # If axis is None, the input is flattened.
442  if axis is None:
443    a = ravel(a)
444    axis = 0
445  elif axis < 0:
446    axis += array_ops.rank(a)
447  return math_ops.cumprod(a, axis)
448
449
450@np_utils.np_doc('cumsum')
451def cumsum(a, axis=None, dtype=None):  # pylint: disable=missing-docstring
452  a = asarray(a, dtype=dtype)
453
454  if dtype is None:
455    a = _maybe_promote_to_int(a)
456
457  # If axis is None, the input is flattened.
458  if axis is None:
459    a = ravel(a)
460    axis = 0
461  elif axis < 0:
462    axis += array_ops.rank(a)
463  return math_ops.cumsum(a, axis)
464
465
466@np_utils.np_doc('imag')
467def imag(val):
468  val = asarray(val)
469  # TODO(srbs): np.imag returns a scalar if `val` is a scalar, whereas we always
470  # return an ndarray.
471  return math_ops.imag(val)
472
473
474_TO_INT_ = 0
475_TO_FLOAT = 1
476
477
478def _reduce(tf_fn,
479            a,
480            axis=None,
481            dtype=None,
482            keepdims=None,
483            promote_int=_TO_INT_,
484            tf_bool_fn=None,
485            preserve_bool=False):
486  """A general reduction function.
487
488  Args:
489    tf_fn: the TF reduction function.
490    a: the array to be reduced.
491    axis: (optional) the axis along which to do the reduction. If None, all
492      dimensions are reduced.
493    dtype: (optional) the dtype of the result.
494    keepdims: (optional) whether to keep the reduced dimension(s).
495    promote_int: how to promote integer and bool inputs. There are three
496      choices. (1) `_TO_INT_` always promotes them to np.int_ or np.uint; (2)
497      `_TO_FLOAT` always promotes them to a float type (determined by
498      dtypes.default_float_type); (3) None: don't promote.
499    tf_bool_fn: (optional) the TF reduction function for bool inputs. It will
500      only be used if `dtype` is explicitly set to `np.bool_` or if `a`'s dtype
501      is `np.bool_` and `preserve_bool` is True.
502    preserve_bool: a flag to control whether to use `tf_bool_fn` if `a`'s dtype
503      is `np.bool_` (some reductions such as np.sum convert bools to integers,
504      while others such as np.max preserve bools.
505
506  Returns:
507    An ndarray.
508  """
509  if dtype:
510    dtype = np_utils.result_type(dtype)
511  if keepdims is None:
512    keepdims = False
513  a = asarray(a, dtype=dtype)
514  if ((dtype == np.bool_ or preserve_bool and a.dtype == np.bool_) and
515      tf_bool_fn is not None):
516    return tf_bool_fn(input_tensor=a, axis=axis, keepdims=keepdims)
517  if dtype is None:
518    dtype = a.dtype.as_numpy_dtype
519    if np.issubdtype(dtype, np.integer) or dtype == np.bool_:
520      if promote_int == _TO_INT_:
521        # If a is an integer/bool type and whose bit width is less than np.int_,
522        # numpy up-casts it to np.int_ based on the documentation at
523        # https://numpy.org/doc/1.18/reference/generated/numpy.sum.html
524        if dtype == np.bool_:
525          is_signed = True
526          width = 8  # We can use any number here that is less than 64
527        else:
528          is_signed = np.issubdtype(dtype, np.signedinteger)
529          width = np.iinfo(dtype).bits
530        # Numpy int_ and uint are defined as 'long' and 'unsigned long', so
531        # should have the same bit width.
532        if width < np.iinfo(np.int_).bits:
533          if is_signed:
534            dtype = np.int_
535          else:
536            dtype = np.uint
537          a = math_ops.cast(a, dtype)
538      elif promote_int == _TO_FLOAT:
539        a = math_ops.cast(a, np_dtypes.default_float_type())
540
541  if isinstance(axis, ops.Tensor) and axis.dtype not in (
542      dtypes.int32, dtypes.int64):
543    axis = math_ops.cast(axis, dtypes.int64)
544
545  return tf_fn(input_tensor=a, axis=axis, keepdims=keepdims)
546
547
548# TODO (DarrenZhang01): Add `axis` support to the `size` API.
549@np_utils.np_doc('size')
550def size(x, axis=None):  # pylint: disable=missing-docstring
551  if axis is not None:
552    raise NotImplementedError('axis argument is not supported in the current '
553                              '`np.size` implementation')
554  if isinstance(x, (int, float, np.int32, np.int64, np.float32, np.float64)):
555    return 1
556  x = asarray(x)
557  if x.shape.is_fully_defined():
558    return np.prod(x.shape.as_list(), dtype=int)
559  else:
560    return array_ops.size_v2(x)
561
562
563@np_utils.np_doc('sum')
564def sum(a, axis=None, dtype=None, keepdims=None):  # pylint: disable=redefined-builtin
565  return _reduce(
566      math_ops.reduce_sum,
567      a,
568      axis=axis,
569      dtype=dtype,
570      keepdims=keepdims,
571      tf_bool_fn=math_ops.reduce_any)
572
573
574@np_utils.np_doc('prod')
575def prod(a, axis=None, dtype=None, keepdims=None):
576  return _reduce(
577      math_ops.reduce_prod,
578      a,
579      axis=axis,
580      dtype=dtype,
581      keepdims=keepdims,
582      tf_bool_fn=math_ops.reduce_all)
583
584
585@np_utils.np_doc('mean')
586def mean(a, axis=None, dtype=None, keepdims=None):
587  return _reduce(
588      math_ops.reduce_mean,
589      a,
590      axis=axis,
591      dtype=dtype,
592      keepdims=keepdims,
593      promote_int=_TO_FLOAT)
594
595
596@np_utils.np_doc('amax')
597def amax(a, axis=None, keepdims=None):
598  return _reduce(
599      math_ops.reduce_max,
600      a,
601      axis=axis,
602      dtype=None,
603      keepdims=keepdims,
604      promote_int=None,
605      tf_bool_fn=math_ops.reduce_any,
606      preserve_bool=True)
607
608
609@np_utils.np_doc('amin')
610def amin(a, axis=None, keepdims=None):
611  return _reduce(
612      math_ops.reduce_min,
613      a,
614      axis=axis,
615      dtype=None,
616      keepdims=keepdims,
617      promote_int=None,
618      tf_bool_fn=math_ops.reduce_all,
619      preserve_bool=True)
620
621
622@np_utils.np_doc('var')
623def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None):  # pylint: disable=missing-docstring
624  if dtype:
625    working_dtype = np_utils.result_type(a, dtype)
626  else:
627    working_dtype = None
628  if out is not None:
629    raise ValueError('Setting out is not supported.')
630  if ddof != 0:
631    # TF reduce_variance doesn't support ddof, so calculate it using raw ops.
632    def reduce_fn(input_tensor, axis, keepdims):
633      means = math_ops.reduce_mean(input_tensor, axis=axis, keepdims=True)
634      centered = input_tensor - means
635      if input_tensor.dtype in (dtypes.complex64, dtypes.complex128):
636        centered = math_ops.cast(
637            math_ops.real(centered * math_ops.conj(centered)),
638            input_tensor.dtype)
639      else:
640        centered = math_ops.square(centered)
641      squared_deviations = math_ops.reduce_sum(
642          centered, axis=axis, keepdims=keepdims)
643
644      if axis is None:
645        n = array_ops.size(input_tensor)
646      else:
647        if axis < 0:
648          axis += array_ops.rank(input_tensor)
649        n = math_ops.reduce_prod(
650            array_ops.gather(array_ops.shape(input_tensor), axis))
651      n = math_ops.cast(n - ddof, input_tensor.dtype)
652
653      return math_ops.cast(math_ops.divide(squared_deviations, n), dtype)
654  else:
655    reduce_fn = math_ops.reduce_variance
656
657  result = _reduce(
658      reduce_fn,
659      a,
660      axis=axis,
661      dtype=working_dtype,
662      keepdims=keepdims,
663      promote_int=_TO_FLOAT)
664  if dtype:
665    result = math_ops.cast(result, dtype)
666  return result
667
668
669@np_utils.np_doc('std')
670def std(a, axis=None, keepdims=None):  # pylint: disable=missing-function-docstring
671  return _reduce(
672      math_ops.reduce_std,
673      a,
674      axis=axis,
675      dtype=None,
676      keepdims=keepdims,
677      promote_int=_TO_FLOAT)
678
679
680@np_utils.np_doc('ravel')
681def ravel(a):  # pylint: disable=missing-docstring
682  a = asarray(a)
683  return array_ops.reshape(a, [-1])
684
685
686@np_utils.np_doc('real')
687def real(val):
688  val = asarray(val)
689  # TODO(srbs): np.real returns a scalar if val is a scalar, whereas we always
690  # return an ndarray.
691  return math_ops.real(val)
692
693
694@np_utils.np_doc('repeat')
695def repeat(a, repeats, axis=None):  # pylint: disable=missing-docstring
696  a = asarray(a)
697  original_shape = a._shape_as_list()  # pylint: disable=protected-access
698  # Best effort recovery of the shape.
699  known_shape = original_shape is not None and None not in original_shape
700  if known_shape:
701    if not original_shape:
702      original_shape = (repeats,)
703    else:
704      repeats_np = np.ravel(np.array(repeats))
705      if repeats_np.size == 1:
706        repeats_np = repeats_np.item()
707        if axis is None:
708          original_shape = (repeats_np * np.prod(original_shape),)
709        else:
710          original_shape[axis] = repeats_np * original_shape[axis]
711      else:
712        if axis is None:
713          original_shape = (repeats_np.sum(),)
714        else:
715          original_shape[axis] = repeats_np.sum()
716
717  repeats = asarray(repeats)
718  result = array_ops.repeat(a, repeats, axis)
719  if known_shape:
720    result.set_shape(original_shape)
721
722  return result
723
724
725@np_utils.np_doc('around')
726def around(a, decimals=0):  # pylint: disable=missing-docstring
727  a = asarray(a)
728  dtype = a.dtype.as_numpy_dtype
729  factor = math.pow(10, decimals)
730  if np.issubdtype(dtype, np.inexact):
731    factor = math_ops.cast(factor, dtype)
732  else:
733    # Use float as the working dtype when a.dtype is exact (e.g. integer),
734    # because `decimals` can be negative.
735    float_dtype = np_dtypes.default_float_type()
736    a = a.astype(float_dtype)
737    factor = math_ops.cast(factor, float_dtype)
738  a = math_ops.multiply(a, factor)
739  a = math_ops.round(a)
740  a = math_ops.divide(a, factor)
741  return a.astype(dtype)
742
743
744setattr(np_arrays.ndarray, '__round__', around)
745
746
747@np_utils.np_doc('reshape')
748def reshape(a, newshape, order='C'):
749  """order argument can only b 'C' or 'F'."""
750  if order not in {'C', 'F'}:
751    raise ValueError('Unsupported order argument {}'.format(order))
752
753  a = asarray(a)
754  if isinstance(newshape, int):
755    newshape = [newshape]
756
757  if order == 'F':
758    r = array_ops.transpose(
759        array_ops.reshape(array_ops.transpose(a), newshape[::-1]))
760  else:
761    r = array_ops.reshape(a, newshape)
762
763  return r
764
765
766def _reshape_method_wrapper(a, *newshape, **kwargs):
767  order = kwargs.pop('order', 'C')
768  if kwargs:
769    raise ValueError('Unsupported arguments: {}'.format(kwargs.keys()))
770
771  if len(newshape) == 1 and not isinstance(newshape[0], int):
772    newshape = newshape[0]
773
774  return reshape(a, newshape, order=order)
775
776
777@np_utils.np_doc('expand_dims')
778def expand_dims(a, axis):
779  a = asarray(a)
780  return array_ops.expand_dims(a, axis=axis)
781
782
783@np_utils.np_doc('squeeze')
784def squeeze(a, axis=None):
785  a = asarray(a)
786  return array_ops.squeeze(a, axis)
787
788
789@np_utils.np_doc('transpose')
790def transpose(a, axes=None):
791  a = asarray(a)
792  if axes is not None:
793    axes = asarray(axes)
794  return array_ops.transpose(a=a, perm=axes)
795
796
797@np_utils.np_doc('swapaxes')
798def swapaxes(a, axis1, axis2):  # pylint: disable=missing-docstring
799  a = asarray(a)
800  def adjust_axes(axes, rank):
801    def f(x):
802      if isinstance(x, int):
803        if x < 0:
804          x = x + rank
805      else:
806        x = array_ops.where_v2(x < 0, np_utils.add(x, a_rank), x)
807      return x
808    return nest.map_structure(f, axes)
809
810  if (a.shape.rank is not None and
811      isinstance(axis1, int) and isinstance(axis2, int)):
812    # This branch makes sure `perm` is statically known, to avoid a
813    # not-compile-time-constant XLA error.
814    a_rank = a.shape.rank
815    axis1, axis2 = adjust_axes((axis1, axis2), a_rank)
816    perm = list(range(a_rank))
817    perm[axis1] = axis2
818    perm[axis2] = axis1
819  else:
820    a_rank = array_ops.rank(a)
821    axis1, axis2 = adjust_axes((axis1, axis2), a_rank)
822    perm = math_ops.range(a_rank)
823    perm = array_ops.tensor_scatter_update(perm, [[axis1], [axis2]],
824                                           [axis2, axis1])
825  a = array_ops.transpose(a, perm)
826  return a
827
828
829@np_utils.np_doc('moveaxis')
830def moveaxis(a, source, destination):  # pylint: disable=missing-docstring
831  """Raises ValueError if source, destination not in (-ndim(a), ndim(a))."""
832  if not source and not destination:
833    return a
834
835  a = asarray(a)
836
837  if isinstance(source, int):
838    source = (source,)
839  if isinstance(destination, int):
840    destination = (destination,)
841  if len(source) != len(destination):
842    raise ValueError('The lengths of source and destination must equal')
843
844  a_rank = np_utils._maybe_static(array_ops.rank(a))  # pylint: disable=protected-access
845
846  def _correct_axis(axis, rank):
847    if axis < 0:
848      return axis + rank
849    return axis
850
851  source = tuple(_correct_axis(axis, a_rank) for axis in source)
852  destination = tuple(_correct_axis(axis, a_rank) for axis in destination)
853
854  if a.shape.rank is not None:
855    perm = [i for i in range(a_rank) if i not in source]
856    for dest, src in sorted(zip(destination, source)):
857      assert dest <= len(perm)
858      perm.insert(dest, src)
859  else:
860    r = math_ops.range(a_rank)
861
862    def _remove_indices(a, b):
863      """Remove indices (`b`) from `a`."""
864      items = array_ops.unstack(sort_ops.sort(array_ops.stack(b)), num=len(b))
865
866      i = 0
867      result = []
868
869      for item in items:
870        result.append(a[i:item])
871        i = item + 1
872
873      result.append(a[i:])
874
875      return array_ops.concat(result, 0)
876
877    minus_sources = _remove_indices(r, source)
878    minus_dest = _remove_indices(r, destination)
879
880    perm = array_ops.scatter_nd(
881        array_ops.expand_dims(minus_dest, 1), minus_sources, [a_rank])
882    perm = array_ops.tensor_scatter_update(
883        perm, array_ops.expand_dims(destination, 1), source)
884  a = array_ops.transpose(a, perm)
885
886  return a
887
888
889@np_utils.np_doc('pad')
890def pad(array, pad_width, mode, **kwargs):  # pylint: disable=redefined-outer-name
891  """Only supports modes 'constant', 'reflect' and 'symmetric' currently."""
892  constant_values = kwargs.get('constant_values', 0)
893  if not (mode == 'constant' or mode == 'reflect' or mode == 'symmetric'):
894    raise ValueError('Unsupported padding mode: ' + mode)
895  mode = mode.upper()
896  array = asarray(array)
897  pad_width = asarray(pad_width, dtype=dtypes.int32)
898  return array_ops.pad(
899      tensor=array,
900      paddings=pad_width,
901      mode=mode,
902      constant_values=constant_values)
903
904
905@np_utils.np_doc('take')
906def take(a, indices, axis=None, out=None, mode='clip'):
907  """out argument is not supported, and default mode is clip."""
908  if out is not None:
909    raise ValueError('out argument is not supported in take.')
910
911  if mode not in {'raise', 'clip', 'wrap'}:
912    raise ValueError("Invalid mode '{}' for take".format(mode))
913
914  a = asarray(a)
915  indices = asarray(indices)
916
917  if axis is None:
918    a = array_ops.reshape(a, [-1])
919    axis = 0
920
921  axis_size = array_ops.shape(a, out_type=indices.dtype)[axis]
922  if mode == 'clip':
923    indices = clip_ops.clip_by_value(indices, 0, axis_size - 1)
924  elif mode == 'wrap':
925    indices = math_ops.floormod(indices, axis_size)
926  else:
927    raise ValueError("The 'raise' mode to take is not supported.")
928
929  return array_ops.gather(a, indices, axis=axis)
930
931
932@np_utils.np_doc_only('where')
933def where(condition, x=None, y=None):
934  """Raises ValueError if exactly one of x or y is not None."""
935  condition = asarray(condition, dtype=np.bool_)
936  if x is None and y is None:
937    return nonzero(condition)
938  elif x is not None and y is not None:
939    x, y = _promote_dtype(x, y)
940    return array_ops.where_v2(condition, x, y)
941  raise ValueError('Both x and y must be ndarrays, or both must be None.')
942
943
944@np_utils.np_doc('select')
945def select(condlist, choicelist, default=0):  # pylint: disable=missing-docstring
946  if len(condlist) != len(choicelist):
947    msg = 'condlist must have length equal to choicelist ({} vs {})'
948    raise ValueError(msg.format(len(condlist), len(choicelist)))
949  if not condlist:
950    raise ValueError('condlist must be non-empty')
951  choices = _promote_dtype(default, *choicelist)
952  choicelist = choices[1:]
953  output = choices[0]
954  # The traversal is in reverse order so we can return the first value in
955  # choicelist where condlist is True.
956  for cond, choice in zip(condlist[::-1], choicelist[::-1]):
957    output = where(cond, choice, output)
958  return output
959
960
961@np_utils.np_doc('shape', link=np_utils.Link(
962    'https://numpy.org/doc/1.18/reference/generated/numpy.shape.html'))
963def shape(a):
964  a = asarray(a)
965  return a.shape
966
967
968@np_utils.np_doc('ndim', link=np_utils.NoLink())
969def ndim(a):
970  a = asarray(a)
971  return a.ndim
972
973
974@np_utils.np_doc('isscalar')
975def isscalar(num):
976  return ndim(num) == 0
977
978
979def _boundaries_to_sizes(a, boundaries, axis):
980  """Converting boundaries of splits to sizes of splits.
981
982  Args:
983    a: the array to be split.
984    boundaries: the boundaries, as in np.split.
985    axis: the axis along which to split.
986
987  Returns:
988    A list of sizes of the splits, as in tf.split.
989  """
990  if axis >= len(a.shape):
991    raise ValueError('axis %s is out of bound for shape %s' % (axis, a.shape))
992  total_size = a.shape[axis]
993  sizes = []
994  sizes_sum = 0
995  prev = 0
996  for i, b in enumerate(boundaries):
997    size = b - prev
998    if size < 0:
999      raise ValueError('The %s-th boundary %s is smaller than the previous '
1000                       'boundary %s' % (i, b, prev))
1001    size = min(size, max(0, total_size - sizes_sum))
1002    sizes.append(size)
1003    sizes_sum += size
1004    prev = b
1005  sizes.append(max(0, total_size - sizes_sum))
1006  return sizes
1007
1008
1009@np_utils.np_doc('split')
1010def split(ary, indices_or_sections, axis=0):
1011  ary = asarray(ary)
1012  if not isinstance(indices_or_sections, six.integer_types):
1013    indices_or_sections = _boundaries_to_sizes(ary, indices_or_sections, axis)
1014  return array_ops.split(ary, indices_or_sections, axis=axis)
1015
1016
1017def _split_on_axis(np_fun_name, axis):
1018
1019  @np_utils.np_doc(np_fun_name)
1020  def f(ary, indices_or_sections):
1021    return split(ary, indices_or_sections, axis=axis)
1022
1023  return f
1024
1025
1026vsplit = _split_on_axis('vsplit', axis=0)
1027hsplit = _split_on_axis('hsplit', axis=1)
1028dsplit = _split_on_axis('dsplit', axis=2)
1029
1030
1031@np_utils.np_doc('broadcast_to')
1032def broadcast_to(array, shape):  # pylint: disable=redefined-outer-name
1033  return full(shape, array)
1034
1035
1036@np_utils.np_doc('stack')
1037def stack(arrays, axis=0):  # pylint: disable=missing-function-docstring
1038  if isinstance(arrays, (np_arrays.ndarray, ops.Tensor)):
1039    arrays = asarray(arrays)
1040    if axis == 0:
1041      return arrays
1042    else:
1043      return swapaxes(arrays, 0, axis)
1044  arrays = _promote_dtype(*arrays)  # pylint: disable=protected-access
1045  unwrapped_arrays = [
1046      a if isinstance(a, np_arrays.ndarray) else a for a in arrays
1047  ]
1048  return asarray(array_ops.stack(unwrapped_arrays, axis))
1049
1050
1051@np_utils.np_doc('hstack')
1052def hstack(tup):
1053  arrays = [atleast_1d(a) for a in tup]
1054  arrays = _promote_dtype(*arrays)  # pylint: disable=protected-access
1055  unwrapped_arrays = [
1056      a if isinstance(a, np_arrays.ndarray) else a for a in arrays
1057  ]
1058  rank = array_ops.rank(unwrapped_arrays[0])
1059  return np_utils.cond(
1060      math_ops.equal(rank,
1061                     1), lambda: array_ops.concat(unwrapped_arrays, axis=0),
1062      lambda: array_ops.concat(unwrapped_arrays, axis=1))
1063
1064
1065@np_utils.np_doc('vstack')
1066def vstack(tup):
1067  arrays = [atleast_2d(a) for a in tup]
1068  arrays = _promote_dtype(*arrays)  # pylint: disable=protected-access
1069  unwrapped_arrays = [
1070      a if isinstance(a, np_arrays.ndarray) else a for a in arrays
1071  ]
1072  return array_ops.concat(unwrapped_arrays, axis=0)
1073
1074
1075@np_utils.np_doc('dstack')
1076def dstack(tup):
1077  arrays = [atleast_3d(a) for a in tup]
1078  arrays = _promote_dtype(*arrays)  # pylint: disable=protected-access
1079  unwrapped_arrays = [
1080      a if isinstance(a, np_arrays.ndarray) else a for a in arrays
1081  ]
1082  return array_ops.concat(unwrapped_arrays, axis=2)
1083
1084
1085def _pad_left_to(n, old_shape):
1086  old_shape = asarray(old_shape, dtype=np.int32)
1087  new_shape = array_ops.pad(
1088      old_shape, [[math_ops.maximum(n - array_ops.size(old_shape), 0), 0]],
1089      constant_values=1)
1090  return asarray(new_shape)
1091
1092
1093def _atleast_nd(n, new_shape, *arys):
1094  """Reshape arrays to be at least `n`-dimensional.
1095
1096  Args:
1097    n: The minimal rank.
1098    new_shape: a function that takes `n` and the old shape and returns the
1099      desired new shape.
1100    *arys: ndarray(s) to be reshaped.
1101
1102  Returns:
1103    The reshaped array(s).
1104  """
1105
1106  def f(x):
1107    # pylint: disable=g-long-lambda
1108    x = asarray(x)
1109    return asarray(
1110        np_utils.cond(
1111            np_utils.greater(n, array_ops.rank(x)),
1112            lambda: reshape(x, new_shape(n, array_ops.shape(x))),
1113            lambda: x))
1114
1115  arys = list(map(f, arys))
1116  if len(arys) == 1:
1117    return arys[0]
1118  else:
1119    return arys
1120
1121
1122@np_utils.np_doc('atleast_1d')
1123def atleast_1d(*arys):
1124  return _atleast_nd(1, _pad_left_to, *arys)
1125
1126
1127@np_utils.np_doc('atleast_2d')
1128def atleast_2d(*arys):
1129  return _atleast_nd(2, _pad_left_to, *arys)
1130
1131
1132@np_utils.np_doc('atleast_3d')
1133def atleast_3d(*arys):  # pylint: disable=missing-docstring
1134
1135  def new_shape(_, old_shape):
1136    # pylint: disable=g-long-lambda
1137    ndim_ = array_ops.size(old_shape)
1138    return np_utils.cond(
1139        math_ops.equal(ndim_, 0),
1140        lambda: constant_op.constant([1, 1, 1], dtype=dtypes.int32),
1141        lambda: np_utils.cond(
1142            math_ops.equal(ndim_, 1), lambda: array_ops.pad(
1143                old_shape, [[1, 1]], constant_values=1), lambda: array_ops.pad(
1144                    old_shape, [[0, 1]], constant_values=1)))
1145
1146  return _atleast_nd(3, new_shape, *arys)
1147
1148
1149@np_utils.np_doc('nonzero')
1150def nonzero(a):
1151  a = atleast_1d(a)
1152  if a.shape.rank is None:
1153    raise ValueError("The rank of `a` is unknown, so we can't decide how many "
1154                     'arrays to return.')
1155  return array_ops.unstack(
1156            array_ops.where_v2(math_ops.cast(a, dtypes.bool)),
1157            a.shape.rank,
1158            axis=1)
1159
1160
1161@np_utils.np_doc('diag_indices')
1162def diag_indices(n, ndim=2):  # pylint: disable=missing-docstring,redefined-outer-name
1163  if n < 0:
1164    raise ValueError(
1165        'n argument to diag_indices must be nonnegative, got {}'.format(n))
1166  if ndim < 0:
1167    raise ValueError(
1168        'ndim argument to diag_indices must be nonnegative, got {}'.format(
1169            ndim))
1170
1171  return (math_ops.range(n),) * ndim
1172
1173
1174@np_utils.np_doc('tri')
1175def tri(N, M=None, k=0, dtype=None):  # pylint: disable=invalid-name,missing-docstring
1176  M = M if M is not None else N
1177  if dtype is not None:
1178    dtype = np_utils.result_type(dtype)
1179  else:
1180    dtype = np_dtypes.default_float_type()
1181
1182  if k < 0:
1183    lower = -k - 1
1184    if lower > N:
1185      r = array_ops.zeros([N, M], dtype)
1186    else:
1187      # Keep as tf bool, since we create an upper triangular matrix and invert
1188      # it.
1189      o = array_ops.ones([N, M], dtype=dtypes.bool)
1190      r = math_ops.cast(
1191          math_ops.logical_not(array_ops.matrix_band_part(o, lower, -1)), dtype)
1192  else:
1193    o = array_ops.ones([N, M], dtype)
1194    if k > M:
1195      r = o
1196    else:
1197      r = array_ops.matrix_band_part(o, -1, k)
1198  return r
1199
1200
1201@np_utils.np_doc('tril')
1202def tril(m, k=0):  # pylint: disable=missing-docstring
1203  m = asarray(m)
1204  if m.shape.ndims is None:
1205    raise ValueError('Argument to tril should have known rank')
1206  m_shape = m.shape.as_list()
1207
1208  if len(m_shape) < 2:
1209    raise ValueError('Argument to tril must have rank at least 2')
1210
1211  if m_shape[-1] is None or m_shape[-2] is None:
1212    raise ValueError('Currently, the last two dimensions of the input array '
1213                     'need to be known.')
1214
1215  z = constant_op.constant(0, m.dtype)
1216
1217  mask = tri(*m_shape[-2:], k=k, dtype=bool)
1218  return array_ops.where_v2(
1219      array_ops.broadcast_to(mask, array_ops.shape(m)), m, z)
1220
1221
1222@np_utils.np_doc('triu')
1223def triu(m, k=0):  # pylint: disable=missing-docstring
1224  m = asarray(m)
1225  if m.shape.ndims is None:
1226    raise ValueError('Argument to triu should have known rank')
1227  m_shape = m.shape.as_list()
1228
1229  if len(m_shape) < 2:
1230    raise ValueError('Argument to triu must have rank at least 2')
1231
1232  if m_shape[-1] is None or m_shape[-2] is None:
1233    raise ValueError('Currently, the last two dimensions of the input array '
1234                     'need to be known.')
1235
1236  z = constant_op.constant(0, m.dtype)
1237
1238  mask = tri(*m_shape[-2:], k=k - 1, dtype=bool)
1239  return array_ops.where_v2(
1240      array_ops.broadcast_to(mask, array_ops.shape(m)), z, m)
1241
1242
1243@np_utils.np_doc('flip')
1244def flip(m, axis=None):  # pylint: disable=missing-docstring
1245  m = asarray(m)
1246
1247  if axis is None:
1248    return array_ops.reverse(m, math_ops.range(array_ops.rank(m)))
1249
1250  axis = np_utils._canonicalize_axis(axis, array_ops.rank(m))  # pylint: disable=protected-access
1251
1252  return array_ops.reverse(m, [axis])
1253
1254
1255@np_utils.np_doc('flipud')
1256def flipud(m):  # pylint: disable=missing-docstring
1257  return flip(m, 0)
1258
1259
1260@np_utils.np_doc('fliplr')
1261def fliplr(m):  # pylint: disable=missing-docstring
1262  return flip(m, 1)
1263
1264
1265@np_utils.np_doc('roll')
1266def roll(a, shift, axis=None):  # pylint: disable=missing-docstring
1267  a = asarray(a)
1268
1269  if axis is not None:
1270    return manip_ops.roll(a, shift, axis)
1271
1272  # If axis is None, the roll happens as a 1-d tensor.
1273  original_shape = array_ops.shape(a)
1274  a = manip_ops.roll(array_ops.reshape(a, [-1]), shift, 0)
1275  return array_ops.reshape(a, original_shape)
1276
1277
1278@np_utils.np_doc('rot90')
1279def rot90(m, k=1, axes=(0, 1)):  # pylint: disable=missing-docstring
1280  m_rank = array_ops.rank(m)
1281  ax1, ax2 = np_utils._canonicalize_axes(axes, m_rank)  # pylint: disable=protected-access
1282
1283  k = k % 4
1284  if k == 0:
1285    return m
1286  elif k == 2:
1287    return flip(flip(m, ax1), ax2)
1288  else:
1289    perm = math_ops.range(m_rank)
1290    perm = array_ops.tensor_scatter_update(perm, [[ax1], [ax2]], [ax2, ax1])
1291
1292    if k == 1:
1293      return transpose(flip(m, ax2), perm)
1294    else:
1295      return flip(transpose(m, perm), ax2)
1296
1297
1298@np_utils.np_doc('vander')
1299def vander(x, N=None, increasing=False):  # pylint: disable=missing-docstring,invalid-name
1300  x = asarray(x)
1301
1302  x_shape = array_ops.shape(x)
1303  N = N or x_shape[0]
1304
1305  N_temp = np_utils.get_static_value(N)  # pylint: disable=invalid-name
1306  if N_temp is not None:
1307    N = N_temp
1308    if N < 0:
1309      raise ValueError('N must be nonnegative')
1310  else:
1311    control_flow_ops.Assert(N >= 0, [N])
1312
1313  rank = array_ops.rank(x)
1314  rank_temp = np_utils.get_static_value(rank)
1315  if rank_temp is not None:
1316    rank = rank_temp
1317    if rank != 1:
1318      raise ValueError('x must be a one-dimensional array')
1319  else:
1320    control_flow_ops.Assert(math_ops.equal(rank, 1), [rank])
1321
1322  if increasing:
1323    start = 0
1324    limit = N
1325    delta = 1
1326  else:
1327    start = N - 1
1328    limit = -1
1329    delta = -1
1330
1331  x = array_ops.expand_dims(x, -1)
1332  return math_ops.pow(
1333      x, math_ops.cast(math_ops.range(start, limit, delta), dtype=x.dtype))
1334
1335
1336@np_utils.np_doc('ix_')
1337def ix_(*args):  # pylint: disable=missing-docstring
1338  n = len(args)
1339  output = []
1340  for i, a in enumerate(args):
1341    a = asarray(a)
1342    a_rank = array_ops.rank(a)
1343    a_rank_temp = np_utils.get_static_value(a_rank)
1344    if a_rank_temp is not None:
1345      a_rank = a_rank_temp
1346      if a_rank != 1:
1347        raise ValueError('Arguments must be 1-d, got arg {} of rank {}'.format(
1348            i, a_rank))
1349    else:
1350      control_flow_ops.Assert(math_ops.equal(a_rank, 1), [a_rank])
1351
1352    new_shape = [1] * n
1353    new_shape[i] = -1
1354    dtype = a.dtype
1355    if dtype == dtypes.bool:
1356      output.append(array_ops.reshape(nonzero(a)[0], new_shape))
1357    elif dtype.is_integer:
1358      output.append(array_ops.reshape(a, new_shape))
1359    else:
1360      raise ValueError(
1361          'Only integer and bool dtypes are supported, got {}'.format(dtype))
1362
1363  return output
1364
1365
1366@np_utils.np_doc('broadcast_arrays')
1367def broadcast_arrays(*args, **kwargs):  # pylint: disable=missing-docstring
1368  subok = kwargs.pop('subok', False)
1369  if subok:
1370    raise ValueError('subok=True is not supported.')
1371  if kwargs:
1372    raise ValueError('Received unsupported arguments {}'.format(kwargs.keys()))
1373
1374  args = [asarray(arg) for arg in args]
1375  return np_utils.tf_broadcast(*args)
1376
1377
1378@np_utils.np_doc_only('sign')
1379def sign(x, out=None, where=None, **kwargs):  # pylint: disable=missing-docstring,redefined-outer-name
1380  if out:
1381    raise ValueError('tf.numpy doesnt support setting out.')
1382  if where:
1383    raise ValueError('tf.numpy doesnt support setting where.')
1384  if kwargs:
1385    raise ValueError('tf.numpy doesnt support setting {}'.format(kwargs.keys()))
1386
1387  x = asarray(x)
1388  dtype = x.dtype.as_numpy_dtype
1389  if np.issubdtype(dtype, np.complex):
1390    result = math_ops.cast(math_ops.sign(math_ops.real(x)), dtype)
1391  else:
1392    result = math_ops.sign(x)
1393
1394  return result
1395
1396
1397# Note that np.take_along_axis may not be present in some supported versions of
1398# numpy.
1399@np_utils.np_doc('take_along_axis')
1400def take_along_axis(arr, indices, axis):  # pylint: disable=missing-docstring
1401  arr = asarray(arr)
1402  indices = asarray(indices)
1403
1404  if axis is None:
1405    return take_along_axis(arr.ravel(), indices, 0)
1406
1407  rank = array_ops.rank(arr)
1408  axis = axis + rank if axis < 0 else axis
1409
1410  # Broadcast shapes to match, ensure that the axis of interest is not
1411  # broadcast.
1412  arr_shape_original = array_ops.shape(arr)
1413  indices_shape_original = array_ops.shape(indices)
1414  arr_shape = array_ops.tensor_scatter_update(arr_shape_original, [[axis]], [1])
1415  indices_shape = array_ops.tensor_scatter_update(indices_shape_original,
1416                                                  [[axis]], [1])
1417  broadcasted_shape = array_ops.broadcast_dynamic_shape(arr_shape,
1418                                                        indices_shape)
1419  arr_shape = array_ops.tensor_scatter_update(broadcasted_shape, [[axis]],
1420                                              [arr_shape_original[axis]])
1421  indices_shape = array_ops.tensor_scatter_update(
1422      broadcasted_shape, [[axis]], [indices_shape_original[axis]])
1423  arr = array_ops.broadcast_to(arr, arr_shape)
1424  indices = array_ops.broadcast_to(indices, indices_shape)
1425
1426  # Save indices shape so we can restore it later.
1427  possible_result_shape = indices.shape
1428
1429  # Correct indices since gather doesn't correctly handle negative indices.
1430  indices = array_ops.where_v2(indices < 0, indices + arr_shape[axis], indices)
1431
1432  swapaxes_ = lambda t: swapaxes(t, axis, -1)
1433
1434  dont_move_axis_to_end = math_ops.equal(axis, np_utils.subtract(rank, 1))
1435  arr = np_utils.cond(dont_move_axis_to_end, lambda: arr,
1436                      lambda: swapaxes_(arr))
1437  indices = np_utils.cond(dont_move_axis_to_end, lambda: indices,
1438                          lambda: swapaxes_(indices))
1439
1440  arr_shape = array_ops.shape(arr)
1441  arr = array_ops.reshape(arr, [-1, arr_shape[-1]])
1442
1443  indices_shape = array_ops.shape(indices)
1444  indices = array_ops.reshape(indices, [-1, indices_shape[-1]])
1445
1446  result = array_ops.gather(arr, indices, batch_dims=1)
1447  result = array_ops.reshape(result, indices_shape)
1448  result = np_utils.cond(dont_move_axis_to_end, lambda: result,
1449                         lambda: swapaxes_(result))
1450  result.set_shape(possible_result_shape)
1451
1452  return result
1453
1454
1455_SLICE_ERORR = (
1456    'only integers, slices (`:`), ellipsis (`...`), '
1457    'numpy.newaxis (`None`) and integer or boolean arrays are valid indices')
1458
1459
1460def _as_index(idx, need_scalar=True):
1461  """Helper function to parse idx as an index.
1462
1463  Args:
1464    idx: index
1465    need_scalar: If idx needs to be a scalar value.
1466
1467  Returns:
1468    A pair, (indx, bool). First one is the parsed index and can be a tensor,
1469    or scalar integer / Dimension. Second one is True if rank is known to be 0.
1470
1471  Raises:
1472    IndexError: For incorrect indices.
1473  """
1474  if isinstance(idx, (numbers.Integral, tensor_shape.Dimension)):
1475    return idx, True
1476  data = asarray(idx)
1477  if data.dtype == dtypes.bool:
1478    if data.shape.ndims != 1:
1479      # TODO(agarwal): handle higher rank boolean masks.
1480      raise NotImplementedError('Need rank 1 for bool index %s' % idx)
1481    data = array_ops.where_v2(data)
1482    data = array_ops.reshape(data, [-1])
1483  if need_scalar and data.shape.rank not in (None, 0):
1484    raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx))
1485  np_dtype = data.dtype.as_numpy_dtype
1486  if not np.issubdtype(np_dtype, np.integer):
1487    raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx))
1488  if data.dtype not in (dtypes.int64, dtypes.int32):
1489    # TF slicing can only handle int32/int64. So we need to cast.
1490    promoted_dtype = np.promote_types(np.int32, np_dtype)
1491    if promoted_dtype == np.int32:
1492      data = math_ops.cast(data, dtypes.int32)
1493    elif promoted_dtype == np.int64:
1494      data = math_ops.cast(data, dtypes.int64)
1495    else:
1496      raise IndexError(_SLICE_ERORR + ', got {!r}'.format(idx))
1497  return data, data.shape.rank == 0
1498
1499
1500class _UpdateMethod(enum.Enum):
1501  UPDATE = 0
1502  ADD = 1
1503  MIN = 2
1504  MAX = 3
1505
1506
1507def _slice_helper(tensor, slice_spec, update_method=None, updates=None):
1508  """Helper function for __getitem__ and _with_index_update_helper.
1509
1510  This function collects the indices in `slice_spec` into two buckets, which we
1511  can call "idx1" and "idx2" here. idx1 is intended for `strided_slice`, idx2
1512  `gather`.  They also correspond to "basic indices" and "advanced indices" in
1513  numpy.  This function supports both reading and writing at the indices. The
1514  reading path can be summarized as `gather(stride_slice(tensor, idx1),
1515  idx2)`. The writing path can be summarized as `strided_slice_update(tensor,
1516  idx1, scatter(strided_slice(tensor, idx1), idx2, updates))`.  (`gather` here
1517  means `tf.gather` or `tf.gather_nd`; `scatter` here means
1518  `tf.tensor_scatter_update`.)  The writing path is inefficient because it needs
1519  to first read out a portion (probably much larger than `updates`) of `tensor`
1520  using `strided_slice`, update it, and then write the portion back. An
1521  alternative approach is to only use `scatter`, which amounts to using the
1522  indexing mechanism of gather/scatter to implement
1523  strided_slice/strided_slice_update. This is feasible for XLA Gather/Scatter
1524  because they support spans (e.g. `2:5`) in indices (as begin/end pairs), but
1525  not TF gather/scatter because they don't support spans (except those that
1526  cover entire dimensions, i.e. `:`).  If we materialize spans into individual
1527  indices, the size of the index tensor would explode.  (Note that XLA
1528  Gather/Scatter have a similar problem for stride > 1 because they don't
1529  support strides.  Indices such as `1:2:8` will need to be materialized into
1530  individual indices such as [1, 3, 5, 7].)
1531
1532  Args:
1533    tensor: the tensor to be read from or write into.
1534    slice_spec: the indices.
1535    update_method: (optional) a member of `_UpdateMethod`, indicating how to
1536      update the values (replacement, add, etc.). `None` indicates just reading.
1537    updates: (optional) the new values to write into `tensor`. It must have the
1538      same dtype as `tensor`.
1539
1540  Returns:
1541    The result of reading (if `update_method` is `None`) or the updated `tensor`
1542    after writing.
1543  """
1544  begin, end, strides = [], [], []
1545  new_axis_mask, shrink_axis_mask = 0, 0
1546  begin_mask, end_mask = 0, 0
1547  ellipsis_mask = 0
1548  advanced_indices = []
1549  shrink_indices = []
1550  for index, s in enumerate(slice_spec):
1551    if isinstance(s, slice):
1552      if s.start is not None:
1553        begin.append(_as_index(s.start)[0])
1554      else:
1555        begin.append(0)
1556        begin_mask |= (1 << index)
1557      if s.stop is not None:
1558        end.append(_as_index(s.stop)[0])
1559      else:
1560        end.append(0)
1561        end_mask |= (1 << index)
1562      if s.step is not None:
1563        strides.append(_as_index(s.step)[0])
1564      else:
1565        strides.append(1)
1566    elif s is Ellipsis:
1567      begin.append(0)
1568      end.append(0)
1569      strides.append(1)
1570      ellipsis_mask |= (1 << index)
1571    elif s is array_ops.newaxis:
1572      begin.append(0)
1573      end.append(0)
1574      strides.append(1)
1575      new_axis_mask |= (1 << index)
1576    else:
1577      s, is_scalar = _as_index(s, False)
1578      if is_scalar:
1579        begin.append(s)
1580        end.append(s + 1)
1581        strides.append(1)
1582        shrink_axis_mask |= (1 << index)
1583        shrink_indices.append(index)
1584      else:
1585        begin.append(0)
1586        end.append(0)
1587        strides.append(1)
1588        begin_mask |= (1 << index)
1589        end_mask |= (1 << index)
1590        advanced_indices.append((index, s, ellipsis_mask != 0))
1591
1592  # stack possibly involves no tensors, so we must use op_scope correct graph.
1593  with ops.name_scope(
1594      None,
1595      'strided_slice', [tensor] + begin + end + strides,
1596      skip_on_eager=False) as name:
1597    if begin:
1598      packed_begin, packed_end, packed_strides = (array_ops.stack(begin),
1599                                                  array_ops.stack(end),
1600                                                  array_ops.stack(strides))
1601      if (packed_begin.dtype == dtypes.int64 or
1602          packed_end.dtype == dtypes.int64 or
1603          packed_strides.dtype == dtypes.int64):
1604        if packed_begin.dtype != dtypes.int64:
1605          packed_begin = math_ops.cast(packed_begin, dtypes.int64)
1606        if packed_end.dtype != dtypes.int64:
1607          packed_end = math_ops.cast(packed_end, dtypes.int64)
1608        if packed_strides.dtype != dtypes.int64:
1609          packed_strides = math_ops.cast(packed_strides, dtypes.int64)
1610    else:
1611      var_empty = constant_op.constant([], dtype=dtypes.int32)
1612      packed_begin = packed_end = packed_strides = var_empty
1613    if update_method == _UpdateMethod.UPDATE and not advanced_indices:
1614      return array_ops.tensor_strided_slice_update(
1615          tensor,
1616          packed_begin,
1617          packed_end,
1618          packed_strides,
1619          updates,
1620          begin_mask=begin_mask,
1621          end_mask=end_mask,
1622          shrink_axis_mask=shrink_axis_mask,
1623          new_axis_mask=new_axis_mask,
1624          ellipsis_mask=ellipsis_mask,
1625          name=name)
1626    else:
1627      # TODO(b/164251540): Find a better way to support update that does not
1628      #   involve one read + two writes.
1629      if updates is not None:
1630        original_tensor = tensor
1631      # TODO(agarwal): set_shape on tensor to set rank.
1632      tensor = array_ops.strided_slice(
1633          tensor,
1634          packed_begin,
1635          packed_end,
1636          packed_strides,
1637          begin_mask=begin_mask,
1638          end_mask=end_mask,
1639          shrink_axis_mask=shrink_axis_mask,
1640          new_axis_mask=new_axis_mask,
1641          ellipsis_mask=ellipsis_mask,
1642          name=name)
1643    if not advanced_indices:
1644      if update_method is None:
1645        return tensor
1646      assert update_method != _UpdateMethod.UPDATE
1647      # TF lacks TensorStridedSliceAdd and alike, so we need to do
1648      # read+add+update.
1649      if update_method == _UpdateMethod.ADD:
1650        update_op = math_ops.add
1651      elif update_method == _UpdateMethod.MIN:
1652        update_op = math_ops.minimum
1653      elif update_method == _UpdateMethod.MAX:
1654        update_op = math_ops.maximum
1655      return array_ops.tensor_strided_slice_update(
1656          original_tensor,
1657          packed_begin,
1658          packed_end,
1659          packed_strides,
1660          update_op(tensor, updates),
1661          begin_mask=begin_mask,
1662          end_mask=end_mask,
1663          shrink_axis_mask=shrink_axis_mask,
1664          new_axis_mask=new_axis_mask,
1665          ellipsis_mask=ellipsis_mask,
1666          name=name + '_2')
1667    advanced_indices_map = {}
1668    for index, data, had_ellipsis in advanced_indices:
1669      if had_ellipsis:
1670        num_shrink = len([x for x in shrink_indices if x > index])
1671        dim = index - len(slice_spec) + num_shrink
1672      else:
1673        num_shrink = len([x for x in shrink_indices if x < index])
1674        dim = index - num_shrink
1675      advanced_indices_map[dim] = data
1676    dims = sorted(advanced_indices_map.keys())
1677    dims_contiguous = True
1678    if len(dims) > 1:
1679      if dims[0] < 0 and dims[-1] >= 0:  # not all same sign
1680        dims_contiguous = False
1681      else:
1682        for i in range(len(dims) - 1):
1683          if dims[i] + 1 != dims[i + 1]:
1684            dims_contiguous = False
1685            break
1686    indices = [advanced_indices_map[x] for x in dims]
1687    indices = _promote_dtype(*indices)
1688    indices = np_utils.tf_broadcast(*indices)
1689    stacked_indices = array_ops.stack(indices, axis=-1)
1690    # Skip the contiguous-dims optimization for update because there is no
1691    # tf.*scatter* op that supports the `axis` argument.
1692    if not dims_contiguous or updates is not None:
1693      if range(len(dims)) != dims:
1694        tensor = moveaxis(tensor, dims, range(len(dims)))
1695      tensor_shape_prefix = array_ops.shape(
1696          tensor, out_type=stacked_indices.dtype)[:len(dims)]
1697      stacked_indices = array_ops.where_v2(
1698          stacked_indices < 0, stacked_indices + tensor_shape_prefix,
1699          stacked_indices)
1700      if updates is None:
1701        return array_ops.gather_nd(tensor, stacked_indices)
1702      else:
1703        # We only need to move-axis `updates` in the contiguous case becausce
1704        # only in this case the result dimensions of advanced indexing are in
1705        # the middle of `updates`. In the non-contiguous case, those dimensions
1706        # are always at the front.
1707        if dims_contiguous:
1708          # TODO(wangpeng): Support unknown rank (e.g. by partially flattening
1709          #   `updates`)
1710          if stacked_indices.shape.rank is None:
1711            raise NotImplementedError(
1712                'Rank of the advanced indices must currently be known')
1713          batch_size = stacked_indices.shape.rank - 1
1714          batch_start = dims[0]
1715          if batch_start < 0:
1716            batch_start += len(dims) - batch_size
1717          def range_(start, length):
1718            return range(start, start + length)
1719          updates = moveaxis(updates, range_(batch_start, batch_size),
1720                             range(batch_size))
1721        if update_method == _UpdateMethod.UPDATE:
1722          update_op = array_ops.tensor_scatter_update
1723        elif update_method == _UpdateMethod.ADD:
1724          update_op = array_ops.tensor_scatter_add
1725        elif update_method == _UpdateMethod.MIN:
1726          update_op = array_ops.tensor_scatter_min
1727        elif update_method == _UpdateMethod.MAX:
1728          update_op = array_ops.tensor_scatter_max
1729        tensor = update_op(
1730            tensor, stacked_indices, updates)
1731        if range(len(dims)) != dims:
1732          tensor = moveaxis(tensor, range(len(dims)), dims)
1733        return array_ops.tensor_strided_slice_update(
1734            original_tensor,
1735            packed_begin,
1736            packed_end,
1737            packed_strides,
1738            tensor,
1739            begin_mask=begin_mask,
1740            end_mask=end_mask,
1741            shrink_axis_mask=shrink_axis_mask,
1742            new_axis_mask=new_axis_mask,
1743            ellipsis_mask=ellipsis_mask,
1744            name=name + '_2')
1745    # Note that gather_nd does not support gathering from inside the array.
1746    # To avoid shuffling data back and forth, we transform the indices and
1747    # do a gather instead.
1748    rank = np_utils._maybe_static(array_ops.rank(tensor))  # pylint: disable=protected-access
1749    dims = [(x + rank if x < 0 else x) for x in dims]
1750    shape_tensor = array_ops.shape(tensor)
1751    dim_sizes = array_ops.gather(shape_tensor, dims)
1752    if len(dims) == 1:
1753      stacked_indices = indices[0]
1754    stacked_indices = math_ops.cast(stacked_indices, dtypes.int32)
1755    stacked_indices = array_ops.where_v2(stacked_indices < 0,
1756                                         stacked_indices + dim_sizes,
1757                                         stacked_indices)
1758    axis = dims[0]
1759    if len(dims) > 1:
1760      index_scaling = math_ops.cumprod(
1761          dim_sizes, reverse=True, exclusive=True)
1762      def _tensordot(a, b):
1763        # TODO(b/168657656): This function should be replaced by
1764        # tensordot(axis=1) once MatMul has int32 XLA kernel.
1765        b = array_ops.broadcast_to(b, array_ops.shape(a))
1766        return math_ops.reduce_sum(a * b, axis=-1)
1767      stacked_indices = _tensordot(stacked_indices, index_scaling)
1768      flat_shape = array_ops.concat(
1769          [shape_tensor[:axis], [-1], shape_tensor[axis + len(dims):]],
1770          axis=0)
1771      tensor = array_ops.reshape(tensor, flat_shape)
1772
1773    return array_ops.gather(tensor, stacked_indices, axis=axis)
1774
1775
1776def _as_spec_tuple(slice_spec):
1777  """Convert slice_spec to tuple."""
1778  if isinstance(slice_spec,
1779                (list, tuple)) and not isinstance(slice_spec, np.ndarray):
1780    is_index = True
1781    for s in slice_spec:
1782      if s is None or s is Ellipsis or isinstance(s, (list, tuple, slice)):
1783        is_index = False
1784        break
1785      elif isinstance(s, (np_arrays.ndarray, np.ndarray)) and s.ndim != 0:
1786        is_index = False
1787        break
1788    if not is_index:
1789      return tuple(slice_spec)
1790  return (slice_spec,)
1791
1792
1793def _getitem(self, slice_spec):
1794  """Implementation of ndarray.__getitem__."""
1795  if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and
1796                                       slice_spec.dtype == dtypes.bool) or
1797      (isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and
1798       slice_spec.dtype == np.bool)):
1799    return array_ops.boolean_mask(tensor=self, mask=slice_spec)
1800
1801  if not isinstance(slice_spec, tuple):
1802    slice_spec = _as_spec_tuple(slice_spec)
1803
1804  result_t = _slice_helper(self, slice_spec)
1805  return result_t
1806
1807
1808def _with_index_update_helper(update_method, a, slice_spec, updates):
1809  """Implementation of ndarray._with_index_*."""
1810  if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and
1811                                       slice_spec.dtype == dtypes.bool) or
1812      (isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and
1813       slice_spec.dtype == np.bool)):
1814    slice_spec = nonzero(slice_spec)
1815
1816  if not isinstance(slice_spec, tuple):
1817    slice_spec = _as_spec_tuple(slice_spec)
1818
1819  a_dtype = a.dtype
1820  a, updates = _promote_dtype_binary(a, updates)
1821  result_t = _slice_helper(a, slice_spec, update_method, updates)
1822  return result_t.astype(a_dtype)
1823
1824
1825setattr(np_arrays.ndarray, '_numpy_style_getitem', _getitem)
1826setattr(np_arrays.ndarray, '_with_index_update',
1827        functools.partial(_with_index_update_helper, _UpdateMethod.UPDATE))
1828setattr(np_arrays.ndarray, '_with_index_add',
1829        functools.partial(_with_index_update_helper, _UpdateMethod.ADD))
1830setattr(np_arrays.ndarray, '_with_index_min',
1831        functools.partial(_with_index_update_helper, _UpdateMethod.MIN))
1832setattr(np_arrays.ndarray, '_with_index_max',
1833        functools.partial(_with_index_update_helper, _UpdateMethod.MAX))
1834