1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementation of Cluster Resolvers for Cloud TPUs."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import re
23
24from tensorflow.python.distribute.cluster_resolver import cluster_resolver
25from tensorflow.python.framework import config as framework_config
26from tensorflow.python.framework import errors
27from tensorflow.python.platform import tf_logging as logging
28from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
29from tensorflow.python.training import server_lib
30from tensorflow.python.util import compat
31
32try:
33  from cloud_tpu_client import client  # pylint: disable=g-import-not-at-top
34except ImportError:
35  logging.debug(
36      'Falling back to TensorFlow client; we recommended you install the Cloud '
37      'TPU client directly with pip install cloud-tpu-client.')
38  from tensorflow.python.tpu.client import client  # pylint: disable=g-import-not-at-top
39
40
41def is_running_in_gce():
42  return True
43
44
45class _LocalCloudTpuClient(object):
46  """Dummy local Cloud TPU client."""
47
48  def api_available(self):
49    return False
50
51
52_TPU_DEVICE_REGEX = re.compile(
53    r'.*task:(?P<host_id>\d+)/.*device:TPU:(?P<core_id>\d+)$')
54_TPU_CONN_RETRIES = 120
55DeviceDetails = collections.namedtuple(
56    'DeviceDetails', ['device_map', 'total_cores'])
57
58
59class TPUClusterResolver(cluster_resolver.ClusterResolver):
60  """Cluster Resolver for Google Cloud TPUs.
61
62  This is an implementation of cluster resolvers for the Google Cloud TPU
63  service.
64
65  TPUClusterResolver supports the following distinct environments:
66  Google Compute Engine
67  Google Kubernetes Engine
68  Google internal
69
70  It can be passed into `tf.distribute.TPUStrategy` to support TF2 training on
71  Cloud TPUs.
72  """
73
74  @staticmethod
75  def connect(tpu=None,
76              zone=None,
77              project=None):
78    """Initializes TPU and returns a TPUClusterResolver.
79
80    This API will connect to remote TPU cluster and initialize the TPU
81    hardwares. Example usage:
82
83    >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver.connect(
84    ...     tpu='')
85
86    It can be viewed as convenient wrapper of the following code:
87
88    >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
89    >>> tf.config.experimental_connect_to_cluster(resolver)
90    >>> tf.tpu.experimental.initialize_tpu_system(resolver)
91
92    Args:
93      tpu: A string corresponding to the TPU to use. It can be the TPU name or
94        TPU worker gRPC address. If not set, it will try automatically resolve
95        the TPU address on Cloud TPUs.
96      zone: Zone where the TPUs are located. If omitted or empty, we will assume
97        that the zone of the TPU is the same as the zone of the GCE VM, which we
98        will try to discover from the GCE metadata service.
99      project: Name of the GCP project containing Cloud TPUs. If omitted or
100        empty, we will try to discover the project name of the GCE VM from the
101        GCE metadata service.
102
103    Returns:
104      An instance of TPUClusterResolver object.
105
106    Raises:
107      NotFoundError: If no TPU devices found in eager mode.
108    """
109    resolver = TPUClusterResolver(tpu, zone, project)
110    from tensorflow.python.eager import remote  # pylint: disable=g-import-not-at-top
111    remote.connect_to_cluster(resolver)
112    from tensorflow.python.tpu import tpu_strategy_util  # pylint: disable=g-import-not-at-top
113    tpu_strategy_util.initialize_tpu_system(resolver)
114    return resolver
115
116  @staticmethod
117  def _get_device_dict_and_cores(devices):
118    """Returns a dict of hosts to cores and total cores given devices names.
119
120    Returns a namedtuple with two attributes:
121      device_map: A map of host_ids to a list of core_ids.
122      total_cores: The total number of cores within the TPU system.
123
124    Args:
125      devices: A list of devices returned by session.list_devices()
126    """
127    device_map = collections.defaultdict(list)
128    num_cores = 0
129    for device in devices:
130      match = _TPU_DEVICE_REGEX.match(device.name)
131      if match:
132        host_id = match.group('host_id')
133        core_id = match.group('core_id')
134        device_map[host_id].append(core_id)
135        num_cores += 1
136    return DeviceDetails(device_map, num_cores)
137
138  @staticmethod
139  def _verify_and_return_same_core_count(device_dict):
140    """Verifies that every device in device_dict has the same # of cores."""
141    num_cores_per_host_set = (
142        {len(core_ids) for core_ids in device_dict.values()})
143    if len(num_cores_per_host_set) != 1:
144      raise RuntimeError('TPU cores on each device is not the same. This '
145                         'should never happen. Devices: {}'.format(device_dict))
146    return num_cores_per_host_set.pop()
147
148  def __init__(self,
149               tpu=None,
150               zone=None,
151               project=None,
152               job_name='worker',
153               coordinator_name=None,
154               coordinator_address=None,
155               credentials='default',
156               service=None,
157               discovery_url=None):
158    """Creates a new TPUClusterResolver object.
159
160    The ClusterResolver will then use the parameters to query the Cloud TPU APIs
161    for the IP addresses and ports of each Cloud TPU listed.
162
163    Args:
164      tpu: A string corresponding to the TPU to use. It can be the TPU name or
165        TPU worker gRPC address. If not set, it will try automatically resolve
166        the TPU address on Cloud TPUs. If set to "local", it will assume that
167        the TPU is directly connected to the VM instead of over the network.
168      zone: Zone where the TPUs are located. If omitted or empty, we will assume
169        that the zone of the TPU is the same as the zone of the GCE VM, which we
170        will try to discover from the GCE metadata service.
171      project: Name of the GCP project containing Cloud TPUs. If omitted or
172        empty, we will try to discover the project name of the GCE VM from the
173        GCE metadata service.
174      job_name: Name of the TensorFlow job the TPUs belong to.
175      coordinator_name: The name to use for the coordinator. Set to None if the
176        coordinator should not be included in the computed ClusterSpec.
177      coordinator_address: The address of the coordinator (typically an ip:port
178        pair). If set to None, a TF server will be started. If coordinator_name
179        is None, a TF server will not be started even if coordinator_address is
180        None.
181      credentials: GCE Credentials. If None, then we use default credentials
182        from the oauth2client
183      service: The GCE API object returned by the googleapiclient.discovery
184        function. If you specify a custom service object, then the credentials
185        parameter will be ignored.
186      discovery_url: A URL template that points to the location of the discovery
187        service. It should have two parameters {api} and {apiVersion} that when
188        filled in produce an absolute URL to the discovery document for that
189        service. The environment variable 'TPU_API_DISCOVERY_URL' will override
190        this.
191
192    Raises:
193      ImportError: If the googleapiclient is not installed.
194      ValueError: If no TPUs are specified.
195      RuntimeError: If an empty TPU name is specified and this is running in a
196        Google Cloud environment.
197    """
198
199    if tpu != 'local':
200      # Default Cloud environment
201      self._cloud_tpu_client = client.Client(
202          tpu=tpu,
203          zone=zone,
204          project=project,
205          credentials=credentials,
206          service=service,
207          discovery_url=discovery_url)
208      self._tpu = self._cloud_tpu_client.name()
209    else:
210      # Directly connected TPU environment
211      self._cloud_tpu_client = _LocalCloudTpuClient()
212      self._tpu = 'local'
213
214    # By default the task_type is 'worker` and the task_id is 0 (which is the
215    # first worker in the task).
216    self.task_type = job_name
217    self.task_id = 0
218    self._coordinator_name = coordinator_name
219    if (coordinator_name and not coordinator_address):
220      self._start_local_server()
221    else:
222      self._coordinator_address = coordinator_address
223
224  def __enter__(self):
225    self._cloud_tpu_client.enter()
226
227  def __exit__(self, type, value, traceback):  # pylint: disable=redefined-builtin
228    self._cloud_tpu_client.exit(type, value, traceback)
229
230  def master(self, task_type=None, task_id=None, rpc_layer=None):
231    """Get the Master string to be used for the session.
232
233    In the normal case, this returns the grpc path (grpc://1.2.3.4:8470) of
234    first instance in the ClusterSpec returned by the cluster_spec function.
235
236    If a non-TPU name is used when constructing a TPUClusterResolver, that will
237    be returned instead (e.g. If the tpus argument's value when constructing
238    this TPUClusterResolver was 'grpc://10.240.1.2:8470',
239    'grpc://10.240.1.2:8470' will be returned).
240
241    Args:
242      task_type: (Optional, string) The type of the TensorFlow task of the
243        master.
244      task_id: (Optional, integer) The index of the TensorFlow task of the
245        master.
246      rpc_layer: (Optional, string) The RPC protocol TensorFlow should use to
247        communicate with TPUs.
248
249    Returns:
250      string, the connection string to use when creating a session.
251
252    Raises:
253      ValueError: If none of the TPUs specified exists.
254    """
255
256    if self._tpu != 'local':
257      cluster_spec = self.cluster_spec()
258      if task_type is not None and task_id is not None:
259        # task_type and task_id is from the function parameter
260        master = cluster_spec.task_address(task_type, task_id)
261      elif self.task_type is not None and self.task_id is not None:
262        # task_type and task_id is from the object
263        master = cluster_spec.task_address(self.task_type, self.task_id)
264      else:
265        # by default we take the first item in the cluster with the right name
266        job_tasks = cluster_spec.job_tasks(self.task_type)
267        if not job_tasks:
268          raise ValueError('No TPUs with the specified names exist.')
269        master = job_tasks[0]
270      return cluster_resolver.format_master_url(master, 'grpc')
271    else:
272      return ''
273
274  def get_master(self):
275    return self.master()
276
277  def get_job_name(self):
278    return self.task_type
279
280  def get_tpu_system_metadata(self):
281    """Returns the metadata of the TPU system.
282
283    Users can call this method to get some facts of the TPU system, like
284    total number of cores, number of TPU workers and the devices. E.g.
285    ```python
286
287    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
288    tpu_system_medata = resolver.get_tpu_system_metadata()
289    num_hosts = tpu_system_medata.num_hosts
290    ```
291
292    Returns:
293      A `tf.tpu.experimental.TPUSystemMetadata` object.
294    """
295    cluster_spec = self.cluster_spec()
296    cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None
297    tpu_system_metadata = (
298        tpu_system_metadata_lib._query_tpu_system_metadata(  # pylint: disable=protected-access
299            self.master(),
300            cluster_def=cluster_def,
301            query_topology=False))
302
303    return tpu_system_metadata
304
305  def cluster_spec(self):
306    """Returns a ClusterSpec object based on the latest TPU information.
307
308    We retrieve the information from the GCE APIs every time this method is
309    called.
310
311    Returns:
312      A ClusterSpec containing host information returned from Cloud TPUs,
313      or None.
314
315    Raises:
316      RuntimeError: If the provided TPU is not healthy.
317    """
318    ############################################################################
319    # There are 6 potential cases this code must handle:
320    #  0. [Local case.] When a TPU is connected directly to the VM.
321    #  1. [Normal case.] We should resolve the TPU name to a set of tasks, and
322    #      a. Create a ClusterSpec that includes the coordinator job
323    #      b. Create a ClusterSpec without the coordinator job.
324    #  2. [GKE / No API Access.] We should not resolve the TPU name to a set of
325    #     tasks and
326    #      a. Create a ClusterSpec with the coordinator
327    #      b. Create a ClusterSpec without the coordinator
328    ############################################################################
329
330    if self._tpu != 'local':
331      network_endpoints = self._cloud_tpu_client.network_endpoints()
332      worker_list = [
333          '%s:%s' % (endpoint['ipAddress'], endpoint['port'])
334          for endpoint in network_endpoints
335      ]
336      cluster_spec = {self.task_type: worker_list}
337      if self._coordinator_address:
338        # {1, 2}.a
339        cluster_spec[self._coordinator_name] = [self._coordinator_address]
340      return server_lib.ClusterSpec(cluster_spec)
341    else:
342      return server_lib.ClusterSpec({})
343
344  def num_accelerators(self,
345                       task_type=None,
346                       task_id=None,
347                       config_proto=None):
348    """Returns the number of TPU cores per worker.
349
350    Connects to the master and list all the devices present in the master,
351    and counts them up. Also verifies that the device counts per host in the
352    cluster is the same before returning the number of TPU cores per host.
353
354    Args:
355      task_type: Unused.
356      task_id: Unused.
357      config_proto: Used to create a connection to a TPU master in order to
358        retrieve the system metadata.
359
360    Raises:
361      RuntimeError: If we cannot talk to a TPU worker after retrying or if the
362        number of TPU devices per host is different.
363    """
364    if self._tpu == 'local':
365      return {
366          'TPU':
367              len([
368                  d for d in framework_config.list_logical_devices()
369                  if d.device_type == 'TPU'
370              ])
371      }
372
373    retry_count = 1
374    # TODO(b/120564445): Replace with standard library for retries.
375    while True:
376      try:
377        device_details = TPUClusterResolver._get_device_dict_and_cores(
378            cluster_resolver.get_accelerator_devices(
379                self.master(), config_proto=config_proto))
380        break
381      except errors.DeadlineExceededError:
382        error_message = ('Failed to connect to master. The TPU might not be '
383                         'ready (e.g. still scheduling) or the master '
384                         'address is incorrect: got (%s)' % self.master())
385        if retry_count <= _TPU_CONN_RETRIES:
386          logging.warning(error_message)
387          logging.warning('Retrying (%d/%d)...', retry_count, _TPU_CONN_RETRIES)
388          retry_count += 1
389        else:
390          raise RuntimeError(error_message)
391
392    if device_details.total_cores:
393      return {
394          'TPU':
395              TPUClusterResolver._verify_and_return_same_core_count(
396                  device_details.device_map)
397      }
398    return {'TPU': 0}
399
400  @property
401  def environment(self):
402    """Returns the current environment which TensorFlow is running in."""
403    return self._environment
404
405  def _start_local_server(self):
406    address = compat.as_text(self._cloud_tpu_client.get_local_ip())
407    self._server = server_lib.Server({'local': ['0.0.0.0:0']},
408                                     protocol='grpc',
409                                     config=None,
410                                     start=True)
411    # self._server.target is of the form: grpc://ipaddress:port
412    target = compat.as_bytes(self._server.target)
413    splits = target.split(compat.as_bytes(':'))
414    assert len(splits) == 3, self._server.target
415    assert splits[0] == compat.as_bytes('grpc'), self._server.target
416    self._coordinator_port = compat.as_text(splits[2])
417    self._coordinator_address = '%s:%s' % (
418        address, compat.as_text(self._coordinator_port))
419
420  def __deepcopy__(self, memo):
421    # TODO(b/73668574): Remove this once RunConfig avoids performing deepcopy.
422    return self
423