1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various classes representing distributed inputs."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.data.experimental.ops import batching
22from tensorflow.python.data.ops import dataset_ops
23from tensorflow.python.data.ops import multi_device_iterator_ops
24from tensorflow.python.data.util import structure
25from tensorflow.python.distribute import device_util
26from tensorflow.python.distribute import distribution_strategy_context
27from tensorflow.python.distribute import input_ops
28from tensorflow.python.distribute import values
29from tensorflow.python.eager import context
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import device as tf_device
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import tensor_shape
34from tensorflow.python.framework import tensor_util
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import math_ops
38from tensorflow.python.util import nest
39
40
41class InputWorkers(object):
42  """A 1-to-many mapping from input worker devices to compute devices."""
43
44  def __init__(self, device_map, worker_device_pairs=None, logical_device=0):
45    """Initialize an `InputWorkers` object.
46
47    Args:
48      device_map: A `DeviceMap` with the computation devices fed by the
49        input workers.
50      worker_device_pairs: A sequence of pairs:
51        `(input device, a tuple of compute devices fed by that input device)`.
52      logical_device: The logical device of `device_map` to feed.
53    """
54    self._device_map = device_map
55    self._logical_device = logical_device
56    if worker_device_pairs is None:
57      worker_device_pairs = ((
58          device_util.canonicalize("/device:CPU:0"),
59          device_map.logical_to_actual_devices(logical_device)),)
60    self._input_worker_devices = tuple(d for d, _ in worker_device_pairs)
61    self._fed_devices = tuple(tuple(device_util.canonicalize(d) for d in f)
62                              for _, f in worker_device_pairs)
63    flattened = tuple(d for l in self._fed_devices for d in l)
64    assert (flattened ==
65            device_map.logical_to_actual_devices(logical_device)), (
66                "flattened: %s logical device %d: %s" %
67                (flattened, logical_device,
68                 device_map.logical_to_actual_devices(logical_device)))
69
70  @property
71  def device_map(self):
72    return self._device_map
73
74  @property
75  def logical_device(self):
76    return self._logical_device
77
78  @property
79  def num_workers(self):
80    return len(self._input_worker_devices)
81
82  @property
83  def worker_devices(self):
84    return self._input_worker_devices
85
86  def compute_devices_for_worker(self, worker_index):
87    return self._fed_devices[worker_index]
88
89  def __repr__(self):
90    devices = self.worker_devices
91    debug_repr = ",\n".join("  %d %s: %s" %
92                            (i, devices[i], self._fed_devices[i])
93                            for i in range(len(devices)))
94    return "%s:{\n%s\n  device_map: %s}" % (
95        self.__class__.__name__, debug_repr, self._device_map)
96
97
98class InputIterator(object):
99  """An input iterator, intended to be passed to `DistributionStrategy.run`."""
100
101  def get_next(self):
102    """Returns the next inputs for all replicas."""
103    raise NotImplementedError("must be implemented in descendants")
104
105  def initialize(self):
106    """Initialize the underlying input dataset, when applicable.
107
108    In eager mode, this will create a new iterator and return it.
109    In graph mode, this will initialize the same underlying iterator(s).
110
111    Users are required to call this if
112    - This iterator was returned from a call to `make_input_fn_iterator` with an
113      input function that returns a dataset.
114    - Or this iterator was returned from a call to `make_dataset_iterator`.
115
116    Returns:
117      A list of initialization ops to be executed.
118    """
119    raise NotImplementedError("must be implemented in descendants")
120
121
122class InputIteratorImpl(InputIterator):
123  """Common implementation for all input iterators."""
124
125  def __init__(self, input_workers, iterators):
126    assert isinstance(input_workers, InputWorkers)
127    if not input_workers.worker_devices:
128      raise ValueError("Should have at least one worker for input iterator.")
129
130    self._iterators = iterators
131    self._input_workers = input_workers
132
133  def get_next(self, name=None):
134    """Returns the next input from the iterator for all replicas."""
135    replicas = []
136    worker_has_values = []
137    for i, worker in enumerate(self._input_workers.worker_devices):
138      if name is not None:
139        d = tf_device.DeviceSpec.from_string(worker)
140        new_name = "%s_%s_%d" % (name, d.job, d.task)
141      else:
142        new_name = None
143      with ops.device(worker):
144        worker_has_value, next_element = (
145            self._iterators[i].get_next_as_list(new_name))
146        worker_has_values.append(worker_has_value)
147        # Make `replicas` a flat list of values across all replicas.
148        replicas.append(next_element)
149
150    out_of_range_replicas = []
151
152    def out_of_range_fn(worker_index, device):
153      """This function will throw an OutOfRange error."""
154      # As this will be only called when there is no data left, so calling
155      # get_next() will trigger an OutOfRange error.
156      data = self._iterators[worker_index].get_next(device)
157      out_of_range_replicas.append(data)
158      return data
159
160    # `global_has_value` indicates whether there is data in this global batch.
161    # We do a all-reduce across all the workers in the multi-worker case.
162    # TODO(b/126259107): Do strategy.reduce for CollectiveAllReduceStrategy.
163    if len(worker_has_values) > 1:
164      with ops.device(self._input_workers.compute_devices_for_worker(0)[0]):
165        # Place the tf.reduce_any op in device 0 to minimize communication
166        # cost.
167        # TODO(b/128545270): Investigate why placing it on worker 0 will cause
168        # the entire data to copy back from device to host.
169        global_has_value = math_ops.reduce_any(worker_has_values)
170    else:
171      global_has_value = worker_has_values[0]
172
173    results = []
174    for i, worker in enumerate(self._input_workers.worker_devices):
175      with ops.device(worker):
176        devices = self._input_workers.compute_devices_for_worker(i)
177        for j, device in enumerate(devices):
178          with ops.device(device):
179            # pylint: disable=undefined-loop-variable
180            # pylint: disable=cell-var-from-loop
181            # It is fine for the lambda to capture variables from the loop as
182            # the lambda is executed in the loop as well.
183            result = control_flow_ops.cond(global_has_value,
184                                           lambda: replicas[i][j],
185                                           lambda: out_of_range_fn(i, device))
186            # pylint: enable=cell-var-from-loop
187            # pylint: enable=undefined-loop-variable
188            results.append(result)
189    replicas = results
190
191    # Some dimensions in `replicas` will become unknown after we conditionally
192    # return the real tensors or the dummy tensors. We fix the input shapes by
193    # using the shapes from `out_of_range_replicas` because it is calling
194    # get_next() inside.
195    flattened_replicas = nest.flatten(replicas)
196    for i, replica_data in enumerate(nest.flatten(out_of_range_replicas)):
197      flattened_replicas[i].set_shape(replica_data.get_shape())
198    replicas = nest.pack_sequence_as(replicas, flattened_replicas)
199
200    return values.regroup(self._input_workers.device_map, replicas)
201
202  def initialize(self):
203    """Initialze underlying iterators.
204
205    Returns:
206      A list of any initializer ops that should be run.
207    """
208    init_ops = []
209    for it in self._iterators:
210      init_ops.extend(it.initialize())
211    return init_ops
212
213  # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
214  @property
215  def output_classes(self):
216    return self._iterators[0].output_classes
217
218  # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
219  @property
220  def output_shapes(self):
221    return self._iterators[0].output_shapes
222
223  # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
224  @property
225  def output_types(self):
226    return self._iterators[0].output_types
227
228  # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
229  def get_iterator(self, worker):
230    for i, w in enumerate(self._input_workers.worker_devices):
231      if worker == w:
232        return self._iterators[i]
233    return None
234
235
236class InputFunctionIterator(InputIteratorImpl):
237  """Iterator created from input function."""
238
239  def __init__(self, input_fn, input_workers, input_contexts):
240    """Make an iterator for input provided via an input function.
241
242    Currently implements PER_WORKER mode, in which the `input_fn` is called
243    once on each worker.
244
245    TODO(priyag): Add other replication modes.
246
247    Args:
248      input_fn: Input function that returns a `tf.data.Dataset` object.
249      input_workers: an `InputWorkers` object.
250      input_contexts: A list of `InputContext` instances to be passed to call(s)
251        to `input_fn`. Length and order should match worker order in
252        `worker_device_pairs`.
253    """
254    assert isinstance(input_workers, InputWorkers)
255    if input_workers.num_workers != len(input_contexts):
256      raise ValueError(
257          "Number of input workers (%d) is not same as number of "
258          "input_contexts (%d)" %
259          (input_workers.num_workers, len(input_contexts)))
260
261    iterators = []
262    for i, ctx in enumerate(input_contexts):
263      worker = input_workers.worker_devices[i]
264      with ops.device(worker):
265        result = input_fn(ctx)
266        devices = input_workers.compute_devices_for_worker(i)
267        if isinstance(result, dataset_ops.DatasetV2):
268          iterator = _SingleWorkerDatasetIterator(result, worker, devices)
269        elif callable(result):
270          iterator = _SingleWorkerCallableIterator(result, worker, devices)
271        else:
272          raise ValueError(
273              "input_fn must return a tf.data.Dataset or a callable.")
274        iterators.append(iterator)
275
276    super(InputFunctionIterator, self).__init__(input_workers, iterators)
277
278
279class DatasetIterator(InputIteratorImpl):
280  """Iterator created from input dataset."""
281
282  def __init__(self, dataset, input_workers, split_batch_by=None):
283    """Make an iterator for the dataset on given devices.
284
285    If `split_batch_by` is not None, we "split" each batch of the
286    dataset by `split_batch_by` value. To achieve this, we first unbatch the
287    input dataset and then rebatch it with the per replica batch size that is
288    calculated using `global_batch_size // split_batch_by`.
289    The currently supported datasets are as follows:
290    `dataset.batch()` is the last operation on the dataset OR
291    `dataset.apply(map_and_batch)` is the last operation on the dataset OR
292    `dataset.batch().prefetch()` are the last 2 operations on the dataset OR
293    `dataset.apply(map_and_batch).prefetch()` are the last 2 operations.
294
295    TODO(priyag): Support multi worker / host cases properly by cloning
296    and sharding the dataset on each worker. Current setup will only work in
297    some cases, such as in-graph multi worker GPU case. If the input pipeline
298    has random shuffling (with a different seed on each worker), each worker
299    will see random input from the same overall dataset in each step. Otherwise,
300    each worker will see the same input in each step.
301
302    Args:
303      dataset: `tf.data.Dataset` that will be used as the input source.
304      input_workers: an `InputWorkers` object.
305      split_batch_by: Optional integer. If present, we "split" each batch of the
306        dataset by `split_batch_by` value.
307    """
308    assert isinstance(input_workers, InputWorkers)
309    if split_batch_by:
310      dataset = batching._RebatchDataset(dataset, split_batch_by)  # pylint: disable=protected-access
311
312    iterators = []
313    for i, worker in enumerate(input_workers.worker_devices):
314      with ops.device(worker):
315        worker_devices = input_workers.compute_devices_for_worker(i)
316        cloned_dataset = dataset
317        if not context.executing_eagerly():
318          cloned_dataset = input_ops._clone_dataset(dataset)  # pylint: disable=protected-access
319          cloned_dataset = cloned_dataset.with_options(dataset.options())
320        iterator = _SingleWorkerDatasetIterator(cloned_dataset, worker,
321                                                worker_devices)
322        iterators.append(iterator)
323
324    self._element_structure = dataset._element_structure  # pylint: disable=protected-access
325
326    super(DatasetIterator, self).__init__(input_workers, iterators)
327
328
329def _dummy_tensor_fn(value_structure):
330  """A function to create dummy tensors from `value_structure`."""
331
332  def create_dummy_tensor(feature_shape, feature_type):
333    """Create a dummy tensor with possible batch dimensions set to 0."""
334
335    # Ideally we should set the batch dimension to 0, however as in
336    # DistributionStrategy we don't know the batch dimension, we try to
337    # guess it as much as possible. If the feature has unknown dimensions, we
338    # will set them to 0. If the feature shape is already static, we guess the
339    # first dimension as batch dimension and set it to 0.
340    dims = []
341    for dim in feature_shape.dims:
342      if dim.value is None:
343        dims.append(tensor_shape.Dimension(0))
344      else:
345        dims.append(dim)
346    if feature_shape.is_fully_defined() and dims:
347      dims[0] = tensor_shape.Dimension(0)
348
349    # Create the dummy tensor.
350    dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type)
351    return dummy_tensor
352
353  result = []
354  # pylint: disable=protected-access
355  for feature_shape, feature_type in zip(value_structure._flat_shapes,
356                                         value_structure._flat_types):
357    result.append(create_dummy_tensor(feature_shape, feature_type))
358
359  if isinstance(value_structure, structure.NestedStructure):
360    result = nest.pack_sequence_as(value_structure._nested_structure, result)
361  else:
362    result = result[0]
363  # pylint: enable=protected-access
364
365  return result
366
367
368class _SingleWorkerDatasetIterator(object):
369  """Iterator for a single `tf.data.Dataset`."""
370
371  def __init__(self, dataset, worker, devices):
372    """Create iterator for the `dataset` to fetch data to worker's `devices` .
373
374    `MultiDeviceIterator` is used to prefetch input to the devices on the
375    given worker.
376
377    Args:
378      dataset: A `tf.data.Dataset` instance.
379      worker: Worker on which ops should be created.
380      devices: Distribute data from `dataset` to these devices.
381    """
382    self._dataset = dataset
383    self._worker = worker
384    self._devices = devices
385    self._make_iterator()
386
387  def _make_iterator(self):
388    """Make appropriate iterator on the dataset."""
389    with ops.device(self._worker):
390      self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
391          self._dataset, self._devices)
392
393  def get_next(self, device, name=None):
394    """Get next element for the given device."""
395    del name
396    with ops.device(self._worker):
397      return self._iterator.get_next(device)
398
399  def get_next_as_list(self, name=None):
400    """Get next element from underlying iterator.
401
402    If there is no data left, a list of dummy tensors with possible batch
403    dimensions set to 0 will be returned.
404
405    Args:
406      name: not used.
407
408    Returns:
409      A boolean tensor indicates whether there is any data in next element and
410      the real data as the next element or a list of dummy tensors if no data
411      left.
412    """
413    del name
414    with ops.device(self._worker):
415      data_list = self._iterator.get_next_as_optional()
416      result = []
417      for i, data in enumerate(data_list):
418        # Place the condition op in the same device as the data so the data
419        # doesn't need to be sent back to the worker.
420        with ops.device(self._devices[i]):
421          # As MultiDeviceIterator will fetch data in order, so we only need to
422          # check if the first replica has value to see whether there is data
423          # left for this single worker.
424          if i == 0:
425            worker_has_value = data.has_value()
426
427          # pylint: disable=unnecessary-lambda
428          # pylint: disable=cell-var-from-loop
429          real_data = control_flow_ops.cond(
430              data.has_value(),
431              lambda: data.get_value(),
432              lambda: _dummy_tensor_fn(data.value_structure))
433          result.append(real_data)
434          # pylint: enable=cell-var-from-loop
435          # pylint: enable=unnecessary-lambda
436
437      return worker_has_value, result
438
439  def initialize(self):
440    """Initialze underlying iterator.
441
442    In eager execution, this simply recreates the underlying iterator.
443    In graph execution, it returns the initializer ops for the underlying
444    iterator.
445
446    Returns:
447      A list of any initializer ops that should be run.
448    """
449    if context.executing_eagerly():
450      self._iterator._eager_reset()  # pylint: disable=protected-access
451      return []
452    else:
453      return [self._iterator.initializer]
454
455  @property
456  def output_classes(self):
457    return dataset_ops.get_legacy_output_classes(self._iterator)
458
459  @property
460  def output_shapes(self):
461    return dataset_ops.get_legacy_output_shapes(self._iterator)
462
463  @property
464  def output_types(self):
465    return dataset_ops.get_legacy_output_types(self._iterator)
466
467
468class _SingleWorkerCallableIterator(object):
469  """Iterator for a single tensor-returning callable."""
470
471  def __init__(self, fn, worker, devices):
472    self._fn = fn
473    self._worker = worker
474    self._devices = devices
475
476  def get_next(self, device, name=None):
477    """Get next element for the given device from the callable."""
478    del device, name
479    with ops.device(self._worker):
480      return self._fn()
481
482  def get_next_as_list(self, name=None):
483    """Get next element from the callable."""
484    del name
485    with ops.device(self._worker):
486      data_list = [self._fn() for _ in self._devices]
487      return constant_op.constant(True), data_list
488
489  def initialize(self):
490    # TODO(petebu) Should this throw an exception instead?
491    return []
492
493
494# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
495def _get_batched_dataset(d):
496  """Get the batched dataset from `d`."""
497  # pylint: disable=protected-access
498  if isinstance(d, dataset_ops.DatasetV1Adapter):
499    d = d._dataset
500
501  if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)):
502    return d
503  elif isinstance(d, (dataset_ops.PrefetchDataset,
504                      dataset_ops._OptionsDataset)):
505    return _get_batched_dataset(d._input_dataset)
506
507  raise ValueError(
508      "Unable to get batched dataset from the input dataset. `batch` "
509      "`map_and_batch` need to be the last operations on the dataset. "
510      "The batch operations can be followed by a prefetch.")
511
512
513def _get_batched_dataset_attributes(d):
514  """Get `batch_size`, `drop_remainder` of dataset."""
515  # pylint: disable=protected-access
516  assert isinstance(d,
517                    (dataset_ops.BatchDataset, batching._MapAndBatchDataset))
518  if isinstance(d, dataset_ops.BatchDataset):
519    batch_size = d._batch_size
520    drop_remainder = d._drop_remainder
521  elif isinstance(d, batching._MapAndBatchDataset):
522    batch_size = d._batch_size_t
523    drop_remainder = d._drop_remainder_t
524  # pylint: enable=protected-access
525
526  if tensor_util.is_tensor(batch_size):
527    batch_size = tensor_util.constant_value(batch_size)
528
529  if tensor_util.is_tensor(drop_remainder):
530    drop_remainder = tensor_util.constant_value(drop_remainder)
531
532  return batch_size, drop_remainder
533
534
535# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
536def _get_dataset_attributes(dataset):
537  """Get the underlying attributes from the dataset object."""
538  # pylint: disable=protected-access
539
540  # First, get batch_size and drop_remainder from the dataset. We need
541  # to walk back the dataset creation process and find the batched version in
542  # order to get the attributes.
543  batched_dataset = _get_batched_dataset(dataset)
544  batch_size, drop_remainder = _get_batched_dataset_attributes(batched_dataset)
545
546  # Second, prefetch buffer should be get from the original dataset.
547  prefetch_buffer = None
548  if isinstance(dataset, dataset_ops.PrefetchDataset):
549    prefetch_buffer = dataset._buffer_size
550  elif (isinstance(dataset, dataset_ops.DatasetV1Adapter)
551        and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)):
552    prefetch_buffer = dataset._dataset._buffer_size
553
554  return batch_size, drop_remainder, prefetch_buffer
555
556
557class MultiStepContext(object):
558  """A context object that can be used to capture things when running steps.
559
560  This context object is useful when running multiple steps at a time using the
561  `experimental_run_steps_on_iterator` API. For e.g. it allows the user's step
562  function to specify which outputs to emit at what frequency. Currently it
563  supports capturing output from the last step, as well as capturing non tensor
564  outputs.  In the future it will be augmented to support other use cases such
565  as output each N steps.
566  """
567
568  def __init__(self):
569    """Initialize an output context.
570
571    Returns:
572      A context object.
573    """
574    self._last_step_outputs = {}
575    self._last_step_outputs_reduce_ops = {}
576    self._non_tensor_outputs = {}
577
578  @property
579  def last_step_outputs(self):
580    """A dictionary consisting of outputs to be captured on last step.
581
582    Keys in the dictionary are names of tensors to be captured, as specified
583    when `set_last_step_output` is called.
584    Values in the dictionary are the tensors themselves. If
585    `set_last_step_output` was called with a `reduce_op` for this output,
586    then the value is the reduced value.
587
588    Returns:
589      A dictionary with last step outputs.
590    """
591    return self._last_step_outputs
592
593  def _set_last_step_outputs(self, outputs):
594    """Replace the entire dictionary of last step outputs."""
595    if not isinstance(outputs, dict):
596      raise ValueError("Need a dictionary to set last_step_outputs.")
597    self._last_step_outputs = outputs
598
599  def set_last_step_output(self, name, output, reduce_op=None):
600    """Set `output` with `name` to be outputted from the last step.
601
602    Args:
603      name: String, name to identify the output. Doesn't need to match tensor
604        name.
605      output: The tensors that should be outputted with `name`. See below for
606        actual types supported.
607      reduce_op: Reduction method to use to reduce outputs from multiple
608        replicas. Required if `set_last_step_output` is called in a replica
609        context. Optional in cross_replica_context.
610        When present, the outputs from all the replicas are reduced using the
611        current distribution strategy's `reduce` method. Hence, the type of
612        `output` must be what's supported by the corresponding `reduce` method.
613        For e.g. if using MirroredStrategy and reduction is set, output
614        must be a `PerReplica` value.
615        The reduce method is also recorded in a dictionary
616        `_last_step_outputs_reduce_ops` for later interpreting of the
617        outputs as already reduced or not.
618    """
619    if distribution_strategy_context.in_cross_replica_context():
620      self._last_step_outputs_reduce_ops[name] = reduce_op
621      if reduce_op is None:
622        self._last_step_outputs[name] = output
623      else:
624        distribution = distribution_strategy_context.get_strategy()
625        self._last_step_outputs[name] = distribution.reduce(reduce_op, output)
626    else:
627      assert reduce_op is not None
628      def merge_fn(distribution, value):
629        self._last_step_outputs[name] = distribution.reduce(reduce_op, value)
630        # Setting this inside the `merge_fn` because all replicas share the same
631        # context object, so it's more robust to set it only once (even if all
632        # the replicas are trying to set the same value).
633        self._last_step_outputs_reduce_ops[name] = reduce_op
634
635      distribution_strategy_context.get_replica_context().merge_call(
636          merge_fn, args=(output,))
637
638  @property
639  def non_tensor_outputs(self):
640    """A dictionary consisting of any non tensor outputs to be captured."""
641    return self._non_tensor_outputs
642
643  def set_non_tensor_output(self, name, output):
644    """Set `output` with `name` to be captured as a non tensor output."""
645    if distribution_strategy_context.in_cross_replica_context():
646      self._non_tensor_outputs[name] = output
647    else:
648      def merge_fn(distribution, value):
649        # NOTE(priyag): For non tensor outputs, we simply return all the values
650        # in a list as reduction doesn't make sense on non tensors.
651        self._non_tensor_outputs[name] = (
652            distribution.experimental_local_results(value))
653      distribution_strategy_context.get_replica_context().merge_call(
654          merge_fn, args=(output,))
655