1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities to construct a TF subgraph implementing distributed All-Reduce."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import math
23
24from tensorflow.python.framework import device as device_lib
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import nccl_ops
29
30
31def _flatten_tensors(tensors):
32  """Check tensors for isomorphism and flatten.
33
34  Args:
35    tensors: list of `tf.Tensor` which must all have the same shape.
36
37  Returns:
38    tensors: a list of `tf.Tensor` which are flattened (1D) views of tensors
39    shape: the original shape of each element of input tensors
40
41  Raises:
42    ValueError: tensors are empty or non-isomorphic or have unknown shape.
43  """
44  if not tensors:
45    raise ValueError("tensors cannot be empty")
46  shape = tensors[0].shape
47  for tensor in tensors:
48    shape = shape.merge_with(tensor.shape)
49  if not shape.is_fully_defined():
50    raise ValueError("Tensors must have statically known shape.")
51  if len(shape) != 1:
52    reshaped = []
53    for t in tensors:
54      with ops.colocate_with(t):
55        reshaped.append(array_ops.reshape(t, [-1]))
56    tensors = reshaped
57  return tensors, shape
58
59
60def _reshape_tensors(tensors, shape):
61  """Reshape tensors flattened by _flatten_tensors.
62
63  Args:
64    tensors: list of `tf.Tensor` of identical length 1D tensors.
65    shape: list of integers describing the desired shape.  Product of
66      the elements must equal the length of each tensor.
67
68  Returns:
69    list of `tf.Tensor` which are the reshaped inputs.
70  """
71  reshaped = []
72  for t in tensors:
73    with ops.colocate_with(t):
74      reshaped.append(array_ops.reshape(t, shape))
75  return reshaped
76
77
78def _padded_split(tensor, pieces):
79  """Like split for 1D tensors but pads-out case where len % pieces != 0.
80
81  Args:
82    tensor: `tf.Tensor` that must be 1D.
83    pieces: a positive integer specifying the number of pieces into which
84      tensor should be split.
85
86  Returns:
87    list of `tf.Tensor` of length pieces, which hold the values of
88      thin input tensor, in order. The final tensor may
89      be zero-padded on the end to make its size equal to those of all
90      of the other tensors.
91
92  Raises:
93    ValueError: The input tensor is not 1D.
94  """
95  shape = tensor.shape
96  if 1 != len(shape):
97    raise ValueError("input tensor must be 1D")
98  tensor_len = shape.dims[0].value
99  with ops.colocate_with(tensor):
100    if tensor_len % pieces != 0:
101      # pad to an even length
102      chunk_size = 1 + tensor_len // pieces
103      if pieces > tensor_len:
104        # This is an edge case that should not come up in practice,
105        # i.e. a different reduction algorithm would be better,
106        # but we'll make it work just for completeness.
107        pad_len = pieces - tensor_len
108        extended_whole = array_ops.concat(
109            [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0)
110        parts = array_ops.split(extended_whole, pieces)
111        return parts, pad_len
112      elif (pieces - 1) * chunk_size >= tensor_len:
113        # Another edge case of limited real interest.
114        pad_len = (pieces * chunk_size) % tensor_len
115        extended_whole = array_ops.concat(
116            [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0)
117        parts = array_ops.split(extended_whole, pieces)
118        return parts, pad_len
119      else:
120        last_chunk_size = tensor_len - (pieces - 1) * chunk_size
121        pad_len = chunk_size - last_chunk_size
122        piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size]
123        parts = array_ops.split(tensor, piece_lens)
124        parts[-1] = array_ops.concat(
125            [parts[-1], array_ops.zeros([pad_len], dtype=tensor.dtype)], 0)
126        return parts, pad_len
127    else:
128      return array_ops.split(tensor, pieces), 0
129
130
131def _strip_padding(tensors, pad_len):
132  """Strip the suffix padding added by _padded_split.
133
134  Args:
135    tensors: list of `tf.Tensor` of identical length 1D tensors.
136    pad_len: number of elements to be stripped from the end of each tensor.
137
138  Returns:
139    list of `tf.Tensor` which are the stripped inputs.
140
141  Raises:
142    ValueError: tensors must be a non-empty list of 1D tensors, and
143      each must be longer than pad_len.
144  """
145  if not tensors:
146    raise ValueError("tensors cannot be empty")
147  shape = tensors[0].shape
148  if len(shape) > 1:
149    raise ValueError("tensors must be 1D")
150  prefix_len = int(shape[0] - pad_len)
151  if prefix_len < 0:
152    raise ValueError("pad_len longer than tensor")
153  stripped = []
154  for t in tensors:
155    with ops.colocate_with(t):
156      stripped.append(array_ops.slice(t, [0], [prefix_len]))
157  return stripped
158
159
160def _ragged_split(tensor, pieces):
161  """Like split for 1D tensors but allows case where len % pieces != 0.
162
163  Args:
164    tensor: `tf.Tensor` that must be 1D.
165    pieces: a positive integer specifying the number of pieces into which
166      tensor should be split.
167
168  Returns:
169    list of `tf.Tensor` of length pieces, which hold the values of
170      the input tensor, in order. The final tensor may be shorter
171      than the others, which will all be of equal length.
172
173  Raises:
174    ValueError: input tensor must be 1D.
175  """
176  shape = tensor.shape
177  if 1 != len(shape):
178    raise ValueError("input tensor must be 1D")
179  tensor_len = shape.dims[0].value
180  chunk_size = tensor_len // pieces
181  with ops.colocate_with(tensor):
182    if tensor_len != (pieces * chunk_size):
183      # last piece will be short
184      assert pieces > 1
185      last_chunk_size = tensor_len - ((pieces - 1) * chunk_size)
186      assert last_chunk_size > 0
187      piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size]
188      return array_ops.split(tensor, piece_lens)
189    else:
190      return array_ops.split(tensor, pieces)
191
192
193def _ring_permutations(num_workers, num_subchunks, gpu_perm):
194  """"Generate an array of device index arrays, one for each subchunk.
195
196  In the basic ring reduction algorithm there are size(T)/num_devices
197  data chunks and each device process one chunk per tick, i.e. sending
198  one chunk and receiving one chunk.  The idea of subchunking is that
199  each device processes num_subchunks smaller data regions per tick,
200  and the ring rank permutation is different for each subchunk index
201  so that a device is potentially sending to and receiving from
202  num_subchunks different other devices at each tick.  Where multiple
203  independent data channels exist between devices, this strategy
204  supplies a method of using them in parallel.
205
206  Args:
207    num_workers: number of worker tasks
208    num_subchunks: number of subchunks into which to divide each per-GPU chunk.
209    gpu_perm: an array of integers in [0, num_gpus-1] giving the default
210      ring order of GPUs at each worker.  Other permutations will be generated
211      by rotating this array and splicing together per-worker instances.
212
213  Raises:
214    ValueError: the number of subchunks may not exceed the number of GPUs.
215
216  Returns:
217    pred_by_s_d: list of lists that maps (by index) from (subchunk, dev) to
218        preceding device in the permutation for that subchunk.  The
219        device index of GPU i at worker j is i + (j * num_gpus).
220    rank_by_s_d: list of lists that maps (by index) from (subchunk, dev) to
221       local rank of device d in the permutation for that subchunk.
222  """
223  num_gpus = len(gpu_perm)
224  devices = num_workers * num_gpus
225  if devices == 0:
226    return [], []
227  if num_subchunks > num_gpus:
228    raise ValueError(
229        "num_subchunks %d must be <= num_gpus %d" % (num_subchunks, num_gpus))
230  rotation_interval = max(1, int(num_gpus / num_subchunks))
231  perms_by_s = []
232  for s in range(0, num_subchunks):
233    full_order = []
234    offset = s * rotation_interval
235    for w in range(0, num_workers):
236      default_order = [(w * num_gpus) + i for i in gpu_perm]
237      dev_order = default_order[offset:] + default_order[:offset]
238      full_order += dev_order
239    perms_by_s.append(full_order)
240  pred_by_s_d = [[-1 for d in range(0, devices)]
241                 for s in range(0, num_subchunks)]
242  rank_by_s_d = [[-1 for d in range(0, devices)]
243                 for s in range(0, num_subchunks)]
244  for s in range(0, num_subchunks):
245    for d in range(0, devices):
246      for t in range(0, devices):
247        if d == perms_by_s[s][t]:
248          rank_by_s_d[s][d] = t
249          pred_by_s_d[s][d] = perms_by_s[s][(t + devices - 1) % devices]
250          break
251  return (pred_by_s_d, rank_by_s_d)
252
253
254def build_ring_all_reduce(input_tensors, num_workers, num_subchunks,
255                          gpu_perm, red_op, un_op=None):
256  """Construct a subgraph performing a ring-style all-reduce of input_tensors.
257
258  Args:
259    input_tensors: a list of `tf.Tensor` objects, which must all
260      have the same shape and type.
261    num_workers: number of worker tasks spanned by input_tensors.
262    num_subchunks: number of subchunks each device should process in one tick.
263    gpu_perm: a list of ints giving a ring-wise rank ordering of GPUs at
264      each worker.  All workers must have the same number of
265      GPUs with the same rank ordering.  If NVLINK is available, this should
266      be a ring order supported by NVLINK edges.
267    red_op: a binary operator for elementwise reduction.
268    un_op: an optional unary operator to apply to fully reduced values.
269
270  Raises:
271    ValueError: empty input_tensors or they don't all have same
272    size.
273
274  Returns:
275    a list of `tf.Tensor` identical sum-reductions of input_tensors.
276  """
277  if len(input_tensors) < 2:
278    raise ValueError("input_tensors must be length 2 or longer")
279  input_tensors, shape = _flatten_tensors(input_tensors)
280  devices = [t.device for t in input_tensors]
281  (pred_by_s_d, rank_by_s_d) = _ring_permutations(
282      num_workers, num_subchunks, gpu_perm)
283  chunks_by_dev, pad_len = _build_ring_gather(
284      input_tensors, devices,
285      num_subchunks, pred_by_s_d, rank_by_s_d, red_op)
286  if un_op:
287    chunks_by_dev = _apply_unary_to_chunks(un_op, chunks_by_dev)
288  output_tensors = _build_ring_scatter(pred_by_s_d, rank_by_s_d,
289                                       chunks_by_dev)
290  if pad_len > 0:
291    output_tensors = _strip_padding(output_tensors, pad_len)
292  if len(shape) != 1:
293    output_tensors = _reshape_tensors(output_tensors, shape)
294  return output_tensors
295
296
297def _build_ring_gather(input_tensors, devices, num_subchunks,
298                       pred_by_s_d, rank_by_s_d, red_op):
299  """Construct a subgraph for the first (reduction) pass of ring all-reduce.
300
301  Args:
302    input_tensors: a list of `tf.Tensor` 1D input tensors of same
303      shape and type.
304    devices: array of device name strings
305    num_subchunks: number of subchunks each device should process in one tick.
306    pred_by_s_d: as produced by _ring_permutations
307    rank_by_s_d: as produced by _ring_permutations
308    red_op: a binary operator for elementwise reduction
309
310  Raises:
311    ValueError: tensors must all be one dimensional.
312
313  Returns:
314    list of list of `tf.Tensor` of (partially) reduced values where
315    exactly num_subchunks chunks at each device are fully reduced.
316  """
317  num_devices = len(input_tensors)
318  if num_devices == 0:
319    return []
320  if num_devices == 1:
321    return input_tensors
322  shape = input_tensors[0].shape
323  if 1 != len(shape):
324    raise ValueError("input tensors must be 1D")
325  num_chunks = num_devices * num_subchunks
326  num_ticks = num_devices - 1
327  # Initialize chunks_by_dev with splits of the input tensors.
328  chunks_by_dev = []
329  split_pad_len = 0
330  for d in range(0, num_devices):
331    with ops.device(devices[d]):
332      splits, split_pad_len = _padded_split(input_tensors[d], num_chunks)
333      chunks_by_dev.append(splits)
334  # Reduction phase
335  for tick in range(0, num_ticks):
336    # One new partial reduction for every chunk
337    new_partial_reductions = [None for _ in range(0, num_chunks)]
338    # Compute reductions with respect to last tick's values
339    for d in range(0, num_devices):
340      with ops.device(devices[d]):
341        for s in range(0, num_subchunks):
342          rank = rank_by_s_d[s][d]
343          seg_index = (rank + num_devices - (2 + tick)) % num_devices
344          pred_dev = pred_by_s_d[s][d]
345          chunk_index = (seg_index * num_subchunks) + s
346          new_partial_reductions[chunk_index] = red_op(
347              chunks_by_dev[pred_dev][chunk_index],
348              chunks_by_dev[d][chunk_index])
349    # Update chunks_by_dev with the new values at the end of the tick.
350    for d in range(0, num_devices):
351      for s in range(0, num_subchunks):
352        rank = rank_by_s_d[s][d]
353        seg_index = (rank + num_devices - (2 + tick)) % num_devices
354        chunk_index = (seg_index * num_subchunks) + s
355        chunks_by_dev[d][chunk_index] = new_partial_reductions[chunk_index]
356  return chunks_by_dev, split_pad_len
357
358
359def _apply_unary_to_chunks(f, chunks_by_dev):
360  """Apply a unary op to each tensor in chunks_by_dev, on same device.
361
362  Args:
363    f: a unary function over `tf.Tensor`.
364    chunks_by_dev: list of lists of `tf.Tensor`.
365
366  Returns:
367    new list of lists of `tf.Tensor` with the same structure as
368    chunks_by_dev containing the derived tensors.
369  """
370  output = []
371  for x in chunks_by_dev:
372    with ops.colocate_with(x[0]):
373      output.append([f(t) for t in x])
374  return output
375
376
377def _build_ring_scatter(pred_by_s_d, rank_by_s_d,
378                        chunks_by_dev):
379  """Construct subgraph for second (scatter) pass of ring all-reduce.
380
381  Args:
382    pred_by_s_d: as produced by _ring_permutations
383    rank_by_s_d: as produced by _ring_permutations
384    chunks_by_dev: list of list of `tf.Tensor` indexed by ints
385      (device, chunk)
386
387  Raises:
388    ValueError: chunks_by_dev is not well-formed
389
390  Returns:
391    list of `tf.Tensor` which are the fully reduced tensors, one
392    at each device corresponding to the outer dimension of chunks_by_dev.
393  """
394  num_devices = len(chunks_by_dev)
395  num_chunks = len(chunks_by_dev[0])
396  if 0 != num_chunks % num_devices:
397    raise ValueError(
398        "Expect number of chunks per device to be divisible by num_devices")
399  num_subchunks = int(num_chunks / num_devices)
400  num_ticks = num_devices - 1
401  for tick in range(0, num_ticks):
402    passed_values = [None for _ in range(0, num_chunks)]
403    for d in range(0, num_devices):
404      with ops.colocate_with(chunks_by_dev[d][0]):
405        for s in range(0, num_subchunks):
406          rank = rank_by_s_d[s][d]
407          seg_index = (rank + num_devices - (1 + tick)) % num_devices
408          pred_dev = pred_by_s_d[s][d]
409          chunk_index = (seg_index * num_subchunks) + s
410          passed_values[chunk_index] = array_ops.identity(
411              chunks_by_dev[pred_dev][chunk_index])
412    for d in range(0, num_devices):
413      for s in range(0, num_subchunks):
414        rank = rank_by_s_d[s][d]
415        seg_index = (rank + num_devices - (1 + tick)) % num_devices
416        chunk_index = (seg_index * num_subchunks) + s
417        chunks_by_dev[d][chunk_index] = passed_values[chunk_index]
418  # Join chunks at each device.
419  output = []
420  for x in chunks_by_dev:
421    with ops.colocate_with(x[0]):
422      output.append(array_ops.concat(x, 0))
423  return output
424
425
426def build_recursive_hd_all_reduce(input_tensors, red_op, un_op=None):
427  """Construct a subgraph for recursive halving-doubling all-reduce.
428
429  The recursive halving-doubling algorithm is described in
430  (Thakur et al., 2015).
431
432  The concept is to arrange the participating n devices in
433  a linear sequence where devices exchange data pairwise
434  with one other device in each round.  During the gather
435  phase there are lg(n) rounds where devices exchange
436  increasingly smaller sub-tensors with another device
437  at increasingly greater distances, until at the top
438  each device has 1/n of the fully reduced values.  During the
439  scatter phase each device exchanges its fully reduced
440  sub-tensor (which doubles in length at each round)
441  with one other device at increasingly smaller distances
442  until each device has all of the fully reduced values.
443
444  Note: this preliminary version requires that len(input_tensors) be a
445    power of 2.  TODO(tucker): relax this restriction.  Also, the
446    number of elements in each tensor must be divisible by 2^h where h
447    is the number of hops in each phase.  This will also be relaxed in
448    the future with edge-case specific logic.
449
450  Args:
451    input_tensors: list of `tf.Tensor` to be elementwise reduced.
452    red_op: a binary elementwise reduction Op.
453    un_op: an optional unary elementwise Op to apply to reduced values.
454
455  Returns:
456    list of `tf.Tensor` which are the fully reduced tensors, one
457    at each device of input_tensors.
458
459  Raises:
460    ValueError: num_devices not a power of 2, or tensor len not divisible
461    by 2 the proper number of times.
462
463  References:
464    Optimization of Collective Communication Operations in MPICH:
465      [Thakur et al., 2005]
466      (https://journals.sagepub.com/doi/abs/10.1177/1094342005051521)
467      ([pdf](http://wwwi10.lrr.in.tum.de/~gerndt/home/Teaching/HPCSeminar/mpich_multi_coll.pdf))
468  """
469  devices = [t.device for t in input_tensors]
470  input_tensors, shape = _flatten_tensors(input_tensors)
471  reduced_shards = _build_recursive_hd_gather(input_tensors, devices, red_op)
472  if un_op:
473    reduced_shards = [un_op(t) for t in reduced_shards]
474  output_tensors = _build_recursive_hd_scatter(reduced_shards, devices)
475  if len(shape) != 1:
476    output_tensors = _reshape_tensors(output_tensors, shape)
477  return output_tensors
478
479
480def _build_recursive_hd_gather(input_tensors, devices, red_op):
481  """Construct the gather phase of recursive halving-doubling all-reduce.
482
483  Args:
484    input_tensors: list of `tf.Tensor` to be elementwise reduced.
485    devices: a list of strings naming the devices hosting input_tensors,
486      which will also be used to host the (partial) reduction values.
487    red_op: a binary elementwise reduction Op.
488
489  Returns:
490    list of `tf.Tensor` which are the fully reduced tensor shards.
491
492  Raises:
493    ValueError: num_devices not a power of 2, or tensor len not divisible
494    by 2 the proper number of times.
495  """
496  num_devices = len(devices)
497  num_hops = int(math.log(num_devices, 2))
498  if num_devices != (2 ** num_hops):
499    raise ValueError("num_devices must be a power of 2")
500  chunks = input_tensors
501  for h in range(0, num_hops):
502    span = 2 ** h
503    group_size = span * 2
504    new_chunks = [[] for _ in devices]
505    for d in range(0, num_devices):
506      if (d % group_size) >= (group_size / 2):
507        # skip right half of a pair
508        continue
509      left_dev = devices[d]
510      right_dev = devices[d + span]
511      left_split = array_ops.split(chunks[d], 2)
512      right_split = array_ops.split(chunks[d+span], 2)
513      with ops.device(left_dev):
514        new_chunks[d] = red_op(left_split[0], right_split[0])
515      with ops.device(right_dev):
516        new_chunks[d + span] = red_op(left_split[1], right_split[1])
517    chunks = new_chunks
518  return chunks
519
520
521def _build_recursive_hd_scatter(input_tensors, devices):
522  """Construct the scatter phase of recursive halving-doubling all-reduce.
523
524  Args:
525    input_tensors: list of `tf.Tensor` that are fully-reduced shards.
526    devices: a list of strings naming the devices on which the reconstituted
527      full tensors should be placed.
528
529  Returns:
530    list of `tf.Tensor` which are the fully reduced tensors.
531  """
532  num_devices = len(devices)
533  num_hops = int(math.log(num_devices, 2))
534  assert num_devices == (2 ** num_hops), "num_devices must be a power of 2"
535  chunks = input_tensors
536  for h in reversed(range(0, num_hops)):
537    span = 2 ** h
538    group_size = span * 2
539    new_chunks = [[] for _ in devices]
540    for d in range(0, num_devices):
541      if (d % group_size) >= (group_size / 2):
542        # skip right half of a pair
543        continue
544      left_idx = d
545      right_idx = d + span
546      left_dev = devices[left_idx]
547      right_dev = devices[right_idx]
548      with ops.device(left_dev):
549        new_chunks[left_idx] = array_ops.concat([chunks[left_idx],
550                                                 chunks[right_idx]], 0)
551      with ops.device(right_dev):
552        new_chunks[right_idx] = array_ops.concat([chunks[left_idx],
553                                                  chunks[right_idx]], 0)
554    chunks = new_chunks
555  return chunks
556
557
558def build_shuffle_all_reduce(input_tensors, gather_devices, red_op, un_op=None):
559  """Construct a subgraph for shuffle all-reduce.
560
561  Shuffle reduce is essentially the algorithm implemented when using
562  parameter servers.  Suppose tensor length is n, there are d devices
563  and g gather shards.  Each device sends a n/g length sub-tensor to
564  each gather shard.  The gather shards perform a reduction across d
565  fragments, then broadcast the result back to each device.  The
566  devices then join the g fully reduced fragments they receive from
567  the shards.  The gather shards could perform d-1 pairwise
568  reductions, or one d-way reduction.  The first is better where
569  reduction Op time is low compared to transmission time, the second
570  better in the other case.
571
572  Args:
573    input_tensors: list of `tf.Tensor` values to be reduced.
574    gather_devices: list of names of devices on which reduction shards
575      should be placed.
576    red_op: an n-array elementwise reduction Op
577    un_op: optional elementwise unary Op to be applied to fully-reduced values.
578
579  Returns:
580    list of `tf.Tensor` which are the fully reduced tensors.
581  """
582  input_tensors, shape = _flatten_tensors(input_tensors)
583  dst_devices = [t.device for t in input_tensors]
584  reduced_shards = _build_shuffle_gather(input_tensors, gather_devices,
585                                         red_op, un_op)
586  output_tensors = _build_shuffle_scatter(reduced_shards, dst_devices)
587  if len(shape) != 1:
588    output_tensors = _reshape_tensors(output_tensors, shape)
589  return output_tensors
590
591
592def _build_shuffle_gather(input_tensors, gather_devices, red_op, un_op=None):
593  """Construct the gather (concentrate and reduce) phase of shuffle all-reduce.
594
595  Args:
596    input_tensors: list of `tf.Tensor` values to be reduced.
597    gather_devices: list of names of devices on which reduction shards
598      should be placed.
599    red_op: the binary reduction Op
600    un_op: optional elementwise unary Op to be applied to fully-reduced values.
601
602  Returns:
603    list of `tf.Tensor` which are the fully reduced shards.
604
605  Raises:
606    ValueError: inputs not well-formed.
607  """
608  num_source_devices = len(input_tensors)
609  num_gather_devices = len(gather_devices)
610  shape = input_tensors[0].shape
611  if len(shape) != 1:
612    raise ValueError("input_tensors must be 1D")
613  shards_by_source = []
614  for d in range(0, num_source_devices):
615    with ops.colocate_with(input_tensors[d]):
616      shards_by_source.append(
617          _ragged_split(input_tensors[d], num_gather_devices))
618  reduced_shards = []
619  for d in range(0, num_gather_devices):
620    with ops.device(gather_devices[d]):
621      values = [s[d] for s in shards_by_source]
622      red_shard = red_op(values)
623      if un_op:
624        red_shard = un_op(red_shard)
625      reduced_shards.append(red_shard)
626  return reduced_shards
627
628
629def _build_shuffle_scatter(reduced_shards, dst_devices):
630  """Build the scatter phase of shuffle all-reduce.
631
632  Args:
633    reduced_shards:  list of `tf.Tensor` fully reduced shards
634    dst_devices: list of names of devices at which the fully-reduced value
635      should be reconstituted.
636
637  Returns:
638    list of `tf.Tensor` scattered tensors.
639  """
640  num_devices = len(dst_devices)
641  out_tensors = []
642  for d in range(0, num_devices):
643    with ops.device(dst_devices[d]):
644      out_tensors.append(array_ops.concat(reduced_shards, 0))
645  return out_tensors
646
647
648def _split_by_task(devices, values):
649  """Partition devices and values by common task.
650
651  Args:
652    devices: list of device name strings
653    values: list of `tf.Tensor` of same length as devices.
654
655  Returns:
656    (per_task_devices, per_task_values) where both values are
657    lists of lists with isomorphic structure: the outer list is
658    indexed by task, and the inner list has length of the number
659    of values belonging to that task.  per_task_devices contains
660    the specific devices to which the values are local, and
661    per_task_values contains the corresponding values.
662
663  Raises:
664    ValueError: devices must be same length as values.
665  """
666  num_devices = len(devices)
667  if num_devices != len(values):
668    raise ValueError("len(devices) must equal len(values)")
669  per_task_devices = collections.OrderedDict()
670  per_task_values = collections.OrderedDict()
671  for d in range(num_devices):
672    d_spec = device_lib.DeviceSpec.from_string(devices[d])
673    if not hasattr(d_spec, "task") or d_spec.task is None:
674      assert False, "failed to parse device %s" % devices[d]
675    index = (d_spec.job or "localhost", d_spec.replica or 0, d_spec.task)
676    if index not in per_task_devices:
677      per_task_devices[index] = []
678      per_task_values[index] = []
679    per_task_devices[index].append(devices[d])
680    per_task_values[index].append(values[d])
681
682  return (list(per_task_devices.values()), list(per_task_values.values()))
683
684
685def build_nccl_all_reduce(input_tensors, red_op, un_op=None):
686  """Build a subgraph that does one full all-reduce, using NCCL.
687
688  Args:
689    input_tensors: list of `tf.Tensor` of same-shape and type values to
690      be reduced.
691    red_op: binary elementwise reduction operator. Must be one of
692      {tf.add}
693    un_op: optional unary elementwise Op to apply to fully-reduce values.
694
695  Returns:
696    list of `tf.Tensor` of reduced values.
697
698  Raises:
699    ValueError: red_op not supported.
700  """
701  if red_op == math_ops.add:
702    output_tensors = nccl_ops.all_sum(input_tensors)
703  else:
704    raise ValueError("red_op not supported by NCCL all-reduce: ", red_op)
705  if un_op:
706    un_op_wrapped = []
707    for t in output_tensors:
708      with ops.colocate_with(t):
709        un_op_wrapped.append(un_op(t))
710    output_tensors = un_op_wrapped
711  return output_tensors
712
713
714def _build_nccl_hybrid(input_tensors, red_op, upper_level_f):
715  """Construct a subgraph for NCCL hybrid all-reduce.
716
717  Args:
718    input_tensors: list of `tf.Tensor` of same-shape and type values to
719      be reduced.
720    red_op: binary elementwise reduction operator.
721    upper_level_f: function for reducing one value per worker, across
722      workers.
723
724  Returns:
725    list of `tf.Tensor` of reduced values.
726
727  Raises:
728    ValueError: inputs not well-formed.
729  """
730  input_tensors, shape = _flatten_tensors(input_tensors)
731  devices = [t.device for t in input_tensors]
732  per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors)
733  num_workers = len(per_worker_devices)
734  up_values = [None for w in range(0, num_workers)]
735  up_devices = up_values[:]
736  down_values = up_values[:]
737  # First stage: reduce within each worker using NCCL
738  for w in range(0, num_workers):
739    worker_values = build_nccl_all_reduce(per_worker_values[w], red_op)
740    # NOTE: these reductions will not run to completion unless
741    # every output value is used.  Since we only need one, we
742    # need to put control dependencies on the rest.
743    with ops.control_dependencies(worker_values):
744      with ops.device(worker_values[0].device):
745        up_values[w] = array_ops.identity(worker_values[0])
746      up_devices[w] = per_worker_devices[w][0]
747  # Second stage: Apply upper_level_f to reduce across first device at
748  # each worker
749  level_2_output = upper_level_f(up_values)
750  # Third stage: propagate within each worker using NCCL Broadcast
751  for w in range(0, num_workers):
752    dst_tensors = []
753    with ops.device(per_worker_devices[w][0]):
754      broadcast_src = nccl_ops.broadcast(array_ops.identity(level_2_output[w]))
755    for d in per_worker_devices[w]:
756      with ops.device(d):
757        dst_tensors.append(array_ops.identity(broadcast_src))
758    down_values[w] = dst_tensors
759  output_tensors = [v for sublist in down_values for v in sublist]
760  if len(shape) != 1:
761    output_tensors = _reshape_tensors(output_tensors, shape)
762  return output_tensors
763
764
765def _reduce_non_singleton(input_tensors, red_f, un_op):
766  """If len(input_tensors) > 1, apply red_f, else apply un_op."""
767  if len(input_tensors) > 1:
768    return red_f(input_tensors)
769  else:
770    if not un_op:
771      return input_tensors
772    output_tensors = []
773    for t in input_tensors:
774      with ops.colocate_with(t):
775        output_tensors.append(un_op(t))
776    return output_tensors
777
778
779def build_nccl_then_ring(input_tensors, subdiv, red_op, un_op=None):
780  """Construct hybrid of NCCL within workers, Ring across workers."""
781  def upper_builder(y):
782    return build_ring_all_reduce(y, len(y), subdiv, [0], red_op, un_op)
783  def upper_level_f(x):
784    return _reduce_non_singleton(x, upper_builder, un_op)
785  return _build_nccl_hybrid(input_tensors, red_op, upper_level_f)
786
787
788def build_nccl_then_recursive_hd(input_tensors, red_op, un_op=None):
789  """Construct hybrid of NCCL within workers, Recursive-HD across workers."""
790  upper_level_f = lambda x: build_recursive_hd_all_reduce(x, red_op, un_op)
791  return _build_nccl_hybrid(input_tensors, red_op, upper_level_f)
792
793
794def build_nccl_then_shuffle(input_tensors, gather_devices, nccl_red_op,
795                            shuffle_red_op, un_op=None):
796  """Construct hybrid of NCCL within workers, Shuffle across workers."""
797  def upper_level_f(x):
798    return build_shuffle_all_reduce(x, gather_devices, shuffle_red_op, un_op)
799
800  return _build_nccl_hybrid(input_tensors, nccl_red_op, upper_level_f)
801
802
803def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f):
804  """Construct a subgraph for Shuffle hybrid all-reduce.
805
806  Args:
807    input_tensors: list of `tf.Tensor` of same-shape and type values to
808      be reduced.
809    gather_devices: list of device names on which to host gather shards.
810    red_op: binary elementwise reduction operator.
811    upper_level_f: function for reducing one value per worker, across
812      workers.
813
814  Returns:
815    list of `tf.Tensor` of reduced values.
816
817  Raises:
818    ValueError: inputs not well-formed.
819  """
820  input_tensors, shape = _flatten_tensors(input_tensors)
821  # First stage, reduce across each worker using gather_devices.
822  devices = [t.device for t in input_tensors]
823  per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors)
824  num_workers = len(per_worker_devices)
825  up_values = []
826  if len(gather_devices) != num_workers:
827    raise ValueError("For shuffle hybrid, gather_devices must contain one "
828                     "device per worker. ")
829  for w in range(0, num_workers):
830    reduced_shards = _build_shuffle_gather(
831        per_worker_values[w], [gather_devices[w]], red_op)
832    up_values.append(reduced_shards[0])
833  # Second stage, apply upper_level_f.
834  level_2_output = upper_level_f(up_values)
835  # Third stage, apply shuffle scatter at each worker.
836  output_tensors = []
837  for w in range(0, num_workers):
838    output_tensors += _build_shuffle_scatter(
839        [level_2_output[w]], per_worker_devices[w])
840  if len(shape) != 1:
841    output_tensors = _reshape_tensors(output_tensors, shape)
842  return output_tensors
843
844
845def build_shuffle_then_ring(input_tensors, gather_devices, subdiv,
846                            red_n_op, red_op, un_op=None):
847  """Construct hybrid of Shuffle within workers, Ring across workers."""
848  def upper_builder(tensors):
849    return build_ring_all_reduce(tensors, len(tensors), subdiv, [0],
850                                 red_op, un_op)
851  def upper_level_f(tensors):
852    return _reduce_non_singleton(tensors, upper_builder, un_op)
853  return _build_shuffle_hybrid(
854      input_tensors, gather_devices, red_n_op, upper_level_f)
855
856
857def build_shuffle_then_shuffle(input_tensors, first_gather_devices,
858                               second_gather_devices, red_op, un_op=None):
859  """Construct hybrid of Shuffle within workers, Shuffle across workers."""
860  def upper_builder(tensors):
861    return build_shuffle_all_reduce(tensors, second_gather_devices,
862                                    red_op, un_op)
863  def upper_level_f(tensors):
864    return _reduce_non_singleton(tensors, upper_builder, un_op)
865  return _build_shuffle_hybrid(
866      input_tensors, first_gather_devices, red_op, upper_level_f)
867