1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# =================================================================== 15"""TPU system metadata and associated tooling.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from contextlib import contextmanager 22import copy 23 24from tensorflow.python.estimator import model_fn as model_fn_lib 25from tensorflow.python.platform import tf_logging as logging 26from tensorflow.python.tpu import _tpu_estimator_embedding 27from tensorflow.python.tpu import device_assignment as tpu_device_assignment 28from tensorflow.python.tpu import tpu_config 29from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib 30 31 32_DEFAULT_JOB_NAME = 'tpu_worker' 33_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' 34_LOCAL_MASTERS = ('', 'local') 35_NUM_CORES_TO_COMPUTATION_SHAPE = { 36 1: [1, 1, 1], 37 2: [1, 1, 2], 38 4: [1, 2, 2], 39 8: [2, 2, 2], 40 16: [4, 2, 2], 41} 42 43 44class TPUContext(object): 45 """A context that holds the current configuration of the TPU computation.""" 46 47 def __init__(self, 48 internal_ctx, 49 input_device=None, 50 invocation_index=None, 51 call_from_input_fn=True): 52 self._internal_ctx = internal_ctx 53 self._input_device = input_device 54 self._invocation_index = invocation_index 55 self._call_from_input_fn = call_from_input_fn 56 57 def current_input_fn_deployment(self): 58 """The configuration of the current input_fn invocation. 59 60 The configuration depends on `TPUConfig.per_host_input_for_training`. See 61 `TPUConfig` for details. 62 63 Only set in params dict of input_fn 64 65 Returns: 66 A tuple of 67 1. Device spec string: String, is the current CPU host where the 68 input_fn is invoked. 69 2. Current invocation index: Int, 0-based index of the input_fn 70 invocation. See next item for details. 71 3. Total invocation count: Int, the total number of times to invoke the 72 input_fn on all CPU hosts. Each invocation will be passed with a new 73 `TPUContext` instance with current invocation index set properly. 74 4. Total number of replicas consumed by current_invocation: Int, the 75 number of replicas fed by the data returned by current input_fn. For 76 example, for per_core input pipeline deployment 77 and non-model-parallelism, total invocation count is equal to 78 the number of cores in the system and num replicas consumed by 79 current invocation is 1. For per-host v2 input pipeline deployment, 80 total invocation count is equal to the number of hosts in the system 81 and num replicas consumed by current invocation is equal to number of 82 cores per host. 83 84 Raises: 85 RuntimeError: If this method must not be called from input_fn. 86 """ 87 if not self._call_from_input_fn: 88 raise RuntimeError('This TPUContext instance must not be called from' 89 ' model_fn.') 90 91 if self._internal_ctx.is_input_sharded_per_core(): 92 total_invocation_count = (self._internal_ctx.num_hosts 93 * self._internal_ctx.num_of_replicas_per_host) 94 replicas_consumed = 1 95 elif self._internal_ctx.is_input_broadcast_with_iterators(): 96 total_invocation_count = 1 97 replicas_consumed = self._internal_ctx.num_replicas 98 else: 99 total_invocation_count = self._internal_ctx.num_hosts 100 replicas_consumed = self._internal_ctx.num_of_replicas_per_host 101 return (self._input_device, self._invocation_index, 102 total_invocation_count, replicas_consumed) 103 104 @property 105 def num_replicas(self): 106 """The total number of replicas. 107 108 For non-model-parallelism, num_replicas should be the total num of TPU 109 cores in the system. 110 111 Returns: 112 The number of replicas. 113 """ 114 return self._internal_ctx.num_replicas 115 116 @property 117 def num_hosts(self): 118 """The number of hosts for the TPU system.""" 119 return self._internal_ctx.num_hosts 120 121 @property 122 def current_host(self): 123 """The current host index for the TPU system.""" 124 return self._invocation_index 125 126 @property 127 def num_of_replicas_per_host(self): 128 """The number of replicas for each host.""" 129 if self._internal_ctx.model_parallelism_enabled: 130 raise ValueError( 131 'num_of_replicas_per_host is not supported for model_parallelism') 132 return self._internal_ctx.num_of_replicas_per_host 133 134 @property 135 def device_assignment(self): 136 """Returns device_assignment object.""" 137 if self._call_from_input_fn: 138 raise RuntimeError('This TPUContext instance must not be called from' 139 ' input_fn.') 140 return self._internal_ctx.device_assignment 141 142 def device_for_replica(self, replica_id): 143 """Returns the tuple of (CPU device and device ordinal) for replica. 144 145 This should be used for full replicate for non-model-parallelism. 146 147 Args: 148 replica_id: Int, the replica index. 149 150 Returns: 151 A tuple of device spec for CPU device and int device ordinal. 152 """ 153 # Note that: For the non-model parallelism, the mapping could be 154 # a random permutation. The order should not matter in most cases 155 # as far as model is replicated to all cores in the system. 156 return self._internal_ctx.device_for_replica(replica_id) 157 158 @property 159 def tpu_host_placement_function(self): 160 """Returns the TPU host place function. 161 162 The place function takes host_id as the input and returns the TF device 163 for the correspoding host. 164 """ 165 166 def _placement_function(host_id): 167 """Return the host device given host_id.""" 168 return self._internal_ctx.tpu_host_placement_function(host_id=host_id) 169 170 return _placement_function 171 172 173class _InternalTPUContext(object): 174 """A context holds immutable states of TPU computation. 175 176 This immutable object holds TPUEstimator config, train/eval batch size, and 177 `TPUEstimator.use_tpu`, which is expected to be passed around. It also 178 provides utility functions, based on the current state, to determine other 179 information commonly required by TPU computation, such as TPU device names, 180 TPU hosts, shard batch size, etc. 181 182 if eval_on_tpu is False, then execution of eval on TPU is disabled. 183 if eval_on_tpu is True, but use_tpu is False, a warning is issued, 184 and TPU execution is disabled for all modes. 185 186 N.B. As `mode` is not immutable state in Estimator, but essential to 187 distinguish between TPU training and evaluation, a common usage for 188 _InternalTPUContext with `mode` is as follows: 189 ``` 190 with _ctx.with_mode(mode) as ctx: 191 if ctx.is_running_on_cpu(): 192 ... 193 ``` 194 """ 195 196 def __init__(self, 197 config, 198 train_batch_size, 199 eval_batch_size, 200 predict_batch_size, 201 use_tpu, 202 eval_on_tpu=True, 203 embedding_config_spec=None): 204 self._config = config 205 self._train_batch_size = train_batch_size 206 self._eval_batch_size = eval_batch_size 207 self._predict_batch_size = predict_batch_size 208 self._use_tpu = use_tpu 209 logging.info('_TPUContext: eval_on_tpu %s', eval_on_tpu) 210 if not use_tpu and eval_on_tpu: 211 logging.warning('eval_on_tpu ignored because use_tpu is False.') 212 213 self._eval_on_tpu = eval_on_tpu 214 self._model_parallelism_enabled = ( 215 use_tpu and config.tpu_config.num_cores_per_replica) 216 self._mode = None 217 num_cores_per_replica = config.tpu_config.num_cores_per_replica 218 if self._model_parallelism_enabled: 219 self._computation_shape = _NUM_CORES_TO_COMPUTATION_SHAPE[ 220 num_cores_per_replica] 221 else: 222 self._computation_shape = None 223 self._lazy_tpu_system_metadata_dict = {} # key by master address 224 self._lazy_device_assignment_dict = {} # key by master address 225 self._lazy_validation_dict = {} # key by ModeKeys 226 self._embedding_config_spec = embedding_config_spec 227 self._lazy_embedding_config_dict = {} # key by master address 228 229 def _assert_mode(self): 230 if self._mode is None: 231 raise RuntimeError( 232 '`mode` needs to be set via contextmanager `with_mode`.') 233 return self._mode 234 235 @contextmanager 236 def with_mode(self, mode): 237 # NOTE(xiejw): Shallow copy is enough. It will share he lazy dictionaries, 238 # such as _lazy_tpu_system_metadata_dict between new copy and the original 239 # one. Note that all lazy states stored in properties _lazy_foo are sort of 240 # immutable as they should be same for the process lifetime. 241 new_ctx = copy.copy(self) 242 new_ctx._mode = mode # pylint: disable=protected-access 243 yield new_ctx 244 245 @property 246 def mode(self): 247 return self._assert_mode() 248 249 def _get_master_address(self): 250 mode = self._assert_mode() 251 config = self._config 252 master = ( 253 config.master 254 if mode != model_fn_lib.ModeKeys.EVAL else config.evaluation_master) 255 return master 256 257 def _get_tpu_system_metadata(self): 258 """Gets the (maybe cached) TPU system metadata.""" 259 master = self._get_master_address() 260 tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master) 261 if tpu_system_metadata is not None: 262 return tpu_system_metadata 263 264 cluster_def = None 265 if (self._config.session_config and 266 self._config.session_config.cluster_def.job): 267 cluster_def = self._config.session_config.cluster_def 268 269 # pylint: disable=protected-access 270 tpu_system_metadata = ( 271 tpu_system_metadata_lib._query_tpu_system_metadata( 272 master, 273 cluster_def=cluster_def, 274 query_topology=self.model_parallelism_enabled)) 275 276 self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata 277 return tpu_system_metadata 278 279 def _get_device_assignment(self): 280 """Gets the (maybe cached) TPU device assignment.""" 281 master = self._get_master_address() 282 device_assignment = self._lazy_device_assignment_dict.get(master) 283 if device_assignment is not None: 284 return device_assignment 285 286 tpu_system_metadata = self._get_tpu_system_metadata() 287 288 device_assignment = tpu_device_assignment.device_assignment( 289 tpu_system_metadata.topology, 290 computation_shape=self._computation_shape, 291 num_replicas=self.num_replicas) 292 293 logging.info('num_cores_per_replica: %s', 294 str(self._config.tpu_config.num_cores_per_replica)) 295 logging.info('computation_shape: %s', str(self._computation_shape)) 296 logging.info('num_replicas: %d', self.num_replicas) 297 logging.info('device_assignment.topology.device_coordinates: %s', 298 str(device_assignment.topology.device_coordinates)) 299 logging.info('device_assignment.core_assignment: %s', 300 str(device_assignment.core_assignment)) 301 302 self._lazy_device_assignment_dict[master] = device_assignment 303 return device_assignment 304 305 @property 306 def embedding_config(self): 307 """Returns the embedding config based on current mode.""" 308 master = self._get_master_address() 309 if master in self._lazy_embedding_config_dict: 310 embedding_config = self._lazy_embedding_config_dict[master] 311 else: 312 embedding_config = None 313 if self._use_tpu and self._embedding_config_spec: 314 embedding_config = _tpu_estimator_embedding.EmbeddingConfig( 315 self._embedding_config_spec, self._train_batch_size, 316 self._eval_batch_size, self.num_hosts, self.num_cores, self.config) 317 if not embedding_config.has_embedding_tables(): 318 embedding_config = None 319 self._lazy_embedding_config_dict[master] = embedding_config 320 321 if embedding_config is not None: 322 mode = self._assert_mode() 323 # Dynamically attach tpu_embedding based on mode. With 324 # this, we could keep embedding_config immutable but call site always 325 # accesses the unified API '.tpu_embedding'. 326 embedding_config.tpu_embedding = embedding_config.get_tpu_embedding(mode) 327 return embedding_config 328 329 @property 330 def model_parallelism_enabled(self): 331 return self._model_parallelism_enabled 332 333 @property 334 def input_partition_dims(self): 335 return self._config.tpu_config.input_partition_dims 336 337 @property 338 def device_assignment(self): 339 return (self._get_device_assignment() 340 if self._model_parallelism_enabled else None) 341 342 @property 343 def num_of_cores_per_host(self): 344 metadata = self._get_tpu_system_metadata() 345 return metadata.num_of_cores_per_host 346 347 @property 348 def num_cores(self): 349 metadata = self._get_tpu_system_metadata() 350 return metadata.num_cores 351 352 @property 353 def num_of_replicas_per_host(self): 354 """Return the number of replicas per host.""" 355 if self.model_parallelism_enabled: 356 return self.num_replicas // self.num_hosts 357 else: 358 return self.num_of_cores_per_host 359 360 @property 361 def num_replicas(self): 362 num_cores_in_system = self.num_cores 363 364 if self.model_parallelism_enabled: 365 num_cores_per_replica = self._config.tpu_config.num_cores_per_replica 366 if num_cores_per_replica > num_cores_in_system: 367 raise ValueError( 368 'The num of cores required by the model parallelism, specified by ' 369 'TPUConfig.num_cores_per_replica, is larger than the total num of ' 370 'TPU cores in the system. num_cores_per_replica: {}, num cores ' 371 'in the system: {}'.format(num_cores_per_replica, 372 num_cores_in_system)) 373 374 if num_cores_in_system % num_cores_per_replica != 0: 375 raise RuntimeError( 376 'The num of cores in the system ({}) is not divisible by the num ' 377 'of cores ({}) required by the model parallelism, specified by ' 378 'TPUConfig.num_cores_per_replica. This should never happen!'.format( 379 num_cores_in_system, num_cores_per_replica)) 380 381 return num_cores_in_system // num_cores_per_replica 382 else: 383 return num_cores_in_system 384 385 @property 386 def num_hosts(self): 387 metadata = self._get_tpu_system_metadata() 388 return metadata.num_hosts 389 390 @property 391 def config(self): 392 return self._config 393 394 def is_input_sharded_per_core(self): 395 """Return true if input_fn is invoked per-core (other than per-host).""" 396 mode = self._assert_mode() 397 return (mode == model_fn_lib.ModeKeys.TRAIN and 398 (self._config.tpu_config.per_host_input_for_training is 399 tpu_config.InputPipelineConfig.PER_SHARD_V1)) 400 401 def is_input_per_host_with_iterators(self): 402 """Return true if input_fn should be run in the per-host v2 config.""" 403 return (self._config.tpu_config.per_host_input_for_training is 404 tpu_config.InputPipelineConfig.PER_HOST_V2) 405 406 def is_input_broadcast_with_iterators(self): 407 """Return true if input_fn should be run in the full_replicae config.""" 408 mode = self._assert_mode() 409 return ((self._config.tpu_config.per_host_input_for_training is 410 tpu_config.InputPipelineConfig.BROADCAST) or 411 (mode != model_fn_lib.ModeKeys.TRAIN and 412 self._config.tpu_config.eval_training_input_configuration is 413 tpu_config.InputPipelineConfig.SLICED)) 414 415 def is_running_on_cpu(self, is_export_mode=False): 416 """Determines whether the input_fn and model_fn should be invoked on CPU. 417 418 This API also validates user provided configuration, such as batch size, 419 according the lazy initialized TPU system metadata. 420 421 Args: 422 is_export_mode: Indicates whether the current mode is for exporting the 423 model, when mode == PREDICT. Only with this bool, we could 424 tell whether user is calling the Estimator.predict or 425 Estimator.export_savedmodel, which are running on TPU and CPU 426 respectively. Parent class Estimator does not distinguish these two. 427 428 Returns: 429 bool, whether current input_fn or model_fn should be running on CPU. 430 431 Raises: 432 ValueError: any configuration is invalid. 433 """ 434 435 is_running_on_cpu = self._is_running_on_cpu(is_export_mode) 436 if not is_running_on_cpu: 437 self._validate_tpu_configuration() 438 return is_running_on_cpu 439 440 def _is_running_on_cpu(self, is_export_mode): 441 """Determines whether the input_fn and model_fn should be invoked on CPU.""" 442 mode = self._assert_mode() 443 444 if not self._use_tpu: 445 return True 446 447 if mode == model_fn_lib.ModeKeys.EVAL and not self._eval_on_tpu: 448 logging.info('_is_running_on_cpu: eval_on_tpu disabled') 449 return True 450 451 if is_export_mode: 452 return True 453 454 return False 455 456 @property 457 def global_batch_size(self): 458 mode = self._assert_mode() 459 if mode == model_fn_lib.ModeKeys.TRAIN: 460 return self._train_batch_size 461 elif mode == model_fn_lib.ModeKeys.EVAL: 462 return self._eval_batch_size 463 elif mode == model_fn_lib.ModeKeys.PREDICT: 464 return self._predict_batch_size 465 else: 466 return None 467 468 @property 469 def batch_size_for_input_fn(self): 470 """Returns the shard batch size for `input_fn`.""" 471 global_batch_size = self.global_batch_size 472 if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()): 473 return global_batch_size 474 475 # On TPU 476 if self.is_input_sharded_per_core() or ( 477 self.is_input_per_host_with_iterators()): 478 return global_batch_size // self.num_replicas 479 else: 480 return global_batch_size // self.num_hosts 481 482 @property 483 def batch_size_for_model_fn(self): 484 """Returns the shard batch size for `model_fn`.""" 485 global_batch_size = self.global_batch_size 486 487 if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()): 488 return global_batch_size 489 490 # On TPU. always sharded per shard. 491 return global_batch_size // self.num_replicas 492 493 @property 494 def master_job(self): 495 """Returns the job name to use to place TPU computations on. 496 497 Returns: 498 A string containing the job name, or None if no job should be specified. 499 500 Raises: 501 ValueError: If the user needs to specify a tpu_job_name, because we are 502 unable to infer the job name automatically, or if the user-specified job 503 names are inappropriate. 504 """ 505 run_config = self._config 506 # If the user specifies the tpu_job_name, use that. 507 if run_config.tpu_config.tpu_job_name: 508 return run_config.tpu_config.tpu_job_name 509 510 # The tpu job is determined by the run_config. Right now, this method is 511 # required as tpu_config is not part of the RunConfig. 512 mode = self._assert_mode() 513 master = ( 514 run_config.evaluation_master 515 if mode == model_fn_lib.ModeKeys.EVAL else run_config.master) 516 cluster_def = (run_config.session_config.cluster_def 517 if run_config.session_config else None) 518 519 return tpu_system_metadata_lib.master_job(master, cluster_def) 520 521 @property 522 def tpu_host_placement_function(self): 523 """Returns the TPU host place function.""" 524 525 master = self.master_job 526 527 def _placement_function(_sentinal=None, replica_id=None, host_id=None): # pylint: disable=invalid-name 528 """Return the host device given replica_id or host_id.""" 529 assert _sentinal is None 530 if replica_id is not None and host_id is not None: 531 raise RuntimeError( 532 'replica_id and host_id can have only one non-None value.') 533 534 if master is None: 535 return '/replica:0/task:0/device:CPU:0' 536 else: 537 if replica_id is not None: 538 if self.model_parallelism_enabled: 539 return self.device_assignment.host_device( 540 replica=replica_id, job=master) 541 else: 542 host_id = replica_id / self.num_of_cores_per_host 543 544 return '/job:%s/task:%d/device:CPU:0' % (master, host_id) 545 546 return _placement_function 547 548 @property 549 def tpu_device_placement_function(self): 550 """Returns a TPU device placement Fn.""" 551 master = self.master_job 552 job_device = '' if master is None else ('/job:%s' % master) 553 554 def _placement_function(i): 555 if self.model_parallelism_enabled: 556 return self.device_assignment.tpu_device(replica=i, job=master) 557 else: 558 num_of_cores_per_host = self.num_of_cores_per_host 559 host_id = i / num_of_cores_per_host 560 ordinal_id = i % num_of_cores_per_host 561 return '%s/task:%d/device:TPU:%d' % (job_device, host_id, ordinal_id) 562 563 return _placement_function 564 565 def tpu_ordinal_function(self, host_id): 566 """Returns the TPU ordinal fn.""" 567 568 def _tpu_ordinal_function(shard_index_in_host): 569 """Return the TPU ordinal associated with a shard. 570 571 Required because the enqueue ops are placed on CPU. 572 573 Args: 574 shard_index_in_host: the shard index 575 576 Returns: 577 The ordinal of the TPU device the shard's infeed should be placed on. 578 """ 579 if self.model_parallelism_enabled: 580 # We put both enqueue/dequeue ops at tpu.core(0) in each replica. 581 replica = self.device_assignment.lookup_replicas(host_id, 582 0)[shard_index_in_host] 583 return self.device_assignment.tpu_ordinal(replica=replica) 584 else: 585 return shard_index_in_host % self.num_of_cores_per_host 586 587 return _tpu_ordinal_function 588 589 def _validate_tpu_configuration(self): 590 """Validates the configuration based on the TPU system metadata.""" 591 mode = self._assert_mode() 592 if self._lazy_validation_dict.get(mode): 593 return 594 595 # All following information is obtained from TPU system metadata. 596 num_cores = self.num_cores 597 num_replicas = self.num_replicas 598 num_hosts = self.num_hosts 599 600 if not num_cores: 601 tpu_system_metadata = self._get_tpu_system_metadata() 602 raise RuntimeError( 603 'Cannot find any TPU cores in the system. Please double check ' 604 'Tensorflow master address and TPU worker(s). Available devices ' 605 'are {}.'.format(tpu_system_metadata.devices)) 606 607 if self._config.tpu_config.num_shards: 608 user_provided_num_replicas = self._config.tpu_config.num_shards 609 if user_provided_num_replicas != num_replicas: 610 message = ( 611 'TPUConfig.num_shards is not set correctly. According to TPU ' 612 'system metadata for Tensorflow master ({}): num_replicas should ' 613 'be ({}), got ({}). For non-model-parallelism, num_replicas should ' 614 'be the total num of TPU cores in the system. For ' 615 'model-parallelism, the total number of TPU cores should be ' 616 'num_cores_per_replica * num_replicas. Please set it ' 617 'accordingly or leave it as `None`'.format( 618 self._get_master_address(), num_replicas, 619 user_provided_num_replicas)) 620 621 raise ValueError(message) 622 623 if self._config.tpu_config.num_cores_per_replica: 624 num_cores_per_replica = self._config.tpu_config.num_cores_per_replica 625 num_cores_per_host = self._get_tpu_system_metadata().num_of_cores_per_host 626 if num_cores_per_replica > num_cores_per_host: 627 raise ValueError( 628 'The num of cores required by the model parallelism, specified by ' 629 'TPUConfig.num_cores_per_replica, is larger than the ' 630 'num_cores_per_host. num_cores_per_replica: {}, ' 631 'num_cores_per_host: {}'.format(num_cores_per_replica, 632 num_cores_per_host)) 633 634 if mode == model_fn_lib.ModeKeys.TRAIN: 635 if (self._train_batch_size % num_replicas != 0 and 636 not self.is_input_broadcast_with_iterators()): 637 raise ValueError( 638 'train batch size {} must be divisible by number of replicas {}' 639 .format(self._train_batch_size, num_replicas)) 640 641 elif mode == model_fn_lib.ModeKeys.EVAL: 642 if self._eval_batch_size is None: 643 raise ValueError( 644 'eval_batch_size in TPUEstimator constructor cannot be `None`' 645 'if .evaluate is running on TPU.') 646 if (self._eval_batch_size % num_replicas != 0 and 647 not self.is_input_broadcast_with_iterators()): 648 raise ValueError( 649 'eval batch size {} must be divisible by number of replicas {}' 650 .format(self._eval_batch_size, num_replicas)) 651 if num_hosts > 1 and not self.is_input_broadcast_with_iterators(): 652 raise ValueError( 653 'TPUEstimator.evaluate should be running on single TPU' 654 ' instead of a Pod.') 655 else: 656 assert mode == model_fn_lib.ModeKeys.PREDICT 657 if self._predict_batch_size is None: 658 raise ValueError( 659 'predict_batch_size in TPUEstimator constructor should not be ' 660 '`None` if .predict is running on TPU.') 661 if (self._predict_batch_size % num_replicas != 0 and 662 not self.is_input_broadcast_with_iterators()): 663 raise ValueError( 664 'predict batch size {} must be divisible by number of replicas {}' 665 .format(self._predict_batch_size, num_replicas)) 666 if num_hosts > 1 and not self.is_input_broadcast_with_iterators(): 667 raise ValueError( 668 'TPUEstimator.predict should be running on single TPU worker. ' 669 'got {}.'.format(num_hosts)) 670 671 # Record the state "validated" into lazy dictionary. 672 self._lazy_validation_dict[mode] = True 673 674 def device_for_replica(self, replica_id): 675 """Returns the tuple of (CPU device and device ordinal) for replica. 676 677 This should be used for full replicate for non-model-parallelism. 678 679 Args: 680 replica_id: Int, the replica index. 681 682 Returns: 683 A tuple of device spec for CPU device and int device ordinal. 684 """ 685 master = self.master_job 686 687 if self.model_parallelism_enabled: 688 return (self.device_assignment.host_device( 689 replica=replica_id, job=master), 690 self.device_assignment.tpu_ordinal(replica=replica_id)) 691 692 job_device = '' if master is None else ('/job:%s' % master) 693 694 num_of_replicas_per_host = self.num_of_replicas_per_host 695 host_id = replica_id / num_of_replicas_per_host 696 ordinal_id = replica_id % num_of_replicas_per_host 697 698 host_device = '%s/task:%d/device:CPU:0' % (job_device, host_id) 699 return (host_device, ordinal_id) 700 701 702class _OneCoreTPUContext(_InternalTPUContext): 703 """Special _InternalTPUContext for one core usage.""" 704 705 def __init__(self, config, train_batch_size, eval_batch_size, 706 predict_batch_size, use_tpu): 707 708 super(_OneCoreTPUContext, self).__init__( 709 config, train_batch_size, eval_batch_size, 710 predict_batch_size, use_tpu) 711 712 def _get_tpu_system_metadata(self): 713 """Gets the (maybe cached) TPU system metadata.""" 714 master = self._get_master_address() 715 tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master) 716 if tpu_system_metadata is not None: 717 return tpu_system_metadata 718 719 tpu_system_metadata = ( 720 tpu_system_metadata_lib._TPUSystemMetadata( # pylint: disable=protected-access 721 num_cores=1, 722 num_hosts=1, 723 num_of_cores_per_host=1, 724 topology=None, 725 devices=[])) 726 727 self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata 728 return tpu_system_metadata 729 730 731def _get_tpu_context(config, train_batch_size, eval_batch_size, 732 predict_batch_size, use_tpu, eval_on_tpu, 733 embedding_config_spec): 734 """Returns an instance of `_InternalTPUContext`.""" 735 736 if (config.tpu_config.num_shards == 1 and 737 config.tpu_config.num_cores_per_replica is None): 738 if embedding_config_spec is not None: 739 raise ValueError('Setting TPUConfig.num_shards==1 is unsupported ' 740 'when embedding_config_spec is not None.') 741 logging.warning( 742 'Setting TPUConfig.num_shards==1 is an unsupported behavior. ' 743 'Please fix as soon as possible (leaving num_shards as None.)') 744 return _OneCoreTPUContext(config, train_batch_size, eval_batch_size, 745 predict_batch_size, use_tpu) 746 747 return _InternalTPUContext(config, train_batch_size, eval_batch_size, 748 predict_batch_size, use_tpu, eval_on_tpu, 749 embedding_config_spec) 750