1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for multi-worker distribution strategies."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.protobuf import cluster_pb2
22from tensorflow.python.distribute import distribute_coordinator_context as dc_context
23from tensorflow.python.training import server_lib
24
25
26def normalize_cluster_spec(cluster_spec):
27  """Makes `cluster_spec` into a `ClusterSpec` object.
28
29  Args:
30    cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
31      cluster configurations.
32
33  Returns:
34    a `ClusterSpec` object.
35
36  Raises:
37    ValueError: if `cluster_spec` is not a dict or a `ClusterSpec` or a
38      `ClusterDef`.
39  """
40  if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
41    return server_lib.ClusterSpec(cluster_spec)
42  elif not isinstance(cluster_spec, server_lib.ClusterSpec):
43    raise ValueError(
44        "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
45        "`tf.train.ClusterDef` object")
46  return cluster_spec
47
48
49def task_count(cluster_spec, task_type):
50  try:
51    return cluster_spec.num_tasks(task_type)
52  except ValueError:
53    return 0
54
55
56def _validate_cluster_spec(cluster_spec,
57                           task_type,
58                           task_id):
59  """Validates `cluster_spec`.
60
61  It checks:
62  1) task type is one of "chief", "worker", "ps", "evaluator", or not provided
63     (None).
64  2) whether there is such a task type as `task_type` in the `cluster_spec`. The
65     only exception is `evaluator`. In other words, it is still a valid
66     configuration when `task_type` is `evaluator` but it doesn't appear in
67     `cluster_spec`. This is to be compatible with `TF_CONFIG` in Estimator.
68  3) whether there is at most one "chief" job.
69  4) whether there is at most one "evaluator" job.
70  5) whether the `task_id` is smaller than the number of tasks for that
71     particular `task_type`.
72
73  Args:
74    cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated.
75    task_type: string indicating the type of the task.
76    task_id: the id of the `task_type` in this cluster.
77
78  Raises:
79    ValueError: if `cluster_spec` fails any check.
80  """
81  allowed_task_types = ("chief", "worker", "evaluator", "ps", None)
82
83  cluster_spec = normalize_cluster_spec(cluster_spec)
84
85  if any([job not in allowed_task_types for job in cluster_spec.jobs]):
86    raise ValueError("Disallowed task type found in cluster spec. Allowed "
87                     "types are {} and the cluster spec is {}.".format(
88                         allowed_task_types, cluster_spec))
89
90  if task_type not in allowed_task_types:
91    raise ValueError(
92        "Unrecognized task_type: {}, valid task types are: {}".format(
93            task_type, allowed_task_types))
94
95  if (task_type and task_type not in cluster_spec.jobs and
96      task_type != "evaluator"):
97    raise ValueError("`task_type` %r not found in cluster_spec." % task_type)
98
99  if task_count(cluster_spec, "chief") > 1:
100    raise ValueError("There must be at most one 'chief' job.")
101
102  if task_count(cluster_spec, "evaluator") > 1:
103    raise ValueError("There must be at most one 'evaluator' job.")
104
105  # The `evaluator` job is allowed to be missing in `cluster_spec`.
106  if task_type in cluster_spec.jobs and task_id >= task_count(
107      cluster_spec, task_type):
108    raise ValueError(
109        "The `task_id` %d exceeds the maximum id of %s." % (task_id, task_type))
110
111
112def is_chief(cluster_spec=None, task_type=None, task_id=None):
113  """Returns whether the given task is chief in the cluster.
114
115  Since there is at most one evaluator and the evaluator itself should be
116  independent of the training cluster, the evaluator job is also a chief job on
117  its own.
118
119  If this is currently running under a `_WorkerContext` of distribute
120  coordinator, the arguments can be omitted as the result is already available.
121
122  Args:
123    cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the
124      cluster configurations.
125    task_type: the task type in the cluster.
126    task_id: the task id in the cluster.
127
128  Returns:
129    a boolean indicating whether the given task is chief.
130
131  Raises:
132    ValueError: if `task_type` is not in the `cluster_spec` or `task_id` exceeds
133      the maximum id of the `task_type`.
134  """
135  if has_worker_context():
136    # If a worker context exists, use the value provided by it.
137    return dc_context.get_current_worker_context().is_chief
138
139  _validate_cluster_spec(cluster_spec, task_type, task_id)
140  cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
141
142  if task_type == "chief" or task_type == "evaluator":
143    return True
144
145  # If chief not in the cluster_spec, use the first worker as chief. This is
146  # common in CollectiveAllReduceStrategy.
147  if ("chief" not in cluster_spec and task_type == "worker" and task_id == 0):
148    return True
149  return False
150
151
152def collective_leader(cluster_spec, task_type, task_id):
153  """Return the job name for the leader of for collective ops.
154
155  Args:
156    cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the
157      cluster configurations.
158    task_type: the task type in the cluster.
159    task_id: the task id in the cluster.
160
161  Returns:
162    a string indicating the leader job name or empty string if no need to set
163    leader job.
164  """
165  cluster_spec = normalize_cluster_spec(cluster_spec)
166
167  # No need to set collective leader for local.
168  if not cluster_spec.as_dict():
169    return ""
170
171  _validate_cluster_spec(cluster_spec, task_type, task_id)
172
173  # Only one evaluator, so no need to set collective leader.
174  if task_type == "evaluator":
175    return ""
176
177  # Use chief if chief is in the cluster.
178  if "chief" in cluster_spec.jobs:
179    return "/job:chief/replica:0/task:0"
180
181  # Use worker 0 if no chief job.
182  assert "worker" in cluster_spec.jobs
183  return "/job:worker/replica:0/task:0"
184
185
186def worker_count(cluster_spec, task_type):
187  """Returns the number of workers in the cluster."""
188  _validate_cluster_spec(cluster_spec, task_type, task_id=0)
189  cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
190
191  # Other jobs such as "ps" shouldn't call this function.
192  if task_type not in ["chief", "worker", "evaluator"]:
193    raise ValueError("Unexpected `task_type` %r" % task_type)
194
195  if task_type == "evaluator":
196    # The "evaluator" is in its own cluster or its own partition of a cluster.
197    # So we don't have to count "chief" or "worker" if the current task is an
198    # "evaluator".
199    return len(cluster_spec["evaluator"])
200  else:
201    # In the non-evaluator case, we return the total number of "chief" and
202    # "worker" tasks as the "chief" is also a worker.
203    return (len(cluster_spec.get("chief", [])) + len(
204        cluster_spec.get("worker", [])))
205
206
207def id_in_cluster(cluster_spec, task_type, task_id):
208  """Returns a unique id for the task in the `task_type`'s cluster.
209
210  It returns an id ranging from [0, `worker_count(task_type, task_id)`).
211
212  Note: this function assumes that "evaluate" job is in its own cluster or its
213  own partition of a cluster.
214
215  Args:
216    cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object to be validated.
217    task_type: string indicating the type of the task.
218    task_id: the id of the `task_type` in this cluster.
219
220  Returns:
221    an int indicating the unique id.
222
223  Throws:
224    ValueError: if `task_type` is not "chief", "worker" or "evaluator".
225  """
226  _validate_cluster_spec(cluster_spec, task_type, task_id)
227  cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
228
229  # The "chief" job has always id 0 and there is at most one and "worker" jobs
230  # come after it.
231  if task_type == "chief":
232    return 0
233
234  if task_type == "worker":
235    return task_id + len(cluster_spec.get("chief", []))
236
237  # The "evaluator" is in its own cluster or its own partition of a cluster.
238  if task_type == "evaluator":
239    return task_id
240
241  # We currently don't assign ids to other tasks.
242  raise ValueError("There is no id for task_type %r" % task_type)
243
244
245def should_save_checkpoint():
246  """Returns whether the current worker should save checkpoints.
247
248  In multi-worker training, if saving checkpoint is requested by user, or needed
249  for fault-tolerance, the cluster should save checkpoint but not necessarily
250  every worker in the cluster should.
251
252  TODO(rchao): Consider generalizing this util to be `should_save_file` as there
253  can be other files to save such as summary.
254
255  Returns:
256      Whether this particular worker in the cluster should save checkpoints.
257  """
258  return dc_context.get_current_worker_context().should_checkpoint
259
260
261def should_load_checkpoint():
262  """Returns whether the current worker should load checkpoints.
263
264  In multi-worker training, if loading checkpoint is requested by user, or
265  needed for fault-tolerance, the cluster should load checkpoint but not
266  necessarily every worker in the cluster should.
267
268  Returns:
269      Whether this particular worker in the cluster should load checkpoints.
270  """
271  return dc_context.get_current_worker_context().experimental_should_init
272
273
274def wait_for_other_workers():
275  """Waits for other workers to reach the same call to this method."""
276  return dc_context.get_current_worker_context().wait_for_other_workers()
277
278
279def has_worker_context():
280  """Returns whether a worker context has been entered."""
281  return dc_context.get_current_worker_context() is not None
282