1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for CrossDeviceOps.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import os 23import threading 24import time 25 26from absl.testing import parameterized 27 28from tensorflow.core.protobuf import config_pb2 29from tensorflow.core.protobuf import tensorflow_server_pb2 30from tensorflow.python.distribute import cluster_resolver as cluster_resolver_lib 31from tensorflow.python.distribute import collective_util 32from tensorflow.python.distribute import combinations 33from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib 34from tensorflow.python.distribute import cross_device_utils 35from tensorflow.python.distribute import device_util 36from tensorflow.python.distribute import multi_process_runner 37from tensorflow.python.distribute import multi_worker_test_base 38from tensorflow.python.distribute import reduce_util 39from tensorflow.python.distribute import test_util 40from tensorflow.python.distribute import values as value_lib 41from tensorflow.python.eager import context 42from tensorflow.python.eager import def_function 43from tensorflow.python.eager import test 44from tensorflow.python.framework import constant_op 45from tensorflow.python.framework import dtypes 46from tensorflow.python.framework import errors 47from tensorflow.python.framework import indexed_slices 48from tensorflow.python.framework import ops 49from tensorflow.python.ops import array_ops 50from tensorflow.python.ops import collective_ops 51from tensorflow.python.ops import control_flow_ops 52from tensorflow.python.ops import math_ops 53from tensorflow.python.util import nest 54 55CollectiveReplicaLauncher = cross_device_utils.CollectiveReplicaLauncher 56CommunicationImplementation = collective_util.CommunicationImplementation 57ReduceOp = reduce_util.ReduceOp 58IndexedSlicesValue = indexed_slices.IndexedSlicesValue 59IndexedSlices = indexed_slices.IndexedSlices 60 61 62def make_per_replica_value(value, devices): 63 """Creates a `PerReplica` object whose values reside in `devices`. 64 65 Args: 66 value: a tensor-convertible value or a `IndexedSlicesValue`, or a callable 67 that takes one argument (`device_idx`) and should return the value that is 68 going to be created on devices[device_idx]. 69 devices: a list of device strings to create `PerReplica` values on. 70 71 Returns: 72 A `PerReplica` object. 73 """ 74 values = [] 75 for device_idx, device in enumerate(devices): 76 if callable(value): 77 v = value(device_idx) 78 elif isinstance(value, list): 79 v = value[device_idx] 80 else: 81 v = value 82 if isinstance(v, IndexedSlicesValue): 83 with ops.device(device): 84 values.append( 85 IndexedSlices( 86 values=array_ops.identity(v.values), 87 indices=array_ops.identity(v.indices), 88 dense_shape=array_ops.identity(v.dense_shape))) 89 else: 90 with ops.device(device): 91 values.append(array_ops.identity(v)) 92 return value_lib.PerReplica(values) 93 94 95def enable_collective_ops(): 96 """Enable collectives in the current process.""" 97 cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver() 98 context.context().configure_collective_ops( 99 collective_leader="'/job:worker/replica:0/task:0'") 100 config_proto = config_pb2.ConfigProto() 101 config_proto.experimental.collective_group_leader = ( 102 "/job:worker/replica:0/task:0") 103 server_def = tensorflow_server_pb2.ServerDef( 104 cluster=cluster_resolver.cluster_spec().as_cluster_def(), 105 default_session_config=config_proto, 106 job_name=cluster_resolver.task_type, 107 task_index=cluster_resolver.task_id, 108 protocol=cluster_resolver.rpc_layer) 109 context.context().enable_collective_ops(server_def) 110 # Recover default flag values. 111 CollectiveReplicaLauncher._prefer_unique_instance_key = True 112 CollectiveReplicaLauncher._prefer_ordering_token = False 113 114 115class MultiProcessPoolRunner(): 116 117 def __init__(self, num_processes): 118 cluster_spec_dict = multi_worker_test_base.create_cluster_spec( 119 num_workers=num_processes) 120 self.runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec_dict) 121 122 123# Global MultiProcessPoolRunners that can be shared by test cases to avoid 124# expensive initialization cost of TensorFlow in new processes. 125# 126# Note that they have to be globals and can't be owned by test classes because 127# usually fn usually captures the test class instance, and test class 128# instance can't be pickled if it has mpr as a member (it is not allowed to 129# pickle Process objects). 130# TODO(crccw): Use `num_workers` combination once it is ready. 131global_mpr_2p = MultiProcessPoolRunner(num_processes=2) 132global_mpr_1p = MultiProcessPoolRunner(num_processes=1) 133 134 135def get_global_mpr(num_processes): 136 if num_processes == 1: 137 return global_mpr_1p.runner 138 elif num_processes == 2: 139 return global_mpr_2p.runner 140 else: 141 raise ValueError("get_global_mpr: num_processes must be 1 or 2, got %d" % 142 num_processes) 143 144 145class CollectiveOpsTest(test.TestCase, parameterized.TestCase): 146 147 def setUp(self): 148 super().setUp() 149 # Enabling collectives can be done in "setUpClass", but requires using 150 # different collective_keys in different tests as collectives are reused 151 # across tests. Always resetting collective ops before each test offers 152 # better test isolation. 153 global_mpr_1p.runner.run(enable_collective_ops) 154 global_mpr_2p.runner.run(enable_collective_ops) 155 156 def make_collective(self, num_processes, gpu_per_process): 157 """Returns collectives and other info to be used in tests. 158 159 Args: 160 num_processes: an integer indicating the number of processes that 161 participate in the collective. 162 gpu_per_process: number of GPUs (0 if no GPUs) used by each process. 163 164 Returns: 165 A tuple of (collective, devices, group_size) where collective is a instance 166 of `CollectiveAllReduce`, devices are a list of local devices (str) 167 attached to the current process, and group_size is the group_size of 168 collective. 169 """ 170 171 cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver() 172 devices = [ 173 "/job:worker/replica:0/task:%d/device:CPU:0" % cluster_resolver.task_id 174 ] 175 if gpu_per_process > 0: 176 devices = [ 177 "/job:worker/replica:0/task:%d/device:GPU:%d" % 178 (cluster_resolver.task_id, i) for i in range(gpu_per_process) 179 ] 180 group_size = num_processes * len(devices) 181 collective = cross_device_ops_lib.CollectiveAllReduce( 182 devices=devices, group_size=group_size) 183 return collective, devices, cluster_resolver.task_id 184 185 def as_list(self, value): 186 """An utility to convert a `Mirrored`, `Tensor` or `IndexedSlices` to a list. 187 188 The reason it exists is to provide a uniformed view of returned value of 189 "reduce" calls, especially across tf.function boundaries. Returning 190 `Mirrored` from a tf.function will only evaluate the primary value, which 191 makes collective ops of non-primary device being pruned, and will eventually 192 cause hanging. 193 194 Args: 195 value: the value to convert, can be one of `Mirrored`, `Tensor` and 196 `IndexedSlices`. 197 198 Returns: 199 A list of `Tensor` or `IndexedSlices`. 200 """ 201 if isinstance(value, ops.Tensor): 202 return [value] 203 elif isinstance(value, IndexedSlices): 204 return [value] 205 elif isinstance(value, value_lib.Mirrored): 206 return value.values 207 else: 208 raise ValueError("unwrap: unsupported input type: %s" % type(value)) 209 210 RunOptions = collections.namedtuple( # pylint: disable=invalid-name 211 "RunOptions", 212 [ 213 "mode", # A list of str from ["eager", "func_graph"] 214 "num_processes", 215 "gpus_per_process", 216 "reduce_op", 217 "communication_options", 218 "prefer_unique_instance_key", 219 ]) 220 RunOptions.__new__.__defaults__ = (["eager", 221 "func_graph"], 2, 0, ReduceOp.SUM, 222 collective_util.Options(), True) 223 224 def reduce_and_verify(self, inputs, expect, options): 225 """Reduce the given `inputs` and verify the output matches `expect`. 226 227 Args: 228 inputs: a list of `Tensor` or `IndexedSlices`, where i-th value will be 229 fed to i-th replica. 230 expect: a `Tensor` or `IndexedSlices`. This should be the expected value 231 for one replica. 232 options: a `RunOpotions` instance. 233 """ 234 235 def replica_fn(): 236 CollectiveReplicaLauncher._prefer_unique_instance_key = ( 237 options.prefer_unique_instance_key) 238 collective, devices, pid = self.make_collective(options.num_processes, 239 options.gpus_per_process) 240 241 def reduce_fn(): 242 value_fn = lambda device_idx: inputs[pid * len(devices) + device_idx] 243 per_replica_value = make_per_replica_value(value_fn, devices) 244 reduced_values = collective.reduce(options.reduce_op, per_replica_value, 245 per_replica_value, 246 options.communication_options) 247 reduced_values = self.as_list(reduced_values) 248 self.assertAllEqual(devices, [v.device for v in reduced_values]) 249 return [ops.convert_to_tensor(v) for v in reduced_values] 250 251 per_replica_expect = [ops.convert_to_tensor(expect)] * len(devices) 252 253 if "eager" in options.mode: 254 got = reduce_fn() 255 self.assertAllClose(got, per_replica_expect) 256 257 if "func_graph" in options.mode: 258 got = def_function.function(reduce_fn)() 259 self.assertAllClose(got, per_replica_expect) 260 261 get_global_mpr(options.num_processes).run(replica_fn) 262 263 def batch_reduce_and_verify(self, inputs, expect, options): 264 """Batch reduce the given `inputs` and verify the output matches `expect`. 265 266 Args: 267 inputs: a 2-level nested list of `Tensor` or `IndexedSlices`, where i-th 268 value will be fed to i-th replica. 269 expect: a list of `Tensor` or `IndexedSlices`. This should be the expected 270 value for one replica. 271 options: a `RunOpotions` instance. 272 """ 273 274 def replica_fn(): 275 CollectiveReplicaLauncher._prefer_unique_instance_key = ( 276 options.prefer_unique_instance_key) 277 collective, devices, pid = self.make_collective(options.num_processes, 278 options.gpus_per_process) 279 280 def batch_reduce_fn(): 281 batch_size = len(inputs[0]) 282 value_dst_pairs = [] 283 for i in range(batch_size): 284 285 def value_fn(device_idx, idx=i): 286 return inputs[pid * len(devices) + device_idx][idx] 287 288 per_replica_value = make_per_replica_value(value_fn, devices) 289 value_dst_pairs.append((per_replica_value, per_replica_value)) 290 reduced_values = collective.batch_reduce(options.reduce_op, 291 value_dst_pairs, 292 options.communication_options) 293 reduced_values = [self.as_list(v) for v in reduced_values] 294 for v in reduced_values: 295 self.assertAllEqual(devices, [t.device for t in v]) 296 return nest.map_structure(ops.convert_to_tensor, reduced_values) 297 298 per_replica_expect = nest.map_structure( 299 lambda x: [ops.convert_to_tensor(x)] * len(devices), expect) 300 301 if "eager" in options.mode: 302 got = batch_reduce_fn() 303 self.assertAllClose(got, per_replica_expect) 304 305 if "func_graph" in options.mode: 306 got = def_function.function(batch_reduce_fn)() 307 self.assertAllClose(got, per_replica_expect) 308 309 get_global_mpr(options.num_processes).run(replica_fn) 310 311 @combinations.generate( 312 combinations.combine( 313 num_processes=[1, 2], 314 required_gpus=[0, 1, 2], 315 implementation=[ 316 CommunicationImplementation.AUTO, 317 CommunicationImplementation.RING, 318 CommunicationImplementation.NCCL, 319 ], 320 reduce_op=[ReduceOp.SUM, ReduceOp.MEAN], 321 prefer_unique_instance_key=[True, False])) 322 def testAllReduceDense(self, num_processes, required_gpus, implementation, 323 reduce_op, prefer_unique_instance_key): 324 if (required_gpus == 0 and 325 implementation == CommunicationImplementation.NCCL): 326 self.skipTest("Skip CPU + NCCL combination") 327 if (num_processes == 2 and 328 implementation == CommunicationImplementation.NCCL): 329 self.skipTest("Skip NCCL + 2 processes combination. NCCL requires " 330 "physical GPUs for every process.") 331 options = self.RunOptions( 332 num_processes=num_processes, 333 gpus_per_process=required_gpus, 334 reduce_op=reduce_op, 335 communication_options=collective_util.Options( 336 implementation=implementation), 337 prefer_unique_instance_key=prefer_unique_instance_key) 338 group_size = options.num_processes * (options.gpus_per_process or 1) 339 340 inputs_data = [1.0, 2.0, 3.0, 4.0] 341 inputs = inputs_data[0:group_size] 342 343 if group_size == 1: 344 expect = 1.0 345 if group_size == 2: 346 expect = 3.0 if reduce_op == ReduceOp.SUM else 1.5 347 elif group_size == 4: 348 expect = 10.0 if reduce_op == ReduceOp.SUM else 2.5 349 350 self.reduce_and_verify(inputs, expect, options) 351 352 @combinations.generate( 353 combinations.combine( 354 num_processes=[1, 2], 355 required_gpus=[0, 1, 2], 356 implementation=[ 357 CommunicationImplementation.AUTO, 358 CommunicationImplementation.RING, 359 CommunicationImplementation.NCCL, 360 ], 361 # TODO(b/166682130): add MEAN reduce once the bug is fixed. 362 reduce_op=ReduceOp.SUM, 363 prefer_unique_instance_key=[True, False])) 364 def testAllReduceSparse(self, num_processes, required_gpus, implementation, 365 reduce_op, prefer_unique_instance_key): 366 if (required_gpus == 0 and 367 implementation == CommunicationImplementation.NCCL): 368 self.skipTest("Skip CPU + NCCL combination") 369 if (num_processes == 2 and 370 implementation == CommunicationImplementation.NCCL): 371 self.skipTest("Skip NCCL + 2 processes combination. NCCL requires " 372 "physical GPUs for every process.") 373 options = self.RunOptions( 374 mode=["func_graph"], # Sparse reduce is not supported in eager. 375 num_processes=num_processes, 376 gpus_per_process=required_gpus, 377 reduce_op=reduce_op, 378 communication_options=collective_util.Options( 379 implementation=implementation), 380 prefer_unique_instance_key=prefer_unique_instance_key) 381 group_size = options.num_processes * (options.gpus_per_process or 1) 382 383 inputs_data = [ 384 IndexedSlicesValue( 385 values=[[1.], [2.]], indices=[0, 1], dense_shape=[10, 1]), 386 IndexedSlicesValue( 387 values=[[3.], [4.]], indices=[1, 2], dense_shape=[10, 1]), 388 IndexedSlicesValue( 389 values=[[5.], [6.]], indices=[7, 8], dense_shape=[10, 1]), 390 IndexedSlicesValue( 391 values=[[7.], [8.]], indices=[3, 2], dense_shape=[10, 1]), 392 ] 393 inputs = inputs_data[0:group_size] 394 395 if group_size == 1: 396 expect = IndexedSlices( 397 values=[[1.], [2.]], indices=[0, 1], dense_shape=[10, 1]) 398 elif group_size == 2: 399 expect = IndexedSlices( 400 values=[[1.], [2.], [3.], [4.]], 401 indices=[0, 1, 1, 2], 402 dense_shape=[10, 1]) 403 elif group_size == 4: 404 expect = IndexedSlices( 405 values=[[1.], [2.], [3.], [4.], [5.], [6.], [7.], [8.]], 406 indices=[0, 1, 1, 2, 7, 8, 3, 2], 407 dense_shape=[10, 1]) 408 409 self.reduce_and_verify(inputs, expect, options) 410 411 @combinations.generate( 412 combinations.combine(prefer_unique_instance_key=[True, False])) 413 def testAllReduceSparseVariableLength(self, prefer_unique_instance_key): 414 # One device per process, 2 processes, 2 replicas in total. 415 inputs = [ 416 IndexedSlicesValue(values=[[1.]], indices=[0], dense_shape=[10, 1]), 417 IndexedSlicesValue( 418 values=[[2.], [3.], [4.]], indices=[0, 1, 2], dense_shape=[10, 1]), 419 ] 420 expect = IndexedSlices( 421 values=[[1.], [2.], [3.], [4.]], 422 indices=[0, 0, 1, 2], 423 dense_shape=[10, 1]) 424 self.reduce_and_verify( 425 inputs, 426 expect, 427 self.RunOptions( 428 mode=["func_graph"], # Sparse reduce is not supported in eager. 429 num_processes=2, 430 reduce_op=ReduceOp.SUM, 431 prefer_unique_instance_key=prefer_unique_instance_key)) 432 433 @combinations.generate( 434 combinations.combine( 435 num_processes=[1, 2], 436 required_gpus=[0, 1, 2], 437 implementation=[ 438 CommunicationImplementation.AUTO, 439 CommunicationImplementation.RING, 440 CommunicationImplementation.NCCL, 441 ], 442 reduce_op=[ReduceOp.SUM, ReduceOp.MEAN], 443 prefer_unique_instance_key=[True, False])) 444 def testBatchAllReduceDense(self, num_processes, required_gpus, 445 implementation, reduce_op, 446 prefer_unique_instance_key): 447 if (required_gpus == 0 and 448 implementation == CommunicationImplementation.NCCL): 449 self.skipTest("Skip CPU + NCCL combination") 450 if (num_processes == 2 and 451 implementation == CommunicationImplementation.NCCL): 452 self.skipTest("Skip NCCL + 2 processes combination. NCCL requires " 453 "physical GPUs for every process.") 454 455 options = self.RunOptions( 456 num_processes=num_processes, 457 gpus_per_process=required_gpus, 458 reduce_op=reduce_op, 459 communication_options=collective_util.Options( 460 implementation=implementation), 461 prefer_unique_instance_key=prefer_unique_instance_key) 462 group_size = options.num_processes * (options.gpus_per_process or 1) 463 464 inputs_data = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]] 465 inputs = inputs_data[0:group_size] 466 467 if group_size == 1: 468 expect = [1.0, 2.0] 469 if group_size == 2: 470 expect = [4.0, 6.0] if reduce_op == ReduceOp.SUM else [2.0, 3.0] 471 elif group_size == 4: 472 expect = [16.0, 20.0] if reduce_op == ReduceOp.SUM else [4.0, 5.0] 473 474 self.batch_reduce_and_verify(inputs, expect, options) 475 476 @combinations.generate( 477 combinations.combine( 478 num_processes=[1, 2], 479 required_gpus=[0, 1, 2], 480 implementation=[ 481 CommunicationImplementation.AUTO, 482 CommunicationImplementation.RING, 483 CommunicationImplementation.NCCL, 484 ], 485 # TODO(b/166682130): add MEAN reduce once the bug is fixed. 486 reduce_op=ReduceOp.SUM, 487 prefer_unique_instance_key=[True, False])) 488 def testBatchAllReduceSparse(self, num_processes, required_gpus, 489 implementation, reduce_op, 490 prefer_unique_instance_key): 491 if (required_gpus == 0 and 492 implementation == CommunicationImplementation.NCCL): 493 self.skipTest("Skip CPU + NCCL combination") 494 if (num_processes == 2 and 495 implementation == CommunicationImplementation.NCCL): 496 self.skipTest("Skip NCCL + 2 processes combination. NCCL requires " 497 "physical GPUs for every process.") 498 499 options = self.RunOptions( 500 mode=["func_graph"], # Sparse reduce is not supported in eager. 501 num_processes=num_processes, 502 gpus_per_process=required_gpus, 503 reduce_op=reduce_op, 504 communication_options=collective_util.Options( 505 implementation=implementation), 506 prefer_unique_instance_key=prefer_unique_instance_key) 507 group_size = options.num_processes * (options.gpus_per_process or 1) 508 509 inputs_data = ([ 510 IndexedSlicesValue( 511 values=[[1.], [2.]], indices=[0, 1], dense_shape=[10, 1]), 512 IndexedSlicesValue( 513 values=[[3.], [4.]], indices=[1, 2], dense_shape=[5, 1]) 514 ], [ 515 IndexedSlicesValue( 516 values=[[5.], [6.]], indices=[1, 2], dense_shape=[10, 1]), 517 IndexedSlicesValue( 518 values=[[7.], [8.]], indices=[0, 1], dense_shape=[5, 1]) 519 ], [ 520 IndexedSlicesValue( 521 values=[[9.], [10.]], indices=[3, 4], dense_shape=[10, 1]), 522 IndexedSlicesValue( 523 values=[[11.], [12.]], indices=[3, 4], dense_shape=[5, 1]) 524 ], [ 525 IndexedSlicesValue( 526 values=[[13.], [14.]], indices=[8, 9], dense_shape=[10, 1]), 527 IndexedSlicesValue( 528 values=[[15.], [16.]], indices=[3, 4], dense_shape=[5, 1]) 529 ]) 530 inputs = inputs_data[0:group_size] 531 532 if group_size == 1: 533 expect = [ 534 IndexedSlices( 535 values=[[1.], [2.]], indices=[0, 1], dense_shape=[10, 1]), 536 IndexedSlicesValue( 537 values=[[3.], [4.]], indices=[1, 2], dense_shape=[5, 1]) 538 ] 539 if group_size == 2: 540 expect = [ 541 IndexedSlices( 542 values=[[1.], [2.], [5.], [6.]], 543 indices=[0, 1, 1, 2], 544 dense_shape=[10, 1]), 545 IndexedSlices( 546 values=[[3.], [4.], [7.], [8.]], 547 indices=[1, 2, 3, 4], 548 dense_shape=[5, 1]) 549 ] 550 elif group_size == 4: 551 expect = [ 552 IndexedSlices( 553 values=[[1.], [2.], [5.], [6.], [9.], [10.], [13.], [14.]], 554 indices=[0, 1, 1, 2, 3, 4, 8, 9], 555 dense_shape=[10, 1]), 556 IndexedSlices( 557 values=[[3.], [4.], [7.], [8.], [11.], [12.], [15.], [16.]], 558 indices=[1, 2, 0, 1, 3, 4, 3, 4], 559 dense_shape=[5, 2]) 560 ] 561 self.batch_reduce_and_verify(inputs, expect, options) 562 563 @combinations.generate( 564 combinations.combine( 565 num_processes=[1, 2], 566 required_gpus=[0, 1, 2], 567 implementation=[ 568 CommunicationImplementation.AUTO, 569 CommunicationImplementation.RING, 570 CommunicationImplementation.NCCL, 571 ], 572 reduce_op=[ReduceOp.SUM, ReduceOp.MEAN], 573 )) 574 def testCollectiveAllReduce(self, num_processes, required_gpus, 575 implementation, reduce_op): 576 if (required_gpus == 0 and 577 implementation == CommunicationImplementation.NCCL): 578 self.skipTest("Skip CPU + NCCL combination") 579 if (num_processes == 2 and 580 implementation == CommunicationImplementation.NCCL): 581 self.skipTest("Skip NCCL + 2 processes combination. NCCL requires " 582 "physical GPUs for every process.") 583 584 def replica_fn(): 585 collective, devices, _ = self.make_collective(num_processes, 586 required_gpus) 587 options = collective_util.Options(implementation=implementation) 588 group_size = num_processes * (required_gpus or 1) 589 590 @def_function.function 591 def collective_all_reduce(): 592 results = [] 593 for replica_id, device in enumerate(devices): 594 with ops.device(device): 595 value = constant_op.constant(1.0) 596 results.append( 597 collective._all_reduce(reduce_op, value, replica_id, options)) 598 return results 599 600 got = collective_all_reduce() 601 if reduce_op == ReduceOp.SUM: 602 expect = [1.0 * group_size] * len(devices) 603 elif reduce_op == ReduceOp.MEAN: 604 expect = [1.0] * len(devices) 605 self.assertAllClose(got, expect) 606 607 @def_function.function 608 def collective_batch_all_reduce(): 609 results = [] 610 for replica_id, device in enumerate(devices): 611 with ops.device(device): 612 value = (constant_op.constant(1.0), constant_op.constant(2.0)) 613 results.append( 614 collective._all_reduce(reduce_op, value, replica_id, options)) 615 return results 616 617 got = collective_batch_all_reduce() 618 if reduce_op == ReduceOp.SUM: 619 expect = [(1.0 * group_size, 2.0 * group_size)] * len(devices) 620 elif reduce_op == ReduceOp.MEAN: 621 expect = [(1.0, 2.0)] * len(devices) 622 self.assertAllClose(got, expect) 623 624 get_global_mpr(num_processes).run(replica_fn) 625 626 @combinations.generate( 627 combinations.combine( 628 num_processes=[1, 2], 629 required_gpus=[0, 1, 2], 630 axis=[0, 1, 2], 631 func_mode=["eager", "func_graph"], 632 implementation=[ 633 CommunicationImplementation.AUTO, 634 CommunicationImplementation.RING, 635 CommunicationImplementation.NCCL, 636 ], 637 prefer_unique_instance_key=[True, False])) 638 def testAllGatherSameShape(self, num_processes, required_gpus, implementation, 639 func_mode, axis, prefer_unique_instance_key): 640 641 def replica_fn(): 642 CollectiveReplicaLauncher._prefer_unique_instance_key = ( 643 prefer_unique_instance_key) 644 collective, devices, _ = self.make_collective(num_processes, 645 required_gpus) 646 options = collective_util.Options(implementation=implementation) 647 value = constant_op.constant([[[1, 2], [1, 2]]], dtype=dtypes.float32) 648 649 def gather_fn(): 650 per_replica_value = make_per_replica_value(value, devices) 651 gathered_values = collective._gather( 652 per_replica_value, per_replica_value, axis=axis, options=options) 653 gathered_values = self.as_list(gathered_values) 654 # Skip checking devices in eager. In eager the device attribute doesn't 655 # reflect the actual device of the tensor. 656 if not context.executing_eagerly(): 657 self.assertAllEqual(devices, [v.device for v in gathered_values]) 658 return [ops.convert_to_tensor(v) for v in gathered_values] 659 660 group_size = num_processes * (required_gpus or 1) 661 expect = array_ops.concat([value] * group_size, axis=axis) 662 per_replica_expect = [ops.convert_to_tensor(expect)] * len(devices) 663 664 if func_mode == "eager": 665 result = gather_fn() 666 self.assertAllClose(result, per_replica_expect) 667 668 if func_mode == "func_graph": 669 result = def_function.function(gather_fn)() 670 self.assertAllClose(result, per_replica_expect) 671 672 get_global_mpr(num_processes).run(replica_fn) 673 674 @combinations.generate( 675 combinations.combine( 676 num_processes=[1, 2], 677 required_gpus=[0, 1, 2], 678 implementation=[CommunicationImplementation.RING])) 679 def testCollectiveV2ControlFlow(self, num_processes, required_gpus, 680 implementation): 681 682 def replica_fn(): 683 CollectiveReplicaLauncher._prefer_unique_instance_key = True 684 collective, devices, _ = self.make_collective(num_processes, 685 required_gpus) 686 options = collective_util.Options(implementation=implementation) 687 value = make_per_replica_value(constant_op.constant([1.]), devices) 688 689 @def_function.function 690 def reduce_fn(): 691 692 def cond_body(): 693 reduced = collective.reduce(reduce_util.ReduceOp.SUM, value, value, 694 options) 695 return math_ops.add_n(self.as_list(reduced)) / len(devices) 696 697 return control_flow_ops.cond( 698 array_ops.identity(False), cond_body, cond_body) 699 700 num_replicas = num_processes * len(devices) 701 self.assertAllEqual(reduce_fn(), [1. * num_replicas]) 702 703 get_global_mpr(num_processes).run(replica_fn) 704 705 @combinations.generate( 706 combinations.combine( 707 num_processes=1, 708 required_gpus=2, 709 implementation=[ 710 CommunicationImplementation.NCCL, CommunicationImplementation.RING 711 ], 712 prefer_unique_instance_key=[True, False])) 713 def testMultiThreadedCollectiveLaunchNoInterleave(self, num_processes, 714 required_gpus, 715 implementation, 716 prefer_unique_instance_key): 717 718 def replica_fn(): 719 CollectiveReplicaLauncher._prefer_unique_instance_key = ( 720 prefer_unique_instance_key) 721 collective, devices, _ = self.make_collective(num_processes, 722 required_gpus) 723 options = collective_util.Options(implementation=implementation) 724 725 # We would like to simulate the following sequence: 726 # thread-0 device0 device1 727 # thread-1 device0 device1 728 # If the kernel launch sequence is as-is the program will deadlock since 729 # NCCL requires the launch order to be same on each device. 730 v0 = make_per_replica_value(1.0, devices) 731 v1 = make_per_replica_value(2.0, devices) 732 733 # Add a delay to collective_ops.all_reduce according to the input tensors 734 # index in `sequence.` 735 sequence = [v0.values[0], v1.values[0], v1.values[1], v0.values[1]] 736 all_reduce = collective_ops.all_reduce 737 738 def delayed_all_reduce(input_tensor, *args, **kwargs): 739 for idx, v in enumerate(sequence): 740 if input_tensor is v: 741 time.sleep(idx) 742 break 743 return all_reduce(input_tensor, *args, **kwargs) 744 745 with test.mock.patch.object(collective_ops, "all_reduce", 746 delayed_all_reduce): 747 # We only use NCCL for batch reduce with two or more values, so we use 748 # two values here. 749 750 def thread_fn(): 751 reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM, 752 [(v0, v0), (v0, v0)], options) 753 self.assertAllEqual(reduced[0].values, [2.0, 2.0]) 754 self.assertAllEqual(reduced[1].values, [2.0, 2.0]) 755 756 t = threading.Thread(target=thread_fn) 757 t.start() 758 reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v1, v1), 759 (v1, v1)], 760 options) 761 self.assertAllEqual(reduced[0].values, [4.0, 4.0]) 762 self.assertAllEqual(reduced[1].values, [4.0, 4.0]) 763 t.join() 764 765 get_global_mpr(num_processes).run(replica_fn) 766 767 @combinations.generate( 768 combinations.combine( 769 num_processes=1, 770 required_gpus=2, 771 implementation=[ 772 CommunicationImplementation.NCCL, CommunicationImplementation.RING 773 ], 774 prefer_unique_instance_key=[True, False])) 775 def testInputsAreFunctionArgs(self, num_processes, required_gpus, 776 implementation, prefer_unique_instance_key): 777 778 def replica_fn(): 779 CollectiveReplicaLauncher._prefer_unique_instance_key = ( 780 prefer_unique_instance_key) 781 collective, devices, _ = self.make_collective(num_processes, 782 required_gpus) 783 options = collective_util.Options(implementation=implementation) 784 785 @def_function.function 786 def reduce_fn(v): 787 # Function inputs don't have device placement. 788 self.assertEqual(v.values[0].device, "") 789 self.assertEqual(v.values[1].device, "") 790 # We only use NCCL for batch reduce with two or more values, so we use 791 # two values here. 792 reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v, v), 793 (v, v)], 794 options) 795 self.assertEqual(reduced[0].values[0].device, devices[0]) 796 self.assertEqual(reduced[0].values[1].device, devices[1]) 797 self.assertEqual(reduced[1].values[0].device, devices[0]) 798 self.assertEqual(reduced[1].values[1].device, devices[1]) 799 # Returning Mirrored only evaluates the primary value, which causes 800 # hanging, 801 return [reduced[0].values, reduced[1].values] 802 803 v = make_per_replica_value(1.0, devices) 804 reduced = reduce_fn(v) 805 self.assertAllClose(reduced, [[2.0, 2.0], [2.0, 2.0]]) 806 807 get_global_mpr(num_processes).run(replica_fn) 808 809 @combinations.generate( 810 combinations.combine( 811 num_processes=2, 812 required_gpus=[0, 1], 813 implementation=[ 814 CommunicationImplementation.RING, CommunicationImplementation.NCCL 815 ], 816 prefer_unique_instance_key=[True, False])) 817 def testTimeoutReduceDense(self, num_processes, implementation, required_gpus, 818 prefer_unique_instance_key): 819 820 if (required_gpus == 0 and 821 implementation == CommunicationImplementation.NCCL): 822 self.skipTest("Skip CPU + NCCL combination") 823 824 def replica_fn(): 825 CollectiveReplicaLauncher._prefer_unique_instance_key = ( 826 prefer_unique_instance_key) 827 collective, devices, task_id = self.make_collective( 828 num_processes, required_gpus) 829 if task_id != 0: 830 return 831 832 v = make_per_replica_value(1.0, devices) 833 options = collective_util.Options( 834 timeout_seconds=1, implementation=implementation) 835 836 @def_function.function 837 def reduce_dense(): 838 return collective.reduce(reduce_util.ReduceOp.SUM, v, v, options) 839 840 # The collective should time out because we only launch it on worker-0, 841 # while there're three workers in total. 842 with self.assertRaises(errors.DeadlineExceededError): 843 reduce_dense() 844 845 get_global_mpr(num_processes).run(replica_fn) 846 847 @combinations.generate( 848 combinations.combine( 849 num_processes=2, 850 required_gpus=[0, 1], 851 implementation=[ 852 CommunicationImplementation.RING, CommunicationImplementation.NCCL 853 ], 854 prefer_unique_instance_key=[True, False])) 855 def testTimeoutBatchReduceDense(self, num_processes, implementation, 856 required_gpus, prefer_unique_instance_key): 857 if (required_gpus == 0 and 858 implementation == CommunicationImplementation.NCCL): 859 self.skipTest("Skip CPU + NCCL combination") 860 861 def replica_fn(): 862 CollectiveReplicaLauncher._prefer_unique_instance_key = ( 863 prefer_unique_instance_key) 864 collective, devices, task_id = self.make_collective( 865 num_processes, required_gpus) 866 if task_id != 0: 867 return 868 869 v = make_per_replica_value(1.0, devices) 870 options = collective_util.Options( 871 timeout_seconds=1, implementation=implementation) 872 873 @def_function.function 874 def batch_reduce_dense(): 875 return collective.batch_reduce(reduce_util.ReduceOp.SUM, 876 [(v, v), (v, v)], options) 877 878 # The collective should time out because we only launch it on worker-0, 879 # while there're two workers in total. 880 with self.assertRaises(errors.DeadlineExceededError): 881 batch_reduce_dense() 882 883 get_global_mpr(num_processes).run(replica_fn) 884 885 @combinations.generate( 886 combinations.combine( 887 num_processes=2, 888 required_gpus=[0, 1], 889 implementation=[ 890 CommunicationImplementation.RING, CommunicationImplementation.NCCL 891 ], 892 prefer_unique_instance_key=[True, False])) 893 def testTimeoutReduceSparse(self, num_processes, implementation, 894 required_gpus, prefer_unique_instance_key): 895 if (required_gpus == 0 and 896 implementation == CommunicationImplementation.NCCL): 897 self.skipTest("Skip CPU + NCCL combination") 898 899 def replica_fn(): 900 CollectiveReplicaLauncher._prefer_unique_instance_key = ( 901 prefer_unique_instance_key) 902 collective, devices, task_id = self.make_collective( 903 num_processes, required_gpus) 904 if task_id != 0: 905 return 906 907 v = make_per_replica_value( 908 IndexedSlicesValue( 909 values=[[4., 6.]], indices=[1], dense_shape=[5, 2]), devices) 910 options = collective_util.Options( 911 timeout_seconds=1, implementation=implementation) 912 913 @def_function.function 914 def reduce_sparse(): 915 return collective.reduce(reduce_util.ReduceOp.SUM, v, v, options) 916 917 # The collective should time out because we only launch it on worker-0, 918 # while there're two workers in total. 919 with self.assertRaises(errors.DeadlineExceededError): 920 reduce_sparse() 921 922 get_global_mpr(num_processes).run(replica_fn) 923 924 @combinations.generate( 925 combinations.combine( 926 num_processes=2, 927 required_gpus=[0, 1], 928 implementation=[ 929 CommunicationImplementation.RING, CommunicationImplementation.NCCL 930 ], 931 prefer_unique_instance_key=[True, False])) 932 def testTimeoutBatchReduceSparse(self, num_processes, required_gpus, 933 implementation, prefer_unique_instance_key): 934 if (required_gpus == 0 and 935 implementation == CommunicationImplementation.NCCL): 936 self.skipTest("Skip CPU + NCCL combination") 937 938 def replica_fn(): 939 CollectiveReplicaLauncher._prefer_unique_instance_key = ( 940 prefer_unique_instance_key) 941 collective, devices, task_id = self.make_collective( 942 num_processes, required_gpus) 943 if task_id != 0: 944 return 945 946 v = make_per_replica_value( 947 IndexedSlicesValue( 948 values=[[4., 6.]], indices=[1], dense_shape=[5, 2]), devices) 949 options = collective_util.Options( 950 timeout_seconds=1, implementation=implementation) 951 952 @def_function.function 953 def batch_reduce_sparse(): 954 return collective.batch_reduce(reduce_util.ReduceOp.SUM, 955 [(v, v), (v, v)], options) 956 957 # The collective should time out because we only launch it on worker-0, 958 # while there're two workers in total. 959 with self.assertRaises(errors.DeadlineExceededError): 960 batch_reduce_sparse() 961 962 get_global_mpr(num_processes).run(replica_fn) 963 964 @combinations.generate(combinations.combine(num_processes=1, required_gpus=2)) 965 def testNcclOrdering(self, num_processes, required_gpus): 966 967 def replica_fn(): 968 CollectiveReplicaLauncher._prefer_unique_instance_key = True 969 CollectiveReplicaLauncher._prefer_ordering_token = True 970 collective, devices, _ = self.make_collective(num_processes, 971 required_gpus) 972 options = collective_util.Options( 973 implementation=CommunicationImplementation.NCCL) 974 975 v_dense = make_per_replica_value([1.0, 1.0], devices) 976 v_sparse = make_per_replica_value([ 977 IndexedSlicesValue([[4., 6.], [5., 6.]], [1, 3], [5, 2]), 978 IndexedSlicesValue([[4., 6.], [5., 6.]], [1, 3], [5, 2]), 979 ], devices) 980 981 @def_function.function 982 def nested_dense(): 983 collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options) 984 985 @def_function.function 986 def nested_sparse(): 987 collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options) 988 989 # All collectives, function calls, if clause and while loops should be 990 # chained by control dependencies, so that the execution order is 991 # deterministic. 992 @def_function.function 993 def f(): 994 # pylint: disable=pointless-statement 995 collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options) 996 # reducing dense value. 997 collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options) 998 # reducing sparse value. 999 collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options) 1000 # reduce dense value in nested tf.function. 1001 nested_dense() 1002 # reduce sparse value in nested tf.function. 1003 nested_sparse() 1004 # reduce dense value in tf.cond. 1005 if array_ops.identity(1.0) > array_ops.identity(2.0): 1006 collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options) 1007 else: 1008 v_dense 1009 # reduce sparse value in tf.cond. 1010 if array_ops.identity(1.0) > array_ops.identity(2.0): 1011 v_sparse 1012 else: 1013 collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, 1014 options) 1015 # reduce dense value in tf.while_loop. 1016 i = array_ops.identity(1) 1017 while i < 3: 1018 collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options) 1019 i += 1 1020 # reduce sparse value in tf.while_loop. 1021 i = array_ops.identity(1) 1022 while i < 3: 1023 collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, 1024 options) 1025 i += 1 1026 # reducing dense and sparse value again. 1027 collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options) 1028 collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options) 1029 # pylint: enable=pointless-statement 1030 1031 graph = f.get_concrete_function().graph 1032 should_be_ordered = set([ 1033 "CollectiveReduceV2", "CollectiveGatherV2", "If", "While", 1034 "StatefulPartitionedCall" 1035 ]) 1036 nodes_by_device = {} 1037 for op in graph.get_operations(): 1038 if op.type in should_be_ordered: 1039 if op.device not in nodes_by_device: 1040 nodes_by_device[op.device] = [] 1041 nodes_by_device[op.device].append(op) 1042 order = test_util.topological_sort_operations(graph.get_operations()) 1043 for device in devices: 1044 device = device_util.canonicalize(device) 1045 # Those function ops don't have device annotations, but they contain 1046 # collectives for both devices so we always include them. 1047 operations = nodes_by_device[device] + nodes_by_device[""] 1048 # Verify that we get all types of nodes we want. 1049 self.assertEqual(set(op.type for op in operations), should_be_ordered) 1050 test_util.assert_sequential_execution(order, operations) 1051 1052 get_global_mpr(num_processes).run(replica_fn) 1053 1054 1055if __name__ == "__main__": 1056 # Set default inter op thread pool size to one to ensure we don't exhaust the 1057 # thread pool with the additional executors to run collectives in eager. 1058 os.environ["TF_NUM_INTEROP_THREADS"] = "1" 1059 # TODO(b/172304955): figure why logical devices doesn't work. 1060 test_util.main(config_logical_devices=False) 1061