1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# Lint as: python3
16"""Cloud TPU Client."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import datetime
23import json
24import logging
25import os
26import time
27
28from absl import flags
29from concurrent import futures
30from six.moves.urllib import request
31from six.moves.urllib.error import HTTPError
32
33_GOOGLE_API_CLIENT_INSTALLED = True
34try:
35  from googleapiclient import discovery  # pylint: disable=g-import-not-at-top
36  from oauth2client import client  # pylint: disable=g-import-not-at-top
37except ImportError:
38  _GOOGLE_API_CLIENT_INSTALLED = False
39
40FLAGS = flags.FLAGS
41
42flags.DEFINE_bool('runtime_oom_exit', True,
43                  'Exit the script when the TPU runtime is OOM.')
44flags.DEFINE_bool('hbm_oom_exit', True,
45                  'Exit the script when the TPU HBM is OOM.')
46
47_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
48_ENDPOINTS_SEPARATOR = ','
49_DEFAULT_ENV_VARIABLE = 'TPU_NAME'
50_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL'
51_GCE_METADATA_URL_ENV_VARIABLE = 'GCE_METADATA_IP'
52_DEFAULT_ENDPOINT_PORT = '8470'
53_OOM_EVENT_COOL_TIME_SEC = 90
54_VERSION_SWITCHER_ENDPOINT = 'http://{}:8475/requestversion'
55
56
57def _utcnow():
58  """A wrapper function around datetime.datetime.utcnow.
59
60  This function is created for unit testing purpose. It's not easy to do
61  StubOutWithMock with datetime.datetime package.
62
63  Returns:
64    datetime.datetime
65  """
66  return datetime.datetime.utcnow()
67
68
69def _environment_discovery_url():
70  return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE)
71
72
73def _gce_metadata_endpoint():
74  return 'http://' + os.environ.get(_GCE_METADATA_URL_ENV_VARIABLE,
75                                    'metadata.google.internal')
76
77
78def _request_compute_metadata(path):
79  req = request.Request(
80      '%s/computeMetadata/v1/%s' % (_gce_metadata_endpoint(), path),
81      headers={'Metadata-Flavor': 'Google'})
82  resp = request.urlopen(req)
83  return _as_text(resp.read())
84
85
86def _environment_var_to_network_endpoints(endpoints):
87  """Yields a dict with ip address and port."""
88  for endpoint in endpoints.split(','):
89    grpc_prefix = 'grpc://'
90    if endpoint.startswith(grpc_prefix):
91      endpoint = endpoint.split(grpc_prefix)[1]
92    parts = endpoint.split(':')
93    ip_address = parts[0]
94    port = _DEFAULT_ENDPOINT_PORT
95    if len(parts) > 1:
96      port = parts[1]
97    yield {
98        'ipAddress': ip_address,
99        'port': port
100    }
101
102
103def _get_tpu_name(tpu):
104  if tpu:
105    return tpu
106
107  for e in [_GKE_ENV_VARIABLE, _DEFAULT_ENV_VARIABLE]:
108    if e in os.environ:
109      return os.environ[e]
110  return None
111
112
113def _as_text(s):
114  if isinstance(s, bytes):
115    return s.decode('utf-8')
116  return s
117
118
119class Client(object):
120  """Client for working with the Cloud TPU API.
121
122  This client is intended to be used for resolving tpu name to ip addresses.
123
124  It's recommended to use this library as a contextlib to utilize all
125  functionality.
126  """
127
128  def __init__(self,
129               tpu=None,
130               zone=None,
131               project=None,
132               credentials='default',
133               service=None,
134               discovery_url=None):
135    if isinstance(tpu, list):
136      if not tpu:
137        raise ValueError('At least one TPU must be specified.')
138      if len(tpu) != 1:
139        raise NotImplementedError(
140            'Using multiple TPUs in a single session is not yet implemented')
141      tpu = tpu[0]
142
143    tpu = _get_tpu_name(tpu)
144
145    if tpu is None:
146      raise ValueError('Please provide a TPU Name to connect to.')
147
148    self._tpu = _as_text(tpu)
149
150    self._use_api = not self._tpu.startswith('grpc://')
151    self._service = service
152
153    self._credentials = None
154    self._project = None
155    self._zone = None
156    self._discovery_url = None
157    if self._use_api:
158      if credentials != 'default':
159        self._credentials = credentials
160      # Automatically detect project and zone if unspecified.
161      if project:
162        self._project = _as_text(project)
163      else:
164        self._project = _request_compute_metadata('project/project-id')
165      if zone:
166        self._zone = _as_text(zone)
167      else:
168        zone_path = _request_compute_metadata('instance/zone')
169        self._zone = zone_path.split('/')[-1]
170      self._discovery_url = _environment_discovery_url() or discovery_url
171
172  def _symptom_msg(self, msg):
173    """Return the structured Symptom message."""
174    return 'Symptom: ' + msg
175
176  def _oom_event(self, symptoms):
177    """Check if a runtime OOM event is reported."""
178    if not symptoms:
179      return False
180    for symptom in reversed(symptoms):
181      if symptom['symptomType'] != 'OUT_OF_MEMORY':
182        continue
183      oom_datetime_str = symptom['createTime'].split('.')[0]
184      oom_datetime = datetime.datetime.strptime(oom_datetime_str,
185                                                '%Y-%m-%dT%H:%M:%S')
186      time_diff = _utcnow() - oom_datetime
187      if time_diff < datetime.timedelta(seconds=_OOM_EVENT_COOL_TIME_SEC):
188        logging.warning(self._symptom_msg(
189            'a recent runtime OOM has occured ~{} seconds ago. The model '
190            'script will terminate automatically. To prevent future OOM '
191            'events, please consider reducing the model size. To disable this '
192            'behavior, set flag --runtime_oom_exit=false when starting the '
193            'script.'.format(time_diff.seconds)))
194        return True
195    return False
196
197  def _hbm_oom_event(self, symptoms):
198    """Check if a HBM OOM event is reported."""
199    if not symptoms:
200      return False
201    for symptom in reversed(symptoms):
202      if symptom['symptomType'] != 'HBM_OUT_OF_MEMORY':
203        continue
204      oom_datetime_str = symptom['createTime'].split('.')[0]
205      oom_datetime = datetime.datetime.strptime(oom_datetime_str,
206                                                '%Y-%m-%dT%H:%M:%S')
207      time_diff = _utcnow() - oom_datetime
208      if time_diff < datetime.timedelta(seconds=_OOM_EVENT_COOL_TIME_SEC):
209        logging.warning(self._symptom_msg(
210            'a recent HBM OOM has occured ~{} seconds ago. The model '
211            'script will terminate automatically. To prevent future HBM OOM '
212            'events, please consider reducing the model size. To disable this '
213            'behavior, set flag --hbm_oom_exit=false when starting the '
214            'script.'.format(time_diff.seconds)))
215        return True
216    return False
217
218  def _tpu_service(self):
219    """Creates a new Cloud TPU API object.
220
221    This works around an issue where the underlying HTTP connection sometimes
222    times out when the script has been running for too long. Other methods in
223    this object call this method to get a new API object whenever they need
224    to communicate with the Cloud API.
225
226    Raises:
227      RuntimeError: If the dependent Python packages are missing.
228
229    Returns:
230      A Google Cloud TPU API object.
231    """
232    if self._service:
233      return self._service
234
235    if not _GOOGLE_API_CLIENT_INSTALLED:
236      raise RuntimeError('Missing runtime dependency on the Google API client. '
237                         'Run `pip install cloud-tpu-client` to fix.')
238
239    credentials = self._credentials
240    if credentials is None or credentials == 'default':
241      credentials = client.GoogleCredentials.get_application_default()
242
243    if self._discovery_url:
244      return discovery.build(
245          'tpu',
246          'v1',
247          credentials=credentials,
248          discoveryServiceUrl=self._discovery_url,
249          cache_discovery=False)
250    else:
251      return discovery.build(
252          'tpu', 'v1', credentials=credentials, cache_discovery=False)
253
254  def _full_name(self):
255    """Returns the full Cloud name for this TPU."""
256    return 'projects/%s/locations/%s/nodes/%s' % (
257        self._project, self._zone, self._tpu)
258
259  def _fetch_cloud_tpu_metadata(self):
260    """Returns the TPU metadata object from the TPU Get API call."""
261    service = self._tpu_service()
262    try:
263      r = service.projects().locations().nodes().get(name=self._full_name())
264      return r.execute()
265    except Exception as e:
266      raise ValueError("Could not lookup TPU metadata from name '%s'. Please "
267                       'doublecheck the tpu argument in the TPUClusterResolver '
268                       'constructor. Exception: %s' % (self._tpu, e))
269
270  def _get_tpu_property(self, key):
271    if self._use_api:
272      metadata = self._fetch_cloud_tpu_metadata()
273      return metadata.get(key)
274
275    return None
276
277  def __enter__(self):
278    self._open = True
279
280  def __exit__(self, type, value, traceback):  # pylint: disable=redefined-builtin
281    del type, value, traceback
282
283  def recoverable(self):
284    """Returns true if the TPU is in a state where training should eventually resume.
285
286    If false the TPU is in a unrecoverable state and should be recreated.
287    """
288    state = self.state()
289    symptoms = self.symptoms()
290    if state and state in ['TERMINATED', 'PREEMPTED']:
291      return False
292    elif FLAGS.runtime_oom_exit and self._oom_event(symptoms):
293      return False
294    elif FLAGS.hbm_oom_exit and self._hbm_oom_event(symptoms):
295      return False
296    return True
297
298  def symptoms(self):
299    """Return Cloud TPU Symptoms of the TPU."""
300    return self._get_tpu_property('symptoms')
301
302  def state(self):
303    """Return state of the TPU."""
304    return self._get_tpu_property('state')
305
306  def health(self):
307    """Return health of the TPU."""
308    return self._get_tpu_property('health')
309
310  def runtime_version(self):
311    """Return runtime version of the TPU."""
312
313    if not self._use_api:
314      # Fallback on getting version directly from TPU.
315      url = _VERSION_SWITCHER_ENDPOINT.format(
316          self.network_endpoints()[0]['ipAddress'])
317      try:
318        req = request.Request(url)
319        resp = request.urlopen(req)
320        version_details = json.loads(resp.read())
321        return version_details.get('currentVersion')
322      except HTTPError as e:
323        status_code = e.code
324        if status_code == 404:
325          return None
326        else:
327          raise e
328    return self._get_tpu_property('tensorflowVersion')
329
330  def accelerator_type(self):
331    """Return accelerator type of the TPU."""
332    return self._get_tpu_property('acceleratorType')
333
334  def api_available(self):
335    """Return if the Cloud TPU API is available, if not certain features will not work."""
336    return self._use_api
337
338  def name(self):
339    """Return the name of the tpu, or the ip address if name is not provided."""
340    return self._tpu
341
342  def get_local_ip(self):
343    """Return the local ip address of the Google Cloud VM the workload is running on."""
344    return _request_compute_metadata('instance/network-interfaces/0/ip')
345
346  def network_endpoints(self):
347    """Return a list of tpu endpoints."""
348    if not self._use_api:
349      return list(_environment_var_to_network_endpoints(self._tpu))
350    response = self._fetch_cloud_tpu_metadata()
351
352    if response.get('state') != 'READY':
353      raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
354                         (self._tpu, response.get('state')))
355    if 'networkEndpoints' in response:
356      return response['networkEndpoints']
357    else:
358      return [{'ipAddress': response['ipAddress'], 'port': response['port']}]
359
360  def wait_for_healthy(self, timeout_s=1200, interval=30):
361    """Wait for TPU to become healthy or raise error if timeout reached.
362
363    Args:
364      timeout_s (int): The timeout in seconds for waiting TPU to become healthy.
365      interval (int): The interval in seconds to poll the TPU for health.
366
367    Raises:
368      RuntimeError: If the TPU doesn't become healthy by the timeout.
369    """
370    timeout = time.time() + timeout_s
371    while self.health() != 'HEALTHY':
372      logging.warning(
373          ('Waiting for TPU "%s" with state "%s" '
374           'and health "%s" to become healthy'),
375          self.name(), self.state(), self.health())
376      if time.time() + interval > timeout:
377        raise RuntimeError(
378            'Timed out waiting for TPU "%s" to become healthy' % self.name())
379      time.sleep(interval)
380
381    logging.warning('TPU "%s" is healthy.', self.name())
382
383  def configure_tpu_version(self, version, restart_type='always'):
384    """Configure TPU software version.
385
386    Args:
387      version (string): Version of software to configure the TPU with.
388      restart_type (string): Restart behaviour when switching versions,
389        defaults to always restart. Options are 'always', 'ifNeeded'.
390
391    """
392
393    def configure_worker(worker):
394      """Configure individual TPU worker.
395
396      Args:
397        worker: A dict with the field ipAddress where the configure request will
398          be sent.
399      """
400      ip_address = worker['ipAddress']
401      url = (_VERSION_SWITCHER_ENDPOINT + '/{}?restartType={}').format(
402          ip_address, version, restart_type)
403      req = request.Request(url, data=b'')
404      try:
405        request.urlopen(req)
406      except HTTPError as e:
407        status_code = e.code
408        if status_code == 404:
409          raise Exception(
410              'Tensorflow version {} is not available on Cloud TPU, '
411              'try a previous nightly version or refer to '
412              'https://cloud.google.com/tpu/docs/release-notes for '
413              'the latest official version.'.format(version))
414        else:
415          raise Exception('Failed to configure worker {}'.format(ip_address))
416
417    workers = self.network_endpoints()
418
419    with futures.ThreadPoolExecutor(max_workers=len(workers)) as executor:
420      results = executor.map(configure_worker, workers)
421      for result in results:
422        if result:
423          result.result()
424