1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Scan dataset transformation."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.data.util import nest
22from tensorflow.python.data.util import structure
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import gen_experimental_dataset_ops
25from tensorflow.python.util.compat import collections_abc
26from tensorflow.python.util.tf_export import tf_export
27
28
29class _ScanDataset(dataset_ops.UnaryDataset):
30  """A dataset that scans a function across its input."""
31
32  def __init__(self,
33               input_dataset,
34               initial_state,
35               scan_func,
36               use_default_device=None):
37    """See `scan()` for details."""
38    self._input_dataset = input_dataset
39    self._initial_state = structure.normalize_element(initial_state)
40
41    # Compute initial values for the state classes, shapes and types based on
42    # the initial state. The shapes may be refined by running `tf_scan_func` one
43    # or more times below.
44    self._state_structure = structure.type_spec_from_value(self._initial_state)
45
46    # Iteratively rerun the scan function until reaching a fixed point on
47    # `self._state_shapes`.
48    need_to_rerun = True
49    while need_to_rerun:
50
51      wrapped_func = dataset_ops.StructuredFunctionWrapper(
52          scan_func,
53          self._transformation_name(),
54          input_structure=(self._state_structure,
55                           input_dataset.element_spec),
56          add_to_graph=False)
57      if not (isinstance(wrapped_func.output_types, collections_abc.Sequence)
58              and len(wrapped_func.output_types) == 2):
59        raise TypeError("The scan function must return a pair comprising the "
60                        "new state and the output value.")
61
62      new_state_classes, self._output_classes = wrapped_func.output_classes
63
64      # Extract and validate class information from the returned values.
65      new_state_classes, output_classes = wrapped_func.output_classes
66      old_state_classes = nest.map_structure(
67          lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
68          self._state_structure)
69      for new_state_class, old_state_class in zip(
70          nest.flatten(new_state_classes),
71          nest.flatten(old_state_classes)):
72        if not issubclass(new_state_class, old_state_class):
73          raise TypeError(
74              "The element classes for the new state must match the initial "
75              "state. Expected %s; got %s." %
76              (old_state_classes, new_state_classes))
77
78      # Extract and validate type information from the returned values.
79      new_state_types, output_types = wrapped_func.output_types
80      old_state_types = nest.map_structure(
81          lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
82          self._state_structure)
83      for new_state_type, old_state_type in zip(
84          nest.flatten(new_state_types), nest.flatten(old_state_types)):
85        if new_state_type != old_state_type:
86          raise TypeError(
87              "The element types for the new state must match the initial "
88              "state. Expected %s; got %s." %
89              (old_state_types, new_state_types))
90
91      # Extract shape information from the returned values.
92      new_state_shapes, output_shapes = wrapped_func.output_shapes
93      old_state_shapes = nest.map_structure(
94          lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
95          self._state_structure)
96      self._element_spec = structure.convert_legacy_structure(
97          output_types, output_shapes, output_classes)
98
99      flat_state_shapes = nest.flatten(old_state_shapes)
100      flat_new_state_shapes = nest.flatten(new_state_shapes)
101      weakened_state_shapes = [
102          original.most_specific_compatible_shape(new)
103          for original, new in zip(flat_state_shapes, flat_new_state_shapes)
104      ]
105
106      need_to_rerun = False
107      for original_shape, weakened_shape in zip(flat_state_shapes,
108                                                weakened_state_shapes):
109        if original_shape.ndims is not None and (
110            weakened_shape.ndims is None or
111            original_shape.as_list() != weakened_shape.as_list()):
112          need_to_rerun = True
113          break
114
115      if need_to_rerun:
116        # TODO(b/110122868): Support a "most specific compatible structure"
117        # method for combining structures, to avoid using legacy structures
118        # in this method.
119        self._state_structure = structure.convert_legacy_structure(
120            old_state_types,
121            nest.pack_sequence_as(old_state_shapes, weakened_state_shapes),
122            old_state_classes)
123
124    self._scan_func = wrapped_func
125    self._scan_func.function.add_to_graph(ops.get_default_graph())
126    # pylint: disable=protected-access
127    if use_default_device is not None:
128      variant_tensor = gen_experimental_dataset_ops.scan_dataset(
129          self._input_dataset._variant_tensor,
130          structure.to_tensor_list(self._state_structure, self._initial_state),
131          self._scan_func.function.captured_inputs,
132          f=self._scan_func.function,
133          preserve_cardinality=True,
134          use_default_device=use_default_device,
135          **self._flat_structure)
136    else:
137      variant_tensor = gen_experimental_dataset_ops.scan_dataset(
138          self._input_dataset._variant_tensor,
139          structure.to_tensor_list(self._state_structure, self._initial_state),
140          self._scan_func.function.captured_inputs,
141          f=self._scan_func.function,
142          preserve_cardinality=True,
143          **self._flat_structure)
144    super(_ScanDataset, self).__init__(input_dataset, variant_tensor)
145
146  def _functions(self):
147    return [self._scan_func]
148
149  @property
150  def element_spec(self):
151    return self._element_spec
152
153  def _transformation_name(self):
154    return "tf.data.experimental.scan()"
155
156
157@tf_export("data.experimental.scan")
158def scan(initial_state, scan_func):
159  """A transformation that scans a function across an input dataset.
160
161  This transformation is a stateful relative of `tf.data.Dataset.map`.
162  In addition to mapping `scan_func` across the elements of the input dataset,
163  `scan()` accumulates one or more state tensors, whose initial values are
164  `initial_state`.
165
166  Args:
167    initial_state: A nested structure of tensors, representing the initial state
168      of the accumulator.
169    scan_func: A function that maps `(old_state, input_element)` to
170      `(new_state, output_element)`. It must take two arguments and return a
171      pair of nested structures of tensors. The `new_state` must match the
172      structure of `initial_state`.
173
174  Returns:
175    A `Dataset` transformation function, which can be passed to
176    `tf.data.Dataset.apply`.
177  """
178  def _apply_fn(dataset):
179    return _ScanDataset(dataset, initial_state, scan_func)
180
181  return _apply_fn
182