1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Distribution Strategy-related dataset transformations."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.python.data.experimental.ops.distribute_options import ExternalStatePolicy
23from tensorflow.python.data.ops import dataset_ops
24from tensorflow.python.data.util import nest
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.framework import tensor_util
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
32
33
34class _AutoShardDataset(dataset_ops.UnaryDataset):
35  """A `Dataset` that shards the `Dataset` automatically.
36
37  This dataset takes in an existing dataset and tries to automatically figure
38  out how to shard the dataset in a multi-worker scenario using graph rewrites.
39
40  If the AutoShardPolicy is set to FILE, it walks up the dataset graph until
41  it finds a reader dataset, then inserts a ShardDataset op before that node
42  so that each worker only sees some files.
43
44  If the AutoShardPolicy is set to DATA, it inserts a ShardDataset op at the
45  end of the input pipeline, before any terminal PrefetchDataset if there is
46  one. Additionally, if there is a RebatchDatasetV2 in the input pipeline, it
47  is written to legacy RebatchDataset for correctness reasons, since
48  RebatchDatasetV2 is incompatible with data sharding.
49
50  If the AutoShardPolicy is set to AUTO, it tries to do file-based sharding.
51  If it cannot find a reader dataset, it falls back to doing data-based
52  sharding.
53
54  If the AutoShardPolicy is set to OFF, it does nothing.
55
56  Args:
57    num_workers: Total number of workers to shard this dataset across.
58    index: The current worker index (out of the total number of workers) this
59      dataset is for.
60    num_replicas: The total number of replicas across all workers. This is used
61      only when sharding by data (either DATA or AUTO) in order to rewrite
62      RebatchDatasetV2 to RebatchDataset.
63
64  Raises:
65    NotFoundError: If we cannot find a suitable reader dataset to begin
66      automatically sharding the dataset.
67  """
68
69  def __init__(self, input_dataset, num_workers, index, num_replicas=None):
70    self._input_dataset = input_dataset
71
72    self._element_spec = input_dataset.element_spec
73    variant_tensor = ged_ops.auto_shard_dataset(
74        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
75        num_workers=num_workers,
76        index=index,
77        auto_shard_policy=int(
78            input_dataset.options().experimental_distribute.auto_shard_policy),
79        num_replicas=num_replicas,
80        **self._flat_structure)
81    super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor)
82
83  @property
84  def element_spec(self):
85    return self._element_spec
86
87
88def _AutoShardDatasetV1(input_dataset, num_workers, index, num_replicas=None):  # pylint: disable=invalid-name
89  return dataset_ops.DatasetV1Adapter(
90      _AutoShardDataset(input_dataset, num_workers, index, num_replicas))
91
92
93class _RebatchDataset(dataset_ops.UnaryDataset):
94  """A `Dataset` that rebatches elements from its input into new batch sizes.
95
96  `_RebatchDataset(input_dataset, batch_sizes)` is functionally equivalent to
97  `input_dataset.unbatch().batch(N)`, where the value of N cycles through the
98  `batch_sizes` input list. The elements produced by this dataset have the same
99  rank as the elements of the input dataset.
100
101  For example:
102
103  ```python
104  ds = tf.data.Dataset.range(8)
105  ds = ds.batch(4)
106  ds = _RebatchDataset(ds, batch_sizes=[2, 1, 1])
107  for elem in ds:
108    print(elem)
109  >> [0, 1], [2], [3], [4, 5], [6], [7]
110
111  ds = tf.data.Dataset.range(16)
112  ds = ds.batch(4)
113  ds = _RebatchDataset(ds, batch_sizes=[6])
114  for elem in ds:
115    print(elem)
116  >> [0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11], [12, 13, 14, 15]
117  ```
118  """
119
120  def __init__(self, input_dataset, batch_sizes, drop_remainder=False):
121    """Creates a _RebatchDataset.
122
123    Args:
124      input_dataset: `Dataset` to rebatch.
125      batch_sizes: A `tf.int64` scalar or vector, representing the size of
126        batches to produce. If this argument is a vector, these values are
127        cycled through in order.
128      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
129        whether the last batch should be dropped in the case it has fewer than
130        `batch_sizes[cycle_index] elements; the default behavior is not to drop
131        the smaller batch.
132    """
133    self._input_dataset = input_dataset
134    self._batch_sizes = ops.convert_to_tensor(
135        batch_sizes, dtype=dtypes.int64, name="batch_sizes")
136    self._drop_remainder = ops.convert_to_tensor(
137        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
138    new_batch_dim = self._compute_static_batch_dim()
139
140    # pylint: disable=protected-access
141    self._element_spec = nest.map_structure(
142        lambda ts: ts._unbatch()._batch(new_batch_dim),
143        dataset_ops.get_structure(input_dataset))
144    # pylint: enable=protected-access
145
146    input_dataset = dataset_ops.normalize_to_dense(input_dataset)
147    variant_tensor = ged_ops.rebatch_dataset_v2(
148        input_dataset._variant_tensor,  # pylint: disable=protected-access
149        batch_sizes=batch_sizes,
150        drop_remainder=drop_remainder,
151        **self._flat_structure)
152    super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
153
154  def _compute_static_batch_dim(self):
155    """Computes the static batch dimension of a dataset if it can be determined.
156
157    Given the _RebatchDataset parameters, determines the batch dimension of this
158    dataset statically. Returns None if this cannot be determined or is
159    variable.
160
161    Returns:
162      An integer representing the batch dimension of the dataset. If it cannot
163      be determined statically, returns None.
164
165    Raises:
166      ValueError: The batch_sizes parameter is malformed, input_dataset is
167      not batched, or input_dataset batch sizes are incompatible with each
168      other.
169    """
170    new_batch_dim = tensor_util.constant_value(self._batch_sizes)
171    if new_batch_dim is None:
172      return None
173
174    if isinstance(new_batch_dim, np.ndarray):
175      if len(new_batch_dim.shape) == 1:
176        if np.all(new_batch_dim == new_batch_dim[0]):
177          new_batch_dim = new_batch_dim[0]
178        else:
179          return None
180      elif len(new_batch_dim.shape) > 1:
181        raise ValueError("Expected batch_sizes to be a scalar or vector.")
182
183    if self._may_form_partial_batches(new_batch_dim):
184      return None
185
186    return new_batch_dim
187
188  def _may_form_partial_batches(self, desired_batch_size):
189    """Returns whether this dataset may form partial batches."""
190    if tensor_util.constant_value(self._drop_remainder):
191      return False
192
193    def get_batch_dim(type_spec):
194      shape = type_spec._to_legacy_output_shapes()  # pylint: disable=protected-access
195      if not isinstance(shape, tensor_shape.TensorShape):
196        return None
197      if shape.rank is None:
198        return None
199      if len(shape) < 1:
200        raise ValueError("Expected a dataset whose elements have rank >= 1 "
201                         "but found a dataset whose elements are scalars. "
202                         "You can fix the issue by adding the `batch` "
203                         "transformation to the dataset.")
204      return shape.dims[0].value
205
206    input_batch_dims = [
207        get_batch_dim(ts)
208        for ts in nest.flatten(dataset_ops.get_structure(self._input_dataset))
209    ]
210    known_input_batch_dims = [d for d in input_batch_dims if d is not None]
211
212    if not known_input_batch_dims:
213      return True
214
215    known_input_batch_dims = np.asarray(known_input_batch_dims)
216    if not np.all(known_input_batch_dims == known_input_batch_dims[0]):
217      raise ValueError("Batch dimensions of input dataset are not compatible.")
218
219    return known_input_batch_dims[0] % desired_batch_size != 0
220
221  @property
222  def element_spec(self):
223    return self._element_spec
224
225
226class _LegacyRebatchDataset(dataset_ops.UnaryDataset):
227  """A `Dataset` that divides its input batches into `num_replicas` sub-batches.
228
229  For each batch in the input dataset, _LegacyRebatchDataset will produce
230  `num_replicas` smaller batches whose sizes add up to the original batch size.
231
232  For example:
233
234  ```python
235  ds = tf.data.Dataset.range(8)
236  ds = ds.batch(4)
237  ds = _LegacyRebatchDataset(ds, num_replicas=3)
238  for elem in ds:
239    print(elem)
240  >> [0, 1], [2, 3], [], [4, 5], [6, 7], []
241  ```
242  """
243
244  def __init__(self, input_dataset, num_replicas):
245    """Creates a _LegacyRebatchDataset.
246
247    Args:
248      input_dataset: `Dataset` to rebatch.
249      num_replicas: A `tf.int64` scalar, representing the number of sub-batches
250        to split each batch from `input_dataset` into.
251    """
252
253    def recalculate_batch_size(type_spec):
254      """Recalculates the output_shape after dividing it by num_replicas."""
255      output_shape = type_spec._to_legacy_output_shapes()  # pylint: disable=protected-access
256      if not isinstance(output_shape, tensor_shape.TensorShape):
257        return None
258
259      # If the output shape is unknown, we set the batch dimension to unknown.
260      if output_shape.rank is None:
261        return None
262
263      if len(output_shape) < 1:
264        raise ValueError("Expected a dataset whose elements have rank >= 1 "
265                         "but found a dataset whose elements are scalars. "
266                         "You can fix the issue by adding the `batch` "
267                         "transformation to the dataset.")
268      output_dims = [d.value for d in output_shape.dims]
269
270      if output_dims[0] is not None and output_dims[0] % num_replicas == 0:
271        return output_dims[0] // num_replicas
272
273      # Set the batch dimension to unknown. If the global batch size does not
274      # divide num_replicas evenly, the minibatches may have different sizes.
275      return None
276
277    def rebatch(type_spec):
278      # pylint: disable=protected-access
279      batch_size = recalculate_batch_size(type_spec)
280      return type_spec._unbatch()._batch(batch_size)
281      # pylint: enable=protected-access
282
283    self._element_spec = nest.map_structure(
284        rebatch, dataset_ops.get_structure(input_dataset))
285    input_dataset = dataset_ops.normalize_to_dense(input_dataset)
286    variant_tensor = ged_ops.rebatch_dataset(
287        input_dataset._variant_tensor,  # pylint: disable=protected-access
288        num_replicas=num_replicas,
289        **self._flat_structure)
290    super(_LegacyRebatchDataset, self).__init__(input_dataset, variant_tensor)
291
292  @property
293  def element_spec(self):
294    return self._element_spec
295
296
297class _RemoteDataset(dataset_ops.DatasetSource):
298  """Creates a dataset on a given `device` given a graph def."""
299
300  def __init__(self, graph_def, device, element_spec):
301    self._elem_spec = element_spec
302    with ops.device(device):
303      variant_tensor = ged_ops.dataset_from_graph(graph_def)
304    super(_RemoteDataset, self).__init__(variant_tensor)
305
306  @property
307  def element_spec(self):
308    return self._elem_spec
309
310
311def replicate(dataset, devices):
312  """A transformation that replicates `dataset` onto a list of devices.
313
314  Args:
315    dataset: A `tf.data.Dataset` object.
316    devices: A list of devices to replicate the dataset on.
317
318  Returns:
319    A dictionary mapping device name to a dataset on that device.
320  """
321  if not isinstance(dataset, dataset_ops.DatasetV2):
322    raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
323
324  # pylint: disable=protected-access
325  dataset_device = dataset._variant_tensor.device
326
327  datasets = {}
328  if len(devices) == 1 and devices[0] == dataset_device:
329    datasets[devices[0]] = dataset
330    return datasets
331
332  with ops.colocate_with(dataset._variant_tensor):
333    # We apply options before replicating the dataset because options are
334    # currently not automatically preserved through dataset serialization and
335    # thus an explicit application of options here is needed to avoid losing
336    # `dataset` options.
337    #
338    # TODO(b/147325552): Propagating options to C++ upon their setting would
339    # allow us to preserve the options across both variant and GraphDef based
340    # serialization, avoiding the need to explicitly apply options here.
341    dataset = dataset._apply_options()
342    policy = dataset.options().experimental_external_state_policy
343    if policy is None:
344      policy = ExternalStatePolicy.WARN
345    graph_def = dataset._as_serialized_graph(
346        strip_device_assignment=True, external_state_policy=policy)
347  for device in devices:
348    ds = _RemoteDataset(graph_def, device, dataset.element_spec)
349    datasets[device] = ds
350  return datasets
351
352
353def batch_sizes_for_worker(global_batch_size, num_workers,
354                           num_replicas_per_worker, worker_index):
355  """Determines how to rebatch a dataset for the given worker.
356
357  Given the global batch size, number of workers, number of replicas per worker,
358  and worker index, returns the correct batch sizes for rebatching a dataset
359  on worker `worker_index` of `num_workers`, such that each global step (across
360  all workers and replicas) will consume global_batch_size elements. The
361  returned value should be passed as the `batch_sizes` input parameter to
362  `tf.data.experimental.rebatch()`. The returned batch sizes meet the following
363  constraints:
364
365  Let G = global_batch_size, W = num_workers, R = num_replicas_per_worker
366  (A) for any worker, len(batch_sizes) = W * R
367  (B) for any worker, sum(batch_sizes) == G
368  (C) for any global step (i.e. R iterations on each worker), the sum of batches
369      consumed by replicas across all workers is G.
370  (D) any two batch sizes of any two replicas differs by at most one.
371
372  For example, suppose we have G = 7, W = 2, R = 2, and suppose we have two
373  files which each contain 7 elements:
374
375  ```python
376  # WORKER 0
377  batch_sizes_0 = batch_sizes_for_worker(global_batch_size=global_batch_size,
378                                         num_workers=2,
379                                         num_replicas_per_worker=2,
380                                         worker_index=0)
381  print(batch_sizes_0)
382  >> [2, 2, 2, 1]
383
384  dataset_0 = tf.data.Dataset.from_tensor_slices(["file_a", "file_b"])
385  dataset_0 = dataset_0.shard(num_shards, index=0)
386  dataset_0 = dataset_0.batch(7)
387  dataset_0 = dataset_0.apply(tf.data.experimental.rebatch(batch_sizes_0))
388  for elem in dataset_0:
389    print(elem)
390  >> [[A0, A1], [A2, A3], [A4, A5], [A6]]
391
392  # WORKER 1
393  batch_sizes_1 = batch_sizes_for_worker(global_batch_size=global_batch_size,
394                                         num_workers=2,
395                                         num_replicas_per_worker=2,
396                                         worker_index=1)
397  print(batch_sizes_1)
398  >> [2, 1, 2, 2]
399
400  dataset_1 = tf.data.Dataset.from_tensor_slices(["file_a", "file_b"])
401  dataset_1 = dataset_1.shard(num_shards, index=1)
402  dataset_1 = dataset_1.batch(7)
403  dataset_1 = dataset_1.apply(tf.data.experimental.rebatch(batch_sizes_1))
404  for elem in dataset_1:
405    print(elem)
406  >> [[B0, B1], [B2], [B3, B4], [B5, B6]]
407  ```
408
409  The above example will produce the following elements:
410
411  Step 1:
412    Worker 0 Replica 0: [A0, A1]
413    Worker 0 Replica 1: [A2, A3]
414    Worker 1 Replica 0: [B0, B1]
415    Worker 1 Replica 1: [B2]
416  Total batch size = 7
417
418  Step 2:
419    Worker 0 Replica 0: [A4, A5]
420    Worker 0 Replica 1: [A6]
421    Worker 1 Replica 0: [B3, B4]
422    Worker 1 Replica 1: [B5, B6]
423  Total batch size = 7
424
425  Args:
426    global_batch_size: A `tf.int64` scalar, representing the global batch size.
427    num_workers: An integer representing the number of workers the dataset will
428      be distributed across.
429    num_replicas_per_worker: An integer representing the number of replicas per
430      worker. All workers are assumed to have the same number of replicas.
431    worker_index: An integer index of the worker to be rebatched.
432
433  Returns:
434    A `tf.int64` vector, representing the batch sizes to rebatch the dataset
435    into.
436  """
437  # Constraint (A)
438  num_subbatches = num_workers * num_replicas_per_worker
439
440  offset = worker_index * num_replicas_per_worker
441
442  const_value = tensor_util.constant_value(global_batch_size)
443  if const_value is not None:
444    # Use the constant global batch size for further calculations
445    global_batch_size = const_value
446
447  # Let N = W * R. Constraint (B) and (D) jointly mean that the iterations
448  # should have batch size either floor(B/N) or ceil(B/N). Namely, of the N
449  # subbatches a batch is split into, B - N * floor(B/N) of them will have size
450  # ceil(B/N), and the rest will have size floor(B/N).
451  floor = global_batch_size // num_subbatches
452  num_ceil = global_batch_size - (num_subbatches * floor)
453
454  # For worker 0, we assign the first num_ceil subbatches to have size
455  # ceil(B/N), and the remainder to have size floor(B/N). The other workers will
456  # each be offset by R * worker_index in order to meet constraint (C).
457  if const_value is not None:
458    # If the global batch size is a known constant value, we return a constant
459    # tensor directly instead of manipulating it with TF ops. This allows for
460    # better downstream shape inference.
461    worker_0 = [floor + 1] * num_ceil + [floor] * (num_subbatches - num_ceil)
462    return ops.convert_to_tensor(
463        worker_0[offset:] + worker_0[:offset],
464        dtype=dtypes.int64,
465        name="batch_sizes")
466
467  worker_0 = array_ops.ones(num_subbatches, dtype=dtypes.int64)
468  worker_0 = floor * worker_0 + array_ops.concat([
469      array_ops.ones(num_ceil, dtype=dtypes.int64),
470      array_ops.zeros(num_subbatches - num_ceil, dtype=dtypes.int64)
471  ],
472                                                 axis=0)
473
474  return array_ops.concat([worker_0[offset:], worker_0[:offset]], axis=0)
475
476
477def compute_batch_size(dataset):
478  """An operation that returns the batch size of the dataset.
479
480  This op tries to infer the batch size statically by walking up the dataset
481  tree from the final dataset node and returning the batch size of the first
482  batching dataset (such as from .batch() and .padded_batch()) that it
483  encounters. This differs from using the `element_spec` of a dataset in that it
484  does not account for partial batches.
485
486  This operation may fail if it encounters contradictory batch sizes (for
487  example, if the dataset is created by zipping together two datasets with
488  different batch sizes), if there are no explicit batching transformations, or
489  if there are operations downstream from the batching transformation that may
490  modify its batch size. In these cases, it returns a -1.
491
492  Args:
493    dataset: A `tf.data.Dataset` object.
494
495  Returns:
496    A `tf.int64` Tensor representing the batch size of the dataset sans partial
497    batches. If this cannot be inferred statically, the value of this tensor
498    will be -1.
499  """
500
501  def get_static_batch_dim(output_shape):
502    if output_shape.rank is None:
503      return None
504    return output_shape.dims[0].value
505
506  batch_dims = [
507      get_static_batch_dim(ts._to_legacy_output_shapes())  # pylint: disable=protected-access
508      for ts in nest.flatten(dataset_ops.get_structure(dataset))
509  ]
510
511  if all(d is not None for d in batch_dims):
512
513    if all(d == batch_dims[0] for d in batch_dims):
514      # If all batch dimensions are known and equal, return that directly.
515      batch_dim = batch_dims[0]
516    else:
517      # If all batch dimensions are known but not all equal, return -1.
518      batch_dim = -1
519
520    return constant_op.constant(
521        batch_dim, dtype=dtypes.int64, name="static_batch_size")
522
523  # If any batch dimensions are unknown, use compute_batch_size op.
524  return ged_ops.compute_batch_size(dataset._variant_tensor)  # pylint: disable=protected-access
525
526
527_AutoShardDatasetV1.__doc__ = _AutoShardDataset.__doc__
528