1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Compatibility functions.
16
17The `tf.compat` module contains two sets of compatibility functions.
18
19## Tensorflow 1.x and 2.x APIs
20
21The `compat.v1` and `compat.v2` submodules provide a complete copy of both the
22`v1` and `v2` APIs for backwards and forwards compatibility across TensorFlow
23versions 1.x and 2.x. See the
24[migration guide](https://www.tensorflow.org/guide/migrate) for details.
25
26## Utilities for writing compatible code
27
28Aside from the `compat.v1` and `compat.v2` submodules, `tf.compat` also contains
29a set of helper functions for writing code that works in both:
30
31* TensorFlow 1.x and 2.x
32* Python 2 and 3
33
34
35## Type collections
36
37The compatibility module also provides the following aliases for common
38sets of python types:
39
40* `bytes_or_text_types`
41* `complex_types`
42* `integral_types`
43* `real_types`
44"""
45
46from __future__ import absolute_import
47from __future__ import division
48from __future__ import print_function
49
50import numbers as _numbers
51
52import numpy as _np
53import six as _six
54
55from tensorflow.python.util.tf_export import tf_export
56
57try:
58  # This import only works on python 3.3 and above.
59  import collections.abc as collections_abc  # pylint: disable=unused-import
60except ImportError:
61  import collections as collections_abc  # pylint: disable=unused-import
62
63
64def as_bytes(bytes_or_text, encoding='utf-8'):
65  """Converts `bytearray`, `bytes`, or unicode python input types to `bytes`.
66
67  Uses utf-8 encoding for text by default.
68
69  Args:
70    bytes_or_text: A `bytearray`, `bytes`, `str`, or `unicode` object.
71    encoding: A string indicating the charset for encoding unicode.
72
73  Returns:
74    A `bytes` object.
75
76  Raises:
77    TypeError: If `bytes_or_text` is not a binary or unicode string.
78  """
79  if isinstance(bytes_or_text, bytearray):
80    return bytes(bytes_or_text)
81  elif isinstance(bytes_or_text, _six.text_type):
82    return bytes_or_text.encode(encoding)
83  elif isinstance(bytes_or_text, bytes):
84    return bytes_or_text
85  else:
86    raise TypeError('Expected binary or unicode string, got %r' %
87                    (bytes_or_text,))
88
89
90def as_text(bytes_or_text, encoding='utf-8'):
91  """Converts any string-like python input types to unicode.
92
93  Returns the input as a unicode string. Uses utf-8 encoding for text
94  by default.
95
96  Args:
97    bytes_or_text: A `bytes`, `str`, or `unicode` object.
98    encoding: A string indicating the charset for decoding unicode.
99
100  Returns:
101    A `unicode` (Python 2) or `str` (Python 3) object.
102
103  Raises:
104    TypeError: If `bytes_or_text` is not a binary or unicode string.
105  """
106  if isinstance(bytes_or_text, _six.text_type):
107    return bytes_or_text
108  elif isinstance(bytes_or_text, bytes):
109    return bytes_or_text.decode(encoding)
110  else:
111    raise TypeError('Expected binary or unicode string, got %r' % bytes_or_text)
112
113
114def as_str(bytes_or_text, encoding='utf-8'):
115  if _six.PY2:
116    return as_bytes(bytes_or_text, encoding)
117  else:
118    return as_text(bytes_or_text, encoding)
119
120tf_export('compat.as_text')(as_text)
121tf_export('compat.as_bytes')(as_bytes)
122tf_export('compat.as_str')(as_str)
123
124
125@tf_export('compat.as_str_any')
126def as_str_any(value):
127  """Converts input to `str` type.
128
129     Uses `str(value)`, except for `bytes` typed inputs, which are converted
130     using `as_str`.
131
132  Args:
133    value: A object that can be converted to `str`.
134
135  Returns:
136    A `str` object.
137  """
138  if isinstance(value, bytes):
139    return as_str(value)
140  else:
141    return str(value)
142
143
144@tf_export('compat.path_to_str')
145def path_to_str(path):
146  r"""Converts input which is a `PathLike` object to `str` type.
147
148  Converts from any python constant representation of a `PathLike` object to
149  a string. If the input is not a `PathLike` object, simply returns the input.
150
151  Args:
152    path: An object that can be converted to path representation.
153
154  Returns:
155    A `str` object.
156
157  Usage:
158    In case a simplified `str` version of the path is needed from an
159    `os.PathLike` object
160
161  Examples:
162  ```python
163  $ tf.compat.path_to_str('C:\XYZ\tensorflow\./.././tensorflow')
164  'C:\XYZ\tensorflow\./.././tensorflow' # Windows OS
165  $ tf.compat.path_to_str(Path('C:\XYZ\tensorflow\./.././tensorflow'))
166  'C:\XYZ\tensorflow\..\tensorflow' # Windows OS
167  $ tf.compat.path_to_str(Path('./corpus'))
168  'corpus' # Linux OS
169  $ tf.compat.path_to_str('./.././Corpus')
170  './.././Corpus' # Linux OS
171  $ tf.compat.path_to_str(Path('./.././Corpus'))
172  '../Corpus' # Linux OS
173  $ tf.compat.path_to_str(Path('./..////../'))
174  '../..' # Linux OS
175
176  ```
177  """
178  if hasattr(path, '__fspath__'):
179    path = as_str_any(path.__fspath__())
180  return path
181
182
183def path_to_bytes(path):
184  r"""Converts input which is a `PathLike` object to `bytes`.
185
186  Converts from any python constant representation of a `PathLike` object
187  or `str` to bytes.
188
189  Args:
190    path: An object that can be converted to path representation.
191
192  Returns:
193    A `bytes` object.
194
195  Usage:
196    In case a simplified `bytes` version of the path is needed from an
197    `os.PathLike` object
198  """
199  if hasattr(path, '__fspath__'):
200    path = path.__fspath__()
201  return as_bytes(path)
202
203
204# Numpy 1.8 scalars don't inherit from numbers.Integral in Python 3, so we
205# need to check them specifically.  The same goes from Real and Complex.
206integral_types = (_numbers.Integral, _np.integer)
207tf_export('compat.integral_types').export_constant(__name__, 'integral_types')
208real_types = (_numbers.Real, _np.integer, _np.floating)
209tf_export('compat.real_types').export_constant(__name__, 'real_types')
210complex_types = (_numbers.Complex, _np.number)
211tf_export('compat.complex_types').export_constant(__name__, 'complex_types')
212
213# Either bytes or text.
214bytes_or_text_types = (bytes, _six.text_type)
215tf_export('compat.bytes_or_text_types').export_constant(__name__,
216                                                        'bytes_or_text_types')
217