1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for describing the structure of a `tf.data` type."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21
22import six
23
24from tensorflow.python.data.util import nest
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.framework import tensor_util
30from tensorflow.python.ops import sparse_ops
31from tensorflow.python.util.tf_export import tf_export
32
33
34_STRUCTURE_CONVERSION_FUNCTION_REGISTRY = {}
35
36
37@tf_export("data.experimental.Structure")
38@six.add_metaclass(abc.ABCMeta)
39class Structure(object):
40  """Represents structural information, such as type and shape, about a value.
41
42  A `Structure` generalizes the `tf.Tensor.dtype` and `tf.Tensor.shape`
43  properties, so that we can define generic containers of objects including:
44
45  * `tf.Tensor`
46  * `tf.SparseTensor`
47  * Nested structures of the above.
48
49  TODO(b/110122868): In the future, a single `Structure` will replace the
50  `tf.data.Dataset.output_types`, `tf.data.Dataset.output_shapes`,
51  and `tf.data.Dataset.output_classes`, and similar properties and arguments in
52  the `tf.data.Iterator` and `Optional` classes.
53  """
54
55  @abc.abstractproperty
56  def _flat_shapes(self):
57    """A list of shapes matching the shapes of `self._to_tensor_list()`.
58
59    Returns:
60      A list of `tf.TensorShape` objects.
61    """
62    raise NotImplementedError("Structure._flat_shapes")
63
64  @abc.abstractproperty
65  def _flat_types(self):
66    """A list of types matching the types of `self._to_tensor_list()`.
67
68    Returns:
69      A list of `tf.DType` objects.
70    """
71    raise NotImplementedError("Structure._flat_shapes")
72
73  @abc.abstractmethod
74  def is_compatible_with(self, other):
75    """Returns `True` if `other` is compatible with this structure.
76
77    A structure `t` is a "subtype" of `s` if:
78
79    * `s` and `t` are instances of the same `Structure` subclass.
80    * The nested structures (if any) of `s` and `t` are the same, according to
81      `tf.contrib.framework.nest.assert_same_structure`, and each nested
82      structure of `t` is a "subtype" of the corresponding nested structure of
83      `s`.
84    * Any `tf.DType` components of `t` are the same as the corresponding
85      components in `s`.
86    * Any `tf.TensorShape` components of `t` are compatible with the
87      corresponding components in `s`, according to
88      `tf.TensorShape.is_compatible_with`.
89
90    Args:
91      other: A `Structure`.
92
93    Returns:
94      `True` if `other` is a subtype of this structure, otherwise `False`.
95    """
96    raise NotImplementedError("Structure.is_compatible_with()")
97
98  @abc.abstractmethod
99  def _to_tensor_list(self, value):
100    """Returns a flat list of `tf.Tensor` representing `value`.
101
102    This method can be used, along with `self._flat_shapes` and
103    `self._flat_types` to represent structured values in lower level APIs
104    (such as plain TensorFlow operations) that do not understand structure.
105
106    Requires: `self.is_compatible_with(Structure.from_value(value))`.
107
108    Args:
109      value: A value with compatible structure.
110
111    Returns:
112      A flat list of `tf.Tensor` representing `value`.
113    """
114    raise NotImplementedError("Structure._to_tensor_list()")
115
116  @abc.abstractmethod
117  def _to_batched_tensor_list(self, value):
118    """Returns a flat list of rank >= 1 `tf.Tensor` representing `value`.
119
120    This method can be used, along with `self._flat_shapes` and
121    `self._flat_types` to represent structured values in lower level APIs
122    (such as plain TensorFlow operations) that do not understand structure,
123    *and* that require that the plain tensors have a rank of at least one
124    (e.g. for the purpose of slicing the tensors).
125
126    Requires: `self.is_compatible_with(Structure.from_value(value))`.
127
128    Args:
129      value: A value with compatible structure.
130
131    Returns:
132      A flat list of `tf.Tensor` representing `value`.
133    """
134    raise NotImplementedError("Structure._to_batched_tensor_list()")
135
136  @abc.abstractmethod
137  def _from_tensor_list(self, flat_value):
138    """Builds a flat list of `tf.Tensor` into a value matching this structure.
139
140    Args:
141      flat_value: A list of `tf.Tensor` with compatible flat structure.
142
143    Returns:
144      A structured object matching this structure.
145
146    Raises:
147      ValueError: If the shapes and types of the tensors in `flat_value` are not
148        compatible with `self._flat_shapes` and `self._flat_types` respectively.
149    """
150    raise NotImplementedError("Structure._from_tensor_list()")
151
152  def _from_compatible_tensor_list(self, flat_value):
153    """A version of `_from_tensor_list()` that may avoid performing checks.
154
155    NOTE: This method should be used to avoid checks for performance reasons,
156    when the validity of `flat_value` has been validated by other means.
157    The shapes and types of the tensors in `flat_value` must be compatible with
158    `self._flat_shapes` and `self._flat_types` respectively. The behavior is
159    undefined if this requirement is not met.
160
161    Args:
162      flat_value: A list of `tf.Tensor` with compatible flat structure.
163
164    Returns:
165      A structured object matching this structure.
166    """
167    return self._from_tensor_list(flat_value)
168
169  @abc.abstractmethod
170  def _batch(self, batch_size):
171    """Returns a structure representing a batch of objects with this structure.
172
173    Args:
174      batch_size: An `int` representing the number of elements in a batch,
175        or `None` if the batch size may vary.
176
177    Returns:
178      A `Structure` representing a batch of objects with this structure.
179    """
180    raise NotImplementedError("Structure._batch()")
181
182  @abc.abstractmethod
183  def _unbatch(self):
184    raise NotImplementedError("Structure._unbatch()")
185
186  @staticmethod
187  def from_value(value):
188    """Returns a `Structure` that represents the given `value`.
189
190    Args:
191      value: A potentially structured value.
192
193    Returns:
194      A `Structure` that is compatible with `value`.
195
196    Raises:
197      TypeError: If a structure cannot be built for `value`, because its type
198        or one of its component types is not supported.
199    """
200    # TODO(b/110122868): Add support for custom types and Dataset to this
201    # method.
202    if isinstance(
203        value,
204        (sparse_tensor_lib.SparseTensor, sparse_tensor_lib.SparseTensorValue)):
205      return SparseTensorStructure.from_value(value)
206    elif isinstance(value, (tuple, dict)):
207      return NestedStructure.from_value(value)
208    else:
209      for converter_type, converter_fn in (
210          _STRUCTURE_CONVERSION_FUNCTION_REGISTRY.items()):
211        if isinstance(value, converter_type):
212          return converter_fn(value)
213      try:
214        tensor = ops.convert_to_tensor(value)
215      except (ValueError, TypeError):
216        raise TypeError("Could not build a structure for %r" % value)
217      return TensorStructure.from_value(tensor)
218
219  @staticmethod
220  def _register_custom_converter(type_object, converter_fn):
221    """Registers `converter_fn` for converting values of the given type.
222
223    Args:
224      type_object: A Python `type` object representing the type of values
225        accepted by `converter_fn`.
226      converter_fn: A function that takes one argument (an instance of the
227        type represented by `type_object`) and returns a `Structure`.
228    """
229    _STRUCTURE_CONVERSION_FUNCTION_REGISTRY[type_object] = converter_fn
230
231  @abc.abstractmethod
232  def _to_legacy_output_types(self):
233    raise NotImplementedError("Structure._to_legacy_output_types()")
234
235  @abc.abstractmethod
236  def _to_legacy_output_shapes(self):
237    raise NotImplementedError("Structure._to_legacy_output_shapes()")
238
239  @abc.abstractmethod
240  def _to_legacy_output_classes(self):
241    raise NotImplementedError("Structure._to_legacy_output_classes()")
242
243
244def convert_legacy_structure(output_types, output_shapes, output_classes):
245  """Returns a `Structure` that represents the given legacy structure.
246
247  This method provides a way to convert from the existing `Dataset` and
248  `Iterator` structure-related properties to a `Structure` object. A "legacy"
249  structure is represented by the `tf.data.Dataset.output_types`,
250  `tf.data.Dataset.output_shapes`, and `tf.data.Dataset.output_classes`
251  properties.
252
253  TODO(b/110122868): Remove this function once `Structure` is used throughout
254  `tf.data`.
255
256  Args:
257    output_types: A nested structure of `tf.DType` objects corresponding to
258      each component of a structured value.
259    output_shapes: A nested structure of `tf.TensorShape` objects
260      corresponding to each component a structured value.
261    output_classes: A nested structure of Python `type` objects corresponding
262      to each component of a structured value.
263
264  Returns:
265    A `Structure`.
266
267  Raises:
268    TypeError: If a structure cannot be built from the arguments, because one of
269      the component classes in `output_classes` is not supported.
270  """
271  flat_types = nest.flatten(output_types)
272  flat_shapes = nest.flatten(output_shapes)
273  flat_classes = nest.flatten(output_classes)
274  flat_ret = []
275  for flat_type, flat_shape, flat_class in zip(flat_types, flat_shapes,
276                                               flat_classes):
277    if isinstance(flat_class, Structure):
278      flat_ret.append(flat_class)
279    elif issubclass(flat_class, sparse_tensor_lib.SparseTensor):
280      flat_ret.append(SparseTensorStructure(flat_type, flat_shape))
281    elif issubclass(flat_class, ops.Tensor):
282      flat_ret.append(TensorStructure(flat_type, flat_shape))
283    else:
284      # NOTE(mrry): Since legacy structures produced by iterators only
285      # comprise Tensors, SparseTensors, and nests, we do not need to
286      # support all structure types here.
287      raise TypeError(
288          "Could not build a structure for output class %r" % flat_type)
289
290  ret = nest.pack_sequence_as(output_classes, flat_ret)
291  if isinstance(ret, Structure):
292    return ret
293  else:
294    return NestedStructure(ret)
295
296
297# NOTE(mrry): The following classes make extensive use of non-public methods of
298# their base class, so we disable the protected-access lint warning once here.
299# pylint: disable=protected-access
300@tf_export("data.experimental.NestedStructure")
301class NestedStructure(Structure):
302  """Represents a nested structure in which each leaf is a `Structure`."""
303
304  def __init__(self, nested_structure):
305    self._nested_structure = nested_structure
306    self._flat_nested_structure = nest.flatten(nested_structure)
307    self._flat_shapes_list = []
308    self._flat_types_list = []
309    for s in nest.flatten(nested_structure):
310      if not isinstance(s, Structure):
311        raise TypeError("nested_structure must be a (potentially nested) tuple "
312                        "or dictionary of Structure objects.")
313      self._flat_shapes_list.extend(s._flat_shapes)
314      self._flat_types_list.extend(s._flat_types)
315
316  @property
317  def _flat_shapes(self):
318    return self._flat_shapes_list
319
320  @property
321  def _flat_types(self):
322    return self._flat_types_list
323
324  def is_compatible_with(self, other):
325    if not isinstance(other, NestedStructure):
326      return False
327    try:
328      # pylint: disable=protected-access
329      nest.assert_same_structure(self._nested_structure,
330                                 other._nested_structure)
331    except (ValueError, TypeError):
332      return False
333
334    return all(
335        substructure.is_compatible_with(other_substructure)
336        for substructure, other_substructure in zip(
337            nest.flatten(self._nested_structure),
338            nest.flatten(other._nested_structure)))
339
340  def _to_tensor_list(self, value):
341    ret = []
342
343    try:
344      flat_value = nest.flatten_up_to(self._nested_structure, value)
345    except (ValueError, TypeError):
346      raise ValueError("The value %r is not compatible with the nested "
347                       "structure %r." % (value, self._nested_structure))
348
349    for sub_value, structure in zip(flat_value, self._flat_nested_structure):
350      if not structure.is_compatible_with(Structure.from_value(sub_value)):
351        raise ValueError("Component value %r is not compatible with the nested "
352                         "structure %r." % (sub_value, structure))
353      ret.extend(structure._to_tensor_list(sub_value))
354    return ret
355
356  def _to_batched_tensor_list(self, value):
357    ret = []
358
359    try:
360      flat_value = nest.flatten_up_to(self._nested_structure, value)
361    except (ValueError, TypeError):
362      raise ValueError("The value %r is not compatible with the nested "
363                       "structure %r." % (value, self._nested_structure))
364
365    for sub_value, structure in zip(flat_value, self._flat_nested_structure):
366      if not structure.is_compatible_with(Structure.from_value(sub_value)):
367        raise ValueError("Component value %r is not compatible with the nested "
368                         "structure %r." % (sub_value, structure))
369      ret.extend(structure._to_batched_tensor_list(sub_value))
370    return ret
371
372  def _from_tensor_list(self, flat_value):
373    if len(flat_value) != len(self._flat_types):
374      raise ValueError("Expected %d flat values in NestedStructure but got %d."
375                       % (len(self._flat_types), len(flat_value)))
376
377    flat_ret = []
378    i = 0
379    for structure in self._flat_nested_structure:
380      num_flat_values = len(structure._flat_types)
381      sub_value = flat_value[i:i + num_flat_values]
382      flat_ret.append(structure._from_tensor_list(sub_value))
383      i += num_flat_values
384
385    return nest.pack_sequence_as(self._nested_structure, flat_ret)
386
387  def _from_compatible_tensor_list(self, flat_value):
388    flat_ret = []
389    i = 0
390    for structure in self._flat_nested_structure:
391      num_flat_values = len(structure._flat_types)
392      sub_value = flat_value[i:i + num_flat_values]
393      flat_ret.append(structure._from_compatible_tensor_list(sub_value))
394      i += num_flat_values
395
396    return nest.pack_sequence_as(self._nested_structure, flat_ret)
397
398  @staticmethod
399  def from_value(value):
400    flat_nested_structure = [
401        Structure.from_value(sub_value) for sub_value in nest.flatten(value)
402    ]
403    return NestedStructure(nest.pack_sequence_as(value, flat_nested_structure))
404
405  def _to_legacy_output_types(self):
406    return nest.map_structure(
407        lambda s: s._to_legacy_output_types(), self._nested_structure)
408
409  def _to_legacy_output_shapes(self):
410    return nest.map_structure(
411        lambda s: s._to_legacy_output_shapes(), self._nested_structure)
412
413  def _to_legacy_output_classes(self):
414    return nest.map_structure(
415        lambda s: s._to_legacy_output_classes(), self._nested_structure)
416
417  def _batch(self, batch_size):
418    return NestedStructure(nest.map_structure(
419        lambda s: s._batch(batch_size), self._nested_structure))
420
421  def _unbatch(self):
422    return NestedStructure(nest.map_structure(
423        lambda s: s._unbatch(), self._nested_structure))
424
425
426@tf_export("data.experimental.TensorStructure")
427class TensorStructure(Structure):
428  """Represents structural information about a `tf.Tensor`."""
429
430  def __init__(self, dtype, shape):
431    self._dtype = dtypes.as_dtype(dtype)
432    self._shape = tensor_shape.as_shape(shape)
433
434  @property
435  def _flat_shapes(self):
436    return [self._shape]
437
438  @property
439  def _flat_types(self):
440    return [self._dtype]
441
442  def is_compatible_with(self, other):
443    return (isinstance(other, TensorStructure) and
444            self._dtype.is_compatible_with(other._dtype) and
445            self._shape.is_compatible_with(other._shape))
446
447  def _to_tensor_list(self, value):
448    if not self.is_compatible_with(Structure.from_value(value)):
449      raise ValueError("Value %r is not convertible to a tensor with dtype %s "
450                       "and shape %s." % (value, self._dtype, self._shape))
451    return [value]
452
453  def _to_batched_tensor_list(self, value):
454    if self._shape.merge_with(value.shape).ndims == 0:
455      raise ValueError("Unbatching a tensor is only supported for rank >= 1")
456    return [value]
457
458  def _from_tensor_list(self, flat_value):
459    if len(flat_value) != 1:
460      raise ValueError("TensorStructure corresponds to a single tf.Tensor.")
461    if not self.is_compatible_with(Structure.from_value(flat_value[0])):
462      raise ValueError("Cannot convert %r to a tensor with dtype %s and shape "
463                       "%s." % (flat_value[0], self._dtype, self._shape))
464    return self._from_compatible_tensor_list(flat_value)
465
466  def _from_compatible_tensor_list(self, flat_value):
467    # TODO(b/112266545): It would be cleaner to create a new `ensure_shape()`
468    # op here and return that, instead of mutating the input's shape using
469    # `Tensor.set_shape()`. However, that would add extra ops on the arguments
470    # of each `tf.data` function, which could impact performance. When this
471    # bug is resolved, we should be able to add the `ensure_shape()` ops and
472    # optimize them away using contextual shape information.
473    flat_value[0].set_shape(self._shape)
474    return flat_value[0]
475
476  @staticmethod
477  def from_value(value):
478    return TensorStructure(value.dtype, value.shape)
479
480  def _to_legacy_output_types(self):
481    return self._dtype
482
483  def _to_legacy_output_shapes(self):
484    return self._shape
485
486  def _to_legacy_output_classes(self):
487    return ops.Tensor
488
489  def _batch(self, batch_size):
490    return TensorStructure(
491        self._dtype,
492        tensor_shape.TensorShape([batch_size]).concatenate(self._shape))
493
494  def _unbatch(self):
495    if self._shape.ndims == 0:
496      raise ValueError("Unbatching a tensor is only supported for rank >= 1")
497    return TensorStructure(self._dtype, self._shape[1:])
498
499
500@tf_export("data.experimental.SparseTensorStructure")
501class SparseTensorStructure(Structure):
502  """Represents structural information about a `tf.SparseTensor`."""
503
504  def __init__(self, dtype, dense_shape):
505    self._dtype = dtypes.as_dtype(dtype)
506    self._dense_shape = tensor_shape.as_shape(dense_shape)
507
508  @property
509  def _flat_shapes(self):
510    # NOTE(mrry): The default flat shape of a boxed `SparseTensor` is `(3,)`,
511    # but a `SparseTensorStructure` can also represent a batch of boxed
512    # `SparseTensor` objects with shape `(?, 3)` (and batches of batches, etc.),
513    # so the flat shape must be unknown.
514    return [tensor_shape.unknown_shape(None)]
515
516  @property
517  def _flat_types(self):
518    return [dtypes.variant]
519
520  def is_compatible_with(self, other):
521    return (isinstance(other, SparseTensorStructure) and
522            self._dtype.is_compatible_with(other._dtype) and
523            self._dense_shape.is_compatible_with(other._dense_shape))
524
525  def _to_tensor_list(self, value):
526    return [sparse_ops.serialize_sparse(value, out_type=dtypes.variant)]
527
528  def _to_batched_tensor_list(self, value):
529    if self._dense_shape.merge_with(
530        tensor_util.constant_value_as_shape(value.dense_shape)).ndims == 0:
531      raise ValueError(
532          "Unbatching a sparse tensor is only supported for rank >= 1")
533    return [sparse_ops.serialize_many_sparse(value, out_type=dtypes.variant)]
534
535  def _from_tensor_list(self, flat_value):
536    if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or
537        not flat_value[0].shape.is_compatible_with(tensor_shape.vector(3))):
538      raise ValueError("SparseTensorStructure corresponds to a single "
539                       "tf.variant vector of length 3.")
540    return self._from_compatible_tensor_list(flat_value)
541
542  def _from_compatible_tensor_list(self, flat_value):
543    ret = sparse_ops.deserialize_sparse(
544        flat_value[0], dtype=self._dtype, rank=self._dense_shape.ndims)
545    ret.indices.set_shape([None, self._dense_shape.ndims])
546    ret.dense_shape.set_shape([self._dense_shape.ndims])
547    return ret
548
549  @staticmethod
550  def from_value(value):
551    sparse_tensor = sparse_tensor_lib.SparseTensor.from_value(value)
552    return SparseTensorStructure(
553        sparse_tensor.dtype,
554        tensor_util.constant_value_as_shape(sparse_tensor.dense_shape))
555
556  def _to_legacy_output_types(self):
557    return self._dtype
558
559  def _to_legacy_output_shapes(self):
560    return self._dense_shape
561
562  def _to_legacy_output_classes(self):
563    return sparse_tensor_lib.SparseTensor
564
565  def _batch(self, batch_size):
566    return SparseTensorStructure(
567        self._dtype,
568        tensor_shape.TensorShape([batch_size]).concatenate(self._dense_shape))
569
570  def _unbatch(self):
571    if self._dense_shape.ndims == 0:
572      raise ValueError("Unbatching a tensor is only supported for rank >= 1")
573    return SparseTensorStructure(self._dtype, self._dense_shape[1:])
574