1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for exporting TensorFlow symbols to the API.
16
17Exporting a function or a class:
18
19To export a function or a class use tf_export decorator. For e.g.:
20```python
21@tf_export('foo', 'bar.foo')
22def foo(...):
23  ...
24```
25
26If a function is assigned to a variable, you can export it by calling
27tf_export explicitly. For e.g.:
28```python
29foo = get_foo(...)
30tf_export('foo', 'bar.foo')(foo)
31```
32
33
34Exporting a constant
35```python
36foo = 1
37tf_export('consts.foo').export_constant(__name__, 'foo')
38```
39"""
40from __future__ import absolute_import
41from __future__ import division
42from __future__ import print_function
43
44import collections
45import functools
46import sys
47
48from tensorflow.python.util import tf_decorator
49from tensorflow.python.util import tf_inspect
50
51ESTIMATOR_API_NAME = 'estimator'
52KERAS_API_NAME = 'keras'
53TENSORFLOW_API_NAME = 'tensorflow'
54
55# List of subpackage names used by TensorFlow components. Have to check that
56# TensorFlow core repo does not export any symbols under these names.
57SUBPACKAGE_NAMESPACES = [ESTIMATOR_API_NAME]
58
59_Attributes = collections.namedtuple(
60    'ExportedApiAttributes', ['names', 'constants'])
61
62# Attribute values must be unique to each API.
63API_ATTRS = {
64    TENSORFLOW_API_NAME: _Attributes(
65        '_tf_api_names',
66        '_tf_api_constants'),
67    ESTIMATOR_API_NAME: _Attributes(
68        '_estimator_api_names',
69        '_estimator_api_constants'),
70    KERAS_API_NAME: _Attributes(
71        '_keras_api_names',
72        '_keras_api_constants')
73}
74
75API_ATTRS_V1 = {
76    TENSORFLOW_API_NAME: _Attributes(
77        '_tf_api_names_v1',
78        '_tf_api_constants_v1'),
79    ESTIMATOR_API_NAME: _Attributes(
80        '_estimator_api_names_v1',
81        '_estimator_api_constants_v1'),
82    KERAS_API_NAME: _Attributes(
83        '_keras_api_names_v1',
84        '_keras_api_constants_v1')
85}
86
87
88class SymbolAlreadyExposedError(Exception):
89  """Raised when adding API names to symbol that already has API names."""
90  pass
91
92
93class InvalidSymbolNameError(Exception):
94  """Raised when trying to export symbol as an invalid or unallowed name."""
95  pass
96
97
98def get_canonical_name_for_symbol(
99    symbol, api_name=TENSORFLOW_API_NAME,
100    add_prefix_to_v1_names=False):
101  """Get canonical name for the API symbol.
102
103  Args:
104    symbol: API function or class.
105    api_name: API name (tensorflow or estimator).
106    add_prefix_to_v1_names: Specifies whether a name available only in V1
107      should be prefixed with compat.v1.
108
109  Returns:
110    Canonical name for the API symbol (for e.g. initializers.zeros) if
111    canonical name could be determined. Otherwise, returns None.
112  """
113  if not hasattr(symbol, '__dict__'):
114    return None
115  api_names_attr = API_ATTRS[api_name].names
116  _, undecorated_symbol = tf_decorator.unwrap(symbol)
117  if api_names_attr not in undecorated_symbol.__dict__:
118    return None
119  api_names = getattr(undecorated_symbol, api_names_attr)
120  deprecated_api_names = undecorated_symbol.__dict__.get(
121      '_tf_deprecated_api_names', [])
122
123  canonical_name = get_canonical_name(api_names, deprecated_api_names)
124  if canonical_name:
125    return canonical_name
126
127  # If there is no V2 canonical name, get V1 canonical name.
128  api_names_attr = API_ATTRS_V1[api_name].names
129  api_names = getattr(undecorated_symbol, api_names_attr)
130  v1_canonical_name = get_canonical_name(api_names, deprecated_api_names)
131  if add_prefix_to_v1_names:
132    return 'compat.v1.%s' % v1_canonical_name
133  return v1_canonical_name
134
135
136def get_canonical_name(api_names, deprecated_api_names):
137  """Get preferred endpoint name.
138
139  Args:
140    api_names: API names iterable.
141    deprecated_api_names: Deprecated API names iterable.
142  Returns:
143    Returns one of the following in decreasing preference:
144    - first non-deprecated endpoint
145    - first endpoint
146    - None
147  """
148  non_deprecated_name = next(
149      (name for name in api_names if name not in deprecated_api_names),
150      None)
151  if non_deprecated_name:
152    return non_deprecated_name
153  if api_names:
154    return api_names[0]
155  return None
156
157
158def get_v1_names(symbol):
159  """Get a list of TF 1.* names for this symbol.
160
161  Args:
162    symbol: symbol to get API names for.
163
164  Returns:
165    List of all API names for this symbol including TensorFlow and
166    Estimator names.
167  """
168  names_v1 = []
169  tensorflow_api_attr_v1 = API_ATTRS_V1[TENSORFLOW_API_NAME].names
170  estimator_api_attr_v1 = API_ATTRS_V1[ESTIMATOR_API_NAME].names
171  keras_api_attr_v1 = API_ATTRS_V1[KERAS_API_NAME].names
172
173  if not hasattr(symbol, '__dict__'):
174    return names_v1
175  if tensorflow_api_attr_v1 in symbol.__dict__:
176    names_v1.extend(getattr(symbol, tensorflow_api_attr_v1))
177  if estimator_api_attr_v1 in symbol.__dict__:
178    names_v1.extend(getattr(symbol, estimator_api_attr_v1))
179  if keras_api_attr_v1 in symbol.__dict__:
180    names_v1.extend(getattr(symbol, keras_api_attr_v1))
181  return names_v1
182
183
184def get_v2_names(symbol):
185  """Get a list of TF 2.0 names for this symbol.
186
187  Args:
188    symbol: symbol to get API names for.
189
190  Returns:
191    List of all API names for this symbol including TensorFlow and
192    Estimator names.
193  """
194  names_v2 = []
195  tensorflow_api_attr = API_ATTRS[TENSORFLOW_API_NAME].names
196  estimator_api_attr = API_ATTRS[ESTIMATOR_API_NAME].names
197  keras_api_attr = API_ATTRS[KERAS_API_NAME].names
198
199  if not hasattr(symbol, '__dict__'):
200    return names_v2
201  if tensorflow_api_attr in symbol.__dict__:
202    names_v2.extend(getattr(symbol, tensorflow_api_attr))
203  if estimator_api_attr in symbol.__dict__:
204    names_v2.extend(getattr(symbol, estimator_api_attr))
205  if keras_api_attr in symbol.__dict__:
206    names_v2.extend(getattr(symbol, keras_api_attr))
207  return names_v2
208
209
210def get_v1_constants(module):
211  """Get a list of TF 1.* constants in this module.
212
213  Args:
214    module: TensorFlow module.
215
216  Returns:
217    List of all API constants under the given module including TensorFlow and
218    Estimator constants.
219  """
220  constants_v1 = []
221  tensorflow_constants_attr_v1 = API_ATTRS_V1[TENSORFLOW_API_NAME].constants
222  estimator_constants_attr_v1 = API_ATTRS_V1[ESTIMATOR_API_NAME].constants
223
224  if hasattr(module, tensorflow_constants_attr_v1):
225    constants_v1.extend(getattr(module, tensorflow_constants_attr_v1))
226  if hasattr(module, estimator_constants_attr_v1):
227    constants_v1.extend(getattr(module, estimator_constants_attr_v1))
228  return constants_v1
229
230
231def get_v2_constants(module):
232  """Get a list of TF 2.0 constants in this module.
233
234  Args:
235    module: TensorFlow module.
236
237  Returns:
238    List of all API constants under the given module including TensorFlow and
239    Estimator constants.
240  """
241  constants_v2 = []
242  tensorflow_constants_attr = API_ATTRS[TENSORFLOW_API_NAME].constants
243  estimator_constants_attr = API_ATTRS[ESTIMATOR_API_NAME].constants
244
245  if hasattr(module, tensorflow_constants_attr):
246    constants_v2.extend(getattr(module, tensorflow_constants_attr))
247  if hasattr(module, estimator_constants_attr):
248    constants_v2.extend(getattr(module, estimator_constants_attr))
249  return constants_v2
250
251
252class api_export(object):  # pylint: disable=invalid-name
253  """Provides ways to export symbols to the TensorFlow API."""
254
255  def __init__(self, *args, **kwargs):  # pylint: disable=g-doc-args
256    """Export under the names *args (first one is considered canonical).
257
258    Args:
259      *args: API names in dot delimited format.
260      **kwargs: Optional keyed arguments.
261        v1: Names for the TensorFlow V1 API. If not set, we will use V2 API
262          names both for TensorFlow V1 and V2 APIs.
263        overrides: List of symbols that this is overriding
264          (those overrided api exports will be removed). Note: passing overrides
265          has no effect on exporting a constant.
266        api_name: Name of the API you want to generate (e.g. `tensorflow` or
267          `estimator`). Default is `tensorflow`.
268        allow_multiple_exports: Allow symbol to be exported multiple time under
269          different names.
270    """
271    self._names = args
272    self._names_v1 = kwargs.get('v1', args)
273    if 'v2' in kwargs:
274      raise ValueError('You passed a "v2" argument to tf_export. This is not '
275                       'what you want. Pass v2 names directly as positional '
276                       'arguments instead.')
277    self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME)
278    self._overrides = kwargs.get('overrides', [])
279    self._allow_multiple_exports = kwargs.get('allow_multiple_exports', False)
280
281    self._validate_symbol_names()
282
283  def _validate_symbol_names(self):
284    """Validate you are exporting symbols under an allowed package.
285
286    We need to ensure things exported by tf_export, estimator_export, etc.
287    export symbols under disjoint top-level package names.
288
289    For TensorFlow, we check that it does not export anything under subpackage
290    names used by components (estimator, keras, etc.).
291
292    For each component, we check that it exports everything under its own
293    subpackage.
294
295    Raises:
296      InvalidSymbolNameError: If you try to export symbol under disallowed name.
297    """
298    all_symbol_names = set(self._names) | set(self._names_v1)
299    if self._api_name == TENSORFLOW_API_NAME:
300      for subpackage in SUBPACKAGE_NAMESPACES:
301        if any(n.startswith(subpackage) for n in all_symbol_names):
302          raise InvalidSymbolNameError(
303              '@tf_export is not allowed to export symbols under %s.*' % (
304                  subpackage))
305    else:
306      if not all(n.startswith(self._api_name) for n in all_symbol_names):
307        raise InvalidSymbolNameError(
308            'Can only export symbols under package name of component. '
309            'e.g. tensorflow_estimator must export all symbols under '
310            'tf.estimator')
311
312  def __call__(self, func):
313    """Calls this decorator.
314
315    Args:
316      func: decorated symbol (function or class).
317
318    Returns:
319      The input function with _tf_api_names attribute set.
320
321    Raises:
322      SymbolAlreadyExposedError: Raised when a symbol already has API names
323        and kwarg `allow_multiple_exports` not set.
324    """
325    api_names_attr = API_ATTRS[self._api_name].names
326    api_names_attr_v1 = API_ATTRS_V1[self._api_name].names
327    # Undecorate overridden names
328    for f in self._overrides:
329      _, undecorated_f = tf_decorator.unwrap(f)
330      delattr(undecorated_f, api_names_attr)
331      delattr(undecorated_f, api_names_attr_v1)
332
333    _, undecorated_func = tf_decorator.unwrap(func)
334    self.set_attr(undecorated_func, api_names_attr, self._names)
335    self.set_attr(undecorated_func, api_names_attr_v1, self._names_v1)
336    return func
337
338  def set_attr(self, func, api_names_attr, names):
339    # Check for an existing api. We check if attribute name is in
340    # __dict__ instead of using hasattr to verify that subclasses have
341    # their own _tf_api_names as opposed to just inheriting it.
342    if api_names_attr in func.__dict__:
343      if not self._allow_multiple_exports:
344        raise SymbolAlreadyExposedError(
345            'Symbol %s is already exposed as %s.' %
346            (func.__name__, getattr(func, api_names_attr)))  # pylint: disable=protected-access
347    setattr(func, api_names_attr, names)
348
349  def export_constant(self, module_name, name):
350    """Store export information for constants/string literals.
351
352    Export information is stored in the module where constants/string literals
353    are defined.
354
355    e.g.
356    ```python
357    foo = 1
358    bar = 2
359    tf_export("consts.foo").export_constant(__name__, 'foo')
360    tf_export("consts.bar").export_constant(__name__, 'bar')
361    ```
362
363    Args:
364      module_name: (string) Name of the module to store constant at.
365      name: (string) Current constant name.
366    """
367    module = sys.modules[module_name]
368    api_constants_attr = API_ATTRS[self._api_name].constants
369    api_constants_attr_v1 = API_ATTRS_V1[self._api_name].constants
370
371    if not hasattr(module, api_constants_attr):
372      setattr(module, api_constants_attr, [])
373    # pylint: disable=protected-access
374    getattr(module, api_constants_attr).append(
375        (self._names, name))
376
377    if not hasattr(module, api_constants_attr_v1):
378      setattr(module, api_constants_attr_v1, [])
379    getattr(module, api_constants_attr_v1).append(
380        (self._names_v1, name))
381
382
383def kwarg_only(f):
384  """A wrapper that throws away all non-kwarg arguments."""
385  f_argspec = tf_inspect.getargspec(f)
386
387  def wrapper(*args, **kwargs):
388    if args:
389      raise TypeError(
390          '{f} only takes keyword args (possible keys: {kwargs}). '
391          'Please pass these args as kwargs instead.'
392          .format(f=f.__name__, kwargs=f_argspec.args))
393    return f(**kwargs)
394
395  return tf_decorator.make_decorator(f, wrapper, decorator_argspec=f_argspec)
396
397
398tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME)
399estimator_export = functools.partial(api_export, api_name=ESTIMATOR_API_NAME)
400keras_export = functools.partial(api_export, api_name=KERAS_API_NAME)
401