1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.python.framework.dtypes.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.core.framework import types_pb2 24from tensorflow.python import _dtypes 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import test_util 27from tensorflow.python.platform import googletest 28 29 30def _is_numeric_dtype_enum(datatype_enum): 31 non_numeric_dtypes = [types_pb2.DT_VARIANT, 32 types_pb2.DT_VARIANT_REF, 33 types_pb2.DT_INVALID, 34 types_pb2.DT_RESOURCE, 35 types_pb2.DT_RESOURCE_REF] 36 return datatype_enum not in non_numeric_dtypes 37 38 39class TypesTest(test_util.TensorFlowTestCase): 40 41 def testAllTypesConstructible(self): 42 for datatype_enum in types_pb2.DataType.values(): 43 if datatype_enum == types_pb2.DT_INVALID: 44 continue 45 self.assertEqual(datatype_enum, 46 dtypes.DType(datatype_enum).as_datatype_enum) 47 48 def testAllTypesConvertibleToDType(self): 49 for datatype_enum in types_pb2.DataType.values(): 50 if datatype_enum == types_pb2.DT_INVALID: 51 continue 52 dt = dtypes.as_dtype(datatype_enum) 53 self.assertEqual(datatype_enum, dt.as_datatype_enum) 54 55 def testAllTypesConvertibleToNumpyDtype(self): 56 for datatype_enum in types_pb2.DataType.values(): 57 if not _is_numeric_dtype_enum(datatype_enum): 58 continue 59 dtype = dtypes.as_dtype(datatype_enum) 60 numpy_dtype = dtype.as_numpy_dtype 61 _ = np.empty((1, 1, 1, 1), dtype=numpy_dtype) 62 if dtype.base_dtype != dtypes.bfloat16: 63 # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16. 64 self.assertEqual( 65 dtypes.as_dtype(datatype_enum).base_dtype, 66 dtypes.as_dtype(numpy_dtype)) 67 68 def testAllPybind11DTypeConvertibleToDType(self): 69 for datatype_enum in types_pb2.DataType.values(): 70 if datatype_enum == types_pb2.DT_INVALID: 71 continue 72 dtype = _dtypes.DType(datatype_enum) 73 self.assertEqual(dtypes.as_dtype(datatype_enum), dtype) 74 75 def testInvalid(self): 76 with self.assertRaises(TypeError): 77 dtypes.DType(types_pb2.DT_INVALID) 78 with self.assertRaises(TypeError): 79 dtypes.as_dtype(types_pb2.DT_INVALID) 80 81 def testNumpyConversion(self): 82 self.assertIs(dtypes.float32, dtypes.as_dtype(np.float32)) 83 self.assertIs(dtypes.float64, dtypes.as_dtype(np.float64)) 84 self.assertIs(dtypes.int32, dtypes.as_dtype(np.int32)) 85 self.assertIs(dtypes.int64, dtypes.as_dtype(np.int64)) 86 self.assertIs(dtypes.uint8, dtypes.as_dtype(np.uint8)) 87 self.assertIs(dtypes.uint16, dtypes.as_dtype(np.uint16)) 88 self.assertIs(dtypes.int16, dtypes.as_dtype(np.int16)) 89 self.assertIs(dtypes.int8, dtypes.as_dtype(np.int8)) 90 self.assertIs(dtypes.complex64, dtypes.as_dtype(np.complex64)) 91 self.assertIs(dtypes.complex128, dtypes.as_dtype(np.complex128)) 92 self.assertIs(dtypes.string, dtypes.as_dtype(np.object_)) 93 self.assertIs(dtypes.string, 94 dtypes.as_dtype(np.array(["foo", "bar"]).dtype)) 95 self.assertIs(dtypes.bool, dtypes.as_dtype(np.bool_)) 96 with self.assertRaises(TypeError): 97 dtypes.as_dtype(np.dtype([("f1", np.uint), ("f2", np.int32)])) 98 99 class AnObject(object): 100 dtype = "f4" 101 102 self.assertIs(dtypes.float32, dtypes.as_dtype(AnObject)) 103 104 class AnotherObject(object): 105 dtype = np.dtype(np.complex64) 106 107 self.assertIs(dtypes.complex64, dtypes.as_dtype(AnotherObject)) 108 109 def testRealDtype(self): 110 for dtype in [ 111 dtypes.float32, dtypes.float64, dtypes.bool, dtypes.uint8, dtypes.int8, 112 dtypes.int16, dtypes.int32, dtypes.int64 113 ]: 114 self.assertIs(dtype.real_dtype, dtype) 115 self.assertIs(dtypes.complex64.real_dtype, dtypes.float32) 116 self.assertIs(dtypes.complex128.real_dtype, dtypes.float64) 117 118 def testStringConversion(self): 119 self.assertIs(dtypes.float32, dtypes.as_dtype("float32")) 120 self.assertIs(dtypes.float64, dtypes.as_dtype("float64")) 121 self.assertIs(dtypes.int32, dtypes.as_dtype("int32")) 122 self.assertIs(dtypes.uint8, dtypes.as_dtype("uint8")) 123 self.assertIs(dtypes.uint16, dtypes.as_dtype("uint16")) 124 self.assertIs(dtypes.int16, dtypes.as_dtype("int16")) 125 self.assertIs(dtypes.int8, dtypes.as_dtype("int8")) 126 self.assertIs(dtypes.string, dtypes.as_dtype("string")) 127 self.assertIs(dtypes.complex64, dtypes.as_dtype("complex64")) 128 self.assertIs(dtypes.complex128, dtypes.as_dtype("complex128")) 129 self.assertIs(dtypes.int64, dtypes.as_dtype("int64")) 130 self.assertIs(dtypes.bool, dtypes.as_dtype("bool")) 131 self.assertIs(dtypes.qint8, dtypes.as_dtype("qint8")) 132 self.assertIs(dtypes.quint8, dtypes.as_dtype("quint8")) 133 self.assertIs(dtypes.qint32, dtypes.as_dtype("qint32")) 134 self.assertIs(dtypes.bfloat16, dtypes.as_dtype("bfloat16")) 135 self.assertIs(dtypes.float32_ref, dtypes.as_dtype("float32_ref")) 136 self.assertIs(dtypes.float64_ref, dtypes.as_dtype("float64_ref")) 137 self.assertIs(dtypes.int32_ref, dtypes.as_dtype("int32_ref")) 138 self.assertIs(dtypes.uint8_ref, dtypes.as_dtype("uint8_ref")) 139 self.assertIs(dtypes.int16_ref, dtypes.as_dtype("int16_ref")) 140 self.assertIs(dtypes.int8_ref, dtypes.as_dtype("int8_ref")) 141 self.assertIs(dtypes.string_ref, dtypes.as_dtype("string_ref")) 142 self.assertIs(dtypes.complex64_ref, dtypes.as_dtype("complex64_ref")) 143 self.assertIs(dtypes.complex128_ref, dtypes.as_dtype("complex128_ref")) 144 self.assertIs(dtypes.int64_ref, dtypes.as_dtype("int64_ref")) 145 self.assertIs(dtypes.bool_ref, dtypes.as_dtype("bool_ref")) 146 self.assertIs(dtypes.qint8_ref, dtypes.as_dtype("qint8_ref")) 147 self.assertIs(dtypes.quint8_ref, dtypes.as_dtype("quint8_ref")) 148 self.assertIs(dtypes.qint32_ref, dtypes.as_dtype("qint32_ref")) 149 self.assertIs(dtypes.bfloat16_ref, dtypes.as_dtype("bfloat16_ref")) 150 with self.assertRaises(TypeError): 151 dtypes.as_dtype("not_a_type") 152 153 def testDTypesHaveUniqueNames(self): 154 dtypez = [] 155 names = set() 156 for datatype_enum in types_pb2.DataType.values(): 157 if datatype_enum == types_pb2.DT_INVALID: 158 continue 159 dtype = dtypes.as_dtype(datatype_enum) 160 dtypez.append(dtype) 161 names.add(dtype.name) 162 self.assertEqual(len(dtypez), len(names)) 163 164 def testIsInteger(self): 165 self.assertEqual(dtypes.as_dtype("int8").is_integer, True) 166 self.assertEqual(dtypes.as_dtype("int16").is_integer, True) 167 self.assertEqual(dtypes.as_dtype("int32").is_integer, True) 168 self.assertEqual(dtypes.as_dtype("int64").is_integer, True) 169 self.assertEqual(dtypes.as_dtype("uint8").is_integer, True) 170 self.assertEqual(dtypes.as_dtype("uint16").is_integer, True) 171 self.assertEqual(dtypes.as_dtype("complex64").is_integer, False) 172 self.assertEqual(dtypes.as_dtype("complex128").is_integer, False) 173 self.assertEqual(dtypes.as_dtype("float").is_integer, False) 174 self.assertEqual(dtypes.as_dtype("double").is_integer, False) 175 self.assertEqual(dtypes.as_dtype("string").is_integer, False) 176 self.assertEqual(dtypes.as_dtype("bool").is_integer, False) 177 self.assertEqual(dtypes.as_dtype("bfloat16").is_integer, False) 178 self.assertEqual(dtypes.as_dtype("qint8").is_integer, False) 179 self.assertEqual(dtypes.as_dtype("qint16").is_integer, False) 180 self.assertEqual(dtypes.as_dtype("qint32").is_integer, False) 181 self.assertEqual(dtypes.as_dtype("quint8").is_integer, False) 182 self.assertEqual(dtypes.as_dtype("quint16").is_integer, False) 183 184 def testIsFloating(self): 185 self.assertEqual(dtypes.as_dtype("int8").is_floating, False) 186 self.assertEqual(dtypes.as_dtype("int16").is_floating, False) 187 self.assertEqual(dtypes.as_dtype("int32").is_floating, False) 188 self.assertEqual(dtypes.as_dtype("int64").is_floating, False) 189 self.assertEqual(dtypes.as_dtype("uint8").is_floating, False) 190 self.assertEqual(dtypes.as_dtype("uint16").is_floating, False) 191 self.assertEqual(dtypes.as_dtype("complex64").is_floating, False) 192 self.assertEqual(dtypes.as_dtype("complex128").is_floating, False) 193 self.assertEqual(dtypes.as_dtype("float32").is_floating, True) 194 self.assertEqual(dtypes.as_dtype("float64").is_floating, True) 195 self.assertEqual(dtypes.as_dtype("string").is_floating, False) 196 self.assertEqual(dtypes.as_dtype("bool").is_floating, False) 197 self.assertEqual(dtypes.as_dtype("bfloat16").is_floating, True) 198 self.assertEqual(dtypes.as_dtype("qint8").is_floating, False) 199 self.assertEqual(dtypes.as_dtype("qint16").is_floating, False) 200 self.assertEqual(dtypes.as_dtype("qint32").is_floating, False) 201 self.assertEqual(dtypes.as_dtype("quint8").is_floating, False) 202 self.assertEqual(dtypes.as_dtype("quint16").is_floating, False) 203 204 def testIsComplex(self): 205 self.assertEqual(dtypes.as_dtype("int8").is_complex, False) 206 self.assertEqual(dtypes.as_dtype("int16").is_complex, False) 207 self.assertEqual(dtypes.as_dtype("int32").is_complex, False) 208 self.assertEqual(dtypes.as_dtype("int64").is_complex, False) 209 self.assertEqual(dtypes.as_dtype("uint8").is_complex, False) 210 self.assertEqual(dtypes.as_dtype("uint16").is_complex, False) 211 self.assertEqual(dtypes.as_dtype("complex64").is_complex, True) 212 self.assertEqual(dtypes.as_dtype("complex128").is_complex, True) 213 self.assertEqual(dtypes.as_dtype("float32").is_complex, False) 214 self.assertEqual(dtypes.as_dtype("float64").is_complex, False) 215 self.assertEqual(dtypes.as_dtype("string").is_complex, False) 216 self.assertEqual(dtypes.as_dtype("bool").is_complex, False) 217 self.assertEqual(dtypes.as_dtype("bfloat16").is_complex, False) 218 self.assertEqual(dtypes.as_dtype("qint8").is_complex, False) 219 self.assertEqual(dtypes.as_dtype("qint16").is_complex, False) 220 self.assertEqual(dtypes.as_dtype("qint32").is_complex, False) 221 self.assertEqual(dtypes.as_dtype("quint8").is_complex, False) 222 self.assertEqual(dtypes.as_dtype("quint16").is_complex, False) 223 224 def testIsUnsigned(self): 225 self.assertEqual(dtypes.as_dtype("int8").is_unsigned, False) 226 self.assertEqual(dtypes.as_dtype("int16").is_unsigned, False) 227 self.assertEqual(dtypes.as_dtype("int32").is_unsigned, False) 228 self.assertEqual(dtypes.as_dtype("int64").is_unsigned, False) 229 self.assertEqual(dtypes.as_dtype("uint8").is_unsigned, True) 230 self.assertEqual(dtypes.as_dtype("uint16").is_unsigned, True) 231 self.assertEqual(dtypes.as_dtype("float32").is_unsigned, False) 232 self.assertEqual(dtypes.as_dtype("float64").is_unsigned, False) 233 self.assertEqual(dtypes.as_dtype("bool").is_unsigned, False) 234 self.assertEqual(dtypes.as_dtype("string").is_unsigned, False) 235 self.assertEqual(dtypes.as_dtype("complex64").is_unsigned, False) 236 self.assertEqual(dtypes.as_dtype("complex128").is_unsigned, False) 237 self.assertEqual(dtypes.as_dtype("bfloat16").is_unsigned, False) 238 self.assertEqual(dtypes.as_dtype("qint8").is_unsigned, False) 239 self.assertEqual(dtypes.as_dtype("qint16").is_unsigned, False) 240 self.assertEqual(dtypes.as_dtype("qint32").is_unsigned, False) 241 self.assertEqual(dtypes.as_dtype("quint8").is_unsigned, False) 242 self.assertEqual(dtypes.as_dtype("quint16").is_unsigned, False) 243 244 def testMinMax(self): 245 # make sure min/max evaluates for all data types that have min/max 246 for datatype_enum in types_pb2.DataType.values(): 247 if not _is_numeric_dtype_enum(datatype_enum): 248 continue 249 dtype = dtypes.as_dtype(datatype_enum) 250 numpy_dtype = dtype.as_numpy_dtype 251 252 # ignore types for which there are no minimum/maximum (or we cannot 253 # compute it, such as for the q* types) 254 if (dtype.is_quantized or dtype.base_dtype == dtypes.bool or 255 dtype.base_dtype == dtypes.string or 256 dtype.base_dtype == dtypes.complex64 or 257 dtype.base_dtype == dtypes.complex128): 258 continue 259 260 print("%s: %s - %s" % (dtype, dtype.min, dtype.max)) 261 262 # check some values that are known 263 if numpy_dtype == np.bool_: 264 self.assertEqual(dtype.min, 0) 265 self.assertEqual(dtype.max, 1) 266 if numpy_dtype == np.int8: 267 self.assertEqual(dtype.min, -128) 268 self.assertEqual(dtype.max, 127) 269 if numpy_dtype == np.int16: 270 self.assertEqual(dtype.min, -32768) 271 self.assertEqual(dtype.max, 32767) 272 if numpy_dtype == np.int32: 273 self.assertEqual(dtype.min, -2147483648) 274 self.assertEqual(dtype.max, 2147483647) 275 if numpy_dtype == np.int64: 276 self.assertEqual(dtype.min, -9223372036854775808) 277 self.assertEqual(dtype.max, 9223372036854775807) 278 if numpy_dtype == np.uint8: 279 self.assertEqual(dtype.min, 0) 280 self.assertEqual(dtype.max, 255) 281 if numpy_dtype == np.uint16: 282 if dtype == dtypes.uint16: 283 self.assertEqual(dtype.min, 0) 284 self.assertEqual(dtype.max, 65535) 285 elif dtype == dtypes.bfloat16: 286 self.assertEqual(dtype.min, 0) 287 self.assertEqual(dtype.max, 4294967295) 288 if numpy_dtype == np.uint32: 289 self.assertEqual(dtype.min, 0) 290 self.assertEqual(dtype.max, 4294967295) 291 if numpy_dtype == np.uint64: 292 self.assertEqual(dtype.min, 0) 293 self.assertEqual(dtype.max, 18446744073709551615) 294 if numpy_dtype in (np.float16, np.float32, np.float64): 295 self.assertEqual(dtype.min, np.finfo(numpy_dtype).min) 296 self.assertEqual(dtype.max, np.finfo(numpy_dtype).max) 297 if numpy_dtype == dtypes.bfloat16.as_numpy_dtype: 298 self.assertEqual(dtype.min, float.fromhex("-0x1.FEp127")) 299 self.assertEqual(dtype.max, float.fromhex("0x1.FEp127")) 300 301 def testRepr(self): 302 self.skipTest("b/142725777") 303 for enum, name in dtypes._TYPE_TO_STRING.items(): 304 if enum > 100: 305 continue 306 dtype = dtypes.DType(enum) 307 self.assertEqual(repr(dtype), "tf." + name) 308 import tensorflow as tf 309 dtype2 = eval(repr(dtype)) 310 self.assertEqual(type(dtype2), dtypes.DType) 311 self.assertEqual(dtype, dtype2) 312 313 def testEqWithNonTFTypes(self): 314 self.assertNotEqual(dtypes.int32, int) 315 self.assertNotEqual(dtypes.float64, 2.1) 316 317 def testPythonLongConversion(self): 318 self.assertIs(dtypes.int64, dtypes.as_dtype(np.array(2**32).dtype)) 319 320 def testPythonTypesConversion(self): 321 self.assertIs(dtypes.float32, dtypes.as_dtype(float)) 322 self.assertIs(dtypes.bool, dtypes.as_dtype(bool)) 323 324 def testReduce(self): 325 for enum in dtypes._TYPE_TO_STRING: 326 dtype = dtypes.DType(enum) 327 ctor, args = dtype.__reduce__() 328 self.assertEqual(ctor, dtypes.as_dtype) 329 self.assertEqual(args, (dtype.name,)) 330 reconstructed = ctor(*args) 331 self.assertEqual(reconstructed, dtype) 332 333 def testAsDtypeInvalidArgument(self): 334 with self.assertRaises(TypeError): 335 dtypes.as_dtype((dtypes.int32, dtypes.float32)) 336 337 def testAsDtypeReturnsInternedVersion(self): 338 dt = dtypes.DType(types_pb2.DT_VARIANT) 339 self.assertIs(dtypes.as_dtype(dt), dtypes.variant) 340 341 342if __name__ == "__main__": 343 googletest.main() 344