1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===================================================================
15"""TPU system metadata and associated tooling."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import re
23
24from tensorflow.core.protobuf import config_pb2
25from tensorflow.python.client import session as session_lib
26from tensorflow.python.eager import context
27from tensorflow.python.framework import device as tf_device
28from tensorflow.python.framework import errors
29from tensorflow.python.framework import ops
30from tensorflow.python.platform import tf_logging as logging
31from tensorflow.python.tpu import tpu
32
33_PINGING_MASTER_TIMEOUT_IN_MS = 60 * 1000  # 1 min
34_RETRY_TIMES = 120
35_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000  # 5 mins
36
37_TPU_DEVICE_REG = re.compile(r'.*task:(\d+)/.*device:TPU:(\d+)$')
38_DEVICE_TYPE_REGEX = re.compile('.*device:([^:]+).*')
39
40_DEFAULT_JOB_NAME = 'tpu_worker'
41_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
42_LOCAL_MASTERS = ('', 'local')
43
44# _TPUSystemMetadata is used by TPUEstimator to hold TPU configuration,
45# including num_cores and num_hosts.
46_TPUSystemMetadata = collections.namedtuple('_TPUSystemMetadata', [
47    'num_cores',
48    'num_hosts',
49    'num_of_cores_per_host',
50    'topology',
51    'devices',
52])
53
54
55def _query_tpu_system_metadata(master_address, cluster_def=None,
56                               query_topology=False):
57  """Automatically detects the TPU system metadata in the system."""
58  tpu_core_count = 0
59  devices = []
60  device_dict = collections.defaultdict(list)
61
62  if context.executing_eagerly():
63    device_names = context.list_devices()
64    devices = []
65
66    # We want the output type to match in both eager and session mode
67    for name in device_names:
68      device_match = _DEVICE_TYPE_REGEX.match(name)
69      device_type = 'CPU'
70      if device_match:
71        device_type = device_match.group(1)
72      devices.append(session_lib._DeviceAttributes(name, device_type, 0, 0))  # pylint: disable=protected-access
73  else:
74    # TODO(b/120564445): Replace with standard library for retries.
75    retry_count = 1
76    while True:
77      logging.info('Querying Tensorflow master (%s) for TPU system metadata.',
78                   master_address)
79      try:
80        with ops.Graph().as_default():
81          with session_lib.Session(
82              master_address,
83              config=get_session_config_with_timeout(
84                  _PINGING_MASTER_TIMEOUT_IN_MS,
85                  cluster_def)) as sess:
86            devices = sess.list_devices()
87            break
88      except errors.DeadlineExceededError:
89        msg = ('Failed to connect to the Tensorflow master. The TPU worker may '
90               'not be ready (still scheduling) or the Tensorflow master '
91               'address is incorrect: got (%s).' %
92               (master_address))
93
94        # TODO(xiejw): For local or grpc master we might not need retry logic
95        # here.
96        if retry_count <= _RETRY_TIMES:
97          logging.warning('%s', msg)
98          logging.warning('Retrying (%d/%d).', retry_count, _RETRY_TIMES)
99          retry_count += 1
100        else:
101          raise ValueError(msg)
102
103  for device in devices:
104    match = _TPU_DEVICE_REG.match(device.name)
105    if match:
106      host_id = match.group(1)
107      core_id = match.group(2)
108      device_dict[host_id].append(core_id)
109      tpu_core_count += 1
110
111  num_of_cores_per_host = 0
112  if tpu_core_count:
113    num_cores_per_host_set = set(
114        [len(core_ids) for core_ids in device_dict.values()])
115    if len(num_cores_per_host_set) != 1:
116      raise RuntimeError(
117          'TPU cores on each host is not same. This should not happen!. '
118          'devices: {}'.format(devices))
119    num_of_cores_per_host = num_cores_per_host_set.pop()
120
121  topology = None
122  if query_topology:
123    if not tpu_core_count:
124      raise RuntimeError(
125          'Cannot find any TPU cores in the system (master address {}). '
126          'This usually means the master address is incorrect or the '
127          'TPU worker has some problems. Available devices: {}'.format(
128              master_address, devices))
129
130    topology = _obtain_topology(master_address, cluster_def)
131
132  # We sort the metadata devices so that downstream users get a sorted list
133  # for creating mirrored variables correctly.
134  def _sort_key(device):
135    spec = tf_device.DeviceSpec.from_string(device.name)
136    return (spec.job, spec.replica, spec.task, spec.device_type,
137            spec.device_index)
138  devices = tuple(sorted(devices, key=_sort_key))
139
140  metadata = _TPUSystemMetadata(
141      num_cores=tpu_core_count,
142      num_hosts=len(device_dict),
143      num_of_cores_per_host=num_of_cores_per_host,
144      topology=topology,
145      devices=devices)
146
147  if tpu_core_count:
148    logging.info('Found TPU system:')
149    logging.info('*** Num TPU Cores: %d', metadata.num_cores)
150    logging.info('*** Num TPU Workers: %d', metadata.num_hosts)
151    logging.info('*** Num TPU Cores Per Worker: %d',
152                 metadata.num_of_cores_per_host)
153    for device in metadata.devices:
154      logging.info('*** Available Device: %s', device)
155  else:
156    logging.info('Failed to find TPU: %s', metadata)
157  return metadata
158
159
160def _obtain_topology(master_address, cluster_def):
161  """Obtains TPU fabric topology."""
162  try:
163    logging.info('Initializing TPU system (master: %s) to fetch topology '
164                 'for model parallelism. This might take a while.',
165                 master_address)
166    with ops.Graph().as_default():
167      session_config = get_session_config_with_timeout(
168          _INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, cluster_def)
169      with session_lib.Session(
170          master_address, config=session_config) as sess:
171        topology = sess.run(tpu.initialize_system())
172        return topology
173  except errors.DeadlineExceededError:
174    raise ValueError(
175        'Fail to initialize TPU system with master (%s). '
176        'Please double check the TPU system is functional.' % (
177            master_address))
178
179
180def get_session_config_with_timeout(timeout_in_secs, cluster_def):
181  """Returns a session given a timeout and a cluster configuration."""
182  config = config_pb2.ConfigProto(
183      operation_timeout_in_ms=timeout_in_secs, cluster_def=cluster_def)
184  return config
185
186
187def master_job(master, cluster_def):
188  """Returns the canonnical job name to use to place TPU computations on.
189
190  Args:
191    master: A `string` representing the TensorFlow master to use.
192    cluster_def: A ClusterDef object describing the TPU cluster.
193
194
195  Returns:
196    A string containing the job name, or None if no job should be specified.
197
198  Raises:
199    ValueError: If the user needs to specify a tpu_job_name, because we are
200      unable to infer the job name automatically, or if the user-specified job
201      names are inappropriate.
202  """
203  # If the user specifies the tpu_job_name, use that.
204
205  if master in _LOCAL_MASTERS:
206    return None
207
208  if (not cluster_def or not cluster_def.job):
209    return _DEFAULT_JOB_NAME
210  job_names = set([job.name for job in cluster_def.job])
211  if _DEFAULT_JOB_NAME in job_names:
212    # b/37868888 tracks allowing ClusterSpec propagation to reuse job names.
213    raise ValueError('Currently, tpu_worker is not an allowed job name.')
214  if len(job_names) == 1:
215    return cluster_def.job[0].name
216  if len(job_names) == 2:
217    if _DEFAULT_COORDINATOR_JOB_NAME in job_names:
218      job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME)
219      return job_names.pop()
220    # TODO(b/67716447): Include more sophisticated heuristics.
221  raise ValueError(
222      'Could not infer TPU job name. Please specify a tpu_job_name as part '
223      'of your TPUConfig.')
224