1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for CrossDeviceOps."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import os
23import threading
24import time
25
26from absl.testing import parameterized
27
28from tensorflow.core.protobuf import config_pb2
29from tensorflow.core.protobuf import tensorflow_server_pb2
30from tensorflow.python.distribute import cluster_resolver as cluster_resolver_lib
31from tensorflow.python.distribute import collective_util
32from tensorflow.python.distribute import combinations
33from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
34from tensorflow.python.distribute import cross_device_utils
35from tensorflow.python.distribute import device_util
36from tensorflow.python.distribute import multi_process_runner
37from tensorflow.python.distribute import multi_worker_test_base
38from tensorflow.python.distribute import reduce_util
39from tensorflow.python.distribute import test_util
40from tensorflow.python.distribute import values as value_lib
41from tensorflow.python.eager import context
42from tensorflow.python.eager import def_function
43from tensorflow.python.eager import test
44from tensorflow.python.framework import constant_op
45from tensorflow.python.framework import dtypes
46from tensorflow.python.framework import errors
47from tensorflow.python.framework import indexed_slices
48from tensorflow.python.framework import ops
49from tensorflow.python.ops import array_ops
50from tensorflow.python.ops import collective_ops
51from tensorflow.python.ops import control_flow_ops
52from tensorflow.python.ops import math_ops
53from tensorflow.python.util import nest
54
55CollectiveReplicaLauncher = cross_device_utils.CollectiveReplicaLauncher
56CommunicationImplementation = collective_util.CommunicationImplementation
57ReduceOp = reduce_util.ReduceOp
58IndexedSlicesValue = indexed_slices.IndexedSlicesValue
59IndexedSlices = indexed_slices.IndexedSlices
60
61
62def make_per_replica_value(value, devices):
63  """Creates a `PerReplica` object whose values reside in `devices`.
64
65  Args:
66    value: a tensor-convertible value or a `IndexedSlicesValue`, or a callable
67      that takes one argument (`device_idx`) and should return the value that is
68      going to be created on devices[device_idx].
69    devices: a list of device strings to create `PerReplica` values on.
70
71  Returns:
72    A `PerReplica` object.
73  """
74  values = []
75  for device_idx, device in enumerate(devices):
76    if callable(value):
77      v = value(device_idx)
78    elif isinstance(value, list):
79      v = value[device_idx]
80    else:
81      v = value
82    if isinstance(v, IndexedSlicesValue):
83      with ops.device(device):
84        values.append(
85            IndexedSlices(
86                values=array_ops.identity(v.values),
87                indices=array_ops.identity(v.indices),
88                dense_shape=array_ops.identity(v.dense_shape)))
89    else:
90      with ops.device(device):
91        values.append(array_ops.identity(v))
92  return value_lib.PerReplica(values)
93
94
95def enable_collective_ops():
96  """Enable collectives in the current process."""
97  cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver()
98  context.context().configure_collective_ops(
99      collective_leader="'/job:worker/replica:0/task:0'")
100  config_proto = config_pb2.ConfigProto()
101  config_proto.experimental.collective_group_leader = (
102      "/job:worker/replica:0/task:0")
103  server_def = tensorflow_server_pb2.ServerDef(
104      cluster=cluster_resolver.cluster_spec().as_cluster_def(),
105      default_session_config=config_proto,
106      job_name=cluster_resolver.task_type,
107      task_index=cluster_resolver.task_id,
108      protocol=cluster_resolver.rpc_layer)
109  context.context().enable_collective_ops(server_def)
110  # Recover default flag values.
111  CollectiveReplicaLauncher._prefer_unique_instance_key = True
112  CollectiveReplicaLauncher._prefer_ordering_token = False
113
114
115class MultiProcessPoolRunner():
116
117  def __init__(self, num_processes):
118    cluster_spec_dict = multi_worker_test_base.create_cluster_spec(
119        num_workers=num_processes)
120    self.runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec_dict)
121
122
123# Global MultiProcessPoolRunners that can be shared by test cases to avoid
124# expensive initialization cost of TensorFlow in new processes.
125#
126# Note that they have to be globals and can't be owned by test classes because
127# usually fn usually captures the test class instance, and test class
128# instance can't be pickled if it has mpr as a member (it is not allowed to
129# pickle Process objects).
130# TODO(crccw): Use `num_workers` combination once it is ready.
131global_mpr_2p = MultiProcessPoolRunner(num_processes=2)
132global_mpr_1p = MultiProcessPoolRunner(num_processes=1)
133
134
135def get_global_mpr(num_processes):
136  if num_processes == 1:
137    return global_mpr_1p.runner
138  elif num_processes == 2:
139    return global_mpr_2p.runner
140  else:
141    raise ValueError("get_global_mpr: num_processes must be 1 or 2, got %d" %
142                     num_processes)
143
144
145class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
146
147  def setUp(self):
148    super().setUp()
149    # Enabling collectives can be done in "setUpClass", but requires using
150    # different collective_keys in different tests as collectives are reused
151    # across tests. Always resetting collective ops before each test offers
152    # better test isolation.
153    global_mpr_1p.runner.run(enable_collective_ops)
154    global_mpr_2p.runner.run(enable_collective_ops)
155
156  def make_collective(self, num_processes, gpu_per_process):
157    """Returns collectives and other info to be used in tests.
158
159    Args:
160      num_processes: an integer indicating the number of processes that
161        participate in the collective.
162      gpu_per_process: number of GPUs (0 if no GPUs) used by each process.
163
164    Returns:
165     A tuple of (collective, devices, group_size) where collective is a instance
166     of `CollectiveAllReduce`, devices are a list of local devices (str)
167     attached to the current process, and group_size is the group_size of
168     collective.
169    """
170
171    cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver()
172    devices = [
173        "/job:worker/replica:0/task:%d/device:CPU:0" % cluster_resolver.task_id
174    ]
175    if gpu_per_process > 0:
176      devices = [
177          "/job:worker/replica:0/task:%d/device:GPU:%d" %
178          (cluster_resolver.task_id, i) for i in range(gpu_per_process)
179      ]
180    group_size = num_processes * len(devices)
181    collective = cross_device_ops_lib.CollectiveAllReduce(
182        devices=devices, group_size=group_size)
183    return collective, devices, cluster_resolver.task_id
184
185  def as_list(self, value):
186    """An utility to convert a `Mirrored`, `Tensor` or `IndexedSlices` to a list.
187
188    The reason it exists is to provide a uniformed view of returned value of
189    "reduce" calls, especially across tf.function boundaries. Returning
190    `Mirrored` from a tf.function will only evaluate the primary value, which
191    makes collective ops of non-primary device being pruned, and will eventually
192    cause hanging.
193
194    Args:
195      value: the value to convert, can be one of `Mirrored`, `Tensor` and
196        `IndexedSlices`.
197
198    Returns:
199      A list of `Tensor` or `IndexedSlices`.
200    """
201    if isinstance(value, ops.Tensor):
202      return [value]
203    elif isinstance(value, IndexedSlices):
204      return [value]
205    elif isinstance(value, value_lib.Mirrored):
206      return value.values
207    else:
208      raise ValueError("unwrap: unsupported input type: %s" % type(value))
209
210  RunOptions = collections.namedtuple(  # pylint: disable=invalid-name
211      "RunOptions",
212      [
213          "mode",  # A list of str from ["eager", "func_graph"]
214          "num_processes",
215          "gpus_per_process",
216          "reduce_op",
217          "communication_options",
218          "prefer_unique_instance_key",
219      ])
220  RunOptions.__new__.__defaults__ = (["eager",
221                                      "func_graph"], 2, 0, ReduceOp.SUM,
222                                     collective_util.Options(), True)
223
224  def reduce_and_verify(self, inputs, expect, options):
225    """Reduce the given `inputs` and verify the output matches `expect`.
226
227    Args:
228      inputs: a list of `Tensor` or `IndexedSlices`, where i-th value will be
229        fed to i-th replica.
230      expect: a `Tensor` or `IndexedSlices`. This should be the expected value
231        for one replica.
232      options: a `RunOpotions` instance.
233    """
234
235    def replica_fn():
236      CollectiveReplicaLauncher._prefer_unique_instance_key = (
237          options.prefer_unique_instance_key)
238      collective, devices, pid = self.make_collective(options.num_processes,
239                                                      options.gpus_per_process)
240
241      def reduce_fn():
242        value_fn = lambda device_idx: inputs[pid * len(devices) + device_idx]
243        per_replica_value = make_per_replica_value(value_fn, devices)
244        reduced_values = collective.reduce(options.reduce_op, per_replica_value,
245                                           per_replica_value,
246                                           options.communication_options)
247        reduced_values = self.as_list(reduced_values)
248        self.assertAllEqual(devices, [v.device for v in reduced_values])
249        return [ops.convert_to_tensor(v) for v in reduced_values]
250
251      per_replica_expect = [ops.convert_to_tensor(expect)] * len(devices)
252
253      if "eager" in options.mode:
254        got = reduce_fn()
255        self.assertAllClose(got, per_replica_expect)
256
257      if "func_graph" in options.mode:
258        got = def_function.function(reduce_fn)()
259        self.assertAllClose(got, per_replica_expect)
260
261    get_global_mpr(options.num_processes).run(replica_fn)
262
263  def batch_reduce_and_verify(self, inputs, expect, options):
264    """Batch reduce the given `inputs` and verify the output matches `expect`.
265
266    Args:
267      inputs: a 2-level nested list of `Tensor` or `IndexedSlices`, where i-th
268        value will be fed to i-th replica.
269      expect: a list of `Tensor` or `IndexedSlices`. This should be the expected
270        value for one replica.
271      options: a `RunOpotions` instance.
272    """
273
274    def replica_fn():
275      CollectiveReplicaLauncher._prefer_unique_instance_key = (
276          options.prefer_unique_instance_key)
277      collective, devices, pid = self.make_collective(options.num_processes,
278                                                      options.gpus_per_process)
279
280      def batch_reduce_fn():
281        batch_size = len(inputs[0])
282        value_dst_pairs = []
283        for i in range(batch_size):
284
285          def value_fn(device_idx, idx=i):
286            return inputs[pid * len(devices) + device_idx][idx]
287
288          per_replica_value = make_per_replica_value(value_fn, devices)
289          value_dst_pairs.append((per_replica_value, per_replica_value))
290        reduced_values = collective.batch_reduce(options.reduce_op,
291                                                 value_dst_pairs,
292                                                 options.communication_options)
293        reduced_values = [self.as_list(v) for v in reduced_values]
294        for v in reduced_values:
295          self.assertAllEqual(devices, [t.device for t in v])
296        return nest.map_structure(ops.convert_to_tensor, reduced_values)
297
298      per_replica_expect = nest.map_structure(
299          lambda x: [ops.convert_to_tensor(x)] * len(devices), expect)
300
301      if "eager" in options.mode:
302        got = batch_reduce_fn()
303        self.assertAllClose(got, per_replica_expect)
304
305      if "func_graph" in options.mode:
306        got = def_function.function(batch_reduce_fn)()
307        self.assertAllClose(got, per_replica_expect)
308
309    get_global_mpr(options.num_processes).run(replica_fn)
310
311  @combinations.generate(
312      combinations.combine(
313          num_processes=[1, 2],
314          required_gpus=[0, 1, 2],
315          implementation=[
316              CommunicationImplementation.AUTO,
317              CommunicationImplementation.RING,
318              CommunicationImplementation.NCCL,
319          ],
320          reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
321          prefer_unique_instance_key=[True, False]))
322  def testAllReduceDense(self, num_processes, required_gpus, implementation,
323                         reduce_op, prefer_unique_instance_key):
324    if (required_gpus == 0 and
325        implementation == CommunicationImplementation.NCCL):
326      self.skipTest("Skip CPU + NCCL combination")
327    if (num_processes == 2 and
328        implementation == CommunicationImplementation.NCCL):
329      self.skipTest("Skip NCCL + 2 processes combination. NCCL requires "
330                    "physical GPUs for every process.")
331    options = self.RunOptions(
332        num_processes=num_processes,
333        gpus_per_process=required_gpus,
334        reduce_op=reduce_op,
335        communication_options=collective_util.Options(
336            implementation=implementation),
337        prefer_unique_instance_key=prefer_unique_instance_key)
338    group_size = options.num_processes * (options.gpus_per_process or 1)
339
340    inputs_data = [1.0, 2.0, 3.0, 4.0]
341    inputs = inputs_data[0:group_size]
342
343    if group_size == 1:
344      expect = 1.0
345    if group_size == 2:
346      expect = 3.0 if reduce_op == ReduceOp.SUM else 1.5
347    elif group_size == 4:
348      expect = 10.0 if reduce_op == ReduceOp.SUM else 2.5
349
350    self.reduce_and_verify(inputs, expect, options)
351
352  @combinations.generate(
353      combinations.combine(
354          num_processes=[1, 2],
355          required_gpus=[0, 1, 2],
356          implementation=[
357              CommunicationImplementation.AUTO,
358              CommunicationImplementation.RING,
359              CommunicationImplementation.NCCL,
360          ],
361          # TODO(b/166682130): add MEAN reduce once the bug is fixed.
362          reduce_op=ReduceOp.SUM,
363          prefer_unique_instance_key=[True, False]))
364  def testAllReduceSparse(self, num_processes, required_gpus, implementation,
365                          reduce_op, prefer_unique_instance_key):
366    if (required_gpus == 0 and
367        implementation == CommunicationImplementation.NCCL):
368      self.skipTest("Skip CPU + NCCL combination")
369    if (num_processes == 2 and
370        implementation == CommunicationImplementation.NCCL):
371      self.skipTest("Skip NCCL + 2 processes combination. NCCL requires "
372                    "physical GPUs for every process.")
373    options = self.RunOptions(
374        mode=["func_graph"],  # Sparse reduce is not supported in eager.
375        num_processes=num_processes,
376        gpus_per_process=required_gpus,
377        reduce_op=reduce_op,
378        communication_options=collective_util.Options(
379            implementation=implementation),
380        prefer_unique_instance_key=prefer_unique_instance_key)
381    group_size = options.num_processes * (options.gpus_per_process or 1)
382
383    inputs_data = [
384        IndexedSlicesValue(
385            values=[[1.], [2.]], indices=[0, 1], dense_shape=[10, 1]),
386        IndexedSlicesValue(
387            values=[[3.], [4.]], indices=[1, 2], dense_shape=[10, 1]),
388        IndexedSlicesValue(
389            values=[[5.], [6.]], indices=[7, 8], dense_shape=[10, 1]),
390        IndexedSlicesValue(
391            values=[[7.], [8.]], indices=[3, 2], dense_shape=[10, 1]),
392    ]
393    inputs = inputs_data[0:group_size]
394
395    if group_size == 1:
396      expect = IndexedSlices(
397          values=[[1.], [2.]], indices=[0, 1], dense_shape=[10, 1])
398    elif group_size == 2:
399      expect = IndexedSlices(
400          values=[[1.], [2.], [3.], [4.]],
401          indices=[0, 1, 1, 2],
402          dense_shape=[10, 1])
403    elif group_size == 4:
404      expect = IndexedSlices(
405          values=[[1.], [2.], [3.], [4.], [5.], [6.], [7.], [8.]],
406          indices=[0, 1, 1, 2, 7, 8, 3, 2],
407          dense_shape=[10, 1])
408
409    self.reduce_and_verify(inputs, expect, options)
410
411  @combinations.generate(
412      combinations.combine(prefer_unique_instance_key=[True, False]))
413  def testAllReduceSparseVariableLength(self, prefer_unique_instance_key):
414    # One device per process, 2 processes, 2 replicas in total.
415    inputs = [
416        IndexedSlicesValue(values=[[1.]], indices=[0], dense_shape=[10, 1]),
417        IndexedSlicesValue(
418            values=[[2.], [3.], [4.]], indices=[0, 1, 2], dense_shape=[10, 1]),
419    ]
420    expect = IndexedSlices(
421        values=[[1.], [2.], [3.], [4.]],
422        indices=[0, 0, 1, 2],
423        dense_shape=[10, 1])
424    self.reduce_and_verify(
425        inputs,
426        expect,
427        self.RunOptions(
428            mode=["func_graph"],  # Sparse reduce is not supported in eager.
429            num_processes=2,
430            reduce_op=ReduceOp.SUM,
431            prefer_unique_instance_key=prefer_unique_instance_key))
432
433  @combinations.generate(
434      combinations.combine(
435          num_processes=[1, 2],
436          required_gpus=[0, 1, 2],
437          implementation=[
438              CommunicationImplementation.AUTO,
439              CommunicationImplementation.RING,
440              CommunicationImplementation.NCCL,
441          ],
442          reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
443          prefer_unique_instance_key=[True, False]))
444  def testBatchAllReduceDense(self, num_processes, required_gpus,
445                              implementation, reduce_op,
446                              prefer_unique_instance_key):
447    if (required_gpus == 0 and
448        implementation == CommunicationImplementation.NCCL):
449      self.skipTest("Skip CPU + NCCL combination")
450    if (num_processes == 2 and
451        implementation == CommunicationImplementation.NCCL):
452      self.skipTest("Skip NCCL + 2 processes combination. NCCL requires "
453                    "physical GPUs for every process.")
454
455    options = self.RunOptions(
456        num_processes=num_processes,
457        gpus_per_process=required_gpus,
458        reduce_op=reduce_op,
459        communication_options=collective_util.Options(
460            implementation=implementation),
461        prefer_unique_instance_key=prefer_unique_instance_key)
462    group_size = options.num_processes * (options.gpus_per_process or 1)
463
464    inputs_data = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]
465    inputs = inputs_data[0:group_size]
466
467    if group_size == 1:
468      expect = [1.0, 2.0]
469    if group_size == 2:
470      expect = [4.0, 6.0] if reduce_op == ReduceOp.SUM else [2.0, 3.0]
471    elif group_size == 4:
472      expect = [16.0, 20.0] if reduce_op == ReduceOp.SUM else [4.0, 5.0]
473
474    self.batch_reduce_and_verify(inputs, expect, options)
475
476  @combinations.generate(
477      combinations.combine(
478          num_processes=[1, 2],
479          required_gpus=[0, 1, 2],
480          implementation=[
481              CommunicationImplementation.AUTO,
482              CommunicationImplementation.RING,
483              CommunicationImplementation.NCCL,
484          ],
485          # TODO(b/166682130): add MEAN reduce once the bug is fixed.
486          reduce_op=ReduceOp.SUM,
487          prefer_unique_instance_key=[True, False]))
488  def testBatchAllReduceSparse(self, num_processes, required_gpus,
489                               implementation, reduce_op,
490                               prefer_unique_instance_key):
491    if (required_gpus == 0 and
492        implementation == CommunicationImplementation.NCCL):
493      self.skipTest("Skip CPU + NCCL combination")
494    if (num_processes == 2 and
495        implementation == CommunicationImplementation.NCCL):
496      self.skipTest("Skip NCCL + 2 processes combination. NCCL requires "
497                    "physical GPUs for every process.")
498
499    options = self.RunOptions(
500        mode=["func_graph"],  # Sparse reduce is not supported in eager.
501        num_processes=num_processes,
502        gpus_per_process=required_gpus,
503        reduce_op=reduce_op,
504        communication_options=collective_util.Options(
505            implementation=implementation),
506        prefer_unique_instance_key=prefer_unique_instance_key)
507    group_size = options.num_processes * (options.gpus_per_process or 1)
508
509    inputs_data = ([
510        IndexedSlicesValue(
511            values=[[1.], [2.]], indices=[0, 1], dense_shape=[10, 1]),
512        IndexedSlicesValue(
513            values=[[3.], [4.]], indices=[1, 2], dense_shape=[5, 1])
514    ], [
515        IndexedSlicesValue(
516            values=[[5.], [6.]], indices=[1, 2], dense_shape=[10, 1]),
517        IndexedSlicesValue(
518            values=[[7.], [8.]], indices=[0, 1], dense_shape=[5, 1])
519    ], [
520        IndexedSlicesValue(
521            values=[[9.], [10.]], indices=[3, 4], dense_shape=[10, 1]),
522        IndexedSlicesValue(
523            values=[[11.], [12.]], indices=[3, 4], dense_shape=[5, 1])
524    ], [
525        IndexedSlicesValue(
526            values=[[13.], [14.]], indices=[8, 9], dense_shape=[10, 1]),
527        IndexedSlicesValue(
528            values=[[15.], [16.]], indices=[3, 4], dense_shape=[5, 1])
529    ])
530    inputs = inputs_data[0:group_size]
531
532    if group_size == 1:
533      expect = [
534          IndexedSlices(
535              values=[[1.], [2.]], indices=[0, 1], dense_shape=[10, 1]),
536          IndexedSlicesValue(
537              values=[[3.], [4.]], indices=[1, 2], dense_shape=[5, 1])
538      ]
539    if group_size == 2:
540      expect = [
541          IndexedSlices(
542              values=[[1.], [2.], [5.], [6.]],
543              indices=[0, 1, 1, 2],
544              dense_shape=[10, 1]),
545          IndexedSlices(
546              values=[[3.], [4.], [7.], [8.]],
547              indices=[1, 2, 3, 4],
548              dense_shape=[5, 1])
549      ]
550    elif group_size == 4:
551      expect = [
552          IndexedSlices(
553              values=[[1.], [2.], [5.], [6.], [9.], [10.], [13.], [14.]],
554              indices=[0, 1, 1, 2, 3, 4, 8, 9],
555              dense_shape=[10, 1]),
556          IndexedSlices(
557              values=[[3.], [4.], [7.], [8.], [11.], [12.], [15.], [16.]],
558              indices=[1, 2, 0, 1, 3, 4, 3, 4],
559              dense_shape=[5, 2])
560      ]
561      self.batch_reduce_and_verify(inputs, expect, options)
562
563  @combinations.generate(
564      combinations.combine(
565          num_processes=[1, 2],
566          required_gpus=[0, 1, 2],
567          implementation=[
568              CommunicationImplementation.AUTO,
569              CommunicationImplementation.RING,
570              CommunicationImplementation.NCCL,
571          ],
572          reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
573      ))
574  def testCollectiveAllReduce(self, num_processes, required_gpus,
575                              implementation, reduce_op):
576    if (required_gpus == 0 and
577        implementation == CommunicationImplementation.NCCL):
578      self.skipTest("Skip CPU + NCCL combination")
579    if (num_processes == 2 and
580        implementation == CommunicationImplementation.NCCL):
581      self.skipTest("Skip NCCL + 2 processes combination. NCCL requires "
582                    "physical GPUs for every process.")
583
584    def replica_fn():
585      collective, devices, _ = self.make_collective(num_processes,
586                                                    required_gpus)
587      options = collective_util.Options(implementation=implementation)
588      group_size = num_processes * (required_gpus or 1)
589
590      @def_function.function
591      def collective_all_reduce():
592        results = []
593        for replica_id, device in enumerate(devices):
594          with ops.device(device):
595            value = constant_op.constant(1.0)
596            results.append(
597                collective._all_reduce(reduce_op, value, replica_id, options))
598        return results
599
600      got = collective_all_reduce()
601      if reduce_op == ReduceOp.SUM:
602        expect = [1.0 * group_size] * len(devices)
603      elif reduce_op == ReduceOp.MEAN:
604        expect = [1.0] * len(devices)
605      self.assertAllClose(got, expect)
606
607      @def_function.function
608      def collective_batch_all_reduce():
609        results = []
610        for replica_id, device in enumerate(devices):
611          with ops.device(device):
612            value = (constant_op.constant(1.0), constant_op.constant(2.0))
613            results.append(
614                collective._all_reduce(reduce_op, value, replica_id, options))
615        return results
616
617      got = collective_batch_all_reduce()
618      if reduce_op == ReduceOp.SUM:
619        expect = [(1.0 * group_size, 2.0 * group_size)] * len(devices)
620      elif reduce_op == ReduceOp.MEAN:
621        expect = [(1.0, 2.0)] * len(devices)
622      self.assertAllClose(got, expect)
623
624    get_global_mpr(num_processes).run(replica_fn)
625
626  @combinations.generate(
627      combinations.combine(
628          num_processes=[1, 2],
629          required_gpus=[0, 1, 2],
630          axis=[0, 1, 2],
631          func_mode=["eager", "func_graph"],
632          implementation=[
633              CommunicationImplementation.AUTO,
634              CommunicationImplementation.RING,
635              CommunicationImplementation.NCCL,
636          ],
637          prefer_unique_instance_key=[True, False]))
638  def testAllGatherSameShape(self, num_processes, required_gpus, implementation,
639                             func_mode, axis, prefer_unique_instance_key):
640
641    def replica_fn():
642      CollectiveReplicaLauncher._prefer_unique_instance_key = (
643          prefer_unique_instance_key)
644      collective, devices, _ = self.make_collective(num_processes,
645                                                    required_gpus)
646      options = collective_util.Options(implementation=implementation)
647      value = constant_op.constant([[[1, 2], [1, 2]]], dtype=dtypes.float32)
648
649      def gather_fn():
650        per_replica_value = make_per_replica_value(value, devices)
651        gathered_values = collective._gather(
652            per_replica_value, per_replica_value, axis=axis, options=options)
653        gathered_values = self.as_list(gathered_values)
654        # Skip checking devices in eager. In eager the device attribute doesn't
655        # reflect the actual device of the tensor.
656        if not context.executing_eagerly():
657          self.assertAllEqual(devices, [v.device for v in gathered_values])
658        return [ops.convert_to_tensor(v) for v in gathered_values]
659
660      group_size = num_processes * (required_gpus or 1)
661      expect = array_ops.concat([value] * group_size, axis=axis)
662      per_replica_expect = [ops.convert_to_tensor(expect)] * len(devices)
663
664      if func_mode == "eager":
665        result = gather_fn()
666        self.assertAllClose(result, per_replica_expect)
667
668      if func_mode == "func_graph":
669        result = def_function.function(gather_fn)()
670        self.assertAllClose(result, per_replica_expect)
671
672    get_global_mpr(num_processes).run(replica_fn)
673
674  @combinations.generate(
675      combinations.combine(
676          num_processes=[1, 2],
677          required_gpus=[0, 1, 2],
678          implementation=[CommunicationImplementation.RING]))
679  def testCollectiveV2ControlFlow(self, num_processes, required_gpus,
680                                  implementation):
681
682    def replica_fn():
683      CollectiveReplicaLauncher._prefer_unique_instance_key = True
684      collective, devices, _ = self.make_collective(num_processes,
685                                                    required_gpus)
686      options = collective_util.Options(implementation=implementation)
687      value = make_per_replica_value(constant_op.constant([1.]), devices)
688
689      @def_function.function
690      def reduce_fn():
691
692        def cond_body():
693          reduced = collective.reduce(reduce_util.ReduceOp.SUM, value, value,
694                                      options)
695          return math_ops.add_n(self.as_list(reduced)) / len(devices)
696
697        return control_flow_ops.cond(
698            array_ops.identity(False), cond_body, cond_body)
699
700      num_replicas = num_processes * len(devices)
701      self.assertAllEqual(reduce_fn(), [1. * num_replicas])
702
703    get_global_mpr(num_processes).run(replica_fn)
704
705  @combinations.generate(
706      combinations.combine(
707          num_processes=1,
708          required_gpus=2,
709          implementation=[
710              CommunicationImplementation.NCCL, CommunicationImplementation.RING
711          ],
712          prefer_unique_instance_key=[True, False]))
713  def testMultiThreadedCollectiveLaunchNoInterleave(self, num_processes,
714                                                    required_gpus,
715                                                    implementation,
716                                                    prefer_unique_instance_key):
717
718    def replica_fn():
719      CollectiveReplicaLauncher._prefer_unique_instance_key = (
720          prefer_unique_instance_key)
721      collective, devices, _ = self.make_collective(num_processes,
722                                                    required_gpus)
723      options = collective_util.Options(implementation=implementation)
724
725      # We would like to simulate the following sequence:
726      #   thread-0  device0                 device1
727      #   thread-1          device0 device1
728      # If the kernel launch sequence is as-is the program will deadlock since
729      # NCCL requires the launch order to be same on each device.
730      v0 = make_per_replica_value(1.0, devices)
731      v1 = make_per_replica_value(2.0, devices)
732
733      # Add a delay to collective_ops.all_reduce according to the input tensors
734      # index in `sequence.`
735      sequence = [v0.values[0], v1.values[0], v1.values[1], v0.values[1]]
736      all_reduce = collective_ops.all_reduce
737
738      def delayed_all_reduce(input_tensor, *args, **kwargs):
739        for idx, v in enumerate(sequence):
740          if input_tensor is v:
741            time.sleep(idx)
742            break
743        return all_reduce(input_tensor, *args, **kwargs)
744
745      with test.mock.patch.object(collective_ops, "all_reduce",
746                                  delayed_all_reduce):
747        # We only use NCCL for batch reduce with two or more values, so we use
748        # two values here.
749
750        def thread_fn():
751          reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM,
752                                            [(v0, v0), (v0, v0)], options)
753          self.assertAllEqual(reduced[0].values, [2.0, 2.0])
754          self.assertAllEqual(reduced[1].values, [2.0, 2.0])
755
756        t = threading.Thread(target=thread_fn)
757        t.start()
758        reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v1, v1),
759                                                                     (v1, v1)],
760                                          options)
761        self.assertAllEqual(reduced[0].values, [4.0, 4.0])
762        self.assertAllEqual(reduced[1].values, [4.0, 4.0])
763        t.join()
764
765    get_global_mpr(num_processes).run(replica_fn)
766
767  @combinations.generate(
768      combinations.combine(
769          num_processes=1,
770          required_gpus=2,
771          implementation=[
772              CommunicationImplementation.NCCL, CommunicationImplementation.RING
773          ],
774          prefer_unique_instance_key=[True, False]))
775  def testInputsAreFunctionArgs(self, num_processes, required_gpus,
776                                implementation, prefer_unique_instance_key):
777
778    def replica_fn():
779      CollectiveReplicaLauncher._prefer_unique_instance_key = (
780          prefer_unique_instance_key)
781      collective, devices, _ = self.make_collective(num_processes,
782                                                    required_gpus)
783      options = collective_util.Options(implementation=implementation)
784
785      @def_function.function
786      def reduce_fn(v):
787        # Function inputs don't have device placement.
788        self.assertEqual(v.values[0].device, "")
789        self.assertEqual(v.values[1].device, "")
790        # We only use NCCL for batch reduce with two or more values, so we use
791        # two values here.
792        reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v, v),
793                                                                     (v, v)],
794                                          options)
795        self.assertEqual(reduced[0].values[0].device, devices[0])
796        self.assertEqual(reduced[0].values[1].device, devices[1])
797        self.assertEqual(reduced[1].values[0].device, devices[0])
798        self.assertEqual(reduced[1].values[1].device, devices[1])
799        # Returning Mirrored only evaluates the primary value, which causes
800        # hanging,
801        return [reduced[0].values, reduced[1].values]
802
803      v = make_per_replica_value(1.0, devices)
804      reduced = reduce_fn(v)
805      self.assertAllClose(reduced, [[2.0, 2.0], [2.0, 2.0]])
806
807    get_global_mpr(num_processes).run(replica_fn)
808
809  @combinations.generate(
810      combinations.combine(
811          num_processes=2,
812          required_gpus=[0, 1],
813          implementation=[
814              CommunicationImplementation.RING, CommunicationImplementation.NCCL
815          ],
816          prefer_unique_instance_key=[True, False]))
817  def testTimeoutReduceDense(self, num_processes, implementation, required_gpus,
818                             prefer_unique_instance_key):
819
820    if (required_gpus == 0 and
821        implementation == CommunicationImplementation.NCCL):
822      self.skipTest("Skip CPU + NCCL combination")
823
824    def replica_fn():
825      CollectiveReplicaLauncher._prefer_unique_instance_key = (
826          prefer_unique_instance_key)
827      collective, devices, task_id = self.make_collective(
828          num_processes, required_gpus)
829      if task_id != 0:
830        return
831
832      v = make_per_replica_value(1.0, devices)
833      options = collective_util.Options(
834          timeout_seconds=1, implementation=implementation)
835
836      @def_function.function
837      def reduce_dense():
838        return collective.reduce(reduce_util.ReduceOp.SUM, v, v, options)
839
840      # The collective should time out because we only launch it on worker-0,
841      # while there're three workers in total.
842      with self.assertRaises(errors.DeadlineExceededError):
843        reduce_dense()
844
845    get_global_mpr(num_processes).run(replica_fn)
846
847  @combinations.generate(
848      combinations.combine(
849          num_processes=2,
850          required_gpus=[0, 1],
851          implementation=[
852              CommunicationImplementation.RING, CommunicationImplementation.NCCL
853          ],
854          prefer_unique_instance_key=[True, False]))
855  def testTimeoutBatchReduceDense(self, num_processes, implementation,
856                                  required_gpus, prefer_unique_instance_key):
857    if (required_gpus == 0 and
858        implementation == CommunicationImplementation.NCCL):
859      self.skipTest("Skip CPU + NCCL combination")
860
861    def replica_fn():
862      CollectiveReplicaLauncher._prefer_unique_instance_key = (
863          prefer_unique_instance_key)
864      collective, devices, task_id = self.make_collective(
865          num_processes, required_gpus)
866      if task_id != 0:
867        return
868
869      v = make_per_replica_value(1.0, devices)
870      options = collective_util.Options(
871          timeout_seconds=1, implementation=implementation)
872
873      @def_function.function
874      def batch_reduce_dense():
875        return collective.batch_reduce(reduce_util.ReduceOp.SUM,
876                                       [(v, v), (v, v)], options)
877
878      # The collective should time out because we only launch it on worker-0,
879      # while there're two workers in total.
880      with self.assertRaises(errors.DeadlineExceededError):
881        batch_reduce_dense()
882
883    get_global_mpr(num_processes).run(replica_fn)
884
885  @combinations.generate(
886      combinations.combine(
887          num_processes=2,
888          required_gpus=[0, 1],
889          implementation=[
890              CommunicationImplementation.RING, CommunicationImplementation.NCCL
891          ],
892          prefer_unique_instance_key=[True, False]))
893  def testTimeoutReduceSparse(self, num_processes, implementation,
894                              required_gpus, prefer_unique_instance_key):
895    if (required_gpus == 0 and
896        implementation == CommunicationImplementation.NCCL):
897      self.skipTest("Skip CPU + NCCL combination")
898
899    def replica_fn():
900      CollectiveReplicaLauncher._prefer_unique_instance_key = (
901          prefer_unique_instance_key)
902      collective, devices, task_id = self.make_collective(
903          num_processes, required_gpus)
904      if task_id != 0:
905        return
906
907      v = make_per_replica_value(
908          IndexedSlicesValue(
909              values=[[4., 6.]], indices=[1], dense_shape=[5, 2]), devices)
910      options = collective_util.Options(
911          timeout_seconds=1, implementation=implementation)
912
913      @def_function.function
914      def reduce_sparse():
915        return collective.reduce(reduce_util.ReduceOp.SUM, v, v, options)
916
917      # The collective should time out because we only launch it on worker-0,
918      # while there're two workers in total.
919      with self.assertRaises(errors.DeadlineExceededError):
920        reduce_sparse()
921
922    get_global_mpr(num_processes).run(replica_fn)
923
924  @combinations.generate(
925      combinations.combine(
926          num_processes=2,
927          required_gpus=[0, 1],
928          implementation=[
929              CommunicationImplementation.RING, CommunicationImplementation.NCCL
930          ],
931          prefer_unique_instance_key=[True, False]))
932  def testTimeoutBatchReduceSparse(self, num_processes, required_gpus,
933                                   implementation, prefer_unique_instance_key):
934    if (required_gpus == 0 and
935        implementation == CommunicationImplementation.NCCL):
936      self.skipTest("Skip CPU + NCCL combination")
937
938    def replica_fn():
939      CollectiveReplicaLauncher._prefer_unique_instance_key = (
940          prefer_unique_instance_key)
941      collective, devices, task_id = self.make_collective(
942          num_processes, required_gpus)
943      if task_id != 0:
944        return
945
946      v = make_per_replica_value(
947          IndexedSlicesValue(
948              values=[[4., 6.]], indices=[1], dense_shape=[5, 2]), devices)
949      options = collective_util.Options(
950          timeout_seconds=1, implementation=implementation)
951
952      @def_function.function
953      def batch_reduce_sparse():
954        return collective.batch_reduce(reduce_util.ReduceOp.SUM,
955                                       [(v, v), (v, v)], options)
956
957      # The collective should time out because we only launch it on worker-0,
958      # while there're two workers in total.
959      with self.assertRaises(errors.DeadlineExceededError):
960        batch_reduce_sparse()
961
962    get_global_mpr(num_processes).run(replica_fn)
963
964  @combinations.generate(combinations.combine(num_processes=1, required_gpus=2))
965  def testNcclOrdering(self, num_processes, required_gpus):
966
967    def replica_fn():
968      CollectiveReplicaLauncher._prefer_unique_instance_key = True
969      CollectiveReplicaLauncher._prefer_ordering_token = True
970      collective, devices, _ = self.make_collective(num_processes,
971                                                    required_gpus)
972      options = collective_util.Options(
973          implementation=CommunicationImplementation.NCCL)
974
975      v_dense = make_per_replica_value([1.0, 1.0], devices)
976      v_sparse = make_per_replica_value([
977          IndexedSlicesValue([[4., 6.], [5., 6.]], [1, 3], [5, 2]),
978          IndexedSlicesValue([[4., 6.], [5., 6.]], [1, 3], [5, 2]),
979      ], devices)
980
981      @def_function.function
982      def nested_dense():
983        collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options)
984
985      @def_function.function
986      def nested_sparse():
987        collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options)
988
989      # All collectives, function calls, if clause and while loops should be
990      # chained by control dependencies, so that the execution order is
991      # deterministic.
992      @def_function.function
993      def f():
994        # pylint: disable=pointless-statement
995        collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options)
996        # reducing dense value.
997        collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options)
998        # reducing sparse value.
999        collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options)
1000        # reduce dense value in nested tf.function.
1001        nested_dense()
1002        # reduce sparse value in nested tf.function.
1003        nested_sparse()
1004        # reduce dense value in tf.cond.
1005        if array_ops.identity(1.0) > array_ops.identity(2.0):
1006          collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options)
1007        else:
1008          v_dense
1009        # reduce sparse value in tf.cond.
1010        if array_ops.identity(1.0) > array_ops.identity(2.0):
1011          v_sparse
1012        else:
1013          collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse,
1014                            options)
1015        # reduce dense value in tf.while_loop.
1016        i = array_ops.identity(1)
1017        while i < 3:
1018          collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options)
1019          i += 1
1020        # reduce sparse value in tf.while_loop.
1021        i = array_ops.identity(1)
1022        while i < 3:
1023          collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse,
1024                            options)
1025          i += 1
1026        # reducing dense and sparse value again.
1027        collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options)
1028        collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options)
1029        # pylint: enable=pointless-statement
1030
1031      graph = f.get_concrete_function().graph
1032      should_be_ordered = set([
1033          "CollectiveReduceV2", "CollectiveGatherV2", "If", "While",
1034          "StatefulPartitionedCall"
1035      ])
1036      nodes_by_device = {}
1037      for op in graph.get_operations():
1038        if op.type in should_be_ordered:
1039          if op.device not in nodes_by_device:
1040            nodes_by_device[op.device] = []
1041          nodes_by_device[op.device].append(op)
1042      order = test_util.topological_sort_operations(graph.get_operations())
1043      for device in devices:
1044        device = device_util.canonicalize(device)
1045        # Those function ops don't have device annotations, but they contain
1046        # collectives for both devices so we always include them.
1047        operations = nodes_by_device[device] + nodes_by_device[""]
1048        # Verify that we get all types of nodes we want.
1049        self.assertEqual(set(op.type for op in operations), should_be_ordered)
1050        test_util.assert_sequential_execution(order, operations)
1051
1052    get_global_mpr(num_processes).run(replica_fn)
1053
1054
1055if __name__ == "__main__":
1056  # Set default inter op thread pool size to one to ensure we don't exhaust the
1057  # thread pool with the additional executors to run collectives in eager.
1058  os.environ["TF_NUM_INTEROP_THREADS"] = "1"
1059  # TODO(b/172304955): figure why logical devices doesn't work.
1060  test_util.main(config_logical_devices=False)
1061