1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Library of dtypes (Tensor element types)."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21from six.moves import builtins
22
23from tensorflow.core.framework import types_pb2
24# We need to import pywrap_tensorflow prior to the bfloat wrapper to avoid
25# protobuf errors where a file is defined twice on MacOS.
26# pylint: disable=invalid-import-order,g-bad-import-order
27from tensorflow.python import pywrap_tensorflow  # pylint: disable=unused-import
28from tensorflow.python import _dtypes
29from tensorflow.python.lib.core import _pywrap_bfloat16
30from tensorflow.python.util.tf_export import tf_export
31
32_np_bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
33
34
35# pylint: disable=slots-on-old-class
36@tf_export("dtypes.DType", "DType")
37class DType(_dtypes.DType):
38  """Represents the type of the elements in a `Tensor`.
39
40  The following `DType` objects are defined:
41
42  * `tf.float16`: 16-bit half-precision floating-point.
43  * `tf.float32`: 32-bit single-precision floating-point.
44  * `tf.float64`: 64-bit double-precision floating-point.
45  * `tf.bfloat16`: 16-bit truncated floating-point.
46  * `tf.complex64`: 64-bit single-precision complex.
47  * `tf.complex128`: 128-bit double-precision complex.
48  * `tf.int8`: 8-bit signed integer.
49  * `tf.uint8`: 8-bit unsigned integer.
50  * `tf.uint16`: 16-bit unsigned integer.
51  * `tf.uint32`: 32-bit unsigned integer.
52  * `tf.uint64`: 64-bit unsigned integer.
53  * `tf.int16`: 16-bit signed integer.
54  * `tf.int32`: 32-bit signed integer.
55  * `tf.int64`: 64-bit signed integer.
56  * `tf.bool`: Boolean.
57  * `tf.string`: String.
58  * `tf.qint8`: Quantized 8-bit signed integer.
59  * `tf.quint8`: Quantized 8-bit unsigned integer.
60  * `tf.qint16`: Quantized 16-bit signed integer.
61  * `tf.quint16`: Quantized 16-bit unsigned integer.
62  * `tf.qint32`: Quantized 32-bit signed integer.
63  * `tf.resource`: Handle to a mutable resource.
64  * `tf.variant`: Values of arbitrary types.
65
66  The `tf.as_dtype()` function converts numpy types and string type
67  names to a `DType` object.
68  """
69  __slots__ = ()
70
71  @property
72  def _is_ref_dtype(self):
73    """Returns `True` if this `DType` represents a reference type."""
74    return self._type_enum > 100
75
76  @property
77  def _as_ref(self):
78    """Returns a reference `DType` based on this `DType`."""
79    if self._is_ref_dtype:
80      return self
81    else:
82      return _INTERN_TABLE[self._type_enum + 100]
83
84  @property
85  def base_dtype(self):
86    """Returns a non-reference `DType` based on this `DType`."""
87    if self._is_ref_dtype:
88      return _INTERN_TABLE[self._type_enum - 100]
89    else:
90      return self
91
92  @property
93  def real_dtype(self):
94    """Returns the `DType` corresponding to this `DType`'s real part."""
95    base = self.base_dtype
96    if base == complex64:
97      return float32
98    elif base == complex128:
99      return float64
100    else:
101      return self
102
103  @property
104  def as_numpy_dtype(self):
105    """Returns a Python `type` object based on this `DType`."""
106    return _TF_TO_NP[self._type_enum]
107
108  @property
109  def min(self):
110    """Returns the minimum representable value in this data type.
111
112    Raises:
113      TypeError: if this is a non-numeric, unordered, or quantized type.
114
115    """
116    if (self.is_quantized or
117        self.base_dtype in (bool, string, complex64, complex128)):
118      raise TypeError("Cannot find minimum value of %s." % self)
119
120    # there is no simple way to get the min value of a dtype, we have to check
121    # float and int types separately
122    try:
123      return np.finfo(self.as_numpy_dtype).min
124    except:  # bare except as possible raises by finfo not documented
125      try:
126        return np.iinfo(self.as_numpy_dtype).min
127      except:
128        if self.base_dtype == bfloat16:
129          return _np_bfloat16(float.fromhex("-0x1.FEp127"))
130        raise TypeError("Cannot find minimum value of %s." % self)
131
132  @property
133  def max(self):
134    """Returns the maximum representable value in this data type.
135
136    Raises:
137      TypeError: if this is a non-numeric, unordered, or quantized type.
138
139    """
140    if (self.is_quantized or
141        self.base_dtype in (bool, string, complex64, complex128)):
142      raise TypeError("Cannot find maximum value of %s." % self)
143
144    # there is no simple way to get the max value of a dtype, we have to check
145    # float and int types separately
146    try:
147      return np.finfo(self.as_numpy_dtype).max
148    except:  # bare except as possible raises by finfo not documented
149      try:
150        return np.iinfo(self.as_numpy_dtype).max
151      except:
152        if self.base_dtype == bfloat16:
153          return _np_bfloat16(float.fromhex("0x1.FEp127"))
154        raise TypeError("Cannot find maximum value of %s." % self)
155
156  @property
157  def limits(self, clip_negative=True):
158    """Return intensity limits, i.e.
159
160    (min, max) tuple, of the dtype.
161    Args:
162      clip_negative : bool, optional If True, clip the negative range (i.e.
163        return 0 for min intensity) even if the image dtype allows negative
164        values. Returns
165      min, max : tuple Lower and upper intensity limits.
166    """
167    min, max = dtype_range[self.as_numpy_dtype]  # pylint: disable=redefined-builtin
168    if clip_negative:
169      min = 0  # pylint: disable=redefined-builtin
170    return min, max
171
172  def is_compatible_with(self, other):
173    """Returns True if the `other` DType will be converted to this DType.
174
175    The conversion rules are as follows:
176
177    ```python
178    DType(T)       .is_compatible_with(DType(T))        == True
179    ```
180
181    Args:
182      other: A `DType` (or object that may be converted to a `DType`).
183
184    Returns:
185      True if a Tensor of the `other` `DType` will be implicitly converted to
186      this `DType`.
187    """
188    other = as_dtype(other)
189    return self._type_enum in (other.as_datatype_enum,
190                               other.base_dtype.as_datatype_enum)
191
192  def __eq__(self, other):
193    """Returns True iff this DType refers to the same type as `other`."""
194    if other is None:
195      return False
196
197    if type(other) != DType:  # pylint: disable=unidiomatic-typecheck
198      try:
199        other = as_dtype(other)
200      except TypeError:
201        return False
202
203    return self._type_enum == other._type_enum  # pylint: disable=protected-access
204
205  def __ne__(self, other):
206    """Returns True iff self != other."""
207    return not self.__eq__(other)
208
209  # "If a class that overrides __eq__() needs to retain the implementation
210  #  of __hash__() from a parent class, the interpreter must be told this
211  #  explicitly by setting __hash__ = <ParentClass>.__hash__."
212  # TODO(slebedev): Remove once __eq__ and __ne__ are implemented in C++.
213  __hash__ = _dtypes.DType.__hash__
214
215  def __reduce__(self):
216    return as_dtype, (self.name,)
217# pylint: enable=slots-on-old-class
218
219
220# Define data type range of numpy dtype
221dtype_range = {
222    np.bool_: (False, True),
223    np.bool8: (False, True),
224    np.uint8: (0, 255),
225    np.uint16: (0, 65535),
226    np.int8: (-128, 127),
227    np.int16: (-32768, 32767),
228    np.int64: (-2**63, 2**63 - 1),
229    np.uint64: (0, 2**64 - 1),
230    np.int32: (-2**31, 2**31 - 1),
231    np.uint32: (0, 2**32 - 1),
232    np.float32: (-1, 1),
233    np.float64: (-1, 1)
234}
235
236# Define standard wrappers for the types_pb2.DataType enum.
237resource = DType(types_pb2.DT_RESOURCE)
238tf_export("dtypes.resource", "resource").export_constant(__name__, "resource")
239variant = DType(types_pb2.DT_VARIANT)
240tf_export("dtypes.variant", "variant").export_constant(__name__, "variant")
241float16 = DType(types_pb2.DT_HALF)
242tf_export("dtypes.float16", "float16").export_constant(__name__, "float16")
243half = float16
244tf_export("dtypes.half", "half").export_constant(__name__, "half")
245float32 = DType(types_pb2.DT_FLOAT)
246tf_export("dtypes.float32", "float32").export_constant(__name__, "float32")
247float64 = DType(types_pb2.DT_DOUBLE)
248tf_export("dtypes.float64", "float64").export_constant(__name__, "float64")
249double = float64
250tf_export("dtypes.double", "double").export_constant(__name__, "double")
251int32 = DType(types_pb2.DT_INT32)
252tf_export("dtypes.int32", "int32").export_constant(__name__, "int32")
253uint8 = DType(types_pb2.DT_UINT8)
254tf_export("dtypes.uint8", "uint8").export_constant(__name__, "uint8")
255uint16 = DType(types_pb2.DT_UINT16)
256tf_export("dtypes.uint16", "uint16").export_constant(__name__, "uint16")
257uint32 = DType(types_pb2.DT_UINT32)
258tf_export("dtypes.uint32", "uint32").export_constant(__name__, "uint32")
259uint64 = DType(types_pb2.DT_UINT64)
260tf_export("dtypes.uint64", "uint64").export_constant(__name__, "uint64")
261int16 = DType(types_pb2.DT_INT16)
262tf_export("dtypes.int16", "int16").export_constant(__name__, "int16")
263int8 = DType(types_pb2.DT_INT8)
264tf_export("dtypes.int8", "int8").export_constant(__name__, "int8")
265string = DType(types_pb2.DT_STRING)
266tf_export("dtypes.string", "string").export_constant(__name__, "string")
267complex64 = DType(types_pb2.DT_COMPLEX64)
268tf_export("dtypes.complex64",
269          "complex64").export_constant(__name__, "complex64")
270complex128 = DType(types_pb2.DT_COMPLEX128)
271tf_export("dtypes.complex128",
272          "complex128").export_constant(__name__, "complex128")
273int64 = DType(types_pb2.DT_INT64)
274tf_export("dtypes.int64", "int64").export_constant(__name__, "int64")
275bool = DType(types_pb2.DT_BOOL)  # pylint: disable=redefined-builtin
276tf_export("dtypes.bool", "bool").export_constant(__name__, "bool")
277qint8 = DType(types_pb2.DT_QINT8)
278tf_export("dtypes.qint8", "qint8").export_constant(__name__, "qint8")
279quint8 = DType(types_pb2.DT_QUINT8)
280tf_export("dtypes.quint8", "quint8").export_constant(__name__, "quint8")
281qint16 = DType(types_pb2.DT_QINT16)
282tf_export("dtypes.qint16", "qint16").export_constant(__name__, "qint16")
283quint16 = DType(types_pb2.DT_QUINT16)
284tf_export("dtypes.quint16", "quint16").export_constant(__name__, "quint16")
285qint32 = DType(types_pb2.DT_QINT32)
286tf_export("dtypes.qint32", "qint32").export_constant(__name__, "qint32")
287resource_ref = DType(types_pb2.DT_RESOURCE_REF)
288variant_ref = DType(types_pb2.DT_VARIANT_REF)
289bfloat16 = DType(types_pb2.DT_BFLOAT16)
290tf_export("dtypes.bfloat16", "bfloat16").export_constant(__name__, "bfloat16")
291float16_ref = DType(types_pb2.DT_HALF_REF)
292half_ref = float16_ref
293float32_ref = DType(types_pb2.DT_FLOAT_REF)
294float64_ref = DType(types_pb2.DT_DOUBLE_REF)
295double_ref = float64_ref
296int32_ref = DType(types_pb2.DT_INT32_REF)
297uint32_ref = DType(types_pb2.DT_UINT32_REF)
298uint8_ref = DType(types_pb2.DT_UINT8_REF)
299uint16_ref = DType(types_pb2.DT_UINT16_REF)
300int16_ref = DType(types_pb2.DT_INT16_REF)
301int8_ref = DType(types_pb2.DT_INT8_REF)
302string_ref = DType(types_pb2.DT_STRING_REF)
303complex64_ref = DType(types_pb2.DT_COMPLEX64_REF)
304complex128_ref = DType(types_pb2.DT_COMPLEX128_REF)
305int64_ref = DType(types_pb2.DT_INT64_REF)
306uint64_ref = DType(types_pb2.DT_UINT64_REF)
307bool_ref = DType(types_pb2.DT_BOOL_REF)
308qint8_ref = DType(types_pb2.DT_QINT8_REF)
309quint8_ref = DType(types_pb2.DT_QUINT8_REF)
310qint16_ref = DType(types_pb2.DT_QINT16_REF)
311quint16_ref = DType(types_pb2.DT_QUINT16_REF)
312qint32_ref = DType(types_pb2.DT_QINT32_REF)
313bfloat16_ref = DType(types_pb2.DT_BFLOAT16_REF)
314
315# Maintain an intern table so that we don't have to create a large
316# number of small objects.
317_INTERN_TABLE = {
318    types_pb2.DT_HALF: float16,
319    types_pb2.DT_FLOAT: float32,
320    types_pb2.DT_DOUBLE: float64,
321    types_pb2.DT_INT32: int32,
322    types_pb2.DT_UINT8: uint8,
323    types_pb2.DT_UINT16: uint16,
324    types_pb2.DT_UINT32: uint32,
325    types_pb2.DT_UINT64: uint64,
326    types_pb2.DT_INT16: int16,
327    types_pb2.DT_INT8: int8,
328    types_pb2.DT_STRING: string,
329    types_pb2.DT_COMPLEX64: complex64,
330    types_pb2.DT_COMPLEX128: complex128,
331    types_pb2.DT_INT64: int64,
332    types_pb2.DT_BOOL: bool,
333    types_pb2.DT_QINT8: qint8,
334    types_pb2.DT_QUINT8: quint8,
335    types_pb2.DT_QINT16: qint16,
336    types_pb2.DT_QUINT16: quint16,
337    types_pb2.DT_QINT32: qint32,
338    types_pb2.DT_BFLOAT16: bfloat16,
339    types_pb2.DT_RESOURCE: resource,
340    types_pb2.DT_VARIANT: variant,
341    types_pb2.DT_HALF_REF: float16_ref,
342    types_pb2.DT_FLOAT_REF: float32_ref,
343    types_pb2.DT_DOUBLE_REF: float64_ref,
344    types_pb2.DT_INT32_REF: int32_ref,
345    types_pb2.DT_UINT32_REF: uint32_ref,
346    types_pb2.DT_UINT8_REF: uint8_ref,
347    types_pb2.DT_UINT16_REF: uint16_ref,
348    types_pb2.DT_INT16_REF: int16_ref,
349    types_pb2.DT_INT8_REF: int8_ref,
350    types_pb2.DT_STRING_REF: string_ref,
351    types_pb2.DT_COMPLEX64_REF: complex64_ref,
352    types_pb2.DT_COMPLEX128_REF: complex128_ref,
353    types_pb2.DT_INT64_REF: int64_ref,
354    types_pb2.DT_UINT64_REF: uint64_ref,
355    types_pb2.DT_BOOL_REF: bool_ref,
356    types_pb2.DT_QINT8_REF: qint8_ref,
357    types_pb2.DT_QUINT8_REF: quint8_ref,
358    types_pb2.DT_QINT16_REF: qint16_ref,
359    types_pb2.DT_QUINT16_REF: quint16_ref,
360    types_pb2.DT_QINT32_REF: qint32_ref,
361    types_pb2.DT_BFLOAT16_REF: bfloat16_ref,
362    types_pb2.DT_RESOURCE_REF: resource_ref,
363    types_pb2.DT_VARIANT_REF: variant_ref,
364}
365
366# Standard mappings between types_pb2.DataType values and string names.
367_TYPE_TO_STRING = {
368    types_pb2.DT_HALF: "float16",
369    types_pb2.DT_FLOAT: "float32",
370    types_pb2.DT_DOUBLE: "float64",
371    types_pb2.DT_INT32: "int32",
372    types_pb2.DT_UINT8: "uint8",
373    types_pb2.DT_UINT16: "uint16",
374    types_pb2.DT_UINT32: "uint32",
375    types_pb2.DT_UINT64: "uint64",
376    types_pb2.DT_INT16: "int16",
377    types_pb2.DT_INT8: "int8",
378    types_pb2.DT_STRING: "string",
379    types_pb2.DT_COMPLEX64: "complex64",
380    types_pb2.DT_COMPLEX128: "complex128",
381    types_pb2.DT_INT64: "int64",
382    types_pb2.DT_BOOL: "bool",
383    types_pb2.DT_QINT8: "qint8",
384    types_pb2.DT_QUINT8: "quint8",
385    types_pb2.DT_QINT16: "qint16",
386    types_pb2.DT_QUINT16: "quint16",
387    types_pb2.DT_QINT32: "qint32",
388    types_pb2.DT_BFLOAT16: "bfloat16",
389    types_pb2.DT_RESOURCE: "resource",
390    types_pb2.DT_VARIANT: "variant",
391    types_pb2.DT_HALF_REF: "float16_ref",
392    types_pb2.DT_FLOAT_REF: "float32_ref",
393    types_pb2.DT_DOUBLE_REF: "float64_ref",
394    types_pb2.DT_INT32_REF: "int32_ref",
395    types_pb2.DT_UINT32_REF: "uint32_ref",
396    types_pb2.DT_UINT8_REF: "uint8_ref",
397    types_pb2.DT_UINT16_REF: "uint16_ref",
398    types_pb2.DT_INT16_REF: "int16_ref",
399    types_pb2.DT_INT8_REF: "int8_ref",
400    types_pb2.DT_STRING_REF: "string_ref",
401    types_pb2.DT_COMPLEX64_REF: "complex64_ref",
402    types_pb2.DT_COMPLEX128_REF: "complex128_ref",
403    types_pb2.DT_INT64_REF: "int64_ref",
404    types_pb2.DT_UINT64_REF: "uint64_ref",
405    types_pb2.DT_BOOL_REF: "bool_ref",
406    types_pb2.DT_QINT8_REF: "qint8_ref",
407    types_pb2.DT_QUINT8_REF: "quint8_ref",
408    types_pb2.DT_QINT16_REF: "qint16_ref",
409    types_pb2.DT_QUINT16_REF: "quint16_ref",
410    types_pb2.DT_QINT32_REF: "qint32_ref",
411    types_pb2.DT_BFLOAT16_REF: "bfloat16_ref",
412    types_pb2.DT_RESOURCE_REF: "resource_ref",
413    types_pb2.DT_VARIANT_REF: "variant_ref",
414}
415_STRING_TO_TF = {
416    value: _INTERN_TABLE[key] for key, value in _TYPE_TO_STRING.items()
417}
418# Add non-canonical aliases.
419_STRING_TO_TF["half"] = float16
420_STRING_TO_TF["half_ref"] = float16_ref
421_STRING_TO_TF["float"] = float32
422_STRING_TO_TF["float_ref"] = float32_ref
423_STRING_TO_TF["double"] = float64
424_STRING_TO_TF["double_ref"] = float64_ref
425
426# Numpy representation for quantized dtypes.
427#
428# These are magic strings that are used in the swig wrapper to identify
429# quantized types.
430# TODO(mrry,keveman): Investigate Numpy type registration to replace this
431# hard-coding of names.
432_np_qint8 = np.dtype([("qint8", np.int8)])
433_np_quint8 = np.dtype([("quint8", np.uint8)])
434_np_qint16 = np.dtype([("qint16", np.int16)])
435_np_quint16 = np.dtype([("quint16", np.uint16)])
436_np_qint32 = np.dtype([("qint32", np.int32)])
437
438# _np_bfloat16 is defined by a module import.
439
440# Custom struct dtype for directly-fed ResourceHandles of supported type(s).
441np_resource = np.dtype([("resource", np.ubyte)])
442
443# Standard mappings between types_pb2.DataType values and numpy.dtypes.
444_NP_TO_TF = {
445    np.float16: float16,
446    np.float32: float32,
447    np.float64: float64,
448    np.int32: int32,
449    np.int64: int64,
450    np.uint8: uint8,
451    np.uint16: uint16,
452    np.uint32: uint32,
453    np.uint64: uint64,
454    np.int16: int16,
455    np.int8: int8,
456    np.complex64: complex64,
457    np.complex128: complex128,
458    np.object_: string,
459    np.string_: string,
460    np.unicode_: string,
461    np.bool_: bool,
462    _np_qint8: qint8,
463    _np_quint8: quint8,
464    _np_qint16: qint16,
465    _np_quint16: quint16,
466    _np_qint32: qint32,
467    _np_bfloat16: bfloat16,
468}
469
470# Map (some) NumPy platform dtypes to TF ones using their fixed-width
471# synonyms. Note that platform dtypes are not always simples aliases,
472# i.e. reference equality is not guaranteed. See e.g. numpy/numpy#9799.
473for pdt in [
474    np.intc,
475    np.uintc,
476    np.int_,
477    np.uint,
478    np.longlong,
479    np.ulonglong,
480]:
481  if pdt not in _NP_TO_TF:
482    _NP_TO_TF[pdt] = next(
483        _NP_TO_TF[dt] for dt in _NP_TO_TF if dt == pdt().dtype)
484
485
486TF_VALUE_DTYPES = set(_NP_TO_TF.values())
487
488
489_TF_TO_NP = {
490    types_pb2.DT_HALF:
491        np.float16,
492    types_pb2.DT_FLOAT:
493        np.float32,
494    types_pb2.DT_DOUBLE:
495        np.float64,
496    types_pb2.DT_INT32:
497        np.int32,
498    types_pb2.DT_UINT8:
499        np.uint8,
500    types_pb2.DT_UINT16:
501        np.uint16,
502    types_pb2.DT_UINT32:
503        np.uint32,
504    types_pb2.DT_UINT64:
505        np.uint64,
506    types_pb2.DT_INT16:
507        np.int16,
508    types_pb2.DT_INT8:
509        np.int8,
510    # NOTE(touts): For strings we use np.object as it supports variable length
511    # strings.
512    types_pb2.DT_STRING:
513        np.object,
514    types_pb2.DT_COMPLEX64:
515        np.complex64,
516    types_pb2.DT_COMPLEX128:
517        np.complex128,
518    types_pb2.DT_INT64:
519        np.int64,
520    types_pb2.DT_BOOL:
521        np.bool_,
522    types_pb2.DT_QINT8:
523        _np_qint8,
524    types_pb2.DT_QUINT8:
525        _np_quint8,
526    types_pb2.DT_QINT16:
527        _np_qint16,
528    types_pb2.DT_QUINT16:
529        _np_quint16,
530    types_pb2.DT_QINT32:
531        _np_qint32,
532    types_pb2.DT_BFLOAT16:
533        _np_bfloat16,
534
535    # Ref types
536    types_pb2.DT_HALF_REF:
537        np.float16,
538    types_pb2.DT_FLOAT_REF:
539        np.float32,
540    types_pb2.DT_DOUBLE_REF:
541        np.float64,
542    types_pb2.DT_INT32_REF:
543        np.int32,
544    types_pb2.DT_UINT32_REF:
545        np.uint32,
546    types_pb2.DT_UINT8_REF:
547        np.uint8,
548    types_pb2.DT_UINT16_REF:
549        np.uint16,
550    types_pb2.DT_INT16_REF:
551        np.int16,
552    types_pb2.DT_INT8_REF:
553        np.int8,
554    types_pb2.DT_STRING_REF:
555        np.object,
556    types_pb2.DT_COMPLEX64_REF:
557        np.complex64,
558    types_pb2.DT_COMPLEX128_REF:
559        np.complex128,
560    types_pb2.DT_INT64_REF:
561        np.int64,
562    types_pb2.DT_UINT64_REF:
563        np.uint64,
564    types_pb2.DT_BOOL_REF:
565        np.bool,
566    types_pb2.DT_QINT8_REF:
567        _np_qint8,
568    types_pb2.DT_QUINT8_REF:
569        _np_quint8,
570    types_pb2.DT_QINT16_REF:
571        _np_qint16,
572    types_pb2.DT_QUINT16_REF:
573        _np_quint16,
574    types_pb2.DT_QINT32_REF:
575        _np_qint32,
576    types_pb2.DT_BFLOAT16_REF:
577        _np_bfloat16,
578}
579
580_QUANTIZED_DTYPES_NO_REF = frozenset([qint8, quint8, qint16, quint16, qint32])
581_QUANTIZED_DTYPES_REF = frozenset(
582    [qint8_ref, quint8_ref, qint16_ref, quint16_ref, qint32_ref])
583QUANTIZED_DTYPES = _QUANTIZED_DTYPES_REF.union(_QUANTIZED_DTYPES_NO_REF)
584tf_export(
585    "dtypes.QUANTIZED_DTYPES",
586    v1=["dtypes.QUANTIZED_DTYPES",
587        "QUANTIZED_DTYPES"]).export_constant(__name__, "QUANTIZED_DTYPES")
588
589_PYTHON_TO_TF = {
590    builtins.float: float32,
591    builtins.bool: bool,
592    builtins.object: string
593}
594
595_ANY_TO_TF = {}
596_ANY_TO_TF.update(_INTERN_TABLE)
597_ANY_TO_TF.update(_STRING_TO_TF)
598_ANY_TO_TF.update(_PYTHON_TO_TF)
599_ANY_TO_TF.update(_NP_TO_TF)
600
601# Ensure no collisions.
602assert len(_ANY_TO_TF) == sum(
603    len(d) for d in [_INTERN_TABLE, _STRING_TO_TF, _PYTHON_TO_TF, _NP_TO_TF])
604
605
606@tf_export("dtypes.as_dtype", "as_dtype")
607def as_dtype(type_value):
608  """Converts the given `type_value` to a `DType`.
609
610  Note: `DType` values are interned. When passed a new `DType` object,
611  `as_dtype` always returns the interned value.
612
613  Args:
614    type_value: A value that can be converted to a `tf.DType` object. This may
615      currently be a `tf.DType` object, a [`DataType`
616      enum](https://www.tensorflow.org/code/tensorflow/core/framework/types.proto),
617        a string type name, or a `numpy.dtype`.
618
619  Returns:
620    A `DType` corresponding to `type_value`.
621
622  Raises:
623    TypeError: If `type_value` cannot be converted to a `DType`.
624  """
625  if isinstance(type_value, DType):
626    return _INTERN_TABLE[type_value.as_datatype_enum]
627
628  if isinstance(type_value, np.dtype):
629    try:
630      return _NP_TO_TF[type_value.type]
631    except KeyError:
632      pass
633
634  try:
635    return _ANY_TO_TF[type_value]
636  except (KeyError, TypeError):
637    # TypeError indicates that type_value is not hashable.
638    pass
639
640  if hasattr(type_value, "dtype"):
641    try:
642      return _NP_TO_TF[np.dtype(type_value.dtype).type]
643    except (KeyError, TypeError):
644      pass
645
646  if isinstance(type_value, _dtypes.DType):
647    return _INTERN_TABLE[type_value.as_datatype_enum]
648
649  raise TypeError("Cannot convert value %r to a TensorFlow DType." %
650                  (type_value,))
651