1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for remote execution."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import random
23import time
24
25from absl.testing import parameterized
26import numpy as np
27import six
28
29from tensorflow.python.data.ops import dataset_ops
30from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver
31from tensorflow.python.eager import cancellation
32from tensorflow.python.eager import context
33from tensorflow.python.eager import def_function
34from tensorflow.python.eager import remote
35from tensorflow.python.eager import test
36from tensorflow.python.framework import constant_op
37from tensorflow.python.framework import dtypes
38from tensorflow.python.framework import errors
39from tensorflow.python.framework import ops
40from tensorflow.python.framework import tensor_spec
41from tensorflow.python.framework import test_ops
42from tensorflow.python.framework import test_util
43from tensorflow.python.ops import array_ops
44from tensorflow.python.ops import control_flow_ops
45from tensorflow.python.ops import data_flow_ops
46from tensorflow.python.ops import functional_ops
47from tensorflow.python.ops import math_ops
48from tensorflow.python.ops import resource_variable_ops
49from tensorflow.python.ops import string_ops
50from tensorflow.python.ops import variables
51from tensorflow.python.training import server_lib
52from tensorflow.python.training.server_lib import ClusterSpec
53from tensorflow.python.util import compat
54
55
56class SingleWorkerTest(test.TestCase, parameterized.TestCase):
57
58  def setUp(self):
59    super(SingleWorkerTest, self).setUp()
60
61    workers, _ = test_util.create_local_cluster(1, 0)
62    remote.connect_to_remote_host(workers[0].target)
63
64  def tearDown(self):
65    super(SingleWorkerTest, self).tearDown()
66
67    # Clear the current device scope to avoid polluting other test cases.
68    ops.device(None).__enter__()
69    # Reset the context to avoid polluting other test cases.
70    context._reset_context()
71
72  def testMultiDeviceFunctionBasic(self):
73
74    @def_function.function
75    def basic(i):
76      with ops.device('/job:localhost/replica:0/task:0/cpu:0'):
77        a = constant_op.constant([2]) + i
78      with ops.device('/job:worker/replica:0/task:0/cpu:0'):
79        b = constant_op.constant([1])
80
81      return a + b
82
83    self.assertAllEqual(basic(constant_op.constant([2])).numpy(), [5])
84    self.assertAllEqual(basic(constant_op.constant([1])).numpy(), [4])
85
86  def testMultiDeviceFunctionVariable(self):
87    with ops.device('/job:worker/replica:0/task:0/cpu:0'):
88      variable_b = variables.Variable(1)
89
90    # Add a sync point to avoid the out-of-order issue of eager async execution
91    # (b/155789951).
92    context.async_wait()
93
94    @def_function.function
95    def with_variable(i):
96      return i + variable_b
97
98    self.assertAllEqual(with_variable(constant_op.constant([2])).numpy(), [3])
99
100  def testMultiDeviceFunctionRemoteOutput(self):
101    with ops.device('/job:worker/replica:0/task:0/cpu:0'):
102      variable_b = variables.Variable(1)
103
104    @def_function.function
105    def remote_output(i):
106      with ops.device('/job:worker/replica:0/task:0/cpu:0'):
107        c = variable_b + 1
108      return i + variable_b, c
109
110    rets = remote_output(constant_op.constant([1]))
111    self.assertAllEqual(rets[0].numpy(), [2])
112    self.assertAllEqual(rets[1].numpy(), 2)
113    self.assertEqual(rets[0].backing_device,
114                     '/job:localhost/replica:0/task:0/device:CPU:0')
115    self.assertEqual(rets[1].backing_device,
116                     '/job:worker/replica:0/task:0/device:CPU:0')
117
118  def testStreaming(self):
119    """A mini stress test for streaming - issuing many RPCs back to back."""
120    with ops.device('job:worker/replica:0/task:0/device:CPU:0'):
121      x = array_ops.ones([2, 2])
122      y = array_ops.zeros([2, 2])
123      num_iters = 200
124      for _ in range(num_iters):
125        y = x + y
126        # Ask for y's shape after every 10 additions on average.
127        # This exercises waiting for remote shape logic in TensorHandle.
128        if random.randint(1, 10) == 1:
129          _ = y.shape
130    np.testing.assert_array_equal(
131        [[num_iters, num_iters], [num_iters, num_iters]], y.numpy())
132
133  def testShapeError_OpByOp(self):
134    with ops.device('job:worker/replica:0/task:0/device:CPU:0'):
135      x = array_ops.ones([2, 3])
136      y = array_ops.zeros([2, 2])
137      with self.assertRaises(errors.InvalidArgumentError) as cm:
138        math_ops.matmul(x, y)
139
140    self.assertIn('Dimensions must be equal', cm.exception.message)
141
142  def testShapeError_Function(self):
143
144    @def_function.function
145    def matmul_func(x, y):
146      return math_ops.matmul(x, y)
147
148    x = array_ops.ones([2, 3])
149    y = array_ops.zeros([2, 2])
150
151    with ops.device('job:worker/replica:0/task:0/device:CPU:0'):
152      with self.assertRaises(ValueError) as cm:
153        matmul_func(x, y)
154
155    if six.PY2:
156      self.assertIn('Dimensions must be equal', cm.exception.message)
157    else:
158      self.assertIn('Dimensions must be equal', cm.exception.args[0])
159
160  def testClientVarible(self):
161    var = variables.Variable(initial_value=0)
162
163    @def_function.function
164    def func():
165      with ops.device('/job:localhost/task:0'):
166        read = var.read_value()
167      return read + 1
168
169    with ops.device('/job:worker/task:0'):
170      self.assertAllEqual(func(), 1)
171
172  def testRemoteCall(self):
173
174    @def_function.function(
175        input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
176    def _remote_fn(x):
177      return constant_op.constant(1) + x
178
179    remote_fn = _remote_fn.get_concrete_function()
180
181    @def_function.function
182    def func(x):
183      return functional_ops.remote_call(
184          args=[x],
185          Tout=[dtypes.int32],
186          f=remote_fn,
187          target='/job:worker/task:0')
188
189    with ops.device('/job:localhost/task:0'):
190      self.assertAllEqual(func(constant_op.constant(1)), [2])
191
192
193class RemoteAsyncTest(test.TestCase):
194
195  def setUp(self):
196    super(RemoteAsyncTest, self).setUp()
197
198    workers, _ = test_util.create_local_cluster(1, 0)
199    remote.connect_to_remote_host(workers[0].target)
200
201  def tearDown(self):
202    super(RemoteAsyncTest, self).tearDown()
203
204    # Reset the context to avoid polluting other test cases.
205    context._reset_context()
206
207  def test_out_of_range_with_while_loop(self):
208
209    with ops.device('/job:worker/task:0'):
210      dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0])
211      dataset = dataset.batch(1, drop_remainder=False)
212      iterator = iter(dataset)
213      v = variables.Variable(1.0)
214
215    @def_function.function
216    def train_step(iterator):
217      i = next(iterator)
218      v.assign_add(math_ops.reduce_mean(i))
219
220    while True:
221      try:
222        with ops.device('/job:worker/task:0'):
223          train_step(iterator)
224      except (errors.OutOfRangeError, errors.InternalError):
225        context.async_clear_error()
226        break
227
228    self.assertAllEqual(v.numpy(), 4.0)
229
230  def test_out_of_range_with_for_loop(self):
231
232    with ops.device('/job:worker/task:0'):
233      dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0])
234      dataset = dataset.batch(1, drop_remainder=False)
235      iterator = iter(dataset)
236      v = variables.Variable(1.0)
237
238    @def_function.function
239    def train_step(iterator):
240      i = next(iterator)
241      v.assign_add(math_ops.reduce_mean(i))
242
243    num_steps = 3
244    for i in range(num_steps):
245      try:
246        with ops.device('/job:worker/task:0'):
247          train_step(iterator)
248        if i == num_steps - 1:
249          context.async_wait()
250      except errors.OutOfRangeError:
251        context.async_clear_error()
252        break
253
254    self.assertAllEqual(v.numpy(), 4.0)
255
256  def test_out_of_range_with_async_scope(self):
257
258    with ops.device('/job:worker/task:0'):
259      dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0])
260      dataset = dataset.batch(1, drop_remainder=False)
261      iterator = iter(dataset)
262      v = variables.Variable(1.0)
263
264    @def_function.function
265    def train_step(iterator):
266      i = next(iterator)
267      v.assign_add(math_ops.reduce_mean(i))
268
269    num_steps = 3
270    try:
271      with context.async_scope():
272        for _ in range(num_steps):
273          with ops.device('/job:worker/task:0'):
274            train_step(iterator)
275    except errors.OutOfRangeError:
276      context.async_clear_error()
277
278    self.assertAllEqual(v.numpy(), 4.0)
279
280
281class MultiWorkersTest(test.TestCase, parameterized.TestCase):
282
283  def setUp(self):
284    super(MultiWorkersTest, self).setUp()
285
286    workers, _ = test_util.create_local_cluster(3, 0)
287    remote.connect_to_remote_host(
288        [workers[0].target, workers[1].target, workers[2].target])
289
290  def tearDown(self):
291    super(MultiWorkersTest, self).tearDown()
292
293    # Clear the current device scope to avoid polluting other test cases.
294    ops.device(None).__enter__()
295    # Reset the context to avoid polluting other test cases.
296    context._reset_context()
297
298  def testReturnRemoteArgument(self):
299
300    @def_function.function
301    def local_func(i):
302      return i
303
304    with ops.device('/job:worker/replica:0/task:0'):
305      x = constant_op.constant([2, 1])
306
307    with ops.device('/job:worker/replica:0/task:1'):
308      self.assertAllEqual(local_func(x), [2, 1])
309
310  def testMultiDeviceFunctionAmbiguousDevice(self):
311
312    @def_function.function
313    def ambiguous_device(i):
314      with ops.device('/job:worker'):
315        # Multiple worker tasks, thus ambiguous device found error will be
316        # raised.
317        return i + constant_op.constant([2])
318
319    with self.assertRaises(errors.InvalidArgumentError) as cm:
320      ambiguous_device(constant_op.constant([2])).numpy()
321
322    self.assertIn('the output node must match exactly one device',
323                  cm.exception.message)
324
325  # Note that the following tests for remote function cancellation only works
326  # when non-streaming RPC. We need to disable streaming explicitly and restore
327  # this config to its initial value at the end of each test case.
328  def testCancelRemoteFunctionBeforeExecution(self):
329    remote_async_env_var = 'TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE'
330    default_streaming = os.environ.get(remote_async_env_var)
331    os.environ[remote_async_env_var] = str(False)
332
333    q = data_flow_ops.FIFOQueue(1, dtypes.int32)
334
335    @def_function.function
336    def f():
337      return q.dequeue()
338
339    c_mgr = cancellation.CancellationManager()
340    cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function())
341
342    c_mgr.start_cancel()
343    with self.assertRaises(errors.CancelledError):
344      with ops.device('/job:worker/replica:0/task:1'):
345        cancelable_func()
346
347    if default_streaming is None:
348      del os.environ[remote_async_env_var]
349    else:
350      os.environ[remote_async_env_var] = default_streaming
351
352  def testCancelRemoteFunctionDuringExecution(self):
353    remote_async_env_var = 'TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE'
354    default_streaming = os.environ.get(remote_async_env_var)
355    os.environ[remote_async_env_var] = str(False)
356
357    q = data_flow_ops.FIFOQueue(1, dtypes.int32)
358
359    @def_function.function
360    def f():
361      return q.dequeue()
362
363    c_mgr = cancellation.CancellationManager()
364    cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function())
365
366    def cancel_thread():
367      time.sleep(0.5)
368      c_mgr.start_cancel()
369
370    t = self.checkedThread(cancel_thread)
371    t.start()
372    with self.assertRaises(errors.CancelledError):
373      with ops.device('/job:worker/replica:0/task:1'):
374        cancelable_func()
375    t.join()
376
377    if default_streaming is None:
378      del os.environ[remote_async_env_var]
379    else:
380      os.environ[remote_async_env_var] = default_streaming
381
382  def testMultiDeviceFunctionOnLocalDevice(self):
383    with ops.device('/job:worker/replica:0/task:1'):
384      variable_b = variables.Variable(1.0)
385
386    @def_function.function
387    def remote_function(i):
388      with ops.device('/job:worker/replica:0/task:0'):
389        a = i + variable_b
390      c = a + 1.0
391      return c
392
393    self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
394
395  def testMultiDeviceFunctionExecutionOrderingWithPackedInput(self):
396    shape = [2]
397    with ops.device('/job:worker/replica:0/task:2/device:CPU:0'):
398      # Send 20 remote requests to simulate heavy load on worker:2.
399      unused_values = []
400      for _ in range(20):
401        unused_values.append(array_ops.zeros(shape))
402      func_input = array_ops.zeros(shape)
403
404    packed_input = ops.pack_eager_tensors([func_input])
405
406    @def_function.function
407    def func(packed_input):
408      # When worker:2 receives the component function request, packed_input
409      # should be ready on worker:2.
410      with ops.device('/job:worker/replica:0/task:2/device:CPU:0'):
411        ret = packed_input + constant_op.constant(1.0)
412      return ret + constant_op.constant(1.0)
413
414    # Run the function on a worker:1
415    with ops.device('/job:worker/replica:0/task:1/device:CPU:0'):
416      self.assertAllEqual(func(packed_input).numpy(),
417                          array_ops.ones(shape).numpy() * 2)
418
419  def testMultiDeviceFunctionWithPackedVariable(self):
420    with ops.device('/job:worker/replica:0/task:0/device:CPU:0'):
421      var0 = resource_variable_ops.ResourceVariable(1.0)
422    with ops.device('/job:worker/replica:0/task:1/device:CPU:0'):
423      var1 = resource_variable_ops.ResourceVariable(2.0)
424
425    packed_var = ops.pack_eager_tensors([var0.handle, var1.handle])
426    self.assertEqual(packed_var.device,
427                     '/job:localhost/replica:0/task:0/device:COMPOSITE:0')
428    self.assertEqual(packed_var.backing_device,
429                     '/job:localhost/replica:0/task:0/device:COMPOSITE:0')
430
431    @def_function.function
432    def add_variables():
433      with ops.device('/job:worker/replica:0/task:0/device:CPU:0'):
434        read0 = resource_variable_ops.read_variable_op(
435            packed_var, dtype=dtypes.float32)
436      with ops.device('/job:worker/replica:0/task:1/device:CPU:0'):
437        read1 = resource_variable_ops.read_variable_op(
438            packed_var, dtype=dtypes.float32)
439
440      return read0 + read1
441
442    # Run the function on a remote device
443    with ops.device('/job:worker/replica:0/task:0'):
444      self.assertAllEqual(add_variables().numpy(), 3.0)
445
446    # Run the function on a local worker
447    self.assertAllEqual(add_variables().numpy(), 3.0)
448
449  def testMultiDeviceFunctionOnRemoteDeviceWithWait(self):
450    with ops.device('/job:worker/replica:0/task:1'):
451      variable_b = variables.Variable([1.0])
452
453    @def_function.function
454    def remote_function(i):
455      x = array_ops.ones([1000, 1000])
456      for _ in range(1, 1000):
457        x = x * x
458      variable_b.assign_add(i)
459      a = 1.0 + variable_b
460      return a
461
462    @def_function.function
463    def remote_function2(i):
464      variable_b.assign_add(i)
465      a = 1.0 + variable_b
466      return a
467
468    # Runs first function:
469    # - on remote device
470    # - needs remote input
471    # - is side impacting
472    # - runs much slower
473    with ops.device('/job:worker/replica:0/task:0'):
474      remote_function(constant_op.constant([2.0]))
475
476    # Runs second function:
477    # - on remote device
478    # - is side impacting
479    # There should be a sync point here and the next function will be executed
480    # only after the first function has completed.
481    with ops.device('/job:worker/replica:0/task:2'):
482      self.assertAllEqual(remote_function2(constant_op.constant([3.0])), [7.0])
483
484  def testMultiDeviceFunctionOnRemoteDevice(self):
485    with ops.device('/job:worker/replica:0/task:1'):
486      variable_b = variables.Variable(1.0)
487
488    @def_function.function
489    def remote_function(i):
490      with ops.device('/job:worker/replica:0/task:0'):
491        a = i + variable_b
492      c = a + 1.0
493      return c
494
495    with ops.device('/job:worker/replica:0/task:0'):
496      self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
497
498    if test_util.is_gpu_available():
499      with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
500        self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
501
502  def testMultiDeviceFunctionRemoteOutput(self):
503    with ops.device('/job:worker/replica:0/task:1/cpu:0'):
504      variable_b = variables.Variable(1)
505
506    @def_function.function
507    def remote_output(i):
508      with ops.device('/job:worker/replica:0/task:1/cpu:0'):
509        c = variable_b + 1
510      return i + variable_b, c
511
512    with ops.device('/job:worker/replica:0/task:0/cpu:0'):
513      rets = remote_output(constant_op.constant([1]))
514    self.assertEqual(rets[0].backing_device,
515                     '/job:worker/replica:0/task:0/device:CPU:0')
516    self.assertEqual(rets[1].backing_device,
517                     '/job:worker/replica:0/task:1/device:CPU:0')
518    self.assertAllEqual(rets[0].numpy(), [2])
519    self.assertAllEqual(rets[1].numpy(), 2)
520
521  def testMultiDeviceWhileLoopOnRemoteDevice(self):
522    with ops.device('/job:worker/replica:0/task:1'):
523      variable_b = variables.Variable(1.0)
524
525    @def_function.function
526    def remote_function(i):
527
528      def body(i, _):
529        with ops.device('/job:worker/replica:0/task:0'):
530          a = i + variable_b
531        return a + 1.0, 1
532
533      return control_flow_ops.while_loop_v2(lambda _, d: d < 1, body, [i, 0])[0]
534
535    with ops.device('/job:worker/replica:0/task:0'):
536      self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
537
538    if test_util.is_gpu_available():
539      with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
540        self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
541
542  def testSimpleParameterServer(self):
543
544    with ops.device('/job:worker/task:2/device:CPU:0'):
545      v1 = variables.Variable(initial_value=0)
546      v2 = variables.Variable(initial_value=10)
547
548    @def_function.function
549    def worker_fn():
550      v1.assign_add(1)
551      v2.assign_sub(2)
552      return v1.read_value() + v2.read_value()
553
554    with ops.device('/job:worker/task:0/device:CPU:0'):
555      self.assertAllEqual(worker_fn(), 9)
556
557    with ops.device('/job:worker/task:1/device:CPU:0'):
558      self.assertAllEqual(worker_fn(), 8)
559
560
561_GRPC_PREFIX = 'grpc://'
562
563
564class MultiJobsTest(test.TestCase, parameterized.TestCase):
565
566  def setUp(self):
567    super(MultiJobsTest, self).setUp()
568
569    workers, ps = test_util.create_local_cluster(num_workers=2, num_ps=2)
570    cluster = {
571        'my_worker': [_strip_prefix(t.target, _GRPC_PREFIX) for t in workers],
572        'my_ps': [_strip_prefix(t.target, _GRPC_PREFIX) for t in ps],
573    }
574    self._cluster = server_lib.ClusterSpec(cluster)
575    self._cluster_resolver = SimpleClusterResolver(
576        cluster_spec=self._cluster, master=ps[0].target)
577
578  def tearDown(self):
579    super(MultiJobsTest, self).tearDown()
580
581    # Clear the current device scope to avoid polluting other test cases.
582    ops.device(None).__enter__()
583    # Reset the context to avoid polluting other test cases.
584    context._reset_context()
585
586  def testMultipleDeviceFoundCheck(self):
587    remote.connect_to_cluster(self._cluster)
588
589    @def_function.function
590    def func():
591      with ops.device('cpu:0'):
592        # Multiple CPU:0 devices match would be found, but the CPU:0 from the
593        # parent device scope should be picked.
594        x = test_ops.device_placement_op()
595        y = string_ops.string_upper(x)
596        packed_var_0 = array_ops.stack([x, y], 0)
597        return packed_var_0
598
599    with ops.device('/job:my_worker/task:1'):
600      output = self.evaluate(func())
601      self.assertEqual(
602          compat.as_bytes('/job:my_worker/replica:0/task:1/device:CPU:0'),
603          output[0])
604      self.assertIn(compat.as_bytes('/JOB:MY_WORKER'), output[1])
605    with ops.device('/job:my_ps/task:1'):
606      output = self.evaluate(func())
607      self.assertEqual(
608          compat.as_bytes('/job:my_ps/replica:0/task:1/device:CPU:0'),
609          output[0])
610      self.assertIn(compat.as_bytes('/JOB:MY_PS'), output[1])
611
612  def testSimpleParameterServer(self):
613    remote.connect_to_cluster(self._cluster)
614
615    with ops.device('/job:my_ps/task:0/device:CPU:0'):
616      v1 = variables.Variable(initial_value=0)
617      v2 = variables.Variable(initial_value=10)
618
619    @def_function.function
620    def worker_fn():
621      v1.assign_add(1)
622      v2.assign_sub(2)
623      return v1.read_value() + v2.read_value()
624
625    with ops.device('/job:my_worker/task:0/device:CPU:0'):
626      self.assertAllEqual(worker_fn(), 9)
627
628    with ops.device('/job:my_worker/task:1/device:CPU:0'):
629      self.assertAllEqual(worker_fn(), 8)
630
631  # TODO(b/152224115): Re-enable this test.
632  def DISABLED_testSimpleParameterServerWithDeviceFilters(self):
633    cluster_device_filters = server_lib.ClusterDeviceFilters()
634    for i in range(2):
635      cluster_device_filters.set_device_filters('my_worker', i, ['/job:my_ps'])
636      cluster_device_filters.set_device_filters('my_ps', i, ['/job:my_worker'])
637    remote.connect_to_cluster(
638        self._cluster, cluster_device_filters=cluster_device_filters)
639
640    with ops.device('/job:my_ps/task:0/device:CPU:0'):
641      v1 = variables.Variable(initial_value=0)
642    with ops.device('/job:my_ps/task:1/device:CPU:0'):
643      v2 = variables.Variable(initial_value=10)
644
645    @def_function.function
646    def worker_fn():
647      v1.assign_add(1)
648      v2.assign_sub(2)
649      return v1.read_value() + v2.read_value()
650
651    with ops.device('/job:my_worker/task:0/device:CPU:0'):
652      self.assertAllEqual(worker_fn(), 9)
653    with ops.device('/job:my_worker/task:1/device:CPU:0'):
654      self.assertAllEqual(worker_fn(), 8)
655
656    # The following remote call would fail because the ps nodes cannot see each
657    # other due to the device filters.
658    with self.assertRaises(errors.InvalidArgumentError) as cm:
659      with ops.device('/job:my_ps/task:0/device:CPU:0'):
660        worker_fn().numpy()
661    self.assertIn('/job:my_ps/replica:0/task:1/device:CPU:0 unknown device',
662                  cm.exception.message)
663
664    with self.assertRaises(errors.InvalidArgumentError) as cm:
665      with ops.device('/job:my_ps/task:1/device:CPU:0'):
666        worker_fn().numpy()
667    self.assertIn('/job:my_ps/replica:0/task:0/device:CPU:0 unknown device',
668                  cm.exception.message)
669
670    with ops.device('/job:my_worker/task:0/device:CPU:0'):
671      self.assertAllEqual(worker_fn(), 7)
672    with ops.device('/job:my_worker/task:1/device:CPU:0'):
673      self.assertAllEqual(worker_fn(), 6)
674    # Explicitly delete variables to avoid triggering errors when being GC'ed in
675    # subsequent tests.
676    del v1, v2
677
678  def testConnectWithClusterResolver(self):
679    remote.connect_to_cluster(self._cluster_resolver)
680
681    v1 = variables.Variable(initial_value=0)
682    v2 = variables.Variable(initial_value=10)
683
684    @def_function.function
685    def worker_fn():
686      v1.assign_add(1)
687      v2.assign_sub(2)
688      return v1.read_value() + v2.read_value()
689
690    with ops.device('/job:my_worker/task:0/device:CPU:0'):
691      self.assertAllEqual(worker_fn(), 9)
692
693    with ops.device('/job:my_worker/task:1/device:CPU:0'):
694      self.assertAllEqual(worker_fn(), 8)
695
696  def testConnectToClusterTwiceOk(self):
697    remote.connect_to_cluster(self._cluster_resolver)
698    remote.connect_to_cluster(self._cluster_resolver)
699
700  def testConnectToClusterOnMismatchedDevice(self):
701    remote.connect_to_cluster(self._cluster_resolver)
702
703    # enter into another device scope.
704    ops.device('/job:my_worker/task:0/device:CPU:0').__enter__()
705
706    with self.assertRaises(ValueError):
707      remote.connect_to_cluster(self._cluster_resolver)
708
709  def testConnectToClusterWithLocalMaster(self):
710    local_resolver = SimpleClusterResolver(ClusterSpec({}), master='local')
711    remote.connect_to_cluster(local_resolver)
712
713  def testConnectToClusterInGraphModeWillFail(self):
714    ops.disable_eager_execution()
715    with self.assertRaises(ValueError):
716      remote.connect_to_cluster(self._cluster_resolver)
717    ops.enable_eager_execution()
718
719
720def _strip_prefix(s, prefix):
721  return s[len(prefix):] if s.startswith(prefix) else s
722
723
724if __name__ == '__main__':
725  test.main()
726