1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for cross_device_ops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections as pycoll
22import threading
23
24from tensorflow.python.distribute import all_reduce
25from tensorflow.python.distribute import values as value_lib
26from tensorflow.python.eager import context
27from tensorflow.python.eager import def_function
28from tensorflow.python.framework import device as pydev
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import collective_ops
33from tensorflow.python.ops import gradients_util
34from tensorflow.python.ops import math_ops
35from tensorflow.python.ops import nccl_ops
36
37
38def aggregate_gradients_using_nccl(replica_grads):
39  """Aggregate gradients using nccl allreduce."""
40  agg_all_g_and_v = []
41  for single_g_and_v in zip(*replica_grads):
42    single_grads = [g for g, _ in single_g_and_v]
43    agg_grads = nccl_ops.all_sum(single_grads)
44    agg_all_g_and_v.append(
45        [(g, v) for g, (_, v) in zip(agg_grads, single_g_and_v)])
46
47  agg_all_g_and_v = list(zip(*agg_all_g_and_v))
48
49  return agg_all_g_and_v
50
51
52def aggregate_gradients_using_hierarchical_copy(avail_devices, replica_grads):
53  """Aggregate gradients using hierarchical copies.
54
55  Args:
56    avail_devices: available GPU devices.
57    replica_grads: List of lists of (gradient, variable) tuples. The outer list
58      is over replicas. The inner list is over individual gradients.
59
60  Returns:
61    The list of (aggregated_gradient, variable), where the gradient has been
62      summed across all replicas and the variable is chosen from the first
63      replica.
64  """
65  # This only works for DGX-1 type of machine topology
66  # Device peer to peer matrix
67  # DMA: 0 1 2 3 4 5 6 7
68  # 0:   Y Y Y Y Y N N N
69  # 1:   Y Y Y Y N Y N N
70  # 2:   Y Y Y Y N N Y N
71  # 3:   Y Y Y Y N N N Y
72  # 4:   Y N N N Y Y Y Y
73  # 5:   N Y N N Y Y Y Y
74  # 6:   N N Y N Y Y Y Y
75  # 7:   N N N Y Y Y Y Y
76  agg_grads = []
77  num_devices = len(avail_devices)
78  # In the special case of DGX-1 machine topology, the two groups have equal
79  # size.
80  group_size = num_devices // 2
81  for i, single_grads in enumerate(zip(*replica_grads)):
82    group_0_main_device = i % num_devices
83    group_1_main_device = (group_0_main_device + group_size) % num_devices
84    if group_0_main_device < group_size:
85      group_0_begin = 0
86      group_1_begin = group_size
87    else:
88      group_0_begin = group_size
89      group_1_begin = 0
90
91    # Aggregate the first group.
92    group_0_device_grads = single_grads[group_0_begin:
93                                        group_0_begin + group_size]
94    with ops.device(avail_devices[group_0_main_device]):
95      group_0_agg_grads, _ = aggregate_single_gradient_using_copy(
96          group_0_device_grads, False, False)
97
98    # Aggregate the second group.
99    group_1_device_grads = single_grads[group_1_begin:
100                                        group_1_begin + group_size]
101    with ops.device(avail_devices[group_1_main_device]):
102      group_1_agg_grads, _ = aggregate_single_gradient_using_copy(
103          group_1_device_grads, False, False)
104
105    # Aggregate between the groups.
106    with ops.device(avail_devices[group_0_main_device]):
107      (agg_total_grads, _), _ = aggregate_single_gradient_using_copy(
108          [group_0_agg_grads, group_1_agg_grads], False, False)
109
110    # Broadcast the result back into the root of each group.
111    with ops.device(avail_devices[group_0_main_device]):
112      group_0_agg_grads_bcast = array_ops.identity(agg_total_grads)
113    with ops.device(avail_devices[group_1_main_device]):
114      group_1_agg_grads_bcast = array_ops.identity(agg_total_grads)
115
116    agg_grads_bcast = []
117    for j in range(len(single_grads)):
118      with ops.device(avail_devices[j]):
119        # Broadcast the result back to each member in the group from the root.
120        if (group_0_main_device < group_size) == (j < group_size):
121          src_device_grad = group_0_agg_grads_bcast
122        else:
123          src_device_grad = group_1_agg_grads_bcast
124        agg_grads_bcast.append(array_ops.identity(src_device_grad))
125
126    agg_grads.append(
127        [(g, v) for g, (_, v) in zip(agg_grads_bcast, single_grads)])
128
129  agg_grads = list(zip(*agg_grads))
130
131  return agg_grads
132
133
134def aggregate_single_gradient_using_copy(grad_and_vars, use_mean,
135                                         check_inf_nan):
136  """Calculate the average gradient for a shared variable across all replicas.
137
138  Note that this function provides a synchronization point across all replicas.
139
140  Args:
141    grad_and_vars: A list or tuple of (gradient, variable) tuples. Each
142      (gradient, variable) pair within the outer list represents the gradient
143      of the variable calculated for a single replica, and the number of pairs
144      equals the number of replicas.
145    use_mean: if True, mean is taken, else sum of gradients is taken.
146    check_inf_nan: check grads for nans and infs.
147
148  Returns:
149    The tuple ([(average_gradient, variable),], has_nan_or_inf) where the
150      gradient has been averaged across all replicas. The variable is chosen
151      from the first replica. The has_nan_or_inf indicates the grads has nan or
152      inf.
153  """
154  grads = [g for g, _ in grad_and_vars]
155  grad = math_ops.add_n(grads)
156
157  if use_mean and len(grads) > 1:
158    grad = array_ops.multiply(grad, 1.0 / len(grads))
159
160  v = grad_and_vars[0][1]
161  if check_inf_nan:
162    has_nan_or_inf = array_ops.logical_not(
163        array_ops.reduce_all(array_ops.is_finite(grads)))
164    return (grad, v), has_nan_or_inf
165  else:
166    return (grad, v), None
167
168
169def group_device_names(devices, group_size):
170  """Group device names into groups of group_size.
171
172  Args:
173    devices: a list of canonical device strings.
174    group_size: integer which is equal to or greater than 1.
175
176  Returns:
177    list of lists of devices, where each inner list is group_size long,
178      and each device appears at least once in an inner list.  If
179      len(devices) % group_size == 0 then each device will appear exactly once.
180
181  Raises:
182    ValueError: if group_size > len(devices)
183  """
184  num_devices = len(devices)
185  if group_size > num_devices:
186    raise ValueError(
187        'only %d devices, but group_size=%d' % (num_devices, group_size))
188  num_groups = (
189      num_devices // group_size + (1 if (num_devices % group_size != 0) else 0))
190  groups = [[] for i in range(num_groups)]
191  for i in range(num_groups * group_size):
192    groups[i % num_groups].append(devices[i % num_devices])
193  return groups
194
195
196def split_grads_by_size(threshold_size, device_grads):
197  """Break gradients into two sets according to tensor size.
198
199  Args:
200    threshold_size: int size cutoff for small vs large tensor.
201    device_grads: List of lists of (gradient, variable) tuples.  The outer
202        list is over devices. The inner list is over individual gradients.
203
204  Returns:
205    small_grads: Subset of device_grads where shape is <= threshold_size
206       elements.
207    large_grads: Subset of device_grads where shape is > threshold_size
208       elements.
209  """
210  small_grads = []
211  large_grads = []
212  for dl in device_grads:
213    small_dl = []
214    large_dl = []
215    for (g, v) in dl:
216      tensor_size = g.get_shape().num_elements()
217      if tensor_size <= threshold_size:
218        small_dl.append([g, v])
219      else:
220        large_dl.append([g, v])
221    if small_dl:
222      small_grads.append(small_dl)
223    if large_dl:
224      large_grads.append(large_dl)
225  return small_grads, large_grads
226
227
228# threading.Lock() and threading.local() cannot be pickled and therefore cannot
229# be a field of CollectiveKeys. Right now _thread_local is not necessary to be
230# an instance member of CollectiveKeys since we always create a new thread for
231# each replica.
232_lock = threading.Lock()
233_thread_local = threading.local()
234
235
236# TODO(yuefengz): use random key starts to avoid reusing keys?
237class CollectiveKeys(object):
238  """Class that manages collective keys.
239
240  We need to manage three different keys for collective:
241
242  *Group key*: an integer key to identify the set of cooperative devices.
243  Collective ops work under the same set of devices must using the same group
244  key.
245
246  *Instance key*: an integer key to identify the set of same counterpart of
247  tensors on different devices in a device group that need to be all-reduced.
248
249  "Graph key": an integer key that is unique key graph. This is used to support
250  multiple graphs per client session. It must be non-zero and set in the
251  `config` argument of each call to `session.run`.
252  """
253
254  def __init__(self,
255               group_key_start=1,
256               instance_key_start=100,
257               instance_key_with_id_start=10000):
258    """Initializes the object.
259
260    Args:
261      group_key_start: the starting integer of group key.
262      instance_key_start: the starting integer of instance key.
263      instance_key_with_id_start: the starting integer of instance key that is
264        recorded with an id.
265    """
266    self._group_key = group_key_start
267    self._group_key_table = dict()
268
269    # For instance keys with ids
270    self._instance_key_id_to_key_table = dict()
271    self._instance_key_with_id_counter = instance_key_with_id_start
272
273    # For instance keys without ids
274    self._instance_key_start = instance_key_start
275
276  def _get_thread_local_object(self):
277    # We make instance key without key ids thread local so that it will work
278    # with MirroredStrategy and distribute coordinator.
279    if not hasattr(_thread_local, 'instance_key'):
280      _thread_local.instance_key = self._instance_key_start
281    return _thread_local
282
283  def get_group_key(self, devices):
284    """Returns a group key for the set of devices.
285
286    Args:
287      devices: list of strings naming devices in a collective group.
288
289    Returns:
290      int key uniquely identifying the set of device names.
291    """
292    parsed = [pydev.DeviceSpec.from_string(d) for d in devices]
293    # In the between-graph replicated training, different workers need to get
294    # the same device key. So we remove the task_type and task_id from the
295    # devices.
296    # TODO(yuefengz): in the in-graph replicated training, we need to include
297    # task_type and task_id.
298    names = sorted(['%s:%d' % (d.device_type, d.device_index) for d in parsed])
299    key_id = ','.join(names)
300    with _lock:
301      if key_id not in self._group_key_table:
302        new_key = self._group_key
303        self._group_key += 1
304        self._group_key_table[key_id] = new_key
305    return self._group_key_table[key_id]
306
307  def get_instance_key(self, key_id=None):
308    """Returns a new instance key for use in defining a collective op.
309
310    Args:
311      key_id: optional string. If set, key will be recorded and the same key
312        will be returned when the same key_id is provided. If not, an increasing
313        instance key will be returned.
314    """
315    if key_id:
316      with _lock:
317        if key_id not in self._instance_key_id_to_key_table:
318          self._instance_key_with_id_counter += 1
319          self._instance_key_id_to_key_table[key_id] = (
320              self._instance_key_with_id_counter)
321      return self._instance_key_id_to_key_table[key_id]
322    else:
323      v = self._get_thread_local_object().instance_key
324      self._get_thread_local_object().instance_key += 1
325      return v
326
327
328def build_collective_reduce(input_tensors,
329                            num_workers,
330                            collective_keys,
331                            reduction_op='Add',
332                            unary_op='Id'):
333  """Build a subgraph that does one full all-reduce, using the collective Op.
334
335  Args:
336    input_tensors: tensors within a single worker graph that are to be reduced
337      together; must be one per device.
338    num_workers: total number of workers with identical independent graphs that
339      will be doing this same reduction.  The reduction will actually include
340      the corresponding tensors at all these workers.
341    collective_keys: a CollectiveKeys object.
342    reduction_op: string naming the reduction op.
343    unary_op: string naming the unary final op.
344
345  Returns:
346    An array of final tensors, one per device, computed by the full reduction.
347
348  Raises:
349    ValueError: There must be at least two tensors over all the workers.
350  """
351  group_size = len(input_tensors) * num_workers
352  if group_size < 2:
353    return input_tensors
354  devices = [t.device for t in input_tensors]
355  num_devices = len(devices)
356  group_key = collective_keys.get_group_key(devices)
357  instance_key = collective_keys.get_instance_key()
358  subdiv_offsets = [0]  # TODO(tucker): maybe support non-default subdiv spec
359
360  def collective_all_reduce():
361    """Call collective allreduce."""
362    assert not context.executing_eagerly()
363    out_tensors = []
364    for d in range(num_devices):
365      with ops.device(devices[d]):
366        reduce_op = collective_ops.all_reduce(
367            input_tensors[d], group_size, group_key, instance_key, reduction_op,
368            unary_op, subdiv_offsets)
369        out_tensors.append(reduce_op)
370    return out_tensors
371
372  if context.executing_eagerly():
373    # Collective ops will block unless they are executed concurrently such as in
374    # a graph or a defun.
375    collective_all_reduce = def_function.function(collective_all_reduce)
376  return collective_all_reduce()
377
378
379def sum_grad_and_var_all_reduce(grad_and_vars,
380                                num_workers,
381                                alg,
382                                gpu_indices,
383                                aux_devices=None,
384                                num_shards=1):
385  """Apply all-reduce algorithm over specified gradient tensors."""
386  with ops.name_scope('allreduce'):
387    # Note that each grad_and_vars looks like the following:
388    #   ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
389    scaled_grads = [g for g, _ in grad_and_vars]
390    if alg == 'nccl':
391      summed_grads = nccl_ops.all_sum(scaled_grads)
392    elif alg == 'xring':
393      summed_grads = all_reduce.build_ring_all_reduce(
394          scaled_grads, num_workers, num_shards, gpu_indices, math_ops.add)
395    elif alg == 'nccl/xring':
396      summed_grads = all_reduce.build_nccl_then_ring(scaled_grads, num_shards,
397                                                     math_ops.add)
398    elif alg == 'nccl/rechd':
399      summed_grads = all_reduce.build_nccl_then_recursive_hd(
400          scaled_grads, math_ops.add)
401    elif alg == 'nccl/pscpu':
402      summed_grads = all_reduce.build_nccl_then_shuffle(
403          scaled_grads, aux_devices, math_ops.add, math_ops.add_n)
404    elif alg == 'pscpu/pscpu':
405      second_gather_devices = aux_devices[:num_shards]
406      summed_grads = all_reduce.build_shuffle_then_shuffle(
407          scaled_grads, aux_devices, second_gather_devices, math_ops.add_n)
408    elif alg in ['pscpu', 'psgpu']:
409      summed_grads = all_reduce.build_shuffle_all_reduce(
410          scaled_grads, aux_devices, math_ops.add_n)
411    else:
412      raise ValueError('unsupported all_reduce alg: ', alg)
413
414  result = []
415  for (_, v), g in zip(grad_and_vars, summed_grads):
416    result.append([g, v])
417  return result
418
419
420def sum_gradients_all_reduce(dev_prefixes, replica_grads, num_workers, alg,
421                             num_shards, gpu_indices):
422  """Apply all-reduce algorithm over specified gradient tensors.
423
424  Args:
425    dev_prefixes: list of prefix strings to use to generate PS device names.
426    replica_grads: the gradients to reduce.
427    num_workers: number of worker processes across entire job.
428    alg: the all-reduce algorithm to apply.
429    num_shards: alg-specific sharding factor.
430    gpu_indices: indices of local GPUs in order usable for ring-reduce.
431
432  Returns:
433    list of reduced tensors
434  """
435  alg_contains_shuffle = any(n in alg for n in ['pscpu', 'psgpu'])
436  is_hierarchical = '/' in alg
437  if 'pscpu' in alg:
438    aux_devices = [prefix + '/cpu:0' for prefix in dev_prefixes]
439  elif 'psgpu' in alg:
440    aux_devices = [
441        prefix + '/gpu:%d' % i
442        for i in range(len(gpu_indices))
443        for prefix in dev_prefixes
444    ]
445  else:
446    aux_devices = ['/job:localhost/cpu:0']
447  # Auxiliary devices for hierarchical all-reduces.
448  aux_device_groups = group_device_names(
449      aux_devices, num_shards if alg_contains_shuffle else 1)
450  group_index = 0
451  reduced_gv_list = []
452  for grad_and_vars in zip(*replica_grads):
453    reduced_gv_list.append(
454        sum_grad_and_var_all_reduce(
455            grad_and_vars, num_workers, alg, gpu_indices, aux_devices
456            if is_hierarchical else aux_device_groups[group_index], num_shards))
457    group_index = (group_index + 1) % len(aux_device_groups)
458  new_replica_grads = [list(x) for x in zip(*reduced_gv_list)]
459  return new_replica_grads
460
461
462def extract_ranges(index_list, range_size_limit=32):
463  """Extract consecutive ranges and singles from index_list.
464
465  Args:
466    index_list: List of monotone increasing non-negative integers.
467    range_size_limit: Largest size range to return.  If a larger
468      consecutive range exists, it will be returned as multiple
469      ranges.
470
471  Returns:
472    (ranges, singles) where ranges is a list of [first, last] pairs of
473      consecutive elements in index_list, and singles is all of the
474      other elements, in original order.
475  """
476  if not index_list:
477    return [], []
478  first = index_list[0]
479  last = first
480  ranges = []
481  singles = []
482  for i in index_list[1:]:
483    if i == last + 1 and (last - first) <= range_size_limit:
484      last = i
485    else:
486      if last > first:
487        ranges.append([first, last])
488      else:
489        singles.append(first)
490      first = i
491      last = i
492  if last > first:
493    ranges.append([first, last])
494  else:
495    singles.append(first)
496  return ranges, singles
497
498
499GradPackTuple = pycoll.namedtuple('GradPackTuple', 'indices vars shapes')
500
501
502def pack_range(key, packing, grad_vars, rng):
503  """Form the concatenation of a specified range of gradient tensors.
504
505  Args:
506    key: Value under which to store meta-data in packing that will be used
507      later to restore the grad_var list structure.
508    packing: Dict holding data describing packed ranges of small tensors.
509    grad_vars: List of (grad, var) pairs for one replica.
510    rng: A pair of integers giving the first, last indices of a consecutive
511      range of tensors to be packed.
512
513  Returns:
514    A tensor that is the concatenation of all the specified small tensors.
515  """
516  to_pack = grad_vars[rng[0]:rng[1] + 1]
517  members = []
518  variables = []
519  restore_shapes = []
520  with ops.name_scope('pack'):
521    for g, v in to_pack:
522      variables.append(v)
523      restore_shapes.append(g.shape)
524      with ops.device(g.device):
525        members.append(array_ops.reshape(g, [-1]))
526    packing[key] = GradPackTuple(
527        indices=range(rng[0], rng[1] + 1),
528        vars=variables,
529        shapes=restore_shapes)
530    with ops.device(members[0].device):
531      return array_ops.concat(members, 0)
532
533
534def unpack_grad_tuple(gv, gpt):
535  """Unpack a previously packed collection of gradient tensors.
536
537  Args:
538    gv: A (grad, var) pair to be unpacked.
539    gpt: A GradPackTuple describing the packing operation that produced gv.
540
541  Returns:
542    A list of (grad, var) pairs corresponding to the values that were
543     originally packed into gv, maybe following subsequent operations like
544     reduction.
545  """
546  elt_widths = [x.num_elements() for x in gpt.shapes]
547  with ops.device(gv[0][0].device):
548    with ops.name_scope('unpack'):
549      splits = array_ops.split(gv[0], elt_widths)
550      unpacked_gv = []
551      for idx, s in enumerate(splits):
552        unpacked_gv.append((array_ops.reshape(s, gpt.shapes[idx]),
553                            gpt.vars[idx]))
554  return unpacked_gv
555
556
557def pack_small_tensors(replica_grads, max_bytes=0, max_group=0):
558  """Concatenate small gradient tensors together for reduction.
559
560  Args:
561    replica_grads: List of lists of (gradient, variable) tuples.
562    max_bytes: Int giving max number of bytes in a tensor that
563      may be considered small.
564    max_group: Int giving max number of small tensors that may be
565      concatenated into one new tensor.
566
567  Returns:
568    new_replica_grads, packing where new_replica_grads is identical to
569      replica_grads except that all feasible small_tensors have been removed
570      from their places and concatenated into larger tensors that are
571      now in the front of the list for each replica, and packing contains
572      the data necessary to restore the replica_grads structure.
573
574  Look through the first replica for gradients of the same type (float),
575  and small size, that are all sequential.  For each such group,
576  replace by a new tensor that is a flattened concatenation.  Note
577  that the corresponding variable will be absent, which doesn't matter
578  because it isn't used during all-reduce.
579
580  Requires:
581    Every gv_list in replicas must have isomorphic structure including identical
582      tensor sizes and types.
583  """
584  small_indices = []
585  large_indices = []
586  for idx, (g, _) in enumerate(replica_grads[0]):
587    if g.dtype == dtypes.float32 and (4 * g.shape.num_elements()) <= max_bytes:
588      small_indices.append(idx)
589    else:
590      large_indices.append(idx)
591  small_ranges, small_singles = extract_ranges(
592      small_indices, range_size_limit=max_group)
593  large_indices = sorted(large_indices + small_singles)
594  num_gv = len(replica_grads[0])
595  packing = {}
596  if small_ranges:
597    new_replica_grads = []
598    for dev_idx, gv_list in enumerate(replica_grads):
599      assert len(gv_list) == num_gv
600      new_gv_list = []
601      for r in small_ranges:
602        key = '%d:%d' % (dev_idx, len(new_gv_list))
603        new_gv_list.append((pack_range(key, packing, gv_list, r),
604                            'packing_var_placeholder'))
605      for i in large_indices:
606        new_gv_list.append(gv_list[i])
607      new_replica_grads.append(new_gv_list)
608    return new_replica_grads, packing
609  else:
610    return replica_grads, None
611
612
613def unpack_small_tensors(replica_grads, packing):
614  """Undo the structure alterations to replica_grads done by pack_small_tensors.
615
616  Args:
617    replica_grads: List of List of (grad, var) tuples.
618    packing: A dict generated by pack_small_tensors describing the changes
619      it made to replica_grads.
620
621  Returns:
622    new_replica_grads: identical to replica_grads except that concatenations
623      of small tensors have been split apart and returned to their original
624      positions, paired with their original variables.
625  """
626  if not packing:
627    return replica_grads
628  new_replica_grads = []
629  num_devices = len(replica_grads)
630  num_packed = len(packing.keys()) // num_devices
631  for dev_idx, gv_list in enumerate(replica_grads):
632    gv_list = list(gv_list)
633    new_gv_list = gv_list[num_packed:]
634    for i in range(num_packed):
635      k = '%d:%d' % (dev_idx, i)
636      gpt = packing[k]
637      gv = unpack_grad_tuple(gv_list[i], gpt)
638      for gi, idx in enumerate(gpt.indices):
639        assert idx == gpt.indices[gi]
640        new_gv_list.insert(idx, gv[gi])
641    new_replica_grads.append(new_gv_list)
642  return new_replica_grads
643
644
645def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n):
646  """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat."""
647  if any(isinstance(v, ops.IndexedSlices) for v in values):
648    return gradients_util._AggregateIndexedSlicesGradients(values)  # pylint: disable=protected-access
649  else:
650    return accumulation_fn(values)
651
652
653def divide_by_n_tensors_or_indexed_slices(value, n):
654  if isinstance(value, ops.IndexedSlices):
655    value = gradients_util._HandleNestedIndexedSlices(value)  # pylint: disable=protected-access
656    return ops.IndexedSlices(
657        value.values / n, value.indices, value.dense_shape)
658  else:
659    return value / n
660
661
662def copy_tensor_or_indexed_slices_to_device(value, device):
663  with ops.device(device):
664    if isinstance(value, ops.IndexedSlices):
665      copied_values = array_ops.identity(value.values)
666      copied_indices = array_ops.identity(value.indices)
667      copied_shape = array_ops.identity(value.dense_shape)
668      result = ops.IndexedSlices(copied_values, copied_indices, copied_shape)
669    else:
670      result = array_ops.identity(value)
671  return result
672
673
674def contains_indexed_slices(value):
675  """Check whether the value is `IndexedSlices` or contains `IndexedSlices`."""
676  if isinstance(value, ops.IndexedSlices):
677    return True
678  elif isinstance(value, (list, tuple)) and value:
679    return any(contains_indexed_slices(v) for v in value)
680  elif isinstance(value, value_lib.DistributedValues):
681    return contains_indexed_slices(value.values)
682  else:
683    return False
684
685
686def is_indexed_slices(value):
687  if isinstance(value, ops.IndexedSlices):
688    return True
689  assert isinstance(value, value_lib.DistributedValues)
690  return all([isinstance(v, ops.IndexedSlices) for v in value.values])
691
692
693def split_by_sparsity(values):
694  """Split values into dense and sparse values.
695
696  Args:
697    values: a list of tensors or `PerReplica`s.
698
699  Returns:
700    Four lists:
701      a list of dense values, a list of their indices in `values` and
702      a list of sparse values, a list of their indices in `values`.
703  """
704  dense_values = []
705  dense_indices = []
706  sparse_values = []
707  sparse_indices = []
708  for i, v in enumerate(values):
709    if is_indexed_slices(v):
710      sparse_values.append(v)
711      sparse_indices.append(i)
712    else:
713      dense_values.append(v)
714      dense_indices.append(i)
715  return dense_values, dense_indices, sparse_values, sparse_indices
716
717
718def stitch_values(values_and_indices_list):
719  """Stitch values together according to their indices.
720
721  Args:
722    values_and_indices_list: a list of tuples of values and indices indicating
723      the values and postions in the returned list.
724
725  Returns:
726    a stitched list of values.
727  """
728  length = 0
729  for values_and_indices in values_and_indices_list:
730    length += len(values_and_indices[0])
731
732  result = [None] * length
733  for values_and_indices in values_and_indices_list:
734    if values_and_indices and values_and_indices[0]:
735      for v, i in zip(*values_and_indices):
736        assert result[i] is None
737        result[i] = v
738  return result
739