1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.python.framework.dtypes."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.core.framework import types_pb2
24from tensorflow.python import _dtypes
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import test_util
27from tensorflow.python.platform import googletest
28
29
30def _is_numeric_dtype_enum(datatype_enum):
31  non_numeric_dtypes = [types_pb2.DT_VARIANT,
32                        types_pb2.DT_VARIANT_REF,
33                        types_pb2.DT_INVALID,
34                        types_pb2.DT_RESOURCE,
35                        types_pb2.DT_RESOURCE_REF]
36  return datatype_enum not in non_numeric_dtypes
37
38
39class TypesTest(test_util.TensorFlowTestCase):
40
41  def testAllTypesConstructible(self):
42    for datatype_enum in types_pb2.DataType.values():
43      if datatype_enum == types_pb2.DT_INVALID:
44        continue
45      self.assertEqual(datatype_enum,
46                       dtypes.DType(datatype_enum).as_datatype_enum)
47
48  def testAllTypesConvertibleToDType(self):
49    for datatype_enum in types_pb2.DataType.values():
50      if datatype_enum == types_pb2.DT_INVALID:
51        continue
52      dt = dtypes.as_dtype(datatype_enum)
53      self.assertEqual(datatype_enum, dt.as_datatype_enum)
54
55  def testAllTypesConvertibleToNumpyDtype(self):
56    for datatype_enum in types_pb2.DataType.values():
57      if not _is_numeric_dtype_enum(datatype_enum):
58        continue
59      dtype = dtypes.as_dtype(datatype_enum)
60      numpy_dtype = dtype.as_numpy_dtype
61      _ = np.empty((1, 1, 1, 1), dtype=numpy_dtype)
62      if dtype.base_dtype != dtypes.bfloat16:
63        # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16.
64        self.assertEqual(
65            dtypes.as_dtype(datatype_enum).base_dtype,
66            dtypes.as_dtype(numpy_dtype))
67
68  def testAllPybind11DTypeConvertibleToDType(self):
69    for datatype_enum in types_pb2.DataType.values():
70      if datatype_enum == types_pb2.DT_INVALID:
71        continue
72      dtype = _dtypes.DType(datatype_enum)
73      self.assertEqual(dtypes.as_dtype(datatype_enum), dtype)
74
75  def testInvalid(self):
76    with self.assertRaises(TypeError):
77      dtypes.DType(types_pb2.DT_INVALID)
78    with self.assertRaises(TypeError):
79      dtypes.as_dtype(types_pb2.DT_INVALID)
80
81  def testNumpyConversion(self):
82    self.assertIs(dtypes.float32, dtypes.as_dtype(np.float32))
83    self.assertIs(dtypes.float64, dtypes.as_dtype(np.float64))
84    self.assertIs(dtypes.int32, dtypes.as_dtype(np.int32))
85    self.assertIs(dtypes.int64, dtypes.as_dtype(np.int64))
86    self.assertIs(dtypes.uint8, dtypes.as_dtype(np.uint8))
87    self.assertIs(dtypes.uint16, dtypes.as_dtype(np.uint16))
88    self.assertIs(dtypes.int16, dtypes.as_dtype(np.int16))
89    self.assertIs(dtypes.int8, dtypes.as_dtype(np.int8))
90    self.assertIs(dtypes.complex64, dtypes.as_dtype(np.complex64))
91    self.assertIs(dtypes.complex128, dtypes.as_dtype(np.complex128))
92    self.assertIs(dtypes.string, dtypes.as_dtype(np.object_))
93    self.assertIs(dtypes.string,
94                  dtypes.as_dtype(np.array(["foo", "bar"]).dtype))
95    self.assertIs(dtypes.bool, dtypes.as_dtype(np.bool_))
96    with self.assertRaises(TypeError):
97      dtypes.as_dtype(np.dtype([("f1", np.uint), ("f2", np.int32)]))
98
99    class AnObject(object):
100      dtype = "f4"
101
102    self.assertIs(dtypes.float32, dtypes.as_dtype(AnObject))
103
104    class AnotherObject(object):
105      dtype = np.dtype(np.complex64)
106
107    self.assertIs(dtypes.complex64, dtypes.as_dtype(AnotherObject))
108
109  def testRealDtype(self):
110    for dtype in [
111        dtypes.float32, dtypes.float64, dtypes.bool, dtypes.uint8, dtypes.int8,
112        dtypes.int16, dtypes.int32, dtypes.int64
113    ]:
114      self.assertIs(dtype.real_dtype, dtype)
115    self.assertIs(dtypes.complex64.real_dtype, dtypes.float32)
116    self.assertIs(dtypes.complex128.real_dtype, dtypes.float64)
117
118  def testStringConversion(self):
119    self.assertIs(dtypes.float32, dtypes.as_dtype("float32"))
120    self.assertIs(dtypes.float64, dtypes.as_dtype("float64"))
121    self.assertIs(dtypes.int32, dtypes.as_dtype("int32"))
122    self.assertIs(dtypes.uint8, dtypes.as_dtype("uint8"))
123    self.assertIs(dtypes.uint16, dtypes.as_dtype("uint16"))
124    self.assertIs(dtypes.int16, dtypes.as_dtype("int16"))
125    self.assertIs(dtypes.int8, dtypes.as_dtype("int8"))
126    self.assertIs(dtypes.string, dtypes.as_dtype("string"))
127    self.assertIs(dtypes.complex64, dtypes.as_dtype("complex64"))
128    self.assertIs(dtypes.complex128, dtypes.as_dtype("complex128"))
129    self.assertIs(dtypes.int64, dtypes.as_dtype("int64"))
130    self.assertIs(dtypes.bool, dtypes.as_dtype("bool"))
131    self.assertIs(dtypes.qint8, dtypes.as_dtype("qint8"))
132    self.assertIs(dtypes.quint8, dtypes.as_dtype("quint8"))
133    self.assertIs(dtypes.qint32, dtypes.as_dtype("qint32"))
134    self.assertIs(dtypes.bfloat16, dtypes.as_dtype("bfloat16"))
135    self.assertIs(dtypes.float32_ref, dtypes.as_dtype("float32_ref"))
136    self.assertIs(dtypes.float64_ref, dtypes.as_dtype("float64_ref"))
137    self.assertIs(dtypes.int32_ref, dtypes.as_dtype("int32_ref"))
138    self.assertIs(dtypes.uint8_ref, dtypes.as_dtype("uint8_ref"))
139    self.assertIs(dtypes.int16_ref, dtypes.as_dtype("int16_ref"))
140    self.assertIs(dtypes.int8_ref, dtypes.as_dtype("int8_ref"))
141    self.assertIs(dtypes.string_ref, dtypes.as_dtype("string_ref"))
142    self.assertIs(dtypes.complex64_ref, dtypes.as_dtype("complex64_ref"))
143    self.assertIs(dtypes.complex128_ref, dtypes.as_dtype("complex128_ref"))
144    self.assertIs(dtypes.int64_ref, dtypes.as_dtype("int64_ref"))
145    self.assertIs(dtypes.bool_ref, dtypes.as_dtype("bool_ref"))
146    self.assertIs(dtypes.qint8_ref, dtypes.as_dtype("qint8_ref"))
147    self.assertIs(dtypes.quint8_ref, dtypes.as_dtype("quint8_ref"))
148    self.assertIs(dtypes.qint32_ref, dtypes.as_dtype("qint32_ref"))
149    self.assertIs(dtypes.bfloat16_ref, dtypes.as_dtype("bfloat16_ref"))
150    with self.assertRaises(TypeError):
151      dtypes.as_dtype("not_a_type")
152
153  def testDTypesHaveUniqueNames(self):
154    dtypez = []
155    names = set()
156    for datatype_enum in types_pb2.DataType.values():
157      if datatype_enum == types_pb2.DT_INVALID:
158        continue
159      dtype = dtypes.as_dtype(datatype_enum)
160      dtypez.append(dtype)
161      names.add(dtype.name)
162    self.assertEqual(len(dtypez), len(names))
163
164  def testIsInteger(self):
165    self.assertEqual(dtypes.as_dtype("int8").is_integer, True)
166    self.assertEqual(dtypes.as_dtype("int16").is_integer, True)
167    self.assertEqual(dtypes.as_dtype("int32").is_integer, True)
168    self.assertEqual(dtypes.as_dtype("int64").is_integer, True)
169    self.assertEqual(dtypes.as_dtype("uint8").is_integer, True)
170    self.assertEqual(dtypes.as_dtype("uint16").is_integer, True)
171    self.assertEqual(dtypes.as_dtype("complex64").is_integer, False)
172    self.assertEqual(dtypes.as_dtype("complex128").is_integer, False)
173    self.assertEqual(dtypes.as_dtype("float").is_integer, False)
174    self.assertEqual(dtypes.as_dtype("double").is_integer, False)
175    self.assertEqual(dtypes.as_dtype("string").is_integer, False)
176    self.assertEqual(dtypes.as_dtype("bool").is_integer, False)
177    self.assertEqual(dtypes.as_dtype("bfloat16").is_integer, False)
178    self.assertEqual(dtypes.as_dtype("qint8").is_integer, False)
179    self.assertEqual(dtypes.as_dtype("qint16").is_integer, False)
180    self.assertEqual(dtypes.as_dtype("qint32").is_integer, False)
181    self.assertEqual(dtypes.as_dtype("quint8").is_integer, False)
182    self.assertEqual(dtypes.as_dtype("quint16").is_integer, False)
183
184  def testIsFloating(self):
185    self.assertEqual(dtypes.as_dtype("int8").is_floating, False)
186    self.assertEqual(dtypes.as_dtype("int16").is_floating, False)
187    self.assertEqual(dtypes.as_dtype("int32").is_floating, False)
188    self.assertEqual(dtypes.as_dtype("int64").is_floating, False)
189    self.assertEqual(dtypes.as_dtype("uint8").is_floating, False)
190    self.assertEqual(dtypes.as_dtype("uint16").is_floating, False)
191    self.assertEqual(dtypes.as_dtype("complex64").is_floating, False)
192    self.assertEqual(dtypes.as_dtype("complex128").is_floating, False)
193    self.assertEqual(dtypes.as_dtype("float32").is_floating, True)
194    self.assertEqual(dtypes.as_dtype("float64").is_floating, True)
195    self.assertEqual(dtypes.as_dtype("string").is_floating, False)
196    self.assertEqual(dtypes.as_dtype("bool").is_floating, False)
197    self.assertEqual(dtypes.as_dtype("bfloat16").is_floating, True)
198    self.assertEqual(dtypes.as_dtype("qint8").is_floating, False)
199    self.assertEqual(dtypes.as_dtype("qint16").is_floating, False)
200    self.assertEqual(dtypes.as_dtype("qint32").is_floating, False)
201    self.assertEqual(dtypes.as_dtype("quint8").is_floating, False)
202    self.assertEqual(dtypes.as_dtype("quint16").is_floating, False)
203
204  def testIsComplex(self):
205    self.assertEqual(dtypes.as_dtype("int8").is_complex, False)
206    self.assertEqual(dtypes.as_dtype("int16").is_complex, False)
207    self.assertEqual(dtypes.as_dtype("int32").is_complex, False)
208    self.assertEqual(dtypes.as_dtype("int64").is_complex, False)
209    self.assertEqual(dtypes.as_dtype("uint8").is_complex, False)
210    self.assertEqual(dtypes.as_dtype("uint16").is_complex, False)
211    self.assertEqual(dtypes.as_dtype("complex64").is_complex, True)
212    self.assertEqual(dtypes.as_dtype("complex128").is_complex, True)
213    self.assertEqual(dtypes.as_dtype("float32").is_complex, False)
214    self.assertEqual(dtypes.as_dtype("float64").is_complex, False)
215    self.assertEqual(dtypes.as_dtype("string").is_complex, False)
216    self.assertEqual(dtypes.as_dtype("bool").is_complex, False)
217    self.assertEqual(dtypes.as_dtype("bfloat16").is_complex, False)
218    self.assertEqual(dtypes.as_dtype("qint8").is_complex, False)
219    self.assertEqual(dtypes.as_dtype("qint16").is_complex, False)
220    self.assertEqual(dtypes.as_dtype("qint32").is_complex, False)
221    self.assertEqual(dtypes.as_dtype("quint8").is_complex, False)
222    self.assertEqual(dtypes.as_dtype("quint16").is_complex, False)
223
224  def testIsUnsigned(self):
225    self.assertEqual(dtypes.as_dtype("int8").is_unsigned, False)
226    self.assertEqual(dtypes.as_dtype("int16").is_unsigned, False)
227    self.assertEqual(dtypes.as_dtype("int32").is_unsigned, False)
228    self.assertEqual(dtypes.as_dtype("int64").is_unsigned, False)
229    self.assertEqual(dtypes.as_dtype("uint8").is_unsigned, True)
230    self.assertEqual(dtypes.as_dtype("uint16").is_unsigned, True)
231    self.assertEqual(dtypes.as_dtype("float32").is_unsigned, False)
232    self.assertEqual(dtypes.as_dtype("float64").is_unsigned, False)
233    self.assertEqual(dtypes.as_dtype("bool").is_unsigned, False)
234    self.assertEqual(dtypes.as_dtype("string").is_unsigned, False)
235    self.assertEqual(dtypes.as_dtype("complex64").is_unsigned, False)
236    self.assertEqual(dtypes.as_dtype("complex128").is_unsigned, False)
237    self.assertEqual(dtypes.as_dtype("bfloat16").is_unsigned, False)
238    self.assertEqual(dtypes.as_dtype("qint8").is_unsigned, False)
239    self.assertEqual(dtypes.as_dtype("qint16").is_unsigned, False)
240    self.assertEqual(dtypes.as_dtype("qint32").is_unsigned, False)
241    self.assertEqual(dtypes.as_dtype("quint8").is_unsigned, False)
242    self.assertEqual(dtypes.as_dtype("quint16").is_unsigned, False)
243
244  def testMinMax(self):
245    # make sure min/max evaluates for all data types that have min/max
246    for datatype_enum in types_pb2.DataType.values():
247      if not _is_numeric_dtype_enum(datatype_enum):
248        continue
249      dtype = dtypes.as_dtype(datatype_enum)
250      numpy_dtype = dtype.as_numpy_dtype
251
252      # ignore types for which there are no minimum/maximum (or we cannot
253      # compute it, such as for the q* types)
254      if (dtype.is_quantized or dtype.base_dtype == dtypes.bool or
255          dtype.base_dtype == dtypes.string or
256          dtype.base_dtype == dtypes.complex64 or
257          dtype.base_dtype == dtypes.complex128):
258        continue
259
260      print("%s: %s - %s" % (dtype, dtype.min, dtype.max))
261
262      # check some values that are known
263      if numpy_dtype == np.bool_:
264        self.assertEqual(dtype.min, 0)
265        self.assertEqual(dtype.max, 1)
266      if numpy_dtype == np.int8:
267        self.assertEqual(dtype.min, -128)
268        self.assertEqual(dtype.max, 127)
269      if numpy_dtype == np.int16:
270        self.assertEqual(dtype.min, -32768)
271        self.assertEqual(dtype.max, 32767)
272      if numpy_dtype == np.int32:
273        self.assertEqual(dtype.min, -2147483648)
274        self.assertEqual(dtype.max, 2147483647)
275      if numpy_dtype == np.int64:
276        self.assertEqual(dtype.min, -9223372036854775808)
277        self.assertEqual(dtype.max, 9223372036854775807)
278      if numpy_dtype == np.uint8:
279        self.assertEqual(dtype.min, 0)
280        self.assertEqual(dtype.max, 255)
281      if numpy_dtype == np.uint16:
282        if dtype == dtypes.uint16:
283          self.assertEqual(dtype.min, 0)
284          self.assertEqual(dtype.max, 65535)
285        elif dtype == dtypes.bfloat16:
286          self.assertEqual(dtype.min, 0)
287          self.assertEqual(dtype.max, 4294967295)
288      if numpy_dtype == np.uint32:
289        self.assertEqual(dtype.min, 0)
290        self.assertEqual(dtype.max, 4294967295)
291      if numpy_dtype == np.uint64:
292        self.assertEqual(dtype.min, 0)
293        self.assertEqual(dtype.max, 18446744073709551615)
294      if numpy_dtype in (np.float16, np.float32, np.float64):
295        self.assertEqual(dtype.min, np.finfo(numpy_dtype).min)
296        self.assertEqual(dtype.max, np.finfo(numpy_dtype).max)
297      if numpy_dtype == dtypes.bfloat16.as_numpy_dtype:
298        self.assertEqual(dtype.min, float.fromhex("-0x1.FEp127"))
299        self.assertEqual(dtype.max, float.fromhex("0x1.FEp127"))
300
301  def testRepr(self):
302    self.skipTest("b/142725777")
303    for enum, name in dtypes._TYPE_TO_STRING.items():
304      if enum > 100:
305        continue
306      dtype = dtypes.DType(enum)
307      self.assertEqual(repr(dtype), "tf." + name)
308      import tensorflow as tf
309      dtype2 = eval(repr(dtype))
310      self.assertEqual(type(dtype2), dtypes.DType)
311      self.assertEqual(dtype, dtype2)
312
313  def testEqWithNonTFTypes(self):
314    self.assertNotEqual(dtypes.int32, int)
315    self.assertNotEqual(dtypes.float64, 2.1)
316
317  def testPythonLongConversion(self):
318    self.assertIs(dtypes.int64, dtypes.as_dtype(np.array(2**32).dtype))
319
320  def testPythonTypesConversion(self):
321    self.assertIs(dtypes.float32, dtypes.as_dtype(float))
322    self.assertIs(dtypes.bool, dtypes.as_dtype(bool))
323
324  def testReduce(self):
325    for enum in dtypes._TYPE_TO_STRING:
326      dtype = dtypes.DType(enum)
327      ctor, args = dtype.__reduce__()
328      self.assertEqual(ctor, dtypes.as_dtype)
329      self.assertEqual(args, (dtype.name,))
330      reconstructed = ctor(*args)
331      self.assertEqual(reconstructed, dtype)
332
333  def testAsDtypeInvalidArgument(self):
334    with self.assertRaises(TypeError):
335      dtypes.as_dtype((dtypes.int32, dtypes.float32))
336
337  def testAsDtypeReturnsInternedVersion(self):
338    dt = dtypes.DType(types_pb2.DT_VARIANT)
339    self.assertIs(dtypes.as_dtype(dt), dtypes.variant)
340
341
342if __name__ == "__main__":
343  googletest.main()
344