1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Non-deterministic dataset transformations.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.data.experimental.ops import random_ops 21from tensorflow.python.data.ops import dataset_ops 22from tensorflow.python.data.ops import readers 23from tensorflow.python.data.util import nest 24from tensorflow.python.data.util import structure 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import gen_experimental_dataset_ops 29from tensorflow.python.ops import gen_stateless_random_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.util import deprecation 32from tensorflow.python.util.tf_export import tf_export 33 34 35@deprecation.deprecated( 36 None, 37 "Use `tf.data.Dataset.interleave(map_func, cycle_length, block_length, " 38 "num_parallel_calls=tf.data.experimental.AUTOTUNE)` instead. If sloppy " 39 "execution is desired, use `tf.data.Options.experimental_determinstic`.") 40@tf_export("data.experimental.parallel_interleave") 41def parallel_interleave(map_func, 42 cycle_length, 43 block_length=1, 44 sloppy=False, 45 buffer_output_elements=None, 46 prefetch_input_elements=None): 47 """A parallel version of the `Dataset.interleave()` transformation. 48 49 `parallel_interleave()` maps `map_func` across its input to produce nested 50 datasets, and outputs their elements interleaved. Unlike 51 `tf.data.Dataset.interleave`, it gets elements from `cycle_length` nested 52 datasets in parallel, which increases the throughput, especially in the 53 presence of stragglers. Furthermore, the `sloppy` argument can be used to 54 improve performance, by relaxing the requirement that the outputs are produced 55 in a deterministic order, and allowing the implementation to skip over nested 56 datasets whose elements are not readily available when requested. 57 58 Example usage: 59 60 ```python 61 # Preprocess 4 files concurrently. 62 filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords") 63 dataset = filenames.apply( 64 tf.data.experimental.parallel_interleave( 65 lambda filename: tf.data.TFRecordDataset(filename), 66 cycle_length=4)) 67 ``` 68 69 WARNING: If `sloppy` is `True`, the order of produced elements is not 70 deterministic. 71 72 Args: 73 map_func: A function mapping a nested structure of tensors to a `Dataset`. 74 cycle_length: The number of input `Dataset`s to interleave from in parallel. 75 block_length: The number of consecutive elements to pull from an input 76 `Dataset` before advancing to the next input `Dataset`. 77 sloppy: If false, elements are produced in deterministic order. Otherwise, 78 the implementation is allowed, for the sake of expediency, to produce 79 elements in a non-deterministic order. 80 buffer_output_elements: The number of elements each iterator being 81 interleaved should buffer (similar to the `.prefetch()` transformation for 82 each interleaved iterator). 83 prefetch_input_elements: The number of input elements to transform to 84 iterators before they are needed for interleaving. 85 86 Returns: 87 A `Dataset` transformation function, which can be passed to 88 `tf.data.Dataset.apply`. 89 """ 90 def _apply_fn(dataset): 91 return readers.ParallelInterleaveDataset( 92 dataset, map_func, cycle_length, block_length, sloppy, 93 buffer_output_elements, prefetch_input_elements) 94 95 return _apply_fn 96 97 98class _DirectedInterleaveDataset(dataset_ops.Dataset): 99 """A substitute for `Dataset.interleave()` on a fixed list of datasets.""" 100 101 def __init__(self, selector_input, data_inputs): 102 self._selector_input = selector_input 103 self._data_inputs = list(data_inputs) 104 105 first_output_types = dataset_ops.get_legacy_output_types(data_inputs[0]) 106 first_output_classes = dataset_ops.get_legacy_output_classes(data_inputs[0]) 107 108 for data_input in data_inputs[1:]: 109 if (dataset_ops.get_legacy_output_types(data_input) != first_output_types 110 or dataset_ops.get_legacy_output_classes(data_input) 111 != first_output_classes): 112 raise TypeError("All datasets must have the same type and class.") 113 114 output_shapes = dataset_ops.get_legacy_output_shapes(self._data_inputs[0]) 115 for data_input in self._data_inputs[1:]: 116 output_shapes = nest.pack_sequence_as(output_shapes, [ 117 ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip( 118 nest.flatten(output_shapes), 119 nest.flatten(dataset_ops.get_legacy_output_shapes(data_input))) 120 ]) 121 122 self._structure = structure.convert_legacy_structure( 123 first_output_types, output_shapes, first_output_classes) 124 super(_DirectedInterleaveDataset, self).__init__() 125 126 def _as_variant_tensor(self): 127 # pylint: disable=protected-access 128 return ( 129 gen_experimental_dataset_ops.experimental_directed_interleave_dataset( 130 self._selector_input._variant_tensor, 131 [data_input._variant_tensor for data_input in self._data_inputs], 132 **dataset_ops.flat_structure(self))) 133 # pylint: enable=protected-access 134 135 def _inputs(self): 136 return [self._selector_input] + self._data_inputs 137 138 @property 139 def _element_structure(self): 140 return self._structure 141 142 143@tf_export("data.experimental.sample_from_datasets", v1=[]) 144def sample_from_datasets_v2(datasets, weights=None, seed=None): 145 """Samples elements at random from the datasets in `datasets`. 146 147 Args: 148 datasets: A list of `tf.data.Dataset` objects with compatible structure. 149 weights: (Optional.) A list of `len(datasets)` floating-point values where 150 `weights[i]` represents the probability with which an element should be 151 sampled from `datasets[i]`, or a `tf.data.Dataset` object where each 152 element is such a list. Defaults to a uniform distribution across 153 `datasets`. 154 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 155 random seed that will be used to create the distribution. See 156 `tf.set_random_seed` for behavior. 157 158 Returns: 159 A dataset that interleaves elements from `datasets` at random, according to 160 `weights` if provided, otherwise with uniform probability. 161 162 Raises: 163 TypeError: If the `datasets` or `weights` arguments have the wrong type. 164 ValueError: If the `weights` argument is specified and does not match the 165 length of the `datasets` element. 166 """ 167 num_datasets = len(datasets) 168 if not isinstance(weights, dataset_ops.DatasetV2): 169 if weights is None: 170 # Select inputs with uniform probability. 171 logits = [[1.0] * num_datasets] 172 173 else: 174 # Use the given `weights` as the probability of choosing the respective 175 # input. 176 weights = ops.convert_to_tensor(weights, name="weights") 177 if weights.dtype not in (dtypes.float32, dtypes.float64): 178 raise TypeError("`weights` must be convertible to a tensor of " 179 "`tf.float32` or `tf.float64` elements.") 180 if not weights.shape.is_compatible_with([num_datasets]): 181 raise ValueError( 182 "`weights` must be a vector of length `len(datasets)`.") 183 184 # The `stateless_multinomial()` op expects log-probabilities, as opposed 185 # to weights. 186 logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0) 187 188 # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When it 189 # is a `Dataset`, it is possible that evaluating it has a side effect the 190 # user depends on. 191 if len(datasets) == 1: 192 return datasets[0] 193 194 def select_dataset_constant_logits(seed): 195 return array_ops.squeeze( 196 gen_stateless_random_ops.stateless_multinomial(logits, 1, seed=seed), 197 axis=[0, 1]) 198 199 selector_input = dataset_ops.MapDataset( 200 random_ops.RandomDataset(seed).batch(2), 201 select_dataset_constant_logits, 202 use_inter_op_parallelism=False) 203 204 else: 205 # Use each element of the given `weights` dataset as the probability of 206 # choosing the respective input. 207 208 # The `stateless_multinomial()` op expects log-probabilities, as opposed to 209 # weights. 210 logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits")) 211 212 def select_dataset_varying_logits(logits, seed): 213 return array_ops.squeeze( 214 gen_stateless_random_ops.stateless_multinomial(logits, 1, seed=seed), 215 axis=[0, 1]) 216 217 logits_and_seeds = dataset_ops.Dataset.zip( 218 (logits_ds, random_ops.RandomDataset(seed).batch(2))) 219 selector_input = dataset_ops.MapDataset( 220 logits_and_seeds, 221 select_dataset_varying_logits, 222 use_inter_op_parallelism=False) 223 224 return _DirectedInterleaveDataset(selector_input, datasets) 225 226 227@tf_export(v1=["data.experimental.sample_from_datasets"]) 228def sample_from_datasets_v1(datasets, weights=None, seed=None): 229 return dataset_ops.DatasetV1Adapter( 230 sample_from_datasets_v2(datasets, weights, seed)) 231sample_from_datasets_v1.__doc__ = sample_from_datasets_v2.__doc__ 232 233 234@tf_export("data.experimental.choose_from_datasets", v1=[]) 235def choose_from_datasets_v2(datasets, choice_dataset): 236 """Creates a dataset that deterministically chooses elements from `datasets`. 237 238 For example, given the following datasets: 239 240 ```python 241 datasets = [tf.data.Dataset.from_tensors("foo").repeat(), 242 tf.data.Dataset.from_tensors("bar").repeat(), 243 tf.data.Dataset.from_tensors("baz").repeat()] 244 245 # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`. 246 choice_dataset = tf.data.Dataset.range(3).repeat(3) 247 248 result = tf.data.experimental.choose_from_datasets(datasets, choice_dataset) 249 ``` 250 251 The elements of `result` will be: 252 253 ``` 254 "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz" 255 ``` 256 257 Args: 258 datasets: A list of `tf.data.Dataset` objects with compatible structure. 259 choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between 260 `0` and `len(datasets) - 1`. 261 262 Returns: 263 A dataset that interleaves elements from `datasets` according to the values 264 of `choice_dataset`. 265 266 Raises: 267 TypeError: If the `datasets` or `choice_dataset` arguments have the wrong 268 type. 269 """ 270 if not dataset_ops.get_structure(choice_dataset).is_compatible_with( 271 structure.TensorStructure(dtypes.int64, [])): 272 raise TypeError("`choice_dataset` must be a dataset of scalar " 273 "`tf.int64` tensors.") 274 return _DirectedInterleaveDataset(choice_dataset, datasets) 275 276 277@tf_export(v1=["data.experimental.choose_from_datasets"]) 278def choose_from_datasets_v1(datasets, choice_dataset): 279 return dataset_ops.DatasetV1Adapter( 280 choose_from_datasets_v2(datasets, choice_dataset)) 281choose_from_datasets_v1.__doc__ = choose_from_datasets_v2.__doc__ 282 283 284# TODO(b/119044825): Until all `tf.data` unit tests are converted to V2, keep 285# these aliases in place. 286choose_from_datasets = choose_from_datasets_v1 287sample_from_datasets = sample_from_datasets_v1 288