1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Dtypes and dtype utilities."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import dtypes
24from tensorflow.python.ops.numpy_ops import np_export
25
26
27# We use numpy's dtypes instead of TF's, because the user expects to use them
28# with numpy facilities such as `np.dtype(np.int64)` and
29# `if x.dtype.type is np.int64`.
30bool_ = np_export.np_export_constant(__name__, 'bool_', np.bool_)
31complex_ = np_export.np_export_constant(__name__, 'complex_', np.complex_)
32complex128 = np_export.np_export_constant(__name__, 'complex128', np.complex128)
33complex64 = np_export.np_export_constant(__name__, 'complex64', np.complex64)
34float_ = np_export.np_export_constant(__name__, 'float_', np.float_)
35float16 = np_export.np_export_constant(__name__, 'float16', np.float16)
36float32 = np_export.np_export_constant(__name__, 'float32', np.float32)
37float64 = np_export.np_export_constant(__name__, 'float64', np.float64)
38inexact = np_export.np_export_constant(__name__, 'inexact', np.inexact)
39int_ = np_export.np_export_constant(__name__, 'int_', np.int_)
40int16 = np_export.np_export_constant(__name__, 'int16', np.int16)
41int32 = np_export.np_export_constant(__name__, 'int32', np.int32)
42int64 = np_export.np_export_constant(__name__, 'int64', np.int64)
43int8 = np_export.np_export_constant(__name__, 'int8', np.int8)
44object_ = np_export.np_export_constant(__name__, 'object_', np.object_)
45string_ = np_export.np_export_constant(__name__, 'string_', np.string_)
46uint16 = np_export.np_export_constant(__name__, 'uint16', np.uint16)
47uint32 = np_export.np_export_constant(__name__, 'uint32', np.uint32)
48uint64 = np_export.np_export_constant(__name__, 'uint64', np.uint64)
49uint8 = np_export.np_export_constant(__name__, 'uint8', np.uint8)
50unicode_ = np_export.np_export_constant(__name__, 'unicode_', np.unicode_)
51
52
53iinfo = np_export.np_export_constant(__name__, 'iinfo', np.iinfo)
54
55
56issubdtype = np_export.np_export('issubdtype')(np.issubdtype)
57
58
59_to_float32 = {
60    np.dtype('float64'): np.dtype('float32'),
61    np.dtype('complex128'): np.dtype('complex64'),
62}
63
64
65_cached_np_dtypes = {}
66
67
68# Difference between is_prefer_float32 and is_allow_float64: is_prefer_float32
69# only decides which dtype to use for Python floats; is_allow_float64 decides
70# whether float64 dtypes can ever appear in programs. The latter is more
71# restrictive than the former.
72_prefer_float32 = False
73
74
75# TODO(b/178862061): Consider removing this knob
76_allow_float64 = True
77
78
79def is_prefer_float32():
80  return _prefer_float32
81
82
83def set_prefer_float32(b):
84  global _prefer_float32
85  _prefer_float32 = b
86
87
88def is_allow_float64():
89  return _allow_float64
90
91
92def set_allow_float64(b):
93  global _allow_float64
94  _allow_float64 = b
95
96
97def canonicalize_dtype(dtype):
98  if not _allow_float64:
99    try:
100      return _to_float32[dtype]
101    except KeyError:
102      pass
103  return dtype
104
105
106def _result_type(*arrays_and_dtypes):
107  def preprocess_float(x):
108    if is_prefer_float32() and isinstance(x, float):
109      return np.float32(x)
110    return x
111  arrays_and_dtypes = [preprocess_float(x) for x in arrays_and_dtypes]
112  dtype = np.result_type(*arrays_and_dtypes)
113  return dtypes.as_dtype(canonicalize_dtype(dtype))
114
115
116def _get_cached_dtype(dtype):
117  """Returns an np.dtype for the TensorFlow DType."""
118  global _cached_np_dtypes
119  try:
120    return _cached_np_dtypes[dtype]
121  except KeyError:
122    pass
123  cached_dtype = np.dtype(dtype.as_numpy_dtype)
124  _cached_np_dtypes[dtype] = cached_dtype
125  return cached_dtype
126
127
128def default_float_type():
129  """Gets the default float type.
130
131  Returns:
132    If `is_prefer_float32()` is false and `is_allow_float64()` is true, returns
133    float64; otherwise returns float32.
134  """
135  if not is_prefer_float32() and is_allow_float64():
136    return float64
137  else:
138    return float32
139