1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions for internal use."""
16# pylint: disable=g-direct-tensorflow-import
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import inspect
23import numbers
24import os
25import re
26import numpy as np
27
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import indexed_slices
30from tensorflow.python.framework import tensor_util
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops.numpy_ops import np_arrays
35from tensorflow.python.ops.numpy_ops import np_dtypes
36from tensorflow.python.ops.numpy_ops import np_export
37from tensorflow.python.types import core
38from tensorflow.python.util import nest
39
40
41def _canonicalize_axis(axis, rank):
42  return _canonicalize_axes([axis], rank)[0]
43
44
45def _canonicalize_axes(axes, rank):
46  rank = _maybe_static(rank)
47
48  if isinstance(rank, core.Tensor):
49    canonicalizer = (
50        lambda axis: cond(axis < 0, lambda: axis + rank, lambda: axis))
51  else:
52    canonicalizer = lambda axis: axis + rank if axis < 0 else axis
53
54  return [canonicalizer(axis) for axis in axes]
55
56
57def _supports_signature():
58  return hasattr(inspect, 'signature')
59
60
61def _to_tf_type(dtype):
62  """Converts a native python or numpy type to TF DType.
63
64  Args:
65    dtype: Could be a python type, a numpy type or a TF DType.
66
67  Returns:
68    A tensorflow `DType`.
69  """
70  return dtypes.as_dtype(dtype)
71
72
73def _to_numpy_type(dtype):
74  """Converts a native python or TF DType to numpy type.
75
76  Args:
77    dtype: Could be a python type, a numpy type or a TF DType.
78
79  Returns:
80    A NumPy `dtype`.
81  """
82  if isinstance(dtype, dtypes.DType):
83    return dtype.as_numpy_dtype
84  return np.dtype(dtype)
85
86
87def isscalar(val):
88  """Returns whether `val` is a scalar value or scalar Tensor."""
89  if isinstance(val, np_arrays.ndarray):
90    val = val.data
91  if isinstance(val, core.Tensor):
92    ndims = val.shape.ndims
93    if ndims is not None:
94      return ndims == 0
95    else:
96      return math_ops.equal(array_ops.rank(val), 0)
97  else:
98    return np.isscalar(val)
99
100
101def _has_docstring(f):
102  return (f and hasattr(f, '__doc__') and isinstance(f.__doc__, str) and
103          f.__doc__)
104
105
106def _add_blank_line(s):
107  if s.endswith('\n'):
108    return s + '\n'
109  else:
110    return s + '\n\n'
111
112
113def _np_signature(f):
114  """An enhanced inspect.signature that can handle numpy.ufunc."""
115  # TODO(wangpeng): consider migrating away from inspect.signature.
116  # inspect.signature is supported in Python 3.3.
117  if not hasattr(inspect, 'signature'):
118    return None
119  if f is None:
120    return None
121  if not isinstance(f, np.ufunc):
122    try:
123      return inspect.signature(f)
124    except ValueError:
125      return None
126
127  def names_from_num(prefix, n):
128    if n <= 0:
129      return []
130    elif n == 1:
131      return [prefix]
132    else:
133      return [prefix + str(i + 1) for i in range(n)]
134
135  input_names = names_from_num('x', f.nin)
136  output_names = names_from_num('out', f.nout)
137  keyword_only_params = [('where', True), ('casting', 'same_kind'),
138                         ('order', 'K'), ('dtype', None), ('subok', True),
139                         ('signature', None), ('extobj', None)]
140  params = []
141  params += [
142      inspect.Parameter(name, inspect.Parameter.POSITIONAL_ONLY)
143      for name in input_names
144  ]
145  if f.nout > 1:
146    params += [
147        inspect.Parameter(
148            name, inspect.Parameter.POSITIONAL_ONLY, default=None)
149        for name in output_names
150    ]
151  params += [
152      inspect.Parameter(
153          'out',
154          inspect.Parameter.POSITIONAL_OR_KEYWORD,
155          default=None if f.nout == 1 else (None,) * f.nout)
156  ]
157  params += [
158      inspect.Parameter(name, inspect.Parameter.KEYWORD_ONLY, default=default)
159      for name, default in keyword_only_params
160  ]
161  return inspect.Signature(params)
162
163
164# Python 2 doesn't allow keyword-only argument. Python prior to 3.8 doesn't
165# allow positional-only argument. So we conflate positional-only, keyword-only
166# and positional-or-keyword arguments here.
167def _is_compatible_param_kind(a, b):
168
169  def relax(k):
170    if k in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.KEYWORD_ONLY):
171      return inspect.Parameter.POSITIONAL_OR_KEYWORD
172    return k
173
174  return relax(a) == relax(b)
175
176
177def _prepare_np_fun_name_and_fun(np_fun_name, np_fun):
178  """Mutually propagates information between `np_fun_name` and `np_fun`.
179
180  If one is None and the other is not, we'll try to make the former not None in
181  a best effort.
182
183  Args:
184    np_fun_name: name for the np_fun symbol. At least one of np_fun or
185      np_fun_name shoud be set.
186    np_fun: the numpy function whose docstring will be used.
187
188  Returns:
189    Processed `np_fun_name` and `np_fun`.
190  """
191  if np_fun_name is not None:
192    assert isinstance(np_fun_name, str)
193  if np_fun is not None:
194    assert not isinstance(np_fun, str)
195  if np_fun is None:
196    assert np_fun_name is not None
197    try:
198      np_fun = getattr(np, str(np_fun_name))
199    except AttributeError:
200      np_fun = None
201  if np_fun_name is None:
202    assert np_fun is not None
203    np_fun_name = np_fun.__name__
204  return np_fun_name, np_fun
205
206
207def _np_doc_helper(f, np_f, np_fun_name=None, unsupported_params=None,
208                   link=None):
209  """Helper to get docs."""
210  assert np_f or np_fun_name
211  if not np_fun_name:
212    np_fun_name = np_f.__name__
213  doc = 'TensorFlow variant of NumPy\'s `%s`.\n\n' % np_fun_name
214  if unsupported_params:
215    doc += 'Unsupported arguments: ' + ', '.join(
216        '`' + name + '`' for name in unsupported_params) + '.\n\n'
217  if _has_docstring(f):
218    doc += f.__doc__
219    doc = _add_blank_line(doc)
220  # TODO(wangpeng): Re-enable the following and choose inlined vs. link to numpy
221  #   doc according to some global switch.
222  doc = _add_np_doc(doc, np_fun_name, np_f, link=link)
223  return doc
224
225
226_np_doc_form = os.getenv('TF_NP_DOC_FORM', '1.16')
227
228
229def get_np_doc_form():
230  """Gets the form of the original numpy docstrings.
231
232  Returns:
233    See `set_np_doc_form` for the list of valid values.
234  """
235  return _np_doc_form
236
237
238def set_np_doc_form(value):
239  r"""Selects the form of the original numpy docstrings.
240
241  This function sets a global variable that controls how a tf-numpy symbol's
242  docstring should refer to the original numpy docstring. If `value` is
243  `'inlined'`, the numpy docstring will be verbatim copied into the tf-numpy
244  docstring. Otherwise, a link to the original numpy docstring will be
245  added. Which numpy version the link points to depends on `value`:
246  * `'stable'`: the current stable version;
247  * `'dev'`: the current development version;
248  * pattern `\d+(\.\d+(\.\d+)?)?`: `value` will be treated as a version number,
249    e.g. '1.16'.
250
251  Args:
252    value: the value to set the global variable to.
253  """
254  global _np_doc_form
255  _np_doc_form = value
256
257
258class Link:
259
260  def __init__(self, v):
261    self.value = v
262
263
264class AliasOf:
265
266  def __init__(self, v):
267    self.value = v
268
269
270class NoLink:
271  pass
272
273
274def generate_link(flag, np_fun_name):
275  """Generates link from numpy function name.
276
277  Args:
278    flag: the flag to control link form. See `set_np_doc_form`.
279    np_fun_name: the numpy function name.
280
281  Returns:
282    A string.
283  """
284  # Only adds link in this case
285  if flag == 'dev':
286    template = 'https://numpy.org/devdocs/reference/generated/numpy.%s.html'
287  elif flag == 'stable':
288    template = (
289        'https://numpy.org/doc/stable/reference/generated/numpy.%s.html')
290  elif re.match(r'\d+(\.\d+(\.\d+)?)?$', flag):
291    # `flag` is the version number
292    template = ('https://numpy.org/doc/' + flag +
293                '/reference/generated/numpy.%s.html')
294  else:
295    return None
296  return template % np_fun_name
297
298
299_is_check_link = (os.getenv('TF_NP_CHECK_LINK', 'False') in
300                  ('True', 'true', '1'))
301
302
303def is_check_link():
304  return _is_check_link
305
306
307def set_check_link(value):
308  global _is_check_link
309  _is_check_link = value
310
311
312def _add_np_doc(doc, np_fun_name, np_f, link):
313  """Appends the numpy docstring to `doc`, according to `set_np_doc_form`.
314
315  See `set_np_doc_form` for how it controls the form of the numpy docstring.
316
317  Args:
318    doc: the docstring to be appended to.
319    np_fun_name: the name of the numpy function.
320    np_f: (optional) the numpy function.
321    link: (optional) which link to use. See `np_doc` for details.
322
323  Returns:
324    `doc` with numpy docstring appended.
325  """
326  flag = get_np_doc_form()
327  if flag == 'inlined':
328    if _has_docstring(np_f):
329      doc += 'Documentation for `numpy.%s`:\n\n' % np_fun_name
330      # TODO(wangpeng): It looks like code snippets in numpy doc don't work
331      # correctly with doctest. Fix that and remove the reformatting of the np_f
332      # comment.
333      doc += np_f.__doc__.replace('>>>', '>')
334  elif isinstance(flag, str):
335    if link is None:
336      url = generate_link(flag, np_fun_name)
337    elif isinstance(link, AliasOf):
338      url = generate_link(flag, link.value)
339    elif isinstance(link, Link):
340      url = link.value
341    else:
342      url = None
343    if url is not None:
344      if is_check_link():
345        # Imports locally because some builds may not have `requests`
346        import requests  # pylint: disable=g-import-not-at-top
347        r = requests.head(url)
348        if r.status_code != 200:
349          raise ValueError("Can't open link for %s: %s" % (np_fun_name, url))
350      doc += 'See the NumPy documentation for [`numpy.%s`](%s).' % (
351          np_fun_name, url)
352  return doc
353
354
355_is_sig_mismatch_an_error = (
356    os.getenv('TF_NP_SIG_MISMATCH_IS_ERROR', 'False') in ('True', 'true', '1'))
357
358
359def is_sig_mismatch_an_error():
360  return _is_sig_mismatch_an_error
361
362
363def set_is_sig_mismatch_an_error(value):
364  global _is_sig_mismatch_an_error
365  _is_sig_mismatch_an_error = value
366
367
368def np_doc(np_fun_name, np_fun=None, export=True, link=None):
369  """Attachs numpy docstring to a function.
370
371  Args:
372    np_fun_name: name for the np_fun symbol. At least one of np_fun or
373      np_fun_name shoud be set.
374    np_fun: (optional) the numpy function whose docstring will be used.
375    export: whether to export this symbol under module
376      `tf.experimental.numpy`. Note that if `export` is `True`, `np_fun` must be
377      a function directly under the `numpy` module, not under any submodule of
378      `numpy` (e.g. `numpy.random`).
379    link: (optional) which link to use. If `None`, a default link generated from
380      `np_fun_name` will be used. If an instance of `AliasOf`, `link.value` will
381      be used in place of `np_fun_name` for the link generation. If an instance
382      of `Link`, `link.value` will be used as the whole link. If an instance of
383      `NoLink`, no link will be added.
384
385  Returns:
386    A function decorator that attaches the docstring from `np_fun` to the
387    decorated function.
388  """
389  np_fun_name, np_fun = _prepare_np_fun_name_and_fun(np_fun_name, np_fun)
390  np_sig = _np_signature(np_fun)
391
392  def decorator(f):
393    """The decorator."""
394    unsupported_params = []
395    if hasattr(inspect, 'signature') and np_sig is not None:
396      try:
397        sig = inspect.signature(f)
398      except ValueError:
399        sig = None
400      if sig is not None:
401        for name, param in sig.parameters.items():
402          np_param = np_sig.parameters.get(name)
403          if np_param is None:
404            if is_sig_mismatch_an_error():
405              raise TypeError(
406                  'Cannot find parameter "%s" in the numpy function\'s '
407                  'signature (which has these parameters: %s)' %
408                  (name, list(np_sig.parameters.keys())))
409            else:
410              continue
411          if (is_sig_mismatch_an_error() and
412              not _is_compatible_param_kind(param.kind, np_param.kind)):
413            raise TypeError(
414                'Parameter "%s" is of kind %s while in numpy it is of '
415                'kind %s' % (name, param.kind, np_param.kind))
416          has_default = (param.default != inspect.Parameter.empty)
417          np_has_default = (np_param.default != inspect.Parameter.empty)
418          if is_sig_mismatch_an_error() and has_default != np_has_default:
419            raise TypeError('Parameter "%s" should%s have a default value' %
420                            (name, '' if np_has_default else ' not'))
421        for name in np_sig.parameters:
422          if name not in sig.parameters:
423            unsupported_params.append(name)
424    f.__doc__ = _np_doc_helper(
425        f, np_fun, np_fun_name=np_fun_name,
426        unsupported_params=unsupported_params, link=link)
427    if export:
428      return np_export.np_export(np_fun_name)(f)
429    else:
430      return f
431
432  return decorator
433
434
435def np_doc_only(np_fun_name, np_fun=None, export=True):
436  """Attachs numpy docstring to a function.
437
438  This differs from np_doc in that it doesn't check for a match in signature.
439
440  Args:
441    np_fun_name: name for the np_fun symbol. At least one of np_fun or
442      np_fun_name shoud be set.
443    np_fun: (optional) the numpy function whose docstring will be used.
444    export: whether to export this symbol under module
445      `tf.experimental.numpy`. Note that if `export` is `True`, `np_f` must be a
446      function directly under the `numpy` module, not under any submodule of
447      `numpy` (e.g. `numpy.random`).
448
449  Returns:
450    A function decorator that attaches the docstring from `np_fun` to the
451    decorated function.
452  """
453  np_fun_name, np_fun = _prepare_np_fun_name_and_fun(np_fun_name, np_fun)
454
455  def decorator(f):
456    f.__doc__ = _np_doc_helper(f, np_fun, np_fun_name=np_fun_name)
457    if export:
458      return np_export.np_export(np_fun_name)(f)
459    else:
460      return f
461
462  return decorator
463
464
465# pylint: disable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-docstring-missing-newline,g-doc-return-or-yield,g-doc-args
466@np_doc('finfo')
467def finfo(dtype):
468  """Note that currently it just forwards to the numpy namesake, while
469  tensorflow and numpy dtypes may have different properties."""
470  return np.finfo(_to_numpy_type(dtype))
471# pylint: enable=g-short-docstring-punctuation,g-no-space-after-docstring-summary,g-docstring-missing-newline,g-doc-return-or-yield,g-doc-args
472
473
474def _maybe_get_dtype(x):
475  """Returns a numpy type if available from x. Skips if x is numpy.ndarray."""
476  # Don't put np.ndarray in this list, because np.result_type looks at the
477  # value (not just dtype) of np.ndarray to decide the result type.
478  if isinstance(x, numbers.Real):
479    return x
480  if isinstance(x, (core.Tensor, indexed_slices.IndexedSlices)):
481    return _to_numpy_type(x.dtype)
482  if isinstance(x, dtypes.DType):
483    return x.as_numpy_dtype
484  if isinstance(x, (list, tuple)):
485    raise ValueError('Got sequence')
486  return x
487
488
489# Can't use np_doc because np.result_type is a builtin function.
490@np_doc_only('result_type')
491def result_type(*arrays_and_dtypes):  # pylint: disable=missing-function-docstring
492  arrays_and_dtypes = [
493      _maybe_get_dtype(x) for x in nest.flatten(arrays_and_dtypes)
494  ]
495  if not arrays_and_dtypes:
496    # If arrays_and_dtypes is an empty list, let numpy decide what the dtype is.
497    arrays_and_dtypes = [np.asarray([])]
498  return np_dtypes._result_type(*arrays_and_dtypes)  # pylint: disable=protected-access
499
500
501def _result_type_binary(t1, t2):  # pylint: disable=missing-function-docstring
502  """A specialization of result_type for 2 arguments for performance reasons."""
503  try:
504    return np_dtypes._result_type(_maybe_get_dtype(t1),  # pylint: disable=protected-access
505                                  _maybe_get_dtype(t2))  # pylint: disable=protected-access
506  except ValueError:
507    return result_type(t1, t2)
508
509
510@np_doc('promote_types')
511def promote_types(type1, type2):  # pylint: disable=missing-function-docstring
512  type1 = _to_numpy_type(type1)
513  type2 = _to_numpy_type(type2)
514  return np_dtypes.canonicalize_dtype(np.promote_types(type1, type2))
515
516
517def tf_broadcast(*args):
518  """Broadcast tensors.
519
520  Args:
521    *args: a list of tensors whose shapes are broadcastable against each other.
522
523  Returns:
524    Tensors broadcasted to the common shape.
525  """
526  if len(args) <= 1:
527    return args
528  sh = array_ops.shape(args[0])
529  for arg in args[1:]:
530    sh = array_ops.broadcast_dynamic_shape(sh, array_ops.shape(arg))
531  return [array_ops.broadcast_to(arg, sh) for arg in args]
532
533
534# TODO(wangpeng): Move the following functions to a separate file and check for
535#   float dtypes in each of them.
536
537
538def get_static_value(x):
539  """A version of tf.get_static_value that returns None on float dtypes.
540
541  It returns None on float dtypes in order to avoid breaking gradients.
542
543  Args:
544    x: a tensor.
545
546  Returns:
547    Same as `tf.get_static_value`, except that it returns None when `x` has a
548    float dtype.
549  """
550  if isinstance(x, core.Tensor) and (x.dtype.is_floating or x.dtype.is_complex):
551    return None
552  return tensor_util.constant_value(x)
553
554
555def _maybe_static(x):
556  value = get_static_value(x)
557  if value is None:
558    return x
559  else:
560    return value
561
562
563# All the following functions exist becaues get_static_value can't handle
564# their TF counterparts.
565
566
567def cond(pred, true_fn, false_fn):
568  """A version of tf.cond that tries to evaluate the condition."""
569  v = get_static_value(pred)
570  if v is None:
571    return control_flow_ops.cond(pred, true_fn, false_fn)
572  if v:
573    return true_fn()
574  else:
575    return false_fn()
576
577
578def add(a, b):
579  """A version of tf.add that eagerly evaluates if possible."""
580  return _maybe_static(a) + _maybe_static(b)
581
582
583def subtract(a, b):
584  """A version of tf.subtract that eagerly evaluates if possible."""
585  return _maybe_static(a) - _maybe_static(b)
586
587
588def greater(a, b):
589  """A version of tf.greater that eagerly evaluates if possible."""
590  return _maybe_static(a) > _maybe_static(b)
591
592
593def greater_equal(a, b):
594  """A version of tf.greater_equal that eagerly evaluates if possible."""
595  return _maybe_static(a) >= _maybe_static(b)
596
597
598def less_equal(a, b):
599  """A version of tf.less_equal that eagerly evaluates if possible."""
600  return _maybe_static(a) <= _maybe_static(b)
601
602
603def logical_and(a, b):
604  """A version of tf.logical_and that eagerly evaluates if possible."""
605  a_value = get_static_value(a)
606  if a_value is not None:
607    if np.isscalar(a_value):
608      if a_value:
609        return _maybe_static(b)
610      else:
611        return a_value
612    else:
613      return a_value & _maybe_static(b)
614  else:
615    return a & _maybe_static(b)
616
617
618def logical_or(a, b):
619  """A version of tf.logical_or that eagerly evaluates if possible."""
620  a_value = get_static_value(a)
621  if a_value is not None:
622    if np.isscalar(a_value):
623      if a_value:
624        return a_value
625      else:
626        return _maybe_static(b)
627    else:
628      return a_value | _maybe_static(b)
629  else:
630    return a | _maybe_static(b)
631
632
633def getitem(a, slice_spec):
634  """A version of __getitem__ that eagerly evaluates if possible."""
635  return _maybe_static(a)[slice_spec]
636
637
638def reduce_all(input_tensor, axis=None, keepdims=False):
639  """A version of tf.reduce_all that eagerly evaluates if possible."""
640  v = get_static_value(input_tensor)
641  if v is None:
642    return math_ops.reduce_all(input_tensor, axis=axis, keepdims=keepdims)
643  else:
644    return v.all(axis=axis, keepdims=keepdims)
645
646
647def reduce_any(input_tensor, axis=None, keepdims=False):
648  """A version of tf.reduce_any that eagerly evaluates if possible."""
649  v = get_static_value(input_tensor)
650  if v is None:
651    return math_ops.reduce_any(input_tensor, axis=axis, keepdims=keepdims)
652  else:
653    return v.any(axis=axis, keepdims=keepdims)
654
655
656def tf_rank(t):
657  r = t.shape.rank
658  if r is not None:
659    return r
660  return array_ops.rank(t)
661