1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A Python interface for creating TensorFlow servers."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.protobuf import cluster_pb2
22from tensorflow.core.protobuf import tensorflow_server_pb2
23from tensorflow.python import pywrap_tensorflow
24from tensorflow.python.framework import errors
25from tensorflow.python.util import compat
26from tensorflow.python.util.tf_export import tf_export
27
28
29def _make_server_def(server_or_cluster_def, job_name, task_index, protocol,
30                     config):
31  """Creates a `tf.train.ServerDef` protocol buffer.
32
33  Args:
34    server_or_cluster_def: A `tf.train.ServerDef` or
35      `tf.train.ClusterDef` protocol buffer, or a
36      `tf.train.ClusterSpec` object, describing the server to be
37      defined and/or the cluster of which it is a member.
38    job_name: (Optional.) Specifies the name of the job of which the server
39      is a member. Defaults to the value in `server_or_cluster_def`, if
40      specified.
41    task_index: (Optional.) Specifies the task index of the server in its job.
42      Defaults to the value in `server_or_cluster_def`, if specified. Otherwise
43      defaults to 0 if the server's job has only one task.
44    protocol: (Optional.) Specifies the protocol to be used by the server.
45      Acceptable values include `"grpc"`. Defaults to the value in
46      `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
47    config: (Options.) A `tf.ConfigProto` that specifies default configuration
48      options for all sessions that run on this server.
49
50  Returns:
51    A `tf.train.ServerDef`.
52
53  Raises:
54    TypeError: If the arguments do not have the appropriate type.
55    ValueError: If an argument is not specified and cannot be inferred.
56  """
57  server_def = tensorflow_server_pb2.ServerDef()
58  if isinstance(server_or_cluster_def, tensorflow_server_pb2.ServerDef):
59    server_def.MergeFrom(server_or_cluster_def)
60    if job_name is not None:
61      server_def.job_name = job_name
62    if task_index is not None:
63      server_def.task_index = task_index
64    if protocol is not None:
65      server_def.protocol = protocol
66    if config is not None:
67      server_def.default_session_config.MergeFrom(config)
68  else:
69    try:
70      cluster_spec = ClusterSpec(server_or_cluster_def)
71    except TypeError:
72      raise TypeError("Could not convert `server_or_cluster_def` to a "
73                      "`tf.train.ServerDef` or `tf.train.ClusterSpec`.")
74    if job_name is None:
75      if len(cluster_spec.jobs) == 1:
76        job_name = cluster_spec.jobs[0]
77      else:
78        raise ValueError("Must specify an explicit `job_name`.")
79    if task_index is None:
80      task_indices = cluster_spec.task_indices(job_name)
81      if len(task_indices) == 1:
82        task_index = task_indices[0]
83      else:
84        raise ValueError("Must specify an explicit `task_index`.")
85    if protocol is None:
86      protocol = "grpc"
87
88    server_def = tensorflow_server_pb2.ServerDef(
89        cluster=cluster_spec.as_cluster_def(),
90        job_name=job_name, task_index=task_index, protocol=protocol)
91    if config is not None:
92      server_def.default_session_config.MergeFrom(config)
93  return server_def
94
95
96@tf_export("train.Server")
97class Server(object):
98  """An in-process TensorFlow server, for use in distributed training.
99
100  A `tf.train.Server` instance encapsulates a set of devices and a
101  @{tf.Session} target that
102  can participate in distributed training. A server belongs to a
103  cluster (specified by a @{tf.train.ClusterSpec}), and
104  corresponds to a particular task in a named job. The server can
105  communicate with any other server in the same cluster.
106  """
107
108  def __init__(self,
109               server_or_cluster_def,
110               job_name=None,
111               task_index=None,
112               protocol=None,
113               config=None,
114               start=True):
115    """Creates a new server with the given definition.
116
117    The `job_name`, `task_index`, and `protocol` arguments are optional, and
118    override any information provided in `server_or_cluster_def`.
119
120    Args:
121      server_or_cluster_def: A `tf.train.ServerDef` or
122        `tf.train.ClusterDef` protocol buffer, or a
123        `tf.train.ClusterSpec` object, describing the server to be
124        created and/or the cluster of which it is a member.
125      job_name: (Optional.) Specifies the name of the job of which the server
126        is a member. Defaults to the value in `server_or_cluster_def`, if
127        specified.
128      task_index: (Optional.) Specifies the task index of the server in its
129        job. Defaults to the value in `server_or_cluster_def`, if specified.
130        Otherwise defaults to 0 if the server's job has only one task.
131      protocol: (Optional.) Specifies the protocol to be used by the server.
132        Acceptable values include `"grpc"`. Defaults to the value in
133        `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
134      config: (Options.) A `tf.ConfigProto` that specifies default
135        configuration options for all sessions that run on this server.
136      start: (Optional.) Boolean, indicating whether to start the server
137        after creating it. Defaults to `True`.
138
139    Raises:
140      tf.errors.OpError: Or one of its subclasses if an error occurs while
141        creating the TensorFlow server.
142    """
143    self._server_def = _make_server_def(server_or_cluster_def,
144                                        job_name, task_index, protocol, config)
145    with errors.raise_exception_on_not_ok_status() as status:
146      self._server = pywrap_tensorflow.PyServer_New(
147          self._server_def.SerializeToString(), status)
148    if start:
149      self.start()
150
151  def start(self):
152    """Starts this server.
153
154    Raises:
155      tf.errors.OpError: Or one of its subclasses if an error occurs while
156        starting the TensorFlow server.
157    """
158    with errors.raise_exception_on_not_ok_status() as status:
159      pywrap_tensorflow.PyServer_Start(self._server, status)
160
161  def join(self):
162    """Blocks until the server has shut down.
163
164    This method currently blocks forever.
165
166    Raises:
167      tf.errors.OpError: Or one of its subclasses if an error occurs while
168        joining the TensorFlow server.
169    """
170    with errors.raise_exception_on_not_ok_status() as status:
171      pywrap_tensorflow.PyServer_Join(self._server, status)
172
173  @property
174  def server_def(self):
175    """Returns the `tf.train.ServerDef` for this server.
176
177    Returns:
178      A `tf.train.ServerDef` protocol buffer that describes the configuration
179      of this server.
180    """
181    return self._server_def
182
183  @property
184  def target(self):
185    """Returns the target for a `tf.Session` to connect to this server.
186
187    To create a
188    @{tf.Session} that
189    connects to this server, use the following snippet:
190
191    ```python
192    server = tf.train.Server(...)
193    with tf.Session(server.target):
194      # ...
195    ```
196
197    Returns:
198      A string containing a session target for this server.
199    """
200    return self._server.target()
201
202  @staticmethod
203  def create_local_server(config=None, start=True):
204    """Creates a new single-process cluster running on the local host.
205
206    This method is a convenience wrapper for creating a
207    `tf.train.Server` with a `tf.train.ServerDef` that specifies a
208    single-process cluster containing a single task in a job called
209    `"local"`.
210
211    Args:
212      config: (Options.) A `tf.ConfigProto` that specifies default
213        configuration options for all sessions that run on this server.
214      start: (Optional.) Boolean, indicating whether to start the server after
215        creating it. Defaults to `True`.
216
217    Returns:
218      A local `tf.train.Server`.
219    """
220    # Specifying port 0 means that the OS will choose a free port for the
221    # server.
222    return Server({"local": ["localhost:0"]}, protocol="grpc", config=config,
223                  start=start)
224
225
226@tf_export("train.ClusterSpec")
227class ClusterSpec(object):
228  """Represents a cluster as a set of "tasks", organized into "jobs".
229
230  A `tf.train.ClusterSpec` represents the set of processes that
231  participate in a distributed TensorFlow computation. Every
232  @{tf.train.Server} is constructed in a particular cluster.
233
234  To create a cluster with two jobs and five tasks, you specify the
235  mapping from job names to lists of network addresses (typically
236  hostname-port pairs).
237
238  ```python
239  cluster = tf.train.ClusterSpec({"worker": ["worker0.example.com:2222",
240                                             "worker1.example.com:2222",
241                                             "worker2.example.com:2222"],
242                                  "ps": ["ps0.example.com:2222",
243                                         "ps1.example.com:2222"]})
244  ```
245
246  Each job may also be specified as a sparse mapping from task indices
247  to network addresses. This enables a server to be configured without
248  needing to know the identity of (for example) all other worker
249  tasks:
250
251  ```python
252  cluster = tf.train.ClusterSpec({"worker": {1: "worker1.example.com:2222"},
253                                  "ps": ["ps0.example.com:2222",
254                                         "ps1.example.com:2222"]})
255  ```
256  """
257
258  def __init__(self, cluster):
259    """Creates a `ClusterSpec`.
260
261    Args:
262      cluster: A dictionary mapping one or more job names to (i) a
263        list of network addresses, or (ii) a dictionary mapping integer
264        task indices to network addresses; or a `tf.train.ClusterDef`
265        protocol buffer.
266
267    Raises:
268      TypeError: If `cluster` is not a dictionary mapping strings to lists
269        of strings, and not a `tf.train.ClusterDef` protobuf.
270    """
271    if isinstance(cluster, dict):
272      self._cluster_spec = {}
273      for job_name, tasks in cluster.items():
274        if isinstance(tasks, (list, tuple)):
275          job_tasks = {i: task for i, task in enumerate(tasks)}
276        elif isinstance(tasks, dict):
277          job_tasks = {i: task for i, task in tasks.items()}
278        else:
279          raise TypeError("The tasks for job %r must be a list or a dictionary "
280                          "from integers to strings." % job_name)
281        self._cluster_spec[job_name] = job_tasks
282      self._make_cluster_def()
283    elif isinstance(cluster, cluster_pb2.ClusterDef):
284      self._cluster_def = cluster
285      self._cluster_spec = {}
286      for job_def in self._cluster_def.job:
287        self._cluster_spec[job_def.name] = {
288            i: t for i, t in job_def.tasks.items()}
289    elif isinstance(cluster, ClusterSpec):
290      self._cluster_def = cluster_pb2.ClusterDef()
291      self._cluster_def.MergeFrom(cluster.as_cluster_def())
292      self._cluster_spec = {}
293      for job_def in self._cluster_def.job:
294        self._cluster_spec[job_def.name] = {
295            i: t for i, t in job_def.tasks.items()}
296    else:
297      raise TypeError("`cluster` must be a dictionary mapping one or more "
298                      "job names to lists of network addresses, or a "
299                      "`ClusterDef` protocol buffer")
300
301  def __nonzero__(self):
302    return bool(self._cluster_spec)
303
304  # Python 3.x
305  __bool__ = __nonzero__
306
307  def __eq__(self, other):
308    return self._cluster_spec == other
309
310  def __ne__(self, other):
311    return self._cluster_spec != other
312
313  def __str__(self):
314    key_values = self.as_dict()
315    string_items = [
316        repr(k) + ": " + repr(key_values[k]) for k in sorted(key_values)]
317    return "ClusterSpec({" + ", ".join(string_items) + "})"
318
319  def as_dict(self):
320    """Returns a dictionary from job names to their tasks.
321
322    For each job, if the task index space is dense, the corresponding
323    value will be a list of network addresses; otherwise it will be a
324    dictionary mapping (sparse) task indices to the corresponding
325    addresses.
326
327    Returns:
328      A dictionary mapping job names to lists or dictionaries
329      describing the tasks in those jobs.
330    """
331    ret = {}
332    for job in self.jobs:
333      task_indices = self.task_indices(job)
334      if max(task_indices) + 1 == len(task_indices):
335        # Return a list because the task indices are dense. This
336        # matches the behavior of `as_dict()` before support for
337        # sparse jobs was added.
338        ret[job] = self.job_tasks(job)
339      else:
340        ret[job] = {i: self.task_address(job, i) for i in task_indices}
341    return ret
342
343  def as_cluster_def(self):
344    """Returns a `tf.train.ClusterDef` protocol buffer based on this cluster."""
345    return self._cluster_def
346
347  @property
348  def jobs(self):
349    """Returns a list of job names in this cluster.
350
351    Returns:
352      A list of strings, corresponding to the names of jobs in this cluster.
353    """
354    return list(self._cluster_spec.keys())
355
356  def num_tasks(self, job_name):
357    """Returns the number of tasks defined in the given job.
358
359    Args:
360      job_name: The string name of a job in this cluster.
361
362    Returns:
363      The number of tasks defined in the given job.
364
365    Raises:
366      ValueError: If `job_name` does not name a job in this cluster.
367    """
368    try:
369      job = self._cluster_spec[job_name]
370    except KeyError:
371      raise ValueError("No such job in cluster: %r" % job_name)
372    return len(job)
373
374  def task_indices(self, job_name):
375    """Returns a list of valid task indices in the given job.
376
377    Args:
378      job_name: The string name of a job in this cluster.
379
380    Returns:
381      A list of valid task indices in the given job.
382
383    Raises:
384      ValueError: If `job_name` does not name a job in this cluster,
385      or no task with index `task_index` is defined in that job.
386    """
387    try:
388      job = self._cluster_spec[job_name]
389    except KeyError:
390      raise ValueError("No such job in cluster: %r" % job_name)
391    return list(sorted(job.keys()))
392
393  def task_address(self, job_name, task_index):
394    """Returns the address of the given task in the given job.
395
396    Args:
397      job_name: The string name of a job in this cluster.
398      task_index: A non-negative integer.
399
400    Returns:
401      The address of the given task in the given job.
402
403    Raises:
404      ValueError: If `job_name` does not name a job in this cluster,
405      or no task with index `task_index` is defined in that job.
406    """
407    try:
408      job = self._cluster_spec[job_name]
409    except KeyError:
410      raise ValueError("No such job in cluster: %r" % job_name)
411    try:
412      return job[task_index]
413    except KeyError:
414      raise ValueError("No task with index %r in job %r"
415                       % (task_index, job_name))
416
417  def job_tasks(self, job_name):
418    """Returns a mapping from task ID to address in the given job.
419
420    NOTE: For backwards compatibility, this method returns a list. If
421    the given job was defined with a sparse set of task indices, the
422    length of this list may not reflect the number of tasks defined in
423    this job. Use the @{tf.train.ClusterSpec.num_tasks} method
424    to find the number of tasks defined in a particular job.
425
426    Args:
427      job_name: The string name of a job in this cluster.
428
429    Returns:
430      A list of task addresses, where the index in the list
431      corresponds to the task index of each task. The list may contain
432      `None` if the job was defined with a sparse set of task indices.
433
434    Raises:
435      ValueError: If `job_name` does not name a job in this cluster.
436    """
437    try:
438      job = self._cluster_spec[job_name]
439    except KeyError:
440      raise ValueError("No such job in cluster: %r" % job_name)
441    ret = [None for _ in range(max(job.keys()) + 1)]
442    for i, task in job.items():
443      ret[i] = task
444    return ret
445
446  def _make_cluster_def(self):
447    """Creates a `tf.train.ClusterDef` based on the given `cluster_spec`.
448
449    Raises:
450      TypeError: If `cluster_spec` is not a dictionary mapping strings to lists
451        of strings.
452    """
453    self._cluster_def = cluster_pb2.ClusterDef()
454
455    # NOTE(mrry): Sort by job_name to produce deterministic protobufs.
456    for job_name, tasks in sorted(self._cluster_spec.items()):
457      try:
458        job_name = compat.as_bytes(job_name)
459      except TypeError:
460        raise TypeError("Job name %r must be bytes or unicode" % job_name)
461
462      job_def = self._cluster_def.job.add()
463      job_def.name = job_name
464
465      for i, task_address in sorted(tasks.items()):
466        try:
467          task_address = compat.as_bytes(task_address)
468        except TypeError:
469          raise TypeError(
470              "Task address %r must be bytes or unicode" % task_address)
471        job_def.tasks[i] = task_address
472