1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Batching dataset transformations."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.data.util import convert
22from tensorflow.python.data.util import nest
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import sparse_tensor
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import tensor_spec
28from tensorflow.python.framework import tensor_util
29from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
30from tensorflow.python.ops.ragged import ragged_tensor
31from tensorflow.python.util import deprecation
32from tensorflow.python.util.tf_export import tf_export
33
34
35@tf_export("data.experimental.dense_to_ragged_batch")
36def dense_to_ragged_batch(batch_size,
37                          drop_remainder=False,
38                          row_splits_dtype=dtypes.int64):
39  """A transformation that batches ragged elements into `tf.RaggedTensor`s.
40
41  This transformation combines multiple consecutive elements of the input
42  dataset into a single element.
43
44  Like `tf.data.Dataset.batch`, the components of the resulting element will
45  have an additional outer dimension, which will be `batch_size` (or
46  `N % batch_size` for the last element if `batch_size` does not divide the
47  number of input elements `N` evenly and `drop_remainder` is `False`). If
48  your program depends on the batches having the same outer dimension, you
49  should set the `drop_remainder` argument to `True` to prevent the smaller
50  batch from being produced.
51
52  Unlike `tf.data.Dataset.batch`, the input elements to be batched may have
53  different shapes:
54
55  *  If an input element is a `tf.Tensor` whose static `tf.TensorShape` is
56     fully defined, then it is batched as normal.
57  *  If an input element is a `tf.Tensor` whose static `tf.TensorShape` contains
58     one or more axes with unknown size (i.e., `shape[i]=None`), then the output
59     will contain a `tf.RaggedTensor` that is ragged up to any of such
60     dimensions.
61  *  If an input element is a `tf.RaggedTensor` or any other type, then it is
62     batched as normal.
63
64  Example:
65
66  >>> dataset = tf.data.Dataset.from_tensor_slices(np.arange(6))
67  >>> dataset = dataset.map(lambda x: tf.range(x))
68  >>> dataset.element_spec.shape
69  TensorShape([None])
70  >>> dataset = dataset.apply(
71  ...     tf.data.experimental.dense_to_ragged_batch(batch_size=2))
72  >>> for batch in dataset:
73  ...   print(batch)
74  <tf.RaggedTensor [[], [0]]>
75  <tf.RaggedTensor [[0, 1], [0, 1, 2]]>
76  <tf.RaggedTensor [[0, 1, 2, 3], [0, 1, 2, 3, 4]]>
77
78  Args:
79    batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
80      consecutive elements of this dataset to combine in a single batch.
81    drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
82      whether the last batch should be dropped in the case it has fewer than
83      `batch_size` elements; the default behavior is not to drop the smaller
84      batch.
85    row_splits_dtype: The dtype that should be used for the `row_splits` of any
86      new ragged tensors.  Existing `tf.RaggedTensor` elements do not have their
87      row_splits dtype changed.
88
89  Returns:
90    Dataset: A `Dataset`.
91  """
92
93  def _apply_fn(dataset):
94    ragged_dataset = _DenseToRaggedDataset(dataset, row_splits_dtype)
95    return dataset_ops.BatchDataset(
96        ragged_dataset, batch_size=batch_size, drop_remainder=drop_remainder)
97
98  return _apply_fn
99
100
101@tf_export("data.experimental.dense_to_sparse_batch")
102def dense_to_sparse_batch(batch_size, row_shape):
103  """A transformation that batches ragged elements into `tf.sparse.SparseTensor`s.
104
105  Like `Dataset.padded_batch()`, this transformation combines multiple
106  consecutive elements of the dataset, which might have different
107  shapes, into a single element. The resulting element has three
108  components (`indices`, `values`, and `dense_shape`), which
109  comprise a `tf.sparse.SparseTensor` that represents the same data. The
110  `row_shape` represents the dense shape of each row in the
111  resulting `tf.sparse.SparseTensor`, to which the effective batch size is
112  prepended. For example:
113
114  ```python
115  # NOTE: The following examples use `{ ... }` to represent the
116  # contents of a dataset.
117  a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
118
119  a.apply(tf.data.experimental.dense_to_sparse_batch(
120      batch_size=2, row_shape=[6])) ==
121  {
122      ([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]],  # indices
123       ['a', 'b', 'c', 'a', 'b'],                 # values
124       [2, 6]),                                   # dense_shape
125      ([[0, 0], [0, 1], [0, 2], [0, 3]],
126       ['a', 'b', 'c', 'd'],
127       [1, 6])
128  }
129  ```
130
131  Args:
132    batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
133      consecutive elements of this dataset to combine in a single batch.
134    row_shape: A `tf.TensorShape` or `tf.int64` vector tensor-like object
135      representing the equivalent dense shape of a row in the resulting
136      `tf.sparse.SparseTensor`. Each element of this dataset must have the same
137      rank as `row_shape`, and must have size less than or equal to `row_shape`
138      in each dimension.
139
140  Returns:
141    A `Dataset` transformation function, which can be passed to
142    `tf.data.Dataset.apply`.
143  """
144
145  def _apply_fn(dataset):
146    return _DenseToSparseBatchDataset(dataset, batch_size, row_shape)
147
148  return _apply_fn
149
150
151@deprecation.deprecated(None, "Use `tf.data.experimental.map_and_batch()")
152@tf_export(v1=["data.experimental.map_and_batch_with_legacy_function"])
153def map_and_batch_with_legacy_function(map_func,
154                                       batch_size,
155                                       num_parallel_batches=None,
156                                       drop_remainder=False,
157                                       num_parallel_calls=None):
158  """Fused implementation of `map` and `batch`.
159
160  NOTE: This is an escape hatch for existing uses of `map_and_batch` that do not
161  work with V2 functions. New uses are strongly discouraged and existing uses
162  should migrate to `map_and_batch` as this method will not be removed in V2.
163
164  Args:
165    map_func: A function mapping a nested structure of tensors to another
166      nested structure of tensors.
167    batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
168      consecutive elements of this dataset to combine in a single batch.
169    num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`,
170      representing the number of batches to create in parallel. On one hand,
171      higher values can help mitigate the effect of stragglers. On the other
172      hand, higher values can increase contention if CPU is scarce.
173    drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
174      whether the last batch should be dropped in case its size is smaller than
175      desired; the default behavior is not to drop the smaller batch.
176    num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
177      representing the number of elements to process in parallel. If not
178      specified, `batch_size * num_parallel_batches` elements will be processed
179      in parallel. If the value `tf.data.AUTOTUNE` is used, then
180      the number of parallel calls is set dynamically based on available CPU.
181
182  Returns:
183    A `Dataset` transformation function, which can be passed to
184    `tf.data.Dataset.apply`.
185
186  Raises:
187    ValueError: If both `num_parallel_batches` and `num_parallel_calls` are
188      specified.
189  """
190
191  if num_parallel_batches is None and num_parallel_calls is None:
192    num_parallel_calls = batch_size
193  elif num_parallel_batches is not None and num_parallel_calls is None:
194    num_parallel_calls = batch_size * num_parallel_batches
195  elif num_parallel_batches is not None and num_parallel_calls is not None:
196    raise ValueError("The `num_parallel_batches` and `num_parallel_calls` "
197                     "arguments are mutually exclusive.")
198
199  def _apply_fn(dataset):
200    return _MapAndBatchDataset(dataset, map_func, batch_size,
201                               num_parallel_calls, drop_remainder,
202                               use_legacy_function=True)
203
204  return _apply_fn
205
206
207@deprecation.deprecated(
208    None,
209    "Use `tf.data.Dataset.map(map_func, num_parallel_calls)` followed by "
210    "`tf.data.Dataset.batch(batch_size, drop_remainder)`. Static tf.data "
211    "optimizations will take care of using the fused implementation.")
212@tf_export("data.experimental.map_and_batch")
213def map_and_batch(map_func,
214                  batch_size,
215                  num_parallel_batches=None,
216                  drop_remainder=False,
217                  num_parallel_calls=None):
218  """Fused implementation of `map` and `batch`.
219
220  Maps `map_func` across `batch_size` consecutive elements of this dataset
221  and then combines them into a batch. Functionally, it is equivalent to `map`
222  followed by `batch`. This API is temporary and deprecated since input pipeline
223  optimization now fuses consecutive `map` and `batch` operations automatically.
224
225  Args:
226    map_func: A function mapping a nested structure of tensors to another
227      nested structure of tensors.
228    batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
229      consecutive elements of this dataset to combine in a single batch.
230    num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`,
231      representing the number of batches to create in parallel. On one hand,
232      higher values can help mitigate the effect of stragglers. On the other
233      hand, higher values can increase contention if CPU is scarce.
234    drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
235      whether the last batch should be dropped in case its size is smaller than
236      desired; the default behavior is not to drop the smaller batch.
237    num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
238      representing the number of elements to process in parallel. If not
239      specified, `batch_size * num_parallel_batches` elements will be processed
240      in parallel. If the value `tf.data.AUTOTUNE` is used, then
241      the number of parallel calls is set dynamically based on available CPU.
242
243  Returns:
244    A `Dataset` transformation function, which can be passed to
245    `tf.data.Dataset.apply`.
246
247  Raises:
248    ValueError: If both `num_parallel_batches` and `num_parallel_calls` are
249      specified.
250  """
251
252  if num_parallel_batches is None and num_parallel_calls is None:
253    num_parallel_calls = batch_size
254  elif num_parallel_batches is not None and num_parallel_calls is None:
255    num_parallel_calls = batch_size * num_parallel_batches
256  elif num_parallel_batches is not None and num_parallel_calls is not None:
257    raise ValueError("The `num_parallel_batches` and `num_parallel_calls` "
258                     "arguments are mutually exclusive.")
259
260  def _apply_fn(dataset):
261    return _MapAndBatchDataset(dataset, map_func, batch_size,
262                               num_parallel_calls, drop_remainder)
263
264  return _apply_fn
265
266
267@deprecation.deprecated(None, "Use `tf.data.Dataset.unbatch()`.")
268@tf_export("data.experimental.unbatch")
269def unbatch():
270  """Splits elements of a dataset into multiple elements on the batch dimension.
271
272  For example, if elements of the dataset are shaped `[B, a0, a1, ...]`,
273  where `B` may vary for each input element, then for each element in the
274  dataset, the unbatched dataset will contain `B` consecutive elements
275  of shape `[a0, a1, ...]`.
276
277  ```python
278  # NOTE: The following example uses `{ ... }` to represent the contents
279  # of a dataset.
280  a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
281
282  a.unbatch() == {
283      'a', 'b', 'c', 'a', 'b', 'a', 'b', 'c', 'd'}
284  ```
285
286  Returns:
287    A `Dataset` transformation function, which can be passed to
288    `tf.data.Dataset.apply`.
289  """
290
291  def _apply_fn(dataset):
292    return dataset.unbatch()
293
294  return _apply_fn
295
296
297class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
298  """A `Dataset` that batches ragged dense elements into `tf.sparse.SparseTensor`s."""
299
300  def __init__(self, input_dataset, batch_size, row_shape):
301    """See `Dataset.dense_to_sparse_batch()` for more details."""
302    if not isinstance(
303        dataset_ops.get_legacy_output_types(input_dataset), dtypes.DType):
304      raise TypeError("DenseToSparseDataset requires an input whose elements "
305                      "have a single component, whereas the input has %r." %
306                      dataset_ops.get_legacy_output_types(input_dataset))
307    self._input_dataset = input_dataset
308    self._batch_size = batch_size
309    self._row_shape = row_shape
310    self._element_spec = sparse_tensor.SparseTensorSpec(
311        tensor_shape.TensorShape([None]).concatenate(self._row_shape),
312        dataset_ops.get_legacy_output_types(input_dataset))
313
314    variant_tensor = ged_ops.dense_to_sparse_batch_dataset(
315        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
316        self._batch_size,
317        row_shape=convert.partial_shape_to_tensor(self._row_shape),
318        **self._flat_structure)
319    super(_DenseToSparseBatchDataset, self).__init__(input_dataset,
320                                                     variant_tensor)
321
322  @property
323  def element_spec(self):
324    return self._element_spec
325
326
327class _MapAndBatchDataset(dataset_ops.UnaryDataset):
328  """A `Dataset` that maps a function over a batch of elements."""
329
330  def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls,
331               drop_remainder, use_legacy_function=False):
332    self._input_dataset = input_dataset
333
334    self._map_func = dataset_ops.StructuredFunctionWrapper(
335        map_func,
336        "tf.data.experimental.map_and_batch()",
337        dataset=input_dataset,
338        use_legacy_function=use_legacy_function)
339    self._batch_size_t = ops.convert_to_tensor(
340        batch_size, dtype=dtypes.int64, name="batch_size")
341    self._num_parallel_calls_t = ops.convert_to_tensor(
342        num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
343    self._drop_remainder_t = ops.convert_to_tensor(
344        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
345
346    constant_drop_remainder = tensor_util.constant_value(self._drop_remainder_t)
347    # pylint: disable=protected-access
348    if constant_drop_remainder:
349      # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
350      # or `False` (explicitly retaining the remainder).
351      # pylint: disable=g-long-lambda
352      self._element_spec = nest.map_structure(
353          lambda component_spec: component_spec._batch(
354              tensor_util.constant_value(self._batch_size_t)),
355          self._map_func.output_structure)
356    else:
357      self._element_spec = nest.map_structure(
358          lambda component_spec: component_spec._batch(None),
359          self._map_func.output_structure)
360    # pylint: enable=protected-access
361    variant_tensor = ged_ops.map_and_batch_dataset(
362        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
363        self._map_func.function.captured_inputs,
364        f=self._map_func.function,
365        batch_size=self._batch_size_t,
366        num_parallel_calls=self._num_parallel_calls_t,
367        drop_remainder=self._drop_remainder_t,
368        preserve_cardinality=True,
369        **self._flat_structure)
370    super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor)
371
372  def _functions(self):
373    return [self._map_func]
374
375  @property
376  def element_spec(self):
377    return self._element_spec
378
379
380class _DenseToRaggedDataset(dataset_ops.UnaryDataset):
381  """A `Dataset` that encodes dense inputs as ragged (w/ ragged_rank=0).
382
383  In particular:
384
385  * Any tf.Tensor elements with rank>0 are encoded as ragged tensors with
386    ragged_rank=0.  This allows tensors with varying shape to be batched
387    together.
388  * Any other elements are left as-is.
389  """
390
391  def __init__(self, input_dataset, row_splits_dtype):
392    """Constructs a new _DenseToRaggedDataset.
393
394    Args:
395      input_dataset: The dataset whose tf.Tensor elements should be made ragged.
396      row_splits_dtype: The dtype that should be used for the `row_splits` of
397        any new ragged tensors.  Existing `tf.RaggedTensor` elements do *not*
398        have their row_splits dtype changed.
399    """
400    # Replace each TensorSpec in the input dataset's structure with a
401    # corresponding RaggedTensorSpec.
402    def to_ragged_spec(spec):
403      """Returns the new spec based on RaggedTensors."""
404      if (not isinstance(spec, tensor_spec.TensorSpec) or
405          spec.shape.rank is None or
406          spec.shape.is_fully_defined()):
407        return spec
408      else:
409        ragged_rank = max([
410            axis for (axis, size) in enumerate(spec.shape.as_list())
411            if size is None
412        ])
413        return ragged_tensor.RaggedTensorSpec(
414            shape=spec.shape,
415            dtype=spec.dtype,
416            ragged_rank=ragged_rank,
417            row_splits_dtype=row_splits_dtype)
418
419    self._structure = nest.map_structure(to_ragged_spec,
420                                         input_dataset.element_spec)
421
422    # Replace each tf.Tensor value in the input dataset with a variant-encoded
423    # RaggedTensor. Since we're updating the corresponding structure to be
424    # a RaggedTensorSpec, this variant-encoded tensor will be decoded with
425    # RaggedTensorSpec._from_tensor_list.
426    def to_ragged_variant(value):
427      """Re-encode Tensors as RaggedTensors."""
428      if (not isinstance(value, ops.Tensor) or
429          value.shape.rank is None or
430          value.shape.is_fully_defined()):
431        return value
432      else:
433        spec = to_ragged_spec(tensor_spec.TensorSpec.from_tensor(value))
434        if spec._ragged_rank > 0:  # pylint: disable=protected-access
435          value = ragged_tensor.RaggedTensor.from_tensor(
436              value, ragged_rank=spec._ragged_rank)  # pylint: disable=protected-access
437        return spec._to_tensor_list(value)[0]  # pylint: disable=protected-access
438
439    # Tuples are automatically unpacked by `dataset.map` so we repack them.
440    if dataset_ops._should_unpack_args(input_dataset.element_spec):  # pylint: disable=protected-access
441      map_fn = lambda *value: nest.map_structure(to_ragged_variant, value)
442    else:
443      map_fn = lambda value: nest.map_structure(to_ragged_variant, value)
444
445    self._mapped_dataset = input_dataset.map(map_fn)
446
447    variant = self._mapped_dataset._variant_tensor  # pylint: disable=protected-access
448    super(_DenseToRaggedDataset, self).__init__(input_dataset, variant)
449
450  @property
451  def element_spec(self):
452    return self._structure
453