1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for TPUClusterResolver."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23import six
24from six.moves.urllib.error import URLError
25
26from tensorflow.python import eager
27from tensorflow.python.client import session
28from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
29from tensorflow.python.framework import errors
30from tensorflow.python.framework import test_util
31from tensorflow.python.platform import test
32from tensorflow.python.training import server_lib
33from tensorflow.python.util import compat
34
35mock = test.mock
36
37
38class MockRequestClass(object):
39
40  def __init__(self, name, tpu_map):
41    self._name = name
42    self._tpu_map = tpu_map
43
44  def execute(self):
45    if self._name in self._tpu_map:
46      return self._tpu_map[self._name]
47    else:
48      raise KeyError('Resource %s was not found' % self._name)
49
50
51class MockNodeClass(object):
52
53  def __init__(self, tpu_map):
54    self._tpu_map = tpu_map
55
56  def get(self, name):
57    return MockRequestClass(name, self._tpu_map)
58
59
60def mock_request_compute_metadata(cls, *args, **kwargs):
61  del cls, kwargs  # Unused.
62  if args[0] == 'project/project-id':
63    return 'test-project'
64  elif args[0] == 'instance/zone':
65    return 'projects/test-project/locations/us-central1-c'
66  elif args[0] == 'instance/network-interfaces/0/ip':
67    return '10.128.1.2'
68  return ''
69
70
71def mock_is_running_in_gce(cls, *args, **kwargs):
72  del cls, args, kwargs  # Unused.
73  return True
74
75
76def mock_is_not_running_in_gce(cls, *args, **kwargs):
77  del cls, args, kwargs  # Unused.
78  return False
79
80
81def mock_running_in_gce_urlopen(cls, *args, **kwargs):
82  del cls, args, kwargs  # Unused.
83  mock_response = mock.MagicMock()
84  mock_response.info.return_value = {'Metadata-Flavor': 'Google'}
85  return mock_response
86
87
88def mock_not_running_in_gce_urlopen(cls, *args, **kwargs):
89  del cls, args, kwargs  # Unused.
90  raise URLError(reason='Host does not exist.')
91
92
93@test_util.run_all_in_graph_and_eager_modes
94class TPUClusterResolverTest(test.TestCase):
95
96  def _verifyClusterSpecEquality(self, cluster_spec, expected_proto):
97    """Verifies that the ClusterSpec generates the correct proto.
98
99    We are testing this four different ways to ensure that the ClusterSpec
100    returned by the TPUClusterResolver behaves identically to a normal
101    ClusterSpec when passed into the generic ClusterSpec libraries.
102
103    Args:
104      cluster_spec: ClusterSpec returned by the TPUClusterResolver
105      expected_proto: Expected protobuf
106    """
107    self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def())
108    self.assertProtoEquals(
109        expected_proto,
110        server_lib.ClusterSpec(cluster_spec).as_cluster_def())
111    self.assertProtoEquals(expected_proto,
112                           server_lib.ClusterSpec(
113                               cluster_spec.as_cluster_def()).as_cluster_def())
114    self.assertProtoEquals(expected_proto,
115                           server_lib.ClusterSpec(
116                               cluster_spec.as_dict()).as_cluster_def())
117
118  def mock_service_client(self, tpu_map=None):
119
120    if tpu_map is None:
121      tpu_map = {}
122
123    mock_locations = mock.MagicMock()
124    mock_locations.nodes.return_value = MockNodeClass(tpu_map)
125
126    mock_project = mock.MagicMock()
127    mock_project.locations.return_value = mock_locations
128
129    mock_client = mock.MagicMock()
130    mock_client.projects.return_value = mock_project
131
132    return mock_client
133
134  @mock.patch.object(TPUClusterResolver,
135                     '_isRunningInGCE',
136                     mock_is_running_in_gce)
137  def testCheckRunningInGceWithNoTpuName(self):
138    with self.assertRaisesRegexp(RuntimeError, '.*Google Cloud.*'):
139      TPUClusterResolver(tpu='')
140
141  @mock.patch.object(six.moves.urllib.request,
142                     'urlopen',
143                     mock_running_in_gce_urlopen)
144  def testIsRunningInGce(self):
145    self.assertTrue(TPUClusterResolver._isRunningInGCE())
146
147  @mock.patch.object(six.moves.urllib.request,
148                     'urlopen',
149                     mock_not_running_in_gce_urlopen)
150  def testIsNotRunningInGce(self):
151    self.assertFalse(TPUClusterResolver._isRunningInGCE())
152
153  @mock.patch.object(TPUClusterResolver,
154                     '_requestComputeMetadata',
155                     mock_request_compute_metadata)
156  def testRetrieveProjectAndZoneFromMetadata(self):
157    tpu_map = {
158        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
159            'ipAddress': '10.1.2.3',
160            'port': '8470',
161            'health': 'HEALTHY'
162        }
163    }
164
165    resolver = TPUClusterResolver(
166        project=None,
167        zone=None,
168        tpu=['test-tpu-1'],
169        credentials=None,
170        service=self.mock_service_client(tpu_map=tpu_map),
171        coordinator_name='coordinator')
172
173    actual_cluster_spec = resolver.cluster_spec()
174    expected_proto = """
175    job {
176      name: 'coordinator'
177      tasks { key: 0 value: '10.128.1.2:%s' }
178    }
179    job {
180      name: 'worker'
181      tasks { key: 0 value: '10.1.2.3:8470' }
182    }
183    """ % resolver._coordinator_port
184    self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto))
185    self.assertEqual(resolver.master(), 'grpc://10.1.2.3:8470')
186
187  @mock.patch.object(TPUClusterResolver,
188                     '_requestComputeMetadata',
189                     mock_request_compute_metadata)
190  def testRetrieveProjectAndZoneFromMetadataNoCoordinator(self):
191    tpu_map = {
192        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
193            'ipAddress': '10.1.2.3',
194            'port': '8470',
195            'health': 'HEALTHY'
196        }
197    }
198
199    resolver = TPUClusterResolver(
200        project=None,
201        zone=None,
202        tpu=['test-tpu-1'],
203        coordinator_name=None,
204        credentials=None,
205        service=self.mock_service_client(tpu_map=tpu_map))
206
207    actual_cluster_spec = resolver.cluster_spec()
208    expected_proto = """
209    job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } }
210    """
211    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
212    self.assertEqual(resolver.master(), 'grpc://10.1.2.3:8470')
213
214  @mock.patch.object(TPUClusterResolver,
215                     '_requestComputeMetadata',
216                     mock_request_compute_metadata)
217  def testNotReadyCloudTpu(self):
218    tpu_map = {
219        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
220            'ipAddress': '10.1.2.3',
221            'port': '8470',
222            'state': 'CREATING'
223        }
224    }
225
226    resolver = TPUClusterResolver(
227        project=None,
228        zone=None,
229        tpu='test-tpu-1',
230        coordinator_name=None,
231        credentials=None,
232        service=self.mock_service_client(tpu_map=tpu_map))
233
234    with self.assertRaises(RuntimeError):
235      resolver.cluster_spec()
236
237  def testSimpleSuccessfulRetrieval(self):
238    tpu_map = {
239        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
240            'ipAddress': '10.1.2.3',
241            'port': '8470',
242            'health': 'HEALTHY'
243        }
244    }
245
246    resolver = TPUClusterResolver(
247        project='test-project',
248        zone='us-central1-c',
249        tpu=['test-tpu-1'],
250        coordinator_name='coordinator',
251        coordinator_address='10.128.1.5:10203',
252        credentials=None,
253        service=self.mock_service_client(tpu_map=tpu_map))
254
255    actual_cluster_spec = resolver.cluster_spec()
256    expected_proto = """
257    job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } }
258    job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } }
259    """
260    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
261    self.assertEqual(resolver.master(), 'grpc://10.1.2.3:8470')
262
263  def testNewNetworkEndpointFormat(self):
264    tpu_map = {
265        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
266            'health': 'HEALTHY',
267            'networkEndpoints': [{
268                'ipAddress': '10.2.3.4',
269                'port': 8470,
270            }]
271        }
272    }
273
274    resolver = TPUClusterResolver(
275        project='test-project',
276        zone='us-central1-c',
277        tpu='test-tpu-1',
278        coordinator_name='coordinator',
279        coordinator_address='10.128.1.5:10203',
280        credentials=None,
281        service=self.mock_service_client(tpu_map=tpu_map))
282
283    actual_cluster_spec = resolver.cluster_spec()
284    expected_proto = """
285    job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } }
286    job { name: 'worker' tasks { key: 0 value: '10.2.3.4:8470' } }
287    """
288    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
289    self.assertEqual('grpc://10.2.3.4:8470', resolver.master())
290
291  @mock.patch.object(TPUClusterResolver,
292                     '_requestComputeMetadata',
293                     mock_request_compute_metadata)
294  def testPodResolution(self):
295    tpu_map = {
296        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
297            'health':
298                'HEALTHY',
299            'networkEndpoints': [
300                {
301                    'ipAddress': '10.2.3.4',
302                    'port': 8470,
303                },
304                {
305                    'ipAddress': '10.2.3.5',
306                    'port': 8470,
307                },
308                {
309                    'ipAddress': '10.2.3.6',
310                    'port': 8470,
311                },
312                {
313                    'ipAddress': '10.2.3.7',
314                    'port': 8470,
315                },
316            ]
317        }
318    }
319
320    resolver = TPUClusterResolver(
321        tpu='test-tpu-1',
322        credentials=None,
323        service=self.mock_service_client(tpu_map=tpu_map),
324        coordinator_name='coordinator')
325
326    actual_cluster_spec = resolver.cluster_spec()
327    expected_proto = """
328    job {
329      name: 'coordinator',
330      tasks { key: 0 value: '10.128.1.2:%s'}
331    }
332    job {
333      name: 'worker'
334      tasks { key: 0 value: '10.2.3.4:8470' }
335      tasks { key: 1 value: '10.2.3.5:8470' }
336      tasks { key: 2 value: '10.2.3.6:8470' }
337      tasks { key: 3 value: '10.2.3.7:8470' }
338    }
339    """ % resolver._coordinator_port
340    self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto))
341    self.assertEqual(resolver.master(), 'grpc://10.2.3.4:8470')
342
343  def testPodResolutionNoCoordinator(self):
344    tpu_map = {
345        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
346            'health':
347                'HEALTHY',
348            'networkEndpoints': [
349                {
350                    'ipAddress': '10.2.3.4',
351                    'port': 8470,
352                },
353                {
354                    'ipAddress': '10.2.3.5',
355                    'port': 8470,
356                },
357                {
358                    'ipAddress': '10.2.3.6',
359                    'port': 8470,
360                },
361                {
362                    'ipAddress': '10.2.3.7',
363                    'port': 8470,
364                },
365            ]
366        }
367    }
368
369    resolver = TPUClusterResolver(
370        project='test-project',
371        zone='us-central1-c',
372        tpu='test-tpu-1',
373        coordinator_name=None,
374        credentials=None,
375        service=self.mock_service_client(tpu_map=tpu_map))
376
377    actual_cluster_spec = resolver.cluster_spec()
378    expected_proto = """
379    job {
380      name: 'worker'
381      tasks { key: 0 value: '10.2.3.4:8470' }
382      tasks { key: 1 value: '10.2.3.5:8470' }
383      tasks { key: 2 value: '10.2.3.6:8470' }
384      tasks { key: 3 value: '10.2.3.7:8470' }
385    }
386    """
387    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
388    self.assertEqual(resolver.master(), 'grpc://10.2.3.4:8470')
389
390  def testGetMasterNoEntries(self):
391    tpu_map = {}
392
393    with self.assertRaises(ValueError):
394      TPUClusterResolver(
395          project='test-project',
396          zone='us-central1-c',
397          tpu=[],
398          coordinator_name=None,
399          credentials=None,
400          service=self.mock_service_client(tpu_map=tpu_map))
401
402  # TODO(saeta): Convert to parameterized test when included in OSS TF.
403  def verifyShouldResolve(self, tpu, should_resolve):
404    resolver = TPUClusterResolver(
405        project='test-project',
406        zone='us-central1-c',
407        tpu=tpu,
408        coordinator_name=None,
409        credentials=None,
410        service=self.mock_service_client(tpu_map={}))
411    self.assertEqual(should_resolve, resolver._shouldResolve(),
412                     "TPU: '%s'" % tpu)
413
414  @mock.patch.object(TPUClusterResolver,
415                     '_isRunningInGCE',
416                     mock_is_not_running_in_gce)
417  def testShouldResolveNoName(self):
418    self.verifyShouldResolve('', False)
419
420  def testShouldResolveLocal(self):
421    self.verifyShouldResolve('local', False)
422
423  def testShouldResolveGrpc(self):
424    self.verifyShouldResolve('grpc://10.1.2.3:8470', False)
425
426  def testShouldResolveBns(self):
427    self.verifyShouldResolve('/bns/foo/bar', False)
428
429  def testShouldResolveName(self):
430    self.verifyShouldResolve('mytpu', True)
431
432  def testShouldResolveList(self):
433    self.verifyShouldResolve(['myothertpu'], True)
434
435  def testShouldResolveGrpcPrefix(self):
436    self.verifyShouldResolve('grpctpu', True)
437
438  def testNoCallComputeMetadata(self):
439    resolver = TPUClusterResolver(
440        tpu='/bns/foo/bar')
441    self.assertEqual('/bns/foo/bar', resolver.master())
442    self.assertEqual(None, resolver.cluster_spec())
443
444  def testGkeEnvironmentForDonut(self):
445    os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470'
446
447    self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
448    self.assertTrue(TPUClusterResolver._inGke())
449    self.assertEqual(
450        compat.as_bytes('grpc://10.120.27.5:8470'),
451        compat.as_bytes(TPUClusterResolver._gkeEndpoints()))
452
453    resolver = TPUClusterResolver()
454    self.assertEqual(
455        compat.as_bytes('grpc://10.120.27.5:8470'),
456        compat.as_bytes(resolver.master()))
457    actual_cluster_spec = resolver.cluster_spec()
458    expected_proto = """
459    job {
460      name: 'worker'
461      tasks { key: 0 value: '10.120.27.5:8470' }
462    }
463    """
464    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
465
466    del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
467
468  def testGkeEnvironmentForPod(self):
469    os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = ('grpc://10.120.27.5:8470,'
470                                                     'grpc://10.120.27.6:8470,'
471                                                     'grpc://10.120.27.7:8470,'
472                                                     'grpc://10.120.27.8:8470')
473
474    self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
475    self.assertTrue(TPUClusterResolver._inGke())
476    self.assertEqual(
477        compat.as_bytes('grpc://10.120.27.5:8470,'
478                        'grpc://10.120.27.6:8470,'
479                        'grpc://10.120.27.7:8470,'
480                        'grpc://10.120.27.8:8470'),
481        compat.as_bytes(TPUClusterResolver._gkeEndpoints()))
482
483    resolver = TPUClusterResolver()
484    self.assertEqual(
485        compat.as_bytes('grpc://10.120.27.5:8470'),
486        compat.as_bytes(resolver.master()))
487    actual_cluster_spec = resolver.cluster_spec()
488    expected_proto = """
489    job {
490      name: 'worker'
491      tasks { key: 0 value: '10.120.27.5:8470' }
492      tasks { key: 1 value: '10.120.27.6:8470' }
493      tasks { key: 2 value: '10.120.27.7:8470' }
494      tasks { key: 3 value: '10.120.27.8:8470' }
495    }
496    """
497    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
498
499    del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
500
501  def testEnvironmentDiscoveryUrl(self):
502    os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}'
503    self.assertEqual('https://{api}.internal/{apiVersion}',
504                     (TPUClusterResolver.
505                      _environmentDiscoveryUrl()))
506
507  def testEnvironmentAndRpcDetectionForGoogle(self):
508    resolver = TPUClusterResolver(
509        tpu='/bns/ab/cd/ef')
510    self.assertEqual(resolver.environment, 'google')
511    self.assertEqual(resolver.rpc_layer, None)
512
513  def testEnvironmentAndRpcDetectionForGrpcString(self):
514    resolver = TPUClusterResolver(
515        tpu='grpc://10.1.2.3:8470')
516    self.assertEqual(resolver.environment, '')
517    self.assertEqual(resolver.rpc_layer, 'grpc')
518    self.assertEqual(resolver.master(), 'grpc://10.1.2.3:8470')
519
520  def testOverrideTaskTypeAndIndexAndGetMaster(self):
521    tpu_map = {
522        'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
523            'health':
524                'HEALTHY',
525            'networkEndpoints': [
526                {
527                    'ipAddress': '10.2.3.4',
528                    'port': 8470,
529                },
530                {
531                    'ipAddress': '10.2.3.5',
532                    'port': 8470,
533                },
534                {
535                    'ipAddress': '10.2.3.6',
536                    'port': 8470,
537                },
538                {
539                    'ipAddress': '10.2.3.7',
540                    'port': 8470,
541                },
542            ]
543        }
544    }
545
546    resolver = TPUClusterResolver(
547        project='test-project',
548        zone='us-central1-c',
549        tpu='test-tpu-1',
550        coordinator_name=None,
551        credentials=None,
552        service=self.mock_service_client(tpu_map=tpu_map))
553
554    self.assertEqual(resolver.master(), 'grpc://10.2.3.4:8470')
555
556    resolver.task_type = 'worker'
557    resolver.task_id = 3
558    self.assertEqual(resolver.master(), 'grpc://10.2.3.7:8470')
559
560    self.assertEqual(
561        resolver.master(
562            task_type='worker', task_id=2, rpc_layer='test'),
563        'test://10.2.3.6:8470')
564
565  def testGetDeviceDictAndCoresWithTPUs(self):
566    device_names = [
567        '/job:tpu_worker/task:0/device:TPU:0',
568        '/job:tpu_worker/task:1/device:TPU:1',
569        '/job:tpu_worker/task:2/device:TPU:0',
570        '/job:tpu_worker/task:3/device:TPU:1',
571        '/job:tpu_worker/task:0/device:TPU:4',
572        '/job:tpu_worker/task:1/device:TPU:5',
573        '/job:tpu_worker/task:2/device:TPU:4',
574        '/job:tpu_worker/task:3/device:TPU:5',
575    ]
576    device_list = [
577        session._DeviceAttributes(
578            name, 'TPU', 1024, 0) for name in device_names
579    ]
580
581    device_details = TPUClusterResolver._get_device_dict_and_cores(
582        device_list)
583    self.assertEqual(device_details.total_cores, 8)
584    self.assertEqual(device_details.device_map,
585                     {'0': ['0', '4'],
586                      '1': ['1', '5'],
587                      '2': ['0', '4'],
588                      '3': ['1', '5']})
589
590  def testGetDeviceDictAndCoresWithCPUsAndGPUs(self):
591    device_names = [
592        '/job:tpu_worker/task:0/device:CPU:0',
593        '/job:tpu_worker/task:1/device:CPU:0',
594        '/job:tpu_worker/task:2/device:CPU:0',
595        '/job:tpu_worker/task:3/device:CPU:0',
596        '/job:tpu_worker/task:0/device:GPU:1',
597        '/job:tpu_worker/task:1/device:GPU:1',
598        '/job:tpu_worker/task:2/device:GPU:1',
599        '/job:tpu_worker/task:3/device:GPU:1',
600    ]
601    device_list = [
602        session._DeviceAttributes(
603            name, 'XLA', 1024, 0) for name in device_names
604    ]
605
606    device_dict, num_cores = TPUClusterResolver._get_device_dict_and_cores(
607        device_list)
608    self.assertEqual(num_cores, 0)
609    self.assertEqual(device_dict, {})
610
611  def testVerifySameCoreCount(self):
612    self.assertEqual(
613        TPUClusterResolver._verify_and_return_same_core_count(
614            {0: [0, 1, 2, 3, 4, 5, 6, 7]}), 8)
615    self.assertEqual(
616        TPUClusterResolver._verify_and_return_same_core_count(
617            {0: [0, 1], 1: [2, 3]}), 2)
618    with self.assertRaises(RuntimeError):
619      TPUClusterResolver._verify_and_return_same_core_count(
620          {0: [0], 1: [1, 2]})
621
622  @mock.patch.object(eager.context, 'list_devices')
623  @mock.patch.object(session.BaseSession, 'list_devices')
624  @mock.patch.object(TPUClusterResolver,
625                     '_isRunningInGCE',
626                     mock_is_not_running_in_gce)
627  def testNumAcceleratorsSuccess(self, mock_list_devices,
628                                 mock_eager_list_devices):
629    device_names = [
630        '/job:tpu_worker/task:0/device:TPU:0',
631        '/job:tpu_worker/task:1/device:TPU:1',
632        '/job:tpu_worker/task:2/device:TPU:0',
633        '/job:tpu_worker/task:3/device:TPU:1',
634        '/job:tpu_worker/task:0/device:TPU:4',
635        '/job:tpu_worker/task:1/device:TPU:5',
636        '/job:tpu_worker/task:2/device:TPU:4',
637        '/job:tpu_worker/task:3/device:TPU:5',
638    ]
639    device_list = [
640        session._DeviceAttributes(
641            name, 'TPU', 1024, 0) for name in device_names
642    ]
643    mock_eager_list_devices.return_value = device_names
644    mock_list_devices.return_value = device_list
645
646    resolver = TPUClusterResolver(tpu='')
647    self.assertEqual(resolver.num_accelerators(), {'TPU': 2})
648
649  @mock.patch.object(eager.context, 'list_devices')
650  @mock.patch.object(session.BaseSession, 'list_devices')
651  @mock.patch.object(TPUClusterResolver,
652                     '_isRunningInGCE',
653                     mock_is_not_running_in_gce)
654  def testNumAcceleratorsRetryFailure(self, mock_list_devices,
655                                      mock_eager_list_devices):
656    resolver = TPUClusterResolver(tpu='')
657    mock_list_devices.side_effect = errors.DeadlineExceededError(
658        None, None, 'timeout')
659    mock_eager_list_devices.side_effect = errors.DeadlineExceededError(
660        None, None, 'timeout')
661    with self.assertRaises(RuntimeError):
662      resolver.num_accelerators()
663
664
665if __name__ == '__main__':
666  test.main()
667