1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Library of dtypes (Tensor element types)."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21from six.moves import builtins
22
23from tensorflow.core.framework import types_pb2
24from tensorflow.python import pywrap_tensorflow
25from tensorflow.python.util.tf_export import tf_export
26
27_np_bfloat16 = pywrap_tensorflow.TF_bfloat16_type()
28
29
30@tf_export("dtypes.DType", "DType")
31class DType(object):
32  """Represents the type of the elements in a `Tensor`.
33
34  The following `DType` objects are defined:
35
36  * `tf.float16`: 16-bit half-precision floating-point.
37  * `tf.float32`: 32-bit single-precision floating-point.
38  * `tf.float64`: 64-bit double-precision floating-point.
39  * `tf.bfloat16`: 16-bit truncated floating-point.
40  * `tf.complex64`: 64-bit single-precision complex.
41  * `tf.complex128`: 128-bit double-precision complex.
42  * `tf.int8`: 8-bit signed integer.
43  * `tf.uint8`: 8-bit unsigned integer.
44  * `tf.uint16`: 16-bit unsigned integer.
45  * `tf.uint32`: 32-bit unsigned integer.
46  * `tf.uint64`: 64-bit unsigned integer.
47  * `tf.int16`: 16-bit signed integer.
48  * `tf.int32`: 32-bit signed integer.
49  * `tf.int64`: 64-bit signed integer.
50  * `tf.bool`: Boolean.
51  * `tf.string`: String.
52  * `tf.qint8`: Quantized 8-bit signed integer.
53  * `tf.quint8`: Quantized 8-bit unsigned integer.
54  * `tf.qint16`: Quantized 16-bit signed integer.
55  * `tf.quint16`: Quantized 16-bit unsigned integer.
56  * `tf.qint32`: Quantized 32-bit signed integer.
57  * `tf.resource`: Handle to a mutable resource.
58  * `tf.variant`: Values of arbitrary types.
59
60  In addition, variants of these types with the `_ref` suffix are
61  defined for reference-typed tensors.
62
63  The `tf.as_dtype()` function converts numpy types and string type
64  names to a `DType` object.
65  """
66
67  def __init__(self, type_enum):
68    """Creates a new `DataType`.
69
70    NOTE(mrry): In normal circumstances, you should not need to
71    construct a `DataType` object directly. Instead, use the
72    `tf.as_dtype()` function.
73
74    Args:
75      type_enum: A `types_pb2.DataType` enum value.
76
77    Raises:
78      TypeError: If `type_enum` is not a value `types_pb2.DataType`.
79
80    """
81    # TODO(mrry): Make the necessary changes (using __new__) to ensure
82    # that calling this returns one of the interned values.
83    type_enum = int(type_enum)
84    if (type_enum not in types_pb2.DataType.values() or
85        type_enum == types_pb2.DT_INVALID):
86      raise TypeError(
87          "type_enum is not a valid types_pb2.DataType: %s" % type_enum)
88    self._type_enum = type_enum
89
90  @property
91  def _is_ref_dtype(self):
92    """Returns `True` if this `DType` represents a reference type."""
93    return self._type_enum > 100
94
95  @property
96  def _as_ref(self):
97    """Returns a reference `DType` based on this `DType`."""
98    if self._is_ref_dtype:
99      return self
100    else:
101      return _INTERN_TABLE[self._type_enum + 100]
102
103  @property
104  def base_dtype(self):
105    """Returns a non-reference `DType` based on this `DType`."""
106    if self._is_ref_dtype:
107      return _INTERN_TABLE[self._type_enum - 100]
108    else:
109      return self
110
111  @property
112  def real_dtype(self):
113    """Returns the dtype correspond to this dtype's real part."""
114    base = self.base_dtype
115    if base == complex64:
116      return float32
117    elif base == complex128:
118      return float64
119    else:
120      return self
121
122  @property
123  def is_numpy_compatible(self):
124    return self._type_enum not in _NUMPY_INCOMPATIBLE
125
126  @property
127  def as_numpy_dtype(self):
128    """Returns a `numpy.dtype` based on this `DType`."""
129    return _TF_TO_NP[self._type_enum]
130
131  @property
132  def as_datatype_enum(self):
133    """Returns a `types_pb2.DataType` enum value based on this `DType`."""
134    return self._type_enum
135
136  @property
137  def is_bool(self):
138    """Returns whether this is a boolean data type"""
139    return self.base_dtype == bool
140
141  @property
142  def is_integer(self):
143    """Returns whether this is a (non-quantized) integer type."""
144    return (self.is_numpy_compatible and not self.is_quantized and
145            np.issubdtype(self.as_numpy_dtype, np.integer))
146
147  @property
148  def is_floating(self):
149    """Returns whether this is a (non-quantized, real) floating point type."""
150    return ((self.is_numpy_compatible and
151             np.issubdtype(self.as_numpy_dtype, np.floating)) or
152            self.base_dtype == bfloat16)
153
154  @property
155  def is_complex(self):
156    """Returns whether this is a complex floating point type."""
157    return self.base_dtype in (complex64, complex128)
158
159  @property
160  def is_quantized(self):
161    """Returns whether this is a quantized data type."""
162    return self.base_dtype in _QUANTIZED_DTYPES_NO_REF
163
164  @property
165  def is_unsigned(self):
166    """Returns whether this type is unsigned.
167
168    Non-numeric, unordered, and quantized types are not considered unsigned, and
169    this function returns `False`.
170
171    Returns:
172      Whether a `DType` is unsigned.
173    """
174    try:
175      return self.min == 0
176    except TypeError:
177      return False
178
179  @property
180  def min(self):
181    """Returns the minimum representable value in this data type.
182
183    Raises:
184      TypeError: if this is a non-numeric, unordered, or quantized type.
185
186    """
187    if (self.is_quantized or
188        self.base_dtype in (bool, string, complex64, complex128)):
189      raise TypeError("Cannot find minimum value of %s." % self)
190
191    # there is no simple way to get the min value of a dtype, we have to check
192    # float and int types separately
193    try:
194      return np.finfo(self.as_numpy_dtype()).min
195    except:  # bare except as possible raises by finfo not documented
196      try:
197        return np.iinfo(self.as_numpy_dtype()).min
198      except:
199        if self.base_dtype == bfloat16:
200          return _np_bfloat16(float.fromhex("-0x1.FEp127"))
201        raise TypeError("Cannot find minimum value of %s." % self)
202
203  @property
204  def max(self):
205    """Returns the maximum representable value in this data type.
206
207    Raises:
208      TypeError: if this is a non-numeric, unordered, or quantized type.
209
210    """
211    if (self.is_quantized or
212        self.base_dtype in (bool, string, complex64, complex128)):
213      raise TypeError("Cannot find maximum value of %s." % self)
214
215    # there is no simple way to get the max value of a dtype, we have to check
216    # float and int types separately
217    try:
218      return np.finfo(self.as_numpy_dtype()).max
219    except:  # bare except as possible raises by finfo not documented
220      try:
221        return np.iinfo(self.as_numpy_dtype()).max
222      except:
223        if self.base_dtype == bfloat16:
224          return _np_bfloat16(float.fromhex("0x1.FEp127"))
225        raise TypeError("Cannot find maximum value of %s." % self)
226
227  @property
228  def limits(self, clip_negative=True):
229    """Return intensity limits, i.e. (min, max) tuple, of the dtype.
230    Args:
231      clip_negative : bool, optional
232          If True, clip the negative range (i.e. return 0 for min intensity)
233          even if the image dtype allows negative values.
234    Returns
235      min, max : tuple
236        Lower and upper intensity limits.
237    """
238    min, max = dtype_range[self.as_numpy_dtype]  # pylint: disable=redefined-builtin
239    if clip_negative:
240      min = 0  # pylint: disable=redefined-builtin
241    return min, max
242
243  def is_compatible_with(self, other):
244    """Returns True if the `other` DType will be converted to this DType.
245
246    The conversion rules are as follows:
247
248    ```python
249    DType(T)       .is_compatible_with(DType(T))        == True
250    DType(T)       .is_compatible_with(DType(T).as_ref) == True
251    DType(T).as_ref.is_compatible_with(DType(T))        == False
252    DType(T).as_ref.is_compatible_with(DType(T).as_ref) == True
253    ```
254
255    Args:
256      other: A `DType` (or object that may be converted to a `DType`).
257
258    Returns:
259      True if a Tensor of the `other` `DType` will be implicitly converted to
260      this `DType`.
261    """
262    other = as_dtype(other)
263    return self._type_enum in (other.as_datatype_enum,
264                               other.base_dtype.as_datatype_enum)
265
266  def __eq__(self, other):
267    """Returns True iff this DType refers to the same type as `other`."""
268    if other is None:
269      return False
270    try:
271      dtype = as_dtype(other).as_datatype_enum
272      return self._type_enum == dtype  # pylint: disable=protected-access
273    except TypeError:
274      return False
275
276  def __ne__(self, other):
277    """Returns True iff self != other."""
278    return not self.__eq__(other)
279
280  @property
281  def name(self):
282    """Returns the string name for this `DType`."""
283    return _TYPE_TO_STRING[self._type_enum]
284
285  def __str__(self):
286    return "<dtype: %r>" % self.name
287
288  def __repr__(self):
289    return "tf." + self.name
290
291  def __hash__(self):
292    return self._type_enum
293
294  def __reduce__(self):
295    return as_dtype, (self.name,)
296
297  @property
298  def size(self):
299    if (self._type_enum == types_pb2.DT_VARIANT or
300        self._type_enum == types_pb2.DT_RESOURCE):
301      return 1
302    return np.dtype(self.as_numpy_dtype).itemsize
303
304
305# Define data type range of numpy dtype
306dtype_range = {
307    np.bool_: (False, True),
308    np.bool8: (False, True),
309    np.uint8: (0, 255),
310    np.uint16: (0, 65535),
311    np.int8: (-128, 127),
312    np.int16: (-32768, 32767),
313    np.int64: (-2**63, 2**63 - 1),
314    np.uint64: (0, 2**64 - 1),
315    np.int32: (-2**31, 2**31 - 1),
316    np.uint32: (0, 2**32 - 1),
317    np.float32: (-1, 1),
318    np.float64: (-1, 1)
319}
320
321# Define standard wrappers for the types_pb2.DataType enum.
322resource = DType(types_pb2.DT_RESOURCE)
323tf_export("dtypes.resource", "resource").export_constant(__name__, "resource")
324variant = DType(types_pb2.DT_VARIANT)
325tf_export("dtypes.variant", "variant").export_constant(__name__, "variant")
326float16 = DType(types_pb2.DT_HALF)
327tf_export("dtypes.float16", "float16").export_constant(__name__, "float16")
328half = float16
329tf_export("dtypes.half", "half").export_constant(__name__, "half")
330float32 = DType(types_pb2.DT_FLOAT)
331tf_export("dtypes.float32", "float32").export_constant(__name__, "float32")
332float64 = DType(types_pb2.DT_DOUBLE)
333tf_export("dtypes.float64", "float64").export_constant(__name__, "float64")
334double = float64
335tf_export("dtypes.double", "double").export_constant(__name__, "double")
336int32 = DType(types_pb2.DT_INT32)
337tf_export("dtypes.int32", "int32").export_constant(__name__, "int32")
338uint8 = DType(types_pb2.DT_UINT8)
339tf_export("dtypes.uint8", "uint8").export_constant(__name__, "uint8")
340uint16 = DType(types_pb2.DT_UINT16)
341tf_export("dtypes.uint16", "uint16").export_constant(__name__, "uint16")
342uint32 = DType(types_pb2.DT_UINT32)
343tf_export("dtypes.uint32", "uint32").export_constant(__name__, "uint32")
344uint64 = DType(types_pb2.DT_UINT64)
345tf_export("dtypes.uint64", "uint64").export_constant(__name__, "uint64")
346int16 = DType(types_pb2.DT_INT16)
347tf_export("dtypes.int16", "int16").export_constant(__name__, "int16")
348int8 = DType(types_pb2.DT_INT8)
349tf_export("dtypes.int8", "int8").export_constant(__name__, "int8")
350string = DType(types_pb2.DT_STRING)
351tf_export("dtypes.string", "string").export_constant(__name__, "string")
352complex64 = DType(types_pb2.DT_COMPLEX64)
353tf_export("dtypes.complex64", "complex64").export_constant(
354    __name__, "complex64")
355complex128 = DType(types_pb2.DT_COMPLEX128)
356tf_export("dtypes.complex128", "complex128").export_constant(
357    __name__, "complex128")
358int64 = DType(types_pb2.DT_INT64)
359tf_export("dtypes.int64", "int64").export_constant(__name__, "int64")
360bool = DType(types_pb2.DT_BOOL)  # pylint: disable=redefined-builtin
361tf_export("dtypes.bool", "bool").export_constant(__name__, "bool")
362qint8 = DType(types_pb2.DT_QINT8)
363tf_export("dtypes.qint8", "qint8").export_constant(__name__, "qint8")
364quint8 = DType(types_pb2.DT_QUINT8)
365tf_export("dtypes.quint8", "quint8").export_constant(__name__, "quint8")
366qint16 = DType(types_pb2.DT_QINT16)
367tf_export("dtypes.qint16", "qint16").export_constant(__name__, "qint16")
368quint16 = DType(types_pb2.DT_QUINT16)
369tf_export("dtypes.quint16", "quint16").export_constant(__name__, "quint16")
370qint32 = DType(types_pb2.DT_QINT32)
371tf_export("dtypes.qint32", "qint32").export_constant(__name__, "qint32")
372resource_ref = DType(types_pb2.DT_RESOURCE_REF)
373variant_ref = DType(types_pb2.DT_VARIANT_REF)
374bfloat16 = DType(types_pb2.DT_BFLOAT16)
375tf_export("dtypes.bfloat16", "bfloat16").export_constant(__name__, "bfloat16")
376float16_ref = DType(types_pb2.DT_HALF_REF)
377half_ref = float16_ref
378float32_ref = DType(types_pb2.DT_FLOAT_REF)
379float64_ref = DType(types_pb2.DT_DOUBLE_REF)
380double_ref = float64_ref
381int32_ref = DType(types_pb2.DT_INT32_REF)
382uint32_ref = DType(types_pb2.DT_UINT32_REF)
383uint8_ref = DType(types_pb2.DT_UINT8_REF)
384uint16_ref = DType(types_pb2.DT_UINT16_REF)
385int16_ref = DType(types_pb2.DT_INT16_REF)
386int8_ref = DType(types_pb2.DT_INT8_REF)
387string_ref = DType(types_pb2.DT_STRING_REF)
388complex64_ref = DType(types_pb2.DT_COMPLEX64_REF)
389complex128_ref = DType(types_pb2.DT_COMPLEX128_REF)
390int64_ref = DType(types_pb2.DT_INT64_REF)
391uint64_ref = DType(types_pb2.DT_UINT64_REF)
392bool_ref = DType(types_pb2.DT_BOOL_REF)
393qint8_ref = DType(types_pb2.DT_QINT8_REF)
394quint8_ref = DType(types_pb2.DT_QUINT8_REF)
395qint16_ref = DType(types_pb2.DT_QINT16_REF)
396quint16_ref = DType(types_pb2.DT_QUINT16_REF)
397qint32_ref = DType(types_pb2.DT_QINT32_REF)
398bfloat16_ref = DType(types_pb2.DT_BFLOAT16_REF)
399
400_NUMPY_INCOMPATIBLE = frozenset([
401    types_pb2.DT_VARIANT, types_pb2.DT_VARIANT_REF, types_pb2.DT_RESOURCE,
402    types_pb2.DT_RESOURCE_REF
403])
404
405# Maintain an intern table so that we don't have to create a large
406# number of small objects.
407_INTERN_TABLE = {
408    types_pb2.DT_HALF: float16,
409    types_pb2.DT_FLOAT: float32,
410    types_pb2.DT_DOUBLE: float64,
411    types_pb2.DT_INT32: int32,
412    types_pb2.DT_UINT8: uint8,
413    types_pb2.DT_UINT16: uint16,
414    types_pb2.DT_UINT32: uint32,
415    types_pb2.DT_UINT64: uint64,
416    types_pb2.DT_INT16: int16,
417    types_pb2.DT_INT8: int8,
418    types_pb2.DT_STRING: string,
419    types_pb2.DT_COMPLEX64: complex64,
420    types_pb2.DT_COMPLEX128: complex128,
421    types_pb2.DT_INT64: int64,
422    types_pb2.DT_BOOL: bool,
423    types_pb2.DT_QINT8: qint8,
424    types_pb2.DT_QUINT8: quint8,
425    types_pb2.DT_QINT16: qint16,
426    types_pb2.DT_QUINT16: quint16,
427    types_pb2.DT_QINT32: qint32,
428    types_pb2.DT_BFLOAT16: bfloat16,
429    types_pb2.DT_RESOURCE: resource,
430    types_pb2.DT_VARIANT: variant,
431    types_pb2.DT_HALF_REF: float16_ref,
432    types_pb2.DT_FLOAT_REF: float32_ref,
433    types_pb2.DT_DOUBLE_REF: float64_ref,
434    types_pb2.DT_INT32_REF: int32_ref,
435    types_pb2.DT_UINT32_REF: uint32_ref,
436    types_pb2.DT_UINT8_REF: uint8_ref,
437    types_pb2.DT_UINT16_REF: uint16_ref,
438    types_pb2.DT_INT16_REF: int16_ref,
439    types_pb2.DT_INT8_REF: int8_ref,
440    types_pb2.DT_STRING_REF: string_ref,
441    types_pb2.DT_COMPLEX64_REF: complex64_ref,
442    types_pb2.DT_COMPLEX128_REF: complex128_ref,
443    types_pb2.DT_INT64_REF: int64_ref,
444    types_pb2.DT_UINT64_REF: uint64_ref,
445    types_pb2.DT_BOOL_REF: bool_ref,
446    types_pb2.DT_QINT8_REF: qint8_ref,
447    types_pb2.DT_QUINT8_REF: quint8_ref,
448    types_pb2.DT_QINT16_REF: qint16_ref,
449    types_pb2.DT_QUINT16_REF: quint16_ref,
450    types_pb2.DT_QINT32_REF: qint32_ref,
451    types_pb2.DT_BFLOAT16_REF: bfloat16_ref,
452    types_pb2.DT_RESOURCE_REF: resource_ref,
453    types_pb2.DT_VARIANT_REF: variant_ref,
454}
455
456# Standard mappings between types_pb2.DataType values and string names.
457_TYPE_TO_STRING = {
458    types_pb2.DT_HALF: "float16",
459    types_pb2.DT_FLOAT: "float32",
460    types_pb2.DT_DOUBLE: "float64",
461    types_pb2.DT_INT32: "int32",
462    types_pb2.DT_UINT8: "uint8",
463    types_pb2.DT_UINT16: "uint16",
464    types_pb2.DT_UINT32: "uint32",
465    types_pb2.DT_UINT64: "uint64",
466    types_pb2.DT_INT16: "int16",
467    types_pb2.DT_INT8: "int8",
468    types_pb2.DT_STRING: "string",
469    types_pb2.DT_COMPLEX64: "complex64",
470    types_pb2.DT_COMPLEX128: "complex128",
471    types_pb2.DT_INT64: "int64",
472    types_pb2.DT_BOOL: "bool",
473    types_pb2.DT_QINT8: "qint8",
474    types_pb2.DT_QUINT8: "quint8",
475    types_pb2.DT_QINT16: "qint16",
476    types_pb2.DT_QUINT16: "quint16",
477    types_pb2.DT_QINT32: "qint32",
478    types_pb2.DT_BFLOAT16: "bfloat16",
479    types_pb2.DT_RESOURCE: "resource",
480    types_pb2.DT_VARIANT: "variant",
481    types_pb2.DT_HALF_REF: "float16_ref",
482    types_pb2.DT_FLOAT_REF: "float32_ref",
483    types_pb2.DT_DOUBLE_REF: "float64_ref",
484    types_pb2.DT_INT32_REF: "int32_ref",
485    types_pb2.DT_UINT32_REF: "uint32_ref",
486    types_pb2.DT_UINT8_REF: "uint8_ref",
487    types_pb2.DT_UINT16_REF: "uint16_ref",
488    types_pb2.DT_INT16_REF: "int16_ref",
489    types_pb2.DT_INT8_REF: "int8_ref",
490    types_pb2.DT_STRING_REF: "string_ref",
491    types_pb2.DT_COMPLEX64_REF: "complex64_ref",
492    types_pb2.DT_COMPLEX128_REF: "complex128_ref",
493    types_pb2.DT_INT64_REF: "int64_ref",
494    types_pb2.DT_UINT64_REF: "uint64_ref",
495    types_pb2.DT_BOOL_REF: "bool_ref",
496    types_pb2.DT_QINT8_REF: "qint8_ref",
497    types_pb2.DT_QUINT8_REF: "quint8_ref",
498    types_pb2.DT_QINT16_REF: "qint16_ref",
499    types_pb2.DT_QUINT16_REF: "quint16_ref",
500    types_pb2.DT_QINT32_REF: "qint32_ref",
501    types_pb2.DT_BFLOAT16_REF: "bfloat16_ref",
502    types_pb2.DT_RESOURCE_REF: "resource_ref",
503    types_pb2.DT_VARIANT_REF: "variant_ref",
504}
505_STRING_TO_TF = {
506    value: _INTERN_TABLE[key]
507    for key, value in _TYPE_TO_STRING.items()
508}
509# Add non-canonical aliases.
510_STRING_TO_TF["half"] = float16
511_STRING_TO_TF["half_ref"] = float16_ref
512_STRING_TO_TF["float"] = float32
513_STRING_TO_TF["float_ref"] = float32_ref
514_STRING_TO_TF["double"] = float64
515_STRING_TO_TF["double_ref"] = float64_ref
516
517# Numpy representation for quantized dtypes.
518#
519# These are magic strings that are used in the swig wrapper to identify
520# quantized types.
521# TODO(mrry,keveman): Investigate Numpy type registration to replace this
522# hard-coding of names.
523_np_qint8 = np.dtype([("qint8", np.int8, 1)])
524_np_quint8 = np.dtype([("quint8", np.uint8, 1)])
525_np_qint16 = np.dtype([("qint16", np.int16, 1)])
526_np_quint16 = np.dtype([("quint16", np.uint16, 1)])
527_np_qint32 = np.dtype([("qint32", np.int32, 1)])
528
529# _np_bfloat16 is defined by a module import.
530
531# Custom struct dtype for directly-fed ResourceHandles of supported type(s).
532np_resource = np.dtype([("resource", np.ubyte, 1)])
533
534# Standard mappings between types_pb2.DataType values and numpy.dtypes.
535_NP_TO_TF = {
536    np.float16: float16,
537    np.float32: float32,
538    np.float64: float64,
539    np.int32: int32,
540    np.int64: int64,
541    np.uint8: uint8,
542    np.uint16: uint16,
543    np.uint32: uint32,
544    np.uint64: uint64,
545    np.int16: int16,
546    np.int8: int8,
547    np.complex64: complex64,
548    np.complex128: complex128,
549    np.object_: string,
550    np.string_: string,
551    np.unicode_: string,
552    np.bool_: bool,
553    _np_qint8: qint8,
554    _np_quint8: quint8,
555    _np_qint16: qint16,
556    _np_quint16: quint16,
557    _np_qint32: qint32,
558    _np_bfloat16: bfloat16,
559}
560
561# Map (some) NumPy platform dtypes to TF ones using their fixed-width
562# synonyms. Note that platform dtypes are not always simples aliases,
563# i.e. reference equality is not guaranteed. See e.g. numpy/numpy#9799.
564for pdt in [
565    np.intc,
566    np.uintc,
567    np.int_,
568    np.uint,
569    np.longlong,
570    np.ulonglong,
571]:
572  if pdt not in _NP_TO_TF:
573    _NP_TO_TF[pdt] = next(
574        _NP_TO_TF[dt] for dt in _NP_TO_TF if dt == pdt().dtype)
575
576_TF_TO_NP = {
577    types_pb2.DT_HALF:
578        np.float16,
579    types_pb2.DT_FLOAT:
580        np.float32,
581    types_pb2.DT_DOUBLE:
582        np.float64,
583    types_pb2.DT_INT32:
584        np.int32,
585    types_pb2.DT_UINT8:
586        np.uint8,
587    types_pb2.DT_UINT16:
588        np.uint16,
589    types_pb2.DT_UINT32:
590        np.uint32,
591    types_pb2.DT_UINT64:
592        np.uint64,
593    types_pb2.DT_INT16:
594        np.int16,
595    types_pb2.DT_INT8:
596        np.int8,
597    # NOTE(touts): For strings we use np.object as it supports variable length
598    # strings.
599    types_pb2.DT_STRING:
600        np.object,
601    types_pb2.DT_COMPLEX64:
602        np.complex64,
603    types_pb2.DT_COMPLEX128:
604        np.complex128,
605    types_pb2.DT_INT64:
606        np.int64,
607    types_pb2.DT_BOOL:
608        np.bool,
609    types_pb2.DT_QINT8:
610        _np_qint8,
611    types_pb2.DT_QUINT8:
612        _np_quint8,
613    types_pb2.DT_QINT16:
614        _np_qint16,
615    types_pb2.DT_QUINT16:
616        _np_quint16,
617    types_pb2.DT_QINT32:
618        _np_qint32,
619    types_pb2.DT_BFLOAT16:
620        _np_bfloat16,
621
622    # Ref types
623    types_pb2.DT_HALF_REF:
624        np.float16,
625    types_pb2.DT_FLOAT_REF:
626        np.float32,
627    types_pb2.DT_DOUBLE_REF:
628        np.float64,
629    types_pb2.DT_INT32_REF:
630        np.int32,
631    types_pb2.DT_UINT32_REF:
632        np.uint32,
633    types_pb2.DT_UINT8_REF:
634        np.uint8,
635    types_pb2.DT_UINT16_REF:
636        np.uint16,
637    types_pb2.DT_INT16_REF:
638        np.int16,
639    types_pb2.DT_INT8_REF:
640        np.int8,
641    types_pb2.DT_STRING_REF:
642        np.object,
643    types_pb2.DT_COMPLEX64_REF:
644        np.complex64,
645    types_pb2.DT_COMPLEX128_REF:
646        np.complex128,
647    types_pb2.DT_INT64_REF:
648        np.int64,
649    types_pb2.DT_UINT64_REF:
650        np.uint64,
651    types_pb2.DT_BOOL_REF:
652        np.bool,
653    types_pb2.DT_QINT8_REF:
654        _np_qint8,
655    types_pb2.DT_QUINT8_REF:
656        _np_quint8,
657    types_pb2.DT_QINT16_REF:
658        _np_qint16,
659    types_pb2.DT_QUINT16_REF:
660        _np_quint16,
661    types_pb2.DT_QINT32_REF:
662        _np_qint32,
663    types_pb2.DT_BFLOAT16_REF:
664        _np_bfloat16,
665}
666
667_QUANTIZED_DTYPES_NO_REF = frozenset([qint8, quint8, qint16, quint16, qint32])
668_QUANTIZED_DTYPES_REF = frozenset(
669    [qint8_ref, quint8_ref, qint16_ref, quint16_ref, qint32_ref])
670QUANTIZED_DTYPES = _QUANTIZED_DTYPES_REF.union(_QUANTIZED_DTYPES_NO_REF)
671tf_export(
672    "dtypes.QUANTIZED_DTYPES",
673    v1=["dtypes.QUANTIZED_DTYPES", "QUANTIZED_DTYPES"]).export_constant(
674        __name__, "QUANTIZED_DTYPES")
675
676_PYTHON_TO_TF = {
677    builtins.float: float32,
678    builtins.bool: bool,
679    builtins.object: string
680}
681
682_ANY_TO_TF = {}
683_ANY_TO_TF.update(_INTERN_TABLE)
684_ANY_TO_TF.update(_STRING_TO_TF)
685_ANY_TO_TF.update(_PYTHON_TO_TF)
686_ANY_TO_TF.update(_NP_TO_TF)
687
688# Ensure no collisions.
689assert len(_ANY_TO_TF) == sum(len(d) for d in [
690    _INTERN_TABLE,
691    _STRING_TO_TF,
692    _PYTHON_TO_TF,
693    _NP_TO_TF
694])
695
696
697@tf_export("dtypes.as_dtype", "as_dtype")
698def as_dtype(type_value):
699  """Converts the given `type_value` to a `DType`.
700
701  Args:
702    type_value: A value that can be converted to a `tf.DType` object. This may
703      currently be a `tf.DType` object, a [`DataType`
704      enum](https://www.tensorflow.org/code/tensorflow/core/framework/types.proto),
705      a string type name, or a `numpy.dtype`.
706
707  Returns:
708    A `DType` corresponding to `type_value`.
709
710  Raises:
711    TypeError: If `type_value` cannot be converted to a `DType`.
712  """
713  if isinstance(type_value, DType):
714    return type_value
715
716  if isinstance(type_value, np.dtype):
717    try:
718      return _NP_TO_TF[type_value.type]
719    except KeyError:
720      pass
721
722  try:
723    return _ANY_TO_TF[type_value]
724  except KeyError:
725    pass
726
727  raise TypeError(
728      "Cannot convert value %r to a TensorFlow DType." % (type_value,))
729