1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Dataset snapshot and related functionality.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import multiprocessing 21 22from tensorflow.python.data.ops import dataset_ops 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import random_seed 26from tensorflow.python.framework import tensor_spec 27from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 28from tensorflow.python.util import deprecation 29from tensorflow.python.util.tf_export import tf_export 30 31COMPRESSION_GZIP = "GZIP" 32COMPRESSION_SNAPPY = "SNAPPY" 33COMPRESSION_NONE = None 34 35 36class _LegacySnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset): 37 """A Dataset that captures a snapshot or reads from a snapshot.""" 38 39 def __init__(self, 40 input_dataset, 41 path, 42 compression=None, 43 reader_path_prefix=None, 44 writer_path_prefix=None, 45 shard_size_bytes=None, 46 pending_snapshot_expiry_seconds=None, 47 num_reader_threads=None, 48 reader_buffer_size=None, 49 num_writer_threads=None, 50 writer_buffer_size=None, 51 shuffle_on_read=None, 52 shuffle_seed=None, 53 mode=None, 54 snapshot_name=None): 55 56 self._compression = compression if compression is not None else "" 57 self._reader_path_prefix = ( 58 reader_path_prefix if reader_path_prefix is not None else "") 59 self._writer_path_prefix = ( 60 writer_path_prefix if writer_path_prefix is not None else "") 61 self._shard_size_bytes = ( 62 shard_size_bytes if shard_size_bytes is not None else -1) 63 self._pending_snapshot_expiry_seconds = ( 64 pending_snapshot_expiry_seconds 65 if pending_snapshot_expiry_seconds is not None else -1) 66 self._num_reader_threads = ( 67 num_reader_threads if num_reader_threads is not None else -1) 68 self._reader_buffer_size = ( 69 reader_buffer_size if reader_buffer_size is not None else -1) 70 self._num_writer_threads = ( 71 num_writer_threads if num_writer_threads is not None else -1) 72 self._writer_buffer_size = ( 73 writer_buffer_size if writer_buffer_size is not None else -1) 74 self._shuffle_on_read = ( 75 shuffle_on_read if shuffle_on_read is not None else False) 76 self._mode = (mode if mode is not None else "auto") 77 self._snapshot_name = (snapshot_name if snapshot_name is not None else "") 78 79 self._seed, self._seed2 = random_seed.get_seed(shuffle_seed) 80 81 self._input_dataset = input_dataset 82 self._path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path") 83 84 variant_tensor = ged_ops.snapshot_dataset( 85 self._input_dataset._variant_tensor, # pylint: disable=protected-access 86 path=self._path, 87 compression=self._compression, 88 reader_path_prefix=self._reader_path_prefix, 89 writer_path_prefix=self._writer_path_prefix, 90 shard_size_bytes=self._shard_size_bytes, 91 pending_snapshot_expiry_seconds=self._pending_snapshot_expiry_seconds, 92 num_reader_threads=self._num_reader_threads, 93 reader_buffer_size=self._reader_buffer_size, 94 num_writer_threads=self._num_writer_threads, 95 writer_buffer_size=self._writer_buffer_size, 96 shuffle_on_read=self._shuffle_on_read, 97 seed=self._seed, 98 seed2=self._seed2, 99 mode=self._mode, 100 snapshot_name=self._snapshot_name, 101 **self._flat_structure) 102 103 super(_LegacySnapshotDataset, self).__init__(input_dataset, variant_tensor) 104 105 106@deprecation.deprecated( 107 None, "Use `tf.data.experimental.snapshot(...)` instead.") 108def legacy_snapshot(path, 109 compression=None, 110 reader_path_prefix=None, 111 writer_path_prefix=None, 112 shard_size_bytes=None, 113 pending_snapshot_expiry_seconds=None, 114 num_reader_threads=None, 115 reader_buffer_size=None, 116 num_writer_threads=None, 117 writer_buffer_size=None, 118 shuffle_on_read=None, 119 shuffle_seed=None, 120 mode=None, 121 snapshot_name=None): 122 """Writes to/reads from a snapshot of a dataset. 123 124 This function attempts to determine whether a valid snapshot exists at the 125 `path`, and reads from the snapshot if so. If not, it will run the 126 preprocessing pipeline as usual, and write out a snapshot of the data 127 processed for future use. 128 129 Args: 130 path: A directory where we want to save our snapshots and/or read from a 131 previously saved snapshot. 132 compression: The type of compression to apply to the Dataset. Currently 133 supports "GZIP" or None. Defaults to None (no compression). 134 reader_path_prefix: A prefix to add to the path when reading from snapshots. 135 Defaults to None. 136 writer_path_prefix: A prefix to add to the path when writing to snapshots. 137 Defaults to None. 138 shard_size_bytes: The size of each shard to be written by the snapshot 139 dataset op. Defaults to 10 GiB. 140 pending_snapshot_expiry_seconds: How long to wait (in seconds) before the 141 snapshot op considers a previously unfinished snapshot to be stale. 142 num_reader_threads: Number of threads to parallelize reading from snapshot. 143 Especially useful if compression is turned on since the decompression 144 operation tends to be intensive. Defaults to 1. If > 1, then this might 145 introduce non-determinism i.e. the order in which the elements are read 146 from the snapshot are different from the order they're written. 147 reader_buffer_size: Maximum number of elements we can prefetch reading from 148 the snapshot. Defaults to 1. Increasing this might improve performance but 149 will increase memory consumption. 150 num_writer_threads: Number of threads to parallelize writing from snapshot. 151 We'll open up `num_writer_threads` files and write to them in parallel. 152 Especially useful if compression is turned on since the compression 153 operation tends to be intensive. Defaults to 1. If > 1, then this might 154 introduce non-determinism i.e. the order in which the elements are read 155 from the upstream iterator are different from the order they're written. 156 writer_buffer_size: Maximum number of pipeline elements to fill up the 157 buffer before writing them out using `num_writer_threads`. 158 shuffle_on_read: If this is True, then the order in which examples are 159 produced when reading from a snapshot will be random. Defaults to False. 160 shuffle_seed: Optional. If shuffle_seed is set, the random number generator 161 used for shuffling (when shuffle_on_read is turned on) is seeded by the 162 given seed. Otherwise, it is seeded by a random seed that differs for 163 every run. 164 mode: The mode at which snapshot should operate. Valid options are "auto", 165 "read", "write", and "passthrough". The default mode is "auto", where the 166 snapshot op will automatically determine what mode to operate in. 167 snapshot_name: If set, use the supplied string as a named snapshot name 168 instead of introspecting the data pipeline and automatically generating a 169 unique identifier for the snapshot. 170 171 Returns: 172 A `Dataset` transformation function, which can be passed to 173 `tf.data.Dataset.apply`. 174 """ 175 176 def _apply_fn(dataset): 177 return _LegacySnapshotDataset( 178 input_dataset=dataset, 179 path=path, 180 compression=compression, 181 reader_path_prefix=reader_path_prefix, 182 writer_path_prefix=writer_path_prefix, 183 shard_size_bytes=shard_size_bytes, 184 pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds, 185 num_reader_threads=num_reader_threads, 186 reader_buffer_size=reader_buffer_size, 187 num_writer_threads=num_writer_threads, 188 writer_buffer_size=writer_buffer_size, 189 shuffle_on_read=shuffle_on_read, 190 shuffle_seed=shuffle_seed, 191 mode=mode, 192 snapshot_name=snapshot_name) 193 194 return _apply_fn 195 196 197class _SnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset): 198 """A dataset that allows saving and re-use of already processed data.""" 199 200 def __init__(self, 201 input_dataset, 202 path, 203 shard_func, 204 compression=None, 205 reader_func=None, 206 pending_snapshot_expiry_seconds=None, 207 use_legacy_function=False): 208 209 if reader_func is None: 210 reader_func = lambda datasets: datasets.interleave( # pylint:disable=g-long-lambda 211 lambda x: x, 212 cycle_length=multiprocessing.cpu_count(), 213 num_parallel_calls=dataset_ops.AUTOTUNE) 214 215 self._input_dataset = input_dataset 216 self._path = path 217 self._compression = compression 218 219 self._reader_func = dataset_ops.StructuredFunctionWrapper( 220 reader_func, 221 self._transformation_name() + ".reader_func", 222 # Dataset of datasets of input elements 223 input_structure=dataset_ops.DatasetSpec( 224 dataset_ops.DatasetSpec(input_dataset.element_spec)), 225 use_legacy_function=use_legacy_function) 226 self._shard_func = dataset_ops.StructuredFunctionWrapper( 227 shard_func, 228 self._transformation_name() + ".shard_func", 229 dataset=input_dataset, 230 use_legacy_function=use_legacy_function) 231 232 if ((not self._shard_func.output_structure.is_compatible_with( 233 tensor_spec.TensorSpec([], dtypes.int32))) and 234 (not self._shard_func.output_structure.is_compatible_with( 235 tensor_spec.TensorSpec([], dtypes.int64)))): 236 raise TypeError( 237 "shard_func must return a 0-dimension tensor containing an int.") 238 239 variant_tensor = ged_ops.snapshot_dataset_v2( 240 input_dataset._variant_tensor, # pylint: disable=protected-access 241 path, 242 self._reader_func.function.captured_inputs, 243 self._shard_func.function.captured_inputs, 244 compression=compression, 245 reader_func=self._reader_func.function, 246 shard_func=self._shard_func.function, 247 **self._flat_structure) 248 super(_SnapshotDataset, self).__init__(input_dataset, variant_tensor) 249 250 def _functions(self): 251 return [self._reader_func, self._shard_func] 252 253 def _transformation_name(self): 254 return "Dataset.snapshot" 255 256 257@tf_export("data.experimental.snapshot") 258def snapshot(path, compression="AUTO", reader_func=None, shard_func=None): 259 """API to persist the output of the input dataset. 260 261 The snapshot API allows users to transparently persist the output of their 262 preprocessing pipeline to disk, and materialize the pre-processed data on a 263 different training run. 264 265 This API enables repeated preprocessing steps to be consolidated, and allows 266 re-use of already processed data, trading off disk storage and network 267 bandwidth for freeing up more valuable CPU resources and accelerator compute 268 time. 269 270 https://github.com/tensorflow/community/blob/master/rfcs/20200107-tf-data-snapshot.md 271 has detailed design documentation of this feature. 272 273 Users can specify various options to control the behavior of snapshot, 274 including how snapshots are read from and written to by passing in 275 user-defined functions to the `reader_func` and `shard_func` parameters. 276 277 `shard_func` is a user specified function that maps input elements to snapshot 278 shards. 279 280 Users may want to specify this function to control how snapshot files should 281 be written to disk. Below is an example of how a potential shard_func could 282 be written. 283 284 ```python 285 dataset = ... 286 dataset = dataset.enumerate() 287 dataset = dataset.apply(tf.data.experimental.snapshot("/path/to/snapshot/dir", 288 shard_func=lambda x, y: x % NUM_SHARDS, ...)) 289 dataset = dataset.map(lambda x, y: y) 290 ``` 291 292 `reader_func` is a user specified function that accepts a single argument: 293 (1) a Dataset of Datasets, each representing a "split" of elements of the 294 original dataset. The cardinality of the input dataset matches the 295 number of the shards specified in the `shard_func` (see above). The function 296 should return a Dataset of elements of the original dataset. 297 298 Users may want specify this function to control how snapshot files should be 299 read from disk, including the amount of shuffling and parallelism. 300 301 Here is an example of a standard reader function a user can define. This 302 function enables both dataset shuffling and parallel reading of datasets: 303 304 ```python 305 def user_reader_func(datasets): 306 # shuffle the datasets splits 307 datasets = datasets.shuffle(NUM_CORES) 308 # read datasets in parallel and interleave their elements 309 return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE) 310 311 dataset = dataset.apply(tf.data.experimental.snapshot("/path/to/snapshot/dir", 312 reader_func=user_reader_func)) 313 ``` 314 315 By default, snapshot parallelizes reads by the number of cores available on 316 the system, but will not attempt to shuffle the data. 317 318 Args: 319 path: Required. A directory to use for storing / loading the snapshot to / 320 from. 321 compression: Optional. The type of compression to apply to the snapshot 322 written to disk. Supported options are `GZIP`, `SNAPPY`, `AUTO` or None. 323 Defaults to AUTO, which attempts to pick an appropriate compression 324 algorithm for the dataset. 325 reader_func: Optional. A function to control how to read data from snapshot 326 shards. 327 shard_func: Optional. A function to control how to shard data when writing a 328 snapshot. 329 330 Returns: 331 A `Dataset` transformation function, which can be passed to 332 `tf.data.Dataset.apply`. 333 """ 334 335 def _apply_fn(dataset): 336 """Actual dataset transformation.""" 337 project_func = None 338 if shard_func is None: 339 dataset = dataset.enumerate() 340 # This sets the amount of parallelism based on the number of CPU cores on 341 # the machine where this Python code is executed, which may differ from 342 # the number of CPU cores where the input pipeline graph is actually 343 # executed (e.g. remote Cloud TPU workers). 344 local_shard_func = lambda index, _: index % multiprocessing.cpu_count() 345 project_func = lambda _, elem: elem 346 else: 347 local_shard_func = shard_func 348 dataset = _SnapshotDataset( 349 input_dataset=dataset, 350 path=path, 351 compression=compression, 352 reader_func=reader_func, 353 # This will not do the right thing where the graph is built on a 354 # different machine than the executor (e.g. Cloud TPUs). 355 shard_func=local_shard_func) 356 if project_func is not None: 357 dataset = dataset.map(project_func) 358 return dataset 359 360 return _apply_fn 361