1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Type-based dispatch for TensorFlow ops.
16
17"Operation dispatchers" can be used to override the behavior for TensorFlow ops
18when they are called with otherwise unsupported argument types.  In particular,
19when an operation is called with arguments that would cause it to raise a
20TypeError, it falls back on its registered operation dispatchers.  If any
21registered dispatchers can handle the arguments, then its result is returned.
22Otherwise, the original TypeError is raised.
23
24By default, dispatch support is added to the generated op wrappers for any
25visible ops by default.  Ops that are implemented in Python can opt in to
26dispatch support using the `add_dispatch_support` decorator.
27"""
28
29from __future__ import absolute_import
30from __future__ import division
31from __future__ import print_function
32
33import itertools
34
35from tensorflow.python.util import tf_decorator
36from tensorflow.python.util import tf_inspect
37from tensorflow.python.util.tf_export import tf_export
38
39
40# Private function attribute used to store a list of dispatchers.
41DISPATCH_ATTR = "_tf_dispatchers"
42
43
44# OpDispatchers which should be used for all operations.
45_GLOBAL_DISPATCHERS = []
46
47
48@tf_export("__internal__.dispatch.OpDispatcher", v1=[])
49class OpDispatcher(object):
50  """Abstract base class for TensorFlow operator dispatchers.
51
52  Each operation dispatcher acts as an override handler for a single
53  TensorFlow operation, and its results are used when the handler indicates
54  that it can handle the operation's arguments (by returning any value other
55  than `OpDispatcher.NOT_SUPPORTED`).
56  """
57
58  # Sentinel value that can be returned to indicate that an operation
59  # dispatcher does not support a given set of arguments.
60  NOT_SUPPORTED = object()
61
62  def handle(self, args, kwargs):  # pylint: disable=unused-argument
63    """Handle this dispatcher's operation with the specified arguments.
64
65    If this operation dispatcher can handle the given arguments, then
66    return an appropriate value (or raise an appropriate exception).
67
68    Args:
69      args: The arguments to the operation.
70      kwargs: They keyword arguments to the operation.
71
72    Returns:
73      The result of the operation, or `OpDispatcher.NOT_SUPPORTED` if this
74      dispatcher can not handle the given arguments.
75    """
76    return self.NOT_SUPPORTED
77
78  def register(self, op):
79    """Register this dispatcher as a handler for `op`.
80
81    Args:
82      op: Python function: the TensorFlow operation that should be handled. Must
83        have a dispatch list (which is added automatically for generated ops,
84        and can be added to Python ops using the `add_dispatch_support`
85        decorator).
86    """
87    if not hasattr(op, DISPATCH_ATTR):
88      raise AssertionError("Dispatching not enabled for %s" % op)
89    getattr(op, DISPATCH_ATTR).append(self)
90
91
92@tf_export("__internal__.dispatch.GlobalOpDispatcher", v1=[])
93class GlobalOpDispatcher(object):
94  """Abstract base class for TensorFlow global operator dispatchers."""
95
96  NOT_SUPPORTED = OpDispatcher.NOT_SUPPORTED
97
98  def handle(self, op, args, kwargs):
99    """Handle the specified operation with the specified arguments."""
100
101  def register(self):
102    """Register this dispatcher as a handler for all ops."""
103    _GLOBAL_DISPATCHERS.append(self)
104
105
106def dispatch(op, args, kwargs):
107  """Returns the result from the first successful dispatcher for a given op.
108
109  Calls the `handle` method of each `OpDispatcher` that has been registered
110  to handle `op`, and returns the value from the first successful handler.
111
112  Args:
113    op: Python function: the operation to dispatch for.
114    args: The arguments to the operation.
115    kwargs: They keyword arguments to the operation.
116
117  Returns:
118    The result of the operation, or `NOT_SUPPORTED` if no registered
119    dispatcher can handle the given arguments.
120  """
121  for dispatcher in getattr(op, DISPATCH_ATTR):
122    result = dispatcher.handle(args, kwargs)
123    if result is not OpDispatcher.NOT_SUPPORTED:
124      return result
125  for dispatcher in _GLOBAL_DISPATCHERS:
126    result = dispatcher.handle(op, args, kwargs)
127    if result is not OpDispatcher.NOT_SUPPORTED:
128      return result
129  return OpDispatcher.NOT_SUPPORTED
130
131
132class _TypeBasedDispatcher(OpDispatcher):
133  """Dispatcher that handles op if any arguments have a specified type.
134
135  Checks the types of the arguments and keyword arguments (including elements
136  of lists or tuples), and if any argument values have the indicated type(s),
137  then delegates to an override function.
138  """
139
140  def __init__(self, override_func, types):
141    self._types = types
142    self._override_func = override_func
143
144  def _handles(self, args, kwargs):
145    for arg in itertools.chain(args, kwargs.values()):
146      if (isinstance(arg, self._types) or
147          (isinstance(arg, (list, tuple)) and
148           any(isinstance(elt, self._types) for elt in arg))):
149        return True
150    return False
151
152  def handle(self, args, kwargs):
153    if self._handles(args, kwargs):
154      return self._override_func(*args, **kwargs)
155    else:
156      return self.NOT_SUPPORTED
157
158
159# pylint: disable=g-doc-return-or-yield
160def dispatch_for_types(op, *types):
161  """Decorator to declare that a Python function overrides an op for a type.
162
163  The decorated function is used to override `op` if any of the arguments or
164  keyword arguments (including elements of lists or tuples) have one of the
165  specified types.
166
167  Example:
168
169  ```python
170  @dispatch_for_types(math_ops.add, RaggedTensor, RaggedTensorValue)
171  def ragged_add(x, y, name=None): ...
172  ```
173
174  Args:
175    op: Python function: the operation that should be overridden.
176    *types: The argument types for which this function should be used.
177  """
178
179  def decorator(func):
180    if tf_inspect.getargspec(func) != tf_inspect.getargspec(op):
181      raise AssertionError("The decorated function's signature must exactly "
182                           "match the signature of the overridden op.")
183    _TypeBasedDispatcher(func, types).register(op)
184    return func
185
186  return decorator
187
188
189# pylint: enable=g-doc-return-or-yield
190
191
192def add_dispatch_list(target):
193  """Decorator that adds a dispatch_list attribute to an op."""
194  if hasattr(target, DISPATCH_ATTR):
195    raise AssertionError("%s already has a dispatch list" % target)
196  setattr(target, DISPATCH_ATTR, [])
197  return target
198
199
200@tf_export("__internal__.dispatch.add_dispatch_support", v1=[])
201def add_dispatch_support(target):
202  """Decorator that adds a dispatch handling wrapper to an op."""
203  def wrapper(*args, **kwargs):
204    """Call target, and fall back on dispatchers if there is a TypeError."""
205    try:
206      return target(*args, **kwargs)
207    except (TypeError, ValueError):
208      # Note: convert_to_eager_tensor currently raises a ValueError, not a
209      # TypeError, when given unexpected types.  So we need to catch both.
210      result = dispatch(wrapper, args, kwargs)
211      if result is not OpDispatcher.NOT_SUPPORTED:
212        return result
213      else:
214        raise
215
216  add_dispatch_list(wrapper)
217  return tf_decorator.make_decorator(target, wrapper)
218