1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# =================================================================== 15 16"""A RunConfig subclass with TPU support.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import json 24import os 25 26from tensorflow.core.protobuf import config_pb2 27from tensorflow.python.estimator import run_config as run_config_lib 28from tensorflow.python.platform import tf_logging as logging 29from tensorflow.python.tpu import util as util_lib 30 31# pylint: disable=protected-access 32_TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV 33_SERVICE_KEY = run_config_lib._SERVICE_KEY 34_TPU_WORKER_JOB_NAME = 'tpu_worker_job_name' 35# pylint: enable=protected-access 36 37 38class InputPipelineConfig(object): 39 r"""Please see the definition of these values in TPUConfig.""" 40 PER_SHARD_V1 = 1 41 PER_HOST_V1 = 2 42 PER_HOST_V2 = 3 43 BROADCAST = 4 44 SLICED = 5 45 46 47class TPUConfig( 48 collections.namedtuple('TPUConfig', [ 49 'iterations_per_loop', 50 'num_shards', 51 'num_cores_per_replica', 52 'per_host_input_for_training', 53 'tpu_job_name', 54 'initial_infeed_sleep_secs', 55 'input_partition_dims', 56 'eval_training_input_configuration', 57 ])): 58 r"""TPU related configuration required by `TPUEstimator`. 59 60 Args: 61 iterations_per_loop: This is the number of train steps running in TPU 62 system before returning to CPU host for each `Session.run`. This means 63 global step is increased `iterations_per_loop` times in one `Session.run`. 64 It is recommended to be set as number of global steps for next checkpoint. 65 num_shards: (Deprecated, ignored by TPUEstimator). 66 The number of model replicas in the system. For non-model-parallelism 67 case, this number equals the total number of TPU cores. For 68 model-parallelism, the total number of TPU cores equals 69 num_cores_per_replica * num_shards. 70 num_cores_per_replica: Defaults to `None`, which disables model parallelism. 71 An integer which describes the number of TPU cores per model replica. This 72 is required by model-parallelism which enables partitioning 73 the model to multiple cores. Currently num_cores_per_replica must be 74 1, 2, 4, or 8. 75 per_host_input_for_training: If `True`, `PER_HOST_V1`, or `PER_HOST_V2`, 76 `input_fn` is invoked once on each host. With the per-core input pipeline 77 configuration, it is invoked once for each core. 78 With a global batch size `train_batch_size` in `TPUEstimator` constructor, 79 the batch size for each shard is `train_batch_size` // #hosts in the 80 `True` or `PER_HOST_V1` mode. In `PER_HOST_V2` mode, it is 81 `train_batch_size` // #cores. In `BROADCAST` mode, `input_fn` is only 82 invoked once on host 0 and the tensors are broadcasted to all other 83 replicas. The batch size equals to train_batch_size`. With the per-core 84 input pipeline configuration, the shard batch size is also 85 `train_batch_size` // #cores. 86 Note: per_host_input_for_training==PER_SHARD_V1 only supports mode.TRAIN. 87 tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred 88 within TPUEstimator, however when using ClusterSpec propagation in more 89 esoteric cluster configurations, you may need to specify the job name as a 90 string. 91 initial_infeed_sleep_secs: The number of seconds the infeed thread should 92 wait before enqueueing the first batch. This helps avoid timeouts for 93 models that require a long compilation time. 94 input_partition_dims: A nested list to describe the partition dims 95 for all the tensors from input_fn(). The structure of 96 input_partition_dims must match the structure of `features` and 97 `labels` from input_fn(). The total number of partitions must match 98 `num_cores_per_replica`. For example, if input_fn() returns two tensors: 99 images with shape [N, H, W, C] and labels [N]. 100 input_partition_dims = [[1, 2, 2, 1], None] will split the images to 4 101 pieces and feed into 4 TPU cores. labels tensor are directly broadcasted 102 to all the TPU cores since the partition dims is `None`. 103 Current limitations: This feature is only supported with the PER_HOST_V2 104 input mode. 105 eval_training_input_configuration: If `SLICED`, `input_fn` is only 106 invoked once on host 0 and the tensors are broadcasted to all other 107 replicas. Unlike per_host_input_for_training=BROADCAST, each replica will 108 only get a slice of the data instead of a whole copy. If `PER_HOST_V1`, 109 the behaviour is determined by per_host_input_for_training. 110 111 Raises: 112 ValueError: If `num_cores_per_replica` is not 1, 2, 4, 8 or 16. 113 """ 114 115 def __new__( 116 cls, 117 iterations_per_loop=2, 118 num_shards=None, 119 num_cores_per_replica=None, 120 per_host_input_for_training=True, 121 tpu_job_name=None, 122 initial_infeed_sleep_secs=None, 123 input_partition_dims=None, 124 eval_training_input_configuration=InputPipelineConfig.PER_HOST_V1): 125 126 # Check iterations_per_loop. 127 util_lib.check_positive_integer(iterations_per_loop, 128 'TPUConfig iterations_per_loop') 129 130 # Check num_shards. 131 if num_shards is not None: 132 util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards') 133 134 if input_partition_dims is not None: 135 if len(input_partition_dims) != 1 and len(input_partition_dims) != 2: 136 raise ValueError( 137 'input_partition_dims must be a list/tuple with one or two' 138 ' elements.') 139 140 if per_host_input_for_training is not InputPipelineConfig.PER_HOST_V2: 141 raise ValueError( 142 'input_partition_dims is only supported in PER_HOST_V2 mode.') 143 144 if num_cores_per_replica is None: 145 raise ValueError( 146 'input_partition_dims requires setting num_cores_per_replica.') 147 148 # Check num_cores_per_replica 149 if num_cores_per_replica is not None: 150 if num_cores_per_replica not in [1, 2, 4, 8, 16]: 151 raise ValueError( 152 'num_cores_per_replica must be 1, 2, 4, 8, or 16; got {}'.format( 153 str(num_cores_per_replica))) 154 155 if eval_training_input_configuration not in [ 156 InputPipelineConfig.PER_HOST_V1, InputPipelineConfig.SLICED 157 ]: 158 raise ValueError( 159 'eval_training_input_configuration must be PER_HOST_V1 or SLICED;' 160 ' got {}'.format(str(eval_training_input_configuration))) 161 162 # per_host_input_for_training may be True, False, or integer in [1..3]. 163 # Map legacy values (True, False) to numeric values. 164 if per_host_input_for_training is False: 165 per_host_input_for_training = InputPipelineConfig.PER_SHARD_V1 166 elif per_host_input_for_training is True: 167 per_host_input_for_training = InputPipelineConfig.PER_HOST_V1 168 169 # Check initial_infeed_sleep_secs. 170 if initial_infeed_sleep_secs: 171 util_lib.check_positive_integer(initial_infeed_sleep_secs, 172 'TPUConfig initial_infeed_sleep_secs') 173 174 tpu_job_name = tpu_job_name or _get_tpu_job_name_from_tf_config() 175 176 return super(TPUConfig, cls).__new__( 177 cls, 178 iterations_per_loop=iterations_per_loop, 179 num_shards=num_shards, 180 num_cores_per_replica=num_cores_per_replica, 181 per_host_input_for_training=per_host_input_for_training, 182 tpu_job_name=tpu_job_name, 183 initial_infeed_sleep_secs=initial_infeed_sleep_secs, 184 input_partition_dims=input_partition_dims, 185 eval_training_input_configuration=eval_training_input_configuration) 186 187 188class RunConfig(run_config_lib.RunConfig): 189 """RunConfig with TPU support.""" 190 191 def __init__(self, 192 tpu_config=None, 193 evaluation_master=None, 194 master=None, 195 cluster=None, 196 **kwargs): 197 """Constructs a RunConfig. 198 199 Args: 200 tpu_config: the TPUConfig that specifies TPU-specific configuration. 201 evaluation_master: a string. The address of the master to use for eval. 202 Defaults to master if not set. 203 master: a string. The address of the master to use for training. 204 cluster: a ClusterResolver 205 **kwargs: keyword config parameters. 206 207 Raises: 208 ValueError: if cluster is not None and the provided session_config has a 209 cluster_def already. 210 """ 211 super(RunConfig, self).__init__(**kwargs) 212 self._tpu_config = tpu_config or TPUConfig() 213 self._cluster = cluster 214 215 # If user sets master and/or evaluation_master explicitly, including empty 216 # string '', take it. Otherwise, take the values set by parent class. 217 if master is not None: 218 if cluster is not None: 219 raise ValueError('Both master and cluster are set.') 220 self._master = master 221 else: 222 if cluster: 223 self._master = cluster.master() 224 225 if evaluation_master is not None: 226 self._evaluation_master = evaluation_master 227 elif (not self._evaluation_master and 228 self.task_type != run_config_lib.TaskType.EVALUATOR): 229 # If the task type is EVALUATOR, it means some cluster manager sets the 230 # TF_CONFIG. In that case, we respect the configuration in TF_CONFIG. 231 # 232 # Otherwise, it means user executes the code without external cluster 233 # manager. For that, we optimize the user experience by setting 234 # evaluation_master to master, unless user overwrites it. 235 self._evaluation_master = self._master 236 237 # Set the ClusterSpec to use 238 if cluster: 239 self._cluster_spec = cluster.cluster_spec() 240 241 # Merge the cluster_def into the ConfigProto. 242 if self._session_config is None: # pylint: disable=access-member-before-definition 243 self._session_config = config_pb2.ConfigProto( 244 allow_soft_placement=True, isolate_session_state=True) 245 if self._session_config.HasField('cluster_def'): 246 raise ValueError( 247 'You cannot provide a ClusterResolver and ' 248 'session_config.cluster_def.') 249 if self._cluster_spec: 250 self._session_config.cluster_def.CopyFrom( 251 self._cluster_spec.as_cluster_def()) 252 253 def _maybe_overwrite_session_config_for_distributed_training(self): 254 # Overrides the parent class session_config overwrite for between-graph. TPU 255 # runs with in-graph, which should not have device filter. Doing nothing 256 # ("pass") basically disables it. 257 pass 258 259 @property 260 def evaluation_master(self): 261 return self._evaluation_master 262 263 @property 264 def master(self): 265 return self._master 266 267 @property 268 def tpu_config(self): 269 return self._tpu_config 270 271 @property 272 def cluster(self): 273 return self._cluster 274 275 def replace(self, **kwargs): 276 if 'tpu_config' not in kwargs: 277 return super(RunConfig, self).replace(**kwargs) 278 279 tpu_config = kwargs.pop('tpu_config') 280 new_instance = super(RunConfig, self).replace(**kwargs) 281 new_instance._tpu_config = tpu_config # pylint: disable=protected-access 282 return new_instance 283 284 285def _get_tpu_job_name_from_tf_config(): 286 """Extracts the TPU job name from TF_CONFIG env variable.""" 287 # TODO(xiejw): Extends this to support both TF_CONFIG env variable and cluster 288 # spec propagation. 289 tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV, '{}')) 290 tpu_job_name = tf_config.get(_SERVICE_KEY, {}).get(_TPU_WORKER_JOB_NAME) 291 if tpu_job_name: 292 logging.info('Load TPU job name from TF_CONFIG: %s', tpu_job_name) 293 return tpu_job_name 294