1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Grouping dataset transformations."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.python.data.ops import dataset_ops
23from tensorflow.python.data.util import nest
24from tensorflow.python.data.util import structure
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.framework import tensor_spec
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import check_ops
32from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.util.tf_export import tf_export
35
36
37@tf_export("data.experimental.group_by_reducer")
38def group_by_reducer(key_func, reducer):
39  """A transformation that groups elements and performs a reduction.
40
41  This transformation maps element of a dataset to a key using `key_func` and
42  groups the elements by key. The `reducer` is used to process each group; its
43  `init_func` is used to initialize state for each group when it is created, the
44  `reduce_func` is used to update the state every time an element is mapped to
45  the matching group, and the `finalize_func` is used to map the final state to
46  an output value.
47
48  Args:
49    key_func: A function mapping a nested structure of tensors
50      (having shapes and types defined by `self.output_shapes` and
51      `self.output_types`) to a scalar `tf.int64` tensor.
52    reducer: An instance of `Reducer`, which captures the reduction logic using
53      the `init_func`, `reduce_func`, and `finalize_func` functions.
54
55  Returns:
56    A `Dataset` transformation function, which can be passed to
57    `tf.data.Dataset.apply`.
58  """
59
60  def _apply_fn(dataset):
61    """Function from `Dataset` to `Dataset` that applies the transformation."""
62    return _GroupByReducerDataset(dataset, key_func, reducer)
63
64  return _apply_fn
65
66
67@tf_export("data.experimental.group_by_window")
68def group_by_window(key_func,
69                    reduce_func,
70                    window_size=None,
71                    window_size_func=None):
72  """A transformation that groups windows of elements by key and reduces them.
73
74  This transformation maps each consecutive element in a dataset to a key
75  using `key_func` and groups the elements by key. It then applies
76  `reduce_func` to at most `window_size_func(key)` elements matching the same
77  key. All except the final window for each key will contain
78  `window_size_func(key)` elements; the final window may be smaller.
79
80  You may provide either a constant `window_size` or a window size determined by
81  the key through `window_size_func`.
82
83  Args:
84    key_func: A function mapping a nested structure of tensors
85      (having shapes and types defined by `self.output_shapes` and
86      `self.output_types`) to a scalar `tf.int64` tensor.
87    reduce_func: A function mapping a key and a dataset of up to `window_size`
88      consecutive elements matching that key to another dataset.
89    window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
90      consecutive elements matching the same key to combine in a single
91      batch, which will be passed to `reduce_func`. Mutually exclusive with
92      `window_size_func`.
93    window_size_func: A function mapping a key to a `tf.int64` scalar
94      `tf.Tensor`, representing the number of consecutive elements matching
95      the same key to combine in a single batch, which will be passed to
96      `reduce_func`. Mutually exclusive with `window_size`.
97
98  Returns:
99    A `Dataset` transformation function, which can be passed to
100    `tf.data.Dataset.apply`.
101
102  Raises:
103    ValueError: if neither or both of {`window_size`, `window_size_func`} are
104      passed.
105  """
106  if (window_size is not None and window_size_func or
107      not (window_size is not None or window_size_func)):
108    raise ValueError("Must pass either window_size or window_size_func.")
109
110  if window_size is not None:
111
112    def constant_window_func(unused_key):
113      return ops.convert_to_tensor(window_size, dtype=dtypes.int64)
114
115    window_size_func = constant_window_func
116
117  assert window_size_func is not None
118
119  def _apply_fn(dataset):
120    """Function from `Dataset` to `Dataset` that applies the transformation."""
121    return _GroupByWindowDataset(dataset, key_func, reduce_func,
122                                 window_size_func)
123
124  return _apply_fn
125
126
127@tf_export("data.experimental.bucket_by_sequence_length")
128def bucket_by_sequence_length(element_length_func,
129                              bucket_boundaries,
130                              bucket_batch_sizes,
131                              padded_shapes=None,
132                              padding_values=None,
133                              pad_to_bucket_boundary=False,
134                              no_padding=False,
135                              drop_remainder=False):
136  """A transformation that buckets elements in a `Dataset` by length.
137
138  Elements of the `Dataset` are grouped together by length and then are padded
139  and batched.
140
141  This is useful for sequence tasks in which the elements have variable length.
142  Grouping together elements that have similar lengths reduces the total
143  fraction of padding in a batch which increases training step efficiency.
144
145  Below is an example to bucketize the input data to the 3 buckets
146  "[0, 3), [3, 5), [5, inf)" based on sequence length, with batch size 2.
147
148  >>> elements = [
149  ...   [0], [1, 2, 3, 4], [5, 6, 7],
150  ...   [7, 8, 9, 10, 11], [13, 14, 15, 16, 19, 20], [21, 22]]
151
152  >>> dataset = tf.data.Dataset.from_generator(
153  ...     lambda: elements, tf.int64, output_shapes=[None])
154
155  >>> dataset = dataset.apply(
156  ...     tf.data.experimental.bucket_by_sequence_length(
157  ...         element_length_func=lambda elem: tf.shape(elem)[0],
158  ...         bucket_boundaries=[3, 5],
159  ...         bucket_batch_sizes=[2, 2, 2]))
160
161  >>> for elem in dataset.as_numpy_iterator():
162  ...   print(elem)
163  [[1 2 3 4]
164   [5 6 7 0]]
165  [[ 7  8  9 10 11  0]
166   [13 14 15 16 19 20]]
167  [[ 0  0]
168   [21 22]]
169
170  Args:
171    element_length_func: function from element in `Dataset` to `tf.int32`,
172      determines the length of the element, which will determine the bucket it
173      goes into.
174    bucket_boundaries: `list<int>`, upper length boundaries of the buckets.
175    bucket_batch_sizes: `list<int>`, batch size per bucket. Length should be
176      `len(bucket_boundaries) + 1`.
177    padded_shapes: Nested structure of `tf.TensorShape` to pass to
178      `tf.data.Dataset.padded_batch`. If not provided, will use
179      `dataset.output_shapes`, which will result in variable length dimensions
180      being padded out to the maximum length in each batch.
181    padding_values: Values to pad with, passed to
182      `tf.data.Dataset.padded_batch`. Defaults to padding with 0.
183    pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown
184      size to maximum length in batch. If `True`, will pad dimensions with
185      unknown size to bucket boundary minus 1 (i.e., the maximum length in each
186      bucket), and caller must ensure that the source `Dataset` does not contain
187      any elements with length longer than `max(bucket_boundaries)`.
188    no_padding: `bool`, indicates whether to pad the batch features (features
189      need to be either of type `tf.sparse.SparseTensor` or of same shape).
190    drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
191      whether the last batch should be dropped in the case it has fewer than
192      `batch_size` elements; the default behavior is not to drop the smaller
193      batch.
194
195  Returns:
196    A `Dataset` transformation function, which can be passed to
197    `tf.data.Dataset.apply`.
198
199  Raises:
200    ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`.
201  """
202  with ops.name_scope("bucket_by_seq_length"):
203    if len(bucket_batch_sizes) != (len(bucket_boundaries) + 1):
204      raise ValueError(
205          "len(bucket_batch_sizes) must equal len(bucket_boundaries) + 1")
206
207    batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64)
208
209    def element_to_bucket_id(*args):
210      """Return int64 id of the length bucket for this element."""
211      seq_length = element_length_func(*args)
212
213      boundaries = list(bucket_boundaries)
214      buckets_min = [np.iinfo(np.int32).min] + boundaries
215      buckets_max = boundaries + [np.iinfo(np.int32).max]
216      conditions_c = math_ops.logical_and(
217          math_ops.less_equal(buckets_min, seq_length),
218          math_ops.less(seq_length, buckets_max))
219      bucket_id = math_ops.reduce_min(array_ops.where(conditions_c))
220
221      return bucket_id
222
223    def window_size_fn(bucket_id):
224      # The window size is set to the batch size for this bucket
225      window_size = batch_sizes[bucket_id]
226      return window_size
227
228    def make_padded_shapes(shapes, none_filler=None):
229      padded = []
230      for shape in nest.flatten(shapes):
231        shape = tensor_shape.TensorShape(shape)
232        shape = [
233            none_filler if tensor_shape.dimension_value(d) is None else d
234            for d in shape
235        ]
236        padded.append(shape)
237      return nest.pack_sequence_as(shapes, padded)
238
239    def batching_fn(bucket_id, grouped_dataset):
240      """Batch elements in dataset."""
241      batch_size = window_size_fn(bucket_id)
242      if no_padding:
243        return grouped_dataset.batch(batch_size, drop_remainder=drop_remainder)
244      none_filler = None
245      if pad_to_bucket_boundary:
246        err_msg = ("When pad_to_bucket_boundary=True, elements must have "
247                   "length < max(bucket_boundaries).")
248        check = check_ops.assert_less(
249            bucket_id,
250            constant_op.constant(len(bucket_batch_sizes) - 1,
251                                 dtype=dtypes.int64),
252            message=err_msg)
253        with ops.control_dependencies([check]):
254          boundaries = constant_op.constant(bucket_boundaries,
255                                            dtype=dtypes.int64)
256          bucket_boundary = boundaries[bucket_id]
257          none_filler = bucket_boundary - 1
258      input_shapes = dataset_ops.get_legacy_output_shapes(grouped_dataset)
259      shapes = make_padded_shapes(padded_shapes or input_shapes,
260                                  none_filler=none_filler)
261      return grouped_dataset.padded_batch(
262          batch_size, shapes, padding_values, drop_remainder=drop_remainder)
263
264    def _apply_fn(dataset):
265      return dataset.apply(
266          group_by_window(element_to_bucket_id, batching_fn,
267                          window_size_func=window_size_fn))
268
269    return _apply_fn
270
271
272class _GroupByReducerDataset(dataset_ops.UnaryDataset):
273  """A `Dataset` that groups its input and performs a reduction."""
274
275  def __init__(self, input_dataset, key_func, reducer):
276    """See `group_by_reducer()` for details."""
277    self._input_dataset = input_dataset
278    self._make_key_func(key_func, input_dataset)
279    self._make_init_func(reducer.init_func)
280    self._make_reduce_func(reducer.reduce_func, input_dataset)
281    self._make_finalize_func(reducer.finalize_func)
282    variant_tensor = ged_ops.experimental_group_by_reducer_dataset(
283        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
284        self._key_func.function.captured_inputs,
285        self._init_func.function.captured_inputs,
286        self._reduce_func.function.captured_inputs,
287        self._finalize_func.function.captured_inputs,
288        key_func=self._key_func.function,
289        init_func=self._init_func.function,
290        reduce_func=self._reduce_func.function,
291        finalize_func=self._finalize_func.function,
292        **self._flat_structure)
293    super(_GroupByReducerDataset, self).__init__(input_dataset, variant_tensor)
294
295  def _make_key_func(self, key_func, input_dataset):
296    """Make wrapping defun for key_func."""
297    self._key_func = dataset_ops.StructuredFunctionWrapper(
298        key_func, self._transformation_name(), dataset=input_dataset)
299    if not self._key_func.output_structure.is_compatible_with(
300        tensor_spec.TensorSpec([], dtypes.int64)):
301      raise ValueError(
302          "`key_func` must return a single tf.int64 tensor. "
303          "Got type=%s and shape=%s"
304          % (self._key_func.output_types, self._key_func.output_shapes))
305
306  def _make_init_func(self, init_func):
307    """Make wrapping defun for init_func."""
308    self._init_func = dataset_ops.StructuredFunctionWrapper(
309        init_func,
310        self._transformation_name(),
311        input_structure=tensor_spec.TensorSpec([], dtypes.int64))
312
313  def _make_reduce_func(self, reduce_func, input_dataset):
314    """Make wrapping defun for reduce_func."""
315
316    # Iteratively rerun the reduce function until reaching a fixed point on
317    # `self._state_structure`.
318    self._state_structure = self._init_func.output_structure
319    state_types = self._init_func.output_types
320    state_shapes = self._init_func.output_shapes
321    state_classes = self._init_func.output_classes
322    need_to_rerun = True
323    while need_to_rerun:
324
325      wrapped_func = dataset_ops.StructuredFunctionWrapper(
326          reduce_func,
327          self._transformation_name(),
328          input_structure=(self._state_structure, input_dataset.element_spec),
329          add_to_graph=False)
330
331      # Extract and validate class information from the returned values.
332      for new_state_class, state_class in zip(
333          nest.flatten(wrapped_func.output_classes),
334          nest.flatten(state_classes)):
335        if not issubclass(new_state_class, state_class):
336          raise TypeError(
337              "The element classes for the new state must match the initial "
338              "state. Expected %s; got %s." %
339              (self._state_classes, wrapped_func.output_classes))
340
341      # Extract and validate type information from the returned values.
342      for new_state_type, state_type in zip(
343          nest.flatten(wrapped_func.output_types), nest.flatten(state_types)):
344        if new_state_type != state_type:
345          raise TypeError(
346              "The element types for the new state must match the initial "
347              "state. Expected %s; got %s." %
348              (self._init_func.output_types, wrapped_func.output_types))
349
350      # Extract shape information from the returned values.
351      flat_state_shapes = nest.flatten(state_shapes)
352      flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes)
353      weakened_state_shapes = [
354          original.most_specific_compatible_shape(new)
355          for original, new in zip(flat_state_shapes, flat_new_state_shapes)
356      ]
357
358      need_to_rerun = False
359      for original_shape, weakened_shape in zip(flat_state_shapes,
360                                                weakened_state_shapes):
361        if original_shape.ndims is not None and (
362            weakened_shape.ndims is None or
363            original_shape.as_list() != weakened_shape.as_list()):
364          need_to_rerun = True
365          break
366
367      if need_to_rerun:
368        state_shapes = nest.pack_sequence_as(
369            self._init_func.output_shapes, weakened_state_shapes)
370        self._state_structure = structure.convert_legacy_structure(
371            state_types, state_shapes, state_classes)
372
373    self._reduce_func = wrapped_func
374    self._reduce_func.function.add_to_graph(ops.get_default_graph())
375
376  def _make_finalize_func(self, finalize_func):
377    """Make wrapping defun for finalize_func."""
378    self._finalize_func = dataset_ops.StructuredFunctionWrapper(
379        finalize_func, self._transformation_name(),
380        input_structure=self._state_structure)
381
382  @property
383  def element_spec(self):
384    return self._finalize_func.output_structure
385
386  def _functions(self):
387    return [
388        self._key_func, self._init_func, self._reduce_func, self._finalize_func
389    ]
390
391  def _transformation_name(self):
392    return "tf.data.experimental.group_by_reducer()"
393
394
395class _GroupByWindowDataset(dataset_ops.UnaryDataset):
396  """A `Dataset` that groups its input and performs a windowed reduction."""
397
398  def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
399    """See `group_by_window()` for details."""
400    self._input_dataset = input_dataset
401    self._make_key_func(key_func, input_dataset)
402    self._make_reduce_func(reduce_func, input_dataset)
403    self._make_window_size_func(window_size_func)
404    variant_tensor = ged_ops.group_by_window_dataset(
405        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
406        self._key_func.function.captured_inputs,
407        self._reduce_func.function.captured_inputs,
408        self._window_size_func.function.captured_inputs,
409        key_func=self._key_func.function,
410        reduce_func=self._reduce_func.function,
411        window_size_func=self._window_size_func.function,
412        **self._flat_structure)
413    super(_GroupByWindowDataset, self).__init__(input_dataset, variant_tensor)
414
415  def _make_window_size_func(self, window_size_func):
416    """Make wrapping defun for window_size_func."""
417
418    def window_size_func_wrapper(key):
419      return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64)
420    self._window_size_func = dataset_ops.StructuredFunctionWrapper(
421        window_size_func_wrapper,
422        self._transformation_name(),
423        input_structure=tensor_spec.TensorSpec([], dtypes.int64))
424    if not self._window_size_func.output_structure.is_compatible_with(
425        tensor_spec.TensorSpec([], dtypes.int64)):
426      raise ValueError(
427          "`window_size_func` must return a single tf.int64 scalar tensor.")
428
429  def _make_key_func(self, key_func, input_dataset):
430    """Make wrapping defun for key_func."""
431
432    def key_func_wrapper(*args):
433      return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64)
434    self._key_func = dataset_ops.StructuredFunctionWrapper(
435        key_func_wrapper, self._transformation_name(), dataset=input_dataset)
436    if not self._key_func.output_structure.is_compatible_with(
437        tensor_spec.TensorSpec([], dtypes.int64)):
438      raise ValueError(
439          "`key_func` must return a single tf.int64 scalar tensor.")
440
441  def _make_reduce_func(self, reduce_func, input_dataset):
442    """Make wrapping defun for reduce_func."""
443    nested_dataset = dataset_ops.DatasetSpec(
444        input_dataset.element_spec)
445    input_structure = (tensor_spec.TensorSpec([], dtypes.int64), nested_dataset)
446    self._reduce_func = dataset_ops.StructuredFunctionWrapper(
447        reduce_func, self._transformation_name(),
448        input_structure=input_structure)
449    if not isinstance(
450        self._reduce_func.output_structure, dataset_ops.DatasetSpec):
451      raise TypeError("`reduce_func` must return a `Dataset` object.")
452    # pylint: disable=protected-access
453    self._element_spec = (
454        self._reduce_func.output_structure._element_spec)
455
456  @property
457  def element_spec(self):
458    return self._element_spec
459
460  def _functions(self):
461    return [self._key_func, self._reduce_func, self._window_size_func]
462
463  def _transformation_name(self):
464    return "tf.data.experimental.group_by_window()"
465
466
467@tf_export("data.experimental.Reducer")
468class Reducer(object):
469  """A reducer is used for reducing a set of elements.
470
471  A reducer is represented as a tuple of the three functions:
472    1) initialization function: key => initial state
473    2) reduce function: (old state, input) => new state
474    3) finalization function: state => result
475  """
476
477  def __init__(self, init_func, reduce_func, finalize_func):
478    self._init_func = init_func
479    self._reduce_func = reduce_func
480    self._finalize_func = finalize_func
481
482  @property
483  def init_func(self):
484    return self._init_func
485
486  @property
487  def reduce_func(self):
488    return self._reduce_func
489
490  @property
491  def finalize_func(self):
492    return self._finalize_func
493