1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for CrossDeviceOps in v1 graph mode."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import itertools
22import os
23import threading
24import time
25
26from absl.testing import parameterized
27import numpy as np
28from tensorflow.core.protobuf import config_pb2
29from tensorflow.python.distribute import cluster_resolver
30from tensorflow.python.distribute import collective_all_reduce_strategy as mwms_lib
31from tensorflow.python.distribute import collective_util
32from tensorflow.python.distribute import combinations
33from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
34from tensorflow.python.distribute import cross_device_utils
35from tensorflow.python.distribute import device_util
36from tensorflow.python.distribute import distribute_utils
37from tensorflow.python.distribute import multi_worker_test_base
38from tensorflow.python.distribute import multi_worker_util
39from tensorflow.python.distribute import reduce_util
40from tensorflow.python.distribute import strategy_combinations
41from tensorflow.python.distribute import values as value_lib
42from tensorflow.python.eager import context
43from tensorflow.python.eager import test
44from tensorflow.python.framework import constant_op
45from tensorflow.python.framework import kernels
46from tensorflow.python.framework import ops
47from tensorflow.python.ops import array_ops
48from tensorflow.python.ops import collective_ops
49from tensorflow.python.ops import math_ops
50from tensorflow.python.ops import variables
51
52
53def _get_devices(devices):
54  if isinstance(devices, (tuple, list)):
55    return tuple(device_util.resolve(d) for d in devices)
56  elif isinstance(devices, value_lib.DistributedValues):
57    return devices._devices
58  elif isinstance(devices, ops.Tensor):
59    return (device_util.resolve(devices.device),)
60  return (device_util.resolve(devices),)
61
62
63def _make_per_replica(values, devices, regroup=False):
64  devices = _get_devices(devices)
65  assert len(values) == len(devices)
66
67  # We simulate the result of regroup called on PerReplica which strips the
68  # PerReplica wrapper if it has only one value.
69  if len(values) == 1 and regroup:
70    with ops.device(devices[0]):
71      placed_v = array_ops.identity(values[0])
72    return placed_v
73
74  index = []
75  for d, v in zip(devices, values):
76    with ops.device(d):
77      placed_v = array_ops.identity(v)
78    index.append(placed_v)
79  return distribute_utils.regroup(index)
80
81
82# pylint: disable=g-doc-args,g-doc-return-or-yield
83def _fake_mirrored(value, devices):
84  """Create a faked Mirrored object for testing.
85
86  All components of the returned Mirrored have the same objects, which is not
87  true in reality.
88  """
89  devices = _get_devices(devices)
90  values = []
91  for d in devices:
92    with ops.device(d):
93      values.append(array_ops.identity(value))
94  return distribute_utils.regroup(
95      values,
96      wrap_class=value_lib.Mirrored)
97
98
99def _make_indexed_slices(values, indices, dense_shape, device):
100  with ops.device(device):
101    tensor = ops.IndexedSlices(
102        values=constant_op.constant(values),
103        indices=constant_op.constant(indices),
104        dense_shape=constant_op.constant(dense_shape))
105  return tensor
106
107
108def _make_mirrored_indexed_slices(devices, values, indices, dense_shape):
109  values = [_make_indexed_slices(values, indices, dense_shape, d)
110            for d in devices]
111  return distribute_utils.regroup(
112      values,
113      wrap_class=value_lib.Mirrored)
114
115
116_cpu_device = "/device:CPU:0"
117
118
119class CrossDeviceOpsTestBase(test.TestCase, parameterized.TestCase):
120
121  def _assert_indexed_slices_equal(self, left, right):
122    self.assertIsInstance(left, ops.IndexedSlices)
123    self.assertIsInstance(right, ops.IndexedSlices)
124    self.assertEqual(
125        device_util.resolve(left.device), device_util.resolve(right.device))
126    self.assertAllEqual(
127        self.evaluate(ops.convert_to_tensor(left)),
128        self.evaluate(ops.convert_to_tensor(right)))
129
130  def _assert_mirrored_equal(self,
131                             left_list,
132                             right_list,
133                             sess=None,
134                             run_options=None):
135    if not isinstance(left_list, list):
136      left_list, right_list = [left_list], [right_list]
137
138    for left, right in zip(left_list, right_list):
139      self.assertEqual(type(left), type(right))
140
141      # Convert Mirrored to a list since sess.run(Mirrored) only returns one
142      # value.
143      if isinstance(left, value_lib.Mirrored):
144        left, right = left.values, right.values
145      else:
146        # When there's only one replica Mirrored is automatically unwrapped.
147        left, right = [left], [right]
148
149      for left_value, right_value in zip(left, right):
150        self.assertEqual(
151            device_util.resolve(left_value.device),
152            device_util.resolve(right_value.device))
153
154      # Densify IndexedSlices.
155      left = [ops.convert_to_tensor(v) for v in left]
156      right = [ops.convert_to_tensor(v) for v in right]
157      if not context.executing_eagerly():
158        left, right = sess.run((left, right), options=run_options)
159      for left_value, right_value in zip(left, right):
160        self.assertAllEqual(left_value, right_value)
161
162  def _testReductionAndBroadcast(self, cross_device_ops, devices):
163    if context.num_gpus() < sum(1 for d in devices if "GPU" in d.upper()):
164      self.skipTest("Not enough GPUs")
165
166    with self.cached_session() as sess:
167      values = [constant_op.constant(float(d)) for d in range(len(devices))]
168      per_replica = _make_per_replica(values, devices)
169      mean = (len(devices) - 1.) / 2.
170
171      values_2 = [constant_op.constant(d + 1.0) for d in range(len(devices))]
172      per_replica_2 = _make_per_replica(values_2, devices)
173      mean_2 = mean + 1.
174
175      destination_mirrored = _fake_mirrored(1., devices)
176      destination_different = _fake_mirrored(1.,
177                                             device_util.resolve(_cpu_device))
178      destination_str = device_util.resolve(_cpu_device)
179
180      all_destinations = [
181          destination_mirrored,
182          destination_different,
183          destination_str,
184      ]
185
186      # test reduce()
187      for destinations in all_destinations:
188        self._assert_mirrored_equal(
189            cross_device_ops.reduce(
190                reduce_util.ReduceOp.MEAN,
191                per_replica,
192                destinations=destinations), _fake_mirrored(mean, destinations),
193            sess)
194        self._assert_mirrored_equal(
195            cross_device_ops.reduce(
196                reduce_util.ReduceOp.MEAN,
197                per_replica_2,
198                destinations=destinations),
199            _fake_mirrored(mean_2, destinations), sess)
200        self._assert_mirrored_equal(
201            cross_device_ops.reduce(
202                reduce_util.ReduceOp.SUM,
203                per_replica,
204                destinations=destinations),
205            _fake_mirrored(mean * len(devices), destinations), sess)
206        self._assert_mirrored_equal(
207            cross_device_ops.reduce(
208                reduce_util.ReduceOp.SUM,
209                per_replica_2,
210                destinations=destinations),
211            _fake_mirrored(mean_2 * len(devices), destinations), sess)
212
213      # test batch_reduce()
214      for d1, d2 in itertools.product(all_destinations, all_destinations):
215        self._assert_mirrored_equal(
216            cross_device_ops.batch_reduce(reduce_util.ReduceOp.MEAN,
217                                          [(per_replica, d1),
218                                           (per_replica_2, d2)]),
219            [_fake_mirrored(mean, d1),
220             _fake_mirrored(mean_2, d2)], sess)
221        self._assert_mirrored_equal(
222            cross_device_ops.batch_reduce(reduce_util.ReduceOp.SUM,
223                                          [(per_replica, d1),
224                                           (per_replica_2, d2)]),
225            [
226                _fake_mirrored(mean * len(devices), d1),
227                _fake_mirrored(mean_2 * len(devices), d2)
228            ], sess)
229
230      # test broadcast()
231      for destinations in all_destinations:
232        self._assert_mirrored_equal(
233            cross_device_ops.broadcast(constant_op.constant(1.), destinations),
234            _fake_mirrored(1., destinations), sess)
235
236  def _testIndexedSlicesAllReduce(self, devices, cross_device_ops_instance,
237                                  reduce_op, batch_reduce):
238    with self.cached_session() as sess:
239      dense_shape = [5, 2]
240      t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0])
241      t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], dense_shape,
242                                devices[1])
243      per_replica = value_lib.PerReplica((t0, t1))
244
245      if batch_reduce:
246        result = cross_device_ops_instance.batch_reduce(
247            reduce_op, [(per_replica, per_replica)])
248      else:
249        result = cross_device_ops_instance.reduce(reduce_op, per_replica,
250                                                  per_replica)
251
252      total_indices_with_dups = [1, 1, 3]
253      total_indices_without_dups = [1, 3]
254
255      if reduce_op == reduce_util.ReduceOp.SUM:
256        total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]]
257        total_values_without_dups = [[4., 6.], [5., 6.]]
258      else:
259        assert reduce_op == reduce_util.ReduceOp.MEAN
260        total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]]
261        total_values_without_dups = [[2., 3.], [2.5, 3.]]
262
263      total_mirrored_with_dups = _make_mirrored_indexed_slices(
264          devices, total_values_with_dups, total_indices_with_dups, dense_shape)
265      total_mirrored_without_dups = _make_mirrored_indexed_slices(
266          devices, total_values_without_dups, total_indices_without_dups,
267          dense_shape)
268
269      # Test that the result is semantically equal to both the concatenated
270      # IndexedSlices, as well as when the duplicate indices are summed up.
271      if batch_reduce:
272        total_mirrored_with_dups = [total_mirrored_with_dups]
273        total_mirrored_without_dups = [total_mirrored_without_dups]
274
275      self._assert_mirrored_equal(total_mirrored_with_dups, result, sess)
276      self._assert_mirrored_equal(total_mirrored_without_dups, result, sess)
277
278
279class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
280
281  reduction_to_one_combinations = combinations.combine(
282      cross_device_ops=[
283          combinations.NamedObject("DefaultReductionToOneDevice",
284                                   cross_device_ops_lib.ReductionToOneDevice()),
285          combinations.NamedObject(
286              "ReductionToCPUDeviceCrossDeviceOps",
287              cross_device_ops_lib.ReductionToOneDevice(
288                  reduce_to_device=_cpu_device)),
289          combinations.NamedObject(
290              "AccumulateNCrossDeviceOp",
291              cross_device_ops_lib.ReductionToOneDevice(
292                  accumulation_fn=math_ops.add_n)),
293      ],
294      devices=[
295          ["/cpu:0"],
296          ["/cpu:0", "/gpu:0"],
297          ["/gpu:0", "/gpu:1"],
298      ],
299      mode=["graph", "eager"])
300  allreduce_combinations = combinations.combine(
301      cross_device_ops=[
302          combinations.NamedObject(
303              "AllReduce",
304              cross_device_ops_lib.AllReduceCrossDeviceOps("nccl", 1)),
305          combinations.NamedObject(
306              "AllReduceNoGradientRepacking",
307              cross_device_ops_lib.AllReduceCrossDeviceOps("nccl", 0)),
308          combinations.NamedObject("NcclAllReduce",
309                                   cross_device_ops_lib.NcclAllReduce()),
310          combinations.NamedObject(
311              "HierarchicalCopy",
312              cross_device_ops_lib.HierarchicalCopyAllReduce(8)),
313      ],
314      devices=[
315          ["/gpu:0", "/gpu:1"],
316      ],
317      mode=["graph", "eager"])
318
319  @combinations.generate(reduction_to_one_combinations + allreduce_combinations)
320  def testReductionAndBroadcast(self, cross_device_ops, devices):
321    if isinstance(
322        cross_device_ops._obj,  # pylint: disable=protected-access
323        cross_device_ops_lib.AllReduceCrossDeviceOps
324    ) and context.executing_eagerly():
325      self.skipTest("b/149881884")
326    self._testReductionAndBroadcast(cross_device_ops, devices)
327
328  def testChooseAlgorithm(self):
329    # Not use nccl if there is any cpu device.
330    self.assertIsInstance(
331        cross_device_ops_lib.select_cross_device_ops(["/cpu:0"]),
332        cross_device_ops_lib.ReductionToOneDevice)
333
334    # Not use nccl if requested device is not visible to TensorFlow.
335    # TODO(yuefengz): make `select_cross_device_ops` work with device strings
336    # self.assertIsInstance(
337    #     cross_device_ops_lib.select_cross_device_ops(["/gpu:100"]),
338    #     cross_device_ops_lib.ReductionToOneDevice)
339
340    if context.num_gpus() < 1:
341      return
342
343    devices = ["/gpu:0"]
344
345    def mock_get_registered_kernels_for_op(op):
346      if op == "NcclAllReduce":
347        return [object]
348      else:
349        return []
350
351    # Use nccl if nccl kernel is found.
352    with test.mock.patch.object(kernels, "get_registered_kernels_for_op",
353                                mock_get_registered_kernels_for_op):
354      self.assertIsInstance(
355          cross_device_ops_lib.select_cross_device_ops(devices),
356          cross_device_ops_lib.NcclAllReduce)
357
358    # Not use nccl if nccl kernel is not found.
359    with test.mock.patch.object(kernels,
360                                "get_registered_kernels_for_op", lambda _: []):
361      self.assertIsInstance(
362          cross_device_ops_lib.select_cross_device_ops(devices),
363          cross_device_ops_lib.ReductionToOneDevice)
364
365  @combinations.generate(combinations.combine(
366      mode=["graph", "eager"],
367      required_gpus=1))
368  def testSimpleReduceWithIndexedSlices(self):
369    devices = ["/cpu:0", "/gpu:0"]
370    t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0])
371    t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1])
372    per_replica = value_lib.PerReplica((t0, t1))
373    result = cross_device_ops_lib._simple_reduce(
374        per_replica, devices[0], math_ops.add_n, reduce_util.ReduceOp.SUM)
375
376    # Test that the result is semantically equal to both the concatenated
377    # IndexedSlices with and without duplicate indices.
378    total_with_dups = _make_indexed_slices(
379        [[1., 2.], [3., 4.], [5., 6.]], [1, 1, 3], [5, 2], devices[0])
380    total_without_dups = _make_indexed_slices(
381        [[4., 6.], [5., 6.]], [1, 3], [5, 2], devices[0])
382    self._assert_indexed_slices_equal(total_with_dups, result)
383    self._assert_indexed_slices_equal(total_without_dups, result)
384
385  @combinations.generate(
386      combinations.combine(
387          cross_device_ops_instance=[
388              combinations.NamedObject(
389                  "ReductionToOneDevice",
390                  cross_device_ops_lib.ReductionToOneDevice()),
391              combinations.NamedObject(
392                  "AllReduceCrossDeviceOps",
393                  cross_device_ops_lib.AllReduceCrossDeviceOps())
394          ],
395          reduce_op=[reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN],
396          batch_reduce=[True, False],
397          mode=["graph", "eager"],
398          required_gpus=1))
399  def testIndexedSlicesAllReduce(self, cross_device_ops_instance, reduce_op,
400                                 batch_reduce):
401    devices = ["/cpu:0", "/gpu:0"]
402    self._testIndexedSlicesAllReduce(devices, cross_device_ops_instance,
403                                     reduce_op, batch_reduce)
404
405  @combinations.generate(
406      combinations.combine(
407          distribution=strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
408          cross_device_ops_instance=[
409              combinations.NamedObject(
410                  "ReductionToOneDevice",
411                  cross_device_ops_lib.ReductionToOneDevice()),
412              combinations.NamedObject(
413                  "AllReduceCrossDeviceOps",
414                  cross_device_ops_lib.AllReduceCrossDeviceOps("ring"))
415          ],
416          batch_reduce=[True, False],
417          mode=["graph", "eager"]))
418  def testReduceDistributedVariable(self, distribution,
419                                    cross_device_ops_instance, batch_reduce):
420    with distribution.scope():
421      v = variables.Variable(1.)
422    if batch_reduce:
423      result = cross_device_ops_instance.batch_reduce(reduce_util.ReduceOp.MEAN,
424                                                      [(v, v)])[0]
425    else:
426      result = cross_device_ops_instance.reduce(reduce_util.ReduceOp.MEAN, v, v)
427    for v in result.values:
428      self.assertIsInstance(v, ops.Tensor)
429    self.evaluate(variables.global_variables_initializer())
430    self.assertAllEqual(self.evaluate(result.values), [1.0, 1.0])
431
432
433NUM_WORKERS = 3
434
435CollectiveCommunication = collective_util.CollectiveCommunication
436
437
438class CollectiveAllReduceTest(multi_worker_test_base.MultiWorkerTestBase,
439                              CrossDeviceOpsTestBase):
440
441  collective_key_base = 100000
442
443  @classmethod
444  def setUpClass(cls):
445    """Create a local cluster with 3 workers."""
446    cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
447        num_workers=NUM_WORKERS, num_ps=0)
448
449  def setUp(self):
450    super(CollectiveAllReduceTest, self).setUp()
451    # Reusing keys is not supported well. So we have to give a different
452    # collective key base for different tests.
453    CollectiveAllReduceTest.collective_key_base += 100000
454    mwms_lib.CollectiveAllReduceStrategy._collective_key_base = (
455        CollectiveAllReduceTest.collective_key_base)
456
457  def _get_test_objects(self,
458                        task_type,
459                        task_id,
460                        num_gpus=0,
461                        communication=CollectiveCommunication.AUTO,
462                        use_strategy_object=False,
463                        local_mode=False):
464    collective_keys = cross_device_utils.CollectiveKeys(
465        group_key_start=10 + CollectiveAllReduceTest.collective_key_base)
466    if local_mode:
467      if num_gpus:
468        devices = ["/device:GPU:%d" % i for i in range(num_gpus)]
469      else:
470        devices = ["/device:CPU:0"]
471
472      if use_strategy_object:
473        comm_options = collective_util.Options(implementation=communication)
474        strategy = (mwms_lib.CollectiveAllReduceStrategy
475                    ._from_local_devices(devices, comm_options))  # pylint: disable=protected-access
476        return strategy, devices, ""
477      else:
478        collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
479            devices=devices,
480            group_size=len(devices),
481            collective_keys=collective_keys)
482        return collective_all_reduce_ops, devices, ""
483    else:
484      # NCCL requires physical GPUs for every replica, which we can't do with
485      # simulated multi host set up now.
486      assert communication != CollectiveCommunication.NCCL
487      if num_gpus:
488        devices = [
489            "/job:%s/task:%d/replica:0/device:GPU:%d" % (task_type, task_id, i)
490            for i in range(num_gpus)
491        ]
492      else:
493        devices = [
494            "/job:%s/task:%d/replica:0/device:CPU:0" % (task_type, task_id)
495        ]
496
497      if use_strategy_object:
498        resolver = cluster_resolver.SimpleClusterResolver(
499            cluster_spec=multi_worker_util.normalize_cluster_spec(
500                self._cluster_spec),
501            task_type=task_type,
502            task_id=task_id,
503            num_accelerators={"GPU": num_gpus})
504        comm_options = collective_util.Options(implementation=communication)
505        strategy = mwms_lib.CollectiveAllReduceStrategy(
506            communication_options=comm_options, cluster_resolver=resolver)
507        return (strategy, devices,
508                "grpc://" + self._cluster_spec[task_type][task_id])
509      else:
510        collective_all_reduce_ops = cross_device_ops_lib.CollectiveAllReduce(
511            devices=devices,
512            group_size=len(devices) * NUM_WORKERS,
513            collective_keys=collective_keys)
514        return (collective_all_reduce_ops, devices,
515                "grpc://" + self._cluster_spec[task_type][task_id])
516
517  def _assert_mirrored_equal(self, left_list, right_list, sess=None):
518    if context.executing_eagerly():
519      run_options = None
520    else:
521      # TODO(b/151025792): figure out why missing run options would make the
522      # test flaky and whether this is a problem in TF 2.
523      run_options = config_pb2.RunOptions()
524      run_options.experimental.collective_graph_key = 5
525    super(CollectiveAllReduceTest, self)._assert_mirrored_equal(
526        left_list, right_list, sess, run_options=run_options)
527
528  def _test_reduction(self,
529                      task_type,
530                      task_id,
531                      num_gpus,
532                      communication,
533                      use_strategy_object=False,
534                      local_mode=False,
535                      hints=None):
536    collective_all_reduce, devices, master_target = self._get_test_objects(
537        task_type,
538        task_id,
539        num_gpus,
540        communication=communication,
541        use_strategy_object=use_strategy_object,
542        local_mode=local_mode)
543    if local_mode:
544      num_workers = 1
545      worker_device = None
546    else:
547      num_workers = len(self._cluster_spec.get("chief", [])) + len(
548          self._cluster_spec.get("worker", []))
549      worker_device = "/job:%s/task:%d" % (task_type, task_id)
550
551    def _reduce(test_object, reduce_op, per_replica, destinations):
552      if use_strategy_object:
553        with test_object.scope():
554          return test_object.extended.reduce_to(reduce_op, per_replica,
555                                                destinations, hints)
556      else:
557        return test_object.reduce(reduce_op, per_replica, destinations, hints)
558
559    def _batch_reduce(test_object, reduce_op, value_destination_pairs):
560      if use_strategy_object:
561        with test_object.scope():
562          return test_object.extended.batch_reduce_to(reduce_op,
563                                                      value_destination_pairs,
564                                                      hints)
565      else:
566        return test_object.batch_reduce(reduce_op, value_destination_pairs,
567                                        hints)
568
569    with ops.Graph().as_default(), \
570         ops.device(worker_device), \
571         self.cached_session(target=master_target) as sess:
572      # Collective ops doesn't support scalar tensors, so we have to construct
573      # 1-d tensors.
574      values = [constant_op.constant([float(d)]) for d in range(len(devices))]
575      per_replica = _make_per_replica(values, devices)
576      mean = np.array([(len(devices) - 1.) / 2.])
577
578      values_2 = [constant_op.constant([d + 1.0]) for d in range(len(devices))]
579      per_replica_2 = _make_per_replica(values_2, devices)
580      mean_2 = np.array([mean[0] + 1.])
581
582      destination_mirrored = _fake_mirrored(1., devices)
583      destination_different = _fake_mirrored(1., _cpu_device)
584      destination_str = _cpu_device
585
586      all_destinations = [
587          destination_different, destination_mirrored, destination_str
588      ]
589
590      # test reduce()
591      for destinations in all_destinations:
592        self._assert_mirrored_equal(
593            _reduce(
594                collective_all_reduce,
595                reduce_util.ReduceOp.MEAN,
596                per_replica,
597                destinations=destinations), _fake_mirrored(mean, destinations),
598            sess)
599        self._assert_mirrored_equal(
600            _reduce(
601                collective_all_reduce,
602                reduce_util.ReduceOp.MEAN,
603                per_replica_2,
604                destinations=destinations),
605            _fake_mirrored(mean_2, destinations), sess)
606        self._assert_mirrored_equal(
607            _reduce(
608                collective_all_reduce,
609                reduce_util.ReduceOp.SUM,
610                per_replica,
611                destinations=destinations),
612            _fake_mirrored(mean * len(devices) * num_workers, destinations),
613            sess)
614        self._assert_mirrored_equal(
615            _reduce(
616                collective_all_reduce,
617                reduce_util.ReduceOp.SUM,
618                per_replica_2,
619                destinations=destinations),
620            _fake_mirrored(mean_2 * len(devices) * num_workers, destinations),
621            sess)
622
623      # test batch_reduce()
624      for d1, d2 in itertools.product(all_destinations, all_destinations):
625        self._assert_mirrored_equal(
626            _batch_reduce(collective_all_reduce, reduce_util.ReduceOp.MEAN,
627                          [(per_replica, d1), (per_replica_2, d2)]),
628            [_fake_mirrored(mean, d1),
629             _fake_mirrored(mean_2, d2)], sess)
630        self._assert_mirrored_equal(
631            _batch_reduce(collective_all_reduce, reduce_util.ReduceOp.SUM,
632                          [(per_replica, d1), (per_replica_2, d2)]),
633            [
634                _fake_mirrored(mean * len(devices) * num_workers, d1),
635                _fake_mirrored(mean_2 * len(devices) * num_workers, d2)
636            ], sess)
637
638  def _get_indexed_slices(self,
639                          devices,
640                          start_i,
641                          variable_length,
642                          as_per_replica=True):
643    dense_shape = [10, 2]
644    values = ([[1., 2.]], [[3., 4.]], [[2., 1.]], [[0., 0.]], [[3., 1.]],
645              [[2., 1.]])
646    indices = ([1], [2], [3], [4], [5], [6])
647
648    # values and indices that have variable lengths.
649    vl_values = ([[1., 2.], [3., 4.]], [[3., 4.]], [[2., 1.]], [[0., 0.]],
650                 [[3., 1.], [2., 1.]], [[2., 1.]])
651    vl_indices = ([1, 2], [2], [3], [4], [5, 6], [6])
652
653    indexed_slices = []
654    for i, d in enumerate(devices):
655      idx = i + start_i
656      indexed_slices.append(
657          _make_indexed_slices(
658              vl_values[idx] if variable_length else values[idx],
659              vl_indices[idx] if variable_length else indices[idx], dense_shape,
660              d))
661    if as_per_replica:
662      per_replica = value_lib.PerReplica(indexed_slices)
663      return per_replica
664    else:
665      return indexed_slices
666
667  def _test_reduce_indexed_slices(self,
668                                  task_type,
669                                  task_id,
670                                  num_gpus,
671                                  communication,
672                                  batch_reduce,
673                                  variable_length,
674                                  local_mode=False):
675    collective_all_reduce, devices, master_target = self._get_test_objects(
676        task_type,
677        task_id,
678        num_gpus,
679        communication=communication,
680        local_mode=local_mode)
681    if local_mode:
682      num_workers = 1
683      worker_device = None
684    else:
685      num_workers = len(self._cluster_spec.get("chief", [])) + len(
686          self._cluster_spec.get("worker", []))
687      worker_device = "/job:%s/task:%d" % (task_type, task_id)
688    with ops.Graph().as_default(), \
689         ops.device(worker_device), \
690         self.cached_session(target=master_target) as sess:
691      per_replica = self._get_indexed_slices(devices,
692                                             (task_id or 0) * max(num_gpus, 1),
693                                             variable_length)
694
695      if batch_reduce:
696        result = collective_all_reduce.batch_reduce(
697            reduce_util.ReduceOp.SUM, [(per_replica, per_replica)])[0]
698      else:
699        result = collective_all_reduce.reduce(reduce_util.ReduceOp.SUM,
700                                              per_replica, per_replica)
701      if num_gpus > 1:
702        self.assertIsInstance(result, value_lib.Mirrored)
703
704      run_options = config_pb2.RunOptions()
705      run_options.experimental.collective_graph_key = 7
706      if num_gpus > 1:
707        result = sess.run([ops.convert_to_tensor(v) for v in result.values],
708                          options=run_options)[0]
709      else:
710        result = sess.run(ops.convert_to_tensor(result), options=run_options)
711
712      # Reduce the same indexed slices on CPU locally as our expected results.
713      devices_cpu = [(worker_device or "") + "/device:CPU:0"] * (
714          max(num_gpus, 1) * num_workers)
715      per_replica_on_cpu = self._get_indexed_slices(
716          devices_cpu, 0, variable_length, as_per_replica=False)
717      expected_result = cross_device_utils.aggregate_tensors_or_indexed_slices(
718          per_replica_on_cpu)
719      expected_result = sess.run(ops.convert_to_tensor(expected_result))
720
721      self.assertAllEqual(expected_result, result)
722
723  @combinations.generate(
724      combinations.combine(
725          mode=["graph"],
726          required_gpus=[0, 1, 2],
727          use_strategy_object=[True, False],
728          bytes_per_pack=[0, 1, 4]))
729  def testReductionDistributed(self, required_gpus, use_strategy_object,
730                               bytes_per_pack):
731    hints = collective_util.Hints(bytes_per_pack=bytes_per_pack)
732    self._run_between_graph_clients(
733        self._test_reduction,
734        self._cluster_spec,
735        required_gpus,
736        communication=CollectiveCommunication.RING,
737        use_strategy_object=use_strategy_object,
738        hints=hints)
739
740  @combinations.generate(
741      combinations.combine(
742          mode=["graph"],
743          required_gpus=[0, 1, 2],
744          variable_length=[True, False]))
745  def testReduceIndexedSlicesDistributed(self, required_gpus, variable_length):
746    self._run_between_graph_clients(
747        self._test_reduce_indexed_slices,
748        self._cluster_spec,
749        required_gpus,
750        communication=CollectiveCommunication.RING,
751        batch_reduce=True,
752        variable_length=variable_length)
753
754  # Collective ops doesn't support strategy with one device.
755  @combinations.generate(
756      combinations.combine(
757          mode=["graph"],
758          required_gpus=2,
759          communication=[
760              CollectiveCommunication.NCCL, CollectiveCommunication.RING
761          ],
762          use_strategy_object=[True, False]))
763  def testReductionLocal(self, required_gpus, communication,
764                         use_strategy_object):
765    self._test_reduction(
766        None,
767        None,
768        required_gpus,
769        communication=communication,
770        use_strategy_object=use_strategy_object,
771        local_mode=True)
772
773  @combinations.generate(
774      combinations.combine(
775          mode=["graph"],
776          required_gpus=2,
777          batch_reduce=[True, False],
778          variable_length=[True, False],
779          communication=[
780              CollectiveCommunication.NCCL, CollectiveCommunication.RING
781          ]))
782  def testReduceIndexedSlicesLocal(self, required_gpus, batch_reduce,
783                                   variable_length, communication):
784    self._test_reduce_indexed_slices(
785        None,
786        None,
787        required_gpus,
788        communication=communication,
789        batch_reduce=batch_reduce,
790        variable_length=variable_length,
791        local_mode=True)
792
793  @combinations.generate(
794      combinations.combine(
795          required_gpus=2,
796          mode="eager",
797          communication=[
798              CollectiveCommunication.NCCL, CollectiveCommunication.RING
799          ]))
800  def testEagerMultiThread(self, communication):
801    collective, devices, _ = self._get_test_objects(
802        None,
803        None,
804        num_gpus=2,
805        communication=communication,
806        use_strategy_object=False,
807        local_mode=True)
808
809    # We would like to simulate the following sequence:
810    #   thread-0  device0                 device1
811    #   thread-1          device0 device1
812    # If the kernel launch sequence is as-is the program will deadlock since
813    # NCCL requires the launch order to be same on each device.
814    v0 = _make_per_replica([1.0 for _ in devices], devices)
815    v1 = _make_per_replica([2.0 for _ in devices], devices)
816
817    # Add a delay to collective_ops.all_reduce according to the input tensors
818    # index in `sequence.`
819    sequence = [v0.values[0], v1.values[0], v1.values[1], v0.values[1]]
820    all_reduce = collective_ops.all_reduce
821
822    def delayed_all_reduce(input_tensor, *args, **kwargs):
823      for idx, v in enumerate(sequence):
824        if input_tensor is v:
825          time.sleep(idx)
826          break
827      return all_reduce(input_tensor, *args, **kwargs)
828
829    with test.mock.patch.object(collective_ops, "all_reduce",
830                                delayed_all_reduce):
831      # We only use NCCL for batch reduce with two or more values, so we use two
832      # values here.
833
834      def thread_fn():
835        reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v0, v0),
836                                                                     (v0, v0)])
837        self.assertAllEqual(reduced[0].values, [2.0, 2.0])
838        self.assertAllEqual(reduced[1].values, [2.0, 2.0])
839
840      t = threading.Thread(target=thread_fn)
841      t.start()
842      reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v1, v1),
843                                                                   (v1, v1)])
844      self.assertAllEqual(reduced[0].values, [4.0, 4.0])
845      self.assertAllEqual(reduced[1].values, [4.0, 4.0])
846      t.join()
847
848if __name__ == "__main__":
849  # Set default inter op thread pool size to one to ensure we don't exhaust the
850  # thread pool with the additional executors to run collectives in eager.
851  os.environ["TF_NUM_INTEROP_THREADS"] = "1"
852  test.main()
853