1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for remote execution.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import random 23import time 24 25from absl.testing import parameterized 26import numpy as np 27import six 28 29from tensorflow.python.data.ops import dataset_ops 30from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver 31from tensorflow.python.eager import cancellation 32from tensorflow.python.eager import context 33from tensorflow.python.eager import def_function 34from tensorflow.python.eager import remote 35from tensorflow.python.eager import test 36from tensorflow.python.framework import constant_op 37from tensorflow.python.framework import dtypes 38from tensorflow.python.framework import errors 39from tensorflow.python.framework import ops 40from tensorflow.python.framework import tensor_spec 41from tensorflow.python.framework import test_ops 42from tensorflow.python.framework import test_util 43from tensorflow.python.ops import array_ops 44from tensorflow.python.ops import control_flow_ops 45from tensorflow.python.ops import data_flow_ops 46from tensorflow.python.ops import functional_ops 47from tensorflow.python.ops import math_ops 48from tensorflow.python.ops import resource_variable_ops 49from tensorflow.python.ops import string_ops 50from tensorflow.python.ops import variables 51from tensorflow.python.training import server_lib 52from tensorflow.python.training.server_lib import ClusterSpec 53from tensorflow.python.util import compat 54 55 56class SingleWorkerTest(test.TestCase, parameterized.TestCase): 57 58 def setUp(self): 59 super(SingleWorkerTest, self).setUp() 60 61 workers, _ = test_util.create_local_cluster(1, 0) 62 remote.connect_to_remote_host(workers[0].target) 63 64 def tearDown(self): 65 super(SingleWorkerTest, self).tearDown() 66 67 # Clear the current device scope to avoid polluting other test cases. 68 ops.device(None).__enter__() 69 # Reset the context to avoid polluting other test cases. 70 context._reset_context() 71 72 def testMultiDeviceFunctionBasic(self): 73 74 @def_function.function 75 def basic(i): 76 with ops.device('/job:localhost/replica:0/task:0/cpu:0'): 77 a = constant_op.constant([2]) + i 78 with ops.device('/job:worker/replica:0/task:0/cpu:0'): 79 b = constant_op.constant([1]) 80 81 return a + b 82 83 self.assertAllEqual(basic(constant_op.constant([2])).numpy(), [5]) 84 self.assertAllEqual(basic(constant_op.constant([1])).numpy(), [4]) 85 86 def testMultiDeviceFunctionVariable(self): 87 with ops.device('/job:worker/replica:0/task:0/cpu:0'): 88 variable_b = variables.Variable(1) 89 90 # Add a sync point to avoid the out-of-order issue of eager async execution 91 # (b/155789951). 92 context.async_wait() 93 94 @def_function.function 95 def with_variable(i): 96 return i + variable_b 97 98 self.assertAllEqual(with_variable(constant_op.constant([2])).numpy(), [3]) 99 100 def testMultiDeviceFunctionRemoteOutput(self): 101 with ops.device('/job:worker/replica:0/task:0/cpu:0'): 102 variable_b = variables.Variable(1) 103 104 @def_function.function 105 def remote_output(i): 106 with ops.device('/job:worker/replica:0/task:0/cpu:0'): 107 c = variable_b + 1 108 return i + variable_b, c 109 110 rets = remote_output(constant_op.constant([1])) 111 self.assertAllEqual(rets[0].numpy(), [2]) 112 self.assertAllEqual(rets[1].numpy(), 2) 113 self.assertEqual(rets[0].backing_device, 114 '/job:localhost/replica:0/task:0/device:CPU:0') 115 self.assertEqual(rets[1].backing_device, 116 '/job:worker/replica:0/task:0/device:CPU:0') 117 118 def testStreaming(self): 119 """A mini stress test for streaming - issuing many RPCs back to back.""" 120 with ops.device('job:worker/replica:0/task:0/device:CPU:0'): 121 x = array_ops.ones([2, 2]) 122 y = array_ops.zeros([2, 2]) 123 num_iters = 200 124 for _ in range(num_iters): 125 y = x + y 126 # Ask for y's shape after every 10 additions on average. 127 # This exercises waiting for remote shape logic in TensorHandle. 128 if random.randint(1, 10) == 1: 129 _ = y.shape 130 np.testing.assert_array_equal( 131 [[num_iters, num_iters], [num_iters, num_iters]], y.numpy()) 132 133 def testShapeError_OpByOp(self): 134 with ops.device('job:worker/replica:0/task:0/device:CPU:0'): 135 x = array_ops.ones([2, 3]) 136 y = array_ops.zeros([2, 2]) 137 with self.assertRaises(errors.InvalidArgumentError) as cm: 138 math_ops.matmul(x, y) 139 140 self.assertIn('Dimensions must be equal', cm.exception.message) 141 142 def testShapeError_Function(self): 143 144 @def_function.function 145 def matmul_func(x, y): 146 return math_ops.matmul(x, y) 147 148 x = array_ops.ones([2, 3]) 149 y = array_ops.zeros([2, 2]) 150 151 with ops.device('job:worker/replica:0/task:0/device:CPU:0'): 152 with self.assertRaises(ValueError) as cm: 153 matmul_func(x, y) 154 155 if six.PY2: 156 self.assertIn('Dimensions must be equal', cm.exception.message) 157 else: 158 self.assertIn('Dimensions must be equal', cm.exception.args[0]) 159 160 def testClientVarible(self): 161 var = variables.Variable(initial_value=0) 162 163 @def_function.function 164 def func(): 165 with ops.device('/job:localhost/task:0'): 166 read = var.read_value() 167 return read + 1 168 169 with ops.device('/job:worker/task:0'): 170 self.assertAllEqual(func(), 1) 171 172 def testRemoteCall(self): 173 174 @def_function.function( 175 input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) 176 def _remote_fn(x): 177 return constant_op.constant(1) + x 178 179 remote_fn = _remote_fn.get_concrete_function() 180 181 @def_function.function 182 def func(x): 183 return functional_ops.remote_call( 184 args=[x], 185 Tout=[dtypes.int32], 186 f=remote_fn, 187 target='/job:worker/task:0') 188 189 with ops.device('/job:localhost/task:0'): 190 self.assertAllEqual(func(constant_op.constant(1)), [2]) 191 192 193class RemoteAsyncTest(test.TestCase): 194 195 def setUp(self): 196 super(RemoteAsyncTest, self).setUp() 197 198 workers, _ = test_util.create_local_cluster(1, 0) 199 remote.connect_to_remote_host(workers[0].target) 200 201 def tearDown(self): 202 super(RemoteAsyncTest, self).tearDown() 203 204 # Reset the context to avoid polluting other test cases. 205 context._reset_context() 206 207 def test_out_of_range_with_while_loop(self): 208 209 with ops.device('/job:worker/task:0'): 210 dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0]) 211 dataset = dataset.batch(1, drop_remainder=False) 212 iterator = iter(dataset) 213 v = variables.Variable(1.0) 214 215 @def_function.function 216 def train_step(iterator): 217 i = next(iterator) 218 v.assign_add(math_ops.reduce_mean(i)) 219 220 while True: 221 try: 222 with ops.device('/job:worker/task:0'): 223 train_step(iterator) 224 except (errors.OutOfRangeError, errors.InternalError): 225 context.async_clear_error() 226 break 227 228 self.assertAllEqual(v.numpy(), 4.0) 229 230 def test_out_of_range_with_for_loop(self): 231 232 with ops.device('/job:worker/task:0'): 233 dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0]) 234 dataset = dataset.batch(1, drop_remainder=False) 235 iterator = iter(dataset) 236 v = variables.Variable(1.0) 237 238 @def_function.function 239 def train_step(iterator): 240 i = next(iterator) 241 v.assign_add(math_ops.reduce_mean(i)) 242 243 num_steps = 3 244 for i in range(num_steps): 245 try: 246 with ops.device('/job:worker/task:0'): 247 train_step(iterator) 248 if i == num_steps - 1: 249 context.async_wait() 250 except errors.OutOfRangeError: 251 context.async_clear_error() 252 break 253 254 self.assertAllEqual(v.numpy(), 4.0) 255 256 def test_out_of_range_with_async_scope(self): 257 258 with ops.device('/job:worker/task:0'): 259 dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0]) 260 dataset = dataset.batch(1, drop_remainder=False) 261 iterator = iter(dataset) 262 v = variables.Variable(1.0) 263 264 @def_function.function 265 def train_step(iterator): 266 i = next(iterator) 267 v.assign_add(math_ops.reduce_mean(i)) 268 269 num_steps = 3 270 try: 271 with context.async_scope(): 272 for _ in range(num_steps): 273 with ops.device('/job:worker/task:0'): 274 train_step(iterator) 275 except errors.OutOfRangeError: 276 context.async_clear_error() 277 278 self.assertAllEqual(v.numpy(), 4.0) 279 280 281class MultiWorkersTest(test.TestCase, parameterized.TestCase): 282 283 def setUp(self): 284 super(MultiWorkersTest, self).setUp() 285 286 workers, _ = test_util.create_local_cluster(3, 0) 287 remote.connect_to_remote_host( 288 [workers[0].target, workers[1].target, workers[2].target]) 289 290 def tearDown(self): 291 super(MultiWorkersTest, self).tearDown() 292 293 # Clear the current device scope to avoid polluting other test cases. 294 ops.device(None).__enter__() 295 # Reset the context to avoid polluting other test cases. 296 context._reset_context() 297 298 def testReturnRemoteArgument(self): 299 300 @def_function.function 301 def local_func(i): 302 return i 303 304 with ops.device('/job:worker/replica:0/task:0'): 305 x = constant_op.constant([2, 1]) 306 307 with ops.device('/job:worker/replica:0/task:1'): 308 self.assertAllEqual(local_func(x), [2, 1]) 309 310 def testMultiDeviceFunctionAmbiguousDevice(self): 311 312 @def_function.function 313 def ambiguous_device(i): 314 with ops.device('/job:worker'): 315 # Multiple worker tasks, thus ambiguous device found error will be 316 # raised. 317 return i + constant_op.constant([2]) 318 319 with self.assertRaises(errors.InvalidArgumentError) as cm: 320 ambiguous_device(constant_op.constant([2])).numpy() 321 322 self.assertIn('the output node must match exactly one device', 323 cm.exception.message) 324 325 # Note that the following tests for remote function cancellation only works 326 # when non-streaming RPC. We need to disable streaming explicitly and restore 327 # this config to its initial value at the end of each test case. 328 def testCancelRemoteFunctionBeforeExecution(self): 329 remote_async_env_var = 'TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE' 330 default_streaming = os.environ.get(remote_async_env_var) 331 os.environ[remote_async_env_var] = str(False) 332 333 q = data_flow_ops.FIFOQueue(1, dtypes.int32) 334 335 @def_function.function 336 def f(): 337 return q.dequeue() 338 339 c_mgr = cancellation.CancellationManager() 340 cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function()) 341 342 c_mgr.start_cancel() 343 with self.assertRaises(errors.CancelledError): 344 with ops.device('/job:worker/replica:0/task:1'): 345 cancelable_func() 346 347 if default_streaming is None: 348 del os.environ[remote_async_env_var] 349 else: 350 os.environ[remote_async_env_var] = default_streaming 351 352 def testCancelRemoteFunctionDuringExecution(self): 353 remote_async_env_var = 'TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE' 354 default_streaming = os.environ.get(remote_async_env_var) 355 os.environ[remote_async_env_var] = str(False) 356 357 q = data_flow_ops.FIFOQueue(1, dtypes.int32) 358 359 @def_function.function 360 def f(): 361 return q.dequeue() 362 363 c_mgr = cancellation.CancellationManager() 364 cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function()) 365 366 def cancel_thread(): 367 time.sleep(0.5) 368 c_mgr.start_cancel() 369 370 t = self.checkedThread(cancel_thread) 371 t.start() 372 with self.assertRaises(errors.CancelledError): 373 with ops.device('/job:worker/replica:0/task:1'): 374 cancelable_func() 375 t.join() 376 377 if default_streaming is None: 378 del os.environ[remote_async_env_var] 379 else: 380 os.environ[remote_async_env_var] = default_streaming 381 382 def testMultiDeviceFunctionOnLocalDevice(self): 383 with ops.device('/job:worker/replica:0/task:1'): 384 variable_b = variables.Variable(1.0) 385 386 @def_function.function 387 def remote_function(i): 388 with ops.device('/job:worker/replica:0/task:0'): 389 a = i + variable_b 390 c = a + 1.0 391 return c 392 393 self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) 394 395 def testMultiDeviceFunctionExecutionOrderingWithPackedInput(self): 396 shape = [2] 397 with ops.device('/job:worker/replica:0/task:2/device:CPU:0'): 398 # Send 20 remote requests to simulate heavy load on worker:2. 399 unused_values = [] 400 for _ in range(20): 401 unused_values.append(array_ops.zeros(shape)) 402 func_input = array_ops.zeros(shape) 403 404 packed_input = ops.pack_eager_tensors([func_input]) 405 406 @def_function.function 407 def func(packed_input): 408 # When worker:2 receives the component function request, packed_input 409 # should be ready on worker:2. 410 with ops.device('/job:worker/replica:0/task:2/device:CPU:0'): 411 ret = packed_input + constant_op.constant(1.0) 412 return ret + constant_op.constant(1.0) 413 414 # Run the function on a worker:1 415 with ops.device('/job:worker/replica:0/task:1/device:CPU:0'): 416 self.assertAllEqual(func(packed_input).numpy(), 417 array_ops.ones(shape).numpy() * 2) 418 419 def testMultiDeviceFunctionWithPackedVariable(self): 420 with ops.device('/job:worker/replica:0/task:0/device:CPU:0'): 421 var0 = resource_variable_ops.ResourceVariable(1.0) 422 with ops.device('/job:worker/replica:0/task:1/device:CPU:0'): 423 var1 = resource_variable_ops.ResourceVariable(2.0) 424 425 packed_var = ops.pack_eager_tensors([var0.handle, var1.handle]) 426 self.assertEqual(packed_var.device, 427 '/job:localhost/replica:0/task:0/device:COMPOSITE:0') 428 self.assertEqual(packed_var.backing_device, 429 '/job:localhost/replica:0/task:0/device:COMPOSITE:0') 430 431 @def_function.function 432 def add_variables(): 433 with ops.device('/job:worker/replica:0/task:0/device:CPU:0'): 434 read0 = resource_variable_ops.read_variable_op( 435 packed_var, dtype=dtypes.float32) 436 with ops.device('/job:worker/replica:0/task:1/device:CPU:0'): 437 read1 = resource_variable_ops.read_variable_op( 438 packed_var, dtype=dtypes.float32) 439 440 return read0 + read1 441 442 # Run the function on a remote device 443 with ops.device('/job:worker/replica:0/task:0'): 444 self.assertAllEqual(add_variables().numpy(), 3.0) 445 446 # Run the function on a local worker 447 self.assertAllEqual(add_variables().numpy(), 3.0) 448 449 def testMultiDeviceFunctionOnRemoteDeviceWithWait(self): 450 with ops.device('/job:worker/replica:0/task:1'): 451 variable_b = variables.Variable([1.0]) 452 453 @def_function.function 454 def remote_function(i): 455 x = array_ops.ones([1000, 1000]) 456 for _ in range(1, 1000): 457 x = x * x 458 variable_b.assign_add(i) 459 a = 1.0 + variable_b 460 return a 461 462 @def_function.function 463 def remote_function2(i): 464 variable_b.assign_add(i) 465 a = 1.0 + variable_b 466 return a 467 468 # Runs first function: 469 # - on remote device 470 # - needs remote input 471 # - is side impacting 472 # - runs much slower 473 with ops.device('/job:worker/replica:0/task:0'): 474 remote_function(constant_op.constant([2.0])) 475 476 # Runs second function: 477 # - on remote device 478 # - is side impacting 479 # There should be a sync point here and the next function will be executed 480 # only after the first function has completed. 481 with ops.device('/job:worker/replica:0/task:2'): 482 self.assertAllEqual(remote_function2(constant_op.constant([3.0])), [7.0]) 483 484 def testMultiDeviceFunctionOnRemoteDevice(self): 485 with ops.device('/job:worker/replica:0/task:1'): 486 variable_b = variables.Variable(1.0) 487 488 @def_function.function 489 def remote_function(i): 490 with ops.device('/job:worker/replica:0/task:0'): 491 a = i + variable_b 492 c = a + 1.0 493 return c 494 495 with ops.device('/job:worker/replica:0/task:0'): 496 self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) 497 498 if test_util.is_gpu_available(): 499 with ops.device('/job:worker/replica:0/task:0/device:GPU:0'): 500 self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) 501 502 def testMultiDeviceFunctionRemoteOutput(self): 503 with ops.device('/job:worker/replica:0/task:1/cpu:0'): 504 variable_b = variables.Variable(1) 505 506 @def_function.function 507 def remote_output(i): 508 with ops.device('/job:worker/replica:0/task:1/cpu:0'): 509 c = variable_b + 1 510 return i + variable_b, c 511 512 with ops.device('/job:worker/replica:0/task:0/cpu:0'): 513 rets = remote_output(constant_op.constant([1])) 514 self.assertEqual(rets[0].backing_device, 515 '/job:worker/replica:0/task:0/device:CPU:0') 516 self.assertEqual(rets[1].backing_device, 517 '/job:worker/replica:0/task:1/device:CPU:0') 518 self.assertAllEqual(rets[0].numpy(), [2]) 519 self.assertAllEqual(rets[1].numpy(), 2) 520 521 def testMultiDeviceWhileLoopOnRemoteDevice(self): 522 with ops.device('/job:worker/replica:0/task:1'): 523 variable_b = variables.Variable(1.0) 524 525 @def_function.function 526 def remote_function(i): 527 528 def body(i, _): 529 with ops.device('/job:worker/replica:0/task:0'): 530 a = i + variable_b 531 return a + 1.0, 1 532 533 return control_flow_ops.while_loop_v2(lambda _, d: d < 1, body, [i, 0])[0] 534 535 with ops.device('/job:worker/replica:0/task:0'): 536 self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) 537 538 if test_util.is_gpu_available(): 539 with ops.device('/job:worker/replica:0/task:0/device:GPU:0'): 540 self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) 541 542 def testSimpleParameterServer(self): 543 544 with ops.device('/job:worker/task:2/device:CPU:0'): 545 v1 = variables.Variable(initial_value=0) 546 v2 = variables.Variable(initial_value=10) 547 548 @def_function.function 549 def worker_fn(): 550 v1.assign_add(1) 551 v2.assign_sub(2) 552 return v1.read_value() + v2.read_value() 553 554 with ops.device('/job:worker/task:0/device:CPU:0'): 555 self.assertAllEqual(worker_fn(), 9) 556 557 with ops.device('/job:worker/task:1/device:CPU:0'): 558 self.assertAllEqual(worker_fn(), 8) 559 560 561_GRPC_PREFIX = 'grpc://' 562 563 564class MultiJobsTest(test.TestCase, parameterized.TestCase): 565 566 def setUp(self): 567 super(MultiJobsTest, self).setUp() 568 569 workers, ps = test_util.create_local_cluster(num_workers=2, num_ps=2) 570 cluster = { 571 'my_worker': [_strip_prefix(t.target, _GRPC_PREFIX) for t in workers], 572 'my_ps': [_strip_prefix(t.target, _GRPC_PREFIX) for t in ps], 573 } 574 self._cluster = server_lib.ClusterSpec(cluster) 575 self._cluster_resolver = SimpleClusterResolver( 576 cluster_spec=self._cluster, master=ps[0].target) 577 578 def tearDown(self): 579 super(MultiJobsTest, self).tearDown() 580 581 # Clear the current device scope to avoid polluting other test cases. 582 ops.device(None).__enter__() 583 # Reset the context to avoid polluting other test cases. 584 context._reset_context() 585 586 def testMultipleDeviceFoundCheck(self): 587 remote.connect_to_cluster(self._cluster) 588 589 @def_function.function 590 def func(): 591 with ops.device('cpu:0'): 592 # Multiple CPU:0 devices match would be found, but the CPU:0 from the 593 # parent device scope should be picked. 594 x = test_ops.device_placement_op() 595 y = string_ops.string_upper(x) 596 packed_var_0 = array_ops.stack([x, y], 0) 597 return packed_var_0 598 599 with ops.device('/job:my_worker/task:1'): 600 output = self.evaluate(func()) 601 self.assertEqual( 602 compat.as_bytes('/job:my_worker/replica:0/task:1/device:CPU:0'), 603 output[0]) 604 self.assertIn(compat.as_bytes('/JOB:MY_WORKER'), output[1]) 605 with ops.device('/job:my_ps/task:1'): 606 output = self.evaluate(func()) 607 self.assertEqual( 608 compat.as_bytes('/job:my_ps/replica:0/task:1/device:CPU:0'), 609 output[0]) 610 self.assertIn(compat.as_bytes('/JOB:MY_PS'), output[1]) 611 612 def testSimpleParameterServer(self): 613 remote.connect_to_cluster(self._cluster) 614 615 with ops.device('/job:my_ps/task:0/device:CPU:0'): 616 v1 = variables.Variable(initial_value=0) 617 v2 = variables.Variable(initial_value=10) 618 619 @def_function.function 620 def worker_fn(): 621 v1.assign_add(1) 622 v2.assign_sub(2) 623 return v1.read_value() + v2.read_value() 624 625 with ops.device('/job:my_worker/task:0/device:CPU:0'): 626 self.assertAllEqual(worker_fn(), 9) 627 628 with ops.device('/job:my_worker/task:1/device:CPU:0'): 629 self.assertAllEqual(worker_fn(), 8) 630 631 # TODO(b/152224115): Re-enable this test. 632 def DISABLED_testSimpleParameterServerWithDeviceFilters(self): 633 cluster_device_filters = server_lib.ClusterDeviceFilters() 634 for i in range(2): 635 cluster_device_filters.set_device_filters('my_worker', i, ['/job:my_ps']) 636 cluster_device_filters.set_device_filters('my_ps', i, ['/job:my_worker']) 637 remote.connect_to_cluster( 638 self._cluster, cluster_device_filters=cluster_device_filters) 639 640 with ops.device('/job:my_ps/task:0/device:CPU:0'): 641 v1 = variables.Variable(initial_value=0) 642 with ops.device('/job:my_ps/task:1/device:CPU:0'): 643 v2 = variables.Variable(initial_value=10) 644 645 @def_function.function 646 def worker_fn(): 647 v1.assign_add(1) 648 v2.assign_sub(2) 649 return v1.read_value() + v2.read_value() 650 651 with ops.device('/job:my_worker/task:0/device:CPU:0'): 652 self.assertAllEqual(worker_fn(), 9) 653 with ops.device('/job:my_worker/task:1/device:CPU:0'): 654 self.assertAllEqual(worker_fn(), 8) 655 656 # The following remote call would fail because the ps nodes cannot see each 657 # other due to the device filters. 658 with self.assertRaises(errors.InvalidArgumentError) as cm: 659 with ops.device('/job:my_ps/task:0/device:CPU:0'): 660 worker_fn().numpy() 661 self.assertIn('/job:my_ps/replica:0/task:1/device:CPU:0 unknown device', 662 cm.exception.message) 663 664 with self.assertRaises(errors.InvalidArgumentError) as cm: 665 with ops.device('/job:my_ps/task:1/device:CPU:0'): 666 worker_fn().numpy() 667 self.assertIn('/job:my_ps/replica:0/task:0/device:CPU:0 unknown device', 668 cm.exception.message) 669 670 with ops.device('/job:my_worker/task:0/device:CPU:0'): 671 self.assertAllEqual(worker_fn(), 7) 672 with ops.device('/job:my_worker/task:1/device:CPU:0'): 673 self.assertAllEqual(worker_fn(), 6) 674 # Explicitly delete variables to avoid triggering errors when being GC'ed in 675 # subsequent tests. 676 del v1, v2 677 678 def testConnectWithClusterResolver(self): 679 remote.connect_to_cluster(self._cluster_resolver) 680 681 v1 = variables.Variable(initial_value=0) 682 v2 = variables.Variable(initial_value=10) 683 684 @def_function.function 685 def worker_fn(): 686 v1.assign_add(1) 687 v2.assign_sub(2) 688 return v1.read_value() + v2.read_value() 689 690 with ops.device('/job:my_worker/task:0/device:CPU:0'): 691 self.assertAllEqual(worker_fn(), 9) 692 693 with ops.device('/job:my_worker/task:1/device:CPU:0'): 694 self.assertAllEqual(worker_fn(), 8) 695 696 def testConnectToClusterTwiceOk(self): 697 remote.connect_to_cluster(self._cluster_resolver) 698 remote.connect_to_cluster(self._cluster_resolver) 699 700 def testConnectToClusterOnMismatchedDevice(self): 701 remote.connect_to_cluster(self._cluster_resolver) 702 703 # enter into another device scope. 704 ops.device('/job:my_worker/task:0/device:CPU:0').__enter__() 705 706 with self.assertRaises(ValueError): 707 remote.connect_to_cluster(self._cluster_resolver) 708 709 def testConnectToClusterWithLocalMaster(self): 710 local_resolver = SimpleClusterResolver(ClusterSpec({}), master='local') 711 remote.connect_to_cluster(local_resolver) 712 713 def testConnectToClusterInGraphModeWillFail(self): 714 ops.disable_eager_execution() 715 with self.assertRaises(ValueError): 716 remote.connect_to_cluster(self._cluster_resolver) 717 ops.enable_eager_execution() 718 719 720def _strip_prefix(s, prefix): 721 return s[len(prefix):] if s.startswith(prefix) else s 722 723 724if __name__ == '__main__': 725 test.main() 726