1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===================================================================
15"""TPU system metadata and associated tooling."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from contextlib import contextmanager
22import copy
23
24from tensorflow.python.estimator import model_fn as model_fn_lib
25from tensorflow.python.platform import tf_logging as logging
26from tensorflow.python.tpu import _tpu_estimator_embedding
27from tensorflow.python.tpu import device_assignment as tpu_device_assignment
28from tensorflow.python.tpu import tpu_config
29from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
30
31
32_DEFAULT_JOB_NAME = 'tpu_worker'
33_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
34_LOCAL_MASTERS = ('', 'local')
35_NUM_CORES_TO_COMPUTATION_SHAPE = {
36    1: [1, 1, 1],
37    2: [1, 1, 2],
38    4: [1, 2, 2],
39    8: [2, 2, 2],
40    16: [4, 2, 2],
41}
42
43
44class TPUContext(object):
45  """A context that holds the current configuration of the TPU computation."""
46
47  def __init__(self,
48               internal_ctx,
49               input_device=None,
50               invocation_index=None,
51               call_from_input_fn=True):
52    self._internal_ctx = internal_ctx
53    self._input_device = input_device
54    self._invocation_index = invocation_index
55    self._call_from_input_fn = call_from_input_fn
56
57  def current_input_fn_deployment(self):
58    """The configuration of the current input_fn invocation.
59
60    The configuration depends on `TPUConfig.per_host_input_for_training`. See
61    `TPUConfig` for details.
62
63    Only set in params dict of input_fn
64
65    Returns:
66      A tuple of
67        1. Device spec string: String, is the current CPU host where the
68           input_fn is invoked.
69        2. Current invocation index: Int, 0-based index of the input_fn
70           invocation. See next item for details.
71        3. Total invocation count: Int, the total number of times to invoke the
72           input_fn on all CPU hosts. Each invocation will be passed with a new
73           `TPUContext` instance with current invocation index set properly.
74        4. Total number of replicas consumed by current_invocation: Int, the
75           number of replicas fed by the data returned by current input_fn. For
76           example, for per_core input pipeline deployment
77           and non-model-parallelism, total invocation count is equal to
78           the number of cores in the system and num replicas consumed by
79           current invocation is 1. For per-host v2 input pipeline deployment,
80           total invocation count is equal to the number of hosts in the system
81           and num replicas consumed by current invocation is equal to number of
82           cores per host.
83
84    Raises:
85      RuntimeError: If this method must not be called from input_fn.
86    """
87    if not self._call_from_input_fn:
88      raise RuntimeError('This TPUContext instance must not be called from'
89                         ' model_fn.')
90
91    if self._internal_ctx.is_input_sharded_per_core():
92      total_invocation_count = (self._internal_ctx.num_hosts
93                                * self._internal_ctx.num_of_replicas_per_host)
94      replicas_consumed = 1
95    elif self._internal_ctx.is_input_broadcast_with_iterators():
96      total_invocation_count = 1
97      replicas_consumed = self._internal_ctx.num_replicas
98    else:
99      total_invocation_count = self._internal_ctx.num_hosts
100      replicas_consumed = self._internal_ctx.num_of_replicas_per_host
101    return (self._input_device, self._invocation_index,
102            total_invocation_count, replicas_consumed)
103
104  @property
105  def num_replicas(self):
106    """The total number of replicas.
107
108    For non-model-parallelism, num_replicas should be the total num of TPU
109    cores in the system.
110
111    Returns:
112      The number of replicas.
113    """
114    return self._internal_ctx.num_replicas
115
116  @property
117  def num_hosts(self):
118    """The number of hosts for the TPU system."""
119    return self._internal_ctx.num_hosts
120
121  @property
122  def current_host(self):
123    """The current host index for the TPU system."""
124    return self._invocation_index
125
126  @property
127  def num_of_replicas_per_host(self):
128    """The number of replicas for each host."""
129    if self._internal_ctx.model_parallelism_enabled:
130      raise ValueError(
131          'num_of_replicas_per_host is not supported for model_parallelism')
132    return self._internal_ctx.num_of_replicas_per_host
133
134  @property
135  def device_assignment(self):
136    """Returns device_assignment object."""
137    if self._call_from_input_fn:
138      raise RuntimeError('This TPUContext instance must not be called from'
139                         ' input_fn.')
140    return self._internal_ctx.device_assignment
141
142  def device_for_replica(self, replica_id):
143    """Returns the tuple of (CPU device and device ordinal) for replica.
144
145    This should be used for full replicate for non-model-parallelism.
146
147    Args:
148       replica_id: Int, the replica index.
149
150    Returns:
151       A tuple of device spec for CPU device and int device ordinal.
152    """
153    # Note that: For the non-model parallelism, the mapping could be
154    # a random permutation. The order should not matter in most cases
155    # as far as model is replicated to all cores in the system.
156    return self._internal_ctx.device_for_replica(replica_id)
157
158  @property
159  def tpu_host_placement_function(self):
160    """Returns the TPU host place function.
161
162    The place function takes host_id as the input and returns the TF device
163    for the correspoding host.
164    """
165
166    def _placement_function(host_id):
167      """Return the host device given host_id."""
168      return self._internal_ctx.tpu_host_placement_function(host_id=host_id)
169
170    return _placement_function
171
172
173class _InternalTPUContext(object):
174  """A context holds immutable states of TPU computation.
175
176  This immutable object holds TPUEstimator config, train/eval batch size, and
177  `TPUEstimator.use_tpu`, which is expected to be passed around. It also
178  provides utility functions, based on the current state, to determine other
179  information commonly required by TPU computation, such as TPU device names,
180  TPU hosts, shard batch size, etc.
181
182  if eval_on_tpu is False, then execution of eval on TPU is disabled.
183  if eval_on_tpu is True, but use_tpu is False, a warning is issued,
184  and TPU execution is disabled for all modes.
185
186  N.B. As `mode` is not immutable state in Estimator, but essential to
187  distinguish between TPU training and evaluation, a common usage for
188  _InternalTPUContext with `mode` is as follows:
189  ```
190  with _ctx.with_mode(mode) as ctx:
191    if ctx.is_running_on_cpu():
192       ...
193  ```
194  """
195
196  def __init__(self,
197               config,
198               train_batch_size,
199               eval_batch_size,
200               predict_batch_size,
201               use_tpu,
202               eval_on_tpu=True,
203               embedding_config_spec=None):
204    self._config = config
205    self._train_batch_size = train_batch_size
206    self._eval_batch_size = eval_batch_size
207    self._predict_batch_size = predict_batch_size
208    self._use_tpu = use_tpu
209    logging.info('_TPUContext: eval_on_tpu %s', eval_on_tpu)
210    if not use_tpu and eval_on_tpu:
211      logging.warning('eval_on_tpu ignored because use_tpu is False.')
212
213    self._eval_on_tpu = eval_on_tpu
214    self._model_parallelism_enabled = (
215        use_tpu and config.tpu_config.num_cores_per_replica)
216    self._mode = None
217    num_cores_per_replica = config.tpu_config.num_cores_per_replica
218    if self._model_parallelism_enabled:
219      self._computation_shape = _NUM_CORES_TO_COMPUTATION_SHAPE[
220          num_cores_per_replica]
221    else:
222      self._computation_shape = None
223    self._lazy_tpu_system_metadata_dict = {}  # key by master address
224    self._lazy_device_assignment_dict = {}  # key by master address
225    self._lazy_validation_dict = {}  # key by ModeKeys
226    self._embedding_config_spec = embedding_config_spec
227    self._lazy_embedding_config_dict = {}  # key by master address
228
229  def _assert_mode(self):
230    if self._mode is None:
231      raise RuntimeError(
232          '`mode` needs to be set via contextmanager `with_mode`.')
233    return self._mode
234
235  @contextmanager
236  def with_mode(self, mode):
237    # NOTE(xiejw): Shallow copy is enough. It will share he lazy dictionaries,
238    # such as _lazy_tpu_system_metadata_dict between new copy and the original
239    # one. Note that all lazy states stored in properties _lazy_foo are sort of
240    # immutable as they should be same for the process lifetime.
241    new_ctx = copy.copy(self)
242    new_ctx._mode = mode  # pylint: disable=protected-access
243    yield new_ctx
244
245  @property
246  def mode(self):
247    return self._assert_mode()
248
249  def _get_master_address(self):
250    mode = self._assert_mode()
251    config = self._config
252    master = (
253        config.master
254        if mode != model_fn_lib.ModeKeys.EVAL else config.evaluation_master)
255    return master
256
257  def _get_tpu_system_metadata(self):
258    """Gets the (maybe cached) TPU system metadata."""
259    master = self._get_master_address()
260    tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master)
261    if tpu_system_metadata is not None:
262      return tpu_system_metadata
263
264    cluster_def = None
265    if (self._config.session_config and
266        self._config.session_config.cluster_def.job):
267      cluster_def = self._config.session_config.cluster_def
268
269    # pylint: disable=protected-access
270    tpu_system_metadata = (
271        tpu_system_metadata_lib._query_tpu_system_metadata(
272            master,
273            cluster_def=cluster_def,
274            query_topology=self.model_parallelism_enabled))
275
276    self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata
277    return tpu_system_metadata
278
279  def _get_device_assignment(self):
280    """Gets the (maybe cached) TPU device assignment."""
281    master = self._get_master_address()
282    device_assignment = self._lazy_device_assignment_dict.get(master)
283    if device_assignment is not None:
284      return device_assignment
285
286    tpu_system_metadata = self._get_tpu_system_metadata()
287
288    device_assignment = tpu_device_assignment.device_assignment(
289        tpu_system_metadata.topology,
290        computation_shape=self._computation_shape,
291        num_replicas=self.num_replicas)
292
293    logging.info('num_cores_per_replica: %s',
294                 str(self._config.tpu_config.num_cores_per_replica))
295    logging.info('computation_shape: %s', str(self._computation_shape))
296    logging.info('num_replicas: %d', self.num_replicas)
297    logging.info('device_assignment.topology.device_coordinates: %s',
298                 str(device_assignment.topology.device_coordinates))
299    logging.info('device_assignment.core_assignment: %s',
300                 str(device_assignment.core_assignment))
301
302    self._lazy_device_assignment_dict[master] = device_assignment
303    return device_assignment
304
305  @property
306  def embedding_config(self):
307    """Returns the embedding config based on current mode."""
308    master = self._get_master_address()
309    if master in self._lazy_embedding_config_dict:
310      embedding_config = self._lazy_embedding_config_dict[master]
311    else:
312      embedding_config = None
313      if self._use_tpu and self._embedding_config_spec:
314        embedding_config = _tpu_estimator_embedding.EmbeddingConfig(
315            self._embedding_config_spec, self._train_batch_size,
316            self._eval_batch_size, self.num_hosts, self.num_cores, self.config)
317        if not embedding_config.has_embedding_tables():
318          embedding_config = None
319      self._lazy_embedding_config_dict[master] = embedding_config
320
321    if embedding_config is not None:
322      mode = self._assert_mode()
323      # Dynamically attach tpu_embedding based on mode. With
324      # this, we could keep embedding_config immutable but call site always
325      # accesses the unified API '.tpu_embedding'.
326      embedding_config.tpu_embedding = embedding_config.get_tpu_embedding(mode)
327    return embedding_config
328
329  @property
330  def model_parallelism_enabled(self):
331    return self._model_parallelism_enabled
332
333  @property
334  def input_partition_dims(self):
335    return self._config.tpu_config.input_partition_dims
336
337  @property
338  def device_assignment(self):
339    return (self._get_device_assignment()
340            if self._model_parallelism_enabled else None)
341
342  @property
343  def num_of_cores_per_host(self):
344    metadata = self._get_tpu_system_metadata()
345    return metadata.num_of_cores_per_host
346
347  @property
348  def num_cores(self):
349    metadata = self._get_tpu_system_metadata()
350    return metadata.num_cores
351
352  @property
353  def num_of_replicas_per_host(self):
354    """Return the number of replicas per host."""
355    if self.model_parallelism_enabled:
356      return self.num_replicas // self.num_hosts
357    else:
358      return self.num_of_cores_per_host
359
360  @property
361  def num_replicas(self):
362    num_cores_in_system = self.num_cores
363
364    if self.model_parallelism_enabled:
365      num_cores_per_replica = self._config.tpu_config.num_cores_per_replica
366      if num_cores_per_replica > num_cores_in_system:
367        raise ValueError(
368            'The num of cores required by the model parallelism, specified by '
369            'TPUConfig.num_cores_per_replica, is larger than the total num of '
370            'TPU cores in the system. num_cores_per_replica: {}, num cores '
371            'in the system: {}'.format(num_cores_per_replica,
372                                       num_cores_in_system))
373
374      if num_cores_in_system % num_cores_per_replica != 0:
375        raise RuntimeError(
376            'The num of cores in the system ({}) is not divisible by the num '
377            'of cores ({}) required by the model parallelism, specified by '
378            'TPUConfig.num_cores_per_replica. This should never happen!'.format(
379                num_cores_in_system, num_cores_per_replica))
380
381      return num_cores_in_system // num_cores_per_replica
382    else:
383      return num_cores_in_system
384
385  @property
386  def num_hosts(self):
387    metadata = self._get_tpu_system_metadata()
388    return metadata.num_hosts
389
390  @property
391  def config(self):
392    return self._config
393
394  def is_input_sharded_per_core(self):
395    """Return true if input_fn is invoked per-core (other than per-host)."""
396    mode = self._assert_mode()
397    return (mode == model_fn_lib.ModeKeys.TRAIN and
398            (self._config.tpu_config.per_host_input_for_training is
399             tpu_config.InputPipelineConfig.PER_SHARD_V1))
400
401  def is_input_per_host_with_iterators(self):
402    """Return true if input_fn should be run in the per-host v2 config."""
403    return (self._config.tpu_config.per_host_input_for_training is
404            tpu_config.InputPipelineConfig.PER_HOST_V2)
405
406  def is_input_broadcast_with_iterators(self):
407    """Return true if input_fn should be run in the full_replicae config."""
408    mode = self._assert_mode()
409    return ((self._config.tpu_config.per_host_input_for_training is
410             tpu_config.InputPipelineConfig.BROADCAST) or
411            (mode != model_fn_lib.ModeKeys.TRAIN and
412             self._config.tpu_config.eval_training_input_configuration is
413             tpu_config.InputPipelineConfig.SLICED))
414
415  def is_running_on_cpu(self, is_export_mode=False):
416    """Determines whether the input_fn and model_fn should be invoked on CPU.
417
418    This API also validates user provided configuration, such as batch size,
419    according the lazy initialized TPU system metadata.
420
421    Args:
422      is_export_mode: Indicates whether the current mode is for exporting the
423        model, when mode == PREDICT. Only with this bool, we could
424        tell whether user is calling the Estimator.predict or
425        Estimator.export_savedmodel, which are running on TPU and CPU
426        respectively. Parent class Estimator does not distinguish these two.
427
428    Returns:
429      bool, whether current input_fn or model_fn should be running on CPU.
430
431    Raises:
432      ValueError: any configuration is invalid.
433    """
434
435    is_running_on_cpu = self._is_running_on_cpu(is_export_mode)
436    if not is_running_on_cpu:
437      self._validate_tpu_configuration()
438    return is_running_on_cpu
439
440  def _is_running_on_cpu(self, is_export_mode):
441    """Determines whether the input_fn and model_fn should be invoked on CPU."""
442    mode = self._assert_mode()
443
444    if not self._use_tpu:
445      return True
446
447    if mode == model_fn_lib.ModeKeys.EVAL and not self._eval_on_tpu:
448      logging.info('_is_running_on_cpu: eval_on_tpu disabled')
449      return True
450
451    if is_export_mode:
452      return True
453
454    return False
455
456  @property
457  def global_batch_size(self):
458    mode = self._assert_mode()
459    if mode == model_fn_lib.ModeKeys.TRAIN:
460      return self._train_batch_size
461    elif mode == model_fn_lib.ModeKeys.EVAL:
462      return self._eval_batch_size
463    elif mode == model_fn_lib.ModeKeys.PREDICT:
464      return self._predict_batch_size
465    else:
466      return None
467
468  @property
469  def batch_size_for_input_fn(self):
470    """Returns the shard batch size for `input_fn`."""
471    global_batch_size = self.global_batch_size
472    if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()):
473      return global_batch_size
474
475    # On TPU
476    if self.is_input_sharded_per_core() or (
477        self.is_input_per_host_with_iterators()):
478      return global_batch_size // self.num_replicas
479    else:
480      return global_batch_size // self.num_hosts
481
482  @property
483  def batch_size_for_model_fn(self):
484    """Returns the shard batch size for `model_fn`."""
485    global_batch_size = self.global_batch_size
486
487    if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()):
488      return global_batch_size
489
490    # On TPU. always sharded per shard.
491    return global_batch_size // self.num_replicas
492
493  @property
494  def master_job(self):
495    """Returns the job name to use to place TPU computations on.
496
497    Returns:
498      A string containing the job name, or None if no job should be specified.
499
500    Raises:
501      ValueError: If the user needs to specify a tpu_job_name, because we are
502        unable to infer the job name automatically, or if the user-specified job
503        names are inappropriate.
504    """
505    run_config = self._config
506    # If the user specifies the tpu_job_name, use that.
507    if run_config.tpu_config.tpu_job_name:
508      return run_config.tpu_config.tpu_job_name
509
510    # The tpu job is determined by the run_config. Right now, this method is
511    # required as tpu_config is not part of the RunConfig.
512    mode = self._assert_mode()
513    master = (
514        run_config.evaluation_master
515        if mode == model_fn_lib.ModeKeys.EVAL else run_config.master)
516    cluster_def = (run_config.session_config.cluster_def
517                   if run_config.session_config else None)
518
519    return tpu_system_metadata_lib.master_job(master, cluster_def)
520
521  @property
522  def tpu_host_placement_function(self):
523    """Returns the TPU host place function."""
524
525    master = self.master_job
526
527    def _placement_function(_sentinal=None, replica_id=None, host_id=None):  # pylint: disable=invalid-name
528      """Return the host device given replica_id or host_id."""
529      assert _sentinal is None
530      if replica_id is not None and host_id is not None:
531        raise RuntimeError(
532            'replica_id and host_id can have only one non-None value.')
533
534      if master is None:
535        return '/replica:0/task:0/device:CPU:0'
536      else:
537        if replica_id is not None:
538          if self.model_parallelism_enabled:
539            return self.device_assignment.host_device(
540                replica=replica_id, job=master)
541          else:
542            host_id = replica_id / self.num_of_cores_per_host
543
544        return '/job:%s/task:%d/device:CPU:0' % (master, host_id)
545
546    return _placement_function
547
548  @property
549  def tpu_device_placement_function(self):
550    """Returns a TPU device placement Fn."""
551    master = self.master_job
552    job_device = '' if master is None else ('/job:%s' % master)
553
554    def _placement_function(i):
555      if self.model_parallelism_enabled:
556        return self.device_assignment.tpu_device(replica=i, job=master)
557      else:
558        num_of_cores_per_host = self.num_of_cores_per_host
559        host_id = i / num_of_cores_per_host
560        ordinal_id = i % num_of_cores_per_host
561        return '%s/task:%d/device:TPU:%d' % (job_device, host_id, ordinal_id)
562
563    return _placement_function
564
565  def tpu_ordinal_function(self, host_id):
566    """Returns the TPU ordinal fn."""
567
568    def _tpu_ordinal_function(shard_index_in_host):
569      """Return the TPU ordinal associated with a shard.
570
571      Required because the enqueue ops are placed on CPU.
572
573      Args:
574        shard_index_in_host: the shard index
575
576      Returns:
577        The ordinal of the TPU device the shard's infeed should be placed on.
578      """
579      if self.model_parallelism_enabled:
580        # We put both enqueue/dequeue ops at tpu.core(0) in each replica.
581        replica = self.device_assignment.lookup_replicas(host_id,
582                                                         0)[shard_index_in_host]
583        return self.device_assignment.tpu_ordinal(replica=replica)
584      else:
585        return shard_index_in_host % self.num_of_cores_per_host
586
587    return _tpu_ordinal_function
588
589  def _validate_tpu_configuration(self):
590    """Validates the configuration based on the TPU system metadata."""
591    mode = self._assert_mode()
592    if self._lazy_validation_dict.get(mode):
593      return
594
595    # All following information is obtained from TPU system metadata.
596    num_cores = self.num_cores
597    num_replicas = self.num_replicas
598    num_hosts = self.num_hosts
599
600    if not num_cores:
601      tpu_system_metadata = self._get_tpu_system_metadata()
602      raise RuntimeError(
603          'Cannot find any TPU cores in the system. Please double check '
604          'Tensorflow master address and TPU worker(s). Available devices '
605          'are {}.'.format(tpu_system_metadata.devices))
606
607    if self._config.tpu_config.num_shards:
608      user_provided_num_replicas = self._config.tpu_config.num_shards
609      if user_provided_num_replicas != num_replicas:
610        message = (
611            'TPUConfig.num_shards is not set correctly. According to TPU '
612            'system metadata for Tensorflow master ({}): num_replicas should '
613            'be ({}), got ({}). For non-model-parallelism, num_replicas should '
614            'be the total num of TPU cores in the system. For '
615            'model-parallelism, the total number of TPU cores should be '
616            'num_cores_per_replica * num_replicas. Please set it '
617            'accordingly or leave it as `None`'.format(
618                self._get_master_address(), num_replicas,
619                user_provided_num_replicas))
620
621        raise ValueError(message)
622
623    if self._config.tpu_config.num_cores_per_replica:
624      num_cores_per_replica = self._config.tpu_config.num_cores_per_replica
625      num_cores_per_host = self._get_tpu_system_metadata().num_of_cores_per_host
626      if num_cores_per_replica > num_cores_per_host:
627        raise ValueError(
628            'The num of cores required by the model parallelism, specified by '
629            'TPUConfig.num_cores_per_replica, is larger than the '
630            'num_cores_per_host. num_cores_per_replica: {}, '
631            'num_cores_per_host: {}'.format(num_cores_per_replica,
632                                            num_cores_per_host))
633
634    if mode == model_fn_lib.ModeKeys.TRAIN:
635      if (self._train_batch_size % num_replicas != 0 and
636          not self.is_input_broadcast_with_iterators()):
637        raise ValueError(
638            'train batch size {} must be divisible by number of replicas {}'
639            .format(self._train_batch_size, num_replicas))
640
641    elif mode == model_fn_lib.ModeKeys.EVAL:
642      if self._eval_batch_size is None:
643        raise ValueError(
644            'eval_batch_size in TPUEstimator constructor cannot be `None`'
645            'if .evaluate is running on TPU.')
646      if (self._eval_batch_size % num_replicas != 0 and
647          not self.is_input_broadcast_with_iterators()):
648        raise ValueError(
649            'eval batch size {} must be divisible by number of replicas {}'
650            .format(self._eval_batch_size, num_replicas))
651      if num_hosts > 1 and not self.is_input_broadcast_with_iterators():
652        raise ValueError(
653            'TPUEstimator.evaluate should be running on single TPU'
654            ' instead of a Pod.')
655    else:
656      assert mode == model_fn_lib.ModeKeys.PREDICT
657      if self._predict_batch_size is None:
658        raise ValueError(
659            'predict_batch_size in TPUEstimator constructor should not be '
660            '`None` if .predict is running on TPU.')
661      if (self._predict_batch_size % num_replicas != 0 and
662          not self.is_input_broadcast_with_iterators()):
663        raise ValueError(
664            'predict batch size {} must be divisible by number of replicas {}'
665            .format(self._predict_batch_size, num_replicas))
666      if num_hosts > 1 and not self.is_input_broadcast_with_iterators():
667        raise ValueError(
668            'TPUEstimator.predict should be running on single TPU worker. '
669            'got {}.'.format(num_hosts))
670
671    # Record the state "validated" into lazy dictionary.
672    self._lazy_validation_dict[mode] = True
673
674  def device_for_replica(self, replica_id):
675    """Returns the tuple of (CPU device and device ordinal) for replica.
676
677    This should be used for full replicate for non-model-parallelism.
678
679    Args:
680       replica_id: Int, the replica index.
681
682    Returns:
683       A tuple of device spec for CPU device and int device ordinal.
684    """
685    master = self.master_job
686
687    if self.model_parallelism_enabled:
688      return (self.device_assignment.host_device(
689          replica=replica_id, job=master),
690              self.device_assignment.tpu_ordinal(replica=replica_id))
691
692    job_device = '' if master is None else ('/job:%s' % master)
693
694    num_of_replicas_per_host = self.num_of_replicas_per_host
695    host_id = replica_id / num_of_replicas_per_host
696    ordinal_id = replica_id % num_of_replicas_per_host
697
698    host_device = '%s/task:%d/device:CPU:0' % (job_device, host_id)
699    return (host_device, ordinal_id)
700
701
702class _OneCoreTPUContext(_InternalTPUContext):
703  """Special _InternalTPUContext for one core usage."""
704
705  def __init__(self, config, train_batch_size, eval_batch_size,
706               predict_batch_size, use_tpu):
707
708    super(_OneCoreTPUContext, self).__init__(
709        config, train_batch_size, eval_batch_size,
710        predict_batch_size, use_tpu)
711
712  def _get_tpu_system_metadata(self):
713    """Gets the (maybe cached) TPU system metadata."""
714    master = self._get_master_address()
715    tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master)
716    if tpu_system_metadata is not None:
717      return tpu_system_metadata
718
719    tpu_system_metadata = (
720        tpu_system_metadata_lib._TPUSystemMetadata(  # pylint: disable=protected-access
721            num_cores=1,
722            num_hosts=1,
723            num_of_cores_per_host=1,
724            topology=None,
725            devices=[]))
726
727    self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata
728    return tpu_system_metadata
729
730
731def _get_tpu_context(config, train_batch_size, eval_batch_size,
732                     predict_batch_size, use_tpu, eval_on_tpu,
733                     embedding_config_spec):
734  """Returns an instance of `_InternalTPUContext`."""
735
736  if (config.tpu_config.num_shards == 1 and
737      config.tpu_config.num_cores_per_replica is None):
738    if embedding_config_spec is not None:
739      raise ValueError('Setting TPUConfig.num_shards==1 is unsupported '
740                       'when embedding_config_spec is not None.')
741    logging.warning(
742        'Setting TPUConfig.num_shards==1 is an unsupported behavior. '
743        'Please fix as soon as possible (leaving num_shards as None.)')
744    return _OneCoreTPUContext(config, train_batch_size, eval_batch_size,
745                              predict_batch_size, use_tpu)
746
747  return _InternalTPUContext(config, train_batch_size, eval_batch_size,
748                             predict_batch_size, use_tpu, eval_on_tpu,
749                             embedding_config_spec)
750