1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Non-deterministic dataset transformations."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.data.experimental.ops import random_ops
21from tensorflow.python.data.ops import dataset_ops
22from tensorflow.python.data.ops import readers
23from tensorflow.python.data.util import nest
24from tensorflow.python.data.util import structure
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import gen_experimental_dataset_ops
29from tensorflow.python.ops import gen_stateless_random_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.util import deprecation
32from tensorflow.python.util.tf_export import tf_export
33
34
35@deprecation.deprecated(
36    None,
37    "Use `tf.data.Dataset.interleave(map_func, cycle_length, block_length, "
38    "num_parallel_calls=tf.data.experimental.AUTOTUNE)` instead. If sloppy "
39    "execution is desired, use `tf.data.Options.experimental_determinstic`.")
40@tf_export("data.experimental.parallel_interleave")
41def parallel_interleave(map_func,
42                        cycle_length,
43                        block_length=1,
44                        sloppy=False,
45                        buffer_output_elements=None,
46                        prefetch_input_elements=None):
47  """A parallel version of the `Dataset.interleave()` transformation.
48
49  `parallel_interleave()` maps `map_func` across its input to produce nested
50  datasets, and outputs their elements interleaved. Unlike
51  `tf.data.Dataset.interleave`, it gets elements from `cycle_length` nested
52  datasets in parallel, which increases the throughput, especially in the
53  presence of stragglers. Furthermore, the `sloppy` argument can be used to
54  improve performance, by relaxing the requirement that the outputs are produced
55  in a deterministic order, and allowing the implementation to skip over nested
56  datasets whose elements are not readily available when requested.
57
58  Example usage:
59
60  ```python
61  # Preprocess 4 files concurrently.
62  filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords")
63  dataset = filenames.apply(
64      tf.data.experimental.parallel_interleave(
65          lambda filename: tf.data.TFRecordDataset(filename),
66          cycle_length=4))
67  ```
68
69  WARNING: If `sloppy` is `True`, the order of produced elements is not
70  deterministic.
71
72  Args:
73    map_func: A function mapping a nested structure of tensors to a `Dataset`.
74    cycle_length: The number of input `Dataset`s to interleave from in parallel.
75    block_length: The number of consecutive elements to pull from an input
76      `Dataset` before advancing to the next input `Dataset`.
77    sloppy: If false, elements are produced in deterministic order. Otherwise,
78      the implementation is allowed, for the sake of expediency, to produce
79      elements in a non-deterministic order.
80    buffer_output_elements: The number of elements each iterator being
81      interleaved should buffer (similar to the `.prefetch()` transformation for
82      each interleaved iterator).
83    prefetch_input_elements: The number of input elements to transform to
84      iterators before they are needed for interleaving.
85
86  Returns:
87    A `Dataset` transformation function, which can be passed to
88    `tf.data.Dataset.apply`.
89  """
90  def _apply_fn(dataset):
91    return readers.ParallelInterleaveDataset(
92        dataset, map_func, cycle_length, block_length, sloppy,
93        buffer_output_elements, prefetch_input_elements)
94
95  return _apply_fn
96
97
98class _DirectedInterleaveDataset(dataset_ops.Dataset):
99  """A substitute for `Dataset.interleave()` on a fixed list of datasets."""
100
101  def __init__(self, selector_input, data_inputs):
102    self._selector_input = selector_input
103    self._data_inputs = list(data_inputs)
104
105    first_output_types = dataset_ops.get_legacy_output_types(data_inputs[0])
106    first_output_classes = dataset_ops.get_legacy_output_classes(data_inputs[0])
107
108    for data_input in data_inputs[1:]:
109      if (dataset_ops.get_legacy_output_types(data_input) != first_output_types
110          or dataset_ops.get_legacy_output_classes(data_input)
111          != first_output_classes):
112        raise TypeError("All datasets must have the same type and class.")
113
114    output_shapes = dataset_ops.get_legacy_output_shapes(self._data_inputs[0])
115    for data_input in self._data_inputs[1:]:
116      output_shapes = nest.pack_sequence_as(output_shapes, [
117          ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip(
118              nest.flatten(output_shapes),
119              nest.flatten(dataset_ops.get_legacy_output_shapes(data_input)))
120      ])
121
122    self._structure = structure.convert_legacy_structure(
123        first_output_types, output_shapes, first_output_classes)
124    super(_DirectedInterleaveDataset, self).__init__()
125
126  def _as_variant_tensor(self):
127    # pylint: disable=protected-access
128    return (
129        gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
130            self._selector_input._variant_tensor,
131            [data_input._variant_tensor for data_input in self._data_inputs],
132            **dataset_ops.flat_structure(self)))
133    # pylint: enable=protected-access
134
135  def _inputs(self):
136    return [self._selector_input] + self._data_inputs
137
138  @property
139  def _element_structure(self):
140    return self._structure
141
142
143@tf_export("data.experimental.sample_from_datasets", v1=[])
144def sample_from_datasets_v2(datasets, weights=None, seed=None):
145  """Samples elements at random from the datasets in `datasets`.
146
147  Args:
148    datasets: A list of `tf.data.Dataset` objects with compatible structure.
149    weights: (Optional.) A list of `len(datasets)` floating-point values where
150      `weights[i]` represents the probability with which an element should be
151      sampled from `datasets[i]`, or a `tf.data.Dataset` object where each
152      element is such a list. Defaults to a uniform distribution across
153      `datasets`.
154    seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
155      random seed that will be used to create the distribution. See
156      `tf.set_random_seed` for behavior.
157
158  Returns:
159    A dataset that interleaves elements from `datasets` at random, according to
160    `weights` if provided, otherwise with uniform probability.
161
162  Raises:
163    TypeError: If the `datasets` or `weights` arguments have the wrong type.
164    ValueError: If the `weights` argument is specified and does not match the
165      length of the `datasets` element.
166  """
167  num_datasets = len(datasets)
168  if not isinstance(weights, dataset_ops.DatasetV2):
169    if weights is None:
170      # Select inputs with uniform probability.
171      logits = [[1.0] * num_datasets]
172
173    else:
174      # Use the given `weights` as the probability of choosing the respective
175      # input.
176      weights = ops.convert_to_tensor(weights, name="weights")
177      if weights.dtype not in (dtypes.float32, dtypes.float64):
178        raise TypeError("`weights` must be convertible to a tensor of "
179                        "`tf.float32` or `tf.float64` elements.")
180      if not weights.shape.is_compatible_with([num_datasets]):
181        raise ValueError(
182            "`weights` must be a vector of length `len(datasets)`.")
183
184      # The `stateless_multinomial()` op expects log-probabilities, as opposed
185      # to weights.
186      logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0)
187
188    # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When it
189    # is a `Dataset`, it is possible that evaluating it has a side effect the
190    # user depends on.
191    if len(datasets) == 1:
192      return datasets[0]
193
194    def select_dataset_constant_logits(seed):
195      return array_ops.squeeze(
196          gen_stateless_random_ops.stateless_multinomial(logits, 1, seed=seed),
197          axis=[0, 1])
198
199    selector_input = dataset_ops.MapDataset(
200        random_ops.RandomDataset(seed).batch(2),
201        select_dataset_constant_logits,
202        use_inter_op_parallelism=False)
203
204  else:
205    # Use each element of the given `weights` dataset as the probability of
206    # choosing the respective input.
207
208    # The `stateless_multinomial()` op expects log-probabilities, as opposed to
209    # weights.
210    logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits"))
211
212    def select_dataset_varying_logits(logits, seed):
213      return array_ops.squeeze(
214          gen_stateless_random_ops.stateless_multinomial(logits, 1, seed=seed),
215          axis=[0, 1])
216
217    logits_and_seeds = dataset_ops.Dataset.zip(
218        (logits_ds, random_ops.RandomDataset(seed).batch(2)))
219    selector_input = dataset_ops.MapDataset(
220        logits_and_seeds,
221        select_dataset_varying_logits,
222        use_inter_op_parallelism=False)
223
224  return _DirectedInterleaveDataset(selector_input, datasets)
225
226
227@tf_export(v1=["data.experimental.sample_from_datasets"])
228def sample_from_datasets_v1(datasets, weights=None, seed=None):
229  return dataset_ops.DatasetV1Adapter(
230      sample_from_datasets_v2(datasets, weights, seed))
231sample_from_datasets_v1.__doc__ = sample_from_datasets_v2.__doc__
232
233
234@tf_export("data.experimental.choose_from_datasets", v1=[])
235def choose_from_datasets_v2(datasets, choice_dataset):
236  """Creates a dataset that deterministically chooses elements from `datasets`.
237
238  For example, given the following datasets:
239
240  ```python
241  datasets = [tf.data.Dataset.from_tensors("foo").repeat(),
242              tf.data.Dataset.from_tensors("bar").repeat(),
243              tf.data.Dataset.from_tensors("baz").repeat()]
244
245  # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`.
246  choice_dataset = tf.data.Dataset.range(3).repeat(3)
247
248  result = tf.data.experimental.choose_from_datasets(datasets, choice_dataset)
249  ```
250
251  The elements of `result` will be:
252
253  ```
254  "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz"
255  ```
256
257  Args:
258    datasets: A list of `tf.data.Dataset` objects with compatible structure.
259    choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between
260      `0` and `len(datasets) - 1`.
261
262  Returns:
263    A dataset that interleaves elements from `datasets` according to the values
264    of `choice_dataset`.
265
266  Raises:
267    TypeError: If the `datasets` or `choice_dataset` arguments have the wrong
268      type.
269  """
270  if not dataset_ops.get_structure(choice_dataset).is_compatible_with(
271      structure.TensorStructure(dtypes.int64, [])):
272    raise TypeError("`choice_dataset` must be a dataset of scalar "
273                    "`tf.int64` tensors.")
274  return _DirectedInterleaveDataset(choice_dataset, datasets)
275
276
277@tf_export(v1=["data.experimental.choose_from_datasets"])
278def choose_from_datasets_v1(datasets, choice_dataset):
279  return dataset_ops.DatasetV1Adapter(
280      choose_from_datasets_v2(datasets, choice_dataset))
281choose_from_datasets_v1.__doc__ = choose_from_datasets_v2.__doc__
282
283
284# TODO(b/119044825): Until all `tf.data` unit tests are converted to V2, keep
285# these aliases in place.
286choose_from_datasets = choose_from_datasets_v1
287sample_from_datasets = sample_from_datasets_v1
288