1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A Python interface for creating TensorFlow servers."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.protobuf import cluster_pb2
22from tensorflow.core.protobuf import tensorflow_server_pb2
23from tensorflow.python import pywrap_tensorflow as c_api
24from tensorflow.python.framework import errors
25from tensorflow.python.util import compat
26from tensorflow.python.util import deprecation
27from tensorflow.python.util.tf_export import tf_export
28
29
30def _make_server_def(server_or_cluster_def, job_name, task_index, protocol,
31                     config):
32  """Creates a `tf.train.ServerDef` protocol buffer.
33
34  Args:
35    server_or_cluster_def: A `tf.train.ServerDef` or
36      `tf.train.ClusterDef` protocol buffer, or a
37      `tf.train.ClusterSpec` object, describing the server to be
38      defined and/or the cluster of which it is a member.
39    job_name: (Optional.) Specifies the name of the job of which the server
40      is a member. Defaults to the value in `server_or_cluster_def`, if
41      specified.
42    task_index: (Optional.) Specifies the task index of the server in its job.
43      Defaults to the value in `server_or_cluster_def`, if specified. Otherwise
44      defaults to 0 if the server's job has only one task.
45    protocol: (Optional.) Specifies the protocol to be used by the server.
46      Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value
47      in `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
48    config: (Options.) A `tf.ConfigProto` that specifies default configuration
49      options for all sessions that run on this server.
50
51  Returns:
52    A `tf.train.ServerDef`.
53
54  Raises:
55    TypeError: If the arguments do not have the appropriate type.
56    ValueError: If an argument is not specified and cannot be inferred.
57  """
58  server_def = tensorflow_server_pb2.ServerDef()
59  if isinstance(server_or_cluster_def, tensorflow_server_pb2.ServerDef):
60    server_def.MergeFrom(server_or_cluster_def)
61    if job_name is not None:
62      server_def.job_name = job_name
63    if task_index is not None:
64      server_def.task_index = task_index
65    if protocol is not None:
66      server_def.protocol = protocol
67    if config is not None:
68      server_def.default_session_config.MergeFrom(config)
69  else:
70    try:
71      cluster_spec = ClusterSpec(server_or_cluster_def)
72    except TypeError:
73      raise TypeError("Could not convert `server_or_cluster_def` to a "
74                      "`tf.train.ServerDef` or `tf.train.ClusterSpec`.")
75    if job_name is None:
76      if len(cluster_spec.jobs) == 1:
77        job_name = cluster_spec.jobs[0]
78      else:
79        raise ValueError("Must specify an explicit `job_name`.")
80    if task_index is None:
81      task_indices = cluster_spec.task_indices(job_name)
82      if len(task_indices) == 1:
83        task_index = task_indices[0]
84      else:
85        raise ValueError("Must specify an explicit `task_index`.")
86    if protocol is None:
87      protocol = "grpc"
88
89    server_def = tensorflow_server_pb2.ServerDef(
90        cluster=cluster_spec.as_cluster_def(),
91        job_name=job_name, task_index=task_index, protocol=protocol)
92    if config is not None:
93      server_def.default_session_config.MergeFrom(config)
94  return server_def
95
96
97@tf_export("distribute.Server", v1=["distribute.Server", "train.Server"])
98@deprecation.deprecated_endpoints("train.Server")
99class Server(object):
100  """An in-process TensorFlow server, for use in distributed training.
101
102  A `tf.train.Server` instance encapsulates a set of devices and a
103  `tf.Session` target that
104  can participate in distributed training. A server belongs to a
105  cluster (specified by a `tf.train.ClusterSpec`), and
106  corresponds to a particular task in a named job. The server can
107  communicate with any other server in the same cluster.
108  """
109
110  def __init__(self,
111               server_or_cluster_def,
112               job_name=None,
113               task_index=None,
114               protocol=None,
115               config=None,
116               start=True):
117    """Creates a new server with the given definition.
118
119    The `job_name`, `task_index`, and `protocol` arguments are optional, and
120    override any information provided in `server_or_cluster_def`.
121
122    Args:
123      server_or_cluster_def: A `tf.train.ServerDef` or
124        `tf.train.ClusterDef` protocol buffer, or a
125        `tf.train.ClusterSpec` object, describing the server to be
126        created and/or the cluster of which it is a member.
127      job_name: (Optional.) Specifies the name of the job of which the server
128        is a member. Defaults to the value in `server_or_cluster_def`, if
129        specified.
130      task_index: (Optional.) Specifies the task index of the server in its
131        job. Defaults to the value in `server_or_cluster_def`, if specified.
132        Otherwise defaults to 0 if the server's job has only one task.
133      protocol: (Optional.) Specifies the protocol to be used by the server.
134        Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the
135        value in `server_or_cluster_def`, if specified. Otherwise defaults to
136        `"grpc"`.
137      config: (Options.) A `tf.ConfigProto` that specifies default
138        configuration options for all sessions that run on this server.
139      start: (Optional.) Boolean, indicating whether to start the server
140        after creating it. Defaults to `True`.
141
142    Raises:
143      tf.errors.OpError: Or one of its subclasses if an error occurs while
144        creating the TensorFlow server.
145    """
146    self._server_def = _make_server_def(server_or_cluster_def,
147                                        job_name, task_index, protocol, config)
148    self._server = c_api.TF_NewServer(self._server_def.SerializeToString())
149    if start:
150      self.start()
151
152  def __del__(self):
153    try:
154      c_api.TF_ServerStop(self._server)
155      # Clean shutdown of servers is not yet implemented, so
156      # we leak instead of calling c_api.TF_DeleteServer here.
157      # See:
158      # https://github.com/tensorflow/tensorflow/blob/0495317a6e9dd4cac577b9d5cf9525e62b571018/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h#L73
159    except errors.UnimplementedError:
160      pass
161    except AttributeError:
162      # At shutdown, `c_api` may have been garbage collected.
163      pass
164    self._server = None
165
166  def start(self):
167    """Starts this server.
168
169    Raises:
170      tf.errors.OpError: Or one of its subclasses if an error occurs while
171        starting the TensorFlow server.
172    """
173    c_api.TF_ServerStart(self._server)
174
175  def join(self):
176    """Blocks until the server has shut down.
177
178    This method currently blocks forever.
179
180    Raises:
181      tf.errors.OpError: Or one of its subclasses if an error occurs while
182        joining the TensorFlow server.
183    """
184    c_api.TF_ServerJoin(self._server)
185
186  @property
187  def server_def(self):
188    """Returns the `tf.train.ServerDef` for this server.
189
190    Returns:
191      A `tf.train.ServerDef` protocol buffer that describes the configuration
192      of this server.
193    """
194    return self._server_def
195
196  @property
197  def target(self):
198    """Returns the target for a `tf.Session` to connect to this server.
199
200    To create a
201    `tf.Session` that
202    connects to this server, use the following snippet:
203
204    ```python
205    server = tf.train.Server(...)
206    with tf.Session(server.target):
207      # ...
208    ```
209
210    Returns:
211      A string containing a session target for this server.
212    """
213    return c_api.TF_ServerTarget(self._server)
214
215  @staticmethod
216  def create_local_server(config=None, start=True):
217    """Creates a new single-process cluster running on the local host.
218
219    This method is a convenience wrapper for creating a
220    `tf.train.Server` with a `tf.train.ServerDef` that specifies a
221    single-process cluster containing a single task in a job called
222    `"local"`.
223
224    Args:
225      config: (Options.) A `tf.ConfigProto` that specifies default
226        configuration options for all sessions that run on this server.
227      start: (Optional.) Boolean, indicating whether to start the server after
228        creating it. Defaults to `True`.
229
230    Returns:
231      A local `tf.train.Server`.
232    """
233    # Specifying port 0 means that the OS will choose a free port for the
234    # server.
235    return Server({"local": ["localhost:0"]}, protocol="grpc", config=config,
236                  start=start)
237
238
239@tf_export("train.ClusterSpec")
240class ClusterSpec(object):
241  """Represents a cluster as a set of "tasks", organized into "jobs".
242
243  A `tf.train.ClusterSpec` represents the set of processes that
244  participate in a distributed TensorFlow computation. Every
245  `tf.train.Server` is constructed in a particular cluster.
246
247  To create a cluster with two jobs and five tasks, you specify the
248  mapping from job names to lists of network addresses (typically
249  hostname-port pairs).
250
251  ```python
252  cluster = tf.train.ClusterSpec({"worker": ["worker0.example.com:2222",
253                                             "worker1.example.com:2222",
254                                             "worker2.example.com:2222"],
255                                  "ps": ["ps0.example.com:2222",
256                                         "ps1.example.com:2222"]})
257  ```
258
259  Each job may also be specified as a sparse mapping from task indices
260  to network addresses. This enables a server to be configured without
261  needing to know the identity of (for example) all other worker
262  tasks:
263
264  ```python
265  cluster = tf.train.ClusterSpec({"worker": {1: "worker1.example.com:2222"},
266                                  "ps": ["ps0.example.com:2222",
267                                         "ps1.example.com:2222"]})
268  ```
269  """
270
271  def __init__(self, cluster):
272    """Creates a `ClusterSpec`.
273
274    Args:
275      cluster: A dictionary mapping one or more job names to (i) a
276        list of network addresses, or (ii) a dictionary mapping integer
277        task indices to network addresses; or a `tf.train.ClusterDef`
278        protocol buffer.
279
280    Raises:
281      TypeError: If `cluster` is not a dictionary mapping strings to lists
282        of strings, and not a `tf.train.ClusterDef` protobuf.
283    """
284    if isinstance(cluster, dict):
285      self._cluster_spec = {}
286      for job_name, tasks in cluster.items():
287        if isinstance(tasks, (list, tuple)):
288          job_tasks = {i: task for i, task in enumerate(tasks)}
289        elif isinstance(tasks, dict):
290          job_tasks = {i: task for i, task in tasks.items()}
291        else:
292          raise TypeError("The tasks for job %r must be a list or a dictionary "
293                          "from integers to strings." % job_name)
294        self._cluster_spec[job_name] = job_tasks
295      self._make_cluster_def()
296    elif isinstance(cluster, cluster_pb2.ClusterDef):
297      self._cluster_def = cluster
298      self._cluster_spec = {}
299      for job_def in self._cluster_def.job:
300        self._cluster_spec[job_def.name] = {
301            i: t for i, t in job_def.tasks.items()}
302    elif isinstance(cluster, ClusterSpec):
303      self._cluster_def = cluster_pb2.ClusterDef()
304      self._cluster_def.MergeFrom(cluster.as_cluster_def())
305      self._cluster_spec = {}
306      for job_def in self._cluster_def.job:
307        self._cluster_spec[job_def.name] = {
308            i: t for i, t in job_def.tasks.items()}
309    else:
310      raise TypeError("`cluster` must be a dictionary mapping one or more "
311                      "job names to lists of network addresses, or a "
312                      "`ClusterDef` protocol buffer")
313
314  def __nonzero__(self):
315    return bool(self._cluster_spec)
316
317  # Python 3.x
318  __bool__ = __nonzero__
319
320  def __eq__(self, other):
321    return self._cluster_spec == other
322
323  def __ne__(self, other):
324    return self._cluster_spec != other
325
326  def __str__(self):
327    key_values = self.as_dict()
328    string_items = [
329        repr(k) + ": " + repr(key_values[k]) for k in sorted(key_values)]
330    return "ClusterSpec({" + ", ".join(string_items) + "})"
331
332  def as_dict(self):
333    """Returns a dictionary from job names to their tasks.
334
335    For each job, if the task index space is dense, the corresponding
336    value will be a list of network addresses; otherwise it will be a
337    dictionary mapping (sparse) task indices to the corresponding
338    addresses.
339
340    Returns:
341      A dictionary mapping job names to lists or dictionaries
342      describing the tasks in those jobs.
343    """
344    ret = {}
345    for job in self.jobs:
346      task_indices = self.task_indices(job)
347      if len(task_indices) == 0:
348        ret[job] = {}
349        continue
350      if max(task_indices) + 1 == len(task_indices):
351        # Return a list because the task indices are dense. This
352        # matches the behavior of `as_dict()` before support for
353        # sparse jobs was added.
354        ret[job] = self.job_tasks(job)
355      else:
356        ret[job] = {i: self.task_address(job, i) for i in task_indices}
357    return ret
358
359  def as_cluster_def(self):
360    """Returns a `tf.train.ClusterDef` protocol buffer based on this cluster."""
361    return self._cluster_def
362
363  @property
364  def jobs(self):
365    """Returns a list of job names in this cluster.
366
367    Returns:
368      A list of strings, corresponding to the names of jobs in this cluster.
369    """
370    return list(self._cluster_spec.keys())
371
372  def num_tasks(self, job_name):
373    """Returns the number of tasks defined in the given job.
374
375    Args:
376      job_name: The string name of a job in this cluster.
377
378    Returns:
379      The number of tasks defined in the given job.
380
381    Raises:
382      ValueError: If `job_name` does not name a job in this cluster.
383    """
384    try:
385      job = self._cluster_spec[job_name]
386    except KeyError:
387      raise ValueError("No such job in cluster: %r" % job_name)
388    return len(job)
389
390  def task_indices(self, job_name):
391    """Returns a list of valid task indices in the given job.
392
393    Args:
394      job_name: The string name of a job in this cluster.
395
396    Returns:
397      A list of valid task indices in the given job.
398
399    Raises:
400      ValueError: If `job_name` does not name a job in this cluster,
401      or no task with index `task_index` is defined in that job.
402    """
403    try:
404      job = self._cluster_spec[job_name]
405    except KeyError:
406      raise ValueError("No such job in cluster: %r" % job_name)
407    return list(sorted(job.keys()))
408
409  def task_address(self, job_name, task_index):
410    """Returns the address of the given task in the given job.
411
412    Args:
413      job_name: The string name of a job in this cluster.
414      task_index: A non-negative integer.
415
416    Returns:
417      The address of the given task in the given job.
418
419    Raises:
420      ValueError: If `job_name` does not name a job in this cluster,
421      or no task with index `task_index` is defined in that job.
422    """
423    try:
424      job = self._cluster_spec[job_name]
425    except KeyError:
426      raise ValueError("No such job in cluster: %r" % job_name)
427    try:
428      return job[task_index]
429    except KeyError:
430      raise ValueError("No task with index %r in job %r"
431                       % (task_index, job_name))
432
433  def job_tasks(self, job_name):
434    """Returns a mapping from task ID to address in the given job.
435
436    NOTE: For backwards compatibility, this method returns a list. If
437    the given job was defined with a sparse set of task indices, the
438    length of this list may not reflect the number of tasks defined in
439    this job. Use the `tf.train.ClusterSpec.num_tasks` method
440    to find the number of tasks defined in a particular job.
441
442    Args:
443      job_name: The string name of a job in this cluster.
444
445    Returns:
446      A list of task addresses, where the index in the list
447      corresponds to the task index of each task. The list may contain
448      `None` if the job was defined with a sparse set of task indices.
449
450    Raises:
451      ValueError: If `job_name` does not name a job in this cluster.
452    """
453    try:
454      job = self._cluster_spec[job_name]
455    except KeyError:
456      raise ValueError("No such job in cluster: %r" % job_name)
457    ret = [None for _ in range(max(job.keys()) + 1)]
458    for i, task in job.items():
459      ret[i] = task
460    return ret
461
462  def _make_cluster_def(self):
463    """Creates a `tf.train.ClusterDef` based on the given `cluster_spec`.
464
465    Raises:
466      TypeError: If `cluster_spec` is not a dictionary mapping strings to lists
467        of strings.
468    """
469    self._cluster_def = cluster_pb2.ClusterDef()
470
471    # NOTE(mrry): Sort by job_name to produce deterministic protobufs.
472    for job_name, tasks in sorted(self._cluster_spec.items()):
473      try:
474        job_name = compat.as_bytes(job_name)
475      except TypeError:
476        raise TypeError("Job name %r must be bytes or unicode" % job_name)
477
478      job_def = self._cluster_def.job.add()
479      job_def.name = job_name
480
481      for i, task_address in sorted(tasks.items()):
482        try:
483          task_address = compat.as_bytes(task_address)
484        except TypeError:
485          raise TypeError(
486              "Task address %r must be bytes or unicode" % task_address)
487        job_def.tasks[i] = task_address
488