1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for exporting TensorFlow symbols to the API.
16
17Exporting a function or a class:
18
19To export a function or a class use tf_export decorator. For e.g.:
20```python
21@tf_export('foo', 'bar.foo')
22def foo(...):
23  ...
24```
25
26If a function is assigned to a variable, you can export it by calling
27tf_export explicitly. For e.g.:
28```python
29foo = get_foo(...)
30tf_export('foo', 'bar.foo')(foo)
31```
32
33
34Exporting a constant
35```python
36foo = 1
37tf_export('consts.foo').export_constant(__name__, 'foo')
38```
39"""
40from __future__ import absolute_import
41from __future__ import division
42from __future__ import print_function
43
44import collections
45import functools
46import sys
47
48from tensorflow.python.util import tf_decorator
49from tensorflow.python.util import tf_inspect
50
51ESTIMATOR_API_NAME = 'estimator'
52KERAS_API_NAME = 'keras'
53TENSORFLOW_API_NAME = 'tensorflow'
54
55# List of subpackage names used by TensorFlow components. Have to check that
56# TensorFlow core repo does not export any symbols under these names.
57SUBPACKAGE_NAMESPACES = [ESTIMATOR_API_NAME]
58
59_Attributes = collections.namedtuple(
60    'ExportedApiAttributes', ['names', 'constants'])
61
62# Attribute values must be unique to each API.
63API_ATTRS = {
64    TENSORFLOW_API_NAME: _Attributes(
65        '_tf_api_names',
66        '_tf_api_constants'),
67    ESTIMATOR_API_NAME: _Attributes(
68        '_estimator_api_names',
69        '_estimator_api_constants'),
70    KERAS_API_NAME: _Attributes(
71        '_keras_api_names',
72        '_keras_api_constants')
73}
74
75API_ATTRS_V1 = {
76    TENSORFLOW_API_NAME: _Attributes(
77        '_tf_api_names_v1',
78        '_tf_api_constants_v1'),
79    ESTIMATOR_API_NAME: _Attributes(
80        '_estimator_api_names_v1',
81        '_estimator_api_constants_v1'),
82    KERAS_API_NAME: _Attributes(
83        '_keras_api_names_v1',
84        '_keras_api_constants_v1')
85}
86
87
88class SymbolAlreadyExposedError(Exception):
89  """Raised when adding API names to symbol that already has API names."""
90  pass
91
92
93class InvalidSymbolNameError(Exception):
94  """Raised when trying to export symbol as an invalid or unallowed name."""
95  pass
96
97_NAME_TO_SYMBOL_MAPPING = dict()
98
99
100def get_symbol_from_name(name):
101  return _NAME_TO_SYMBOL_MAPPING.get(name)
102
103
104def get_canonical_name_for_symbol(
105    symbol, api_name=TENSORFLOW_API_NAME,
106    add_prefix_to_v1_names=False):
107  """Get canonical name for the API symbol.
108
109  Args:
110    symbol: API function or class.
111    api_name: API name (tensorflow or estimator).
112    add_prefix_to_v1_names: Specifies whether a name available only in V1
113      should be prefixed with compat.v1.
114
115  Returns:
116    Canonical name for the API symbol (for e.g. initializers.zeros) if
117    canonical name could be determined. Otherwise, returns None.
118  """
119  if not hasattr(symbol, '__dict__'):
120    return None
121  api_names_attr = API_ATTRS[api_name].names
122  _, undecorated_symbol = tf_decorator.unwrap(symbol)
123  if api_names_attr not in undecorated_symbol.__dict__:
124    return None
125  api_names = getattr(undecorated_symbol, api_names_attr)
126  deprecated_api_names = undecorated_symbol.__dict__.get(
127      '_tf_deprecated_api_names', [])
128
129  canonical_name = get_canonical_name(api_names, deprecated_api_names)
130  if canonical_name:
131    return canonical_name
132
133  # If there is no V2 canonical name, get V1 canonical name.
134  api_names_attr = API_ATTRS_V1[api_name].names
135  api_names = getattr(undecorated_symbol, api_names_attr)
136  v1_canonical_name = get_canonical_name(api_names, deprecated_api_names)
137  if add_prefix_to_v1_names:
138    return 'compat.v1.%s' % v1_canonical_name
139  return v1_canonical_name
140
141
142def get_canonical_name(api_names, deprecated_api_names):
143  """Get preferred endpoint name.
144
145  Args:
146    api_names: API names iterable.
147    deprecated_api_names: Deprecated API names iterable.
148  Returns:
149    Returns one of the following in decreasing preference:
150    - first non-deprecated endpoint
151    - first endpoint
152    - None
153  """
154  non_deprecated_name = next(
155      (name for name in api_names if name not in deprecated_api_names),
156      None)
157  if non_deprecated_name:
158    return non_deprecated_name
159  if api_names:
160    return api_names[0]
161  return None
162
163
164def get_v1_names(symbol):
165  """Get a list of TF 1.* names for this symbol.
166
167  Args:
168    symbol: symbol to get API names for.
169
170  Returns:
171    List of all API names for this symbol including TensorFlow and
172    Estimator names.
173  """
174  names_v1 = []
175  tensorflow_api_attr_v1 = API_ATTRS_V1[TENSORFLOW_API_NAME].names
176  estimator_api_attr_v1 = API_ATTRS_V1[ESTIMATOR_API_NAME].names
177  keras_api_attr_v1 = API_ATTRS_V1[KERAS_API_NAME].names
178
179  if not hasattr(symbol, '__dict__'):
180    return names_v1
181  if tensorflow_api_attr_v1 in symbol.__dict__:
182    names_v1.extend(getattr(symbol, tensorflow_api_attr_v1))
183  if estimator_api_attr_v1 in symbol.__dict__:
184    names_v1.extend(getattr(symbol, estimator_api_attr_v1))
185  if keras_api_attr_v1 in symbol.__dict__:
186    names_v1.extend(getattr(symbol, keras_api_attr_v1))
187  return names_v1
188
189
190def get_v2_names(symbol):
191  """Get a list of TF 2.0 names for this symbol.
192
193  Args:
194    symbol: symbol to get API names for.
195
196  Returns:
197    List of all API names for this symbol including TensorFlow and
198    Estimator names.
199  """
200  names_v2 = []
201  tensorflow_api_attr = API_ATTRS[TENSORFLOW_API_NAME].names
202  estimator_api_attr = API_ATTRS[ESTIMATOR_API_NAME].names
203  keras_api_attr = API_ATTRS[KERAS_API_NAME].names
204
205  if not hasattr(symbol, '__dict__'):
206    return names_v2
207  if tensorflow_api_attr in symbol.__dict__:
208    names_v2.extend(getattr(symbol, tensorflow_api_attr))
209  if estimator_api_attr in symbol.__dict__:
210    names_v2.extend(getattr(symbol, estimator_api_attr))
211  if keras_api_attr in symbol.__dict__:
212    names_v2.extend(getattr(symbol, keras_api_attr))
213  return names_v2
214
215
216def get_v1_constants(module):
217  """Get a list of TF 1.* constants in this module.
218
219  Args:
220    module: TensorFlow module.
221
222  Returns:
223    List of all API constants under the given module including TensorFlow and
224    Estimator constants.
225  """
226  constants_v1 = []
227  tensorflow_constants_attr_v1 = API_ATTRS_V1[TENSORFLOW_API_NAME].constants
228  estimator_constants_attr_v1 = API_ATTRS_V1[ESTIMATOR_API_NAME].constants
229
230  if hasattr(module, tensorflow_constants_attr_v1):
231    constants_v1.extend(getattr(module, tensorflow_constants_attr_v1))
232  if hasattr(module, estimator_constants_attr_v1):
233    constants_v1.extend(getattr(module, estimator_constants_attr_v1))
234  return constants_v1
235
236
237def get_v2_constants(module):
238  """Get a list of TF 2.0 constants in this module.
239
240  Args:
241    module: TensorFlow module.
242
243  Returns:
244    List of all API constants under the given module including TensorFlow and
245    Estimator constants.
246  """
247  constants_v2 = []
248  tensorflow_constants_attr = API_ATTRS[TENSORFLOW_API_NAME].constants
249  estimator_constants_attr = API_ATTRS[ESTIMATOR_API_NAME].constants
250
251  if hasattr(module, tensorflow_constants_attr):
252    constants_v2.extend(getattr(module, tensorflow_constants_attr))
253  if hasattr(module, estimator_constants_attr):
254    constants_v2.extend(getattr(module, estimator_constants_attr))
255  return constants_v2
256
257
258class api_export(object):  # pylint: disable=invalid-name
259  """Provides ways to export symbols to the TensorFlow API."""
260
261  def __init__(self, *args, **kwargs):  # pylint: disable=g-doc-args
262    """Export under the names *args (first one is considered canonical).
263
264    Args:
265      *args: API names in dot delimited format.
266      **kwargs: Optional keyed arguments.
267        v1: Names for the TensorFlow V1 API. If not set, we will use V2 API
268          names both for TensorFlow V1 and V2 APIs.
269        overrides: List of symbols that this is overriding
270          (those overrided api exports will be removed). Note: passing overrides
271          has no effect on exporting a constant.
272        api_name: Name of the API you want to generate (e.g. `tensorflow` or
273          `estimator`). Default is `tensorflow`.
274        allow_multiple_exports: Allow symbol to be exported multiple time under
275          different names.
276    """
277    self._names = args
278    self._names_v1 = kwargs.get('v1', args)
279    if 'v2' in kwargs:
280      raise ValueError('You passed a "v2" argument to tf_export. This is not '
281                       'what you want. Pass v2 names directly as positional '
282                       'arguments instead.')
283    self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME)
284    self._overrides = kwargs.get('overrides', [])
285    self._allow_multiple_exports = kwargs.get('allow_multiple_exports', False)
286
287    self._validate_symbol_names()
288
289  def _validate_symbol_names(self):
290    """Validate you are exporting symbols under an allowed package.
291
292    We need to ensure things exported by tf_export, estimator_export, etc.
293    export symbols under disjoint top-level package names.
294
295    For TensorFlow, we check that it does not export anything under subpackage
296    names used by components (estimator, keras, etc.).
297
298    For each component, we check that it exports everything under its own
299    subpackage.
300
301    Raises:
302      InvalidSymbolNameError: If you try to export symbol under disallowed name.
303    """
304    all_symbol_names = set(self._names) | set(self._names_v1)
305    if self._api_name == TENSORFLOW_API_NAME:
306      for subpackage in SUBPACKAGE_NAMESPACES:
307        if any(n.startswith(subpackage) for n in all_symbol_names):
308          raise InvalidSymbolNameError(
309              '@tf_export is not allowed to export symbols under %s.*' % (
310                  subpackage))
311    else:
312      if not all(n.startswith(self._api_name) for n in all_symbol_names):
313        raise InvalidSymbolNameError(
314            'Can only export symbols under package name of component. '
315            'e.g. tensorflow_estimator must export all symbols under '
316            'tf.estimator')
317
318  def __call__(self, func):
319    """Calls this decorator.
320
321    Args:
322      func: decorated symbol (function or class).
323
324    Returns:
325      The input function with _tf_api_names attribute set.
326
327    Raises:
328      SymbolAlreadyExposedError: Raised when a symbol already has API names
329        and kwarg `allow_multiple_exports` not set.
330    """
331    api_names_attr = API_ATTRS[self._api_name].names
332    api_names_attr_v1 = API_ATTRS_V1[self._api_name].names
333    # Undecorate overridden names
334    for f in self._overrides:
335      _, undecorated_f = tf_decorator.unwrap(f)
336      delattr(undecorated_f, api_names_attr)
337      delattr(undecorated_f, api_names_attr_v1)
338
339    _, undecorated_func = tf_decorator.unwrap(func)
340    self.set_attr(undecorated_func, api_names_attr, self._names)
341    self.set_attr(undecorated_func, api_names_attr_v1, self._names_v1)
342
343    for name in self._names:
344      _NAME_TO_SYMBOL_MAPPING[name] = func
345    for name_v1 in self._names_v1:
346      _NAME_TO_SYMBOL_MAPPING['compat.v1.%s' % name_v1] = func
347    return func
348
349  def set_attr(self, func, api_names_attr, names):
350    # Check for an existing api. We check if attribute name is in
351    # __dict__ instead of using hasattr to verify that subclasses have
352    # their own _tf_api_names as opposed to just inheriting it.
353    if api_names_attr in func.__dict__:
354      if not self._allow_multiple_exports:
355        raise SymbolAlreadyExposedError(
356            'Symbol %s is already exposed as %s.' %
357            (func.__name__, getattr(func, api_names_attr)))  # pylint: disable=protected-access
358    setattr(func, api_names_attr, names)
359
360  def export_constant(self, module_name, name):
361    """Store export information for constants/string literals.
362
363    Export information is stored in the module where constants/string literals
364    are defined.
365
366    e.g.
367    ```python
368    foo = 1
369    bar = 2
370    tf_export("consts.foo").export_constant(__name__, 'foo')
371    tf_export("consts.bar").export_constant(__name__, 'bar')
372    ```
373
374    Args:
375      module_name: (string) Name of the module to store constant at.
376      name: (string) Current constant name.
377    """
378    module = sys.modules[module_name]
379    api_constants_attr = API_ATTRS[self._api_name].constants
380    api_constants_attr_v1 = API_ATTRS_V1[self._api_name].constants
381
382    if not hasattr(module, api_constants_attr):
383      setattr(module, api_constants_attr, [])
384    # pylint: disable=protected-access
385    getattr(module, api_constants_attr).append(
386        (self._names, name))
387
388    if not hasattr(module, api_constants_attr_v1):
389      setattr(module, api_constants_attr_v1, [])
390    getattr(module, api_constants_attr_v1).append(
391        (self._names_v1, name))
392
393
394def kwarg_only(f):
395  """A wrapper that throws away all non-kwarg arguments."""
396  f_argspec = tf_inspect.getargspec(f)
397
398  def wrapper(*args, **kwargs):
399    if args:
400      raise TypeError(
401          '{f} only takes keyword args (possible keys: {kwargs}). '
402          'Please pass these args as kwargs instead.'
403          .format(f=f.__name__, kwargs=f_argspec.args))
404    return f(**kwargs)
405
406  return tf_decorator.make_decorator(f, wrapper, decorator_argspec=f_argspec)
407
408
409tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME)
410estimator_export = functools.partial(api_export, api_name=ESTIMATOR_API_NAME)
411keras_export = functools.partial(api_export, api_name=KERAS_API_NAME)
412