1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities to create TensorProtos."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21import six
22
23from tensorflow.core.framework import tensor_pb2
24from tensorflow.core.framework import tensor_shape_pb2
25from tensorflow.python.framework import composite_tensor
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.util import compat
29
30# Fallback in case fast_tensor_util is not properly compiled.
31# pylint: disable=g-import-not-at-top
32try:
33  from tensorflow.python.framework import fast_tensor_util
34  _FAST_TENSOR_UTIL_AVAILABLE = True
35except ImportError:
36  _FAST_TENSOR_UTIL_AVAILABLE = False
37
38from tensorflow.python.framework import dtypes
39from tensorflow.python.framework import ops
40from tensorflow.python.util.tf_export import tf_export
41
42# pylint: enable=g-import-not-at-top
43
44
45def ExtractBitsFromFloat16(x):
46  return np.asarray(x, dtype=np.float16).view(np.uint16).item()
47
48
49def SlowAppendFloat16ArrayToTensorProto(tensor_proto, proto_values):
50  tensor_proto.half_val.extend(
51      [ExtractBitsFromFloat16(x) for x in proto_values])
52
53
54def _MediumAppendFloat16ArrayToTensorProto(tensor_proto, proto_values):
55  # TODO: Remove the conversion if cython supports np.float16_t
56  fast_tensor_util.AppendFloat16ArrayToTensorProto(
57      tensor_proto,
58      np.asarray(proto_values, dtype=np.float16).view(np.uint16))
59
60
61def ExtractBitsFromBFloat16(x):
62  return np.asarray(
63      x, dtype=dtypes.bfloat16.as_numpy_dtype).view(np.uint16).item()
64
65
66def SlowAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values):
67  tensor_proto.half_val.extend(
68      [ExtractBitsFromBFloat16(x) for x in proto_values])
69
70
71def FastAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values):
72  fast_tensor_util.AppendBFloat16ArrayToTensorProto(
73      tensor_proto, np.asarray(
74          proto_values, dtype=dtypes.bfloat16.as_numpy_dtype).view(np.uint16))
75
76
77if _FAST_TENSOR_UTIL_AVAILABLE:
78  _NP_TO_APPEND_FN = {
79      dtypes.bfloat16.as_numpy_dtype:
80          FastAppendBFloat16ArrayToTensorProto,
81      np.float16:
82          _MediumAppendFloat16ArrayToTensorProto,
83      np.float32:
84          fast_tensor_util.AppendFloat32ArrayToTensorProto,
85      np.float64:
86          fast_tensor_util.AppendFloat64ArrayToTensorProto,
87      np.int32:
88          fast_tensor_util.AppendInt32ArrayToTensorProto,
89      np.int64:
90          fast_tensor_util.AppendInt64ArrayToTensorProto,
91      np.uint8:
92          fast_tensor_util.AppendUInt8ArrayToTensorProto,
93      np.uint16:
94          fast_tensor_util.AppendUInt16ArrayToTensorProto,
95      np.uint32:
96          fast_tensor_util.AppendUInt32ArrayToTensorProto,
97      np.uint64:
98          fast_tensor_util.AppendUInt64ArrayToTensorProto,
99      np.int8:
100          fast_tensor_util.AppendInt8ArrayToTensorProto,
101      np.int16:
102          fast_tensor_util.AppendInt16ArrayToTensorProto,
103      np.complex64:
104          fast_tensor_util.AppendComplex64ArrayToTensorProto,
105      np.complex128:
106          fast_tensor_util.AppendComplex128ArrayToTensorProto,
107      np.object:
108          fast_tensor_util.AppendObjectArrayToTensorProto,
109      np.bool:
110          fast_tensor_util.AppendBoolArrayToTensorProto,
111      dtypes.qint8.as_numpy_dtype:
112          fast_tensor_util.AppendInt8ArrayToTensorProto,
113      dtypes.quint8.as_numpy_dtype:
114          fast_tensor_util.AppendUInt8ArrayToTensorProto,
115      dtypes.qint16.as_numpy_dtype:
116          fast_tensor_util.AppendInt8ArrayToTensorProto,
117      dtypes.quint16.as_numpy_dtype:
118          fast_tensor_util.AppendUInt8ArrayToTensorProto,
119      dtypes.qint32.as_numpy_dtype:
120          fast_tensor_util.AppendInt32ArrayToTensorProto,
121      # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16.
122  }
123else:
124
125  def SlowAppendFloat32ArrayToTensorProto(tensor_proto, proto_values):
126    tensor_proto.float_val.extend([x.item() for x in proto_values])
127
128  def SlowAppendFloat64ArrayToTensorProto(tensor_proto, proto_values):
129    tensor_proto.double_val.extend([x.item() for x in proto_values])
130
131  def SlowAppendIntArrayToTensorProto(tensor_proto, proto_values):
132    tensor_proto.int_val.extend([x.item() for x in proto_values])
133
134  def SlowAppendInt64ArrayToTensorProto(tensor_proto, proto_values):
135    tensor_proto.int64_val.extend([x.item() for x in proto_values])
136
137  def SlowAppendQIntArrayToTensorProto(tensor_proto, proto_values):
138    tensor_proto.int_val.extend([x.item()[0] for x in proto_values])
139
140  def SlowAppendUInt32ArrayToTensorProto(tensor_proto, proto_values):
141    tensor_proto.uint32_val.extend([x.item() for x in proto_values])
142
143  def SlowAppendUInt64ArrayToTensorProto(tensor_proto, proto_values):
144    tensor_proto.uint64_val.extend([x.item() for x in proto_values])
145
146  def SlowAppendComplex64ArrayToTensorProto(tensor_proto, proto_values):
147    tensor_proto.scomplex_val.extend(
148        [v.item() for x in proto_values for v in [x.real, x.imag]])
149
150  def SlowAppendComplex128ArrayToTensorProto(tensor_proto, proto_values):
151    tensor_proto.dcomplex_val.extend(
152        [v.item() for x in proto_values for v in [x.real, x.imag]])
153
154  def SlowAppendObjectArrayToTensorProto(tensor_proto, proto_values):
155    tensor_proto.string_val.extend([compat.as_bytes(x) for x in proto_values])
156
157  def SlowAppendBoolArrayToTensorProto(tensor_proto, proto_values):
158    tensor_proto.bool_val.extend([x.item() for x in proto_values])
159
160  _NP_TO_APPEND_FN = {
161      dtypes.bfloat16.as_numpy_dtype: SlowAppendBFloat16ArrayToTensorProto,
162      np.float16: SlowAppendFloat16ArrayToTensorProto,
163      np.float32: SlowAppendFloat32ArrayToTensorProto,
164      np.float64: SlowAppendFloat64ArrayToTensorProto,
165      np.int32: SlowAppendIntArrayToTensorProto,
166      np.int64: SlowAppendInt64ArrayToTensorProto,
167      np.uint8: SlowAppendIntArrayToTensorProto,
168      np.uint16: SlowAppendIntArrayToTensorProto,
169      np.uint32: SlowAppendUInt32ArrayToTensorProto,
170      np.uint64: SlowAppendUInt64ArrayToTensorProto,
171      np.int8: SlowAppendIntArrayToTensorProto,
172      np.int16: SlowAppendIntArrayToTensorProto,
173      np.complex64: SlowAppendComplex64ArrayToTensorProto,
174      np.complex128: SlowAppendComplex128ArrayToTensorProto,
175      np.object: SlowAppendObjectArrayToTensorProto,
176      np.bool: SlowAppendBoolArrayToTensorProto,
177      dtypes.qint8.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
178      dtypes.quint8.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
179      dtypes.qint16.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
180      dtypes.quint16.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
181      dtypes.qint32.as_numpy_dtype: SlowAppendQIntArrayToTensorProto,
182      # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16.
183  }
184
185
186def GetFromNumpyDTypeDict(dtype_dict, dtype):
187  # NOTE: dtype_dict.get(dtype) always returns None.
188  for key, val in six.iteritems(dtype_dict):
189    if key == dtype:
190      return val
191  return None
192
193
194def GetNumpyAppendFn(dtype):
195  # numpy dtype for strings are variable length. We can not compare
196  # dtype with a single constant (np.string does not exist) to decide
197  # dtype is a "string" type. We need to compare the dtype.type to be
198  # sure it's a string type.
199  if dtype.type == np.string_ or dtype.type == np.unicode_:
200    if _FAST_TENSOR_UTIL_AVAILABLE:
201      return fast_tensor_util.AppendObjectArrayToTensorProto
202    else:
203      return SlowAppendObjectArrayToTensorProto
204  return GetFromNumpyDTypeDict(_NP_TO_APPEND_FN, dtype)
205
206
207def TensorShapeProtoToList(shape):
208  """Convert a TensorShape to a list.
209
210  Args:
211    shape: A TensorShapeProto.
212
213  Returns:
214    List of integers representing the dimensions of the tensor.
215  """
216  return [dim.size for dim in shape.dim]
217
218
219def _GetDenseDimensions(list_of_lists):
220  """Returns the inferred dense dimensions of a list of lists."""
221  if not isinstance(list_of_lists, (list, tuple)):
222    return []
223  elif not list_of_lists:
224    return [0]
225  else:
226    return [len(list_of_lists)] + _GetDenseDimensions(list_of_lists[0])
227
228
229def _FlattenToStrings(nested_strings):
230  if isinstance(nested_strings, (list, tuple)):
231    for inner in nested_strings:
232      for flattened_string in _FlattenToStrings(inner):
233        yield flattened_string
234  else:
235    yield nested_strings
236
237
238_TENSOR_CONTENT_TYPES = frozenset([
239    dtypes.float32, dtypes.float64, dtypes.int32, dtypes.uint8, dtypes.int16,
240    dtypes.int8, dtypes.int64, dtypes.qint8, dtypes.quint8, dtypes.qint16,
241    dtypes.quint16, dtypes.qint32, dtypes.uint32, dtypes.uint64
242])
243
244
245class _Message(object):
246
247  def __init__(self, message):
248    self._message = message
249
250  def __repr__(self):
251    return self._message
252
253
254def _FirstNotNone(l):
255  for x in l:
256    if x is not None:
257      if isinstance(x, ops.Tensor):
258        return _Message("list containing Tensors")
259      else:
260        return x
261  return None
262
263
264def _NotNone(v):
265  if v is None:
266    return _Message("None")
267  else:
268    return v
269
270
271def _FilterTuple(v):
272  if not isinstance(v, (list, tuple)):
273    return v
274  if isinstance(v, tuple):
275    if not any(isinstance(x, (list, tuple)) for x in v):
276      return None
277  if isinstance(v, list):
278    if not any(isinstance(x, (list, tuple)) for x in v):
279      return _FirstNotNone(
280          [None if isinstance(x, (list, tuple)) else x for x in v])
281  return _FirstNotNone([_FilterTuple(x) for x in v])
282
283
284def _FilterInt(v):
285  if isinstance(v, (list, tuple)):
286    return _FirstNotNone([_FilterInt(x) for x in v])
287  return None if isinstance(
288      v, (compat.integral_types, tensor_shape.Dimension)) else _NotNone(v)
289
290
291def _FilterFloat(v):
292  if isinstance(v, (list, tuple)):
293    return _FirstNotNone([_FilterFloat(x) for x in v])
294  return None if isinstance(v, compat.real_types) else _NotNone(v)
295
296
297def _FilterComplex(v):
298  if isinstance(v, (list, tuple)):
299    return _FirstNotNone([_FilterComplex(x) for x in v])
300  return None if isinstance(v, compat.complex_types) else _NotNone(v)
301
302
303def _FilterStr(v):
304  if isinstance(v, (list, tuple)):
305    return _FirstNotNone([_FilterStr(x) for x in v])
306  if isinstance(v, compat.bytes_or_text_types):
307    return None
308  else:
309    return _NotNone(v)
310
311
312def _FilterBool(v):
313  if isinstance(v, (list, tuple)):
314    return _FirstNotNone([_FilterBool(x) for x in v])
315  return None if isinstance(v, bool) else _NotNone(v)
316
317
318def _FilterNotTensor(v):
319  if isinstance(v, (list, tuple)):
320    return _FirstNotNone([_FilterNotTensor(x) for x in v])
321  return str(v) if isinstance(v, ops.Tensor) else None
322
323
324_TF_TO_IS_OK = {
325    dtypes.bool: [_FilterBool],
326    dtypes.complex128: [_FilterComplex],
327    dtypes.complex64: [_FilterComplex],
328    dtypes.float16: [_FilterFloat],
329    dtypes.float32: [_FilterFloat],
330    dtypes.float64: [_FilterFloat],
331    dtypes.int16: [_FilterInt],
332    dtypes.int32: [_FilterInt],
333    dtypes.int64: [_FilterInt],
334    dtypes.int8: [_FilterInt],
335    dtypes.qint16: [_FilterInt, _FilterTuple],
336    dtypes.qint32: [_FilterInt, _FilterTuple],
337    dtypes.qint8: [_FilterInt, _FilterTuple],
338    dtypes.quint16: [_FilterInt, _FilterTuple],
339    dtypes.quint8: [_FilterInt, _FilterTuple],
340    dtypes.string: [_FilterStr],
341    dtypes.uint16: [_FilterInt],
342    dtypes.uint8: [_FilterInt],
343    dtypes.uint32: [_FilterInt],
344    dtypes.uint64: [_FilterInt],
345}
346
347
348def _AssertCompatible(values, dtype):
349  if dtype is None:
350    fn_list = [_FilterNotTensor]
351  else:
352    try:
353      fn_list = _TF_TO_IS_OK[dtype]
354    except KeyError:
355      # There isn't a specific fn_list, so we try to do the best possible.
356      if dtype.is_integer:
357        fn_list = [_FilterInt]
358      elif dtype.is_floating:
359        fn_list = [_FilterFloat]
360      elif dtype.is_complex:
361        fn_list = [_FilterComplex]
362      elif dtype.is_quantized:
363        fn_list = [_FilterInt, _FilterTuple]
364      else:
365        fn_list = [_FilterNotTensor]
366  mismatch = _FirstNotNone([fn(values) for fn in fn_list])
367  if mismatch is not None:
368    if dtype is None:
369      raise TypeError("List of Tensors when single Tensor expected")
370    else:
371      raise TypeError("Expected %s, got %s of type '%s' instead." %
372                      (dtype.name, repr(mismatch), type(mismatch).__name__))
373
374
375# pylint: disable=invalid-name
376@tf_export(v1=["make_tensor_proto"])
377def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False,
378                      allow_broadcast=False):
379  """Create a TensorProto.
380
381  Args:
382    values:         Values to put in the TensorProto.
383    dtype:          Optional tensor_pb2 DataType value.
384    shape:          List of integers representing the dimensions of tensor.
385    verify_shape:   Boolean that enables verification of a shape of values.
386    allow_broadcast:Boolean that enables allowing scalars and 1 length vector
387        broadcasting. Cannot be true when verify_shape is true.
388
389  Returns:
390    A `TensorProto`. Depending on the type, it may contain data in the
391    "tensor_content" attribute, which is not directly useful to Python programs.
392    To access the values you should convert the proto back to a numpy ndarray
393    with `tf.make_ndarray(proto)`.
394
395    If `values` is a `TensorProto`, it is immediately returned; `dtype` and
396    `shape` are ignored.
397
398  Raises:
399    TypeError:  if unsupported types are provided.
400    ValueError: if arguments have inappropriate values or if verify_shape is
401     True and shape of values is not equals to a shape from the argument.
402
403  make_tensor_proto accepts "values" of a python scalar, a python list, a
404  numpy ndarray, or a numpy scalar.
405
406  If "values" is a python scalar or a python list, make_tensor_proto
407  first convert it to numpy ndarray. If dtype is None, the
408  conversion tries its best to infer the right numpy data
409  type. Otherwise, the resulting numpy array has a compatible data
410  type with the given dtype.
411
412  In either case above, the numpy ndarray (either the caller provided
413  or the auto converted) must have the compatible type with dtype.
414
415  make_tensor_proto then converts the numpy array to a tensor proto.
416
417  If "shape" is None, the resulting tensor proto represents the numpy
418  array precisely.
419
420  Otherwise, "shape" specifies the tensor's shape and the numpy array
421  can not have more elements than what "shape" specifies.
422
423  """
424  if allow_broadcast and verify_shape:
425    raise ValueError("allow_broadcast and verify_shape are not both allowed.")
426  if isinstance(values, tensor_pb2.TensorProto):
427    return values
428
429  if dtype:
430    dtype = dtypes.as_dtype(dtype)
431
432  is_quantized = (
433      dtype in [
434          dtypes.qint8, dtypes.quint8, dtypes.qint16, dtypes.quint16,
435          dtypes.qint32
436      ])
437
438  # We first convert value to a numpy array or scalar.
439  if isinstance(values, (np.ndarray, np.generic)):
440    if dtype:
441      nparray = values.astype(dtype.as_numpy_dtype)
442    else:
443      nparray = values
444  elif callable(getattr(values, "__array__", None)) or isinstance(
445      getattr(values, "__array_interface__", None), dict):
446    # If a class has the __array__ method, or __array_interface__ dict, then it
447    # is possible to convert to numpy array.
448    nparray = np.asarray(values, dtype=dtype)
449
450    # This is the preferred way to create an array from the object, so replace
451    # the `values` with the array so that _FlattenToStrings is not run.
452    values = nparray
453  else:
454    if values is None:
455      raise ValueError("None values not supported.")
456    # if dtype is provided, forces numpy array to be the type
457    # provided if possible.
458    if dtype and dtype.is_numpy_compatible:
459      np_dt = dtype.as_numpy_dtype
460    else:
461      np_dt = None
462    # If shape is None, numpy.prod returns None when dtype is not set, but raises
463    # exception when dtype is set to np.int64
464    if shape is not None and np.prod(shape, dtype=np.int64) == 0:
465      nparray = np.empty(shape, dtype=np_dt)
466    else:
467      _AssertCompatible(values, dtype)
468      nparray = np.array(values, dtype=np_dt)
469      # check to them.
470      # We need to pass in quantized values as tuples, so don't apply the shape
471      if (list(nparray.shape) != _GetDenseDimensions(values) and
472          not is_quantized):
473        raise ValueError("""Argument must be a dense tensor: %s"""
474                         """ - got shape %s, but wanted %s.""" %
475                         (values, list(nparray.shape),
476                          _GetDenseDimensions(values)))
477
478    # python/numpy default float type is float64. We prefer float32 instead.
479    if (nparray.dtype == np.float64) and dtype is None:
480      nparray = nparray.astype(np.float32)
481    # python/numpy default int type is int64. We prefer int32 instead.
482    elif (nparray.dtype == np.int64) and dtype is None:
483      downcasted_array = nparray.astype(np.int32)
484      # Do not down cast if it leads to precision loss.
485      if np.array_equal(downcasted_array, nparray):
486        nparray = downcasted_array
487
488  # if dtype is provided, it must be compatible with what numpy
489  # conversion says.
490  numpy_dtype = dtypes.as_dtype(nparray.dtype)
491  if numpy_dtype is None:
492    raise TypeError("Unrecognized data type: %s" % nparray.dtype)
493
494  # If dtype was specified and is a quantized type, we convert
495  # numpy_dtype back into the quantized version.
496  if is_quantized:
497    numpy_dtype = dtype
498
499  if dtype is not None and (not hasattr(dtype, "base_dtype") or
500                            dtype.base_dtype != numpy_dtype.base_dtype):
501    raise TypeError("Incompatible types: %s vs. %s. Value is %s" %
502                    (dtype, nparray.dtype, values))
503
504  # If shape is not given, get the shape from the numpy array.
505  if shape is None:
506    shape = nparray.shape
507    is_same_size = True
508    shape_size = nparray.size
509  else:
510    shape = [int(dim) for dim in shape]
511    shape_size = np.prod(shape, dtype=np.int64)
512    is_same_size = shape_size == nparray.size
513
514    if allow_broadcast:
515      if nparray.shape == (1,) or nparray.shape == tuple():
516        pass
517      elif nparray.size != shape_size:
518        raise TypeError("Expected Tensor's shape: %s, got %s." %
519                        (tuple(shape), nparray.shape))
520
521    else:
522      if verify_shape and nparray.shape != tuple(shape):
523        raise TypeError("Expected Tensor's shape: %s, got %s." %
524                        (tuple(shape), nparray.shape))
525
526      if nparray.size > shape_size:
527        raise ValueError(
528            "Too many elements provided. Needed at most %d, but received %d" %
529            (shape_size, nparray.size))
530
531  tensor_proto = tensor_pb2.TensorProto(
532      dtype=numpy_dtype.as_datatype_enum,
533      tensor_shape=tensor_shape.as_shape(shape).as_proto())
534
535  if is_same_size and numpy_dtype in _TENSOR_CONTENT_TYPES and shape_size > 1:
536    if nparray.size * nparray.itemsize >= (1 << 31):
537      raise ValueError(
538          "Cannot create a tensor proto whose content is larger than 2GB.")
539    tensor_proto.tensor_content = nparray.tostring()
540    return tensor_proto
541
542  # If we were not given values as a numpy array, compute the proto_values
543  # from the given values directly, to avoid numpy trimming nulls from the
544  # strings. Since values could be a list of strings, or a multi-dimensional
545  # list of lists that might or might not correspond to the given shape,
546  # we flatten it conservatively.
547  if numpy_dtype == dtypes.string and not isinstance(values, np.ndarray):
548    proto_values = _FlattenToStrings(values)
549
550    # At this point, values may be a list of objects that we could not
551    # identify a common type for (hence it was inferred as
552    # np.object/dtypes.string).  If we are unable to convert it to a
553    # string, we raise a more helpful error message.
554    #
555    # Ideally, we'd be able to convert the elements of the list to a
556    # common type, but this type inference requires some thinking and
557    # so we defer it for now.
558    try:
559      str_values = [compat.as_bytes(x) for x in proto_values]
560    except TypeError:
561      raise TypeError("Failed to convert object of type %s to Tensor. "
562                      "Contents: %s. Consider casting elements to a "
563                      "supported type." % (type(values), values))
564    tensor_proto.string_val.extend(str_values)
565    return tensor_proto
566
567  # TensorFlow expects C order (a.k.a., eigen row major).
568  proto_values = nparray.ravel()
569
570  append_fn = GetNumpyAppendFn(proto_values.dtype)
571  if append_fn is None:
572    raise TypeError(
573        "Element type not supported in TensorProto: %s" % numpy_dtype.name)
574  append_fn(tensor_proto, proto_values)
575
576  return tensor_proto
577# pylint: enable=invalid-name
578
579
580@tf_export("make_ndarray")
581def MakeNdarray(tensor):
582  """Create a numpy ndarray from a tensor.
583
584  Create a numpy ndarray with the same shape and data as the tensor.
585
586  Args:
587    tensor: A TensorProto.
588
589  Returns:
590    A numpy array with the tensor contents.
591
592  Raises:
593    TypeError: if tensor has unsupported type.
594
595  """
596  shape = [d.size for d in tensor.tensor_shape.dim]
597  num_elements = np.prod(shape, dtype=np.int64)
598  tensor_dtype = dtypes.as_dtype(tensor.dtype)
599  dtype = tensor_dtype.as_numpy_dtype
600
601  if tensor.tensor_content:
602    return (np.frombuffer(tensor.tensor_content,
603                          dtype=dtype).copy().reshape(shape))
604
605  if tensor_dtype == dtypes.string:
606    # np.pad throws on these arrays of type np.object.
607    values = list(tensor.string_val)
608    padding = num_elements - len(values)
609    if padding > 0:
610      last = values[-1] if values else ""
611      values.extend([last] * padding)
612    return np.array(values, dtype=dtype).reshape(shape)
613
614  if tensor_dtype == dtypes.float16 or tensor_dtype == dtypes.bfloat16:
615    # the half_val field of the TensorProto stores the binary representation
616    # of the fp16: we need to reinterpret this as a proper float16
617    values = np.fromiter(tensor.half_val, dtype=np.uint16)
618    values.dtype = tensor_dtype.as_numpy_dtype
619  elif tensor_dtype == dtypes.float32:
620    values = np.fromiter(tensor.float_val, dtype=dtype)
621  elif tensor_dtype == dtypes.float64:
622    values = np.fromiter(tensor.double_val, dtype=dtype)
623  elif tensor_dtype in [
624      dtypes.int32, dtypes.uint8, dtypes.uint16, dtypes.int16, dtypes.int8,
625      dtypes.qint32, dtypes.quint8, dtypes.qint8, dtypes.qint16, dtypes.quint16
626  ]:
627    values = np.fromiter(tensor.int_val, dtype=dtype)
628  elif tensor_dtype == dtypes.int64:
629    values = np.fromiter(tensor.int64_val, dtype=dtype)
630  elif tensor_dtype == dtypes.complex64:
631    it = iter(tensor.scomplex_val)
632    values = np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype)
633  elif tensor_dtype == dtypes.complex128:
634    it = iter(tensor.dcomplex_val)
635    values = np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype)
636  elif tensor_dtype == dtypes.bool:
637    values = np.fromiter(tensor.bool_val, dtype=dtype)
638  else:
639    raise TypeError("Unsupported tensor type: %s" % tensor.dtype)
640
641  if values.size == 0:
642    return np.zeros(shape, dtype)
643
644  if values.size != num_elements:
645    values = np.pad(values, (0, num_elements - values.size), "edge")
646
647  return values.reshape(shape)
648
649
650def ShapeEquals(tensor_proto, shape):
651  """Returns True if "tensor_proto" has the given "shape".
652
653  Args:
654    tensor_proto: A TensorProto.
655    shape: A tensor shape, expressed as a TensorShape, list, or tuple.
656
657  Returns:
658    True if "tensor_proto" has the given "shape", otherwise False.
659
660  Raises:
661    TypeError: If "tensor_proto" is not a TensorProto, or shape is not a
662      TensorShape, list, or tuple.
663  """
664  if not isinstance(tensor_proto, tensor_pb2.TensorProto):
665    raise TypeError("tensor_proto is not a tensor_pb2.TensorProto object")
666  if isinstance(shape, tensor_shape_pb2.TensorShapeProto):
667    shape = [d.size for d in shape.dim]
668  elif not isinstance(shape, (list, tuple)):
669    raise TypeError("shape is not a list or tuple")
670  tensor_shape_list = [d.size for d in tensor_proto.tensor_shape.dim]
671  return all(x == y for x, y in zip(tensor_shape_list, shape))
672
673
674def _ConstantValue(tensor, partial):
675  # TODO(touts): Support Variables?
676  if not isinstance(tensor, ops.Tensor):
677    raise TypeError("%r is not a Tensor, has type %s" % (tensor, type(tensor)))
678  if tensor.op.type == "Const":
679    return MakeNdarray(tensor.op.get_attr("value"))
680  elif tensor.op.type == "Shape":
681    input_shape = tensor.op.inputs[0].get_shape()
682    if input_shape.is_fully_defined():
683      return np.array(
684          [dim.value for dim in input_shape.dims],
685          dtype=tensor.dtype.as_numpy_dtype)
686    else:
687      return None
688  elif tensor.op.type == "Size":
689    input_shape = tensor.op.inputs[0].get_shape()
690    if input_shape.is_fully_defined():
691      return np.prod([dim.value for dim in input_shape.dims], dtype=np.int32)
692    else:
693      return None
694  elif tensor.op.type == "Rank":
695    input_shape = tensor.op.inputs[0].get_shape()
696    if input_shape.ndims is not None:
697      return np.ndarray(
698          shape=(),
699          buffer=np.array([input_shape.ndims], dtype=np.int32),
700          dtype=np.int32)
701    else:
702      return None
703  elif tensor.op.type == "Range":
704    start = constant_value(tensor.op.inputs[0])
705    if start is None:
706      return None
707    limit = constant_value(tensor.op.inputs[1])
708    if limit is None:
709      return None
710    delta = constant_value(tensor.op.inputs[2])
711    if delta is None:
712      return None
713    return np.arange(start, limit, delta, dtype=tensor.dtype.as_numpy_dtype)
714  elif tensor.op.type == "Cast":
715    pre_cast = constant_value(tensor.op.inputs[0])
716    if pre_cast is None:
717      return None
718    cast_dtype = dtypes.as_dtype(tensor.op.get_attr("DstT"))
719    return pre_cast.astype(cast_dtype.as_numpy_dtype)
720  elif tensor.op.type == "Concat":
721    dim = constant_value(tensor.op.inputs[0])
722    if dim is None:
723      return None
724    values = []
725    for x in tensor.op.inputs[1:]:
726      value = constant_value(x)
727      if value is None:
728        return None
729      values.append(value)
730    return np.concatenate(values, axis=dim)
731  elif tensor.op.type == "ConcatV2":
732    dim = constant_value(tensor.op.inputs[-1])
733    if dim is None:
734      return None
735    values = []
736    for x in tensor.op.inputs[:-1]:
737      value = constant_value(x)
738      if value is None:
739        return None
740      values.append(value)
741    return np.concatenate(values, axis=dim)
742  elif tensor.op.type == "Pack":
743    values = []
744    # Some imported GraphDefs have Pack ops with zero inputs. Those are invalid
745    # and shouldn't be produced, but to deal sensibly with them here we check
746    # and return None.
747    if not tensor.op.inputs:
748      return None
749    # We can't handle axis != 0 Packs at the moment.
750    if tensor.op.get_attr("axis") != 0:
751      return None
752    for x in tensor.op.inputs:
753      value = constant_value(x, partial)
754      if value is None and not partial:
755        return None
756      values.append(value)
757    return np.array(values)
758  elif tensor.op.type == "Fill":
759    fill_shape = tensor.shape
760    fill_value = constant_value(tensor.op.inputs[1])
761    if fill_shape.is_fully_defined() and fill_value is not None:
762      return np.full(fill_shape.as_list(), fill_value, dtype=fill_value.dtype)
763    else:
764      return None
765  elif tensor.op.type == "Equal":
766    value1 = constant_value(tensor.op.inputs[0])
767    if value1 is None:
768      return None
769    value2 = constant_value(tensor.op.inputs[1])
770    if value2 is None:
771      return None
772    return np.equal(value1, value2)
773  elif tensor.op.type == "NotEqual":
774    value1 = constant_value(tensor.op.inputs[0])
775    if value1 is None:
776      return None
777    value2 = constant_value(tensor.op.inputs[1])
778    if value2 is None:
779      return None
780    return np.not_equal(value1, value2)
781  else:
782    return None
783
784
785@tf_export('get_static_value')
786def constant_value(tensor, partial=False):  # pylint: disable=invalid-name
787  """Returns the constant value of the given tensor, if efficiently calculable.
788
789  This function attempts to partially evaluate the given tensor, and
790  returns its value as a numpy ndarray if this succeeds.
791
792  Compatibility(V1): If `constant_value(tensor)` returns a non-`None` result, it
793  will no longer be possible to feed a different value for `tensor`. This allows
794  the result of this function to influence the graph that is constructed, and
795  permits static shape optimizations.
796
797  Args:
798    tensor: The Tensor to be evaluated.
799    partial: If True, the returned numpy array is allowed to have partially
800      evaluated values. Values that can't be evaluated will be None.
801
802  Returns:
803    A numpy ndarray containing the constant value of the given `tensor`,
804    or None if it cannot be calculated.
805
806  Raises:
807    TypeError: if tensor is not an ops.Tensor.
808  """
809  if isinstance(tensor, ops.EagerTensor):
810    return tensor.numpy()
811  if not is_tensor(tensor):
812    return tensor
813  if not isinstance(tensor, ops.Tensor):
814    return None
815  ret = _ConstantValue(tensor, partial)
816  if ret is not None:
817    # The caller may now depend on the constant value of `tensor`, so we
818    # conservatively prevent it from being fed.
819    tensor.graph.prevent_feeding(tensor)
820  return ret
821
822
823def constant_value_as_shape(tensor):  # pylint: disable=invalid-name
824  """A version of `constant_value()` that returns a `TensorShape`.
825
826  This version should be used when a constant tensor value is
827  interpreted as a (possibly partial) shape, e.g. in the shape
828  function for `tf.reshape()`. By explicitly requesting a
829  `TensorShape` as the return value, it is possible to represent
830  unknown dimensions; by contrast, `constant_value()` is
831  all-or-nothing.
832
833  Args:
834    tensor: The rank-0 or rank-1 Tensor to be evaluated.
835
836  Returns:
837    A `TensorShape` based on the constant value of the given `tensor`.
838
839  Raises:
840    ValueError: If the shape is rank-0 and is not statically known to be -1.
841  """
842  if isinstance(tensor, ops.EagerTensor):
843    return tensor_shape.as_shape(
844        [dim if dim != -1 else None for dim in tensor.numpy()])
845
846  if tensor.get_shape().ndims == 0:
847    value = constant_value(tensor)
848    if value is None:
849      raise ValueError(
850          "Received a scalar with unknown value as shape; require a statically "
851          "known scalar with value '-1' to describe an unknown shape.")
852    if value != -1:
853      raise ValueError(
854          "Received a scalar value '%s' as shape; require a statically known "
855          "scalar with value '-1' to describe an unknown shape." % value)
856    return tensor_shape.unknown_shape()
857
858  shape = tensor.get_shape().with_rank(1)
859  if shape == [0]:
860    return tensor_shape.scalar()
861  elif tensor.op.type == "Shape":
862    return tensor.op.inputs[0].get_shape()
863  elif tensor.op.type == "Pack":
864    ret = tensor_shape.scalar()  # Empty list.
865    # Since we expect rank 1 inputs, Pack's axis must be zero, otherwise it
866    # would not be rank 1.
867    assert tensor.op.get_attr("axis") == 0
868    for pack_input in tensor.op.inputs:
869      # `pack_input` must be a scalar. Attempt to evaluate it, and append it
870      # to `ret`.
871      pack_input_val = constant_value(pack_input)
872      if pack_input_val is None or pack_input_val < 0:
873        new_dim = tensor_shape.Dimension(None)
874      else:
875        new_dim = tensor_shape.Dimension(pack_input_val)
876      ret = ret.concatenate([new_dim])
877    return ret
878  elif tensor.op.type == "Concat":
879    # We assume that `tensor.op.inputs[0]` evaluates to 0, as this is
880    # the only legal value when concatenating vectors, and it will
881    # have been checked by a previous shape function.
882    ret = tensor_shape.scalar()  # Empty list.
883    for concat_input in tensor.op.inputs[1:]:
884      # `concat_input` must be a vector. Attempt to evaluate it as a shape,
885      # and concatenate it with `ret`.
886      ret = ret.concatenate(constant_value_as_shape(concat_input))
887    return ret
888  elif tensor.op.type == "ConcatV2":
889    # We assume that `tensor.op.inputs[-1]` evaluates to 0, as this is
890    # the only legal value when concatenating vectors, and it will
891    # have been checked by a previous shape function.
892    ret = tensor_shape.scalar()  # Empty list.
893    for concat_input in tensor.op.inputs[:-1]:
894      # `concat_input` must be a vector. Attempt to evaluate it as a shape,
895      # and concatenate it with `ret`.
896      ret = ret.concatenate(constant_value_as_shape(concat_input))
897    return ret
898  elif tensor.op.type == "StridedSlice":
899    try:
900      begin = constant_value(tensor.op.inputs[1])
901      end = constant_value(tensor.op.inputs[2])
902      strides = constant_value(tensor.op.inputs[3])
903      if begin is not None and end is not None and strides is not None:
904        begin = begin[0]
905        end = end[0]
906        strides = strides[0]
907        begin_mask = tensor.op.get_attr("begin_mask")
908        if begin_mask == 1:
909          begin = None
910        end_mask = tensor.op.get_attr("end_mask")
911        if end_mask == 1:
912          end = None
913
914        ellipsis_mask = tensor.op.get_attr("ellipsis_mask")
915        new_axis_mask = tensor.op.get_attr("new_axis_mask")
916        shrink_axis_mask = tensor.op.get_attr("shrink_axis_mask")
917        valid_attributes = (not ellipsis_mask and not new_axis_mask and
918                            not shrink_axis_mask and (not begin_mask or
919                                                      (begin_mask == 1)) and
920                            (not end_mask or (end_mask == 1)))
921        if valid_attributes:  # additional inputs not supported
922          prev = constant_value_as_shape(tensor.op.inputs[0])
923          prev = prev[begin:end:strides]
924          ret = tensor_shape.TensorShape(prev)
925          return ret
926
927    except ValueError:  # Could come from get_attr or slicing prev.
928      pass
929    except TypeError:  # Could come from slicing prev.
930      pass
931
932  ret = tensor_shape.unknown_shape(shape.dims[0].value)
933  value = constant_value(tensor)
934  if value is not None:
935    ret = ret.merge_with(
936        tensor_shape.TensorShape([d if d >= 0 else None for d in value]))
937  return ret
938
939
940@tf_export("is_tensor")
941def is_tensor(x):  # pylint: disable=invalid-name
942  """Check whether `x` is of tensor type.
943
944  Check whether an object is a tensor or a composite tensor. This check is
945  equivalent to calling
946  `isinstance(x, (tf.Tensor, tf.SparseTensor, tf.RaggedTensor, tf.Variable))`
947  and also checks if all the component variables of a MirroredVariable or a
948  SyncOnReadVariable are tensors.
949
950  Args:
951    x: A python object to check.
952
953  Returns:
954    `True` if `x` is a tensor, `False` if not.
955  """
956  return (isinstance(x, ops._TensorLike) or ops.is_dense_tensor_like(x) or  # pylint: disable=protected-access
957          isinstance(x, composite_tensor.CompositeTensor) or
958          (hasattr(x, "is_tensor_like") and x.is_tensor_like))
959