1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Tensor utility functions."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import functools
23import re
24
25from tensorflow.python.platform import tf_logging as logging
26from tensorflow.python.util import decorator_utils
27from tensorflow.python.util import is_in_graph_mode
28from tensorflow.python.util import tf_contextlib
29from tensorflow.python.util import tf_decorator
30from tensorflow.python.util import tf_inspect
31from tensorflow.python.util import tf_stack
32
33
34# Allow deprecation warnings to be silenced temporarily with a context manager.
35_PRINT_DEPRECATION_WARNINGS = True
36
37# Remember which deprecation warnings have been printed already.
38_PRINTED_WARNING = {}
39
40
41class DeprecatedNamesAlreadySet(Exception):
42  """Raised when setting deprecated names multiple times for the same symbol."""
43  pass
44
45
46def _add_deprecated_function_notice_to_docstring(doc, date, instructions):
47  """Adds a deprecation notice to a docstring for deprecated functions."""
48  main_text = ['THIS FUNCTION IS DEPRECATED. It will be removed %s.' %
49               ('in a future version' if date is None else ('after %s' % date))]
50  if instructions:
51    main_text.append('Instructions for updating:')
52  return decorator_utils.add_notice_to_docstring(
53      doc, instructions,
54      'DEPRECATED FUNCTION',
55      '(deprecated)', main_text)
56
57
58def _add_deprecated_arg_notice_to_docstring(doc, date, instructions,
59                                            deprecated_names):
60  """Adds a deprecation notice to a docstring for deprecated arguments."""
61
62  deprecation_string = ', '.join(sorted(deprecated_names))
63
64  return decorator_utils.add_notice_to_docstring(
65      doc, instructions, 'DEPRECATED FUNCTION ARGUMENTS',
66      '(deprecated arguments)', [
67          'SOME ARGUMENTS ARE DEPRECATED: `(%s)`. '
68          'They will be removed %s.' %
69          (deprecation_string, 'in a future version' if date is None else
70           ('after %s' % date)), 'Instructions for updating:'
71      ])
72
73
74def _add_deprecated_arg_value_notice_to_docstring(doc, date, instructions,
75                                                  deprecated_name_value_dict):
76  """Adds a deprecation notice to a docstring for deprecated arguments."""
77
78  deprecation_string = ', '.join(
79      '%s=%r' % (key, value)
80      for key, value in sorted(deprecated_name_value_dict.items()))
81
82  when = 'in a future version' if date is None else ('after %s' % date)
83
84  return decorator_utils.add_notice_to_docstring(
85      doc, instructions, 'DEPRECATED FUNCTION ARGUMENT VALUES',
86      '(deprecated argument values)', [
87          'SOME ARGUMENT VALUES ARE DEPRECATED: `(%s)`. '
88          'They will be removed %s.' % (deprecation_string, when),
89          'Instructions for updating:'
90      ])
91
92
93def _validate_deprecation_args(date, instructions):
94  if date is not None and not re.match(r'20\d\d-[01]\d-[0123]\d', date):
95    raise ValueError('Date must be YYYY-MM-DD.')
96  if not instructions:
97    raise ValueError('Don\'t deprecate things without conversion instructions!')
98
99
100def _call_location(outer=False):
101  """Returns call location given level up from current call."""
102  stack = tf_stack.extract_stack()
103  length = len(stack)
104  if length == 0:  # should never happen as we're in a function
105    return 'UNKNOWN'
106  index = length-4 if outer else length-3
107  if index < 0:
108    index = 0
109  frame = stack[index]
110  return '{filename}:{lineno}'.format(filename=frame[0], lineno=frame[1])
111
112
113def _wrap_decorator(wrapped_function):
114  """Indicate that one function wraps another.
115
116  This decorator wraps a function using `tf_decorator.make_decorator`
117  so that doc generation scripts can pick up original function
118  signature.
119  It would be better to use @functools.wrap decorator, but it would
120  not update function signature to match wrapped function in Python 2.
121
122  Args:
123    wrapped_function: The function that decorated function wraps.
124
125  Returns:
126    Function that accepts wrapper function as an argument and returns
127    `TFDecorator` instance.
128  """
129  def wrapper(wrapper_func):
130    return tf_decorator.make_decorator(wrapped_function, wrapper_func)
131  return wrapper
132
133
134def deprecated_alias(deprecated_name, name, func_or_class, warn_once=True):
135  """Deprecate a symbol in favor of a new name with identical semantics.
136
137  This function is meant to be used when defining a backwards-compatibility
138  alias for a symbol which has been moved. For example:
139
140  module1.py:
141  ```python
142  class NewNameForClass: pass
143  ```
144
145  module2.py:
146  ```python
147  import module1
148
149  DeprecatedNameForClass = deprecated_alias(
150    deprecated_name='module2.DeprecatedNameForClass',
151    name='module1.NewNameForClass',
152    module1.NewNameForClass)
153  ```
154
155  This function works for classes and functions.
156
157  For classes, it creates a new class which is functionally identical (it
158  inherits from the original, and overrides its constructor), but which prints
159  a deprecation warning when an instance is created. It also adds a deprecation
160  notice to the class' docstring.
161
162  For functions, it returns a function wrapped by `tf_decorator.make_decorator`.
163  That function prints a warning when used, and has a deprecation notice in its
164  docstring. This is more or less equivalent (the deprecation warning has
165  slightly different text) to writing:
166
167  ```python
168  @deprecated
169  def deprecated_alias(original_args):
170    real_function(original_args)
171  ```
172
173  Args:
174    deprecated_name: The name of the symbol that is being deprecated, to be used
175      in the warning message. This should be its fully qualified name to avoid
176      confusion.
177    name: The name of the symbol that is to be used instead of the deprecated
178      name. This should be a fully qualified name to avoid confusion.
179    func_or_class: The (non-deprecated) class or function for which a deprecated
180      alias should be created.
181    warn_once: If True (the default), only print a deprecation warning the first
182      time this function is used, or the class is instantiated.
183
184  Returns:
185    A wrapped version of `func_or_class` which prints a deprecation warning on
186    use and has a modified docstring.
187  """
188  if tf_inspect.isclass(func_or_class):
189
190    # Make a new class with __init__ wrapped in a warning.
191    class _NewClass(func_or_class):  # pylint: disable=missing-docstring
192      __doc__ = decorator_utils.add_notice_to_docstring(
193          func_or_class.__doc__, 'Please use %s instead.' % name,
194          'DEPRECATED CLASS',
195          '(deprecated)', ['THIS CLASS IS DEPRECATED. '
196                           'It will be removed in a future version. '])
197      __name__ = func_or_class.__name__
198      __module__ = _call_location(outer=True)
199
200      @_wrap_decorator(func_or_class.__init__)
201      def __init__(self, *args, **kwargs):
202        if hasattr(_NewClass.__init__, '__func__'):
203          # Python 2
204          _NewClass.__init__.__func__.__doc__ = func_or_class.__init__.__doc__
205        else:
206          # Python 3
207          _NewClass.__init__.__doc__ = func_or_class.__init__.__doc__
208
209        if _PRINT_DEPRECATION_WARNINGS:
210          # We're making the alias as we speak. The original may have other
211          # aliases, so we cannot use it to check for whether it's already been
212          # warned about.
213          if _NewClass.__init__ not in _PRINTED_WARNING:
214            if warn_once:
215              _PRINTED_WARNING[_NewClass.__init__] = True
216            logging.warning(
217                'From %s: The name %s is deprecated. Please use %s instead.\n',
218                _call_location(), deprecated_name, name)
219        super(_NewClass, self).__init__(*args, **kwargs)
220
221    return _NewClass
222  else:
223    decorator_utils.validate_callable(func_or_class, 'deprecated')
224
225    # Make a wrapper for the original
226    @functools.wraps(func_or_class)
227    def new_func(*args, **kwargs):  # pylint: disable=missing-docstring
228      if _PRINT_DEPRECATION_WARNINGS:
229        # We're making the alias as we speak. The original may have other
230        # aliases, so we cannot use it to check for whether it's already been
231        # warned about.
232        if new_func not in _PRINTED_WARNING:
233          if warn_once:
234            _PRINTED_WARNING[new_func] = True
235          logging.warning(
236              'From %s: The name %s is deprecated. Please use %s instead.\n',
237              _call_location(), deprecated_name, name)
238      return func_or_class(*args, **kwargs)
239    return tf_decorator.make_decorator(
240        func_or_class, new_func, 'deprecated',
241        _add_deprecated_function_notice_to_docstring(
242            func_or_class.__doc__, None, 'Please use %s instead.' % name))
243
244
245def deprecated_endpoints(*args):
246  """Decorator for marking endpoints deprecated.
247
248  This decorator does not print deprecation messages.
249  TODO(annarev): eventually start printing deprecation warnings when
250  @deprecation_endpoints decorator is added.
251
252  Args:
253    *args: Deprecated endpoint names.
254
255  Returns:
256    A function that takes symbol as an argument and adds
257    _tf_deprecated_api_names to that symbol.
258    _tf_deprecated_api_names would be set to a list of deprecated
259    endpoint names for the symbol.
260  """
261  def deprecated_wrapper(func):
262    # pylint: disable=protected-access
263    if '_tf_deprecated_api_names' in func.__dict__:
264      raise DeprecatedNamesAlreadySet(
265          'Cannot set deprecated names for %s to %s. '
266          'Deprecated names are already set to %s.' % (
267              func.__name__, str(args), str(func._tf_deprecated_api_names)))
268    func._tf_deprecated_api_names = args
269    # pylint: disable=protected-access
270    return func
271  return deprecated_wrapper
272
273
274def deprecated(date, instructions, warn_once=True):
275  """Decorator for marking functions or methods deprecated.
276
277  This decorator logs a deprecation warning whenever the decorated function is
278  called. It has the following format:
279
280    <function> (from <module>) is deprecated and will be removed after <date>.
281    Instructions for updating:
282    <instructions>
283
284  If `date` is None, 'after <date>' is replaced with 'in a future version'.
285  <function> will include the class name if it is a method.
286
287  It also edits the docstring of the function: ' (deprecated)' is appended
288  to the first line of the docstring and a deprecation notice is prepended
289  to the rest of the docstring.
290
291  Args:
292    date: String or None. The date the function is scheduled to be removed.
293      Must be ISO 8601 (YYYY-MM-DD), or None.
294    instructions: String. Instructions on how to update code using the
295      deprecated function.
296    warn_once: Boolean. Set to `True` to warn only the first time the decorated
297      function is called. Otherwise, every call will log a warning.
298
299  Returns:
300    Decorated function or method.
301
302  Raises:
303    ValueError: If date is not None or in ISO 8601 format, or instructions are
304      empty.
305  """
306  _validate_deprecation_args(date, instructions)
307
308  def deprecated_wrapper(func):
309    """Deprecation wrapper."""
310    decorator_utils.validate_callable(func, 'deprecated')
311    @functools.wraps(func)
312    def new_func(*args, **kwargs):  # pylint: disable=missing-docstring
313      if _PRINT_DEPRECATION_WARNINGS:
314        if func not in _PRINTED_WARNING:
315          if warn_once:
316            _PRINTED_WARNING[func] = True
317          logging.warning(
318              'From %s: %s (from %s) is deprecated and will be removed %s.\n'
319              'Instructions for updating:\n%s',
320              _call_location(), decorator_utils.get_qualified_name(func),
321              func.__module__,
322              'in a future version' if date is None else ('after %s' % date),
323              instructions)
324      return func(*args, **kwargs)
325    return tf_decorator.make_decorator(
326        func, new_func, 'deprecated',
327        _add_deprecated_function_notice_to_docstring(func.__doc__, date,
328                                                     instructions))
329  return deprecated_wrapper
330
331
332DeprecatedArgSpec = collections.namedtuple(
333    'DeprecatedArgSpec', ['position', 'has_ok_value', 'ok_value'])
334
335
336def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples,
337                    **kwargs):
338  """Decorator for marking specific function arguments as deprecated.
339
340  This decorator logs a deprecation warning whenever the decorated function is
341  called with the deprecated argument. It has the following format:
342
343    Calling <function> (from <module>) with <arg> is deprecated and will be
344    removed after <date>. Instructions for updating:
345      <instructions>
346
347  If `date` is None, 'after <date>' is replaced with 'in a future version'.
348  <function> includes the class name if it is a method.
349
350  It also edits the docstring of the function: ' (deprecated arguments)' is
351  appended to the first line of the docstring and a deprecation notice is
352  prepended to the rest of the docstring.
353
354  Args:
355    date: String or None. The date the function is scheduled to be removed.
356      Must be ISO 8601 (YYYY-MM-DD), or None.
357    instructions: String. Instructions on how to update code using the
358      deprecated function.
359    *deprecated_arg_names_or_tuples: String or 2-Tuple(String,
360      [ok_vals]).  The string is the deprecated argument name.
361      Optionally, an ok-value may be provided.  If the user provided
362      argument equals this value, the warning is suppressed.
363    **kwargs: If `warn_once=False` is passed, every call with a deprecated
364      argument will log a warning. The default behavior is to only warn the
365      first time the function is called with any given deprecated argument.
366      All other kwargs raise `ValueError`.
367
368  Returns:
369    Decorated function or method.
370
371  Raises:
372    ValueError: If date is not None or in ISO 8601 format, instructions are
373      empty, the deprecated arguments are not present in the function
374      signature, the second element of a deprecated_tuple is not a
375      list, or if a kwarg other than `warn_once` is passed.
376  """
377  _validate_deprecation_args(date, instructions)
378  if not deprecated_arg_names_or_tuples:
379    raise ValueError('Specify which argument is deprecated.')
380  if kwargs and list(kwargs.keys()) != ['warn_once']:
381    kwargs.pop('warn_once', None)
382    raise ValueError('Illegal argument to deprecated_args: %s' % kwargs)
383  warn_once = kwargs.get('warn_once', True)
384
385  def _get_arg_names_to_ok_vals():
386    """Returns a dict mapping arg_name to DeprecatedArgSpec w/o position."""
387    d = {}
388    for name_or_tuple in deprecated_arg_names_or_tuples:
389      if isinstance(name_or_tuple, tuple):
390        d[name_or_tuple[0]] = DeprecatedArgSpec(-1, True, name_or_tuple[1])
391      else:
392        d[name_or_tuple] = DeprecatedArgSpec(-1, False, None)
393    return d
394
395  def _get_deprecated_positional_arguments(names_to_ok_vals, arg_spec):
396    """Builds a dictionary from deprecated arguments to their spec.
397
398    Returned dict is keyed by argument name.
399    Each value is a DeprecatedArgSpec with the following fields:
400       position: The zero-based argument position of the argument
401         within the signature.  None if the argument isn't found in
402         the signature.
403       ok_values:  Values of this argument for which warning will be
404         suppressed.
405
406    Args:
407      names_to_ok_vals: dict from string arg_name to a list of values,
408        possibly empty, which should not elicit a warning.
409      arg_spec: Output from tf_inspect.getfullargspec on the called function.
410
411    Returns:
412      Dictionary from arg_name to DeprecatedArgSpec.
413    """
414    arg_name_to_pos = {
415        name: pos for pos, name in enumerate(arg_spec.args)}
416    deprecated_positional_args = {}
417    for arg_name, spec in iter(names_to_ok_vals.items()):
418      if arg_name in arg_name_to_pos:
419        pos = arg_name_to_pos[arg_name]
420        deprecated_positional_args[arg_name] = DeprecatedArgSpec(
421            pos, spec.has_ok_value, spec.ok_value)
422    return deprecated_positional_args
423
424  deprecated_arg_names = _get_arg_names_to_ok_vals()
425
426  def deprecated_wrapper(func):
427    """Deprecation decorator."""
428    decorator_utils.validate_callable(func, 'deprecated_args')
429
430    arg_spec = tf_inspect.getfullargspec(func)
431    deprecated_positions = _get_deprecated_positional_arguments(
432        deprecated_arg_names, arg_spec)
433
434    is_varargs_deprecated = arg_spec.varargs in deprecated_arg_names
435    is_kwargs_deprecated = arg_spec.varkw in deprecated_arg_names
436
437    if (len(deprecated_positions) + is_varargs_deprecated + is_kwargs_deprecated
438        != len(deprecated_arg_names_or_tuples)):
439      known_args = arg_spec.args + [arg_spec.varargs, arg_spec.varkw]
440      missing_args = [arg_name for arg_name in deprecated_arg_names
441                      if arg_name not in known_args]
442      raise ValueError('The following deprecated arguments are not present '
443                       'in the function signature: %s. '
444                       'Found next arguments: %s.' % (missing_args, known_args))
445
446    def _same_value(a, b):
447      """A comparison operation that works for multiple object types.
448
449      Returns True for two empty lists, two numeric values with the
450      same value, etc.
451
452      Returns False for (pd.DataFrame, None), and other pairs which
453      should not be considered equivalent.
454
455      Args:
456        a: value one of the comparison.
457        b: value two of the comparison.
458
459      Returns:
460        A boolean indicating whether the two inputs are the same value
461        for the purposes of deprecation.
462      """
463      if a is b:
464        return True
465      try:
466        equality = a == b
467        if isinstance(equality, bool):
468          return equality
469      except TypeError:
470        return False
471      return False
472
473    @functools.wraps(func)
474    def new_func(*args, **kwargs):
475      """Deprecation wrapper."""
476      # TODO(apassos) figure out a way to have reasonable performance with
477      # deprecation warnings and eager mode.
478      if is_in_graph_mode.IS_IN_GRAPH_MODE() and _PRINT_DEPRECATION_WARNINGS:
479        invalid_args = []
480        named_args = tf_inspect.getcallargs(func, *args, **kwargs)
481        for arg_name, spec in iter(deprecated_positions.items()):
482          if (spec.position < len(args) and
483              not (spec.has_ok_value and
484                   _same_value(named_args[arg_name], spec.ok_value))):
485            invalid_args.append(arg_name)
486        if is_varargs_deprecated and len(args) > len(arg_spec.args):
487          invalid_args.append(arg_spec.varargs)
488        if is_kwargs_deprecated and kwargs:
489          invalid_args.append(arg_spec.varkw)
490        for arg_name in deprecated_arg_names:
491          if (arg_name in kwargs and
492              not (deprecated_positions[arg_name].has_ok_value and
493                   _same_value(named_args[arg_name],
494                               deprecated_positions[arg_name].ok_value))):
495            invalid_args.append(arg_name)
496        for arg_name in invalid_args:
497          if (func, arg_name) not in _PRINTED_WARNING:
498            if warn_once:
499              _PRINTED_WARNING[(func, arg_name)] = True
500            logging.warning(
501                'From %s: calling %s (from %s) with %s is deprecated and will '
502                'be removed %s.\nInstructions for updating:\n%s',
503                _call_location(), decorator_utils.get_qualified_name(func),
504                func.__module__, arg_name,
505                'in a future version' if date is None else ('after %s' % date),
506                instructions)
507      return func(*args, **kwargs)
508
509    doc = _add_deprecated_arg_notice_to_docstring(
510        func.__doc__, date, instructions, sorted(deprecated_arg_names.keys()))
511    return tf_decorator.make_decorator(func, new_func, 'deprecated', doc)
512
513  return deprecated_wrapper
514
515
516def deprecated_arg_values(date, instructions, warn_once=True,
517                          **deprecated_kwargs):
518  """Decorator for marking specific function argument values as deprecated.
519
520  This decorator logs a deprecation warning whenever the decorated function is
521  called with the deprecated argument values. It has the following format:
522
523    Calling <function> (from <module>) with <arg>=<value> is deprecated and
524    will be removed after <date>. Instructions for updating:
525      <instructions>
526
527  If `date` is None, 'after <date>' is replaced with 'in a future version'.
528  <function> will include the class name if it is a method.
529
530  It also edits the docstring of the function: ' (deprecated arguments)' is
531  appended to the first line of the docstring and a deprecation notice is
532  prepended to the rest of the docstring.
533
534  Args:
535    date: String or None. The date the function is scheduled to be removed.
536      Must be ISO 8601 (YYYY-MM-DD), or None
537    instructions: String. Instructions on how to update code using the
538      deprecated function.
539    warn_once: If `True`, warn only the first time this function is called with
540      deprecated argument values. Otherwise, every call (with a deprecated
541      argument value) will log a warning.
542    **deprecated_kwargs: The deprecated argument values.
543
544  Returns:
545    Decorated function or method.
546
547  Raises:
548    ValueError: If date is not None or in ISO 8601 format, or instructions are
549      empty.
550  """
551  _validate_deprecation_args(date, instructions)
552  if not deprecated_kwargs:
553    raise ValueError('Specify which argument values are deprecated.')
554
555  def deprecated_wrapper(func):
556    """Deprecation decorator."""
557    decorator_utils.validate_callable(func, 'deprecated_arg_values')
558    @functools.wraps(func)
559    def new_func(*args, **kwargs):
560      """Deprecation wrapper."""
561      if _PRINT_DEPRECATION_WARNINGS:
562        named_args = tf_inspect.getcallargs(func, *args, **kwargs)
563        for arg_name, arg_value in deprecated_kwargs.items():
564          if arg_name in named_args and named_args[arg_name] == arg_value:
565            if (func, arg_name) not in _PRINTED_WARNING:
566              if warn_once:
567                _PRINTED_WARNING[(func, arg_name)] = True
568              logging.warning(
569                  'From %s: calling %s (from %s) with %s=%s is deprecated and '
570                  'will be removed %s.\nInstructions for updating:\n%s',
571                  _call_location(), decorator_utils.get_qualified_name(func),
572                  func.__module__, arg_name, arg_value, 'in a future version'
573                  if date is None else ('after %s' % date), instructions)
574      return func(*args, **kwargs)
575
576    doc = _add_deprecated_arg_value_notice_to_docstring(
577        func.__doc__, date, instructions, deprecated_kwargs)
578    return tf_decorator.make_decorator(func, new_func, 'deprecated', doc)
579
580  return deprecated_wrapper
581
582
583def deprecated_argument_lookup(new_name, new_value, old_name, old_value):
584  """Looks up deprecated argument name and ensures both are not used.
585
586  Args:
587    new_name: new name of argument
588    new_value: value of new argument (or None if not used)
589    old_name: old name of argument
590    old_value: value of old argument (or None if not used)
591  Returns:
592    The effective argument that should be used.
593  Raises:
594    ValueError: if new_value and old_value are both non-null
595  """
596  if old_value is not None:
597    if new_value is not None:
598      raise ValueError("Cannot specify both '%s' and '%s'" %
599                       (old_name, new_name))
600    return old_value
601  return new_value
602
603
604def rewrite_argument_docstring(old_doc, old_argument, new_argument):
605  return old_doc.replace('`%s`' % old_argument, '`%s`' % new_argument).replace(
606      '%s:' % old_argument, '%s:' % new_argument)
607
608
609@tf_contextlib.contextmanager
610def silence():
611  """Temporarily silence deprecation warnings."""
612  global _PRINT_DEPRECATION_WARNINGS
613  print_deprecation_warnings = _PRINT_DEPRECATION_WARNINGS
614  _PRINT_DEPRECATION_WARNINGS = False
615  yield
616  _PRINT_DEPRECATION_WARNINGS = print_deprecation_warnings
617