1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import os 21import threading 22 23from absl.testing import parameterized 24 25from tensorflow.python.distribute.parallel_device import parallel_device 26from tensorflow.python.eager import backprop 27from tensorflow.python.eager import context 28from tensorflow.python.eager import def_function 29from tensorflow.python.framework import config 30from tensorflow.python.framework import constant_op 31from tensorflow.python.framework import ops 32from tensorflow.python.module import module 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import collective_ops 35from tensorflow.python.ops import control_flow_ops 36from tensorflow.python.ops import gen_resource_variable_ops 37from tensorflow.python.ops import math_ops 38from tensorflow.python.ops import variables 39from tensorflow.python.platform import test 40from tensorflow.python.saved_model import load 41from tensorflow.python.saved_model import save 42from tensorflow.python.training import checkpoint_management 43from tensorflow.python.training.tracking import util as tracking 44from tensorflow.python.util import nest 45 46# When running collectives asynchronously, we need to give each parallel device 47# execution a unique ID so the collectives don't interfere. Since the op is 48# replicated with group/instance key intact, the replicated nodes will 49# communicate. 50# TODO(allenl): Switch to using a collective manager. 51_COUNTER_LOCK = threading.Lock() 52_COUNTER = 100 53 54 55def _collective_reduce(inputs, operation, num_replicas): 56 57 def _reduce_tensor(tensor): 58 with _COUNTER_LOCK: 59 global _COUNTER 60 keys = _COUNTER 61 _COUNTER += 1 62 return collective_ops.all_reduce( 63 t=tensor, 64 group_size=num_replicas, 65 merge_op=operation, 66 group_key=keys, 67 instance_key=keys, 68 final_op="Id") 69 70 return nest.map_structure(_reduce_tensor, inputs) 71 72 73def _collective_sum(inputs, num_replicas): 74 return _collective_reduce( 75 inputs=inputs, operation="Add", num_replicas=num_replicas) 76 77 78class _Dense(module.Module): 79 80 def __init__(self, output_size): 81 self.output_size = output_size 82 self.kernel = None 83 self.bias = None 84 85 def __call__(self, x): 86 if self.kernel is None: 87 self.kernel = variables.Variable( 88 array_ops.ones( 89 array_ops.stack([self.output_size, 90 array_ops.shape(x)[-1]]))) 91 self.bias = variables.Variable(array_ops.ones([self.output_size])) 92 return math_ops.matmul(x, self.kernel, transpose_b=True) + self.bias 93 94 95class _VirtualDeviceTestCase(test.TestCase): 96 97 def setUp(self): 98 super(_VirtualDeviceTestCase, self).setUp() 99 ctx = context.context() 100 if ctx.list_physical_devices("TPU"): 101 self.device_type = "TPU" 102 elif ctx.list_physical_devices("GPU"): 103 self.device_type = "GPU" 104 gpus = ctx.list_physical_devices(self.device_type) 105 ctx.set_logical_device_configuration(gpus[0], [ 106 context.LogicalDeviceConfiguration(memory_limit=100), 107 context.LogicalDeviceConfiguration(memory_limit=100), 108 ]) 109 else: 110 self.device_type = "CPU" 111 cpus = ctx.list_physical_devices("CPU") 112 ctx.set_logical_device_configuration(cpus[0], [ 113 context.LogicalDeviceConfiguration(), 114 context.LogicalDeviceConfiguration(), 115 ]) 116 117 self.device = parallel_device.ParallelDevice(components=[ 118 "/job:localhost/device:{}:0".format(self.device_type), 119 self.device_type + ":1" 120 ]) 121 self.assertIn(self.device_type + ":0", self.device.components[0]) 122 self.assertIn(self.device_type + ":1", self.device.components[1]) 123 124 125class ParallelDeviceTests(_VirtualDeviceTestCase, parameterized.TestCase): 126 127 def test_register_parallel_device(self): 128 with self.device: 129 c = constant_op.constant(1.) 130 d = constant_op.constant(2.) 131 e = c + d 132 outputs = self.device.unpack(e) 133 self.assertAllClose([3., 3.], outputs) 134 135 self.assertIn(self.device.components[0], outputs[0].backing_device) 136 self.assertIn(self.device.components[1], outputs[1].backing_device) 137 138 def test_device_id(self): 139 device_ids = self.device.unpack(self.device.device_ids) 140 self.assertAllClose([0, 1], device_ids) 141 # TODO(allenl): Should device IDs be int64 so they can be placed on GPUs? 142 # Currently backing_device is CPU. 143 self.assertIn(self.device.components[0], device_ids[0].device) 144 self.assertIn(self.device.components[1], device_ids[1].device) 145 146 def test_collective_reduce(self): 147 if self.device_type == "TPU": 148 self.skipTest("ParallelDevice collectives on TPUs need work") 149 with self.device: 150 x = self.device.pack( 151 [constant_op.constant(-1.5), 152 constant_op.constant(3.5)]) 153 reduced = _collective_sum(x, num_replicas=2) 154 outputs = self.device.unpack(reduced) 155 self.assertAllClose([2., 2.], outputs) 156 self.assertIn(self.device.components[0], outputs[0].backing_device) 157 self.assertIn(self.device.components[1], outputs[1].backing_device) 158 159 def test_collective_reduce_async_scope(self): 160 if self.device_type == "TPU": 161 self.skipTest("ParallelDevice collectives on TPUs need work") 162 # Note that ops on the parallel device currently don't execute 163 # asynchronously. The test is just that we don't get deadlocks. 164 with context.async_scope(), self.device: 165 x = self.device.pack( 166 [constant_op.constant(-1.5), 167 constant_op.constant(3.5)]) 168 reduced = _collective_sum(x, num_replicas=2) 169 outputs = self.device.unpack(reduced) 170 self.assertAllClose([2., 2.], outputs) 171 self.assertIn(self.device.components[0], outputs[0].backing_device) 172 self.assertIn(self.device.components[1], outputs[1].backing_device) 173 174 def test_collective_reduce_async_context(self): 175 if self.device_type == "TPU": 176 self.skipTest("ParallelDevice collectives on TPUs need work") 177 previous = config.get_synchronous_execution() 178 try: 179 context._reset_context() 180 config.set_synchronous_execution(False) 181 self.setUp() 182 # Note that ops on the parallel device currently don't execute 183 # asynchronously. The test is just that we don't get deadlocks. 184 with self.device: 185 x = self.device.pack( 186 [constant_op.constant(-1.5), 187 constant_op.constant(3.5)]) 188 reduced = _collective_sum(x, num_replicas=2) 189 outputs = self.device.unpack(reduced) 190 self.assertAllClose([2., 2.], outputs) 191 self.assertIn(self.device.components[0], outputs[0].backing_device) 192 self.assertIn(self.device.components[1], outputs[1].backing_device) 193 finally: 194 context._reset_context() 195 config.set_synchronous_execution(previous) 196 197 @parameterized.named_parameters( 198 [("RunFunctionsEagerly", True), 199 ("", False)]) 200 def test_cond(self, run_functions_eagerly): 201 try: 202 def_function.run_functions_eagerly(run_functions_eagerly) 203 with self.device: 204 pred = self.device.pack([True, False]) 205 capture = self.device.pack([[1.], [2.]]) 206 result = control_flow_ops.cond( 207 pred, 208 def_function.function(lambda: capture * 2.), 209 def_function.function(lambda: capture * 4.)) 210 self.assertAllClose( 211 [[2.], [8.]], self.device.unpack(result)) 212 finally: 213 def_function.run_functions_eagerly(False) 214 215 def test_cond_with_variable(self): 216 with self.device: 217 pred = self.device.pack([True, False]) 218 capture = self.device.pack([[1.], [2.]]) 219 v = None 220 @def_function.function 221 def true_branch(): 222 nonlocal v 223 if v is None: 224 v = variables.Variable(constant_op.constant(2.)) 225 return v * capture 226 result = control_flow_ops.cond( 227 pred, true_branch, def_function.function(lambda: capture * 4.)) 228 self.assertAllClose( 229 [[2.], [8.]], self.device.unpack(result)) 230 self.assertAllClose( 231 [2., 2.], self.device.unpack(v)) 232 # There are two unique variable handles with separate storage. 233 h1, _ = self.device.unpack(v.handle) 234 gen_resource_variable_ops.assign_variable_op(h1, constant_op.constant(3.)) 235 self.assertAllClose( 236 [3., 2.], self.device.unpack(v)) 237 238 def test_collective_in_function(self): 239 if self.device_type == "TPU": 240 self.skipTest("ParallelDevice collectives on TPUs need work") 241 c = constant_op.constant([2]) 242 243 @def_function.function 244 def broadcast_send_recv(device_id): 245 246 @def_function.function 247 def send(): 248 s0 = collective_ops.broadcast_send( 249 c * 3, c.shape, c.dtype, group_size=2, group_key=1, instance_key=1) 250 with ops.control_dependencies([s0.op]): 251 return array_ops.identity(c) 252 253 @def_function.function 254 def recv(): 255 r0 = collective_ops.broadcast_recv( 256 c.shape, c.dtype, group_size=2, group_key=1, instance_key=1) 257 return r0 258 259 return control_flow_ops.switch_case( 260 device_id, branch_fns={0: send, 1: recv}) 261 262 with self.device: 263 result = broadcast_send_recv(self.device.device_ids) 264 self.assertAllClose([[2], [6]], self.device.unpack(result)) 265 266 def test_use_in_graph_error_is_informative(self): 267 @def_function.function 268 def uses_parallel(): 269 with self.device: 270 return self.device.unpack(array_ops.ones([])) 271 272 with self.assertRaisesRegex(NotImplementedError, "inside `tf.function`"): 273 uses_parallel() 274 275 def test_checkpointing(self): 276 prefix = os.path.join(self.get_temp_dir(), "ckpt") 277 with self.device: 278 different_values = self.device.pack( 279 [constant_op.constant(-1.), 280 constant_op.constant(3.)]) 281 v = variables.Variable(different_values) 282 checkpoint = tracking.Checkpoint(v=v) 283 save_path = checkpoint.save(prefix) 284 with self.device: 285 v.assign(constant_op.constant(0.)) 286 checkpoint.restore(save_path).assert_consumed() 287 with self.device: 288 outputs = self.device.unpack(v) 289 self.assertAllClose([-1., 3.], outputs) 290 291 with self.device: 292 restore_on_create = tracking.Checkpoint() 293 restore_on_create.restore(save_path) 294 restore_on_create.v = variables.Variable(0.) 295 outputs = self.device.unpack(restore_on_create.v) 296 self.assertAllClose([-1., 3.], outputs) 297 298 # Changing the number of devices / restoring into a single-device copy is OK 299 single_device = tracking.Checkpoint(v=variables.Variable(0.)) 300 status = single_device.restore(save_path) 301 status.assert_existing_objects_matched() 302 self.assertAllClose(-1., single_device.v) 303 with self.assertRaisesRegex(AssertionError, "parallel_component_1"): 304 # There are parts of the variable that aren't restored into a 305 # single-device copy. 306 status.assert_consumed() 307 308 def test_saved_model(self): 309 with self.device: 310 different_values = self.device.pack( 311 [constant_op.constant(-1.), 312 constant_op.constant(3.)]) 313 m = module.Module() 314 m.v = variables.Variable(different_values) 315 m.f = def_function.function(lambda: m.v * 2.) 316 self.assertAllClose([-2., 6.], self.device.unpack(m.f())) 317 saved_model_path = os.path.join(self.get_temp_dir(), "saved_model") 318 save.save(m, saved_model_path) 319 320 context._reset_context() 321 self.setUp() 322 323 single_device_loaded = load.load(saved_model_path) 324 self.assertAllClose(-2., single_device_loaded.f()) 325 with self.device: 326 parallel_loaded = load.load(saved_model_path) 327 self.assertAllClose([-2., 6.], self.device.unpack(parallel_loaded.f())) 328 self.assertAllClose([-1., 3.], self.device.unpack(parallel_loaded.v)) 329 parallel_loaded.v.assign(self.device.pack([.1, .2])) 330 self.assertAllClose([.2, .4], self.device.unpack(parallel_loaded.f())) 331 332 def _assert_close_to_non_parallel(self, computation): 333 """Asserts that replication of `computation` works and is equivalent.""" 334 with self.device: 335 parallel_result = computation() 336 non_parallel_result = computation() 337 # The computations should have the same number and structure of Tensor 338 # objects, even though the tensors themselves will be on different devices 339 # and represent different numbers of values. 340 nest.assert_same_structure(parallel_result, non_parallel_result) 341 non_parallel_flat = nest.flatten(non_parallel_result) 342 parallel_flat = nest.flatten(parallel_result) 343 self.assertGreater(len(parallel_flat), 0) 344 for non_parallel, parallel in zip(non_parallel_flat, parallel_flat): 345 self.assertEqual(self.device._name, parallel.device) 346 self.assertNotEqual(self.device._name, non_parallel.device) 347 for parallel_component in self.device.unpack(parallel): 348 self.assertAllClose(non_parallel, parallel_component) 349 350 def test_capturing(self): 351 with self.device: 352 x = constant_op.constant([1., 2.]) 353 x = array_ops.identity(x) 354 355 @def_function.function 356 def f(y): 357 return x + y 358 359 y = array_ops.ones([2]) 360 parallel_result = f(y) 361 self.assertAllClose([[2., 3.]] * 2, self.device.unpack(parallel_result)) 362 363 def test_euclidean_norm(self): 364 def _test_fn(): 365 with backprop.GradientTape() as tape: 366 x = array_ops.ones([5, 5]) 367 tape.watch(x) 368 y = math_ops.reduce_euclidean_norm(x, axis=constant_op.constant(1)) 369 return y, tape.gradient(y, x) 370 self._assert_close_to_non_parallel(_test_fn) 371 372 def test_reduce_sum(self): 373 def _test_fn(): 374 with backprop.GradientTape() as tape: 375 x = array_ops.ones([5, 5]) 376 tape.watch(x) 377 y = math_ops.reduce_sum(x, axis=constant_op.constant(1)) 378 return y, tape.gradient(y, x) 379 self._assert_close_to_non_parallel(_test_fn) 380 381 def test_variable_created_in_function(self): 382 383 class M(module.Module): 384 385 def __init__(self): 386 self.v = None 387 self.w = None 388 self.x = None 389 self.z = None 390 391 @def_function.function(autograph=False) 392 def __call__(self, x): 393 if self.v is None: 394 with ops.init_scope(): 395 initial_value = constant_op.constant(2.) 396 self.z = variables.Variable(initial_value) 397 self.x = variables.Variable(initial_value) 398 self.w = variables.Variable(lambda: constant_op.constant(2.)) 399 self.v = variables.Variable(constant_op.constant(2.)) 400 return x * self.v * self.w * self.x * self.z 401 402 with self.device: 403 m = M() 404 packed_outputs = m(array_ops.ones([])) 405 outputs = self.device.unpack(packed_outputs) 406 self.assertAllClose([16., 16.], outputs) 407 408 def test_different_shapes(self): 409 with self.device: 410 x = self.device.pack( 411 [constant_op.constant([1., 2.]), 412 constant_op.constant([5.])]) 413 y = x * 2. 414 with self.assertRaisesRegex(Exception, 415 "components do not all have the same shape"): 416 y.shape # pylint: disable=pointless-statement 417 self.assertAllClose([[2., 4.], [10.]], self.device.unpack(y)) 418 419 different_axes = self.device.pack( 420 [constant_op.constant([1., 2.]), 421 constant_op.constant([[5.]])]) 422 with self.assertRaisesRegex(Exception, 423 "components do not all have the same shape"): 424 different_axes.shape # pylint: disable=pointless-statement 425 426 427class LayerTests(_VirtualDeviceTestCase): 428 429 def test_layer_forward(self): 430 with self.device: 431 layer = _Dense(5) 432 x = constant_op.constant([[2.]]) 433 y = layer(x) 434 outputs = self.device.unpack(y) 435 self.assertAllClose([[3.] * 5], outputs[0]) 436 self.assertAllClose([[3.] * 5], outputs[1]) 437 self.assertIn(self.device.components[0], outputs[0].backing_device) 438 self.assertIn(self.device.components[1], outputs[1].backing_device) 439 440 # With different Layer inputs we get different outputs 441 with self.device: 442 x = self.device.pack( 443 [constant_op.constant([[-0.5]]), 444 constant_op.constant([[0.5]])]) 445 y = layer(x) 446 outputs = self.device.unpack(y) 447 self.assertGreater( 448 math_ops.reduce_max(math_ops.abs(outputs[0] - outputs[1])), 1e-5) 449 self.assertIn(self.device.components[0], outputs[0].backing_device) 450 self.assertIn(self.device.components[1], outputs[1].backing_device) 451 452 def test_layer_sync_training(self): 453 if self.device_type == "TPU": 454 self.skipTest("ParallelDevice collectives on TPUs need work") 455 with self.device: 456 layer = _Dense(5) 457 458 with backprop.GradientTape() as tape: 459 x = self.device.pack( 460 [constant_op.constant([[-0.5]]), 461 constant_op.constant([[0.5]])]) 462 y = layer(x) 463 loss = (y - math_ops.range(5.))**2. 464 parameters = layer.trainable_variables 465 unreduced_gradients = tape.gradient(loss, parameters) 466 reduced_gradients = _collective_sum(unreduced_gradients, num_replicas=2) 467 for grad, param in zip(reduced_gradients, parameters): 468 param.assign_sub(0.01 * grad) 469 final_kernels = self.device.unpack(layer.kernel) 470 self.assertAllClose(final_kernels[0], final_kernels[1]) 471 final_bias = self.device.unpack(layer.bias) 472 expected_bias = (1. - 0.01 * 2. * (1. + .5 - math_ops.range(5.)) - 473 0.01 * 2. * (1. - .5 - math_ops.range(5.))) 474 self.assertAllClose(expected_bias, final_bias[0]) 475 self.assertAllClose(expected_bias, final_bias[1]) 476 self.assertIn(self.device.components[0], final_kernels[0].backing_device) 477 self.assertIn(self.device.components[1], final_kernels[1].backing_device) 478 479 def test_layer_divergent_buffer_training(self): 480 with self.device: 481 layer = _Dense(5) 482 483 with backprop.GradientTape() as tape: 484 x = self.device.pack( 485 [constant_op.constant([[-0.5]]), 486 constant_op.constant([[0.5]])]) 487 y = layer(x) 488 loss = (y - math_ops.range(5.))**2. 489 parameters = layer.trainable_variables 490 unreduced_gradients = tape.gradient(loss, parameters) 491 for grad, param in zip(unreduced_gradients, parameters): 492 param.assign_sub(0.01 * grad) 493 final_kernels = self.device.unpack(layer.kernel) 494 self.assertNotAllClose(final_kernels[0], final_kernels[1]) 495 final_bias = self.device.unpack(layer.bias) 496 self.assertAllClose(1. - 0.01 * 2. * (1. - .5 - math_ops.range(5.)), 497 final_bias[0]) 498 self.assertAllClose(1. - 0.01 * 2. * (1. + .5 - math_ops.range(5.)), 499 final_bias[1]) 500 self.assertIn(self.device.components[0], final_kernels[0].backing_device) 501 self.assertIn(self.device.components[1], final_kernels[1].backing_device) 502 503 def test_training_loop(self): 504 if self.device_type == "TPU": 505 self.skipTest("ParallelDevice collectives on TPUs need work") 506 for _ in range(5): 507 layer = _Dense(5) 508 checkpoint = tracking.Checkpoint(layer=layer) 509 manager = checkpoint_management.CheckpointManager( 510 checkpoint, directory=self.get_temp_dir(), max_to_keep=5) 511 manager.restore_or_initialize() 512 513 for _ in range(10): 514 with self.device: 515 with backprop.GradientTape() as tape: 516 x = self.device.pack( 517 [constant_op.constant([[-0.5]]), 518 constant_op.constant([[0.5]])]) 519 y = layer(x) 520 loss = (y - math_ops.range(5.))**2. 521 parameters = layer.trainable_variables 522 unreduced_gradients = tape.gradient(loss, parameters) 523 reduced_gradients = _collective_sum( 524 unreduced_gradients, num_replicas=len(self.device.components)) 525 for grad, param in zip(reduced_gradients, parameters): 526 param.assign_sub(0.01 * grad) 527 528 manager.save() 529 530 531if __name__ == "__main__": 532 ops.enable_eager_execution() 533 test.main() 534