1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Dataset snapshot and related functionality."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import multiprocessing
21
22from tensorflow.python.data.ops import dataset_ops
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import random_seed
26from tensorflow.python.framework import tensor_spec
27from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
28from tensorflow.python.util import deprecation
29from tensorflow.python.util.tf_export import tf_export
30
31COMPRESSION_GZIP = "GZIP"
32COMPRESSION_SNAPPY = "SNAPPY"
33COMPRESSION_NONE = None
34
35
36class _LegacySnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset):
37  """A Dataset that captures a snapshot or reads from a snapshot."""
38
39  def __init__(self,
40               input_dataset,
41               path,
42               compression=None,
43               reader_path_prefix=None,
44               writer_path_prefix=None,
45               shard_size_bytes=None,
46               pending_snapshot_expiry_seconds=None,
47               num_reader_threads=None,
48               reader_buffer_size=None,
49               num_writer_threads=None,
50               writer_buffer_size=None,
51               shuffle_on_read=None,
52               shuffle_seed=None,
53               mode=None,
54               snapshot_name=None):
55
56    self._compression = compression if compression is not None else ""
57    self._reader_path_prefix = (
58        reader_path_prefix if reader_path_prefix is not None else "")
59    self._writer_path_prefix = (
60        writer_path_prefix if writer_path_prefix is not None else "")
61    self._shard_size_bytes = (
62        shard_size_bytes if shard_size_bytes is not None else -1)
63    self._pending_snapshot_expiry_seconds = (
64        pending_snapshot_expiry_seconds
65        if pending_snapshot_expiry_seconds is not None else -1)
66    self._num_reader_threads = (
67        num_reader_threads if num_reader_threads is not None else -1)
68    self._reader_buffer_size = (
69        reader_buffer_size if reader_buffer_size is not None else -1)
70    self._num_writer_threads = (
71        num_writer_threads if num_writer_threads is not None else -1)
72    self._writer_buffer_size = (
73        writer_buffer_size if writer_buffer_size is not None else -1)
74    self._shuffle_on_read = (
75        shuffle_on_read if shuffle_on_read is not None else False)
76    self._mode = (mode if mode is not None else "auto")
77    self._snapshot_name = (snapshot_name if snapshot_name is not None else "")
78
79    self._seed, self._seed2 = random_seed.get_seed(shuffle_seed)
80
81    self._input_dataset = input_dataset
82    self._path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path")
83
84    variant_tensor = ged_ops.snapshot_dataset(
85        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
86        path=self._path,
87        compression=self._compression,
88        reader_path_prefix=self._reader_path_prefix,
89        writer_path_prefix=self._writer_path_prefix,
90        shard_size_bytes=self._shard_size_bytes,
91        pending_snapshot_expiry_seconds=self._pending_snapshot_expiry_seconds,
92        num_reader_threads=self._num_reader_threads,
93        reader_buffer_size=self._reader_buffer_size,
94        num_writer_threads=self._num_writer_threads,
95        writer_buffer_size=self._writer_buffer_size,
96        shuffle_on_read=self._shuffle_on_read,
97        seed=self._seed,
98        seed2=self._seed2,
99        mode=self._mode,
100        snapshot_name=self._snapshot_name,
101        **self._flat_structure)
102
103    super(_LegacySnapshotDataset, self).__init__(input_dataset, variant_tensor)
104
105
106@deprecation.deprecated(
107    None, "Use `tf.data.experimental.snapshot(...)` instead.")
108def legacy_snapshot(path,
109                    compression=None,
110                    reader_path_prefix=None,
111                    writer_path_prefix=None,
112                    shard_size_bytes=None,
113                    pending_snapshot_expiry_seconds=None,
114                    num_reader_threads=None,
115                    reader_buffer_size=None,
116                    num_writer_threads=None,
117                    writer_buffer_size=None,
118                    shuffle_on_read=None,
119                    shuffle_seed=None,
120                    mode=None,
121                    snapshot_name=None):
122  """Writes to/reads from a snapshot of a dataset.
123
124  This function attempts to determine whether a valid snapshot exists at the
125  `path`, and reads from the snapshot if so. If not, it will run the
126  preprocessing pipeline as usual, and write out a snapshot of the data
127  processed for future use.
128
129  Args:
130    path: A directory where we want to save our snapshots and/or read from a
131      previously saved snapshot.
132    compression: The type of compression to apply to the Dataset. Currently
133      supports "GZIP" or None. Defaults to None (no compression).
134    reader_path_prefix: A prefix to add to the path when reading from snapshots.
135      Defaults to None.
136    writer_path_prefix: A prefix to add to the path when writing to snapshots.
137      Defaults to None.
138    shard_size_bytes: The size of each shard to be written by the snapshot
139      dataset op. Defaults to 10 GiB.
140    pending_snapshot_expiry_seconds: How long to wait (in seconds) before the
141      snapshot op considers a previously unfinished snapshot to be stale.
142    num_reader_threads: Number of threads to parallelize reading from snapshot.
143      Especially useful if compression is turned on since the decompression
144      operation tends to be intensive. Defaults to 1. If > 1, then this might
145      introduce non-determinism i.e. the order in which the elements are read
146      from the snapshot are different from the order they're written.
147    reader_buffer_size: Maximum number of elements we can prefetch reading from
148      the snapshot. Defaults to 1. Increasing this might improve performance but
149      will increase memory consumption.
150    num_writer_threads: Number of threads to parallelize writing from snapshot.
151      We'll open up `num_writer_threads` files and write to them in parallel.
152      Especially useful if compression is turned on since the compression
153      operation tends to be intensive. Defaults to 1. If > 1, then this might
154      introduce non-determinism i.e. the order in which the elements are read
155      from the upstream iterator are different from the order they're written.
156    writer_buffer_size: Maximum number of pipeline elements to fill up the
157      buffer before writing them out using `num_writer_threads`.
158    shuffle_on_read: If this is True, then the order in which examples are
159      produced when reading from a snapshot will be random. Defaults to False.
160    shuffle_seed: Optional. If shuffle_seed is set, the random number generator
161      used for shuffling (when shuffle_on_read is turned on) is seeded by the
162      given seed. Otherwise, it is seeded by a random seed that differs for
163      every run.
164    mode: The mode at which snapshot should operate. Valid options are "auto",
165      "read", "write", and "passthrough". The default mode is "auto", where the
166      snapshot op will automatically determine what mode to operate in.
167    snapshot_name: If set, use the supplied string as a named snapshot name
168      instead of introspecting the data pipeline and automatically generating a
169      unique identifier for the snapshot.
170
171  Returns:
172    A `Dataset` transformation function, which can be passed to
173    `tf.data.Dataset.apply`.
174  """
175
176  def _apply_fn(dataset):
177    return _LegacySnapshotDataset(
178        input_dataset=dataset,
179        path=path,
180        compression=compression,
181        reader_path_prefix=reader_path_prefix,
182        writer_path_prefix=writer_path_prefix,
183        shard_size_bytes=shard_size_bytes,
184        pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds,
185        num_reader_threads=num_reader_threads,
186        reader_buffer_size=reader_buffer_size,
187        num_writer_threads=num_writer_threads,
188        writer_buffer_size=writer_buffer_size,
189        shuffle_on_read=shuffle_on_read,
190        shuffle_seed=shuffle_seed,
191        mode=mode,
192        snapshot_name=snapshot_name)
193
194  return _apply_fn
195
196
197class _SnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset):
198  """A dataset that allows saving and re-use of already processed data."""
199
200  def __init__(self,
201               input_dataset,
202               path,
203               shard_func,
204               compression=None,
205               reader_func=None,
206               pending_snapshot_expiry_seconds=None,
207               use_legacy_function=False):
208
209    if reader_func is None:
210      reader_func = lambda datasets: datasets.interleave(  # pylint:disable=g-long-lambda
211          lambda x: x,
212          cycle_length=multiprocessing.cpu_count(),
213          num_parallel_calls=dataset_ops.AUTOTUNE)
214
215    self._input_dataset = input_dataset
216    self._path = path
217    self._compression = compression
218
219    self._reader_func = dataset_ops.StructuredFunctionWrapper(
220        reader_func,
221        self._transformation_name() + ".reader_func",
222        # Dataset of datasets of input elements
223        input_structure=dataset_ops.DatasetSpec(
224            dataset_ops.DatasetSpec(input_dataset.element_spec)),
225        use_legacy_function=use_legacy_function)
226    self._shard_func = dataset_ops.StructuredFunctionWrapper(
227        shard_func,
228        self._transformation_name() + ".shard_func",
229        dataset=input_dataset,
230        use_legacy_function=use_legacy_function)
231
232    if ((not self._shard_func.output_structure.is_compatible_with(
233        tensor_spec.TensorSpec([], dtypes.int32))) and
234        (not self._shard_func.output_structure.is_compatible_with(
235            tensor_spec.TensorSpec([], dtypes.int64)))):
236      raise TypeError(
237          "shard_func must return a 0-dimension tensor containing an int.")
238
239    variant_tensor = ged_ops.snapshot_dataset_v2(
240        input_dataset._variant_tensor,  # pylint: disable=protected-access
241        path,
242        self._reader_func.function.captured_inputs,
243        self._shard_func.function.captured_inputs,
244        compression=compression,
245        reader_func=self._reader_func.function,
246        shard_func=self._shard_func.function,
247        **self._flat_structure)
248    super(_SnapshotDataset, self).__init__(input_dataset, variant_tensor)
249
250  def _functions(self):
251    return [self._reader_func, self._shard_func]
252
253  def _transformation_name(self):
254    return "Dataset.snapshot"
255
256
257@tf_export("data.experimental.snapshot")
258def snapshot(path, compression="AUTO", reader_func=None, shard_func=None):
259  """API to persist the output of the input dataset.
260
261  The snapshot API allows users to transparently persist the output of their
262  preprocessing pipeline to disk, and materialize the pre-processed data on a
263  different training run.
264
265  This API enables repeated preprocessing steps to be consolidated, and allows
266  re-use of already processed data, trading off disk storage and network
267  bandwidth for freeing up more valuable CPU resources and accelerator compute
268  time.
269
270  https://github.com/tensorflow/community/blob/master/rfcs/20200107-tf-data-snapshot.md
271  has detailed design documentation of this feature.
272
273  Users can specify various options to control the behavior of snapshot,
274  including how snapshots are read from and written to by passing in
275  user-defined functions to the `reader_func` and `shard_func` parameters.
276
277  `shard_func` is a user specified function that maps input elements to snapshot
278  shards.
279
280  Users may want to specify this function to control how snapshot files should
281  be written to disk. Below is an example of how a potential shard_func could
282  be written.
283
284  ```python
285  dataset = ...
286  dataset = dataset.enumerate()
287  dataset = dataset.apply(tf.data.experimental.snapshot("/path/to/snapshot/dir",
288      shard_func=lambda x, y: x % NUM_SHARDS, ...))
289  dataset = dataset.map(lambda x, y: y)
290  ```
291
292  `reader_func` is a user specified function that accepts a single argument:
293  (1) a Dataset of Datasets, each representing a "split" of elements of the
294  original dataset. The cardinality of the input dataset matches the
295  number of the shards specified in the `shard_func` (see above). The function
296  should return a Dataset of elements of the original dataset.
297
298  Users may want specify this function to control how snapshot files should be
299  read from disk, including the amount of shuffling and parallelism.
300
301  Here is an example of a standard reader function a user can define. This
302  function enables both dataset shuffling and parallel reading of datasets:
303
304  ```python
305  def user_reader_func(datasets):
306    # shuffle the datasets splits
307    datasets = datasets.shuffle(NUM_CORES)
308    # read datasets in parallel and interleave their elements
309    return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE)
310
311  dataset = dataset.apply(tf.data.experimental.snapshot("/path/to/snapshot/dir",
312      reader_func=user_reader_func))
313  ```
314
315  By default, snapshot parallelizes reads by the number of cores available on
316  the system, but will not attempt to shuffle the data.
317
318  Args:
319    path: Required. A directory to use for storing / loading the snapshot to /
320      from.
321    compression: Optional. The type of compression to apply to the snapshot
322      written to disk. Supported options are `GZIP`, `SNAPPY`, `AUTO` or None.
323      Defaults to AUTO, which attempts to pick an appropriate compression
324      algorithm for the dataset.
325    reader_func: Optional. A function to control how to read data from snapshot
326      shards.
327    shard_func: Optional. A function to control how to shard data when writing a
328      snapshot.
329
330  Returns:
331    A `Dataset` transformation function, which can be passed to
332    `tf.data.Dataset.apply`.
333  """
334
335  def _apply_fn(dataset):
336    """Actual dataset transformation."""
337    project_func = None
338    if shard_func is None:
339      dataset = dataset.enumerate()
340      # This sets the amount of parallelism based on the number of CPU cores on
341      # the machine where this Python code is executed, which may differ from
342      # the number of CPU cores where the input pipeline graph is actually
343      # executed (e.g. remote Cloud TPU workers).
344      local_shard_func = lambda index, _: index % multiprocessing.cpu_count()
345      project_func = lambda _, elem: elem
346    else:
347      local_shard_func = shard_func
348    dataset = _SnapshotDataset(
349        input_dataset=dataset,
350        path=path,
351        compression=compression,
352        reader_func=reader_func,
353        # This will not do the right thing where the graph is built on a
354        # different machine than the executor (e.g. Cloud TPUs).
355        shard_func=local_shard_func)
356    if project_func is not None:
357      dataset = dataset.map(project_func)
358    return dataset
359
360  return _apply_fn
361