1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities to construct a TF subgraph implementing distributed All-Reduce."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import math
23
24from tensorflow.python.framework import device as device_lib
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import nccl_ops
29
30
31def _flatten_tensors(tensors):
32  """Check tensors for isomorphism and flatten.
33
34  Args:
35    tensors: list of T `tf.Tensor` which must all have the same shape.
36
37  Returns:
38    tensors: a list of T `tf.Tensor` which are flattened (1D) views of tensors
39    shape: the original shape of each element of input tensors
40
41  Raises:
42    ValueError: tensors are empty or non-isomorphic or have unknown shape.
43  """
44  if not tensors:
45    raise ValueError("tensors cannot be empty")
46  shape = tensors[0].shape
47  for tensor in tensors:
48    shape = shape.merge_with(tensor.shape)
49  if not shape.is_fully_defined():
50    raise ValueError("Tensors must have statically known shape.")
51  if len(shape) != 1:
52    reshaped = []
53    for t in tensors:
54      with ops.colocate_with(t):
55        reshaped.append(array_ops.reshape(t, [-1]))
56    tensors = reshaped
57  return tensors, shape
58
59
60def _reshape_tensors(tensors, shape):
61  """Reshape tensors flattened by _flatten_tensors.
62
63  Args:
64    tensors: list of T `tf.Tensor` of identical length 1D tensors.
65    shape: list of integers describing the desired shape.  Product of
66      the elements must equal the length of each tensor.
67
68  Returns:
69    list of T `tf.Tensor` which are the reshaped inputs.
70  """
71  reshaped = []
72  for t in tensors:
73    with ops.colocate_with(t):
74      reshaped.append(array_ops.reshape(t, shape))
75  return reshaped
76
77
78def _padded_split(tensor, pieces):
79  """Like split for 1D tensors but pads-out case where len % pieces != 0.
80
81  Args:
82    tensor: T `tf.Tensor` that must be 1D.
83    pieces: a positive integer specifying the number of pieces into which
84      tensor should be split.
85
86  Returns:
87    list of T `tf.Tensor` of length pieces, which hold the values of
88      thin input tensor, in order.  The final tensor may
89      be zero-padded on the end to make its size equal to those of all
90      of the other tensors.
91
92  Raises:
93    ValueError: The input tensor is not 1D.
94  """
95  shape = tensor.shape
96  if 1 != len(shape):
97    raise ValueError("input tensor must be 1D")
98  tensor_len = shape.dims[0].value
99  with ops.colocate_with(tensor):
100    if tensor_len % pieces != 0:
101      # pad to an even length
102      chunk_size = 1 + tensor_len // pieces
103      if pieces > tensor_len:
104        # This is an edge case that should not come up in practice,
105        # i.e. a different reduction algorithm would be better,
106        # but we'll make it work just for completeness.
107        pad_len = pieces - tensor_len
108        extended_whole = array_ops.concat(
109            [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0)
110        parts = array_ops.split(extended_whole, pieces)
111        return parts, pad_len
112      elif (pieces - 1) * chunk_size >= tensor_len:
113        # Another edge case of limited real interest.
114        pad_len = (pieces * chunk_size) % tensor_len
115        extended_whole = array_ops.concat(
116            [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0)
117        parts = array_ops.split(extended_whole, pieces)
118        return parts, pad_len
119      else:
120        last_chunk_size = tensor_len - (pieces - 1) * chunk_size
121        pad_len = chunk_size - last_chunk_size
122        piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size]
123        parts = array_ops.split(tensor, piece_lens)
124        parts[-1] = array_ops.concat(
125            [parts[-1], array_ops.zeros([pad_len], dtype=tensor.dtype)], 0)
126        return parts, pad_len
127    else:
128      return array_ops.split(tensor, pieces), 0
129
130
131def _strip_padding(tensors, pad_len):
132  """Strip the suffix padding added by _padded_split.
133
134  Args:
135    tensors: list of T `tf.Tensor` of identical length 1D tensors.
136    pad_len: number of elements to be stripped from the end of each tensor.
137
138  Returns:
139    list of T `tf.Tensor` which are the stripped inputs.
140
141  Raises:
142    ValueError: tensors must be a non-empty list of 1D tensors, and
143      each must be longer than pad_len.
144  """
145  if not tensors:
146    raise ValueError("tensors cannot be empty")
147  shape = tensors[0].shape
148  if len(shape) > 1:
149    raise ValueError("tensors must be 1D")
150  prefix_len = int(shape[0] - pad_len)
151  if prefix_len < 0:
152    raise ValueError("pad_len longer than tensor")
153  stripped = []
154  for t in tensors:
155    with ops.colocate_with(t):
156      stripped.append(array_ops.slice(t, [0], [prefix_len]))
157  return stripped
158
159
160def _ragged_split(tensor, pieces):
161  """Like split for 1D tensors but allows case where len % pieces != 0.
162
163  Args:
164    tensor: T `tf.Tensor` that must be 1D.
165    pieces: a positive integer specifying the number of pieces into which
166      tensor should be split.
167
168  Returns:
169    list of T `tf.Tensor` of length pieces, which hold the values of
170      the input tensor, in order.  The final tensor may be shorter
171      than the others, which will all be of equal length.
172
173  Raises:
174    ValueError: input tensor must be 1D.
175  """
176  shape = tensor.shape
177  if 1 != len(shape):
178    raise ValueError("input tensor must be 1D")
179  tensor_len = shape.dims[0].value
180  chunk_size = tensor_len // pieces
181  with ops.colocate_with(tensor):
182    if tensor_len != (pieces * chunk_size):
183      # last piece will be short
184      assert pieces > 1
185      last_chunk_size = tensor_len - ((pieces - 1) * chunk_size)
186      assert last_chunk_size > 0
187      piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size]
188      return array_ops.split(tensor, piece_lens)
189    else:
190      return array_ops.split(tensor, pieces)
191
192
193def _ring_permutations(num_workers, num_subchunks, gpu_perm):
194  """"Generate an array of device index arrays, one for each subchunk.
195
196  In the basic ring reduction algorithm there are size(T)/num_devices
197  data chunks and each device process one chunk per tick, i.e. sending
198  one chunk and receiving one chunk.  The idea of subchunking is that
199  each device processes num_subchunks smaller data regions per tick,
200  and the ring rank permutation is different for each subchunk index
201  so that a device is potentially sending to and receiving from
202  num_subchunks different other devices at each tick.  Where multiple
203  independent data channels exist between devices, this strategy
204  supplies a method of using them in parallel.
205
206  Args:
207    num_workers: number of worker tasks
208    num_subchunks: number of subchunks into which to divide each per-GPU chunk.
209    gpu_perm: an array of integers in [0, num_gpus-1] giving the default
210      ring order of GPUs at each worker.  Other permutations will be generated
211      by rotating this array and splicing together per-worker instances.
212
213  Raises:
214    ValueError: the number of subchunks may not exceed the number of GPUs.
215
216  Returns:
217    pred_by_s_d: list of lists that maps (by index) from (subchunk, dev) to
218        preceding device in the permutation for that subchunk.  The
219        device index of GPU i at worker j is i + (j * num_gpus).
220    rank_by_s_d: list of lists that maps (by index) from (subchunk, dev) to
221       local rank of device d in the permutation for that subchunk.
222  """
223  num_gpus = len(gpu_perm)
224  devices = num_workers * num_gpus
225  if devices == 0:
226    return [], []
227  if num_subchunks > num_gpus:
228    raise ValueError(
229        "num_subchunks %d must be <= num_gpus %d" % (num_subchunks, num_gpus))
230  rotation_interval = max(1, int(num_gpus / num_subchunks))
231  perms_by_s = []
232  for s in range(0, num_subchunks):
233    full_order = []
234    offset = s * rotation_interval
235    for w in range(0, num_workers):
236      default_order = [(w * num_gpus) + i for i in gpu_perm]
237      dev_order = default_order[offset:] + default_order[:offset]
238      full_order += dev_order
239    perms_by_s.append(full_order)
240  pred_by_s_d = [[-1 for d in range(0, devices)]
241                 for s in range(0, num_subchunks)]
242  rank_by_s_d = [[-1 for d in range(0, devices)]
243                 for s in range(0, num_subchunks)]
244  for s in range(0, num_subchunks):
245    for d in range(0, devices):
246      for t in range(0, devices):
247        if d == perms_by_s[s][t]:
248          rank_by_s_d[s][d] = t
249          pred_by_s_d[s][d] = perms_by_s[s][(t + devices - 1) % devices]
250          break
251  return (pred_by_s_d, rank_by_s_d)
252
253
254def build_ring_all_reduce(input_tensors, num_workers, num_subchunks,
255                          gpu_perm, red_op, un_op=None):
256  """Construct a subgraph performing a ring-style all-reduce of input_tensors.
257
258  Args:
259    input_tensors: a list of T `tf.Tensor` objects, which must all
260      have the same shape and type.
261    num_workers: number of worker tasks spanned by input_tensors.
262    num_subchunks: number of subchunks each device should process in one tick.
263    gpu_perm: a list of ints giving a ring-wise rank ordering of GPUs at
264      each worker.  All workers must have the same number of
265      GPUs with the same rank ordering.  If NVLINK is available, this should
266      be a ring order supported by NVLINK edges.
267    red_op: a binary operator for elementwise reduction.
268    un_op: an optional unary operator to apply to fully reduced values.
269
270  Raises:
271    ValueError: empty input_tensors or they don't all have same
272    size.
273
274  Returns:
275    a list of T `tf.Tensor` identical sum-reductions of input_tensors.
276  """
277  if len(input_tensors) < 2:
278    raise ValueError("input_tensors must be length 2 or longer")
279  input_tensors, shape = _flatten_tensors(input_tensors)
280  devices = [t.device for t in input_tensors]
281  (pred_by_s_d, rank_by_s_d) = _ring_permutations(
282      num_workers, num_subchunks, gpu_perm)
283  chunks_by_dev, pad_len = _build_ring_gather(
284      input_tensors, devices,
285      num_subchunks, pred_by_s_d, rank_by_s_d, red_op)
286  if un_op:
287    chunks_by_dev = _apply_unary_to_chunks(un_op, chunks_by_dev)
288  output_tensors = _build_ring_scatter(pred_by_s_d, rank_by_s_d,
289                                       chunks_by_dev)
290  if pad_len > 0:
291    output_tensors = _strip_padding(output_tensors, pad_len)
292  if len(shape) != 1:
293    output_tensors = _reshape_tensors(output_tensors, shape)
294  return output_tensors
295
296
297def _build_ring_gather(input_tensors, devices, num_subchunks,
298                       pred_by_s_d, rank_by_s_d, red_op):
299  """Construct a subgraph for the first (reduction) pass of ring all-reduce.
300
301  Args:
302    input_tensors: a list of T `tf.Tensor` 1D input tensors of same
303      shape and type.
304    devices: array of device name strings
305    num_subchunks: number of subchunks each device should process in one tick.
306    pred_by_s_d: as produced by _ring_permutations
307    rank_by_s_d: as produced by _ring_permutations
308    red_op: a binary operator for elementwise reduction
309
310  Raises:
311    ValueError: tensors must all be one dimensional.
312
313  Returns:
314    list of list of T `tf.Tensor` of (partially) reduced values where
315    exactly num_subchunks chunks at each device are fully reduced.
316  """
317  num_devices = len(input_tensors)
318  if num_devices == 0:
319    return []
320  if num_devices == 1:
321    return input_tensors
322  shape = input_tensors[0].shape
323  if 1 != len(shape):
324    raise ValueError("input tensors must be 1D")
325  num_chunks = num_devices * num_subchunks
326  num_ticks = num_devices - 1
327  # Initialize chunks_by_dev with splits of the input tensors.
328  chunks_by_dev = []
329  split_pad_len = 0
330  for d in range(0, num_devices):
331    with ops.device(devices[d]):
332      splits, split_pad_len = _padded_split(input_tensors[d], num_chunks)
333      chunks_by_dev.append(splits)
334  # Reduction phase
335  for tick in range(0, num_ticks):
336    # One new partial reduction for every chunk
337    new_partial_reductions = [None for _ in range(0, num_chunks)]
338    # Compute reductions with respect to last tick's values
339    for d in range(0, num_devices):
340      with ops.device(devices[d]):
341        for s in range(0, num_subchunks):
342          rank = rank_by_s_d[s][d]
343          seg_index = (rank + num_devices - (2 + tick)) % num_devices
344          pred_dev = pred_by_s_d[s][d]
345          chunk_index = (seg_index * num_subchunks) + s
346          new_partial_reductions[chunk_index] = red_op(
347              chunks_by_dev[pred_dev][chunk_index],
348              chunks_by_dev[d][chunk_index])
349    # Update chunks_by_dev with the new values at the end of the tick.
350    for d in range(0, num_devices):
351      for s in range(0, num_subchunks):
352        rank = rank_by_s_d[s][d]
353        seg_index = (rank + num_devices - (2 + tick)) % num_devices
354        chunk_index = (seg_index * num_subchunks) + s
355        chunks_by_dev[d][chunk_index] = new_partial_reductions[chunk_index]
356  return chunks_by_dev, split_pad_len
357
358
359def _apply_unary_to_chunks(f, chunks_by_dev):
360  """Apply a unary op to each tensor in chunks_by_dev, on same device.
361
362  Args:
363    f: a unary function over T `tf.Tensor`.
364    chunks_by_dev: list of lists of T `tf.Tensor`.
365
366  Returns:
367    new list of lists of T `tf.Tensor` with the same structure as
368    chunks_by_dev containing the derived tensors.
369  """
370  output = []
371  for x in chunks_by_dev:
372    with ops.colocate_with(x[0]):
373      output.append([f(t) for t in x])
374  return output
375
376
377def _build_ring_scatter(pred_by_s_d, rank_by_s_d,
378                        chunks_by_dev):
379  """Construct subgraph for second (scatter) pass of ring all-reduce.
380
381  Args:
382    pred_by_s_d: as produced by _ring_permutations
383    rank_by_s_d: as produced by _ring_permutations
384    chunks_by_dev: list of list of T `tf.Tensor` indexed by ints
385      (device, chunk)
386
387  Raises:
388    ValueError: chunks_by_dev is not well-formed
389
390  Returns:
391    list of T `tf.Tensor` which are the fully reduced tensors, one
392    at each device corresponding to the outer dimension of chunks_by_dev.
393  """
394  num_devices = len(chunks_by_dev)
395  num_chunks = len(chunks_by_dev[0])
396  if 0 != num_chunks % num_devices:
397    raise ValueError(
398        "Expect number of chunks per device to be divisible by num_devices")
399  num_subchunks = int(num_chunks / num_devices)
400  num_ticks = num_devices - 1
401  for tick in range(0, num_ticks):
402    passed_values = [None for _ in range(0, num_chunks)]
403    for d in range(0, num_devices):
404      with ops.colocate_with(chunks_by_dev[d][0]):
405        for s in range(0, num_subchunks):
406          rank = rank_by_s_d[s][d]
407          seg_index = (rank + num_devices - (1 + tick)) % num_devices
408          pred_dev = pred_by_s_d[s][d]
409          chunk_index = (seg_index * num_subchunks) + s
410          passed_values[chunk_index] = array_ops.identity(
411              chunks_by_dev[pred_dev][chunk_index])
412    for d in range(0, num_devices):
413      for s in range(0, num_subchunks):
414        rank = rank_by_s_d[s][d]
415        seg_index = (rank + num_devices - (1 + tick)) % num_devices
416        chunk_index = (seg_index * num_subchunks) + s
417        chunks_by_dev[d][chunk_index] = passed_values[chunk_index]
418  # Join chunks at each device.
419  output = []
420  for x in chunks_by_dev:
421    with ops.colocate_with(x[0]):
422      output.append(array_ops.concat(x, 0))
423  return output
424
425
426def build_recursive_hd_all_reduce(input_tensors, red_op, un_op=None):
427  """Construct a subgraph for recursive halving-doubling all-reduce.
428
429  The recursive halving-doubling algorithm is described in
430  http://www.mcs.anl.gov/~thakur/papers/ijhpca-coll.pdf
431
432  The concept is to arrange the participating n devices in
433  a linear sequence where devices exchange data pairwise
434  with one other device in each round.  During the gather
435  phase there are lg(n) rounds where devices exchange
436  increasingly smaller sub-tensors with another device
437  at increasingly greater distances, until at the top
438  each device has 1/n of the fully reduced values.  During the
439  scatter phase each device exchanges its fully reduced
440  sub-tensor (which doubles in length at each round)
441  with one other device at increasingly smaller distances
442  until each device has all of the fully reduced values.
443
444  Note: this preliminary version requires that len(input_tensors) be a
445    power of 2.  TODO(tucker): relax this restriction.  Also, the
446    number of elements in each tensor must be divisible by 2^h where h
447    is the number of hops in each phase.  This will also be relaxed in
448    the future with edge-case specific logic.
449
450  Args:
451    input_tensors: list of T `tf.Tensor` to be elementwise reduced.
452    red_op: a binary elementwise reduction Op.
453    un_op: an optional unary elementwise Op to apply to reduced values.
454
455  Returns:
456    list of T `tf.Tensor` which are the fully reduced tensors, one
457    at each device of input_tensors.
458
459  Raises:
460    ValueError: num_devices not a power of 2, or tensor len not divisible
461    by 2 the proper number of times.
462  """
463  devices = [t.device for t in input_tensors]
464  input_tensors, shape = _flatten_tensors(input_tensors)
465  reduced_shards = _build_recursive_hd_gather(input_tensors, devices, red_op)
466  if un_op:
467    reduced_shards = [un_op(t) for t in reduced_shards]
468  output_tensors = _build_recursive_hd_scatter(reduced_shards, devices)
469  if len(shape) != 1:
470    output_tensors = _reshape_tensors(output_tensors, shape)
471  return output_tensors
472
473
474def _build_recursive_hd_gather(input_tensors, devices, red_op):
475  """Construct the gather phase of recursive halving-doubling all-reduce.
476
477  Args:
478    input_tensors: list of T `tf.Tensor` to be elementwise reduced.
479    devices: a list of strings naming the devices hosting input_tensors,
480      which will also be used to host the (partial) reduction values.
481    red_op: a binary elementwise reduction Op.
482
483  Returns:
484    list of T `tf.Tensor` which are the fully reduced tensor shards.
485
486  Raises:
487    ValueError: num_devices not a power of 2, or tensor len not divisible
488    by 2 the proper number of times.
489  """
490  num_devices = len(devices)
491  num_hops = int(math.log(num_devices, 2))
492  if num_devices != (2 ** num_hops):
493    raise ValueError("num_devices must be a power of 2")
494  chunks = input_tensors
495  for h in range(0, num_hops):
496    span = 2 ** h
497    group_size = span * 2
498    new_chunks = [[] for _ in devices]
499    for d in range(0, num_devices):
500      if (d % group_size) >= (group_size / 2):
501        # skip right half of a pair
502        continue
503      left_dev = devices[d]
504      right_dev = devices[d + span]
505      left_split = array_ops.split(chunks[d], 2)
506      right_split = array_ops.split(chunks[d+span], 2)
507      with ops.device(left_dev):
508        new_chunks[d] = red_op(left_split[0], right_split[0])
509      with ops.device(right_dev):
510        new_chunks[d + span] = red_op(left_split[1], right_split[1])
511    chunks = new_chunks
512  return chunks
513
514
515def _build_recursive_hd_scatter(input_tensors, devices):
516  """Construct the scatter phase of recursive halving-doublng all-reduce.
517
518  Args:
519    input_tensors: list of T `tf.Tensor` that are fully-reduced shards.
520    devices: a list of strings naming the devices on which the reconstituted
521      full tensors should be placed.
522
523  Returns:
524    list of T `tf.Tensor` which are the fully reduced tensors.
525  """
526  num_devices = len(devices)
527  num_hops = int(math.log(num_devices, 2))
528  assert num_devices == (2 ** num_hops), "num_devices must be a power of 2"
529  chunks = input_tensors
530  for h in reversed(range(0, num_hops)):
531    span = 2 ** h
532    group_size = span * 2
533    new_chunks = [[] for _ in devices]
534    for d in range(0, num_devices):
535      if (d % group_size) >= (group_size / 2):
536        # skip right half of a pair
537        continue
538      left_idx = d
539      right_idx = d + span
540      left_dev = devices[left_idx]
541      right_dev = devices[right_idx]
542      with ops.device(left_dev):
543        new_chunks[left_idx] = array_ops.concat([chunks[left_idx],
544                                                 chunks[right_idx]], 0)
545      with ops.device(right_dev):
546        new_chunks[right_idx] = array_ops.concat([chunks[left_idx],
547                                                  chunks[right_idx]], 0)
548    chunks = new_chunks
549  return chunks
550
551
552def build_shuffle_all_reduce(input_tensors, gather_devices, red_op, un_op=None):
553  """Construct a subgraph for shuffle all-reduce.
554
555  Shuffle reduce is essentially the algorithm implemented when using
556  parameter servers.  Suppose tensor length is n, there are d devices
557  and g gather shards.  Each device sends a n/g length sub-tensor to
558  each gather shard.  The gather shards perform a reduction across d
559  fragments, then broadcast the result back to each device.  The
560  devices then join the g fully reduced fragments they receive from
561  the shards.  The gather shards could perform d-1 pairwise
562  reductions, or one d-way reduction.  The first is better where
563  reduction Op time is low compared to transmission time, the second
564  better in the other case.
565
566  Args:
567    input_tensors: list of T @(tf.Tensor} values to be reduced.
568    gather_devices: list of names of devices on which reduction shards
569      should be placed.
570    red_op: an n-array elementwise reduction Op
571    un_op: optional elementwise unary Op to be applied to fully-reduced values.
572
573  Returns:
574    list of T `tf.Tensor` which are the fully reduced tensors.
575  """
576  input_tensors, shape = _flatten_tensors(input_tensors)
577  dst_devices = [t.device for t in input_tensors]
578  reduced_shards = _build_shuffle_gather(input_tensors, gather_devices,
579                                         red_op, un_op)
580  output_tensors = _build_shuffle_scatter(reduced_shards, dst_devices)
581  if len(shape) != 1:
582    output_tensors = _reshape_tensors(output_tensors, shape)
583  return output_tensors
584
585
586def _build_shuffle_gather(input_tensors, gather_devices, red_op, un_op=None):
587  """Construct the gather (concentrate and reduce) phase of shuffle all-reduce.
588
589  Args:
590    input_tensors: list of T @(tf.Tensor} values to be reduced.
591    gather_devices: list of names of devices on which reduction shards
592      should be placed.
593    red_op: the binary reduction Op
594    un_op: optional elementwise unary Op to be applied to fully-reduced values.
595
596  Returns:
597    list of T `tf.Tensor` which are the fully reduced shards.
598
599  Raises:
600    ValueError: inputs not well-formed.
601  """
602  num_source_devices = len(input_tensors)
603  num_gather_devices = len(gather_devices)
604  shape = input_tensors[0].shape
605  if len(shape) != 1:
606    raise ValueError("input_tensors must be 1D")
607  shards_by_source = []
608  for d in range(0, num_source_devices):
609    with ops.colocate_with(input_tensors[d]):
610      shards_by_source.append(
611          _ragged_split(input_tensors[d], num_gather_devices))
612  reduced_shards = []
613  for d in range(0, num_gather_devices):
614    with ops.device(gather_devices[d]):
615      values = [s[d] for s in shards_by_source]
616      red_shard = red_op(values)
617      if un_op:
618        red_shard = un_op(red_shard)
619      reduced_shards.append(red_shard)
620  return reduced_shards
621
622
623def _build_shuffle_scatter(reduced_shards, dst_devices):
624  """Build the scatter phase of shuffle all-reduce.
625
626  Args:
627    reduced_shards:  list of T @(tf.Tensor} fully reduced shards
628    dst_devices: list of names of devices at which the fully-reduced value
629      should be reconstituted.
630
631  Returns:
632    list of T `tf.Tensor` scattered tensors.
633  """
634  num_devices = len(dst_devices)
635  out_tensors = []
636  for d in range(0, num_devices):
637    with ops.device(dst_devices[d]):
638      out_tensors.append(array_ops.concat(reduced_shards, 0))
639  return out_tensors
640
641
642def _split_by_task(devices, values):
643  """Partition devices and values by common task.
644
645  Args:
646    devices: list of device name strings
647    values: list of T `tf.tensor` of same length as devices.
648
649  Returns:
650    (per_task_devices, per_task_values) where both values are
651    lists of lists with isomorphic structure: the outer list is
652    indexed by task, and the inner list has length of the number
653    of values belonging to that task.  per_task_devices contains
654    the specific devices to which the values are local, and
655    per_task_values contains the corresponding values.
656
657  Raises:
658    ValueError: devices must be same length as values.
659  """
660  num_devices = len(devices)
661  if num_devices != len(values):
662    raise ValueError("len(devices) must equal len(values)")
663  per_task_devices = collections.OrderedDict()
664  per_task_values = collections.OrderedDict()
665  for d in range(num_devices):
666    d_spec = device_lib.DeviceSpec.from_string(devices[d])
667    if not hasattr(d_spec, "task") or d_spec.task is None:
668      assert False, "failed to parse device %s" % devices[d]
669    index = (d_spec.job or "localhost", d_spec.replica or 0, d_spec.task)
670    if index not in per_task_devices:
671      per_task_devices[index] = []
672      per_task_values[index] = []
673    per_task_devices[index].append(devices[d])
674    per_task_values[index].append(values[d])
675
676  return (list(per_task_devices.values()), list(per_task_values.values()))
677
678
679def build_nccl_all_reduce(input_tensors, red_op, un_op=None):
680  """Build a subgraph that does one full all-reduce, using NCCL.
681
682  Args:
683    input_tensors: list of T `tf.Tensor` of same-shape and type values to
684      be reduced.
685    red_op: binary elementwise reduction operator.  Must be one of
686      {tf.add}
687    un_op: optional unary elementwise Op to apply to fully-reduce values.
688
689  Returns:
690    list of T `tf.Tensor` of reduced values.
691
692  Raises:
693    ValueError: red_op not supported.
694  """
695  if red_op == math_ops.add:
696    output_tensors = nccl_ops.all_sum(input_tensors)
697  else:
698    raise ValueError("red_op not supported by NCCL all-reduce: ", red_op)
699  if un_op:
700    un_op_wrapped = []
701    for t in output_tensors:
702      with ops.colocate_with(t):
703        un_op_wrapped.append(un_op(t))
704    output_tensors = un_op_wrapped
705  return output_tensors
706
707
708def _build_nccl_hybrid(input_tensors, red_op, upper_level_f):
709  """Construct a subgraph for NCCL hybrid all-reduce.
710
711  Args:
712    input_tensors: list of T `tf.Tensor` of same-shape and type values to
713      be reduced.
714    red_op: binary elementwise reduction operator.
715    upper_level_f: function for reducing one value per worker, across
716      workers.
717
718  Returns:
719    list of T `tf.Tensor` of reduced values.
720
721  Raises:
722    ValueError: inputs not well-formed.
723  """
724  input_tensors, shape = _flatten_tensors(input_tensors)
725  devices = [t.device for t in input_tensors]
726  per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors)
727  num_workers = len(per_worker_devices)
728  up_values = [None for w in range(0, num_workers)]
729  up_devices = up_values[:]
730  down_values = up_values[:]
731  # First stage: reduce within each worker using NCCL
732  for w in range(0, num_workers):
733    worker_values = build_nccl_all_reduce(per_worker_values[w], red_op)
734    # NOTE: these reductions will not run to completion unless
735    # every output value is used.  Since we only need one, we
736    # need to put control dependencies on the rest.
737    with ops.control_dependencies(worker_values):
738      with ops.device(worker_values[0].device):
739        up_values[w] = array_ops.identity(worker_values[0])
740      up_devices[w] = per_worker_devices[w][0]
741  # Second stage: Apply upper_level_f to reduce across first device at
742  # each worker
743  level_2_output = upper_level_f(up_values)
744  # Third stage: propagate within each worker using NCCL Broadcast
745  for w in range(0, num_workers):
746    dst_tensors = []
747    with ops.device(per_worker_devices[w][0]):
748      broadcast_src = nccl_ops.broadcast(array_ops.identity(level_2_output[w]))
749    for d in per_worker_devices[w]:
750      with ops.device(d):
751        dst_tensors.append(array_ops.identity(broadcast_src))
752    down_values[w] = dst_tensors
753  output_tensors = [v for sublist in down_values for v in sublist]
754  if len(shape) != 1:
755    output_tensors = _reshape_tensors(output_tensors, shape)
756  return output_tensors
757
758
759def _reduce_non_singleton(input_tensors, red_f, un_op):
760  """If len(input_tensors) > 1, apply red_f, else apply un_op."""
761  if len(input_tensors) > 1:
762    return red_f(input_tensors)
763  else:
764    if not un_op:
765      return input_tensors
766    output_tensors = []
767    for t in input_tensors:
768      with ops.colocate_with(t):
769        output_tensors.append(un_op(t))
770    return output_tensors
771
772
773def build_nccl_then_ring(input_tensors, subdiv, red_op, un_op=None):
774  """Construct hybrid of NCCL within workers, Ring across workers."""
775  def upper_builder(y):
776    return build_ring_all_reduce(y, len(y), subdiv, [0], red_op, un_op)
777  def upper_level_f(x):
778    return _reduce_non_singleton(x, upper_builder, un_op)
779  return _build_nccl_hybrid(input_tensors, red_op, upper_level_f)
780
781
782def build_nccl_then_recursive_hd(input_tensors, red_op, un_op=None):
783  """Construct hybrid of NCCL within workers, Recursive-HD across workers."""
784  upper_level_f = lambda x: build_recursive_hd_all_reduce(x, red_op, un_op)
785  return _build_nccl_hybrid(input_tensors, red_op, upper_level_f)
786
787
788def build_nccl_then_shuffle(input_tensors, gather_devices, nccl_red_op,
789                            shuffle_red_op, un_op=None):
790  """Construct hybrid of NCCL within workers, Shuffle across workers."""
791  def upper_level_f(x):
792    return build_shuffle_all_reduce(x, gather_devices, shuffle_red_op, un_op)
793
794  return _build_nccl_hybrid(input_tensors, nccl_red_op, upper_level_f)
795
796
797def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f):
798  """Construct a subgraph for Shuffle hybrid all-reduce.
799
800  Args:
801    input_tensors: list of T `tf.Tensor` of same-shape and type values to
802      be reduced.
803    gather_devices: list of device names on which to host gather shards.
804    red_op: binary elementwise reduction operator.
805    upper_level_f: function for reducing one value per worker, across
806      workers.
807
808  Returns:
809    list of T `tf.Tensor` of reduced values.
810
811  Raises:
812    ValueError: inputs not well-formed.
813  """
814  input_tensors, shape = _flatten_tensors(input_tensors)
815  # First stage, reduce across each worker using gather_devices.
816  devices = [t.device for t in input_tensors]
817  per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors)
818  num_workers = len(per_worker_devices)
819  up_values = []
820  if len(gather_devices) != num_workers:
821    raise ValueError("For shuffle hybrid, gather_devices must contain one "
822                     "device per worker. ")
823  for w in range(0, num_workers):
824    reduced_shards = _build_shuffle_gather(
825        per_worker_values[w], [gather_devices[w]], red_op)
826    up_values.append(reduced_shards[0])
827  # Second stage, apply upper_level_f.
828  level_2_output = upper_level_f(up_values)
829  # Third stage, apply shuffle scatter at each worker.
830  output_tensors = []
831  for w in range(0, num_workers):
832    output_tensors += _build_shuffle_scatter(
833        [level_2_output[w]], per_worker_devices[w])
834  if len(shape) != 1:
835    output_tensors = _reshape_tensors(output_tensors, shape)
836  return output_tensors
837
838
839def build_shuffle_then_ring(input_tensors, gather_devices, subdiv,
840                            red_n_op, red_op, un_op=None):
841  """Construct hybrid of Shuffle within workers, Ring across workers."""
842  def upper_builder(tensors):
843    return build_ring_all_reduce(tensors, len(tensors), subdiv, [0],
844                                 red_op, un_op)
845  def upper_level_f(tensors):
846    return _reduce_non_singleton(tensors, upper_builder, un_op)
847  return _build_shuffle_hybrid(
848      input_tensors, gather_devices, red_n_op, upper_level_f)
849
850
851def build_shuffle_then_shuffle(input_tensors, first_gather_devices,
852                               second_gather_devices, red_op, un_op=None):
853  """Construct hybrid of Shuffle within workers, Shuffle across workers."""
854  def upper_builder(tensors):
855    return build_shuffle_all_reduce(tensors, second_gather_devices,
856                                    red_op, un_op)
857  def upper_level_f(tensors):
858    return _reduce_non_singleton(tensors, upper_builder, un_op)
859  return _build_shuffle_hybrid(
860      input_tensors, first_gather_devices, red_op, upper_level_f)
861