1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes for different algorithms of reduction and broadcasting."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import enum
23import six
24
25from tensorflow.python.client import device_lib
26from tensorflow.python.distribute import cross_device_utils
27from tensorflow.python.distribute import device_util
28from tensorflow.python.distribute import reduce_util
29from tensorflow.python.distribute import values as value_lib
30from tensorflow.python.eager import context
31from tensorflow.python.framework import ops
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import resource_variable_ops
35from tensorflow.python.platform import tf_logging as logging
36from tensorflow.python.util.tf_export import tf_export
37from tensorflow.tools.docs import doc_controls
38
39
40def check_destinations(destinations):
41  """Checks whether `destinations` is not empty.
42
43  Args:
44    destinations: a `DistributedValues`, variable, or string object.
45
46  Returns:
47    Boolean which is True if `destinations` is not empty.
48  """
49  # Calling bool() on a ResourceVariable is not allowed.
50  if isinstance(destinations, resource_variable_ops.ResourceVariable):
51    return bool(destinations.device)
52  return bool(destinations)
53
54
55def validate_destinations(destinations):
56  if not isinstance(destinations,
57                    (value_lib.DistributedValues,
58                     resource_variable_ops.ResourceVariable,
59                     value_lib.AggregatingVariable,
60                     six.string_types,
61                     value_lib.TPUMirroredVariable,
62                     # LogicalDeviceSpec is only used internally, e.g. as a
63                     # broadcast destination, never supplied by a user.
64                     value_lib.LogicalDeviceSpec)):
65    raise ValueError("destinations must be one of a `DistributedValues` object,"
66                     " a tf.Variable object, or a device string.")
67
68  if not check_destinations(destinations):
69    raise ValueError("destinations can not be empty")
70
71
72def reduce_non_distributed_value(reduce_op, device_map, value, destinations):
73  """Reduce a non-DistributedValue `value` to `destinations`."""
74  if isinstance(value, value_lib.DistributedValues):
75    raise ValueError("You are passing a `DistributedValue` to "
76                     "`reduce_non_distributed_value`, which is not allowed.")
77
78  # If the same value is present on all replicas then the PerReplica value will
79  # be a single value. We also handle the case when `value` is a single value
80  # and equal to 0.
81  if value == 0:
82    return 0
83  # If there is only a single value and the reduce op is MEAN,
84  # that value should be on all destinations.
85  if reduce_op == reduce_util.ReduceOp.MEAN:
86    return value
87
88  validate_destinations(destinations)
89  # We do not support a reduce op of SUM if the value is the same across
90  # all replicas. We call this as part of assign functions for MirroredVariables
91  # and summing up identical values across replicas is not clearly defined.
92  if device_map.num_replicas_in_graph != 1:
93    raise ValueError("A non-DistributedValues value %s cannot be reduced with "
94                     "the given reduce op %s." % (value, reduce_op))
95  return simple_broadcast(value, destinations)
96
97
98def _make_tensor_into_per_replica(input_tensor):
99  """Converts a single tensor into a PerReplica object."""
100  if isinstance(input_tensor, (tuple, list)):
101    raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object, "
102                     "got %r but expected a object that is not a tuple or list."
103                     % (input_tensor,))
104  if isinstance(input_tensor, value_lib.PerReplica):
105    return input_tensor
106
107  try:
108    device = input_tensor.device
109  except AttributeError:
110    raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object "
111                     "because it doesn't have device set.")
112
113  device_map = value_lib.SingleDeviceMap(device)
114  return value_lib.PerReplica(device_map, (input_tensor,))
115
116
117def _normalize_value_destination_pairs(value_destination_pairs):
118  """Converts each tensor into a PerReplica object in the input list."""
119  result = []
120
121  value_destination_pairs = list(value_destination_pairs)
122
123  if not isinstance(value_destination_pairs, (list, tuple)):
124    raise ValueError("`value_destination_pairs` should be a list or tuple")
125  for pair in value_destination_pairs:
126    if not isinstance(pair, tuple):
127      raise ValueError(
128          "Each element of `value_destination_pairs` should be a tuple.")
129    if len(pair) != 2:
130      raise ValueError("Each element of `value_destination_pairs` should be a "
131                       "tuple of size 2.")
132
133    per_replica = _make_tensor_into_per_replica(pair[0])
134    result.append((per_replica, pair[1]))
135  return result
136
137
138def _validate_value_destination_pairs(value_destination_pairs):
139  # TODO(yuefengz): raise exceptions instead of returning False.
140  # pylint: disable=g-missing-docstring
141  if not value_destination_pairs: return False
142  if not isinstance(value_destination_pairs, (list, tuple)): return False
143  if not all(isinstance(pair, tuple) for pair in value_destination_pairs):
144    return False
145  if not all(isinstance(v[0], value_lib.PerReplica)
146             for v in value_destination_pairs):
147    return False
148  return True
149
150
151# TODO(yuefengz): consider calling this function in the caller of
152# CrossDeviceOps.
153def get_devices_from(destinations):
154  if isinstance(destinations, value_lib.DistributedValues):
155    return destinations.devices
156  elif isinstance(destinations, value_lib.LogicalDeviceSpec):
157    return destinations.device_map.logical_to_actual_devices(
158        destinations.logical_device)
159  elif isinstance(destinations, six.string_types):
160    return (device_util.resolve(destinations),)
161  return (destinations.device,)
162
163
164def get_device_map_from(destinations):
165  if isinstance(destinations, (value_lib.DistributedValues,
166                               value_lib.LogicalDeviceSpec)):
167    return destinations.device_map, destinations.logical_device
168  if isinstance(destinations, six.string_types):
169    device = device_util.resolve(destinations)
170  else:
171    device = destinations.device
172  return value_lib.SingleDeviceMap(device), 0
173
174
175def _devices_match(left, right):
176  return set(get_devices_from(left)) == set(get_devices_from(right))
177
178
179def _all_devices_match(value_destination_pairs):
180  if not all(_devices_match(v, d) for v, d in value_destination_pairs):
181    return False
182  if not all(_devices_match(v, value_destination_pairs[0][0])
183             for v, _ in value_destination_pairs[1:]):
184    return False
185  return True
186
187
188def simple_broadcast(value, destinations, always_mirrored=False):
189  """Broadcast `value` to `destinations` using simple copies."""
190  device_map, logical_device = get_device_map_from(destinations)
191  devices = device_map.logical_to_actual_devices(logical_device)
192  if len(devices) == 1 and not always_mirrored:
193    return cross_device_utils.copy_tensor_or_indexed_slices_to_device(
194        value, devices[0])
195  else:
196    value_updates = []
197    for d in devices:
198      value_updates.append(
199          cross_device_utils.copy_tensor_or_indexed_slices_to_device(
200              value, d))
201    return value_lib.Mirrored(device_map, value_updates, logical_device)
202
203
204def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
205                   reduce_op):
206  # pylint: disable=g-missing-docstring
207  all_values = per_replica_value.values
208  if not all_values:
209    raise ValueError("`per_replica_value` must be non-empty")
210  count = len(all_values)
211
212  with ops.device(reduce_to_device):
213    with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
214      reduced = cross_device_utils.aggregate_tensors_or_indexed_slices(
215          all_values, accumulation_fn)
216      if reduce_op == reduce_util.ReduceOp.MEAN:
217        reduced = cross_device_utils.divide_by_n_tensors_or_indexed_slices(
218            reduced, count)
219      elif reduce_op != reduce_util.ReduceOp.SUM:
220        raise ValueError("`reduce_op` must be Reduce.SUM or Reduce.MEAN.")
221  return reduced
222
223
224@tf_export("distribute.CrossDeviceOps")
225class CrossDeviceOps(object):
226  """Base class for cross-device reduction and broadcasting algorithms."""
227
228  def __init__(self):
229    pass
230
231  def reduce(self, reduce_op, per_replica_value, destinations):
232    """Reduce `per_replica_value` to `destinations`.
233
234    It runs the reduction operation defined by `reduce_op` and put the
235    result on `destinations`.
236
237    Args:
238      reduce_op: Indicates how per_replica_value will be reduced. Accepted
239        values are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`.
240      per_replica_value: a PerReplica object or a tensor with device set.
241      destinations: the reduction destinations.
242
243    Returns:
244      a Mirrored object.
245
246    Raises:
247      ValueError: if per_replica_value can't be converted to a PerReplica
248        object.
249    """
250    if not isinstance(per_replica_value, value_lib.PerReplica):
251      per_replica_value = _make_tensor_into_per_replica(per_replica_value)
252
253    validate_destinations(destinations)
254    return self.reduce_implementation(reduce_op, per_replica_value,
255                                      destinations)
256
257  def batch_reduce(self, reduce_op, value_destination_pairs):
258    """Reduce PerReplica objects in a batch.
259
260    Reduce each first element in `value_destination_pairs` to each second
261    element which indicates the destinations.
262
263    Args:
264      reduce_op: Indicates how per_replica_value will be reduced. Accepted
265        values are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`.
266      value_destination_pairs: a list or a tuple of tuples of PerReplica objects
267        (or tensors with device set if there is one device) and destinations.
268
269    Returns:
270      a list of Mirrored objects.
271
272    Raises:
273      ValueError: if `value_destination_pairs` is not a list or a tuple of
274        tuples of PerReplica objects and destinations
275    """
276    # TODO(yuefengz): if destinations are different, split into several
277    # `_batch_reduce` invocations.
278    if not _validate_value_destination_pairs(value_destination_pairs):
279      # If the first element of each pair is a tensor, we try to turn it into a
280      # PerReplica object.
281      value_destination_pairs = _normalize_value_destination_pairs(
282          value_destination_pairs)
283
284    for _, d in value_destination_pairs:
285      validate_destinations(d)
286
287    return self.batch_reduce_implementation(reduce_op, value_destination_pairs)
288
289  def broadcast(self, tensor, destinations):
290    """Broadcast the `tensor` to destinations.
291
292    Args:
293      tensor: the tensor to broadcast.
294      destinations: the broadcast destinations.
295
296    Returns:
297      a Mirrored object.
298    """
299    validate_destinations(destinations)
300    return self.broadcast_implementation(tensor, destinations)
301
302  @doc_controls.for_subclass_implementers
303  def reduce_implementation(self, reduce_op, per_replica_value, destinations):
304    """The implementation of reduce of `per_replica_value` to `destinations`.
305
306    It runs the reduction operation defined by `reduce_op` and put the
307    result on `destinations`.
308
309    Args:
310      reduce_op: Indicates how per_replica_value will be reduced. Accepted
311        values are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`.
312      per_replica_value: a PerReplica object or a tensor with device set.
313      destinations: the reduction destinations.
314
315    Returns:
316      a Mirrored object.
317
318    Raises:
319      ValueError: if per_replica_value can't be converted to a PerReplica
320        object.
321    """
322    raise NotImplementedError(
323        "_reduce method must be implemented in descendants.")
324
325  @doc_controls.for_subclass_implementers
326  def batch_reduce_implementation(self, reduce_op, value_destination_pairs):
327    """Implementation of reduce PerReplica objects in a batch.
328
329    Reduce each first element in `value_destination_pairs` to each second
330    element which indicates the destinations.
331
332    Args:
333      reduce_op: Indicates how per_replica_value will be reduced. Accepted
334        values are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`.
335      value_destination_pairs: a list or a tuple of tuples of PerReplica objects
336        (or tensors with device set if there is one device) and destinations.
337
338    Returns:
339      a list of Mirrored objects.
340
341    Raises:
342      ValueError: if `value_destination_pairs` is not a list or a tuple of
343        tuples of PerReplica objects and destinations
344    """
345    raise NotImplementedError(
346        "_batch_reduce method must be implemented in descendants.")
347
348  @doc_controls.for_subclass_implementers
349  def broadcast_implementation(self, tensor, destinations):
350    """Implementation of broadcast the `tensor` to destinations.
351
352    Args:
353      tensor: the tensor to broadcast.
354      destinations: the broadcast destinations.
355
356    Returns:
357      a Mirrored object.
358    """
359    return simple_broadcast(tensor, destinations, always_mirrored=True)
360
361
362@tf_export("distribute.ReductionToOneDevice")
363class ReductionToOneDevice(CrossDeviceOps):
364  """Always do reduction to one device first and then do broadcasting.
365
366    Batch reduction is done by reduction on each element one by one.
367  """
368
369  def __init__(self, reduce_to_device=None, accumulation_fn=None):
370    """Constructor.
371
372    Args:
373      reduce_to_device: the intermediate device to reduce to. If None, reduce
374        to the first device in `destinations` of the reduce() method.
375      accumulation_fn: a function that does accumulation.  If None, then
376        `tf.math.add_n` is used.
377    """
378    self.reduce_to_device = reduce_to_device
379    self.accumulation_fn = accumulation_fn or math_ops.add_n
380    super(ReductionToOneDevice, self).__init__()
381
382  def reduce_implementation(self, reduce_op, per_replica_value, destinations):
383    if check_destinations(destinations):
384      devices = get_devices_from(destinations)
385    else:
386      devices = get_devices_from(per_replica_value)
387    reduce_to_device = self.reduce_to_device or devices[0]
388    logging.log_first_n(
389        logging.INFO,
390        "Reduce to %s then broadcast to %r." % (reduce_to_device, devices), 10)
391    reduced = _simple_reduce(per_replica_value, reduce_to_device,
392                             self.accumulation_fn, reduce_op)
393    return self.broadcast(reduced, destinations)
394
395  def batch_reduce_implementation(self, reduce_op, value_destination_pairs):
396    return [
397        self.reduce_implementation(reduce_op, t, destinations=v)
398        for t, v in value_destination_pairs
399    ]
400
401
402def _group_value_by_device(per_replica_values):
403  """Group values into sublists by their devices.
404
405  This grouping is needed to call the all-reduce library because it expects a
406  list of the following form:
407    [[(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ...],
408     [(grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ...],
409     [(grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ...],
410     ...
411    ]
412
413  Args:
414    per_replica_values: a list of PerReplica obejcts.
415
416  Returns:
417    a list of lists, each sublist has components for its corresponding device of
418      PerReplica objects, paired with a None.
419  """
420  destinations = per_replica_values[0].devices
421  grouped = [[] for _ in range(len(destinations))]
422  for per_replica_value in per_replica_values:
423    # pylint: disable=protected-access
424    for i, v in enumerate(per_replica_value.values):
425      assert per_replica_value.devices == destinations
426      grouped[i].append((v, None))
427  return grouped
428
429
430def _ungroup_and_make_mirrored(grouped_reduced,
431                               destinations,
432                               reduce_op,
433                               num_between_graph_workers=1):
434  """Ungroup results from all-reduce and make Mirrored objects.
435
436  Each all-reduce result will be divided by the number of destinations before
437  Mirrored objects are created if reduce_op is "mean".
438
439  Args:
440    grouped_reduced: a list of lists, each sublist has components for each
441      device, paired with a None. It is the result from
442      cross_device_utils.aggregate_gradients_using*.
443    destinations: a value to colocate the result with.
444    reduce_op: Indicates how values will be aggregated. Accepted values
445      are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`.
446    num_between_graph_workers: number of workers in the between-graph
447      replication.
448
449  Returns:
450    a list of Mirrored objects.
451  """
452  device_map, logical_device = get_device_map_from(destinations)
453  num_replicas = device_map.num_replicas_in_graph * num_between_graph_workers
454  index = [[] for _ in range(len(grouped_reduced[0]))]
455  for per_replica_reduced in grouped_reduced:
456    for i, (v, _) in enumerate(per_replica_reduced):
457      if reduce_op == reduce_util.ReduceOp.MEAN:
458        index[i].append(v / num_replicas)
459      else:
460        index[i].append(v)
461  return [value_lib.Mirrored(device_map, v, logical_device) for v in index]
462
463
464class _ConcatAndSplitPacker(object):
465  """Concatenate and split tensors for reduction."""
466
467  def __init__(self, num_packs=1):
468    """Initialize the _ConcatAndSplitPacker object.
469
470    Args:
471      num_packs: specifies the number of split packs that will be
472        formed.
473
474    Raises:
475      ValueError: if num_packs is not greater than 0.
476    """
477    if num_packs <= 0:
478      raise ValueError("num_packs must be greater than zero.")
479    self.num_packs = num_packs
480
481  def pack(self, grouped_grads_and_vars):
482    """Pack tensors."""
483    self.grouped_grads_and_vars = grouped_grads_and_vars
484    self.all_device_shapes = []
485    self.all_device_sizes = []
486
487    device_grad_packs = []
488    for device_grads_and_vars in grouped_grads_and_vars:
489      with ops.colocate_with(device_grads_and_vars[0][0]):
490        # Flatten all the grads.
491        flat_grads = [
492            array_ops.reshape(g, [-1]) for g, _ in device_grads_and_vars
493        ]
494        # Remember the original shape of all the grads.
495        device_shapes = [array_ops.shape(g) for g, _ in device_grads_and_vars]
496        # Remember the original sizes of all the grads.
497        device_sizes = [array_ops.size(g) for g, _ in device_grads_and_vars]
498        # Concat all the flat grads into a big flat tensor.
499        concat_grads = array_ops.concat(flat_grads, 0)
500
501        # Split the big tensor into num_splits packs. In cases where the
502        # total size is not divisible num_splits, the last pack gets
503        # more elements.
504        # TODO(zhengxq): it is also possible to optimize away all the concat
505        # as well.
506        num_splits = self.num_packs
507
508        # The array_ops.size function will sometimes remove static shapes. So if
509        # all gradient shapes are defined, we use another method to get the
510        # total size.
511        # TODO(yuefengz): move this logic to array_ops.size.
512        if all(g.shape.is_fully_defined() for g, _ in device_grads_and_vars):
513          total_grad_size = sum(
514              [g.shape.num_elements() for g, _ in device_grads_and_vars])
515        else:
516          total_grad_size = array_ops.size(concat_grads)
517
518        split_size = total_grad_size // num_splits
519        split_size_last = total_grad_size - split_size * (num_splits - 1)
520        split_sizes = [split_size] * (num_splits - 1) + [split_size_last]
521        grad_packs = array_ops.split(concat_grads, split_sizes)
522
523        # Ready to aggregate the repacked gradients, with fake variables.
524        # TODO(zhengxq): It is hacky to have to use fake variables.
525        # We should remove the need for variables in
526        # aggregate_gradients_using*.
527        device_grad_packs.append(zip(grad_packs, [None] * num_splits))
528        self.all_device_shapes.append(device_shapes)
529        self.all_device_sizes.append(device_sizes)
530
531    return device_grad_packs
532
533  def unpack(self, summed_device_grad_packs):
534    """Reverse the pack."""
535    aggregated_device_grads = []
536    for (summed_device_grad_packs,
537         device_grads_and_vars, device_shapes, device_sizes) in zip(
538             summed_device_grad_packs, self.grouped_grads_and_vars,
539             self.all_device_shapes, self.all_device_sizes):
540      # pylint: enable=line-too-long
541      # Reverse the packing operations in the previous steps. Form the
542      # summed gradients back into their original shapes.
543      with ops.colocate_with(summed_device_grad_packs[0][0]):
544        # Form a list of the summed grad packs.
545        device_grad_packs = [g for g, _ in summed_device_grad_packs]
546
547        # Concat them back into a big flat tensor.
548        device_grads_concat = array_ops.concat(device_grad_packs, 0)
549
550        # Split the tensors back into their original sizes.
551        grads_with_sizes = array_ops.split(device_grads_concat, device_sizes)
552
553        # Reshape the tensors back into their original shapes.
554        grads_with_shapes = [
555            array_ops.reshape(grad, shape)
556            for shape, grad in zip(device_shapes, grads_with_sizes)
557        ]
558
559        # Form the list with the original list of variables.
560        summed_device_grads = [
561            (g, v) for g, (_, v) in zip(grads_with_shapes,
562                                        device_grads_and_vars)
563        ]
564        aggregated_device_grads.append(summed_device_grads)
565    return aggregated_device_grads
566
567
568class _AggregateSmallTensorPacker(object):
569  """Concatenate small gradient tensors together for reduction."""
570
571  def __init__(self,
572               agg_small_grads_max_bytes=1048576,
573               agg_small_grads_max_group=16):
574    """Initialize the _AggregateSmallTensorPacker object.
575
576    Args:
577      agg_small_grads_max_bytes: largest tensor eligible for aggregation,
578        in number of bytes.
579      agg_small_grads_max_group: largest permitted aggregation of small
580        tensors.
581
582    Raises:
583      ValueError: if `agg_small_grads_max_bytes` or `agg_small_grads_max_group`
584        is not greater than 0.
585    """
586    if agg_small_grads_max_bytes <= 0 or agg_small_grads_max_group <= 0:
587      raise ValueError("agg_small_grads_max_bytes and agg_small_grads_max_group"
588                       " should both be greater than zero.")
589    self.agg_small_grads_max_bytes = agg_small_grads_max_bytes
590    self.agg_small_grads_max_group = agg_small_grads_max_group
591
592  def pack(self, grouped_grads_and_vars):
593    """Aggregate small tensors."""
594    if (self.agg_small_grads_max_bytes > 0 and
595        self.agg_small_grads_max_group > 0):
596      device_grads, self.packing = cross_device_utils.pack_small_tensors(
597          grouped_grads_and_vars,
598          max_bytes=self.agg_small_grads_max_bytes,
599          max_group=self.agg_small_grads_max_group)
600    return device_grads
601
602  def unpack(self, summed_device_grad_packs):
603    """Reverse the aggregation process."""
604    return cross_device_utils.unpack_small_tensors(summed_device_grad_packs,
605                                                   self.packing)
606
607
608def _pack_tensors(device_grads,
609                  num_packs=0,
610                  agg_small_grads_max_bytes=0,
611                  agg_small_grads_max_group=0):
612  """Pack tensors if specified."""
613  if num_packs > 0:
614    tensor_packer = _ConcatAndSplitPacker(num_packs)
615    device_grad_packs = tensor_packer.pack(device_grads)
616  elif agg_small_grads_max_bytes > 0 and agg_small_grads_max_group > 0:
617    tensor_packer = _AggregateSmallTensorPacker(agg_small_grads_max_bytes,
618                                                agg_small_grads_max_group)
619    device_grad_packs = tensor_packer.pack(device_grads)
620  else:
621    tensor_packer = None
622    device_grad_packs = device_grads
623  return device_grad_packs, tensor_packer
624
625
626def _unpack_tensors(reduced, tensor_packer=None):
627  """Unpack tensors if they are packed before all-reduce."""
628  if tensor_packer:
629    return tensor_packer.unpack(reduced)
630  return reduced
631
632
633class AllReduceCrossDeviceOps(CrossDeviceOps):
634  """Reduction using all-reduce."""
635
636  def __init__(self,
637               all_reduce_alg="nccl",
638               num_packs=1,
639               agg_small_grads_max_bytes=0,
640               agg_small_grads_max_group=10):
641    """All-reduce implementation of CrossDeviceOps.
642
643    Before performing all-reduce, tensors will be repacked or aggregated for
644    more efficient cross-device transportation:
645      1) If `num_packs` is non-zero, pack values into
646        `num_packs` splits.
647      2) Otherwise, if `agg_small_grads_max_bytes` > 0 and
648        `agg_small_grads_max_group` > 0, aggregate values smaller than
649        `agg_small_grads_max_bytes` into groups with at most
650        `agg_small_grads_max_group` values.
651      3) Otherwise, no repacking or grouping will happen.
652
653    Args:
654      all_reduce_alg: the all-reduce algorithm to use, currently only "nccl" or
655        "hierarchical_copy" are supported.
656      num_packs: see above.
657      agg_small_grads_max_bytes: see above.
658      agg_small_grads_max_group: see above.
659    """
660    self._all_reduce_alg = all_reduce_alg
661    self._num_packs = num_packs
662    self._agg_small_grads_max_bytes = agg_small_grads_max_bytes
663    self._agg_small_grads_max_group = agg_small_grads_max_group
664    self._simple_cross_replica_ops = ReductionToOneDevice()
665    super(AllReduceCrossDeviceOps, self).__init__()
666
667  def reduce_implementation(self, reduce_op, per_replica_value, destinations):
668    if _devices_match(per_replica_value, destinations):
669      return self._batch_all_reduce(reduce_op, [per_replica_value])[0]
670    else:
671      return self._simple_cross_replica_ops.reduce(reduce_op, per_replica_value,
672                                                   destinations)
673
674  def batch_reduce_implementation(self, reduce_op, value_destination_pairs):
675    all_devices_match = _all_devices_match(value_destination_pairs)
676    contains_indexed_slices = cross_device_utils.contains_indexed_slices(
677        value_destination_pairs)
678    if (all_devices_match and not context.executing_eagerly()
679        and not contains_indexed_slices):
680      return self._batch_all_reduce(reduce_op,
681                                    [v[0] for v in value_destination_pairs])
682    else:
683      if not all_devices_match:
684        logging.log_first_n(logging.WARN,
685                            "Efficient batch_reduce is not supported if "
686                            "destinations are different.",
687                            10)
688
689      return [
690          self.reduce_implementation(reduce_op, t, destinations=v)
691          for t, v in value_destination_pairs
692      ]
693
694  def _batch_all_reduce(self, reduce_op, per_replica_values):
695    """All-reduce algorithm in a batch."""
696    dense_values, dense_indices, sparse_values, sparse_indices = (
697        cross_device_utils.split_by_sparsity(per_replica_values))
698    if dense_values:
699      dense_results = self._do_batch_all_reduce(reduce_op, dense_values)
700    else:
701      dense_results = []
702    if sparse_values:
703      sparse_results = self._do_batch_all_reduce_sparse(reduce_op,
704                                                        sparse_values)
705    else:
706      sparse_results = []
707    return cross_device_utils.stitch_values(((dense_results, dense_indices),
708                                             (sparse_results, sparse_indices)))
709
710  def _do_batch_all_reduce(self, reduce_op, dense_values):
711    """Run batch all-reduces."""
712    logging.log_first_n(
713        logging.INFO, "batch_all_reduce invoked for batches size = %d with "
714        "algorithm = %s, num_packs = %d, agg_small_grads_max_bytes = %d and "
715        "agg_small_grads_max_group = %d" %
716        (len(dense_values), self._all_reduce_alg, self._num_packs,
717         self._agg_small_grads_max_bytes, self._agg_small_grads_max_group), 10)
718
719    destinations = dense_values[0].devices
720    grouped = _group_value_by_device(dense_values)
721
722    device_grad_packs, tensor_packer = _pack_tensors(
723        grouped, self._num_packs, self._agg_small_grads_max_bytes,
724        self._agg_small_grads_max_group)
725
726    # The actual aggregation of the repacked gradients. Note that they are
727    # sharded among different aggregation trees. So it is important to strike
728    # the balance on num_splits.
729    if self._all_reduce_alg == "nccl":
730      # TODO(yuefengz): merge this into the all-reduce library.
731      reduced = cross_device_utils.aggregate_gradients_using_nccl(
732          device_grad_packs)
733    else:
734      # TODO(yuefengz): check that gpu ids in `destinations` are in ascending
735      # order.
736      reduced = (
737          cross_device_utils.aggregate_gradients_using_hierarchical_copy(
738              destinations, device_grad_packs))
739
740    reduced = _unpack_tensors(reduced, tensor_packer)
741    return _ungroup_and_make_mirrored(reduced, dense_values[0], reduce_op)
742
743  def _do_batch_all_reduce_sparse(self, reduce_op, sparse_values):
744    """Run batch all-reduce for sparse values."""
745    logging.log_first_n(
746        logging.WARN,
747        "Efficient allreduce is not supported for %d IndexedSlices" %
748        len(sparse_values), 10)
749    # Use `sparse_values` as destinations to do all-reduces. It is effectively
750    # an allgather under the hood but not an efficient one.
751    return self._simple_cross_replica_ops.batch_reduce(
752        reduce_op, zip(sparse_values, sparse_values))
753
754
755# For compatibility with code using the old name of `AllReduceCrossDeviceOps`.
756AllReduceCrossTowerOps = AllReduceCrossDeviceOps
757
758
759AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple",
760                                            "alg shards limit")
761
762
763@tf_export("distribute.NcclAllReduce")
764class NcclAllReduce(AllReduceCrossDeviceOps):
765  """Reduction using NCCL all-reduce."""
766
767  def __init__(self, num_packs=1):
768    """NCCL all-reduce implementation of CrossDeviceOps.
769
770    Before performing all-reduce, tensors will be repacked or aggregated for
771    more efficient cross-device transportation.
772
773    Args:
774      num_packs: values will be packed in this many splits.  `num_packs` should
775        be greater than 0.
776    """
777    assert num_packs > 0, (
778        "NCLL all-reduce requires num_packs > 0, but {} is specified".format(
779            num_packs))
780    super(NcclAllReduce, self).__init__(
781        all_reduce_alg="nccl", num_packs=num_packs)
782
783
784@tf_export("distribute.HierarchicalCopyAllReduce")
785class HierarchicalCopyAllReduce(AllReduceCrossDeviceOps):
786  """Reduction using hierarchical copy all-reduce.
787
788  This is a good reduction for configurations like Nvidia DGX-1.
789  """
790
791  def __init__(self, num_packs=1):
792    """Hierarchical copy all-reduce implementation of CrossDeviceOps.
793
794    Before performing all-reduce, tensors will be repacked or aggregated for
795    more efficient cross-device transportation.
796
797    Args:
798      num_packs: values will be packed in this many splits.  `num_packs` should
799        be greater than 0.
800    """
801    super(HierarchicalCopyAllReduce, self).__init__(
802        all_reduce_alg="hierarchical_copy",
803        num_packs=num_packs)
804
805
806class MultiWorkerAllReduce(AllReduceCrossDeviceOps):
807  """All-reduce algorithms for distributed TensorFlow."""
808
809  def __init__(self,
810               worker_devices,
811               num_gpus_per_worker,
812               all_reduce_spec=("pscpu/pscpu", 2, -1),
813               num_packs=0,
814               agg_small_grads_max_bytes=0,
815               agg_small_grads_max_group=10):
816    """Initialize the all-reduce algorithm.
817
818    Args:
819      worker_devices: a list of device strings for workers participating in
820        all-reduce.
821      num_gpus_per_worker: number of GPU devices per worker.
822      all_reduce_spec: a tuple or a named tuple or a list of tuples specifying
823        the all-reduce algorithm.
824        1. The first element of a tuple is the name of the all-reduce algorithm.
825        Valid algorithm names are: "nccl", "nccl/xring", "nccl/rechd",
826        "nccl/pscpu", "xring", "pscpu", "psgpu", "pscpu/pscpu". Algorithms with
827        a "/" are hierarchical, so two all-reduces are executed, the first one
828        aggregates tensors within a worker and the second aggregates across
829        workers.
830        2. The second element of a tuple is the number of shards when doing
831        all-reduce. Let's say its values is M, each tensor after packing will be
832        split into M shards and then M parallel all-reduces would be performed
833        before finally they are concatenated backed into a complete tensor.
834        3. The third element is the maximum size of tensors that will be
835        applicable for the algorithm specified by the first element. For
836        example, if all_reduce_spec=[("nccl", 2, 1024), ("pscpu/pscpu", 2, -1)],
837        tensors with size not larger than 1024 bytes will be applied a 2-shard
838        "nccl" all-reduce and other tensors will be applied a 2-shard
839        "pscpu/pscpu" algorithm. The third elements should be in increasing
840        order across tuples and end with -1 which indicates infinity.
841      num_packs: see AllReduceCrossDeviceOps.
842      agg_small_grads_max_bytes: see AllReduceCrossDeviceOps.
843      agg_small_grads_max_group: see AllReduceCrossDeviceOps.
844    """
845    self._worker_devices = worker_devices
846    self._num_gpus_per_worker = num_gpus_per_worker
847    super(MultiWorkerAllReduce, self).__init__(
848        num_packs=num_packs,
849        agg_small_grads_max_bytes=agg_small_grads_max_bytes,
850        agg_small_grads_max_group=agg_small_grads_max_group)
851
852    def validate_and_complete_spec(spec):
853      """Validate and complete the all-reduce spec."""
854      # TODO(yuefengz): support namedtuple.
855      if not isinstance(spec, tuple):
856        raise ValueError(
857            "A tuple is expected for all-reduce spec: %r" % all_reduce_spec)
858      if not spec or len(spec) > 3:
859        raise ValueError(
860            "Too many elements in the all-reduce spec tuple: %r" % spec)
861      if len(spec) == 1:
862        return AllReduceSpecTuple(spec[0], 1, -1)
863      elif len(spec) == 2:
864        return AllReduceSpecTuple(spec[0], spec[1], -1)
865      else:
866        return AllReduceSpecTuple(*spec)
867
868    self._all_reduce_spec = []
869    if isinstance(all_reduce_spec, six.string_types):
870      self._all_reduce_spec.append(AllReduceSpecTuple(all_reduce_spec, 1, -1))
871    elif isinstance(all_reduce_spec, tuple):
872      self._all_reduce_spec.append(validate_and_complete_spec(all_reduce_spec))
873    elif isinstance(all_reduce_spec, list):
874      self._all_reduce_spec = [
875          validate_and_complete_spec(spec) for spec in all_reduce_spec
876      ]
877
878  def _batch_all_reduce(self, reduce_op, per_replica_values):
879    """All-reduce algorithm in a batch."""
880    logging.log_first_n(
881        logging.INFO,
882        "distributed batch_all_reduce invoked for batches size = %d with "
883        "allreduce_spec = %r, num_packs = %d, agg_small_grads_max_bytes = %d "
884        "and agg_small_grads_max_group = %d" %
885        (len(per_replica_values), self._all_reduce_spec, self._num_packs,
886         self._agg_small_grads_max_bytes, self._agg_small_grads_max_group), 10)
887
888    device_grads = _group_value_by_device(per_replica_values)
889
890    # The all-reduce library requires fully defined shapes.
891    # TODO(yuefengz): when tensor sharding is not needed, static shapes are not
892    # required as well.
893    for device_grad in device_grads:
894      for grad, _ in device_grad:
895        if not grad.shape.is_fully_defined():
896          raise ValueError("Shape is unknown for node %r" % grad)
897
898    remaining_grads = device_grads
899    aggregated_grads = []
900    for spec_tuple in self._all_reduce_spec:
901      if spec_tuple.limit < 0:
902        this_grads = remaining_grads
903        remaining_grads = []
904      else:
905        (this_grads, remaining_grads) = cross_device_utils.split_grads_by_size(
906            spec_tuple.limit, remaining_grads)
907      if this_grads:
908        device_grad_packs, tensor_packer = _pack_tensors(
909            this_grads, self._num_packs, self._agg_small_grads_max_bytes,
910            self._agg_small_grads_max_group)
911        range_agg_grads = cross_device_utils.sum_gradients_all_reduce(
912            self._worker_devices, device_grad_packs, len(self._worker_devices),
913            spec_tuple.alg, spec_tuple.shards, range(self._num_gpus_per_worker))
914        range_agg_grads = _unpack_tensors(range_agg_grads, tensor_packer)
915
916        if not aggregated_grads:
917          aggregated_grads = range_agg_grads
918        else:
919          assert len(aggregated_grads) == len(range_agg_grads)
920          for i in range(len(aggregated_grads)):
921            aggregated_grads[i] += range_agg_grads[i]
922    assert not remaining_grads
923
924    return _ungroup_and_make_mirrored(aggregated_grads, per_replica_values[0],
925                                      reduce_op)
926
927
928@tf_export("distribute.experimental.CollectiveCommunication")
929class CollectiveCommunication(enum.Enum):
930  """Communication choices for CollectiveOps.
931
932  * `AUTO`: Default to runtime's automatic choices.
933  * `RING`: TensorFlow's ring algorithms for all-reduce and
934    all-gather.
935  * `NCCL`: Use ncclAllReduce for all-reduce, and ring algorithms for
936    all-gather.  TODO(ayushd): add ncclAllGather implementation.
937  """
938  AUTO = "AUTO"
939  RING = "RING"
940  NCCL = "NCCL"
941
942
943# TODO(yuefengz): support in-graph collective all-reduce.
944class CollectiveAllReduce(CrossDeviceOps):
945  """All-reduce cross device ops using collective ops.
946
947  In the between-graph replicated training, it will still do all-reduces across
948  all workers and then put results on the right destinations.
949  """
950
951  def __init__(self,
952               num_workers=1,
953               num_gpus_per_worker=0,
954               all_reduce_merge_scope=32,
955               collective_keys=None):
956    """Initializes the object.
957
958    Args:
959      num_workers: number of workers in the between-graph replicated training.
960      num_gpus_per_worker: number of GPUs per worker.
961      all_reduce_merge_scope: size of groups into which to partition consecutive
962        gradients grouped under a common 'allreduce' name scope. This is useful
963        for some optimization of collective ops.
964      collective_keys: an optional CollectiveKey object.
965    """
966    self._num_workers = num_workers
967    self._num_gpus_per_worker = num_gpus_per_worker
968    self._all_reduce_merge_scope = all_reduce_merge_scope
969    self._collective_keys = (collective_keys or
970                             cross_device_utils.CollectiveKeys())
971    super(CollectiveAllReduce, self).__init__()
972
973  # TODO(yuefengz, tucker): is indexed slices supported by collective ops?
974  def reduce_implementation(self, reduce_op, per_replica_value, destinations):
975    if cross_device_utils.contains_indexed_slices(per_replica_value):
976      raise ValueError(
977          "`IndexSlices` is not supported for Collective All-Reduce.")
978
979    all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value])[0]
980    device_map, logical_device = get_device_map_from(destinations)
981    if (all_reduced.device_map is device_map and
982        all_reduced.logical_device == logical_device):
983      return all_reduced
984    devices = device_map.logical_to_actual_devices(logical_device)
985    index = []
986    for d in devices:
987      if d in all_reduced.devices:
988        index.append(all_reduced.get(d))
989      else:
990        # TODO(josh11b): Once we add support for model parallelism, get the
991        # copy from the corresponding replica instead of the primary.
992        with ops.control_dependencies(all_reduced.values), ops.device(d):
993          index.append(array_ops.identity(all_reduced.primary))
994
995    return value_lib.Mirrored(device_map, index, logical_device)
996
997  def batch_reduce_implementation(self, reduce_op, value_destination_pairs):
998    if cross_device_utils.contains_indexed_slices(value_destination_pairs):
999      raise ValueError(
1000          "`IndexSlices` is not supported for Collective All-Reduce.")
1001
1002    all_devices_match = _all_devices_match(value_destination_pairs)
1003    if all_devices_match:
1004      return self._batch_all_reduce(reduce_op,
1005                                    [v[0] for v in value_destination_pairs])
1006    else:
1007      if not all_devices_match:
1008        logging.log_first_n(
1009            logging.WARN, "Efficient batch_reduce is not supported if "
1010            "destinations are different.", 10)
1011
1012      return [
1013          self.reduce_implementation(reduce_op, t, destinations=v)
1014          for t, v in value_destination_pairs
1015      ]
1016
1017  def _batch_all_reduce(self, reduce_op, per_replica_values):
1018    """All-reduce across all workers in a batch."""
1019
1020    logging.log_first_n(
1021        logging.INFO, "Collective All-reduce invoked with batches size = %d, "
1022        "num_workers = %d" % (len(per_replica_values), self._num_workers), 10)
1023
1024    grouped_by_device = _group_value_by_device(per_replica_values)
1025
1026    grouped_by_var = list(zip(*grouped_by_device))
1027    # grouped_by_var is grouped by variables and takes the following format:
1028    # [((grad0_gpu0, v0_gpu0), (grad0_gpu1, v0_gpu1), (grad0_gpu2, v0_gpu2) ..),
1029    #  ((grad1_gpu0, v1_gpu0), (grad1_gpu1, v1_gpu1), (grad1_gpu0, v1_gpu2) ..),
1030    #  ((grad2_gpu0, v2_gpu0), (grad2_gpu1, v2_gpu1), (grad2_gpu0, v2_gpu2) ..),
1031    #  ...
1032    # ]
1033    chunked_gv = [
1034        grouped_by_var[x:x + self._all_reduce_merge_scope]
1035        for x in range(0, len(grouped_by_var), self._all_reduce_merge_scope)
1036    ]
1037
1038    reduced_gv_list = []
1039    for chunk in chunked_gv:
1040      with ops.name_scope("allreduce"):
1041        for grad_and_vars in chunk:
1042          scaled_grads = [g for g, _ in grad_and_vars]
1043          collective_reduced = cross_device_utils.build_collective_reduce(
1044              scaled_grads, self._num_workers, self._collective_keys, "Add",
1045              "Id")
1046          result = []
1047          for (_, v), g in zip(grad_and_vars, collective_reduced):
1048            result.append([g, v])
1049          reduced_gv_list.append(result)
1050
1051    new_device_grads = [list(x) for x in zip(*reduced_gv_list)]
1052    return _ungroup_and_make_mirrored(
1053        new_device_grads,
1054        per_replica_values[0],
1055        reduce_op,
1056        num_between_graph_workers=self._num_workers)
1057
1058
1059_dgx1_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7],
1060               [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]]
1061
1062
1063def _has_dgx1_like_links(gpu_links):
1064  if not gpu_links:
1065    return False
1066  # TODO(yuefengz): figure out the right topology for hierarchical copy if
1067  # number of gpus are less than 8.
1068  if len(gpu_links) < 8:
1069    return False
1070  for i, (gpu_link, dgx1_link) in enumerate(zip(gpu_links, _dgx1_links)):
1071    if (set(gpu_link) != set(dgx1_link) and
1072        set(gpu_link) != set(dgx1_link + [i])):
1073      return False
1074  return True
1075
1076
1077def _choose_all_reduce_algorithm(device_links):
1078  if _has_dgx1_like_links(device_links):
1079    return HierarchicalCopyAllReduce(num_packs=len(device_links))
1080  else:
1081    return NcclAllReduce(num_packs=1)
1082
1083
1084def choose_the_best(devices, session_config=None):
1085  """Find the best subclass of CrossDeviceOps given a session config.
1086
1087  Args:
1088    devices: a list of devices passed to `tf.distribute.Strategy`.
1089    session_config: a `tf.ConfigProto` or `None`. If `None`, it will make
1090      decision based on all local devices.
1091
1092  Returns:
1093    A subclass of `CrossDeviceOps`.
1094  """
1095  requested_devices = set([device_util.canonicalize(d) for d in devices])
1096  machine_devices = device_lib.list_local_devices(session_config=session_config)
1097  using_devices = []
1098  for d in machine_devices:
1099    if device_util.canonicalize(d.name) in requested_devices:
1100      using_devices.append(d)
1101    else:
1102      logging.info(
1103          "Device is available but not used by distribute strategy: %s", d.name)
1104
1105  if len(using_devices) != len(requested_devices):
1106    logging.warning("Not all devices in `tf.distribute.Strategy` are visible "
1107                    "to TensorFlow.")
1108    return ReductionToOneDevice()
1109
1110  if any(d.device_type.lower() != "gpu" for d in using_devices):
1111    logging.warning("Not all devices in `tf.distribute.Strategy` are visible "
1112                    "to TensorFlow.")
1113    return ReductionToOneDevice()
1114
1115  device_links = [[] for _ in range(len(using_devices))]
1116  for i, device in enumerate(using_devices):
1117    for link in device.locality.links.link:
1118      device_links[i].append(link.device_id)
1119
1120  return _choose_all_reduce_algorithm(device_links)
1121