1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Registry for tensor conversion functions."""
16# pylint: disable=g-bad-name
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import threading
23
24import numpy as np
25import six
26
27from tensorflow.python.util import lazy_loader
28from tensorflow.python.util.tf_export import tf_export
29
30# Loaded lazily due to a circular dependency
31# ops->tensor_conversion_registry->constant_op->ops.
32constant_op = lazy_loader.LazyLoader(
33    "constant_op", globals(),
34    "tensorflow.python.framework.constant_op")
35
36
37_tensor_conversion_func_registry = collections.defaultdict(list)
38_tensor_conversion_func_cache = {}
39_tensor_conversion_func_lock = threading.Lock()
40
41# Instances of these types are always converted using
42# `_default_conversion_function`.
43_UNCONVERTIBLE_TYPES = six.integer_types + (
44    float,
45    np.generic,
46    np.ndarray,
47)
48
49
50def _default_conversion_function(value, dtype, name, as_ref):
51  del as_ref  # Unused.
52  return constant_op.constant(value, dtype, name=name)
53
54
55# TODO(josh11b): Add ctx argument to conversion_func() signature.
56@tf_export("register_tensor_conversion_function")
57def register_tensor_conversion_function(base_type,
58                                        conversion_func,
59                                        priority=100):
60  """Registers a function for converting objects of `base_type` to `Tensor`.
61
62  The conversion function must have the following signature:
63
64  ```python
65      def conversion_func(value, dtype=None, name=None, as_ref=False):
66        # ...
67  ```
68
69  It must return a `Tensor` with the given `dtype` if specified. If the
70  conversion function creates a new `Tensor`, it should use the given
71  `name` if specified. All exceptions will be propagated to the caller.
72
73  The conversion function may return `NotImplemented` for some
74  inputs. In this case, the conversion process will continue to try
75  subsequent conversion functions.
76
77  If `as_ref` is true, the function must return a `Tensor` reference,
78  such as a `Variable`.
79
80  NOTE: The conversion functions will execute in order of priority,
81  followed by order of registration. To ensure that a conversion function
82  `F` runs before another conversion function `G`, ensure that `F` is
83  registered with a smaller priority than `G`.
84
85  Args:
86    base_type: The base type or tuple of base types for all objects that
87      `conversion_func` accepts.
88    conversion_func: A function that converts instances of `base_type` to
89      `Tensor`.
90    priority: Optional integer that indicates the priority for applying this
91      conversion function. Conversion functions with smaller priority values run
92      earlier than conversion functions with larger priority values. Defaults to
93      100.
94
95  Raises:
96    TypeError: If the arguments do not have the appropriate type.
97  """
98  base_types = base_type if isinstance(base_type, tuple) else (base_type,)
99  if any(not isinstance(x, type) for x in base_types):
100    raise TypeError("base_type must be a type or a tuple of types.")
101  if any(issubclass(x, _UNCONVERTIBLE_TYPES) for x in base_types):
102    raise TypeError("Cannot register conversions for Python numeric types and "
103                    "NumPy scalars and arrays.")
104  del base_types  # Only needed for validation.
105  if not callable(conversion_func):
106    raise TypeError("conversion_func must be callable.")
107
108  with _tensor_conversion_func_lock:
109    _tensor_conversion_func_registry[priority].append(
110        (base_type, conversion_func))
111    _tensor_conversion_func_cache.clear()
112
113
114def get(query):
115  """Get conversion function for objects of `cls`.
116
117  Args:
118    query: The type to query for.
119
120  Returns:
121    A list of conversion functions in increasing order of priority.
122  """
123  if issubclass(query, _UNCONVERTIBLE_TYPES):
124    return [(query, _default_conversion_function)]
125
126  conversion_funcs = _tensor_conversion_func_cache.get(query)
127  if conversion_funcs is None:
128    with _tensor_conversion_func_lock:
129      # Has another thread populated the cache in the meantime?
130      conversion_funcs = _tensor_conversion_func_cache.get(query)
131      if conversion_funcs is None:
132        conversion_funcs = []
133        for _, funcs_at_priority in sorted(
134            _tensor_conversion_func_registry.items()):
135          conversion_funcs.extend(
136              (base_type, conversion_func)
137              for base_type, conversion_func in funcs_at_priority
138              if issubclass(query, base_type))
139        _tensor_conversion_func_cache[query] = conversion_funcs
140  return conversion_funcs
141