1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes for different algorithms of reduction and broadcasting."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import copy
23import multiprocessing.dummy
24import multiprocessing.pool
25import threading
26
27import six
28
29from tensorflow.python.client import device_lib
30from tensorflow.python.distribute import collective_util
31from tensorflow.python.distribute import cross_device_utils
32from tensorflow.python.distribute import device_util
33from tensorflow.python.distribute import distribute_utils
34from tensorflow.python.distribute import ps_values
35from tensorflow.python.distribute import reduce_util
36from tensorflow.python.distribute import tpu_values
37from tensorflow.python.distribute import values as value_lib
38from tensorflow.python.distribute import values_util
39from tensorflow.python.eager import context
40from tensorflow.python.eager import def_function
41from tensorflow.python.framework import kernels
42from tensorflow.python.framework import ops
43from tensorflow.python.framework import tensor_util
44from tensorflow.python.ops import array_ops
45from tensorflow.python.ops import math_ops
46from tensorflow.python.ops import resource_variable_ops
47from tensorflow.python.platform import tf_logging as logging
48from tensorflow.python.util import nest
49from tensorflow.python.util.tf_export import tf_export
50from tensorflow.tools.docs import doc_controls
51
52
53def check_destinations(destinations):
54  """Checks whether `destinations` is not empty.
55
56  Args:
57    destinations: a `DistributedValues`, variable, or string object.
58
59  Returns:
60    Boolean which is True if `destinations` is not empty.
61  """
62  # Calling bool() on a ResourceVariable is not allowed.
63  if isinstance(destinations,
64                (resource_variable_ops.BaseResourceVariable, ops.Tensor)):
65    return bool(destinations.device)
66  return bool(destinations)
67
68
69def validate_destinations(destinations):
70  """Validates the `destination` is one of expected types."""
71  if not isinstance(
72      destinations,
73      (value_lib.DistributedValues, ops.Tensor, ps_values.AggregatingVariable,
74       six.string_types, tpu_values.TPUMirroredVariable
75      )) and not resource_variable_ops.is_resource_variable(destinations):
76    raise ValueError("destinations must be one of a `DistributedValues` object,"
77                     " a tf.Variable object, or a device string.")
78
79  if not check_destinations(destinations):
80    raise ValueError("destinations can not be empty")
81
82
83def reduce_non_distributed_value(
84    reduce_op, value, destinations, num_replicas_in_graph):
85  """Reduce a non-DistributedValue `value` to `destinations`."""
86  if isinstance(value, value_lib.DistributedValues):
87    raise ValueError("You are passing a `DistributedValues` to "
88                     "`reduce_non_distributed_value`, which is not allowed.")
89
90  # If the same value is present on all replicas then the PerReplica value will
91  # be a single value. We also handle the case when `value` is a single value
92  # and equal to 0.
93  # TODO:(b/138823479): handle the tensor value properly.
94  if not tensor_util.is_tf_type(value) and value == 0:
95    return 0
96  # If there is only a single value and the reduce op is MEAN,
97  # that value should be on all destinations.
98  if reduce_op == reduce_util.ReduceOp.MEAN:
99    return value
100  elif num_replicas_in_graph != 1:
101    # We do not support a reduce op of SUM if the value is the same across
102    # all replicas. We call this as part of assign functions for
103    # MirroredVariables and summing up identical values across replicas is not
104    # clearly defined.
105    raise ValueError("A non-DistributedValues value %s cannot be reduced with "
106                     "the given reduce op %s." % (value, reduce_op))
107  else:
108    validate_destinations(destinations)
109    return simple_broadcast(value, destinations)
110
111
112def _make_tensor_into_per_replica(input_tensor):
113  """Converts a single tensor into a PerReplica object."""
114  if isinstance(input_tensor, (tuple, list)):
115    raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object, "
116                     "got %r but expected a object that is not a tuple or list."
117                     % (input_tensor,))
118  if isinstance(input_tensor, value_lib.PerReplica):
119    return input_tensor
120  elif hasattr(input_tensor, "device"):
121    return value_lib.PerReplica((input_tensor,))
122  else:
123    raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object "
124                     "because it doesn't have device set.")
125
126
127def _normalize_value_destination_pairs(value_destination_pairs):
128  """Converts each tensor into a PerReplica object in the input list."""
129  result = []
130
131  value_destination_pairs = list(value_destination_pairs)
132
133  if not isinstance(value_destination_pairs, (list, tuple)):
134    raise ValueError("`value_destination_pairs` should be a list or tuple")
135  for pair in value_destination_pairs:
136    if not isinstance(pair, tuple):
137      raise ValueError(
138          "Each element of `value_destination_pairs` should be a tuple.")
139    if len(pair) != 2:
140      raise ValueError("Each element of `value_destination_pairs` should be a "
141                       "tuple of size 2.")
142
143    per_replica = _make_tensor_into_per_replica(pair[0])
144    result.append((per_replica, pair[1]))
145  return result
146
147
148def _validate_value_destination_pairs(value_destination_pairs):
149  """Validates value_destination_pairs are valid."""
150  # TODO(yuefengz): raise exceptions instead of returning False.
151  if not value_destination_pairs: return False
152  if not isinstance(value_destination_pairs, (list, tuple)): return False
153  if not all(isinstance(pair, tuple) for pair in value_destination_pairs):
154    return False
155  if not all(isinstance(v[0], value_lib.PerReplica)
156             for v in value_destination_pairs):
157    return False
158  return True
159
160
161# TODO(yuefengz): consider calling this function in the caller of
162# CrossDeviceOps.
163def get_devices_from(destinations):
164  if isinstance(destinations, value_lib.DistributedValues):
165    return destinations._devices  # pylint: disable=protected-access
166  elif isinstance(destinations, six.string_types):
167    return (device_util.resolve(destinations),)
168  return (device_util.resolve(destinations.device),)
169
170
171def _devices_match(left, right):
172  return left is right or set(get_devices_from(left)) == set(
173      get_devices_from(right))
174
175
176def _all_devices_match(value_destination_pairs):
177  if not all(_devices_match(v, d) for v, d in value_destination_pairs):
178    return False
179  if not all(_devices_match(v, value_destination_pairs[0][0])
180             for v, _ in value_destination_pairs[1:]):
181    return False
182  return True
183
184
185def simple_broadcast(value, destinations, always_mirrored=False):
186  """Broadcast `value` to `destinations` using simple copies."""
187  devices = get_devices_from(destinations)
188  if len(devices) == 1 and not always_mirrored:
189    return cross_device_utils.copy_tensor_or_indexed_slices_to_device(
190        value, devices[0])
191  else:
192    value_updates = []
193    for d in devices:
194      value_updates.append(
195          cross_device_utils.copy_tensor_or_indexed_slices_to_device(value, d))
196    return distribute_utils.regroup(value_updates,
197                                    wrap_class=value_lib.Mirrored)
198
199
200def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
201                   reduce_op):
202  """Reduces the value by accumulation_fn and reduce_op."""
203  all_values = per_replica_value.values
204  if not all_values:
205    raise ValueError("`per_replica_value` must be non-empty")
206  count = len(all_values)
207
208  with ops.device(reduce_to_device):
209    with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
210      reduced = cross_device_utils.aggregate_tensors_or_indexed_slices(
211          all_values, accumulation_fn)
212      if reduce_op == reduce_util.ReduceOp.MEAN:
213        reduced = cross_device_utils.divide_by_n_tensors_or_indexed_slices(
214            reduced, count)
215      elif reduce_op != reduce_util.ReduceOp.SUM:
216        raise ValueError("`reduce_op` must be Reduce.SUM or Reduce.MEAN.")
217  return reduced
218
219
220def _simple_gather(per_replica_value, reduce_to_device, axis):
221  """Concatenate all values in the DistributedValues input and return."""
222  all_values = per_replica_value.values
223  if not all_values:
224    raise ValueError("`per_replica_value` must be non-empty")
225
226  with ops.device(reduce_to_device):
227    with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
228      gathered = array_ops.concat(all_values, axis)
229  return gathered
230
231
232@tf_export("distribute.CrossDeviceOps")
233class CrossDeviceOps(object):
234  """Base class for cross-device reduction and broadcasting algorithms.
235
236  The main purpose of this class is to be passed to
237  `tf.distribute.MirroredStrategy` in order to choose among different cross
238  device communication implementations. Prefer using the methods of
239  `tf.distribute.Strategy` instead of the ones of this class.
240
241  Implementations:
242  * `tf.distribute.ReductionToOneDevice`
243  * `tf.distribute.NcclAllReduce`
244  * `tf.distribute.HierarchicalCopyAllReduce`
245  """
246
247  def __init__(self):
248    pass
249
250  @property
251  def _num_between_graph_workers(self):
252    # Returns 1 by default, the value may be overridden by sub classes.
253    return 1
254
255  def reduce(self, reduce_op, per_replica_value, destinations, options=None):
256    """Reduce `per_replica_value` to `destinations`.
257
258    See `tf.distribute.StrategyExtended.reduce_to`. This can only be called in
259    the cross-replica context.
260
261    Args:
262      reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
263        combined.
264      per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
265        like object.
266      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
267        `tf.Tensor` alike object, or a device string. It specifies the devices
268        to reduce to. To perform an all-reduce, pass the same to `value` and
269        `destinations`. Note that if it's a `tf.Variable`, the value is reduced
270        to the devices of that variable, and this method doesn't update the
271        variable.
272      options: a `tf.distribute.experimental.CommunicationOptions`. See
273        `tf.distribute.experimental.CommunicationOptions` for details.
274
275    Returns:
276      A `tf.Tensor` or `tf.distribute.DistributedValues`.
277
278    Raises:
279      ValueError: if per_replica_value can't be converted to a
280        `tf.distribute.DistributedValues` or if destinations is not a string,
281        `tf.Variable` or `tf.distribute.DistributedValues`.
282    """
283    if options is None:
284      options = collective_util.Options()
285    if not isinstance(per_replica_value, value_lib.DistributedValues):
286      per_replica_value = _make_tensor_into_per_replica(per_replica_value)
287
288    validate_destinations(destinations)
289
290    # Shortcut if `per_replica_value` only contains one value.
291    if self._num_between_graph_workers == 1 and len(
292        per_replica_value.values) == 1 and _devices_match(
293            per_replica_value, destinations):
294      with ops.device(per_replica_value.values[0].device):
295        v = array_ops.identity(per_replica_value.values[0])
296      return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored)
297
298    if options is None:
299      options = collective_util.Options()
300    return self.reduce_implementation(reduce_op, per_replica_value,
301                                      destinations, options)
302
303  def _gather(self, per_replica_value, destinations, axis, options=None):
304    """Gather `per_replica_value` to `destinations`.
305
306    Args:
307      per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
308        like object.
309      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
310        `tf.Tensor` alike object, or a device string. It specifies the devices
311        to gather to. To perform an all-gather, pass the same to `value` and
312        `destinations`. Note that if it's a `tf.Variable`, the value is gathered
313        to the devices of that variable, and this method doesn't update the
314        variable.
315      axis: specifies the dimension to gather along within each replica's
316        tensor.
317      options: a `tf.distribute.experimental.CommunicationOptions`. See
318        `tf.distribute.experimental.CommunicationOptions` for details.
319
320    Returns:
321      A `tf.Tensor` or `tf.distribute.DistributedValues`
322
323    Raises:
324      ValueError: if per_replica_value can't be converted to a
325        `tf.distribute.DistributedValues` or if destinations is not a string,
326        `tf.Variable` or `tf.distribute.DistributedValues`.
327    """
328    if isinstance(per_replica_value, ops.IndexedSlices):
329      raise NotImplementedError("gather/all_gather does not support "
330                                "IndexedSlices")
331    if options is None:
332      options = collective_util.Options()
333
334    if not isinstance(per_replica_value, value_lib.DistributedValues):
335      per_replica_value = _make_tensor_into_per_replica(per_replica_value)
336
337    validate_destinations(destinations)
338
339    # Shortcut if `per_replica_value` only contains one value.
340    if self._num_between_graph_workers == 1 and len(
341        per_replica_value.values) == 1 and _devices_match(
342            per_replica_value, destinations):
343      with ops.device(per_replica_value.values[0].device):
344        v = array_ops.identity(per_replica_value.values[0])
345      return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored)
346
347    return self._gather_implementation(per_replica_value, destinations, axis,
348                                       options)
349
350  def _gather_implementation(self, per_replica_value, destinations, axis,
351                             options):
352    """Implementation of `gather` method of `tf.distribute.CrossDeviceOps`.
353
354    Overriding this method is useful for subclass implementers.
355
356    Args:
357      per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
358        like object.
359      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
360        `tf.Tensor` alike object, or a device string. It specifies the devices
361        to gather to. To perform an all-gather, pass the same to `value` and
362        `destinations`. Note that if it's a `tf.Variable`, the value is gathered
363        to the devices of that variable, this method doesn't update the
364        variable.
365      axis: specifies the dimension to gather along within each replica's
366        tensor.
367      options: a `tf.distribute.experimental.CommunicationOptions`. See
368        `tf.distribute.experimental.CommunicationOptions` for details.
369
370    Returns:
371      A `tf.Tensor` or `tf.distribute.DistributedValues`.
372
373    Raises:
374      ValueError: if per_replica_value can't be converted to a
375        `tf.distribute.DistributedValues` or if destinations is not a string,
376        `tf.Variable` or `tf.distribute.DistributedValues`.
377    """
378    raise NotImplementedError(
379        "_gather method must be implemented in descendants.")
380
381  def batch_reduce(self, reduce_op, value_destination_pairs, options=None):
382    """Reduce values to destinations in batches.
383
384    See `tf.distribute.StrategyExtended.batch_reduce_to`. This can only be
385    called in the cross-replica context.
386
387    Args:
388      reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
389        combined.
390      value_destination_pairs: a sequence of (value, destinations) pairs. See
391        `tf.distribute.CrossDeviceOps.reduce` for descriptions.
392      options: a `tf.distribute.experimental.CommunicationOptions`. See
393        `tf.distribute.experimental.CommunicationOptions` for details.
394
395    Returns:
396      A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair
397      in `value_destination_pairs`.
398
399    Raises:
400      ValueError: if `value_destination_pairs` is not an iterable of
401        tuples of `tf.distribute.DistributedValues` and destinations.
402    """
403    if options is None:
404      options = collective_util.Options()
405    # TODO(yuefengz): if destinations are different, split into several
406    # `_batch_reduce` invocations.
407    if not _validate_value_destination_pairs(value_destination_pairs):
408      # If the first element of each pair is a tensor, we try to turn it into a
409      # PerReplica object.
410      value_destination_pairs = _normalize_value_destination_pairs(
411          value_destination_pairs)
412
413    for _, d in value_destination_pairs:
414      validate_destinations(d)
415
416    # Shortcut all PerReplica objects only contain one value.
417    if self._num_between_graph_workers == 1 and _all_devices_match(
418        value_destination_pairs) and len(
419            value_destination_pairs[0][0].values) == 1:
420      return [
421          distribute_utils.regroup(v.values, wrap_class=value_lib.Mirrored)
422          for v, _ in value_destination_pairs
423      ]
424
425    if options is None:
426      options = collective_util.Options()
427    return self.batch_reduce_implementation(reduce_op, value_destination_pairs,
428                                            options)
429
430  def broadcast(self, tensor, destinations):
431    """Broadcast `tensor` to `destinations`.
432
433    This can only be called in the cross-replica context.
434
435    Args:
436      tensor: a `tf.Tensor` like object. The value to broadcast.
437      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
438        `tf.Tensor` alike object, or a device string. It specifies the devices
439        to broadcast to. Note that if it's a `tf.Variable`, the value is
440        broadcasted to the devices of that variable, this method doesn't update
441        the variable.
442
443    Returns:
444      A `tf.Tensor` or `tf.distribute.DistributedValues`.
445    """
446    validate_destinations(destinations)
447    return self.broadcast_implementation(tensor, destinations)
448
449  @doc_controls.for_subclass_implementers
450  def reduce_implementation(self, reduce_op, per_replica_value, destinations,
451                            options):
452    """Implementation of `reduce`.
453
454    Overriding this method is useful for subclass implementers.
455
456    Args:
457      reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
458        combined.
459      per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
460        like object.
461      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
462        `tf.Tensor` alike object, or a device string. It specifies the devices
463        to reduce to. To perform an all-reduce, pass the same to `value` and
464        `destinations`. Note that if it's a `tf.Variable`, the value is reduced
465        to the devices of that variable, this method doesn't update the
466        variable.
467      options: a `tf.distribute.experimental.CommunicationOptions`. See
468        `tf.distribute.experimental.CommunicationOptions` for details.
469
470    Returns:
471      A `tf.Tensor` or `tf.distribute.DistributedValues`.
472
473    Raises:
474      ValueError: if per_replica_value can't be converted to a
475        `tf.distribute.DistributedValues` or if destinations is not a string,
476        `tf.Variable` or `tf.distribute.DistributedValues`.
477    """
478    raise NotImplementedError(
479        "_reduce method must be implemented in descendants.")
480
481  @doc_controls.for_subclass_implementers
482  def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
483                                  options):
484    """Implementation of `batch_reduce`.
485
486    Overriding this method is useful for subclass implementers.
487
488    Args:
489      reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
490        combined.
491      value_destination_pairs: a sequence of (value, destinations) pairs. See
492        `reduce` for descriptions.
493      options: a `tf.distribute.experimental.CommunicationOptions`. See
494        `tf.distribute.experimental.CommunicationOptions` for details.
495
496    Returns:
497      A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair
498      in `value_destination_pairs`.
499
500    Raises:
501      ValueError: if `value_destination_pairs` is not an iterable of
502        tuples of `tf.distribute.DistributedValues` and destinations.
503    """
504    raise NotImplementedError(
505        "batch_reduce_implementation method must be implemented in descendants."
506    )
507
508  @doc_controls.for_subclass_implementers
509  def broadcast_implementation(self, tensor, destinations):
510    """Implementation of `broadcast`.
511
512    Args:
513      tensor: a `tf.Tensor` like object. The value to broadcast.
514      destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
515        `tf.Tensor` alike object, or a device string. It specifies the devices
516        to broadcast to.
517        `destinations`. Note that if it's a `tf.Variable`, the value is
518        broadcasted to the devices of that variable, this method doesn't update
519        the variable.
520
521    Returns:
522      A `tf.Tensor` or `tf.distribute.DistributedValues`.
523    """
524    return simple_broadcast(tensor, destinations, always_mirrored=True)
525
526  # ========================== Collective APIs ================================
527  #
528  # Different than `reduce`, `batch_reduce` and `broadcast` which must be called
529  # in cross-replcia context, collective APIs are to be called in replica
530  # context.
531
532  def _all_reduce(self, reduce_op, value, replica_id, options):
533    """All-reduce the `value` across all replicas so that all get the result.
534
535    `value` can be a nested structure of tensors. The implementation should
536    generally batch the all-reduces when possible. `options` can be set to
537    hint the batching behavior.
538
539    This API must be called in a replica context.
540
541    Args:
542      reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
543        be combined. Allows using string representation of the enum such as
544        "SUM", "MEAN".
545      value: Value to be reduced. A tensor or a nested structure of tensors.
546      replica_id: An interger indicating the id of the replica where this
547        all_reduce is called under. This is the local replica id that ranges
548        from 0 to len(local_devices) - 1.
549      options: A `tf.distribute.experimental.CommunicationOptions`.
550
551    Returns:
552      A tensor or a nested strucutre of tensors with the reduced values. The
553      structure is the same as `value`.
554    """
555    raise NotImplementedError("_all_reduce must be implemented in descendants.")
556
557
558@tf_export("distribute.ReductionToOneDevice")
559class ReductionToOneDevice(CrossDeviceOps):
560  """A CrossDeviceOps implementation that copies values to one device to reduce.
561
562  This implementation always copies values to one device to reduce them, then
563  broadcast reduced values to the destinations. It doesn't support efficient
564  batching.
565
566  Here is how you can use `ReductionToOneDevice` in
567  `tf.distribute.MirroredStrategy`:
568
569  ```
570    strategy = tf.distribute.MirroredStrategy(
571      cross_device_ops=tf.distribute.ReductionToOneDevice())
572  ```
573  """
574
575  def __init__(self, reduce_to_device=None, accumulation_fn=None):
576    """Initializes with a device to reduce to and a way to accumulate.
577
578    Args:
579      reduce_to_device: the intermediate device to reduce to. If None, reduce
580        to the first device in `destinations` of the `reduce` method.
581      accumulation_fn: a function that does accumulation.  If None,
582        `tf.math.add_n` is used.
583    """
584    self.reduce_to_device = reduce_to_device
585    self.accumulation_fn = accumulation_fn or math_ops.add_n
586    super(ReductionToOneDevice, self).__init__()
587
588  def reduce_implementation(self, reduce_op, per_replica_value, destinations,
589                            options):
590    del options  # Unused.
591    if check_destinations(destinations):
592      devices = get_devices_from(destinations)
593    else:
594      devices = get_devices_from(per_replica_value)
595    reduce_to_device = self.reduce_to_device or devices[0]
596    logging.log_first_n(
597        logging.INFO,
598        "Reduce to %s then broadcast to %r." % (reduce_to_device, devices), 10)
599    reduced = _simple_reduce(per_replica_value, reduce_to_device,
600                             self.accumulation_fn, reduce_op)
601    return self.broadcast(reduced, destinations)
602
603  def _gather_implementation(self, per_replica_value, destinations, axis,
604                             options):
605    del options  # Unused.
606    if check_destinations(destinations):
607      devices = get_devices_from(destinations)
608    else:
609      devices = get_devices_from(per_replica_value)
610    reduce_to_device = self.reduce_to_device or devices[0]
611    logging.log_first_n(
612        logging.INFO,
613        "Gather to %s then broadcast to %r." % (reduce_to_device, devices), 10)
614    gathered = _simple_gather(per_replica_value, reduce_to_device, axis)
615    return self.broadcast(gathered, destinations)
616
617  def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
618                                  options):
619    return [
620        self.reduce_implementation(
621            reduce_op, t, destinations=v, options=options)
622        for t, v in value_destination_pairs
623    ]
624
625
626def _group_value_by_device(per_replica_values):
627  """Group values into sublists by their devices.
628
629  This grouping is needed to call the all-reduce library because it expects a
630  list of the following form:
631    [[(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ...],
632     [(grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ...],
633     [(grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ...],
634     ...
635    ]
636
637  Args:
638    per_replica_values: a list of PerReplica objects.
639
640  Returns:
641    a list of lists, each sublist has components for its corresponding device of
642      PerReplica objects, paired with a None.
643  """
644  destinations = per_replica_values[0]._devices  # pylint: disable=protected-access
645  grouped = [[] for _ in range(len(destinations))]
646  for per_replica_value in per_replica_values:
647    # pylint: disable=protected-access
648    for i, v in enumerate(per_replica_value.values):
649      assert per_replica_value._devices == destinations
650      grouped[i].append((v, None))
651  return grouped
652
653
654def _ungroup_and_make_mirrored(grouped_reduced,
655                               destinations,
656                               reduce_op,
657                               num_between_graph_workers=1):
658  """Ungroup results from all-reduce and make Mirrored objects.
659
660  Each all-reduce result will be divided by the number of destinations before
661  Mirrored objects are created if reduce_op is "mean".
662
663  Args:
664    grouped_reduced: a list of lists, each sublist has components for each
665      device, paired with a None. It is the result from
666      cross_device_utils.aggregate_gradients_using*.
667    destinations: a value to colocate the result with.
668    reduce_op: Indicates how values will be aggregated. Accepted values
669      are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`.
670    num_between_graph_workers: number of workers in the between-graph
671      replication.
672
673  Returns:
674    a list of Mirrored objects.
675  """
676  num_replicas = len(get_devices_from(destinations)) * num_between_graph_workers
677  index = [[] for _ in range(len(grouped_reduced[0]))]
678  for per_replica_reduced in grouped_reduced:
679    for i, (v, _) in enumerate(per_replica_reduced):
680      if reduce_op == reduce_util.ReduceOp.MEAN:
681        with ops.device(v.device):
682          index[i].append(v / num_replicas)
683      else:
684        index[i].append(v)
685  return [distribute_utils.regroup(
686      v, wrap_class=value_lib.Mirrored) for v in index]
687
688
689class _ConcatAndSplitPacker(object):
690  """Concatenate and split tensors for reduction."""
691
692  def __init__(self, num_packs=1):
693    """Initialize the _ConcatAndSplitPacker object.
694
695    Args:
696      num_packs: specifies the number of split packs that will be
697        formed.
698
699    Raises:
700      ValueError: if num_packs is not greater than 0.
701    """
702    if num_packs <= 0:
703      raise ValueError("num_packs must be greater than zero.")
704    self.num_packs = num_packs
705
706  def pack(self, grouped_grads_and_vars):
707    """Pack tensors."""
708    self.grouped_grads_and_vars = grouped_grads_and_vars
709    self.all_device_shapes = []
710    self.all_device_sizes = []
711
712    device_grad_packs = []
713    for device_grads_and_vars in grouped_grads_and_vars:
714      with ops.colocate_with(device_grads_and_vars[0][0]):
715        # Flatten all the grads.
716        flat_grads = [
717            array_ops.reshape(g, [-1]) for g, _ in device_grads_and_vars
718        ]
719        # Remember the original shape of all the grads.
720        device_shapes = [array_ops.shape(g) for g, _ in device_grads_and_vars]
721        # Remember the original sizes of all the grads.
722        device_sizes = [array_ops.size(g) for g, _ in device_grads_and_vars]
723        # Concat all the flat grads into a big flat tensor.
724        concat_grads = array_ops.concat(flat_grads, 0)
725
726        # Split the big tensor into num_splits packs. In cases where the
727        # total size is not divisible num_splits, the last pack gets
728        # more elements.
729        # TODO(zhengxq): it is also possible to optimize away all the concat
730        # as well.
731        num_splits = self.num_packs
732
733        # The array_ops.size function will sometimes remove static shapes. So if
734        # all gradient shapes are defined, we use another method to get the
735        # total size.
736        # TODO(yuefengz): move this logic to array_ops.size.
737        if all(g.shape.is_fully_defined() for g, _ in device_grads_and_vars):
738          total_grad_size = sum(
739              [g.shape.num_elements() for g, _ in device_grads_and_vars])
740        else:
741          total_grad_size = array_ops.size(concat_grads)
742
743        split_size = total_grad_size // num_splits
744        split_size_last = total_grad_size - split_size * (num_splits - 1)
745        split_sizes = [split_size] * (num_splits - 1) + [split_size_last]
746        grad_packs = array_ops.split(concat_grads, split_sizes)
747
748        # Ready to aggregate the repacked gradients, with fake variables.
749        # TODO(zhengxq): It is hacky to have to use fake variables.
750        # We should remove the need for variables in
751        # aggregate_gradients_using*.
752        device_grad_packs.append(zip(grad_packs, [None] * num_splits))
753        self.all_device_shapes.append(device_shapes)
754        self.all_device_sizes.append(device_sizes)
755
756    return device_grad_packs
757
758  def unpack(self, summed_device_grad_packs):
759    """Reverse the pack."""
760    aggregated_device_grads = []
761    for (summed_device_grad_packs,
762         device_grads_and_vars, device_shapes, device_sizes) in zip(
763             summed_device_grad_packs, self.grouped_grads_and_vars,
764             self.all_device_shapes, self.all_device_sizes):
765      # pylint: enable=line-too-long
766      # Reverse the packing operations in the previous steps. Form the
767      # summed gradients back into their original shapes.
768      with ops.colocate_with(summed_device_grad_packs[0][0]):
769        # Form a list of the summed grad packs.
770        device_grad_packs = [g for g, _ in summed_device_grad_packs]
771
772        # Concat them back into a big flat tensor.
773        device_grads_concat = array_ops.concat(device_grad_packs, 0)
774
775        # Split the tensors back into their original sizes.
776        grads_with_sizes = array_ops.split(device_grads_concat, device_sizes)
777
778        # Reshape the tensors back into their original shapes.
779        grads_with_shapes = [
780            array_ops.reshape(grad, shape)
781            for shape, grad in zip(device_shapes, grads_with_sizes)
782        ]
783
784        # Form the list with the original list of variables.
785        summed_device_grads = [
786            (g, v) for g, (_, v) in zip(grads_with_shapes,
787                                        device_grads_and_vars)
788        ]
789        aggregated_device_grads.append(summed_device_grads)
790    return aggregated_device_grads
791
792
793def _pack_tensors(device_grads, num_packs=0):
794  """Pack tensors if specified."""
795  if num_packs > 0:
796    tensor_packer = _ConcatAndSplitPacker(num_packs)
797    device_grad_packs = tensor_packer.pack(device_grads)
798  else:
799    tensor_packer = None
800    device_grad_packs = device_grads
801  return device_grad_packs, tensor_packer
802
803
804def _unpack_tensors(reduced, tensor_packer=None):
805  """Unpack tensors if they are packed before all-reduce."""
806  if tensor_packer:
807    return tensor_packer.unpack(reduced)
808  return reduced
809
810
811class AllReduceCrossDeviceOps(CrossDeviceOps):
812  """All-reduce implementation of CrossDeviceOps.
813
814  It performs all-reduce when applicable using NCCL or hierarchical copy. For
815  the batch API, tensors will be repacked or aggregated for more efficient
816  cross-device transportation.
817
818  For reduces that are not all-reduce, it falls back to
819  `tf.distribute.ReductionToOneDevice`.
820  """
821
822  def __init__(self, all_reduce_alg="nccl", num_packs=1):
823    """Initializes the object.
824
825    Args:
826      all_reduce_alg: the all-reduce algorithm to use, currently only "nccl" or
827        "hierarchical_copy" are supported.
828      num_packs: a non-negative integer. The number of packs to split values
829        into. If zero, no packing will be done.
830    """
831    self._all_reduce_alg = all_reduce_alg
832    self._num_packs = num_packs
833    self._simple_cross_replica_ops = ReductionToOneDevice()
834    super(AllReduceCrossDeviceOps, self).__init__()
835
836  def reduce_implementation(self, reduce_op, per_replica_value, destinations,
837                            options):
838    del options  # Unused.
839    # To use NCCL or all-reduce, source and destination devices should match,
840    # and none of the devices should be CPU.
841    if (_devices_match(per_replica_value, destinations) and
842        not any("cpu" in d.lower() for d in get_devices_from(destinations))):
843      return self._batch_all_reduce(reduce_op, [per_replica_value])[0]
844    else:
845      return self._simple_cross_replica_ops.reduce(reduce_op, per_replica_value,
846                                                   destinations)
847
848  def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
849                                  options):
850    if _all_devices_match(value_destination_pairs):
851      return self._batch_all_reduce(reduce_op,
852                                    [v[0] for v in value_destination_pairs])
853    else:
854      return [
855          self.reduce_implementation(reduce_op, value, dest, options)
856          for value, dest in value_destination_pairs
857      ]
858
859  def _batch_all_reduce(self, reduce_op, per_replica_values):
860    """All-reduce algorithm in a batch."""
861    dense_values, dense_indices, sparse_values, sparse_indices = (
862        cross_device_utils.split_by_sparsity(per_replica_values))
863    if dense_values:
864      dense_results = self._do_batch_all_reduce(reduce_op, dense_values)
865    else:
866      dense_results = []
867    if sparse_values:
868      sparse_results = self._do_batch_all_reduce_sparse(reduce_op,
869                                                        sparse_values)
870    else:
871      sparse_results = []
872    return cross_device_utils.stitch_values(((dense_results, dense_indices),
873                                             (sparse_results, sparse_indices)))
874
875  def _do_batch_all_reduce(self, reduce_op, dense_values):
876    """Run batch all-reduces."""
877    logging.log_first_n(
878        logging.INFO,
879        "batch_all_reduce: %d all-reduces with algorithm = %s, num_packs = %d" %
880        (len(dense_values), self._all_reduce_alg, self._num_packs), 10)
881
882    destinations = dense_values[0]._devices  # pylint: disable=protected-access
883    grouped = _group_value_by_device(dense_values)
884
885    # device_grad_packs:
886    # [[(t0_gpu0, None), (t1_gpu0, None)], [(t0_gpu1, None), (t1_gpu1, None)]]
887    device_grad_packs, tensor_packer = _pack_tensors(grouped, self._num_packs)
888
889    # The actual aggregation of the repacked gradients. Note that they are
890    # sharded among different aggregation trees. So it is important to strike
891    # the balance on num_splits.
892    if self._all_reduce_alg == "nccl":
893      # TODO(yuefengz): merge this into the all-reduce library.
894      reduced = cross_device_utils.aggregate_gradients_using_nccl(
895          device_grad_packs)
896    else:
897      # TODO(yuefengz): check that gpu ids in `destinations` are in ascending
898      # order.
899      reduced = (
900          cross_device_utils.aggregate_gradients_using_hierarchical_copy(
901              destinations, device_grad_packs))
902
903    reduced = _unpack_tensors(reduced, tensor_packer)
904    return _ungroup_and_make_mirrored(reduced, dense_values[0], reduce_op)
905
906  def _do_batch_all_reduce_sparse(self, reduce_op, sparse_values):
907    """Run batch all-reduce for sparse values."""
908    logging.log_first_n(
909        logging.WARN,
910        "Efficient allreduce is not supported for %d IndexedSlices" %
911        len(sparse_values), 10)
912    # Use `sparse_values` as destinations to do all-reduces. It is effectively
913    # an allgather under the hood but not an efficient one.
914    return self._simple_cross_replica_ops.batch_reduce(
915        reduce_op, zip(sparse_values, sparse_values))
916
917  def _gather_implementation(self, per_replica_value, destinations, axis,
918                             options):
919    logging.warning("gather/all_gather with NCCL or HierarchicalCopy is not "
920                    "supported. Falling back to gather on one device and "
921                    "then broadcast. We're working on a more efficient "
922                    "implementation.")
923    return ReductionToOneDevice()._gather(per_replica_value, destinations, axis,  # pylint: disable=protected-access
924                                          options)
925
926
927# For compatibility with code using the old name of `AllReduceCrossDeviceOps`.
928AllReduceCrossTowerOps = AllReduceCrossDeviceOps
929
930
931AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple",
932                                            "alg shards limit")
933
934
935@tf_export("distribute.NcclAllReduce")
936class NcclAllReduce(AllReduceCrossDeviceOps):
937  """NCCL all-reduce implementation of CrossDeviceOps.
938
939  It uses Nvidia NCCL for all-reduce. For the batch API, tensors will be
940  repacked or aggregated for more efficient cross-device transportation.
941
942  For reduces that are not all-reduce, it falls back to
943  `tf.distribute.ReductionToOneDevice`.
944
945  Here is how you can use `NcclAllReduce` in `tf.distribute.MirroredStrategy`:
946
947
948  ```
949    strategy = tf.distribute.MirroredStrategy(
950      cross_device_ops=tf.distribute.NcclAllReduce())
951  ```
952  """
953
954  def __init__(self, num_packs=1):
955    """Initializes the object.
956
957    Args:
958      num_packs: a non-negative integer. The number of packs to split values
959        into. If zero, no packing will be done.
960
961    Raises:
962      ValueError: if `num_packs` is negative.
963    """
964    if num_packs < 0:
965      raise ValueError(
966          "NCCL all-reduce requires num_packs >= 0, but {} is specified".format(
967              num_packs))
968    super(NcclAllReduce, self).__init__(
969        all_reduce_alg="nccl", num_packs=num_packs)
970
971
972@tf_export("distribute.HierarchicalCopyAllReduce")
973class HierarchicalCopyAllReduce(AllReduceCrossDeviceOps):
974  """Hierarchical copy all-reduce implementation of CrossDeviceOps.
975
976  It reduces to one GPU along edges in some hierarchy and broadcasts back to
977  each GPU along the same path. For the batch API, tensors will be repacked or
978  aggregated for more efficient cross-device transportation.
979
980  This is a reduction created for Nvidia DGX-1 which assumes GPUs connects like
981  that on DGX-1 machine. If you have different GPU inter-connections, it is
982  likely that it would be slower than `tf.distribute.ReductionToOneDevice`.
983
984  For reduces that are not all-reduce, it falls back to
985  `tf.distribute.ReductionToOneDevice`.
986
987  Here is how you can use `HierarchicalCopyAllReduce` in
988  `tf.distribute.MirroredStrategy`:
989
990  ```
991    strategy = tf.distribute.MirroredStrategy(
992      cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())
993  ```
994  """
995
996  def __init__(self, num_packs=1):
997    """Initializes the object.
998
999    Args:
1000      num_packs: a non-negative integer. The number of packs to split values
1001        into. If zero, no packing will be done.
1002
1003    Raises:
1004      ValueError if `num_packs` is negative.
1005    """
1006    if num_packs < 0:
1007      raise ValueError(
1008          "HierarchicalCopy requires num_packs >= 0, but {} is specified"
1009          .format(num_packs))
1010    super(HierarchicalCopyAllReduce, self).__init__(
1011        all_reduce_alg="hierarchical_copy",
1012        num_packs=num_packs)
1013
1014
1015# TODO(crccw): remove after migrating all callers.
1016CollectiveCommunication = collective_util.CommunicationImplementation
1017CommunicationImplementation = collective_util.CommunicationImplementation
1018
1019
1020# TODO(yuefengz): support in-graph collective all-reduce.
1021class CollectiveAllReduce(CrossDeviceOps):
1022  """All-reduce cross device ops using collective ops.
1023
1024  In the between-graph replicated training, it will still do all-reduces across
1025  all workers and then put results on the right destinations.
1026  """
1027
1028  def __init__(self, devices, group_size, collective_keys=None):
1029    """Initializes the object.
1030
1031    Args:
1032      devices: a list of device strings to run collectives on.
1033      group_size: the global group size. For between-graph replicated training
1034        it's the total number of devices across all workers.
1035      collective_keys: an optional CollectiveKey object.
1036    """
1037    if group_size % len(devices) > 0:
1038      raise ValueError("group_size must be divisible by the number of devices.")
1039
1040    self._group_size = group_size
1041    self._collective_keys = (collective_keys or
1042                             cross_device_utils.CollectiveKeys())
1043    # This lock guards all collective launches, i.e. calls to
1044    # cross_device_utils.build_collectve_*.
1045    #
1046    # In a multi threaded eager program we need to ensure different groups of
1047    # collectives don't interleave each other, otherwise there could be
1048    # deadlocks. E.g. if two user threads both are launching collectives:
1049    #   user-thread-0  device0                 device1
1050    #   user-thread-1          device0 device1
1051    # In eager mode, we use one thread per device to launch collective ops, so
1052    # the above launch sequences end up with the following queues:
1053    #   device-0  collective-0  collective-1
1054    #   device-1  collective-1  collective-0
1055    # This deadlocks since neither collective is able to finish.
1056    self._lock = threading.Lock()
1057
1058    self._devices = tuple(device_util.canonicalize(d) for d in devices)
1059    group_key = self._collective_keys.get_group_key(self._devices)
1060    self._launchers = []
1061    # Whether to only use NCCL for batched all-reduce when NCCL is requested.
1062    # This is because of the lack of mechanism to order NCCL operations
1063    # deterministically.
1064    self._limited_nccl = False
1065    for device in self._devices:
1066      launcher = cross_device_utils.CollectiveReplicaLauncher(
1067          group_key, group_size, self._collective_keys, device)
1068      self._launchers.append(launcher)
1069      if not launcher.can_order_nccl():
1070        self._limited_nccl = True
1071
1072    self._pool = multiprocessing.pool.ThreadPool(len(self._devices))
1073
1074    super(CollectiveAllReduce, self).__init__()
1075
1076  @property
1077  def _num_between_graph_workers(self):
1078    # Currently we only support equal number of devices on each worker.
1079    return self._group_size / len(self._devices)
1080
1081  def _all_reduce(self, reduce_op, value, replica_id, options):
1082    """Implements CrossDeviceOps.all_reduce."""
1083    # TODO(b/122840926): reuse this method in _batch_all_reduce.
1084    flat_values = nest.flatten(value)
1085
1086    if isinstance(flat_values[0], ops.IndexedSlices):
1087      raise NotImplementedError("all_reduce doesn't support IndexedSlices.")
1088
1089    batch_size = len(flat_values)
1090
1091    implementation = options.implementation.value
1092    # If NCCL launches can't be ordered (self._limited_nccl == True), we only
1093    # use NCCL only when batch_size > 1, hoping that there's only one batched
1094    # all-reduce, which is the gradients.
1095    if (self._limited_nccl and
1096        options.implementation == CommunicationImplementation.NCCL and
1097        batch_size == 1):
1098      implementation = CommunicationImplementation.AUTO.value
1099
1100    # Reverse the lists so that there's better chance that values follows
1101    # the order in which they are calculated (e.g. when they're gradients), so
1102    # as to overlap calculation with communication. However, this may not be
1103    # optimal for cases like gradients of complicated non-sequential models.
1104    #
1105    # Note that we reverse the list before packing so that the first pack won't
1106    # be too small, since it's more likely for first few packs to have long
1107    # queuing time due to concurrent intense computation.
1108    #
1109    # TODO(b/147393503): explore solutions for optimal ordering.
1110    flat_values.reverse()
1111    packs = cross_device_utils.group_by_size(flat_values,
1112                                             options.bytes_per_pack)
1113
1114    launcher = self._launchers[replica_id]
1115    if not context.executing_eagerly() and replica_id == 0:
1116      logging.info(
1117          "Collective all_reduce: %d all-reduces, num_devices = %d, "
1118          "group_size = %d, implementation = %s, num_packs = %d", batch_size,
1119          len(self._launchers), self._group_size, implementation, len(packs))
1120    flat_results = launcher.batch_all_reduce(packs, implementation,
1121                                             options.timeout_seconds)
1122
1123    if reduce_op == reduce_util.ReduceOp.MEAN:
1124      for i, v in enumerate(flat_results):
1125        flat_results[i] = v / self._group_size
1126    flat_results.reverse()
1127
1128    return nest.pack_sequence_as(value, flat_results)
1129
1130  def reduce_implementation(self, reduce_op, per_replica_value, destinations,
1131                            options):
1132    values_util.mark_as_unsaveable()
1133    all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value],
1134                                         options)[0]
1135    devices = get_devices_from(destinations)
1136
1137    if _devices_match(per_replica_value, destinations):
1138      return all_reduced
1139
1140    # Convert `all_reduced` to a `Mirrored` object, as a simple and uniform
1141    # utility to access component for a particular device.
1142    if not isinstance(all_reduced, value_lib.Mirrored):
1143      all_reduced = value_lib.Mirrored([all_reduced])
1144
1145    # If we got this far, the destination devices do not match the all-reduce
1146    # devices, so we must map from one to the other.
1147    index = []
1148    # We must add these control dependencies, otherwise we can get deadlock.
1149    with ops.control_dependencies(all_reduced.values):
1150      for d in devices:
1151        with ops.device(d):
1152          for v in all_reduced.values:
1153            if v.device == d:
1154              index.append(array_ops.identity(v))
1155              break
1156          else:
1157            # TODO(josh11b): Once we add support for model parallelism, get the
1158            # copy from the corresponding replica instead of the primary.
1159            index.append(array_ops.identity(all_reduced._primary))  # pylint: disable=protected-access
1160    return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored)
1161
1162  def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
1163                                  options):
1164    values_util.mark_as_unsaveable()
1165    all_devices_match = _all_devices_match(value_destination_pairs)
1166    if all_devices_match:
1167      return self._batch_all_reduce(reduce_op,
1168                                    [v[0] for v in value_destination_pairs],
1169                                    options)
1170    else:
1171      if not all_devices_match:
1172        logging.log_first_n(
1173            logging.WARN, "Efficient batch_reduce is not supported if "
1174            "destinations are different.", 10)
1175
1176      return [
1177          self.reduce_implementation(reduce_op, value, dest, options)
1178          for value, dest in value_destination_pairs
1179      ]
1180
1181  def _batch_all_reduce(self, reduce_op, per_replica_values, options):
1182    """All reduce algorithm in a batch."""
1183    dense_values, dense_indices, sparse_values, sparse_indices = (
1184        cross_device_utils.split_by_sparsity(per_replica_values))
1185    if dense_values:
1186      dense_results = self._do_batch_all_reduce_dense(reduce_op, dense_values,
1187                                                      options)
1188    else:
1189      dense_results = []
1190    if sparse_values:
1191      sparse_results = self._do_batch_all_reduce_sparse(reduce_op,
1192                                                        sparse_values, options)
1193    else:
1194      sparse_results = []
1195    return cross_device_utils.stitch_values(
1196        ((dense_results, dense_indices), (sparse_results, sparse_indices)))
1197
1198  def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values, options):
1199    """All-reduce across all workers in a batch."""
1200
1201    batch_size = len(per_replica_values)
1202    implementation = options.implementation.value
1203    # For now, we use NCCL only when batch_size > 1 since we don't have a way to
1204    # order NCCL launches. We're hoping that there's only one batched
1205    # all-reduce, which is the gradients.
1206    # TODO(b/132575814): switch to NCCL for all collectives when communication
1207    # is NCCL if and only if we can order collectives deterministically.
1208    if (self._limited_nccl and
1209        options.implementation == CommunicationImplementation.NCCL and
1210        batch_size == 1):
1211      implementation = CommunicationImplementation.AUTO.value
1212
1213    # Reverse the lists so that there's better chance that values follows
1214    # the order in which they are calculated (e.g. when they're gradients), so
1215    # as to overlap calculation with communication. However, this may not be
1216    # optimal for cases like gradients of complicated non-sequential models.
1217    #
1218    # Note that we reverse the list before packing so that the first pack won't
1219    # be too small, since it's more likely for first few packs to have long
1220    # queuing time due to concurrent intense computation.
1221    #
1222    # TODO(b/147393503): explore solutions for optimal ordering.
1223    values_by_device = [[] for _ in range(len(self._devices))]
1224    for per_replica in reversed(per_replica_values):
1225      for i in range(len(self._devices)):
1226        values_by_device[i].append(per_replica.values[i])
1227
1228    if context.executing_eagerly():
1229      def thread_fn(device_id):
1230        with context.eager_mode():
1231          packs = cross_device_utils.group_by_size(values_by_device[device_id],
1232                                                   options.bytes_per_pack)
1233          return self._launchers[device_id].batch_all_reduce(
1234              packs, implementation, options.timeout_seconds)
1235
1236      num_devices = len(self._devices)
1237      with self._lock:
1238        outputs_by_device = self._pool.map(thread_fn, list(range(num_devices)))
1239    else:
1240      outputs_by_device = []
1241      with self._lock:
1242        for i in range(len(self._devices)):
1243          packs = cross_device_utils.group_by_size(
1244              values_by_device[i], options.bytes_per_pack)
1245          if i == 0:
1246            logging.info(
1247                "Collective batch_all_reduce: %d all-reduces, num_devices = %d,"
1248                " group_size = %d, implementation = %s, num_packs = %d",
1249                batch_size, len(self._launchers), self._group_size,
1250                implementation, len(packs))
1251          outputs_by_device.append(self._launchers[i].batch_all_reduce(
1252              packs, implementation, options.timeout_seconds))
1253
1254    mirrored = []
1255    for values in zip(*outputs_by_device):
1256      if reduce_op == reduce_util.ReduceOp.MEAN:
1257        values = list(values)
1258        for i, v in enumerate(values):
1259          with ops.device(v.device):
1260            values[i] = v / self._group_size
1261      mirrored.append(
1262          distribute_utils.regroup(values, wrap_class=value_lib.Mirrored))
1263    # Reverse the order of reduced value to recover the order in the input.
1264    return list(reversed(mirrored))
1265
1266  def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values, options):
1267    """All-reduce IndexedSlices across all workers in a batch."""
1268
1269    logging.log_first_n(
1270        logging.INFO, "Collective batch_all_reduce for IndexedSlices: "
1271        "%d all-reduces, group_size = %d" %
1272        (len(per_replica_values), self._group_size), 10)
1273
1274    implementation = options.implementation.value
1275    # For now, we use NCCL only when batch_size > 1.
1276    # TODO(b/132575814): switch to NCCL for all collectives when implementation
1277    # is NCCL.
1278    if (self._limited_nccl and
1279        options.implementation == CommunicationImplementation.NCCL and
1280        len(per_replica_values) == 1):
1281      implementation = CommunicationImplementation.AUTO.value
1282
1283    gathered_values = []
1284    with self._lock:
1285      for per_replica in per_replica_values:
1286        outputs = []
1287        for i in range(len(self._devices)):
1288          outputs.append(self._launchers[i].all_reduce_indexed_slices(
1289              per_replica.values[i], implementation, options.timeout_seconds))
1290        gathered_values.append(outputs)
1291
1292    mirrored = []
1293    for value in gathered_values:
1294      if reduce_op == reduce_util.ReduceOp.MEAN:
1295        # Assume each worker has the same number of replicas.
1296        for i, v in enumerate(value):
1297          with ops.device(v.device):
1298            value[i].values = value[i].values / self._group_size
1299      mirrored.append(
1300          distribute_utils.regroup(value, wrap_class=value_lib.Mirrored))
1301    return mirrored
1302
1303  def _gather_implementation(self, per_replica_value, destinations, axis,
1304                             options):
1305    all_gathered = self._batch_all_gather([per_replica_value], axis, options)[0]
1306    values_util.mark_as_unsaveable()
1307    devices = get_devices_from(destinations)
1308
1309    if _devices_match(per_replica_value, destinations):
1310      return all_gathered
1311
1312    # Convert `all_gathered` to a `Mirrored` object, as a simple and uniform
1313    # utility to access component for a particular device.
1314    if not isinstance(all_gathered, value_lib.Mirrored):
1315      all_gathered = value_lib.Mirrored([all_gathered])
1316
1317    # If we got this far, the destination devices do not match the all-gather
1318    # devices, so we must map from one to the other.
1319    index = []
1320    # We must add these control dependencies, otherwise we can get deadlock.
1321    with ops.control_dependencies(all_gathered.values):
1322      for d in devices:
1323        with ops.device(d):
1324          for v in all_gathered.values:
1325            if v.device == d:
1326              index.append(array_ops.identity(v))
1327              break
1328            else:
1329              index.append(array_ops.identity(all_gathered._primary))  # pylint: disable=protected-access
1330    return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored)
1331
1332  def _batch_all_gather(self, per_replica_values, axis, options):
1333    """all gather multiple per-replica-values."""
1334    batch_size = len(per_replica_values)
1335    # Pass options.implementation to the runtime as a communication
1336    # implementation hint.
1337    implementation = options.implementation.value
1338    # For now, we use NCCL only when batch_size > 1.
1339    # TODO(b/132575814): switch to NCCL for all collectives when implementation
1340    # is NCCL.
1341    if (options.implementation == CommunicationImplementation.NCCL and
1342        batch_size == 1):
1343      implementation = CommunicationImplementation.AUTO.value
1344
1345    logging.log_first_n(
1346        logging.INFO, "Collective batch_all_gather: %d all-gathers, "
1347        "num_devices = %d, group_size = %d, implementation = %s, " %
1348        (batch_size, len(self._devices), self._group_size, implementation), 10)
1349
1350    def compute_gathered_values():
1351      gathered_values = []
1352      with self._lock, ops.name_scope("allgather"):
1353        for per_replica in per_replica_values:
1354          outputs = []
1355          for i in range(len(self._devices)):
1356            outputs.append(self._launchers[i].all_gather(
1357                per_replica.values[i], axis, implementation,
1358                options.timeout_seconds))
1359          gathered_values.append(outputs)
1360      return gathered_values
1361
1362    if context.executing_eagerly():
1363      gathered_values = def_function.function(compute_gathered_values)()
1364    else:
1365      gathered_values = compute_gathered_values()
1366
1367    mirrored = []
1368    for value in gathered_values:
1369      mirrored.append(
1370          distribute_utils.regroup(value, wrap_class=value_lib.Mirrored))
1371    return mirrored
1372
1373  def __deepcopy__(self, memo):
1374    # distribute_coordinator deep-copies the strategy object, so
1375    # CollectiveAllReduce needs to support deep copy as well.
1376    collective_keys = copy.deepcopy(self._collective_keys, memo)
1377    return CollectiveAllReduce(self._devices, self._group_size, collective_keys)
1378
1379
1380def select_cross_device_ops(devices, session_config=None):
1381  """Find the best `CrossDeviceOps` locally given a `tf.compat.v1.ConfigProto`.
1382
1383  Args:
1384    devices: a list of devices passed to `tf.distribute.Strategy`.
1385    session_config: a `tf.compat.v1.ConfigProto` or `None`. If `None`, it will
1386      make decision based on all logical devices.
1387
1388  Returns:
1389    A subclass of `CrossDeviceOps`.
1390  """
1391  requested_devices = set(device_util.canonicalize(d) for d in devices)
1392  if ops.executing_eagerly_outside_functions():
1393    logical_gpus = context.context().list_logical_devices(device_type="GPU")
1394    physical_gpus = context.context().list_physical_devices(device_type="GPU")
1395    if len(logical_gpus) != len(physical_gpus):
1396      logging.warning("NCCL is not supported when using virtual GPUs, falling"
1397                      "back to reduction to one device")
1398      return ReductionToOneDevice()
1399
1400    machine_devices = context.context().list_logical_devices()
1401  else:
1402    machine_devices = device_lib.list_local_devices(
1403        session_config=session_config)
1404  using_devices = set()
1405  for d in machine_devices:
1406    if device_util.canonicalize(d.name) in requested_devices:
1407      using_devices.add(d.name)
1408
1409  if len(using_devices) != len(requested_devices):
1410    logging.warning(
1411        "Some requested devices in `tf.distribute.Strategy` are not visible "
1412        "to TensorFlow: %s", ",".join(list(requested_devices - using_devices)))
1413
1414  if any("gpu" not in d.lower() for d in requested_devices):
1415    logging.warning("There are non-GPU devices in `tf.distribute.Strategy`, "
1416                    "not using nccl allreduce.")
1417    return ReductionToOneDevice()
1418
1419  if kernels.get_registered_kernels_for_op("NcclAllReduce"):
1420    return NcclAllReduce(num_packs=1)
1421  else:
1422    logging.warning("Nccl kernel is not found, not using nccl allreduce.")
1423    return ReductionToOneDevice()
1424