1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Scan dataset transformation.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.data.util import nest 22from tensorflow.python.data.util import structure 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import gen_experimental_dataset_ops 25from tensorflow.python.util.compat import collections_abc 26from tensorflow.python.util.tf_export import tf_export 27 28 29class _ScanDataset(dataset_ops.UnaryDataset): 30 """A dataset that scans a function across its input.""" 31 32 def __init__(self, 33 input_dataset, 34 initial_state, 35 scan_func, 36 use_default_device=None): 37 """See `scan()` for details.""" 38 self._input_dataset = input_dataset 39 self._initial_state = structure.normalize_element(initial_state) 40 41 # Compute initial values for the state classes, shapes and types based on 42 # the initial state. The shapes may be refined by running `tf_scan_func` one 43 # or more times below. 44 self._state_structure = structure.type_spec_from_value(self._initial_state) 45 46 # Iteratively rerun the scan function until reaching a fixed point on 47 # `self._state_shapes`. 48 need_to_rerun = True 49 while need_to_rerun: 50 51 wrapped_func = dataset_ops.StructuredFunctionWrapper( 52 scan_func, 53 self._transformation_name(), 54 input_structure=(self._state_structure, 55 input_dataset.element_spec), 56 add_to_graph=False) 57 if not (isinstance(wrapped_func.output_types, collections_abc.Sequence) 58 and len(wrapped_func.output_types) == 2): 59 raise TypeError("The scan function must return a pair comprising the " 60 "new state and the output value.") 61 62 new_state_classes, self._output_classes = wrapped_func.output_classes 63 64 # Extract and validate class information from the returned values. 65 new_state_classes, output_classes = wrapped_func.output_classes 66 old_state_classes = nest.map_structure( 67 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 68 self._state_structure) 69 for new_state_class, old_state_class in zip( 70 nest.flatten(new_state_classes), 71 nest.flatten(old_state_classes)): 72 if not issubclass(new_state_class, old_state_class): 73 raise TypeError( 74 "The element classes for the new state must match the initial " 75 "state. Expected %s; got %s." % 76 (old_state_classes, new_state_classes)) 77 78 # Extract and validate type information from the returned values. 79 new_state_types, output_types = wrapped_func.output_types 80 old_state_types = nest.map_structure( 81 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 82 self._state_structure) 83 for new_state_type, old_state_type in zip( 84 nest.flatten(new_state_types), nest.flatten(old_state_types)): 85 if new_state_type != old_state_type: 86 raise TypeError( 87 "The element types for the new state must match the initial " 88 "state. Expected %s; got %s." % 89 (old_state_types, new_state_types)) 90 91 # Extract shape information from the returned values. 92 new_state_shapes, output_shapes = wrapped_func.output_shapes 93 old_state_shapes = nest.map_structure( 94 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 95 self._state_structure) 96 self._element_spec = structure.convert_legacy_structure( 97 output_types, output_shapes, output_classes) 98 99 flat_state_shapes = nest.flatten(old_state_shapes) 100 flat_new_state_shapes = nest.flatten(new_state_shapes) 101 weakened_state_shapes = [ 102 original.most_specific_compatible_shape(new) 103 for original, new in zip(flat_state_shapes, flat_new_state_shapes) 104 ] 105 106 need_to_rerun = False 107 for original_shape, weakened_shape in zip(flat_state_shapes, 108 weakened_state_shapes): 109 if original_shape.ndims is not None and ( 110 weakened_shape.ndims is None or 111 original_shape.as_list() != weakened_shape.as_list()): 112 need_to_rerun = True 113 break 114 115 if need_to_rerun: 116 # TODO(b/110122868): Support a "most specific compatible structure" 117 # method for combining structures, to avoid using legacy structures 118 # in this method. 119 self._state_structure = structure.convert_legacy_structure( 120 old_state_types, 121 nest.pack_sequence_as(old_state_shapes, weakened_state_shapes), 122 old_state_classes) 123 124 self._scan_func = wrapped_func 125 self._scan_func.function.add_to_graph(ops.get_default_graph()) 126 # pylint: disable=protected-access 127 if use_default_device is not None: 128 variant_tensor = gen_experimental_dataset_ops.scan_dataset( 129 self._input_dataset._variant_tensor, 130 structure.to_tensor_list(self._state_structure, self._initial_state), 131 self._scan_func.function.captured_inputs, 132 f=self._scan_func.function, 133 preserve_cardinality=True, 134 use_default_device=use_default_device, 135 **self._flat_structure) 136 else: 137 variant_tensor = gen_experimental_dataset_ops.scan_dataset( 138 self._input_dataset._variant_tensor, 139 structure.to_tensor_list(self._state_structure, self._initial_state), 140 self._scan_func.function.captured_inputs, 141 f=self._scan_func.function, 142 preserve_cardinality=True, 143 **self._flat_structure) 144 super(_ScanDataset, self).__init__(input_dataset, variant_tensor) 145 146 def _functions(self): 147 return [self._scan_func] 148 149 @property 150 def element_spec(self): 151 return self._element_spec 152 153 def _transformation_name(self): 154 return "tf.data.experimental.scan()" 155 156 157@tf_export("data.experimental.scan") 158def scan(initial_state, scan_func): 159 """A transformation that scans a function across an input dataset. 160 161 This transformation is a stateful relative of `tf.data.Dataset.map`. 162 In addition to mapping `scan_func` across the elements of the input dataset, 163 `scan()` accumulates one or more state tensors, whose initial values are 164 `initial_state`. 165 166 Args: 167 initial_state: A nested structure of tensors, representing the initial state 168 of the accumulator. 169 scan_func: A function that maps `(old_state, input_element)` to 170 `(new_state, output_element)`. It must take two arguments and return a 171 pair of nested structures of tensors. The `new_state` must match the 172 structure of `initial_state`. 173 174 Returns: 175 A `Dataset` transformation function, which can be passed to 176 `tf.data.Dataset.apply`. 177 """ 178 def _apply_fn(dataset): 179 return _ScanDataset(dataset, initial_state, scan_func) 180 181 return _apply_fn 182