1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementation of Cluster Resolvers for Cloud TPUs."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import os
23import re
24
25from six.moves import urllib
26from six.moves.urllib.error import URLError
27from six.moves.urllib.request import Request
28from six.moves.urllib.request import urlopen
29
30from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
31from tensorflow.python.distribute.cluster_resolver.cluster_resolver import format_master_url
32from tensorflow.python.distribute.cluster_resolver.cluster_resolver import get_accelerator_devices
33from tensorflow.python.framework import errors
34from tensorflow.python.platform import tf_logging as logging
35from tensorflow.python.training import server_lib
36from tensorflow.python.util import compat
37from tensorflow.python.util.tf_export import tf_export
38
39_GOOGLE_API_CLIENT_INSTALLED = True
40try:
41  from googleapiclient import discovery  # pylint: disable=g-import-not-at-top
42  from oauth2client.client import GoogleCredentials  # pylint: disable=g-import-not-at-top
43except ImportError:
44  _GOOGLE_API_CLIENT_INSTALLED = False
45
46_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
47_ENDPOINTS_SEPARATOR = ','
48_DEFAULT_ENV_VARIABLE = 'TPU_NAME'
49_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL'
50
51_TPU_DEVICE_REGEX = re.compile(
52    r'.*task:(?P<host_id>\d+)/.*device:TPU:(?P<core_id>\d+)$')
53_TPU_CONN_RETRIES = 120
54
55DeviceDetails = collections.namedtuple(
56    'DeviceDetails', ['device_map', 'total_cores'])
57
58
59@tf_export('distribute.cluster_resolver.TPUClusterResolver')
60class TPUClusterResolver(ClusterResolver):
61  """Cluster Resolver for Google Cloud TPUs.
62
63  This is an implementation of cluster resolvers for the Google Cloud TPU
64  service. As Cloud TPUs are in alpha, you will need to specify a API definition
65  file for this to consume, in addition to a list of Cloud TPUs in your Google
66  Cloud Platform project.
67  """
68
69  def _tpuService(self):
70    """Creates a new Cloud TPU API object.
71
72    This works around an issue where the underlying HTTP connection sometimes
73    times out when the script has been running for too long. Other methods in
74    this object calls this method to get a new API object whenever they need
75    to communicate with the Cloud API.
76
77    Returns:
78      A Google Cloud TPU API object.
79    """
80    if self._service:
81      return self._service
82
83    credentials = self._credentials
84    if credentials is None or credentials == 'default':
85      credentials = GoogleCredentials.get_application_default()
86
87    if self._discovery_url:
88      return discovery.build(
89          'tpu', 'v1alpha1',
90          credentials=credentials,
91          discoveryServiceUrl=self._discovery_url)
92    else:
93      return discovery.build(
94          'tpu', 'v1alpha1',
95          credentials=credentials)
96
97  def _requestComputeMetadata(self, path):
98    req = Request('http://metadata/computeMetadata/v1/%s' % path,
99                  headers={'Metadata-Flavor': 'Google'})
100    resp = urlopen(req)
101    return compat.as_bytes(resp.read())
102
103  def _shouldResolve(self):
104    if isinstance(self._should_resolve_override, bool):
105      return self._should_resolve_override
106    if (self._tpu == compat.as_bytes('') or
107        self._tpu == compat.as_bytes('local') or
108        self._tpu.startswith(compat.as_bytes('/bns')) or
109        self._tpu.startswith(compat.as_bytes('localhost:')) or
110        self._tpu.startswith(compat.as_bytes('grpc://')) or
111        self._tpu.startswith(compat.as_bytes('uptc://'))):
112      return False
113    return True
114
115  @staticmethod
116  def _get_device_dict_and_cores(devices):
117    """Returns a dict of hosts to cores and total cores given devices names.
118
119    Returns a namedtuple with two attributes:
120      device_map: A map of host_ids to a list of core_ids.
121      total_cores: The total number of cores within the TPU system.
122
123    Args:
124      devices: A list of devices returned by session.list_devices()
125    """
126    device_map = collections.defaultdict(list)
127    num_cores = 0
128    for device in devices:
129      match = _TPU_DEVICE_REGEX.match(device.name)
130      if match:
131        host_id = match.group('host_id')
132        core_id = match.group('core_id')
133        device_map[host_id].append(core_id)
134        num_cores += 1
135    return DeviceDetails(device_map, num_cores)
136
137  @staticmethod
138  def _verify_and_return_same_core_count(device_dict):
139    """Verifies that every device in device_dict has the same # of cores."""
140    num_cores_per_host_set = (
141        {len(core_ids) for core_ids in device_dict.values()})
142    if len(num_cores_per_host_set) != 1:
143      raise RuntimeError('TPU cores on each device is not the same. This '
144                         'should never happen. Devices: {}'.format(device_dict))
145    return num_cores_per_host_set.pop()
146
147  @staticmethod
148  def _inGke():
149    """When running in GKE, the environment variable will be set."""
150    return _GKE_ENV_VARIABLE in os.environ
151
152  @staticmethod
153  def _gkeEndpoints():
154    return os.environ[_GKE_ENV_VARIABLE]
155
156  @staticmethod
157  def _envVarFallback():
158    if _DEFAULT_ENV_VARIABLE in os.environ:
159      return os.environ[_DEFAULT_ENV_VARIABLE]
160    return None
161
162  @staticmethod
163  def _environmentDiscoveryUrl():
164    return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE)
165
166  @staticmethod
167  def _isRunningInGCE():
168    """Checks for GCE presence by attempting to query the metadata service."""
169    try:
170      req = Request('http://metadata.google.internal/computeMetadata/v1',
171                    headers={'Metadata-Flavor': 'Google'})
172      resp = urllib.request.urlopen(req, timeout=1)
173      info = resp.info()
174      if 'Metadata-Flavor' in info and info['Metadata-Flavor'] == 'Google':
175        return True
176    except URLError:
177      pass
178    return False
179
180  def __init__(self,
181               tpu=None,
182               zone=None,
183               project=None,
184               job_name='worker',
185               coordinator_name=None,
186               coordinator_address=None,
187               credentials='default',
188               service=None,
189               discovery_url=None):
190    """Creates a new TPUClusterResolver object.
191
192    The ClusterResolver will then use the parameters to query the Cloud TPU APIs
193    for the IP addresses and ports of each Cloud TPU listed.
194
195    Args:
196      tpu: A string corresponding to the TPU to use. If the string is the empty
197        string, the string 'local', or a string that begins with 'grpc://' or
198        '/bns', then it is assumed to not correspond with a Cloud TPU and will
199        instead be passed as the session master and no ClusterSpec propagation
200        will be done. In the future, this may also support a list of strings
201        when multiple Cloud TPUs are used.
202      zone: Zone where the TPUs are located. If omitted or empty, we will assume
203        that the zone of the TPU is the same as the zone of the GCE VM, which we
204        will try to discover from the GCE metadata service.
205      project: Name of the GCP project containing Cloud TPUs. If omitted or
206        empty, we will try to discover the project name of the GCE VM from the
207        GCE metadata service.
208      job_name: Name of the TensorFlow job the TPUs belong to.
209      coordinator_name: The name to use for the coordinator. Set to None if the
210        coordinator should not be included in the computed ClusterSpec.
211      coordinator_address: The address of the coordinator (typically an ip:port
212        pair). If set to None, a TF server will be started. If coordinator_name
213        is None, a TF server will not be started even if coordinator_address is
214        None.
215      credentials: GCE Credentials. If None, then we use default credentials
216        from the oauth2client
217      service: The GCE API object returned by the googleapiclient.discovery
218        function. If you specify a custom service object, then the credentials
219        parameter will be ignored.
220      discovery_url: A URL template that points to the location of
221        the discovery service. It should have two parameters {api} and
222        {apiVersion} that when filled in produce an absolute URL to the
223        discovery document for that service. The environment variable
224        'TPU_API_DISCOVERY_URL' will override this.
225
226    Raises:
227      ImportError: If the googleapiclient is not installed.
228      ValueError: If no TPUs are specified.
229      RuntimeError: If an empty TPU name is specified and this is running in a
230        Google Cloud environment.
231    """
232    if isinstance(tpu, list):
233      if not tpu:
234        raise ValueError('At least one TPU must be specified.')
235      if len(tpu) != 1:
236        raise NotImplementedError(
237            'Using multiple TPUs in a single session is not yet implemented')
238      tpu = tpu[0]
239
240    in_gke = self._inGke()
241    # When using GKE with Cloud TPUs, the env variable will be set.
242    if tpu is None:
243      if in_gke:
244        tpu = self._gkeEndpoints()
245      else:
246        tpu = self._envVarFallback()
247
248    if tpu is None:
249      raise ValueError('Please provide a TPU Name to connect to.')
250
251    self._tpu = compat.as_bytes(tpu)  # self._tpu is always bytes
252
253    # If we are running in Cloud and don't specify a TPU name
254    if self._isRunningInGCE() and not self._tpu:
255      raise RuntimeError('You need to specify a TPU Name if you are running in '
256                         'the Google Cloud environment.')
257
258    # By default the task_type is 'worker` and the task_id is 0 (which is the
259    # first worker in the task).
260    self.task_type = job_name
261    self.task_id = 0
262
263    if tpu.startswith('grpc://'):
264      # Cloud environment, where we are using GRPC to communicate to TPUs.
265      self._environment = ''
266    elif tpu == 'local' or not tpu:
267      # Google environment, where the TPU is attached to the host.
268      self._environment = 'google'
269    elif tpu.startswith('/bns') or tpu.startswith('uptc://'):
270      # Google environment, where we reach the TPU through BNS.
271      self._environment = 'google'
272
273    # If TPU is in the Google environment or exists locally, we don't use any
274    # RPC layer.
275    if tpu.startswith('/bns') or tpu.startswith(
276        'uptc://') or tpu == 'local' or not tpu:
277      self.rpc_layer = None
278    else:
279      self.rpc_layer = 'grpc'
280
281    # Setting this overrides the return value of self._shouldResolve()
282    self._should_resolve_override = None
283
284    # We strip out the protocol if it is included, and override the
285    # shouldResolve function to never resolve. We are adding the protocol back
286    # in later in self.master().
287    if self.rpc_layer is not None and tpu.startswith(self.rpc_layer + '://'):
288      tpu = tpu[len(self.rpc_layer + '://'):]
289      self._tpu = compat.as_bytes(tpu)  # self._tpu is always bytes
290      self._should_resolve_override = False
291
292    # Whether we should actually attempt to contact Cloud APIs
293    should_resolve = self._shouldResolve()
294
295    # We error out if we are in a non-Cloud environment which cannot talk to the
296    # Cloud APIs using the standard class and a special object is not passed in.
297    self._service = service
298    if (self._service is None and should_resolve and
299        not _GOOGLE_API_CLIENT_INSTALLED):
300      raise ImportError('googleapiclient and oauth2client must be installed '
301                        'before using the TPU cluster resolver. Execute: '
302                        '`pip install --upgrade google-api-python-client` '
303                        'and `pip install --upgrade oauth2client` to '
304                        'install with pip.')
305
306    # We save user-passed credentials, unless the user didn't pass in anything.
307    self._credentials = credentials
308    if (credentials == 'default' and should_resolve and
309        _GOOGLE_API_CLIENT_INSTALLED):
310      self._credentials = None
311
312    # Automatically detect project and zone if unspecified.
313    if not project and should_resolve:
314      project = compat.as_str(
315          self._requestComputeMetadata('project/project-id'))
316    if not zone and should_resolve:
317      zone_path = compat.as_str(self._requestComputeMetadata('instance/zone'))
318      zone = zone_path.split('/')[-1]
319    self._project = project
320    self._zone = zone
321
322    self._discovery_url = self._environmentDiscoveryUrl() or discovery_url
323
324    self._coordinator_name = coordinator_name
325    if (coordinator_name and not coordinator_address and
326        (should_resolve or in_gke)):
327      self._start_local_server()
328    else:
329      self._coordinator_address = coordinator_address
330
331  def master(self, task_type=None, task_id=None, rpc_layer=None):
332    """Get the Master string to be used for the session.
333
334    In the normal case, this returns the grpc path (grpc://1.2.3.4:8470) of
335    first instance in the ClusterSpec returned by the cluster_spec function.
336
337    If a non-TPU name is used when constructing a TPUClusterResolver, that will
338    be returned instead (e.g. If the tpus argument's value when constructing
339    this TPUClusterResolver was 'grpc://10.240.1.2:8470',
340    'grpc://10.240.1.2:8470' will be returned).
341
342    Args:
343      task_type: (Optional, string) The type of the TensorFlow task of the
344        master.
345      task_id: (Optional, integer) The index of the TensorFlow task of the
346        master.
347      rpc_layer: (Optional, string) The RPC protocol TensorFlow should use to
348        communicate with TPUs.
349
350    Returns:
351      string, the connection string to use when creating a session.
352
353    Raises:
354      ValueError: If none of the TPUs specified exists.
355    """
356    if self._shouldResolve():
357      # We are going to communicate with the Cloud TPU APIs to get a Cluster.
358      cluster_spec = self.cluster_spec()
359      if task_type is not None and task_id is not None:
360        # task_type and task_id is from the function parameter
361        master = cluster_spec.task_address(task_type, task_id)
362      elif self.task_type is not None and self.task_id is not None:
363        # task_type and task_id is from the object
364        master = cluster_spec.task_address(self.task_type, self.task_id)
365      else:
366        # by default we take the first item in the cluster with the right name
367        job_tasks = cluster_spec.job_tasks(self.task_type)
368        if not job_tasks:
369          raise ValueError('No TPUs with the specified names exist.')
370        master = job_tasks[0]
371    else:
372      if isinstance(self._tpu, (bytes, bytearray)):
373        master = compat.as_text(self._tpu).split(_ENDPOINTS_SEPARATOR)[0]
374      else:
375        master = self._tpu.split(_ENDPOINTS_SEPARATOR)[0]
376    return format_master_url(master, rpc_layer or self.rpc_layer)
377
378  def get_master(self):
379    return self.master()
380
381  def get_job_name(self):
382    if (self._shouldResolve() or
383        self._isRunningInGCE()):
384      return self.task_type
385
386  def cluster_spec(self):
387    """Returns a ClusterSpec object based on the latest TPU information.
388
389    We retrieve the information from the GCE APIs every time this method is
390    called.
391
392    Returns:
393      A ClusterSpec containing host information returned from Cloud TPUs.
394
395    Raises:
396      RuntimeError: If the provided TPU is not healthy.
397    """
398    ############################################################################
399    # There are 5 potential cases this code must handle:
400    #  1. [Normal case.] We should resolve the TPU name to a set of tasks, and
401    #      a. Create a ClusterSpec that includes the coordinator job
402    #      b. Create a ClusterSpec without the coordinator job.
403    #  2. [GKE / No API Access.] We should not resolve the TPU name to a set of
404    #     tasks and
405    #      a. Create a ClusterSpec with the coordinator
406    #      b. Create a ClusterSpec without the coordinator
407    #  3. [Other (legacy non-gRPC).] We should return an empty ClusterSpec.
408    ############################################################################
409
410    if self._shouldResolve():
411      # Case 1.
412      full_name = 'projects/%s/locations/%s/nodes/%s' % (
413          self._project, self._zone, compat.as_text(self._tpu))
414      service = self._tpuService()
415      request = service.projects().locations().nodes().get(name=full_name)
416      response = request.execute()
417
418      if 'state' in response and response['state'] != 'READY':
419        raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
420                           (compat.as_text(self._tpu), response['state']))
421
422      if 'networkEndpoints' in response:
423        worker_list = [
424            '%s:%s' % (endpoint['ipAddress'], endpoint['port'])
425            for endpoint in response['networkEndpoints']
426        ]
427      else:
428        # Fall back to the deprecated response format
429        instance_url = '%s:%s' % (response['ipAddress'], response['port'])
430        worker_list = [instance_url]
431
432      cluster_spec = {self.task_type: worker_list}
433    else:
434      if self.rpc_layer is None:
435        # Case 3.
436        return None
437      # Case 2.
438      tpus = []
439      for tpu in compat.as_text(self._tpu).split(_ENDPOINTS_SEPARATOR):
440        # We are working around the fact that GKE environment variable that is
441        # supplied to us has the protocol string embedded in it, but we want
442        # to strip it out for the ClusterSpec.
443        if (self.rpc_layer is not None and
444            tpu.startswith(self.rpc_layer + '://')):
445          tpus.append(tpu[len(self.rpc_layer + '://'):])
446        else:
447          tpus.append(tpu)
448      cluster_spec = {self.task_type: tpus}
449
450    if self._coordinator_address:
451      # {1, 2}.a
452      cluster_spec[self._coordinator_name] = [self._coordinator_address]
453
454    return server_lib.ClusterSpec(cluster_spec)
455
456  def num_accelerators(self,
457                       task_type=None,
458                       task_id=None,
459                       config_proto=None):
460    """Returns the number of TPU cores per worker.
461
462    Connects to the master and list all the devices present in the master,
463    and counts them up. Also verifies that the device counts per host in the
464    cluster is the same before returning the number of TPU cores per host.
465
466    Args:
467      task_type: Unused.
468      task_id: Unused.
469      config_proto: Used to create a connection to a TPU master in order to
470        retrieve the system metadata.
471
472    Raises:
473      RuntimeError: If we cannot talk to a TPU worker after retrying or if the
474        number of TPU devices per host is different.
475    """
476    retry_count = 1
477    # TODO(b/120564445): Replace with standard library for retries.
478    while True:
479      try:
480        device_details = TPUClusterResolver._get_device_dict_and_cores(
481            get_accelerator_devices(self.master(), config_proto=config_proto))
482        break
483      except errors.DeadlineExceededError:
484        error_message = ('Failed to connect to master. The TPU might not be '
485                         'ready (e.g. still scheduling) or the master '
486                         'address is incorrect: got (%s)' % self.master())
487        if retry_count <= _TPU_CONN_RETRIES:
488          logging.warning(error_message)
489          logging.warning('Retrying (%d/%d)...', retry_count, _TPU_CONN_RETRIES)
490          retry_count += 1
491        else:
492          raise RuntimeError(error_message)
493
494    if device_details.total_cores:
495      return {'TPU': TPUClusterResolver._verify_and_return_same_core_count(
496          device_details.device_map)}
497    return {'TPU': 0}
498
499  @property
500  def environment(self):
501    """Returns the current environment which TensorFlow is running in."""
502    return self._environment
503
504  def _start_local_server(self):
505    address = compat.as_text(self._requestComputeMetadata(
506        'instance/network-interfaces/0/ip'))
507    self._server = server_lib.Server(
508        {
509            'local': ['0.0.0.0:0']
510        }, protocol='grpc', config=None, start=True)
511    # self._server.target is of the form: grpc://ipaddress:port
512    target = compat.as_bytes(self._server.target)
513    splits = target.split(compat.as_bytes(':'))
514    assert len(splits) == 3, self._server.target
515    assert splits[0] == compat.as_bytes('grpc'), self._server.target
516    self._coordinator_port = compat.as_text(splits[2])
517    self._coordinator_address = '%s:%s' % (
518        address, compat.as_text(self._coordinator_port))
519
520  def __deepcopy__(self, memo):
521    # TODO(b/73668574): Remove this once RunConfig avoids performing deepcopy.
522    return self
523