1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===================================================================
15
16"""A RunConfig subclass with TPU support."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import json
24import os
25
26from tensorflow.core.protobuf import config_pb2
27from tensorflow.python.estimator import run_config as run_config_lib
28from tensorflow.python.platform import tf_logging as logging
29from tensorflow.python.tpu import util as util_lib
30
31# pylint: disable=protected-access
32_TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV
33_SERVICE_KEY = run_config_lib._SERVICE_KEY
34_TPU_WORKER_JOB_NAME = 'tpu_worker_job_name'
35# pylint: enable=protected-access
36
37
38class InputPipelineConfig(object):
39  r"""Please see the definition of these values in TPUConfig."""
40  PER_SHARD_V1 = 1
41  PER_HOST_V1 = 2
42  PER_HOST_V2 = 3
43  BROADCAST = 4
44  SLICED = 5
45
46
47class TPUConfig(
48    collections.namedtuple('TPUConfig', [
49        'iterations_per_loop',
50        'num_shards',
51        'num_cores_per_replica',
52        'per_host_input_for_training',
53        'tpu_job_name',
54        'initial_infeed_sleep_secs',
55        'input_partition_dims',
56        'eval_training_input_configuration',
57    ])):
58  r"""TPU related configuration required by `TPUEstimator`.
59
60  Args:
61    iterations_per_loop: This is the number of train steps running in TPU
62      system before returning to CPU host for each `Session.run`. This means
63      global step is increased `iterations_per_loop` times in one `Session.run`.
64      It is recommended to be set as number of global steps for next checkpoint.
65    num_shards: (Deprecated, ignored by TPUEstimator).
66      The number of model replicas in the system. For non-model-parallelism
67      case, this number equals the total number of TPU cores. For
68      model-parallelism, the total number of TPU cores equals
69      num_cores_per_replica * num_shards.
70    num_cores_per_replica: Defaults to `None`, which disables model parallelism.
71      An integer which describes the number of TPU cores per model replica. This
72      is required by model-parallelism which enables partitioning
73      the model to multiple cores. Currently num_cores_per_replica must be
74      1, 2, 4, or 8.
75    per_host_input_for_training: If `True`, `PER_HOST_V1`, or `PER_HOST_V2`,
76      `input_fn` is invoked once on each host. With the per-core input pipeline
77      configuration, it is invoked once for each core.
78      With a global batch size `train_batch_size` in `TPUEstimator` constructor,
79      the batch size for each shard is `train_batch_size` // #hosts in the
80      `True` or `PER_HOST_V1` mode. In `PER_HOST_V2` mode, it is
81      `train_batch_size` // #cores. In `BROADCAST` mode, `input_fn` is only
82      invoked once on host 0 and the tensors are broadcasted to all other
83      replicas. The batch size equals to train_batch_size`. With the per-core
84      input pipeline configuration, the shard batch size is also
85      `train_batch_size` // #cores.
86      Note: per_host_input_for_training==PER_SHARD_V1 only supports mode.TRAIN.
87    tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred
88      within TPUEstimator, however when using ClusterSpec propagation in more
89      esoteric cluster configurations, you may need to specify the job name as a
90      string.
91    initial_infeed_sleep_secs: The number of seconds the infeed thread should
92      wait before enqueueing the first batch. This helps avoid timeouts for
93      models that require a long compilation time.
94    input_partition_dims: A nested list to describe the partition dims
95      for all the tensors from input_fn(). The structure of
96      input_partition_dims must match the structure of `features` and
97      `labels` from input_fn(). The total number of partitions must match
98      `num_cores_per_replica`. For example, if input_fn() returns two tensors:
99      images with shape [N, H, W, C] and labels [N].
100      input_partition_dims = [[1, 2, 2, 1], None] will split the images to 4
101      pieces and feed into 4 TPU cores. labels tensor are directly broadcasted
102      to all the TPU cores since the partition dims is `None`.
103      Current limitations: This feature is only supported with the PER_HOST_V2
104      input mode.
105    eval_training_input_configuration: If `SLICED`, `input_fn` is only
106      invoked once on host 0 and the tensors are broadcasted to all other
107      replicas. Unlike per_host_input_for_training=BROADCAST, each replica will
108      only get a slice of the data instead of a whole copy. If `PER_HOST_V1`,
109      the behaviour is determined by per_host_input_for_training.
110
111    Raises:
112      ValueError: If `num_cores_per_replica` is not 1, 2, 4, 8 or 16.
113  """
114
115  def __new__(
116      cls,
117      iterations_per_loop=2,
118      num_shards=None,
119      num_cores_per_replica=None,
120      per_host_input_for_training=True,
121      tpu_job_name=None,
122      initial_infeed_sleep_secs=None,
123      input_partition_dims=None,
124      eval_training_input_configuration=InputPipelineConfig.PER_HOST_V1):
125
126    # Check iterations_per_loop.
127    util_lib.check_positive_integer(iterations_per_loop,
128                                    'TPUConfig iterations_per_loop')
129
130    # Check num_shards.
131    if num_shards is not None:
132      util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards')
133
134    if input_partition_dims is not None:
135      if len(input_partition_dims) != 1 and len(input_partition_dims) != 2:
136        raise ValueError(
137            'input_partition_dims must be a list/tuple with one or two'
138            ' elements.')
139
140      if per_host_input_for_training is not InputPipelineConfig.PER_HOST_V2:
141        raise ValueError(
142            'input_partition_dims is only supported in PER_HOST_V2 mode.')
143
144      if num_cores_per_replica is None:
145        raise ValueError(
146            'input_partition_dims requires setting num_cores_per_replica.')
147
148    # Check num_cores_per_replica
149    if num_cores_per_replica is not None:
150      if num_cores_per_replica not in [1, 2, 4, 8, 16]:
151        raise ValueError(
152            'num_cores_per_replica must be 1, 2, 4, 8, or 16; got {}'.format(
153                str(num_cores_per_replica)))
154
155    if eval_training_input_configuration not in [
156        InputPipelineConfig.PER_HOST_V1, InputPipelineConfig.SLICED
157    ]:
158      raise ValueError(
159          'eval_training_input_configuration must be PER_HOST_V1 or SLICED;'
160          ' got {}'.format(str(eval_training_input_configuration)))
161
162    # per_host_input_for_training may be True, False, or integer in [1..3].
163    # Map legacy values (True, False) to numeric values.
164    if per_host_input_for_training is False:
165      per_host_input_for_training = InputPipelineConfig.PER_SHARD_V1
166    elif per_host_input_for_training is True:
167      per_host_input_for_training = InputPipelineConfig.PER_HOST_V1
168
169    # Check initial_infeed_sleep_secs.
170    if initial_infeed_sleep_secs:
171      util_lib.check_positive_integer(initial_infeed_sleep_secs,
172                                      'TPUConfig initial_infeed_sleep_secs')
173
174    tpu_job_name = tpu_job_name or _get_tpu_job_name_from_tf_config()
175
176    return super(TPUConfig, cls).__new__(
177        cls,
178        iterations_per_loop=iterations_per_loop,
179        num_shards=num_shards,
180        num_cores_per_replica=num_cores_per_replica,
181        per_host_input_for_training=per_host_input_for_training,
182        tpu_job_name=tpu_job_name,
183        initial_infeed_sleep_secs=initial_infeed_sleep_secs,
184        input_partition_dims=input_partition_dims,
185        eval_training_input_configuration=eval_training_input_configuration)
186
187
188class RunConfig(run_config_lib.RunConfig):
189  """RunConfig with TPU support."""
190
191  def __init__(self,
192               tpu_config=None,
193               evaluation_master=None,
194               master=None,
195               cluster=None,
196               **kwargs):
197    """Constructs a RunConfig.
198
199    Args:
200      tpu_config: the TPUConfig that specifies TPU-specific configuration.
201      evaluation_master: a string. The address of the master to use for eval.
202        Defaults to master if not set.
203      master: a string. The address of the master to use for training.
204      cluster: a ClusterResolver
205      **kwargs: keyword config parameters.
206
207    Raises:
208      ValueError: if cluster is not None and the provided session_config has a
209        cluster_def already.
210    """
211    super(RunConfig, self).__init__(**kwargs)
212    self._tpu_config = tpu_config or TPUConfig()
213    self._cluster = cluster
214
215    # If user sets master and/or evaluation_master explicitly, including empty
216    # string '', take it. Otherwise, take the values set by parent class.
217    if master is not None:
218      if cluster is not None:
219        raise ValueError('Both master and cluster are set.')
220      self._master = master
221    else:
222      if cluster:
223        self._master = cluster.master()
224
225    if evaluation_master is not None:
226      self._evaluation_master = evaluation_master
227    elif (not self._evaluation_master and
228          self.task_type != run_config_lib.TaskType.EVALUATOR):
229      # If the task type is EVALUATOR, it means some cluster manager sets the
230      # TF_CONFIG. In that case, we respect the configuration in TF_CONFIG.
231      #
232      # Otherwise, it means user executes the code without external cluster
233      # manager. For that, we optimize the user experience by setting
234      # evaluation_master to master, unless user overwrites it.
235      self._evaluation_master = self._master
236
237    # Set the ClusterSpec to use
238    if cluster:
239      self._cluster_spec = cluster.cluster_spec()
240
241      # Merge the cluster_def into the ConfigProto.
242      if self._session_config is None:  # pylint: disable=access-member-before-definition
243        self._session_config = config_pb2.ConfigProto(
244            allow_soft_placement=True, isolate_session_state=True)
245      if self._session_config.HasField('cluster_def'):
246        raise ValueError(
247            'You cannot provide a ClusterResolver and '
248            'session_config.cluster_def.')
249      if self._cluster_spec:
250        self._session_config.cluster_def.CopyFrom(
251            self._cluster_spec.as_cluster_def())
252
253  def _maybe_overwrite_session_config_for_distributed_training(self):
254    # Overrides the parent class session_config overwrite for between-graph. TPU
255    # runs with in-graph, which should not have device filter. Doing nothing
256    # ("pass") basically disables it.
257    pass
258
259  @property
260  def evaluation_master(self):
261    return self._evaluation_master
262
263  @property
264  def master(self):
265    return self._master
266
267  @property
268  def tpu_config(self):
269    return self._tpu_config
270
271  @property
272  def cluster(self):
273    return self._cluster
274
275  def replace(self, **kwargs):
276    if 'tpu_config' not in kwargs:
277      return super(RunConfig, self).replace(**kwargs)
278
279    tpu_config = kwargs.pop('tpu_config')
280    new_instance = super(RunConfig, self).replace(**kwargs)
281    new_instance._tpu_config = tpu_config  # pylint: disable=protected-access
282    return new_instance
283
284
285def _get_tpu_job_name_from_tf_config():
286  """Extracts the TPU job name from TF_CONFIG env variable."""
287  # TODO(xiejw): Extends this to support both TF_CONFIG env variable and cluster
288  # spec propagation.
289  tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV, '{}'))
290  tpu_job_name = tf_config.get(_SERVICE_KEY, {}).get(_TPU_WORKER_JOB_NAME)
291  if tpu_job_name:
292    logging.info('Load TPU job name from TF_CONFIG: %s', tpu_job_name)
293  return tpu_job_name
294