1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various classes representing distributed inputs."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import sys
23
24import six
25
26from tensorflow.python import tf2
27from tensorflow.python.data.experimental.ops import batching
28from tensorflow.python.data.experimental.ops import cardinality
29from tensorflow.python.data.experimental.ops import distribute
30from tensorflow.python.data.ops import dataset_ops
31from tensorflow.python.data.ops import iterator_ops
32from tensorflow.python.data.ops import multi_device_iterator_ops
33from tensorflow.python.data.ops import optional_ops
34from tensorflow.python.distribute import device_util
35from tensorflow.python.distribute import distribute_utils
36from tensorflow.python.distribute import distribution_strategy_context
37from tensorflow.python.distribute import input_ops
38from tensorflow.python.distribute import reduce_util
39from tensorflow.python.distribute import values
40from tensorflow.python.distribute.distribute_lib import InputReplicationMode
41from tensorflow.python.eager import context
42from tensorflow.python.framework import composite_tensor
43from tensorflow.python.framework import constant_op
44from tensorflow.python.framework import device as tf_device
45from tensorflow.python.framework import dtypes
46from tensorflow.python.framework import errors
47from tensorflow.python.framework import ops
48from tensorflow.python.framework import sparse_tensor
49from tensorflow.python.framework import tensor_shape
50from tensorflow.python.framework import tensor_util
51from tensorflow.python.framework import type_spec
52from tensorflow.python.ops import array_ops
53from tensorflow.python.ops import control_flow_ops
54from tensorflow.python.ops import math_ops
55from tensorflow.python.ops.ragged import ragged_tensor
56from tensorflow.python.types import distribute as distribute_types
57from tensorflow.python.util import nest
58from tensorflow.python.util.compat import collections_abc
59from tensorflow.python.util.deprecation import deprecated
60from tensorflow.python.util.tf_export import tf_export
61from tensorflow.tools.docs import doc_controls
62
63
64def get_distributed_dataset(dataset,
65                            input_workers,
66                            strategy,
67                            num_replicas_in_sync=None,
68                            input_context=None):
69  """Returns a distributed dataset from the given tf.data.Dataset instance.
70
71  This is a common function that is used by all strategies to return a
72  distributed dataset. The distributed dataset instance returned is different
73  depending on if we are in a TF 1 or TF 2 context. The distributed dataset
74  instances returned differ from each other in the APIs supported by each of
75  them.
76
77  Args:
78    dataset: a tf.data.Dataset instance.
79    input_workers: an InputWorkers object which specifies devices on which
80        iterators should be created.
81    strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
82        handle last partial batch.
83    num_replicas_in_sync: Optional integer. If this is not None, the value is
84        used to decide how to rebatch datasets into smaller batches so that
85        the total batch size for each step (across all workers and replicas)
86        adds up to `dataset`'s batch size.
87    input_context: `InputContext` for sharding. Only pass this in for between
88        graph multi-worker cases where there is only one `input_worker`. In
89        these cases, we will shard based on the `input_pipeline_id` and
90        `num_input_pipelines` in the `InputContext`.
91
92  Returns:
93    A distributed dataset instance.
94  """
95  if tf2.enabled():
96    return DistributedDataset(
97        dataset,
98        input_workers,
99        strategy,
100        num_replicas_in_sync=num_replicas_in_sync,
101        input_context=input_context)
102  else:
103    return DistributedDatasetV1(
104        dataset,
105        input_workers,
106        strategy,
107        num_replicas_in_sync=num_replicas_in_sync,
108        input_context=input_context)
109
110
111def get_distributed_datasets_from_function(dataset_fn,
112                                           input_workers,
113                                           input_contexts,
114                                           strategy,
115                                           options=None):
116  """Returns a distributed dataset from the given input function.
117
118  This is a common function that is used by all strategies to return a
119  distributed dataset. The distributed dataset instance returned is different
120  depending on if we are in a TF 1 or TF 2 context. The distributed dataset
121  instances returned differ from each other in the APIs supported by each of
122  them.
123
124  Args:
125    dataset_fn: a function that returns a tf.data.Dataset instance.
126    input_workers: an InputWorkers object which specifies devices on which
127        iterators should be created.
128    input_contexts: A list of `InputContext` instances to be passed to call(s)
129        to `dataset_fn`. Length and order should match worker order in
130        `worker_device_pairs`.
131    strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
132        handle last partial batch.
133    options: Default is None. `tf.distribute.InputOptions` used to control
134        options on how this dataset is distributed.
135
136  Returns:
137    A distributed dataset instance.
138
139  Raises:
140    ValueError: if `options.experimental_replication_mode` and
141    `options.experimental_place_dataset_on_device` are not consistent
142  """
143  if (options is not None and
144      options.experimental_replication_mode != InputReplicationMode.PER_REPLICA
145      and options.experimental_place_dataset_on_device):
146    raise ValueError(
147        "When `experimental_place_dataset_on_device` is set for dataset "
148        "placement, you must also specify `PER_REPLICA` for the "
149        "replication mode")
150
151  if (options is not None and
152      options.experimental_replication_mode == InputReplicationMode.PER_REPLICA
153      and options.experimental_prefetch_to_device and
154      options.experimental_place_dataset_on_device):
155    raise ValueError(
156        "`experimental_place_dataset_on_device` can not be set to True "
157        "when experimental_prefetch_to_device is True and "
158        "replication mode is set to `PER_REPLICA`")
159
160  if tf2.enabled():
161    return DistributedDatasetsFromFunction(dataset_fn, input_workers,
162                                           input_contexts, strategy, options)
163  else:
164    return DistributedDatasetsFromFunctionV1(
165        dataset_fn,
166        input_workers,
167        input_contexts,
168        strategy,
169        options)
170
171
172@tf_export("distribute.DistributedIterator", v1=[])
173class DistributedIteratorInterface(collections_abc.Iterator,
174                                   distribute_types.Iterator):
175  """An iterator over `tf.distribute.DistributedDataset`.
176
177  `tf.distribute.DistributedIterator` is the primary mechanism for enumerating
178  elements of a `tf.distribute.DistributedDataset`. It supports the Python
179  Iterator protocol, which means it can be iterated over using a for-loop or by
180  fetching individual elements explicitly via `get_next()`.
181
182  You can create a `tf.distribute.DistributedIterator` by calling `iter` on
183  a `tf.distribute.DistributedDataset` or creating a python loop over a
184  `tf.distribute.DistributedDataset`.
185
186  Visit the [tutorial](https://www.tensorflow.org/tutorials/distribute/input)
187  on distributed input for more examples and caveats.
188  """
189
190  def get_next(self):
191    """Returns the next input from the iterator for all replicas.
192
193    Example use:
194
195    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
196    >>> dataset = tf.data.Dataset.range(100).batch(2)
197    >>> dist_dataset = strategy.experimental_distribute_dataset(dataset)
198    >>> dist_dataset_iterator = iter(dist_dataset)
199    >>> @tf.function
200    ... def one_step(input):
201    ...   return input
202    >>> step_num = 5
203    >>> for _ in range(step_num):
204    ...   strategy.run(one_step, args=(dist_dataset_iterator.get_next(),))
205    >>> strategy.experimental_local_results(dist_dataset_iterator.get_next())
206    (<tf.Tensor: shape=(1,), dtype=int64, numpy=array([10])>,
207     <tf.Tensor: shape=(1,), dtype=int64, numpy=array([11])>)
208
209    Returns:
210      A single `tf.Tensor` or a `tf.distribute.DistributedValues` which contains
211      the next input for all replicas.
212
213    Raises:
214      `tf.errors.OutOfRangeError`: If the end of the iterator has been reached.
215    """
216    raise NotImplementedError(
217        "DistributedIterator.get_next() must be implemented in descendants.")
218
219  @property
220  def element_spec(self):
221    # pylint: disable=line-too-long
222    """The type specification of an element of `tf.distribute.DistributedIterator`.
223
224    Example usage:
225
226    >>> global_batch_size = 16
227    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
228    >>> dataset = tf.data.Dataset.from_tensors(([1.],[2])).repeat(100).batch(global_batch_size)
229    >>> distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset))
230    >>> distributed_iterator.element_spec
231    (PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None),
232                    TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)),
233     PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.int32, name=None),
234                    TensorSpec(shape=(None, 1), dtype=tf.int32, name=None)))
235
236    Returns:
237      A nested structure of `tf.TypeSpec` objects matching the structure of an
238      element of this `tf.distribute.DistributedIterator`. This returned value
239      is typically a `tf.distribute.DistributedValues` object and specifies the
240      `tf.TensorSpec` of individual components.
241    """
242    raise NotImplementedError(
243        "DistributedIterator.element_spec() must be implemented in descendants")
244
245  def get_next_as_optional(self):
246    # pylint: disable=line-too-long
247    """Returns a `tf.experimental.Optional` that contains the next value for all replicas.
248
249    If the `tf.distribute.DistributedIterator` has reached the end of the
250    sequence, the returned `tf.experimental.Optional` will have no value.
251
252    Example usage:
253
254    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
255    >>> global_batch_size = 2
256    >>> steps_per_loop = 2
257    >>> dataset = tf.data.Dataset.range(10).batch(global_batch_size)
258    >>> distributed_iterator = iter(
259    ...     strategy.experimental_distribute_dataset(dataset))
260    >>> def step_fn(x):
261    ...   # train the model with inputs
262    ...   return x
263    >>> @tf.function
264    ... def train_fn(distributed_iterator):
265    ...   for _ in tf.range(steps_per_loop):
266    ...     optional_data = distributed_iterator.get_next_as_optional()
267    ...     if not optional_data.has_value():
268    ...       break
269    ...     per_replica_results = strategy.run(step_fn, args=(optional_data.get_value(),))
270    ...     tf.print(strategy.experimental_local_results(per_replica_results))
271    >>> train_fn(distributed_iterator)
272    ... # ([0 1], [2 3])
273    ... # ([4], [])
274
275    Returns:
276      An `tf.experimental.Optional` object representing the next value from the
277      `tf.distribute.DistributedIterator` (if it has one) or no value.
278    """
279    # pylint: enable=line-too-long
280    raise NotImplementedError(
281        "get_next_as_optional() not implemented in descendants")
282
283
284@tf_export("distribute.DistributedDataset", v1=[])
285class DistributedDatasetInterface(collections_abc.Iterable,
286                                  distribute_types.Iterable):
287  # pylint: disable=line-too-long
288  """Represents a dataset distributed among devices and machines.
289
290  A `tf.distribute.DistributedDataset` could be thought of as a "distributed"
291  dataset. When you use `tf.distribute` API to scale training to multiple
292  devices or machines, you also need to distribute the input data, which leads
293  to a `tf.distribute.DistributedDataset` instance, instead of a
294  `tf.data.Dataset` instance in the non-distributed case. In TF 2.x,
295  `tf.distribute.DistributedDataset` objects are Python iterables.
296
297  Note: `tf.distribute.DistributedDataset` instances are *not* of type
298  `tf.data.Dataset`. It only supports two usages we will mention below:
299  iteration and `element_spec`. We don't support any other APIs to transform or
300  inspect the dataset.
301
302  There are two APIs to create a `tf.distribute.DistributedDataset` object:
303  `tf.distribute.Strategy.experimental_distribute_dataset(dataset)`and
304  `tf.distribute.Strategy.distribute_datasets_from_function(dataset_fn)`.
305  *When to use which?* When you have a `tf.data.Dataset` instance, and the
306  regular batch splitting (i.e. re-batch the input `tf.data.Dataset` instance
307  with a new batch size that is equal to the global batch size divided by the
308  number of replicas in sync) and autosharding (i.e. the
309  `tf.data.experimental.AutoShardPolicy` options) work for you, use the former
310  API. Otherwise, if you are *not* using a canonical `tf.data.Dataset` instance,
311  or you would like to customize the batch splitting or sharding, you can wrap
312  these logic in a `dataset_fn` and use the latter API. Both API handles
313  prefetch to device for the user. For more details and examples, follow the
314  links to the APIs.
315
316
317  There are two main usages of a `DistributedDataset` object:
318
319  1. Iterate over it to generate the input for a single device or multiple
320  devices, which is a `tf.distribute.DistributedValues` instance. To do this,
321  you can:
322
323    * use a pythonic for-loop construct:
324
325      >>> global_batch_size = 4
326      >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
327      >>> dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(4).batch(global_batch_size)
328      >>> dist_dataset = strategy.experimental_distribute_dataset(dataset)
329      >>> @tf.function
330      ... def train_step(input):
331      ...   features, labels = input
332      ...   return labels - 0.3 * features
333      >>> for x in dist_dataset:
334      ...   # train_step trains the model using the dataset elements
335      ...   loss = strategy.run(train_step, args=(x,))
336      ...   print("Loss is", loss)
337      Loss is PerReplica:{
338        0: tf.Tensor(
339      [[0.7]
340       [0.7]], shape=(2, 1), dtype=float32),
341        1: tf.Tensor(
342      [[0.7]
343       [0.7]], shape=(2, 1), dtype=float32)
344      }
345
346      Placing the loop inside a `tf.function` will give a performance boost.
347      However `break` and `return` are currently not supported if the loop is
348      placed inside a `tf.function`. We also don't support placing the loop
349      inside a `tf.function` when using
350      `tf.distribute.experimental.MultiWorkerMirroredStrategy` or
351      `tf.distribute.experimental.TPUStrategy` with multiple workers.
352
353    * use `__iter__` to create an explicit iterator, which is of type
354      `tf.distribute.DistributedIterator`
355
356      >>> global_batch_size = 4
357      >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
358      >>> train_dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(50).batch(global_batch_size)
359      >>> train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
360      >>> @tf.function
361      ... def distributed_train_step(dataset_inputs):
362      ...   def train_step(input):
363      ...     loss = tf.constant(0.1)
364      ...     return loss
365      ...   per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
366      ...   return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,axis=None)
367      >>> EPOCHS = 2
368      >>> STEPS = 3
369      >>> for epoch in range(EPOCHS):
370      ...   total_loss = 0.0
371      ...   num_batches = 0
372      ...   dist_dataset_iterator = iter(train_dist_dataset)
373      ...   for _ in range(STEPS):
374      ...     total_loss += distributed_train_step(next(dist_dataset_iterator))
375      ...     num_batches += 1
376      ...   average_train_loss = total_loss / num_batches
377      ...   template = ("Epoch {}, Loss: {:.4f}")
378      ...   print (template.format(epoch+1, average_train_loss))
379      Epoch 1, Loss: 0.2000
380      Epoch 2, Loss: 0.2000
381
382
383    To achieve a performance improvement, you can also wrap the `strategy.run`
384    call with a `tf.range` inside a `tf.function`. This runs multiple steps in a
385    `tf.function`. Autograph will convert it to a `tf.while_loop` on the worker.
386    However, it is less flexible comparing with running a single step inside
387    `tf.function`. For example, you cannot run things eagerly or arbitrary
388    python code within the steps.
389
390
391  2. Inspect the `tf.TypeSpec` of the data generated by `DistributedDataset`.
392
393    `tf.distribute.DistributedDataset` generates
394    `tf.distribute.DistributedValues` as input to the devices. If you pass the
395    input to a `tf.function` and would like to specify the shape and type of
396    each Tensor argument to the function, you can pass a `tf.TypeSpec` object to
397    the `input_signature` argument of the `tf.function`. To get the
398    `tf.TypeSpec` of the input, you can use the `element_spec` property of the
399    `tf.distribute.DistributedDataset` or `tf.distribute.DistributedIterator`
400    object.
401
402    For example:
403
404    >>> global_batch_size = 4
405    >>> epochs = 1
406    >>> steps_per_epoch = 1
407    >>> mirrored_strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
408    >>> dataset = tf.data.Dataset.from_tensors(([2.])).repeat(100).batch(global_batch_size)
409    >>> dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)
410    >>> @tf.function(input_signature=[dist_dataset.element_spec])
411    ... def train_step(per_replica_inputs):
412    ...   def step_fn(inputs):
413    ...     return tf.square(inputs)
414    ...   return mirrored_strategy.run(step_fn, args=(per_replica_inputs,))
415    >>> for _ in range(epochs):
416    ...   iterator = iter(dist_dataset)
417    ...   for _ in range(steps_per_epoch):
418    ...     output = train_step(next(iterator))
419    ...     print(output)
420    PerReplica:{
421      0: tf.Tensor(
422    [[4.]
423     [4.]], shape=(2, 1), dtype=float32),
424      1: tf.Tensor(
425    [[4.]
426     [4.]], shape=(2, 1), dtype=float32)
427    }
428
429
430  Visit the [tutorial](https://www.tensorflow.org/tutorials/distribute/input)
431  on distributed input for more examples and caveats.
432  """
433
434  def __iter__(self):
435    """Creates an iterator for the `tf.distribute.DistributedDataset`.
436
437    The returned iterator implements the Python Iterator protocol.
438
439    Example usage:
440
441    >>> global_batch_size = 4
442    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
443    >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4]).repeat().batch(global_batch_size)
444    >>> distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset))
445    >>> print(next(distributed_iterator))
446    PerReplica:{
447      0: tf.Tensor([1 2], shape=(2,), dtype=int32),
448      1: tf.Tensor([3 4], shape=(2,), dtype=int32)
449    }
450
451    Returns:
452      An `tf.distribute.DistributedIterator` instance for the given
453      `tf.distribute.DistributedDataset` object to enumerate over the
454      distributed data.
455    """
456    raise NotImplementedError("Must be implemented in descendants")
457
458  @property
459  def element_spec(self):
460    """The type specification of an element of this `tf.distribute.DistributedDataset`.
461
462    Example usage:
463
464    >>> global_batch_size = 16
465    >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
466    >>> dataset = tf.data.Dataset.from_tensors(([1.],[2])).repeat(100).batch(global_batch_size)
467    >>> dist_dataset = strategy.experimental_distribute_dataset(dataset)
468    >>> dist_dataset.element_spec
469    (PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None),
470                    TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)),
471     PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.int32, name=None),
472                    TensorSpec(shape=(None, 1), dtype=tf.int32, name=None)))
473
474    Returns:
475      A nested structure of `tf.TypeSpec` objects matching the structure of an
476      element of this `tf.distribute.DistributedDataset`. This returned value is
477      typically a `tf.distribute.DistributedValues` object and specifies the
478      `tf.TensorSpec` of individual components.
479    """
480    raise NotImplementedError(
481        "DistributedDataset.element_spec must be implemented in descendants.")
482
483  @doc_controls.do_not_generate_docs
484  def reduce(self, initial_state, reduce_func):
485    raise NotImplementedError(
486        "DistributedDataset.reduce must be implemented in descendants.")
487
488
489class InputWorkers(object):
490  """A 1-to-many mapping from input worker devices to compute devices."""
491
492  def __init__(self, worker_device_pairs):
493    """Initialize an `InputWorkers` object.
494
495    Args:
496      worker_device_pairs: A sequence of pairs:
497        `(input device, a tuple of compute devices fed by that input device)`.
498    """
499    self._worker_device_pairs = worker_device_pairs
500    self._input_worker_devices = tuple(d for d, _ in self._worker_device_pairs)
501    self._fed_devices = tuple(tuple(device_util.canonicalize(d) for d in f)
502                              for _, f in self._worker_device_pairs)
503
504  @property
505  def num_workers(self):
506    return len(self._input_worker_devices)
507
508  @property
509  def worker_devices(self):
510    return self._input_worker_devices
511
512  def compute_devices_for_worker(self, worker_index):
513    return self._fed_devices[worker_index]
514
515  def __repr__(self):
516    devices = self.worker_devices
517    debug_repr = ",\n".join("  %d %s: %s" %
518                            (i, devices[i], self._fed_devices[i])
519                            for i in range(len(devices)))
520    return "%s:{\n%s}" % (self.__class__.__name__, debug_repr)
521
522  def serialize(self):
523    return self._worker_device_pairs
524
525  def deserialize(self, worker_device_pairs):
526    return InputWorkers(worker_device_pairs)
527
528
529def _get_next_as_optional(iterator, strategy, return_per_replica=False):
530  """Returns an empty dataset indicator and the next input from the iterator.
531
532  Args:
533    iterator: a DistributedIterator object.
534    strategy: the `tf.distribute.Strategy` instance.
535    return_per_replica: a boolean. If True, the returned data will be wrapped
536      with `PerReplica` structure. Otherwise it is a 2D
537      num_input_workers*num_replicas_per_worker list.
538
539  Returns:
540    A tuple (a boolean tensor indicating whether the next batch has value
541    globally, data from all replicas).
542  """
543  replicas = []
544  worker_has_values = []
545  worker_devices = []
546  for i, worker in enumerate(iterator._input_workers.worker_devices):  # pylint: disable=protected-access
547    with ops.device(worker):
548      worker_has_value, next_element = (
549          iterator._iterators[i].get_next_as_list())  # pylint: disable=protected-access
550      # Collective all-reduce requires explicit devices for inputs.
551      with ops.device("/cpu:0"):
552        # Converting to integers for all-reduce.
553        worker_has_value = math_ops.cast(worker_has_value, dtypes.int64)
554        worker_devices.append(worker_has_value.device)
555        worker_has_values.append(worker_has_value)
556      # Make `replicas` a flat list of values across all replicas.
557      replicas.append(next_element)
558
559  if return_per_replica:
560    flattened_data = []
561    for per_worker_data in replicas:
562      flattened_data.extend(per_worker_data)
563    replicas = _create_per_replica(flattened_data, strategy)
564
565  # Run an all-reduce to see whether any worker has values.
566  # TODO(b/131423105): we should be able to short-cut the all-reduce in some
567  # cases.
568  if getattr(strategy.extended, "_support_per_replica_values", True):
569    # `reduce` expects a `PerReplica`, so we pass it one, even
570    # though it doesn't actually have a value per replica
571    worker_has_values = values.PerReplica(worker_has_values)
572    global_has_value = strategy.reduce(
573        reduce_util.ReduceOp.SUM, worker_has_values, axis=None)
574  else:
575    assert len(worker_has_values) == 1
576    global_has_value = worker_has_values[0]
577  global_has_value = array_ops.reshape(
578      math_ops.cast(global_has_value, dtypes.bool), [])
579  return global_has_value, replicas
580
581
582def _is_statically_shaped(element_spec):
583  """Test if an iterator output is statically shaped.
584
585  For sparse and ragged tensors this only tests the batch dimension.
586
587  Args:
588    element_spec: a nest structure of `tf.TypeSpec`. The element spec of the
589      dataset of the iterator.
590
591  Returns:
592    True if the shape is static, false otherwise.
593  """
594
595  for spec in nest.flatten(element_spec):
596    if isinstance(
597        spec, (sparse_tensor.SparseTensorSpec, ragged_tensor.RaggedTensorSpec)):
598      # For sparse or ragged tensor, we should only check the first
599      # dimension in order to get_next_as_optional. This is because
600      # when these tensors get batched by dataset only the batch dimension
601      # is set.
602      if spec.shape.rank > 0 and spec.shape.as_list()[0] is None:
603        return False
604    else:
605      for component in nest.flatten(spec._component_specs):  # pylint: disable=protected-access
606        if not component.shape.is_fully_defined():
607          return False
608  return True
609
610
611class DistributedIteratorBase(DistributedIteratorInterface):
612  """Common implementation for all input iterators."""
613
614  # pylint: disable=super-init-not-called
615  def __init__(self, input_workers, iterators, strategy,
616               enable_get_next_as_optional):
617    assert isinstance(input_workers, InputWorkers)
618    if not input_workers.worker_devices:
619      raise ValueError("Should have at least one worker for input iterator.")
620
621    self._iterators = iterators
622    self._input_workers = input_workers
623    self._strategy = strategy
624    self._enable_get_next_as_optional = enable_get_next_as_optional
625
626  def next(self):
627    return self.__next__()
628
629  def __next__(self):
630    try:
631      return self.get_next()
632    except errors.OutOfRangeError:
633      raise StopIteration
634
635  def __iter__(self):
636    return self
637
638  def get_next_as_optional(self):
639    global_has_value, replicas = _get_next_as_optional(
640        self, self._strategy, return_per_replica=True)
641
642    def return_none():
643      return optional_ops.Optional.empty(self._element_spec)
644
645    return control_flow_ops.cond(
646        global_has_value, lambda: optional_ops.Optional.from_value(replicas),
647        return_none)
648
649  def get_next(self, name=None):
650    """Returns the next input from the iterator for all replicas."""
651    if not self._enable_get_next_as_optional:
652      replicas = []
653      for i, worker in enumerate(self._input_workers.worker_devices):
654        if name is not None:
655          d = tf_device.DeviceSpec.from_string(worker)
656          new_name = "%s_%s_%d" % (name, d.job, d.task)
657        else:
658          new_name = None
659        with ops.device(worker):
660          # Make `replicas` a flat list of values across all replicas.
661          replicas.extend(
662              self._iterators[i].get_next_as_list_static_shapes(new_name))
663      return _create_per_replica(replicas, self._strategy)
664
665    out_of_range_replicas = []
666    def out_of_range_fn(worker_index, device):
667      """This function will throw an OutOfRange error."""
668      # As this will be only called when there is no data left, so calling
669      # get_next() will trigger an OutOfRange error.
670      data = self._iterators[worker_index].get_next(device)
671      out_of_range_replicas.append(data)
672      return data
673
674    global_has_value, replicas = _get_next_as_optional(
675        self, self._strategy, return_per_replica=False)
676    results = []
677    for i, worker in enumerate(self._input_workers.worker_devices):
678      with ops.device(worker):
679        devices = self._input_workers.compute_devices_for_worker(i)
680        for j, device in enumerate(devices):
681          with ops.device(device):
682            # pylint: disable=undefined-loop-variable
683            # pylint: disable=cell-var-from-loop
684            # It is fine for the lambda to capture variables from the loop as
685            # the lambda is executed in the loop as well.
686            result = control_flow_ops.cond(
687                global_has_value,
688                lambda: replicas[i][j],
689                lambda: out_of_range_fn(i, device),
690                strict=True,
691            )
692            # pylint: enable=cell-var-from-loop
693            # pylint: enable=undefined-loop-variable
694            results.append(result)
695    replicas = results
696
697    return _create_per_replica(replicas, self._strategy)
698
699
700class DistributedIteratorV1(DistributedIteratorBase):
701  """Input Iterator for a distributed dataset."""
702
703  # We need a private initializer method for re-initializing multidevice
704  # iterators when used with Keras training loops. If we don't reinitialize the
705  # iterator we run into memory leak issues (b/123315763).
706  @property
707  def _initializer(self):
708    init_ops = []
709    for it in self._iterators:
710      init_ops.extend(it.initialize())
711    return control_flow_ops.group(init_ops)
712
713  @deprecated(None, "Use the iterator's `initializer` property instead.")
714  def initialize(self):
715    """Initialize underlying iterators.
716
717    Returns:
718      A list of any initializer ops that should be run.
719    """
720    return self._initializer
721
722  @property
723  def initializer(self):
724    """Returns a list of ops that initialize the iterator."""
725    return self.initialize()
726
727  # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
728  @property
729  def output_classes(self):
730    return self._iterators[0].output_classes
731
732  # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
733  @property
734  def output_shapes(self):
735    return self._iterators[0].output_shapes
736
737  # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
738  @property
739  def output_types(self):
740    return self._iterators[0].output_types
741
742  # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
743  def get_iterator(self, worker):
744    for i, w in enumerate(self._input_workers.worker_devices):
745      if worker == w:
746        return self._iterators[i]
747    return None
748
749  @property
750  def element_spec(self):
751    """The type specification of an element of this iterator."""
752    return self._element_spec
753
754
755class DistributedIteratorSpec(type_spec.TypeSpec):
756  """Type specification for `DistributedIterator`."""
757
758  __slots__ = [
759      "_input_workers", "_element_spec", "_strategy",
760      "_enable_get_next_as_optional", "_options"
761  ]
762
763  def __init__(self, input_workers, element_spec, strategy,
764               enable_get_next_as_optional, options):
765    # We don't want to allow deserialization of this class because we don't
766    # serialize the strategy object. Currently the only places where
767    # _deserialize is called is when we save/restore using SavedModels.
768    if isinstance(input_workers, tuple):
769      raise NotImplementedError("DistributedIteratorSpec does not have support "
770                                "for deserialization.")
771    else:
772      self._input_workers = input_workers
773      self._element_spec = element_spec
774      self._strategy = strategy
775      self._enable_get_next_as_optional = enable_get_next_as_optional
776      self._options = options
777
778  @property
779  def value_type(self):
780    return DistributedIterator
781
782  def _serialize(self):
783    # We cannot serialize the strategy object so we convert it to an id that we
784    # can use for comparison.
785    return (self._input_workers.serialize(), self._element_spec,
786            id(self._strategy), id(self._options))
787
788  def _deserialize(self):
789    raise ValueError("Deserialization is currently unsupported for "
790                     "DistributedIteratorSpec.")
791
792  # Overriding this method so that we can merge and reconstruct the spec object
793  def most_specific_compatible_type(self, other):
794    """Returns the most specific TypeSpec compatible with `self` and `other`.
795
796    Args:
797      other: A `TypeSpec`.
798
799    Raises:
800      ValueError: If there is no TypeSpec that is compatible with both `self`
801        and `other`.
802    """
803    # pylint: disable=protected-access
804    if type(self) is not type(other):
805      raise ValueError("No TypeSpec is compatible with both %s and %s" %
806                       (self, other))
807    if self._input_workers.serialize() != other._input_workers.serialize():
808      raise ValueError("_input_workers is not compatible with both %s "
809                       "and %s" % (self, other))
810    if self._strategy is not other._strategy:
811      raise ValueError("tf.distribute strategy is not compatible with both %s "
812                       "and %s" % (self, other))
813    element_spec = nest.map_structure(
814        lambda a, b: a.most_specific_compatible_type(b), self._element_spec,
815        other._element_spec)
816    return DistributedIteratorSpec(self._input_workers, element_spec,
817                                   self._strategy,
818                                   self._enable_get_next_as_optional,
819                                   self._options)
820
821  @property
822  def _component_specs(self):
823    specs = []
824    worker_device_pairs = self._input_workers._worker_device_pairs  # pylint: disable=protected-access
825
826    for i, (input_device, compute_devices) in enumerate(worker_device_pairs):
827      element_spec = nest.map_structure(
828          functools.partial(_replace_per_replica_spec, i=i), self._element_spec)
829      specs.append(
830          _SingleWorkerDatasetIteratorSpec(input_device, compute_devices,
831                                           element_spec, self._options))
832    return specs
833
834  def _to_components(self, value):
835    return value._iterators  # pylint: disable=protected-access
836
837  def _from_components(self, components):
838    return DistributedIterator(
839        input_workers=self._input_workers,
840        iterators=None,
841        components=components,
842        element_spec=self._element_spec,
843        strategy=self._strategy,
844        enable_get_next_as_optional=self._enable_get_next_as_optional,
845        options=self._options)
846
847  @staticmethod
848  def from_value(value):
849    # pylint: disable=protected-access
850    return DistributedIteratorSpec(value._input_workers, value._element_spec,
851                                   value._strategy,
852                                   value._enable_get_next_as_optional,
853                                   value._options)
854
855  def _with_tensor_ranks_only(self):
856    element_spec = nest.map_structure(
857        lambda s: s._with_tensor_ranks_only(),  # pylint: disable=protected-access
858        self._element_spec)
859    return DistributedIteratorSpec(self._input_workers, element_spec,
860                                   self._strategy,
861                                   self._enable_get_next_as_optional,
862                                   self._options)
863
864
865class DistributedIterator(DistributedIteratorBase,
866                          composite_tensor.CompositeTensor):
867  """Input Iterator for a distributed dataset."""
868
869  def __init__(self,
870               input_workers=None,
871               iterators=None,
872               strategy=None,
873               components=None,
874               element_spec=None,
875               enable_get_next_as_optional=False,
876               options=None):
877    if input_workers is None:
878      raise ValueError("`input_workers` should be "
879                       "provided.")
880
881    error_message = ("Either `input_workers` or "
882                     "both `components` and `element_spec` need to be "
883                     "provided.")
884    self._options = options
885
886    if iterators is None:
887      if (components is None or element_spec is None):
888        raise ValueError(error_message)
889      self._element_spec = element_spec
890      self._input_workers = input_workers
891      self._iterators = components
892      self._strategy = strategy
893      self._enable_get_next_as_optional = enable_get_next_as_optional
894    else:
895      if (components is not None and element_spec is not None):
896        raise ValueError(error_message)
897
898      super(DistributedIterator,
899            self).__init__(input_workers, iterators, strategy,
900                           enable_get_next_as_optional)
901
902  @property
903  def element_spec(self):
904    # When partial batch handling is enabled, always set the batch dimension to
905    # None, otherwise we just follow element_spec of the underlying dataset
906    # (whose batch dimension may also be None). This is because with partial
907    # batching handling we could always produce empty batches.
908    if (self._enable_get_next_as_optional and
909        self._strategy.extended._in_multi_worker_mode()):  # pylint: disable=protected-access
910      return nest.map_structure(
911          _rebatch_as_dynamic, self._element_spec, expand_composites=False)
912    return self._element_spec
913
914  @property
915  def _type_spec(self):
916    # Note that we use actual element_spec instead of the rebatched-as-dynamic
917    # one to create DistributedIteratorSpec, to be consistent with the
918    # underlying iterators' specs.
919    return DistributedIteratorSpec(self._input_workers, self._element_spec,
920                                   self._strategy,
921                                   self._enable_get_next_as_optional,
922                                   self._options)
923
924
925class _IterableInput(DistributedDatasetInterface):
926  """Base class for iterable inputs for distribution strategies."""
927
928  # pylint: disable=super-init-not-called
929  def __init__(self, input_workers):
930    assert isinstance(input_workers, InputWorkers)
931    self._input_workers = input_workers
932
933  def __iter__(self):
934    raise NotImplementedError("must be implemented in descendants")
935
936  def reduce(self, initial_state, reduce_fn):
937    """Execute a `reduce_fn` over all the elements of the input."""
938    iterator = iter(self)
939    has_data, data = _get_next_as_optional(
940        iterator, self._strategy, return_per_replica=True)
941
942    def cond(has_data, data, state):
943      del data, state  # Unused.
944      return has_data
945
946    def loop_body(has_data, data, state):
947      """Executes `reduce_fn` in a loop till the dataset is empty."""
948      del has_data  # Unused.
949      state = reduce_fn(state, data)
950      has_data, data = _get_next_as_optional(
951          iterator, self._strategy, return_per_replica=True)
952      return has_data, data, state
953
954    has_data, data, final_state = control_flow_ops.while_loop(
955        cond, loop_body, [has_data, data, initial_state], parallel_iterations=1)
956    return final_state
957
958
959class DistributedDataset(_IterableInput):
960  """Distributed dataset that supports prefetching to multiple devices."""
961
962  def __init__(self,
963               dataset,
964               input_workers,
965               strategy,
966               num_replicas_in_sync=None,
967               input_context=None):
968    """Distribute the dataset on all workers.
969
970    If `num_replicas_in_sync` is not None, we split each batch of the dataset
971    into `num_replicas_in_sync` smaller batches, to be distributed among that
972    worker's replicas, so that the batch size for a global step (across all
973    workers and replicas) is as expected.
974
975    Args:
976      dataset: `tf.data.Dataset` that will be used as the input source.
977      input_workers: an `InputWorkers` object.
978      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
979        handle last partial batch.
980      num_replicas_in_sync: Optional integer. If this is not None, the value
981        is used to decide how to rebatch datasets into smaller batches so that
982        the total batch size for each step (across all workers and replicas)
983        adds up to `dataset`'s batch size.
984      input_context: `InputContext` for sharding. Only pass this in for between
985        graph multi-worker cases where there is only one `input_worker`. In
986        these cases, we will shard based on the `input_pipeline_id` and
987        `num_input_pipelines` in the `InputContext`.
988    """
989    super(DistributedDataset, self).__init__(input_workers=input_workers)
990    # We clone and shard the dataset on each worker. The current setup tries to
991    # shard the dataset by files if possible so that each worker sees a
992    # different subset of files. If that is not possible, will attempt to shard
993    # the final input such that each worker will run the entire preprocessing
994    # pipeline and only receive its own shard of the dataset.
995
996    # Additionally, we rebatch the dataset on each worker into
997    # `num_replicas_in_sync` smaller batches to be distributed among that
998    # worker's replicas, so that the batch size for a global step (across all
999    # workers and replicas) adds up to the original dataset's batch size.
1000    if num_replicas_in_sync is not None:
1001      num_workers = input_context.num_input_pipelines if input_context else len(
1002          input_workers.worker_devices)
1003      rebatch_fn = self._make_rebatch_fn(dataset, num_workers,
1004                                         num_replicas_in_sync)
1005    else:
1006      rebatch_fn = None
1007
1008    self._cloned_datasets = []
1009    if input_context:
1010      # Between-graph where we rely on the input_context for sharding
1011      assert input_workers.num_workers == 1
1012      if rebatch_fn is not None:
1013        dataset = rebatch_fn(dataset, input_context.input_pipeline_id)
1014      dataset = input_ops.auto_shard_dataset(dataset,
1015                                             input_context.num_input_pipelines,
1016                                             input_context.input_pipeline_id,
1017                                             num_replicas_in_sync)
1018      self._cloned_datasets.append(dataset)
1019    else:
1020      replicated_ds = distribute.replicate(dataset,
1021                                           input_workers.worker_devices)
1022      for i, worker in enumerate(input_workers.worker_devices):
1023        with ops.device(worker):
1024          cloned_dataset = replicated_ds[worker]
1025          cloned_dataset = cloned_dataset.with_options(dataset.options())
1026          if rebatch_fn is not None:
1027            cloned_dataset = rebatch_fn(cloned_dataset, i)
1028          cloned_dataset = input_ops.auto_shard_dataset(
1029              cloned_dataset, len(input_workers.worker_devices), i,
1030              num_replicas_in_sync)
1031          self._cloned_datasets.append(cloned_dataset)
1032
1033    self._input_workers = input_workers
1034    self._strategy = strategy
1035    self._enable_get_next_as_optional = _enable_get_next_as_optional(
1036        self._strategy, dataset)
1037    self._element_spec = _create_distributed_tensor_spec(
1038        self._strategy, self._cloned_datasets[0].element_spec)
1039
1040  def _make_rebatch_fn(self, dataset, num_workers, num_replicas_in_sync):
1041    """Returns a callable that rebatches the input dataset.
1042
1043    Args:
1044      dataset: A `tf.data.Dataset` representing the dataset to be distributed.
1045      num_workers: An integer representing the number of workers to distribute
1046        `dataset` among.
1047      num_replicas_in_sync: An integer representing the number of replicas in
1048        sync across all workers.
1049    """
1050    if num_replicas_in_sync % num_workers:
1051      raise ValueError(
1052          "tf.distribute expects every worker to have the same number of "
1053          "replicas. However, encountered `num_replicas_in_sync` ({}) that "
1054          "cannot be divided by `num_workers` ({})".format(
1055              num_replicas_in_sync, num_workers))
1056
1057    num_replicas_per_worker = num_replicas_in_sync // num_workers
1058    with ops.colocate_with(dataset._variant_tensor):  # pylint: disable=protected-access
1059      batch_size = distribute.compute_batch_size(dataset)
1060
1061    def rebatch_fn(dataset, worker_index):
1062      try:
1063        # pylint: disable=protected-access
1064        def apply_rebatch():
1065          batch_sizes = distribute.batch_sizes_for_worker(
1066              batch_size, num_workers, num_replicas_per_worker, worker_index)
1067          return distribute._RebatchDataset(
1068              dataset, batch_sizes).prefetch(num_replicas_per_worker)
1069
1070        def apply_legacy_rebatch():
1071          return distribute._LegacyRebatchDataset(
1072              dataset, num_replicas_in_sync).prefetch(num_replicas_per_worker)
1073
1074        with ops.colocate_with(dataset._variant_tensor):
1075          return control_flow_ops.cond(
1076              math_ops.not_equal(batch_size, -1),
1077              true_fn=apply_rebatch,
1078              false_fn=apply_legacy_rebatch)
1079      except errors.InvalidArgumentError as e:
1080        if "without encountering a batch" in str(e):
1081          six.reraise(
1082              ValueError,
1083              ValueError(
1084                  "Call the `batch` method on the input Dataset in order to be "
1085                  "able to split your input across {} replicas.\n Please see "
1086                  "the tf.distribute.Strategy guide. {}".format(
1087                      num_replicas_in_sync, e)),
1088              sys.exc_info()[2])
1089        else:
1090          raise
1091
1092    return rebatch_fn
1093
1094  def __iter__(self):
1095    if not (context.executing_eagerly() or
1096            ops.get_default_graph().building_function):
1097      raise RuntimeError("__iter__() is only supported inside of tf.function "
1098                         "or when eager execution is enabled.")
1099
1100    # This is an optional flag that can be used to turn off using
1101    # OwnedMultiDeviceIterators and instead use the legacy MultiDeviceIterators
1102    # as a stop gap solution that will allow us to roll out this change.
1103    enable_legacy_iterators = getattr(self._strategy,
1104                                      "_enable_legacy_iterators", False)
1105    worker_iterators = _create_iterators_per_worker(self._cloned_datasets,
1106                                                    self._input_workers,
1107                                                    enable_legacy_iterators)
1108    if enable_legacy_iterators:
1109      iterator = DistributedIteratorV1(
1110          self._input_workers,
1111          worker_iterators,
1112          self._strategy,
1113          enable_get_next_as_optional=self._enable_get_next_as_optional)
1114    else:
1115      iterator = DistributedIterator(
1116          self._input_workers,
1117          worker_iterators,
1118          self._strategy,
1119          enable_get_next_as_optional=self._enable_get_next_as_optional)
1120    iterator._element_spec = self._element_spec  # pylint: disable=protected-access
1121
1122    # When async eager is enabled, sometimes the iterator may not finish
1123    # initialization before passing to a multi device function, add a sync point
1124    # here to make sure all underlying iterators are initialized.
1125    if context.executing_eagerly():
1126      context.async_wait()
1127
1128    return iterator
1129
1130  @property
1131  def element_spec(self):
1132    """The type specification of an element of this dataset."""
1133    # When partial batch handling is enabled, always set the batch dimension to
1134    # None, otherwise we just follow element_spec of the underlying dataset
1135    # (whose batch dimension may also be None). This is because with partial
1136    # batching handling we could always produce empty batches.
1137    if (self._enable_get_next_as_optional and
1138        self._strategy.extended._in_multi_worker_mode()):  # pylint: disable=protected-access
1139      return nest.map_structure(
1140          _rebatch_as_dynamic, self._element_spec, expand_composites=False)
1141    return self._element_spec
1142
1143
1144class DistributedDatasetV1(DistributedDataset):
1145  """Distributed dataset that supports prefetching to multiple devices."""
1146
1147  def __init__(self,
1148               dataset,
1149               input_workers,
1150               strategy,
1151               num_replicas_in_sync=None,
1152               input_context=None):
1153    self._input_workers = input_workers
1154    super(DistributedDatasetV1, self).__init__(
1155        dataset,
1156        input_workers,
1157        strategy,
1158        num_replicas_in_sync=num_replicas_in_sync,
1159        input_context=input_context)
1160
1161  def make_one_shot_iterator(self):
1162    """Get a one time use iterator for DistributedDatasetV1.
1163
1164    Note: This API is deprecated. Please use `for ... in dataset:` to iterate
1165    over the dataset or `iter` to create an iterator.
1166
1167    Returns:
1168      A DistributedIteratorV1 instance.
1169    """
1170    return self._make_one_shot_iterator()
1171
1172  def _make_one_shot_iterator(self):
1173    """Get an iterator for DistributedDatasetV1."""
1174    # Graph mode with one shot iterator is disabled because we have to call
1175    # `initialize` on the iterator which is only required if we are using a
1176    # tf.distribute strategy.
1177    if not context.executing_eagerly():
1178      raise ValueError("Cannot create a one shot iterator. Please use "
1179                       "`make_initializable_iterator()` instead.")
1180    return self._get_iterator()
1181
1182  def make_initializable_iterator(self):
1183    """Get an initializable iterator for DistributedDatasetV1.
1184
1185    Note: This API is deprecated. Please use
1186    `tf.compat.v1.data.make_initializable_iterator(dataset)` to create an
1187    initializable iterator.
1188
1189    Returns:
1190      A DistributedIteratorV1 instance.
1191    """
1192    return self._make_initializable_iterator()
1193
1194  def _make_initializable_iterator(self, shared_name=None):  # pylint: disable=unused-argument
1195    """Get an initializable iterator for DistributedDatasetV1."""
1196    # Eager mode generates already initialized iterators. Hence we cannot create
1197    # an initializable iterator.
1198    if context.executing_eagerly():
1199      raise ValueError("Cannot create initializable iterator in Eager mode. "
1200                       "Please use `iter()` instead.")
1201    return self._get_iterator()
1202
1203  def _get_iterator(self):
1204    worker_iterators = _create_iterators_per_worker(self._cloned_datasets,
1205                                                    self._input_workers,
1206                                                    True)
1207    iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
1208                                     self._strategy,
1209                                     self._enable_get_next_as_optional)
1210    iterator._element_spec = self.element_spec  # pylint: disable=protected-access
1211
1212    # When async eager is enabled, sometimes the iterator may not finish
1213    # initialization before passing to a multi device function, add a sync point
1214    # here to make sure all underlying iterators are initialized.
1215    if context.executing_eagerly():
1216      context.async_wait()
1217
1218    return iterator
1219
1220  def __iter__(self):
1221    if (ops.executing_eagerly_outside_functions() or
1222        ops.get_default_graph().building_function):
1223      return self._get_iterator()
1224
1225    raise RuntimeError("__iter__() is only supported inside of tf.function "
1226                       "or when eager execution is enabled.")
1227
1228
1229# TODO(priyag): Add other replication modes.
1230class DistributedDatasetsFromFunction(_IterableInput):
1231  """Inputs created from dataset function."""
1232
1233  def __init__(self, dataset_fn, input_workers, input_contexts, strategy,
1234               options):
1235    """Makes an iterable from datasets created by the given function.
1236
1237    Args:
1238      dataset_fn: A function that returns a `Dataset` given an `InputContext`.
1239      input_workers: an `InputWorkers` object.
1240      input_contexts: A list of `InputContext` instances to be passed to call(s)
1241        to `dataset_fn`. Length and order should match worker order in
1242        `worker_device_pairs`.
1243      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
1244        handle last partial batch.
1245      options: `tf.distribute.InputOptions` used to control options on how this
1246        dataset is distributed.
1247    """
1248    super(DistributedDatasetsFromFunction, self).__init__(
1249        input_workers=input_workers)
1250
1251    if input_workers.num_workers != len(input_contexts):
1252      raise ValueError(
1253          "Number of input workers (%d) is not same as number of "
1254          "input_contexts (%d)" %
1255          (input_workers.num_workers, len(input_contexts)))
1256
1257    self._input_workers = input_workers
1258    self._input_contexts = input_contexts
1259    self._strategy = strategy
1260    self._options = options
1261    self._datasets, element_spec = (
1262        _create_datasets_from_function_with_input_context(
1263            self._input_contexts, self._input_workers, dataset_fn))
1264    self._enable_get_next_as_optional = _enable_get_next_as_optional(
1265        self._strategy, self._datasets[0])
1266    self._element_spec = _create_distributed_tensor_spec(
1267        self._strategy, element_spec)
1268
1269  def __iter__(self):
1270    if (ops.executing_eagerly_outside_functions() or
1271        ops.get_default_graph().building_function):
1272      # This is an optional flag that can be used to turn off using
1273      # OwnedMultiDeviceIterators and instead use the legacy
1274      # MultiDeviceIterators as a stop gap solution that will allow us to roll
1275      # out this change.
1276      enable_legacy_iterators = getattr(self._strategy,
1277                                        "_enable_legacy_iterators", False)
1278      iterators = _create_iterators_per_worker(self._datasets,
1279                                               self._input_workers,
1280                                               enable_legacy_iterators,
1281                                               self._options)
1282      if enable_legacy_iterators:
1283        iterator = DistributedIteratorV1(
1284            self._input_workers,
1285            iterators,
1286            self._strategy,
1287            enable_get_next_as_optional=self._enable_get_next_as_optional)
1288      else:
1289        iterator = DistributedIterator(
1290            input_workers=self._input_workers,
1291            iterators=iterators,
1292            strategy=self._strategy,
1293            enable_get_next_as_optional=self._enable_get_next_as_optional,
1294            options=self._options)
1295      iterator._element_spec = self._element_spec  # pylint: disable=protected-access
1296
1297      # When async eager is enabled, sometimes the iterator may not finish
1298      # initialization before passing to a multi device function, add a sync
1299      # point here to make sure all underlying iterators are initialized.
1300      if context.executing_eagerly():
1301        context.async_wait()
1302
1303      return iterator
1304
1305    raise RuntimeError("__iter__() is only supported inside of tf.function "
1306                       "or when eager execution is enabled.")
1307
1308  @property
1309  def element_spec(self):
1310    """The type specification of an element of this dataset."""
1311    # When partial batch handling is enabled, always set the batch dimension to
1312    # None, otherwise we just follow element_spec of the underlying dataset
1313    # (whose batch dimension may also be None). This is because with partial
1314    # batching handling we could always produce empty batches.
1315    if (self._enable_get_next_as_optional and
1316        self._strategy.extended._in_multi_worker_mode()):  # pylint: disable=protected-access
1317      return nest.map_structure(
1318          _rebatch_as_dynamic, self._element_spec, expand_composites=False)
1319    return self._element_spec
1320
1321
1322class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction):
1323  """Inputs created from dataset function."""
1324
1325  def _make_initializable_iterator(self, shared_name=None):
1326    """Get an initializable iterator for DistributedDatasetsFromFunctionV1."""
1327    del shared_name  # Unused
1328    # Eager mode generates already initialized iterators. Hence we cannot create
1329    # an initializable iterator.
1330    if context.executing_eagerly():
1331      raise ValueError("Cannot create initializable iterator in Eager mode. "
1332                       "Please use `iter()` instead.")
1333    return self._get_iterator()
1334
1335  def _make_one_shot_iterator(self):
1336    """Get an iterator for iterating over DistributedDatasetsFromFunctionV1."""
1337    # Graph mode with one shot iterator is disabled because we have to call
1338    # `initialize` on the iterator which is only required if we are using a
1339    # tf.distribute strategy.
1340    if not context.executing_eagerly():
1341      raise ValueError("Cannot create a one shot iterator. Please use "
1342                       "`make_initializable_iterator()` instead.")
1343    return self._get_iterator()
1344
1345  def _get_iterator(self):
1346    iterators = _create_iterators_per_worker(self._datasets,
1347                                             self._input_workers, True)
1348    iterator = DistributedIteratorV1(self._input_workers, iterators,
1349                                     self._strategy,
1350                                     self._enable_get_next_as_optional)
1351    iterator._element_spec = self._element_spec  # pylint: disable=protected-access
1352
1353    # When async eager is enabled, sometimes the iterator may not finish
1354    # initialization before passing to a multi device function, add a sync point
1355    # here to make sure all underlying iterators are initialized.
1356    if context.executing_eagerly():
1357      context.async_wait()
1358
1359    return iterator
1360
1361  def __iter__(self):
1362    if (ops.executing_eagerly_outside_functions() or
1363        ops.get_default_graph().building_function):
1364      return self._get_iterator()
1365
1366    raise RuntimeError("__iter__() is only supported inside of tf.function "
1367                       "or when eager execution is enabled.")
1368
1369
1370# TODO(anjalisridhar): This class will be soon removed in favor of newer
1371# APIs.
1372class InputFunctionIterator(DistributedIteratorV1):
1373  """Iterator created from input function."""
1374
1375  def __init__(self, input_fn, input_workers, input_contexts, strategy):
1376    """Make an iterator for input provided via an input function.
1377
1378    Currently implements PER_WORKER mode, in which the `input_fn` is called
1379    once on each worker.
1380
1381    TODO(priyag): Add other replication modes.
1382
1383    Args:
1384      input_fn: Input function that returns a `tf.data.Dataset` object.
1385      input_workers: an `InputWorkers` object.
1386      input_contexts: A list of `InputContext` instances to be passed to call(s)
1387        to `input_fn`. Length and order should match worker order in
1388        `worker_device_pairs`.
1389      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
1390        handle last partial batch.
1391    """
1392    assert isinstance(input_workers, InputWorkers)
1393    if input_workers.num_workers != len(input_contexts):
1394      raise ValueError(
1395          "Number of input workers (%d) is not same as number of "
1396          "input_contexts (%d)" %
1397          (input_workers.num_workers, len(input_contexts)))
1398
1399    iterators = []
1400    for i, ctx in enumerate(input_contexts):
1401      worker = input_workers.worker_devices[i]
1402      with ops.device(worker):
1403        result = input_fn(ctx)
1404        devices = input_workers.compute_devices_for_worker(i)
1405        if isinstance(result, dataset_ops.DatasetV2):
1406          iterator = _SingleWorkerDatasetIterator(result, worker, devices)
1407        elif callable(result):
1408          iterator = _SingleWorkerCallableIterator(result, worker, devices)
1409        else:
1410          raise ValueError(
1411              "input_fn must return a tf.data.Dataset or a callable.")
1412        iterators.append(iterator)
1413
1414    super(InputFunctionIterator, self).__init__(
1415        input_workers, iterators, strategy, enable_get_next_as_optional=False)
1416    self._enable_get_next_as_optional = False
1417
1418
1419# TODO(anjalisridhar): This class will soon be removed and users should move
1420# to using DistributedIterator.
1421class DatasetIterator(DistributedIteratorV1):
1422  """Iterator created from input dataset."""
1423
1424  def __init__(self,
1425               dataset,
1426               input_workers,
1427               strategy,
1428               num_replicas_in_sync=None,
1429               input_context=None):
1430    """Make an iterator for the dataset on given devices.
1431
1432    If `num_replicas_in_sync` is not None, we split each batch of the dataset
1433    into `num_replicas_in_sync` smaller batches, to be distributed among that
1434    worker's replicas, so that the batch size for a global step (across all
1435    workers and replicas) is as expected.
1436
1437    Args:
1438      dataset: `tf.data.Dataset` that will be used as the input source.
1439      input_workers: an `InputWorkers` object.
1440      strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
1441        handle last partial batch.
1442      num_replicas_in_sync: Optional integer. If this is not None, the value is
1443        used to decide how to rebatch datasets into smaller batches so that the
1444        total batch size for each step (across all workers and replicas) adds up
1445        to `dataset`'s batch size.
1446      input_context: `InputContext` for sharding. Only pass this in for between
1447        graph multi-worker cases where there is only one `input_worker`. In
1448        these cases, we will shard based on the `input_pipeline_id` and
1449        `num_input_pipelines` in the `InputContext`.
1450    """
1451    dist_dataset = DistributedDatasetV1(
1452        dataset,
1453        input_workers,
1454        strategy,
1455        num_replicas_in_sync=num_replicas_in_sync,
1456        input_context=input_context)
1457    worker_iterators = _create_iterators_per_worker(
1458        dist_dataset._cloned_datasets, input_workers, True)  # pylint: disable=protected-access
1459    super(DatasetIterator,
1460          self).__init__(input_workers, worker_iterators, strategy,
1461                         dist_dataset._enable_get_next_as_optional)  # pylint: disable=protected-access
1462    self._element_spec = dist_dataset.element_spec
1463
1464
1465def _dummy_tensor_fn(value_structure):
1466  """A function to create dummy tensors from `value_structure`."""
1467
1468  def create_dummy_tensor(spec):
1469    """Create a dummy tensor with possible batch dimensions set to 0."""
1470    if isinstance(spec, ragged_tensor.RaggedTensorSpec):
1471      # Splice out the ragged dimensions.
1472      # pylint: disable=protected-access
1473      feature_shape = spec._shape[:1].concatenate(
1474          spec._shape[(1 + spec._ragged_rank):])
1475      feature_type = spec._dtype
1476      # pylint: enable=protected-access
1477    else:
1478      feature_shape = spec.shape
1479      feature_type = spec.dtype
1480    # Ideally we should set the batch dimension to 0, however as in
1481    # DistributionStrategy we don't know the batch dimension, we try to
1482    # guess it as much as possible. If the feature has unknown dimensions, we
1483    # will set them to 0. If the feature shape is already static, we guess the
1484    # first dimension as batch dimension and set it to 0.
1485    dims = ([dim if dim is not None else 0 for dim in feature_shape.as_list()]
1486            if feature_shape else [])
1487    if dims and (isinstance(spec, ragged_tensor.RaggedTensorSpec) or
1488                 feature_shape.is_fully_defined()):
1489      dims[0] = tensor_shape.Dimension(0)
1490
1491    if isinstance(spec, sparse_tensor.SparseTensorSpec):
1492      return sparse_tensor.SparseTensor(
1493          values=array_ops.zeros(0, feature_type),
1494          indices=array_ops.zeros((0, len(dims)), dtypes.int64),
1495          dense_shape=dims)
1496
1497    # Create the dummy tensor.
1498    dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type)
1499    if isinstance(spec, ragged_tensor.RaggedTensorSpec):
1500      # Reinsert the ragged dimensions with size 0.
1501      # pylint: disable=protected-access
1502      row_splits = array_ops.zeros(1, spec._row_splits_dtype)
1503      dummy_tensor = ragged_tensor.RaggedTensor.from_nested_row_splits(
1504          dummy_tensor, (row_splits,) * spec._ragged_rank, validate=False)
1505      # pylint: enable=protected-access
1506    return dummy_tensor
1507
1508  return nest.map_structure(create_dummy_tensor, value_structure)
1509
1510
1511def _recover_shape_fn(data, value_structure):
1512  """Recover the shape of `data` the same as shape of `value_structure`."""
1513
1514  flattened_data = nest.flatten(data)
1515  for i, spec in enumerate(nest.flatten(value_structure)):
1516    for target, source in zip(
1517        nest.flatten(flattened_data[i], expand_composites=True),
1518        nest.flatten(spec, expand_composites=True)):
1519      target.set_shape(source.shape)
1520    # `SparseTensor` shape is not determined by the shape of its component
1521    # tensors. Rather, its shape depends on a tensor's values.
1522    if isinstance(spec, sparse_tensor.SparseTensorSpec) and spec.shape:
1523      dense_shape = spec.shape
1524      with ops.device(flattened_data[i].op.device):
1525        # For partially defined shapes, fill in missing values from tensor.
1526        if not dense_shape.is_fully_defined():
1527          dense_shape = array_ops.stack([
1528              flattened_data[i].dense_shape[j] if dim is None else dim
1529              for j, dim in enumerate(dense_shape.as_list())
1530          ])
1531        flattened_data[i] = sparse_tensor.SparseTensor(
1532            indices=flattened_data[i].indices,
1533            values=flattened_data[i].values,
1534            dense_shape=dense_shape)
1535  data = nest.pack_sequence_as(data, flattened_data)
1536  return data
1537
1538
1539class _SingleWorkerDatasetIteratorBase(object):
1540  """Iterator for a single `tf.data.Dataset`."""
1541
1542  def __init__(self, dataset, worker, devices, options=None):
1543    """Create iterator for the `dataset` to fetch data to worker's `devices` .
1544
1545    A `MultiDeviceIterator`  or `OwnedMultiDeviceIterator` is used to prefetch
1546    input to the devices on the given worker.
1547
1548    Args:
1549      dataset: A `tf.data.Dataset` instance.
1550      worker: Worker on which ops should be created.
1551      devices: Distribute data from `dataset` to these devices.
1552      options: options.
1553    """
1554    self._dataset = dataset
1555    self._worker = worker
1556    self._devices = devices
1557    self._element_spec = dataset.element_spec
1558    self._options = options
1559    self._make_iterator()
1560
1561  def _make_iterator(self):
1562    raise NotImplementedError("must be implemented in descendants")
1563
1564  def _format_data_list_with_options(self, data_list):
1565    """Change the data in to a list type if required.
1566
1567    The OwnedMultiDeviceIterator returns the list data type,
1568    while the PER_REPLICA iterator (when used with prefetch disabled)
1569    returns without the enclosed list. This is to fix the inconsistency.
1570    Args:
1571      data_list: data_list
1572    Returns:
1573      list
1574    """
1575    if (self._options and self._options.experimental_replication_mode ==
1576        InputReplicationMode.PER_REPLICA and
1577        not self._options.experimental_prefetch_to_device):
1578      return [data_list]
1579    else:
1580      return data_list
1581
1582  def get_next(self, device, name=None):
1583    """Get next element for the given device."""
1584    del name
1585    with ops.device(self._worker):
1586      if _should_use_multi_device_iterator(self._options):
1587        return self._iterator.get_next(device)
1588      else:
1589        return self._iterator.get_next()
1590
1591  def get_next_as_list_static_shapes(self, name=None):
1592    """Get next element from the underlying iterator.
1593
1594    Runs the iterator get_next() within a device scope. Since this doesn't use
1595    get_next_as_optional(), it is considerably faster than get_next_as_list()
1596    (but can only be used when the shapes are static).
1597
1598    Args:
1599      name: not used.
1600
1601    Returns:
1602      A list consisting of the next data from each device.
1603    """
1604    del name
1605    with ops.device(self._worker):
1606      return self._format_data_list_with_options(self._iterator.get_next())
1607
1608  def get_next_as_list(self, name=None):
1609    """Get next element from underlying iterator.
1610
1611    If there is no data left, a list of dummy tensors with possible batch
1612    dimensions set to 0 will be returned. Use of get_next_as_optional() and
1613    extra logic adds overhead compared to get_next_as_list_static_shapes(), but
1614    allows us to handle non-static shapes.
1615
1616    Args:
1617      name: not used.
1618
1619    Returns:
1620      A boolean tensor indicates whether there is any data in next element and
1621      the real data as the next element or a list of dummy tensors if no data
1622      left.
1623    """
1624    del name
1625    with ops.device(self._worker):
1626      data_list = self._format_data_list_with_options(
1627          self._iterator.get_next_as_optional())
1628      result = []
1629      for i, data in enumerate(data_list):
1630        # Place the condition op in the same device as the data so the data
1631        # doesn't need to be sent back to the worker.
1632        with ops.device(self._devices[i]):
1633          # Data will be fetched in order, so we only need to check if the first
1634          # replica has value to see whether there is data left for this single
1635          # worker.
1636          if i == 0:
1637            worker_has_value = data.has_value()
1638
1639          # pylint: disable=unnecessary-lambda
1640          # pylint: disable=cell-var-from-loop
1641          real_data = control_flow_ops.cond(
1642              data.has_value(),
1643              lambda: data.get_value(),
1644              lambda: _dummy_tensor_fn(data.element_spec),
1645              strict=True,
1646          )
1647          # Some dimensions in `replicas` will become unknown after we
1648          # conditionally return the real tensors or the dummy tensors. Recover
1649          # the shapes from `data.element_spec`. We only need to do this in
1650          # non eager mode because we always know the runtime shape of the
1651          # tensors in eager mode.
1652          if not context.executing_eagerly():
1653            real_data = _recover_shape_fn(real_data, data.element_spec)
1654          result.append(real_data)
1655          # pylint: enable=cell-var-from-loop
1656          # pylint: enable=unnecessary-lambda
1657
1658      return worker_has_value, result
1659
1660
1661class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec):
1662  """Type specification for `_SingleWorkerOwnedDatasetIterator`."""
1663
1664  __slots__ = ["_worker", "_devices", "_element_spec", "_options"]
1665
1666  def __init__(self, worker, devices, element_spec, options):
1667    self._worker = worker
1668    self._devices = tuple(device_util.canonicalize(d) for d in devices)
1669    self._element_spec = element_spec
1670    self._options = options
1671
1672  @property
1673  def value_type(self):
1674    return _SingleWorkerOwnedDatasetIterator
1675
1676  def _serialize(self):
1677    return (self._worker, self._devices, self._element_spec, self._options)
1678
1679  @property
1680  def _component_specs(self):
1681    specs = []
1682    if _should_use_multi_device_iterator(self._options):
1683      specs.append(
1684          multi_device_iterator_ops.MultiDeviceIteratorSpec(
1685              self._devices, self._worker, element_spec=self._element_spec))
1686    else:
1687      specs.append(iterator_ops.IteratorSpec(element_spec=self._element_spec))
1688    return specs
1689
1690  def _to_components(self, value):
1691    return [value._iterator]  # pylint: disable=protected-access
1692
1693  def _from_components(self, components):
1694    return _SingleWorkerOwnedDatasetIterator(
1695        dataset=None,
1696        worker=self._worker,
1697        devices=self._devices,
1698        components=components,
1699        element_spec=self._element_spec,
1700        options=self._options)
1701
1702  @staticmethod
1703  def from_value(value):
1704    # pylint: disable=protected-access
1705    return _SingleWorkerDatasetIteratorSpec(value._worker, value._devices,
1706                                            value._element_spec, value._options)
1707
1708
1709class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
1710                                        composite_tensor.CompositeTensor):
1711  """Iterator for a DistributedDataset instance."""
1712
1713  def __init__(self,
1714               dataset=None,
1715               worker=None,
1716               devices=None,
1717               components=None,
1718               element_spec=None,
1719               options=None):
1720    """Create iterator for the `dataset` to fetch data to worker's `devices` .
1721
1722    `OwnedMultiDeviceIterator` is used to prefetch input to the devices on the
1723    given worker. The lifetime of this iterator is tied to the encompassing
1724    python object. Once we go out of scope of the python object or return from
1725    a tf.function the underlying iterator resource is deleted.
1726
1727    Args:
1728      dataset: A `tf.data.Dataset` instance.
1729      worker: Worker on which ops should be created.
1730      devices: Distribute data from `dataset` to these devices.
1731      components: Tensor components to construct the
1732        _SingleWorkerOwnedDatasetIterator from.
1733      element_spec: A nested structure of `TypeSpec` objects that represents the
1734      type specification of elements of the iterator.
1735      options: `tf.distribute.InputOptions` used to control options on how this
1736      dataset is distributed.
1737    """
1738    if worker is None or devices is None:
1739      raise ValueError("Both `worker` and `devices` should be provided")
1740
1741    error_message = ("Either `dataset` or both `components` and `element_spec` "
1742                     "need to be provided.")
1743
1744    self._options = options
1745    if dataset is None:
1746      if (components is None or element_spec is None):
1747        raise ValueError(error_message)
1748      self._element_spec = element_spec
1749      self._worker = worker
1750      self._devices = devices
1751      self._iterator = components[0]
1752    else:
1753      if (components is not None or element_spec is not None):
1754        raise ValueError(error_message)
1755      super(_SingleWorkerOwnedDatasetIterator,
1756            self).__init__(dataset, worker, devices, options)
1757
1758  def _make_iterator(self):
1759    """Make appropriate iterator on the dataset."""
1760    if not self._worker:
1761      raise ValueError("Worked device must be specified when creating an "
1762                       "owned iterator.")
1763    if _should_use_multi_device_iterator(self._options):
1764      host_device = device_util.get_host_for_device(self._worker)
1765      with ops.device(self._worker):
1766        self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
1767            self._dataset, self._devices, source_device=host_device)
1768    else:
1769      with ops.device(self._worker):
1770        self._iterator = iter(self._dataset)
1771
1772  @property
1773  def element_spec(self):
1774    return self._element_spec
1775
1776  @property
1777  def _type_spec(self):
1778    return _SingleWorkerDatasetIteratorSpec(self._worker, self._devices,
1779                                            self._element_spec, self._options)
1780
1781  @property
1782  def output_classes(self):
1783    """Returns the class of each component of an element of this iterator.
1784
1785    The expected values are `tf.Tensor` and `tf.SparseTensor`.
1786
1787    Returns:
1788      A nested structure of Python `type` objects corresponding to each
1789      component of an element of this dataset.
1790    """
1791    return nest.map_structure(
1792        lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
1793        self._element_spec)
1794
1795  @property
1796  def output_shapes(self):
1797    """Returns the shape of each component of an element of this iterator.
1798
1799    Returns:
1800      A nested structure of `tf.TensorShape` objects corresponding to each
1801      component of an element of this dataset.
1802    """
1803    return nest.map_structure(
1804        lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
1805        self._element_spec)
1806
1807  @property
1808  def output_types(self):
1809    """Returns the type of each component of an element of this iterator.
1810
1811    Returns:
1812      A nested structure of `tf.DType` objects corresponding to each component
1813      of an element of this dataset.
1814    """
1815    return nest.map_structure(
1816        lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
1817        self._element_spec)
1818
1819
1820class _SingleWorkerDatasetIterator(_SingleWorkerDatasetIteratorBase):
1821  """Iterator for a single DistributedDatasetV1 instance."""
1822
1823  def _make_iterator(self):
1824    """Make appropriate iterator on the dataset."""
1825    with ops.device(self._worker):
1826      self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
1827          self._dataset, self._devices)
1828
1829  def initialize(self):
1830    """Initialize underlying iterator.
1831
1832    In eager execution, this simply recreates the underlying iterator.
1833    In graph execution, it returns the initializer ops for the underlying
1834    iterator.
1835
1836    Returns:
1837      A list of any initializer ops that should be run.
1838    """
1839    if ops.executing_eagerly_outside_functions():
1840      self._iterator._eager_reset()  # pylint: disable=protected-access
1841      return []
1842    else:
1843      return [self._iterator.initializer]
1844
1845  @property
1846  def output_classes(self):
1847    return dataset_ops.get_legacy_output_classes(self._iterator)
1848
1849  @property
1850  def output_shapes(self):
1851    return dataset_ops.get_legacy_output_shapes(self._iterator)
1852
1853  @property
1854  def output_types(self):
1855    return dataset_ops.get_legacy_output_types(self._iterator)
1856
1857
1858class _SingleWorkerCallableIterator(object):
1859  """Iterator for a single tensor-returning callable."""
1860
1861  def __init__(self, fn, worker, devices):
1862    self._fn = fn
1863    self._worker = worker
1864    self._devices = devices
1865
1866  def get_next(self, device, name=None):
1867    """Get next element for the given device from the callable."""
1868    del device, name
1869    with ops.device(self._worker):
1870      return self._fn()
1871
1872  def get_next_as_list_static_shapes(self, name=None):
1873    """Get next element from the callable."""
1874    del name
1875    with ops.device(self._worker):
1876      data_list = [self._fn() for _ in self._devices]
1877      return data_list
1878
1879  def get_next_as_list(self, name=None):
1880    """Get next element from the callable."""
1881    del name
1882    with ops.device(self._worker):
1883      data_list = [self._fn() for _ in self._devices]
1884      return constant_op.constant(True), data_list
1885
1886  def initialize(self):
1887    # TODO(petebu) Should this throw an exception instead?
1888    return []
1889
1890
1891def _create_iterators_per_worker(worker_datasets,
1892                                 input_workers,
1893                                 enable_legacy_iterators,
1894                                 options=None):
1895  """Create a multidevice iterator on each of the workers."""
1896  assert isinstance(input_workers, InputWorkers)
1897  assert len(worker_datasets) == len(input_workers.worker_devices)
1898  iterators = []
1899  for i, worker in enumerate(input_workers.worker_devices):
1900    with ops.device(worker):
1901      worker_devices = input_workers.compute_devices_for_worker(i)
1902      if tf2.enabled() and not enable_legacy_iterators:
1903        iterator = _SingleWorkerOwnedDatasetIterator(
1904            dataset=worker_datasets[i],
1905            worker=worker,
1906            devices=worker_devices,
1907            options=options)
1908      else:
1909        iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker,
1910                                                worker_devices, options)
1911      iterators.append(iterator)
1912  return iterators
1913
1914
1915def _create_datasets_from_function_with_input_context(input_contexts,
1916                                                      input_workers,
1917                                                      dataset_fn):
1918  """Create device datasets per worker given a dataset function."""
1919  datasets = []
1920  for i, ctx in enumerate(input_contexts):
1921    worker = input_workers.worker_devices[i]
1922    with ops.device(worker):
1923      dataset = dataset_fn(ctx)
1924      datasets.append(dataset)
1925  return datasets, dataset.element_spec
1926
1927
1928# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
1929def _get_batched_dataset(d):
1930  """Get the batched dataset from `d`."""
1931  # pylint: disable=protected-access
1932  if isinstance(d, dataset_ops.DatasetV1Adapter):
1933    d = d._dataset
1934
1935  if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)):
1936    return d
1937  elif isinstance(d, (dataset_ops.PrefetchDataset,
1938                      dataset_ops._OptionsDataset)):
1939    return _get_batched_dataset(d._input_dataset)
1940
1941  raise ValueError(
1942      "Unable to get batched dataset from the input dataset. `batch` "
1943      "`map_and_batch` need to be the last operations on the dataset. "
1944      "The batch operations can be followed by a prefetch.")
1945
1946
1947def _get_batched_dataset_attributes(d):
1948  """Get `batch_size`, `drop_remainder` of dataset."""
1949  # pylint: disable=protected-access
1950  assert isinstance(d,
1951                    (dataset_ops.BatchDataset, batching._MapAndBatchDataset))
1952  if isinstance(d, dataset_ops.BatchDataset):
1953    batch_size = d._batch_size
1954    drop_remainder = d._drop_remainder
1955  elif isinstance(d, batching._MapAndBatchDataset):
1956    batch_size = d._batch_size_t
1957    drop_remainder = d._drop_remainder_t
1958  # pylint: enable=protected-access
1959
1960  if tensor_util.is_tf_type(batch_size):
1961    batch_size = tensor_util.constant_value(batch_size)
1962
1963  if tensor_util.is_tf_type(drop_remainder):
1964    drop_remainder = tensor_util.constant_value(drop_remainder)
1965
1966  return batch_size, drop_remainder
1967
1968
1969# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
1970def _get_dataset_attributes(dataset):
1971  """Get the underlying attributes from the dataset object."""
1972  # pylint: disable=protected-access
1973
1974  # First, get batch_size and drop_remainder from the dataset. We need
1975  # to walk back the dataset creation process and find the batched version in
1976  # order to get the attributes.
1977  batched_dataset = _get_batched_dataset(dataset)
1978  batch_size, drop_remainder = _get_batched_dataset_attributes(batched_dataset)
1979
1980  # Second, prefetch buffer should be get from the original dataset.
1981  prefetch_buffer = None
1982  if isinstance(dataset, dataset_ops.PrefetchDataset):
1983    prefetch_buffer = dataset._buffer_size
1984  elif (isinstance(dataset, dataset_ops.DatasetV1Adapter)
1985        and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)):
1986    prefetch_buffer = dataset._dataset._buffer_size
1987
1988  return batch_size, drop_remainder, prefetch_buffer
1989
1990
1991def _should_use_multi_device_iterator(options):
1992  """Determine whether to use multi_device_iterator_ops."""
1993  if (options is None or
1994      options.experimental_replication_mode == InputReplicationMode.PER_WORKER
1995      or
1996      (options.experimental_replication_mode == InputReplicationMode.PER_REPLICA
1997       and options.experimental_prefetch_to_device)):
1998    return True
1999  return False
2000
2001
2002class MultiStepContext(object):
2003  """A context object that can be used to capture things when running steps.
2004
2005  This context object is useful when running multiple steps at a time using the
2006  `experimental_run_steps_on_iterator` API. For e.g. it allows the user's step
2007  function to specify which outputs to emit at what frequency. Currently it
2008  supports capturing output from the last step, as well as capturing non tensor
2009  outputs.  In the future it will be augmented to support other use cases such
2010  as output each N steps.
2011  """
2012
2013  def __init__(self):
2014    """Initialize an output context.
2015
2016    Returns:
2017      A context object.
2018    """
2019    self._last_step_outputs = {}
2020    self._last_step_outputs_reduce_ops = {}
2021    self._non_tensor_outputs = {}
2022
2023  @property
2024  def last_step_outputs(self):
2025    """A dictionary consisting of outputs to be captured on last step.
2026
2027    Keys in the dictionary are names of tensors to be captured, as specified
2028    when `set_last_step_output` is called.
2029    Values in the dictionary are the tensors themselves. If
2030    `set_last_step_output` was called with a `reduce_op` for this output,
2031    then the value is the reduced value.
2032
2033    Returns:
2034      A dictionary with last step outputs.
2035    """
2036    return self._last_step_outputs
2037
2038  def _set_last_step_outputs(self, outputs):
2039    """Replace the entire dictionary of last step outputs."""
2040    if not isinstance(outputs, dict):
2041      raise ValueError("Need a dictionary to set last_step_outputs.")
2042    self._last_step_outputs = outputs
2043
2044  def set_last_step_output(self, name, output, reduce_op=None):
2045    """Set `output` with `name` to be outputted from the last step.
2046
2047    Args:
2048      name: String, name to identify the output. Doesn't need to match tensor
2049        name.
2050      output: The tensors that should be outputted with `name`. See below for
2051        actual types supported.
2052      reduce_op: Reduction method to use to reduce outputs from multiple
2053        replicas. Required if `set_last_step_output` is called in a replica
2054        context. Optional in cross_replica_context.
2055        When present, the outputs from all the replicas are reduced using the
2056        current distribution strategy's `reduce` method. Hence, the type of
2057        `output` must be what's supported by the corresponding `reduce` method.
2058        For e.g. if using MirroredStrategy and reduction is set, output
2059        must be a `PerReplica` value.
2060        The reduce method is also recorded in a dictionary
2061        `_last_step_outputs_reduce_ops` for later interpreting of the
2062        outputs as already reduced or not.
2063    """
2064    if distribution_strategy_context.in_cross_replica_context():
2065      self._last_step_outputs_reduce_ops[name] = reduce_op
2066      if reduce_op is None:
2067        self._last_step_outputs[name] = output
2068      else:
2069        distribution = distribution_strategy_context.get_strategy()
2070        self._last_step_outputs[name] = distribution.reduce(reduce_op, output,
2071                                                            axis=None)
2072    else:
2073      assert reduce_op is not None
2074      def merge_fn(distribution, value):
2075        self._last_step_outputs[name] = distribution.reduce(reduce_op, value,
2076                                                            axis=None)
2077        # Setting this inside the `merge_fn` because all replicas share the same
2078        # context object, so it's more robust to set it only once (even if all
2079        # the replicas are trying to set the same value).
2080        self._last_step_outputs_reduce_ops[name] = reduce_op
2081
2082      distribution_strategy_context.get_replica_context().merge_call(
2083          merge_fn, args=(output,))
2084
2085  @property
2086  def non_tensor_outputs(self):
2087    """A dictionary consisting of any non tensor outputs to be captured."""
2088    return self._non_tensor_outputs
2089
2090  def set_non_tensor_output(self, name, output):
2091    """Set `output` with `name` to be captured as a non tensor output."""
2092    if distribution_strategy_context.in_cross_replica_context():
2093      self._non_tensor_outputs[name] = output
2094    else:
2095      def merge_fn(distribution, value):
2096        # NOTE(priyag): For non tensor outputs, we simply return all the values
2097        # in a list as reduction doesn't make sense on non tensors.
2098        self._non_tensor_outputs[name] = (
2099            distribution.experimental_local_results(value))
2100      distribution_strategy_context.get_replica_context().merge_call(
2101          merge_fn, args=(output,))
2102
2103
2104def _create_distributed_tensor_spec(strategy, tensor_spec):
2105  """Create a `tf.TypeSpec` for a given strategy and input `tensor_spec`.
2106
2107  Args:
2108    strategy: The given `tf.distribute` strategy.
2109    tensor_spec: `tf.TensorSpec` of a given value. The batch dimension of the
2110      shape should be None if you have partial batches.
2111
2112  Returns:
2113    A `tf.TypeSpec` that matches the values produced by a given strategy. This
2114    can be a `tf.TensorSpec` or a `PerRelicaSpec`.
2115  """
2116  num_replicas = len(strategy.extended.worker_devices)
2117
2118  # For one device strategy that is not MultiWorkerMirroredStrategy,  return the
2119  # tensor_spec as is, since we don't wrap the output with PerReplica in this
2120  # case.
2121  # TODO(b/166464552): remove after we always wrap for all strategies.
2122  if not _always_wrap(strategy):
2123    return tensor_spec
2124
2125  # For other cases we assume the input to tf.function is a per replica type.
2126  def _get_value_per_replica(tensor_spec_per_input):
2127    value_specs = [tensor_spec_per_input for _ in range(num_replicas)]
2128    return values.PerReplicaSpec(*value_specs)
2129
2130  return nest.map_structure(_get_value_per_replica, tensor_spec)
2131
2132
2133def _replace_per_replica_spec(spec, i):
2134  """If `spec` is a `PerReplicaSpec`, then return its `i`th value_spec."""
2135  if isinstance(spec, values.PerReplicaSpec):
2136    return spec._value_specs[i]  # pylint: disable=protected-access
2137  else:
2138    return spec
2139
2140
2141def _enable_get_next_as_optional(strategy, dataset):
2142  """Returns whether to enable using partial batch handling."""
2143  # TODO(b/133073708): we currently need a flag to control the usage because
2144  # there is a performance difference between get_next() and
2145  # get_next_as_optional(). And we only enable get_next_as_optional when the
2146  # output shapes are not static.
2147  #
2148  # TODO(rxsang): We want to always enable the get_next_as_optional behavior
2149  # when user passed input_fn instead of dataset.
2150  if not getattr(strategy.extended, "experimental_enable_get_next_as_optional",
2151                 False):
2152    return False
2153
2154  if context.executing_eagerly():
2155    # If the dataset is infinite, we don't need to enable last partial batch
2156    # support. Currently the logic only applies to the case that distributed
2157    # dataset is created in eager mode, as we need to evaluate the dataset
2158    # cardinality.
2159    with ops.device(dataset._variant_tensor.device):  # pylint: disable=protected-access
2160      if dataset.cardinality().numpy() == cardinality.INFINITE:
2161        return False
2162
2163  return not _is_statically_shaped(
2164      dataset.element_spec) or strategy.extended._in_multi_worker_mode()  # pylint: disable=protected-access
2165
2166
2167def _create_per_replica(value_list, strategy):
2168  """Creates a PerReplica.
2169
2170  For strategies other than OneDeviceStrategy, it creates a PerReplica whose
2171  type spec is set to the element spec of the dataset. This helps avoid
2172  retracing for partial batches. Retracing is problematic for multi client when
2173  different client retraces different time, since retracing changes the
2174  collective keys in the tf.function, and causes mismatches among clients.
2175
2176  For single client strategies, this simply calls distribute_utils.regroup().
2177
2178  Args:
2179    value_list: a list of values, one for each replica.
2180    strategy: the `tf.distribute.Strategy`.
2181
2182  Returns:
2183    a structure of PerReplica.
2184
2185  """
2186  # TODO(b/166464552): always wrap for all one device strategies as well.
2187  always_wrap = _always_wrap(strategy)
2188  per_replicas = distribute_utils.regroup(value_list, always_wrap=always_wrap)
2189  return per_replicas
2190
2191
2192def _always_wrap(strategy):
2193  """Returns whether to always wrap the values in a DistributedValues."""
2194  return strategy.extended._in_multi_worker_mode() or len(  # pylint: disable=protected-access
2195      strategy.extended.worker_devices) > 1
2196
2197
2198def _rebatch_as_dynamic(per_replica_spec):
2199  """Rebatch the spec to have a dynamic batch dimension."""
2200  assert isinstance(per_replica_spec, values.PerReplicaSpec), per_replica_spec
2201
2202  # pylint: disable=protected-access
2203  def _rebatch(spec):
2204    # Rebatch if possible.
2205    try:
2206      return spec._unbatch()._batch(None)
2207    except ValueError:
2208      pass
2209    return spec
2210
2211  return values.PerReplicaSpec(
2212      *nest.map_structure(_rebatch, per_replica_spec._value_specs))
2213  # pylint: enable=protected-access
2214