1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Non-deterministic dataset transformations."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python import tf2
21from tensorflow.python.data.experimental.ops import random_ops
22from tensorflow.python.data.ops import dataset_ops
23from tensorflow.python.data.ops import readers
24from tensorflow.python.data.util import nest
25from tensorflow.python.data.util import structure
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_spec
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import gen_experimental_dataset_ops
31from tensorflow.python.ops import gen_stateless_random_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.util import deprecation
34from tensorflow.python.util.tf_export import tf_export
35
36
37@deprecation.deprecated(
38    None,
39    "Use `tf.data.Dataset.interleave(map_func, cycle_length, block_length, "
40    "num_parallel_calls=tf.data.AUTOTUNE)` instead. If sloppy "
41    "execution is desired, use `tf.data.Options.experimental_deterministic`.")
42@tf_export("data.experimental.parallel_interleave")
43def parallel_interleave(map_func,
44                        cycle_length,
45                        block_length=1,
46                        sloppy=False,
47                        buffer_output_elements=None,
48                        prefetch_input_elements=None):
49  """A parallel version of the `Dataset.interleave()` transformation.
50
51  `parallel_interleave()` maps `map_func` across its input to produce nested
52  datasets, and outputs their elements interleaved. Unlike
53  `tf.data.Dataset.interleave`, it gets elements from `cycle_length` nested
54  datasets in parallel, which increases the throughput, especially in the
55  presence of stragglers. Furthermore, the `sloppy` argument can be used to
56  improve performance, by relaxing the requirement that the outputs are produced
57  in a deterministic order, and allowing the implementation to skip over nested
58  datasets whose elements are not readily available when requested.
59
60  Example usage:
61
62  ```python
63  # Preprocess 4 files concurrently.
64  filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords")
65  dataset = filenames.apply(
66      tf.data.experimental.parallel_interleave(
67          lambda filename: tf.data.TFRecordDataset(filename),
68          cycle_length=4))
69  ```
70
71  WARNING: If `sloppy` is `True`, the order of produced elements is not
72  deterministic.
73
74  Args:
75    map_func: A function mapping a nested structure of tensors to a `Dataset`.
76    cycle_length: The number of input `Dataset`s to interleave from in parallel.
77    block_length: The number of consecutive elements to pull from an input
78      `Dataset` before advancing to the next input `Dataset`.
79    sloppy: A boolean controlling whether determinism should be traded for
80      performance by allowing elements to be produced out of order.  If
81      `sloppy` is `None`, the `tf.data.Options.experimental_deterministic`
82      dataset option (`True` by default) is used to decide whether to enforce a
83      deterministic order.
84    buffer_output_elements: The number of elements each iterator being
85      interleaved should buffer (similar to the `.prefetch()` transformation for
86      each interleaved iterator).
87    prefetch_input_elements: The number of input elements to transform to
88      iterators before they are needed for interleaving.
89
90  Returns:
91    A `Dataset` transformation function, which can be passed to
92    `tf.data.Dataset.apply`.
93  """
94
95  def _apply_fn(dataset):
96    return readers.ParallelInterleaveDataset(dataset, map_func, cycle_length,
97                                             block_length, sloppy,
98                                             buffer_output_elements,
99                                             prefetch_input_elements)
100
101  return _apply_fn
102
103
104class _DirectedInterleaveDataset(dataset_ops.DatasetV2):
105  """A substitute for `Dataset.interleave()` on a fixed list of datasets."""
106
107  def __init__(self, selector_input, data_inputs):
108    self._selector_input = selector_input
109    self._data_inputs = list(data_inputs)
110
111    first_output_types = dataset_ops.get_legacy_output_types(data_inputs[0])
112    first_output_classes = dataset_ops.get_legacy_output_classes(data_inputs[0])
113
114    for i, data_input in enumerate(data_inputs[1:]):
115      if (dataset_ops.get_legacy_output_types(data_input) != first_output_types
116          or dataset_ops.get_legacy_output_classes(data_input)
117          != first_output_classes):
118        raise TypeError("All datasets must have the same type and class.\n"
119                        "dataset 0 vs dataset %s types: %s ; %s\n"
120                        "classes: %s ; %s" %
121                        (i + 1, first_output_types,
122                         dataset_ops.get_legacy_output_types(data_input),
123                         first_output_classes,
124                         dataset_ops.get_legacy_output_classes(data_input)))
125
126    output_shapes = dataset_ops.get_legacy_output_shapes(self._data_inputs[0])
127    for data_input in self._data_inputs[1:]:
128      output_shapes = nest.pack_sequence_as(output_shapes, [
129          ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip(
130              nest.flatten(output_shapes),
131              nest.flatten(dataset_ops.get_legacy_output_shapes(data_input)))
132      ])
133
134    self._element_spec = structure.convert_legacy_structure(
135        first_output_types, output_shapes, first_output_classes)
136    # pylint: disable=protected-access
137    variant_tensor = gen_experimental_dataset_ops.directed_interleave_dataset(
138        self._selector_input._variant_tensor,
139        [data_input._variant_tensor for data_input in self._data_inputs],
140        **self._flat_structure)
141    super(_DirectedInterleaveDataset, self).__init__(variant_tensor)
142
143  def _inputs(self):
144    return [self._selector_input] + self._data_inputs
145
146  @property
147  def element_spec(self):
148    return self._element_spec
149
150
151@tf_export("data.experimental.sample_from_datasets", v1=[])
152def sample_from_datasets_v2(datasets, weights=None, seed=None):
153  """Samples elements at random from the datasets in `datasets`.
154
155  Args:
156    datasets: A list of `tf.data.Dataset` objects with compatible structure.
157    weights: (Optional.) A list of `len(datasets)` floating-point values where
158      `weights[i]` represents the probability with which an element should be
159      sampled from `datasets[i]`, or a `tf.data.Dataset` object where each
160      element is such a list. Defaults to a uniform distribution across
161      `datasets`.
162    seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
163      random seed that will be used to create the distribution. See
164      `tf.random.set_seed` for behavior.
165
166  Returns:
167    A dataset that interleaves elements from `datasets` at random, according to
168    `weights` if provided, otherwise with uniform probability.
169
170  Raises:
171    TypeError: If the `datasets` or `weights` arguments have the wrong type.
172    ValueError: If the `weights` argument is specified and does not match the
173      length of the `datasets` element.
174  """
175  num_datasets = len(datasets)
176  if not isinstance(weights, dataset_ops.DatasetV2):
177    if weights is None:
178      # Select inputs with uniform probability.
179      logits = [[1.0] * num_datasets]
180
181    else:
182      # Use the given `weights` as the probability of choosing the respective
183      # input.
184      weights = ops.convert_to_tensor(weights, name="weights")
185      if weights.dtype not in (dtypes.float32, dtypes.float64):
186        raise TypeError("`weights` must be convertible to a tensor of "
187                        "`tf.float32` or `tf.float64` elements.")
188      if not weights.shape.is_compatible_with([num_datasets]):
189        raise ValueError(
190            "`weights` must be a vector of length `len(datasets)`.")
191
192      # The `stateless_multinomial()` op expects log-probabilities, as opposed
193      # to weights.
194      logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0)
195
196    # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When it
197    # is a `Dataset`, it is possible that evaluating it has a side effect the
198    # user depends on.
199    if len(datasets) == 1:
200      return datasets[0]
201
202    def select_dataset_constant_logits(seed):
203      return array_ops.squeeze(
204          gen_stateless_random_ops.stateless_multinomial(logits, 1, seed=seed),
205          axis=[0, 1])
206
207    selector_input = dataset_ops.MapDataset(
208        random_ops.RandomDataset(seed).batch(2),
209        select_dataset_constant_logits,
210        use_inter_op_parallelism=False)
211
212  else:
213    # Use each element of the given `weights` dataset as the probability of
214    # choosing the respective input.
215
216    # The `stateless_multinomial()` op expects log-probabilities, as opposed to
217    # weights.
218    logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits"))
219
220    def select_dataset_varying_logits(logits, seed):
221      return array_ops.squeeze(
222          gen_stateless_random_ops.stateless_multinomial(logits, 1, seed=seed),
223          axis=[0, 1])
224
225    logits_and_seeds = dataset_ops.Dataset.zip(
226        (logits_ds, random_ops.RandomDataset(seed).batch(2)))
227    selector_input = dataset_ops.MapDataset(
228        logits_and_seeds,
229        select_dataset_varying_logits,
230        use_inter_op_parallelism=False)
231
232  return _DirectedInterleaveDataset(selector_input, datasets)
233
234
235@tf_export(v1=["data.experimental.sample_from_datasets"])
236def sample_from_datasets_v1(datasets, weights=None, seed=None):
237  return dataset_ops.DatasetV1Adapter(
238      sample_from_datasets_v2(datasets, weights, seed))
239sample_from_datasets_v1.__doc__ = sample_from_datasets_v2.__doc__
240
241
242@tf_export("data.experimental.choose_from_datasets", v1=[])
243def choose_from_datasets_v2(datasets, choice_dataset):
244  """Creates a dataset that deterministically chooses elements from `datasets`.
245
246  For example, given the following datasets:
247
248  ```python
249  datasets = [tf.data.Dataset.from_tensors("foo").repeat(),
250              tf.data.Dataset.from_tensors("bar").repeat(),
251              tf.data.Dataset.from_tensors("baz").repeat()]
252
253  # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`.
254  choice_dataset = tf.data.Dataset.range(3).repeat(3)
255
256  result = tf.data.experimental.choose_from_datasets(datasets, choice_dataset)
257  ```
258
259  The elements of `result` will be:
260
261  ```
262  "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz"
263  ```
264
265  Args:
266    datasets: A list of `tf.data.Dataset` objects with compatible structure.
267    choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between
268      `0` and `len(datasets) - 1`.
269
270  Returns:
271    A dataset that interleaves elements from `datasets` according to the values
272    of `choice_dataset`.
273
274  Raises:
275    TypeError: If the `datasets` or `choice_dataset` arguments have the wrong
276      type.
277  """
278  if not structure.are_compatible(choice_dataset.element_spec,
279                                  tensor_spec.TensorSpec([], dtypes.int64)):
280    raise TypeError("`choice_dataset` must be a dataset of scalar "
281                    "`tf.int64` tensors.")
282  return _DirectedInterleaveDataset(choice_dataset, datasets)
283
284
285@tf_export(v1=["data.experimental.choose_from_datasets"])
286def choose_from_datasets_v1(datasets, choice_dataset):
287  return dataset_ops.DatasetV1Adapter(
288      choose_from_datasets_v2(datasets, choice_dataset))
289choose_from_datasets_v1.__doc__ = choose_from_datasets_v2.__doc__
290
291
292if tf2.enabled():
293  choose_from_datasets = choose_from_datasets_v2
294  sample_from_datasets = sample_from_datasets_v2
295else:
296  choose_from_datasets = choose_from_datasets_v1
297  sample_from_datasets = sample_from_datasets_v1
298