1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for describing the structure of a `tf.data` type."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import functools
22
23import six
24import wrapt
25
26from tensorflow.python.data.util import nest
27from tensorflow.python.framework import composite_tensor
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import sparse_tensor
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.framework import tensor_spec
32from tensorflow.python.framework import type_spec
33from tensorflow.python.ops import tensor_array_ops
34from tensorflow.python.ops.ragged import ragged_tensor
35from tensorflow.python.platform import tf_logging as logging
36from tensorflow.python.util import deprecation
37from tensorflow.python.util.compat import collections_abc
38from tensorflow.python.util.tf_export import tf_export
39
40
41# pylint: disable=invalid-name
42@tf_export(v1=["data.experimental.TensorStructure"])
43@deprecation.deprecated(None, "Use `tf.TensorSpec` instead.")
44def _TensorStructure(dtype, shape):
45  return tensor_spec.TensorSpec(shape, dtype)
46
47
48@tf_export(v1=["data.experimental.SparseTensorStructure"])
49@deprecation.deprecated(None, "Use `tf.SparseTensorSpec` instead.")
50def _SparseTensorStructure(dtype, shape):
51  return sparse_tensor.SparseTensorSpec(shape, dtype)
52
53
54@tf_export(v1=["data.experimental.TensorArrayStructure"])
55@deprecation.deprecated(None, "Use `tf.TensorArraySpec` instead.")
56def _TensorArrayStructure(dtype, element_shape, dynamic_size, infer_shape):
57  return tensor_array_ops.TensorArraySpec(element_shape, dtype,
58                                          dynamic_size, infer_shape)
59
60
61@tf_export(v1=["data.experimental.RaggedTensorStructure"])
62@deprecation.deprecated(None, "Use `tf.RaggedTensorSpec` instead.")
63def _RaggedTensorStructure(dtype, shape, ragged_rank):
64  return ragged_tensor.RaggedTensorSpec(shape, dtype, ragged_rank)
65# pylint: enable=invalid-name
66
67
68# TODO(jsimsa): Remove the special-case for `TensorArray` pass-through once
69# it is a subclass of `CompositeTensor`.
70def normalize_element(element, element_signature=None):
71  """Normalizes a nested structure of element components.
72
73  * Components matching `SparseTensorSpec` are converted to `SparseTensor`.
74  * Components matching `RaggedTensorSpec` are converted to `RaggedTensor`.
75  * Components matching `DatasetSpec` or `TensorArraySpec` are passed through.
76  * `CompositeTensor` components are passed through.
77  * All other components are converted to `Tensor`.
78
79  Args:
80    element: A nested structure of individual components.
81    element_signature: (Optional.) A nested structure of `tf.DType` objects
82      corresponding to each component of `element`. If specified, it will be
83      used to set the exact type of output tensor when converting input
84      components which are not tensors themselves (e.g. numpy arrays, native
85      python types, etc.)
86
87  Returns:
88    A nested structure of `Tensor`, `Dataset`, `SparseTensor`, `RaggedTensor`,
89    or `TensorArray` objects.
90  """
91  normalized_components = []
92  if element_signature is None:
93    components = nest.flatten(element)
94    flattened_signature = [None] * len(components)
95    pack_as = element
96  else:
97    flattened_signature = nest.flatten(element_signature)
98    components = nest.flatten_up_to(element_signature, element)
99    pack_as = element_signature
100  with ops.name_scope("normalize_element"):
101    # Imported here to avoid circular dependency.
102    from tensorflow.python.data.ops import dataset_ops  # pylint: disable=g-import-not-at-top
103    for i, (t, spec) in enumerate(zip(components, flattened_signature)):
104      try:
105        if spec is None:
106          spec = type_spec_from_value(t, use_fallback=False)
107      except TypeError:
108        # TypeError indicates it was not possible to compute a `TypeSpec` for
109        # the value. As a fallback try converting the value to a tensor.
110        normalized_components.append(
111            ops.convert_to_tensor(t, name="component_%d" % i))
112      else:
113        if isinstance(spec, sparse_tensor.SparseTensorSpec):
114          normalized_components.append(sparse_tensor.SparseTensor.from_value(t))
115        elif isinstance(spec, ragged_tensor.RaggedTensorSpec):
116          normalized_components.append(
117              ragged_tensor.convert_to_tensor_or_ragged_tensor(
118                  t, name="component_%d" % i))
119        elif isinstance(
120            spec, (tensor_array_ops.TensorArraySpec, dataset_ops.DatasetSpec)):
121          normalized_components.append(t)
122        elif isinstance(spec, NoneTensorSpec):
123          normalized_components.append(NoneTensor())
124        elif isinstance(t, composite_tensor.CompositeTensor):
125          normalized_components.append(t)
126        else:
127          dtype = getattr(spec, "dtype", None)
128          normalized_components.append(
129              ops.convert_to_tensor(t, name="component_%d" % i, dtype=dtype))
130  return nest.pack_sequence_as(pack_as, normalized_components)
131
132
133def convert_legacy_structure(output_types, output_shapes, output_classes):
134  """Returns a `Structure` that represents the given legacy structure.
135
136  This method provides a way to convert from the existing `Dataset` and
137  `Iterator` structure-related properties to a `Structure` object. A "legacy"
138  structure is represented by the `tf.data.Dataset.output_types`,
139  `tf.data.Dataset.output_shapes`, and `tf.data.Dataset.output_classes`
140  properties.
141
142  TODO(b/110122868): Remove this function once `Structure` is used throughout
143  `tf.data`.
144
145  Args:
146    output_types: A nested structure of `tf.DType` objects corresponding to
147      each component of a structured value.
148    output_shapes: A nested structure of `tf.TensorShape` objects
149      corresponding to each component a structured value.
150    output_classes: A nested structure of Python `type` objects corresponding
151      to each component of a structured value.
152
153  Returns:
154    A `Structure`.
155
156  Raises:
157    TypeError: If a structure cannot be built from the arguments, because one of
158      the component classes in `output_classes` is not supported.
159  """
160  flat_types = nest.flatten(output_types)
161  flat_shapes = nest.flatten(output_shapes)
162  flat_classes = nest.flatten(output_classes)
163  flat_ret = []
164  for flat_type, flat_shape, flat_class in zip(flat_types, flat_shapes,
165                                               flat_classes):
166    if isinstance(flat_class, type_spec.TypeSpec):
167      flat_ret.append(flat_class)
168    elif issubclass(flat_class, sparse_tensor.SparseTensor):
169      flat_ret.append(sparse_tensor.SparseTensorSpec(flat_shape, flat_type))
170    elif issubclass(flat_class, ops.Tensor):
171      flat_ret.append(tensor_spec.TensorSpec(flat_shape, flat_type))
172    elif issubclass(flat_class, tensor_array_ops.TensorArray):
173      # We sneaked the dynamic_size and infer_shape into the legacy shape.
174      flat_ret.append(
175          tensor_array_ops.TensorArraySpec(
176              flat_shape[2:], flat_type,
177              dynamic_size=tensor_shape.dimension_value(flat_shape[0]),
178              infer_shape=tensor_shape.dimension_value(flat_shape[1])))
179    else:
180      # NOTE(mrry): Since legacy structures produced by iterators only
181      # comprise Tensors, SparseTensors, and nests, we do not need to
182      # support all structure types here.
183      raise TypeError(
184          "Could not build a structure for output class %r" % (flat_class,))
185
186  return nest.pack_sequence_as(output_classes, flat_ret)
187
188
189def _from_tensor_list_helper(decode_fn, element_spec, tensor_list):
190  """Returns an element constructed from the given spec and tensor list.
191
192  Args:
193    decode_fn: Method that constructs an element component from the element spec
194      component and a tensor list.
195    element_spec: A nested structure of `tf.TypeSpec` objects representing to
196      element type specification.
197    tensor_list: A list of tensors to use for constructing the value.
198
199  Returns:
200    An element constructed from the given spec and tensor list.
201
202  Raises:
203    ValueError: If the number of tensors needed to construct an element for
204      the given spec does not match the given number of tensors.
205  """
206
207  # pylint: disable=protected-access
208
209  flat_specs = nest.flatten(element_spec)
210  flat_spec_lengths = [len(spec._flat_tensor_specs) for spec in flat_specs]
211  if sum(flat_spec_lengths) != len(tensor_list):
212    raise ValueError("Expected %d tensors but got %d." %
213                     (sum(flat_spec_lengths), len(tensor_list)))
214
215  i = 0
216  flat_ret = []
217  for (component_spec, num_flat_values) in zip(flat_specs, flat_spec_lengths):
218    value = tensor_list[i:i + num_flat_values]
219    flat_ret.append(decode_fn(component_spec, value))
220    i += num_flat_values
221  return nest.pack_sequence_as(element_spec, flat_ret)
222
223
224def from_compatible_tensor_list(element_spec, tensor_list):
225  """Returns an element constructed from the given spec and tensor list.
226
227  Args:
228    element_spec: A nested structure of `tf.TypeSpec` objects representing to
229      element type specification.
230    tensor_list: A list of tensors to use for constructing the value.
231
232  Returns:
233    An element constructed from the given spec and tensor list.
234
235  Raises:
236    ValueError: If the number of tensors needed to construct an element for
237      the given spec does not match the given number of tensors.
238  """
239
240  # pylint: disable=protected-access
241  # pylint: disable=g-long-lambda
242  return _from_tensor_list_helper(
243      lambda spec, value: spec._from_compatible_tensor_list(value),
244      element_spec, tensor_list)
245
246
247def from_tensor_list(element_spec, tensor_list):
248  """Returns an element constructed from the given spec and tensor list.
249
250  Args:
251    element_spec: A nested structure of `tf.TypeSpec` objects representing to
252      element type specification.
253    tensor_list: A list of tensors to use for constructing the value.
254
255  Returns:
256    An element constructed from the given spec and tensor list.
257
258  Raises:
259    ValueError: If the number of tensors needed to construct an element for
260      the given spec does not match the given number of tensors or the given
261      spec is not compatible with the tensor list.
262  """
263
264  # pylint: disable=protected-access
265  # pylint: disable=g-long-lambda
266  return _from_tensor_list_helper(
267      lambda spec, value: spec._from_tensor_list(value), element_spec,
268      tensor_list)
269
270
271def get_flat_tensor_specs(element_spec):
272  """Returns a list `tf.TypeSpec`s for the element tensor representation.
273
274  Args:
275    element_spec: A nested structure of `tf.TypeSpec` objects representing to
276      element type specification.
277
278  Returns:
279    A list `tf.TypeSpec`s for the element tensor representation.
280  """
281
282  # pylint: disable=protected-access
283  return functools.reduce(lambda state, value: state + value._flat_tensor_specs,
284                          nest.flatten(element_spec), [])
285
286
287def get_flat_tensor_shapes(element_spec):
288  """Returns a list `tf.TensorShapes`s for the element tensor representation.
289
290  Args:
291    element_spec: A nested structure of `tf.TypeSpec` objects representing to
292      element type specification.
293
294  Returns:
295    A list `tf.TensorShapes`s for the element tensor representation.
296  """
297  return [spec.shape for spec in get_flat_tensor_specs(element_spec)]
298
299
300def get_flat_tensor_types(element_spec):
301  """Returns a list `tf.DType`s for the element tensor representation.
302
303  Args:
304    element_spec: A nested structure of `tf.TypeSpec` objects representing to
305      element type specification.
306
307  Returns:
308    A list `tf.DType`s for the element tensor representation.
309  """
310  return [spec.dtype for spec in get_flat_tensor_specs(element_spec)]
311
312
313def _to_tensor_list_helper(encode_fn, element_spec, element):
314  """Returns a tensor list representation of the element.
315
316  Args:
317    encode_fn: Method that constructs a tensor list representation from the
318      given element spec and element.
319    element_spec: A nested structure of `tf.TypeSpec` objects representing to
320      element type specification.
321    element: The element to convert to tensor list representation.
322
323  Returns:
324    A tensor list representation of `element`.
325
326  Raises:
327    ValueError: If `element_spec` and `element` do not have the same number of
328      elements or if the two structures are not nested in the same way.
329    TypeError: If `element_spec` and `element` differ in the type of sequence
330      in any of their substructures.
331  """
332
333  nest.assert_same_structure(element_spec, element)
334
335  def reduce_fn(state, value):
336    spec, component = value
337    return encode_fn(state, spec, component)
338
339  return functools.reduce(
340      reduce_fn, zip(nest.flatten(element_spec), nest.flatten(element)), [])
341
342
343def to_batched_tensor_list(element_spec, element):
344  """Returns a tensor list representation of the element.
345
346  Args:
347    element_spec: A nested structure of `tf.TypeSpec` objects representing to
348      element type specification.
349    element: The element to convert to tensor list representation.
350
351  Returns:
352    A tensor list representation of `element`.
353
354  Raises:
355    ValueError: If `element_spec` and `element` do not have the same number of
356      elements or if the two structures are not nested in the same way or the
357      rank of any of the tensors in the tensor list representation is 0.
358    TypeError: If `element_spec` and `element` differ in the type of sequence
359      in any of their substructures.
360  """
361
362  # pylint: disable=protected-access
363  # pylint: disable=g-long-lambda
364  return _to_tensor_list_helper(
365      lambda state, spec, component: state + spec._to_batched_tensor_list(
366          component), element_spec, element)
367
368
369def to_tensor_list(element_spec, element):
370  """Returns a tensor list representation of the element.
371
372  Args:
373    element_spec: A nested structure of `tf.TypeSpec` objects representing to
374      element type specification.
375    element: The element to convert to tensor list representation.
376
377  Returns:
378    A tensor list representation of `element`.
379
380  Raises:
381    ValueError: If `element_spec` and `element` do not have the same number of
382      elements or if the two structures are not nested in the same way.
383    TypeError: If `element_spec` and `element` differ in the type of sequence
384      in any of their substructures.
385  """
386
387  # pylint: disable=protected-access
388  # pylint: disable=g-long-lambda
389  return _to_tensor_list_helper(
390      lambda state, spec, component: state + spec._to_tensor_list(component),
391      element_spec, element)
392
393
394def are_compatible(spec1, spec2):
395  """Indicates whether two type specifications are compatible.
396
397  Two type specifications are compatible if they have the same nested structure
398  and the their individual components are pair-wise compatible.
399
400  Args:
401    spec1: A `tf.TypeSpec` object to compare.
402    spec2: A `tf.TypeSpec` object to compare.
403
404  Returns:
405    `True` if the two type specifications are compatible and `False` otherwise.
406  """
407
408  try:
409    nest.assert_same_structure(spec1, spec2)
410  except TypeError:
411    return False
412  except ValueError:
413    return False
414
415  for s1, s2 in zip(nest.flatten(spec1), nest.flatten(spec2)):
416    if not s1.is_compatible_with(s2) or not s2.is_compatible_with(s1):
417      return False
418  return True
419
420
421def type_spec_from_value(element, use_fallback=True):
422  """Creates a type specification for the given value.
423
424  Args:
425    element: The element to create the type specification for.
426    use_fallback: Whether to fall back to converting the element to a tensor
427      in order to compute its `TypeSpec`.
428
429  Returns:
430    A nested structure of `TypeSpec`s that represents the type specification
431    of `element`.
432
433  Raises:
434    TypeError: If a `TypeSpec` cannot be built for `element`, because its type
435      is not supported.
436  """
437  spec = type_spec._type_spec_from_value(element)  # pylint: disable=protected-access
438  if spec is not None:
439    return spec
440
441  if isinstance(element, collections_abc.Mapping):
442    # We create a shallow copy in an attempt to preserve the key order.
443    #
444    # Note that we do not guarantee that the key order is preserved, which is
445    # a limitation inherited from `copy()`. As a consequence, callers of
446    # `type_spec_from_value` should not assume that the key order of a `dict`
447    # in the returned nested structure matches the key order of the
448    # corresponding `dict` in the input value.
449    if isinstance(element, collections.defaultdict):
450      ctor = lambda items: type(element)(element.default_factory, items)
451    else:
452      ctor = type(element)
453    return ctor([(k, type_spec_from_value(v)) for k, v in element.items()])
454
455  if isinstance(element, tuple):
456    if hasattr(element, "_fields") and isinstance(
457        element._fields, collections_abc.Sequence) and all(
458            isinstance(f, six.string_types) for f in element._fields):
459      if isinstance(element, wrapt.ObjectProxy):
460        element_type = type(element.__wrapped__)
461      else:
462        element_type = type(element)
463      # `element` is a namedtuple
464      return element_type(*[type_spec_from_value(v) for v in element])
465    # `element` is not a namedtuple
466    return tuple([type_spec_from_value(v) for v in element])
467
468  if use_fallback:
469    # As a fallback try converting the element to a tensor.
470    try:
471      tensor = ops.convert_to_tensor(element)
472      spec = type_spec_from_value(tensor)
473      if spec is not None:
474        return spec
475    except (ValueError, TypeError) as e:
476      logging.vlog(
477          3, "Failed to convert %r to tensor: %s" % (type(element).__name__, e))
478
479  raise TypeError("Could not build a TypeSpec for %r with type %s" %
480                  (element, type(element).__name__))
481
482
483# TODO(b/149584798): Move this to framework and add tests for non-tf.data
484# functionality.
485class NoneTensor(composite_tensor.CompositeTensor):
486  """Composite tensor representation for `None` value."""
487
488  @property
489  def _type_spec(self):
490    return NoneTensorSpec()
491
492
493# TODO(b/149584798): Move this to framework and add tests for non-tf.data
494# functionality.
495@type_spec.register("tf.NoneTensorSpec")
496class NoneTensorSpec(type_spec.BatchableTypeSpec):
497  """Type specification for `None` value."""
498
499  @property
500  def value_type(self):
501    return NoneTensor
502
503  def _serialize(self):
504    return ()
505
506  @property
507  def _component_specs(self):
508    return []
509
510  def _to_components(self, value):
511    return []
512
513  def _from_components(self, components):
514    return
515
516  def _to_tensor_list(self, value):
517    return []
518
519  @staticmethod
520  def from_value(value):
521    return NoneTensorSpec()
522
523  def _batch(self, batch_size):
524    return NoneTensorSpec()
525
526  def _unbatch(self):
527    return NoneTensorSpec()
528
529  def _to_batched_tensor_list(self, value):
530    return []
531
532  def _to_legacy_output_types(self):
533    return self
534
535  def _to_legacy_output_shapes(self):
536    return self
537
538  def _to_legacy_output_classes(self):
539    return self
540
541  def most_specific_compatible_shape(self, other):
542    if type(self) is not type(other):
543      raise ValueError("No TypeSpec is compatible with both %s and %s" %
544                       (self, other))
545    return self
546
547
548type_spec.register_type_spec_from_value_converter(type(None),
549                                                  NoneTensorSpec.from_value)
550