1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A Python interface for creating TensorFlow servers."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.protobuf import cluster_pb2
22from tensorflow.core.protobuf import device_filters_pb2
23from tensorflow.core.protobuf import tensorflow_server_pb2
24from tensorflow.python.client import pywrap_tf_session as c_api
25from tensorflow.python.framework import errors
26from tensorflow.python.util import compat
27from tensorflow.python.util import deprecation
28from tensorflow.python.util.tf_export import tf_export
29
30
31def _make_server_def(server_or_cluster_def, job_name, task_index, protocol,
32                     config):
33  """Creates a `tf.train.ServerDef` protocol buffer.
34
35  Args:
36    server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef`
37      protocol buffer, or a `tf.train.ClusterSpec` object, describing the server
38      to be defined and/or the cluster of which it is a member.
39    job_name: (Optional.) Specifies the name of the job of which the server is a
40      member. Defaults to the value in `server_or_cluster_def`, if specified.
41    task_index: (Optional.) Specifies the task index of the server in its job.
42      Defaults to the value in `server_or_cluster_def`, if specified. Otherwise
43      defaults to 0 if the server's job has only one task.
44    protocol: (Optional.) Specifies the protocol to be used by the server.
45      Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value in
46      `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
47    config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default
48      configuration options for all sessions that run on this server.
49
50  Returns:
51    A `tf.train.ServerDef`.
52
53  Raises:
54    TypeError: If the arguments do not have the appropriate type.
55    ValueError: If an argument is not specified and cannot be inferred.
56  """
57  server_def = tensorflow_server_pb2.ServerDef()
58  if isinstance(server_or_cluster_def, tensorflow_server_pb2.ServerDef):
59    server_def.MergeFrom(server_or_cluster_def)
60    if job_name is not None:
61      server_def.job_name = job_name
62    if task_index is not None:
63      server_def.task_index = task_index
64    if protocol is not None:
65      server_def.protocol = protocol
66    if config is not None:
67      server_def.default_session_config.MergeFrom(config)
68  else:
69    try:
70      cluster_spec = ClusterSpec(server_or_cluster_def)
71    except TypeError:
72      raise TypeError("Could not convert `server_or_cluster_def` to a "
73                      "`tf.train.ServerDef` or `tf.train.ClusterSpec`.")
74    if job_name is None:
75      if len(cluster_spec.jobs) == 1:
76        job_name = cluster_spec.jobs[0]
77      else:
78        raise ValueError("Must specify an explicit `job_name`.")
79    if task_index is None:
80      task_indices = cluster_spec.task_indices(job_name)
81      if len(task_indices) == 1:
82        task_index = task_indices[0]
83      else:
84        raise ValueError("Must specify an explicit `task_index`.")
85    if protocol is None:
86      protocol = "grpc"
87
88    server_def = tensorflow_server_pb2.ServerDef(
89        cluster=cluster_spec.as_cluster_def(),
90        job_name=job_name,
91        task_index=task_index,
92        protocol=protocol)
93    if config is not None:
94      server_def.default_session_config.MergeFrom(config)
95  return server_def
96
97
98@tf_export("distribute.Server", v1=["distribute.Server", "train.Server"])
99@deprecation.deprecated_endpoints("train.Server")
100class Server(object):
101  """An in-process TensorFlow server, for use in distributed training.
102
103  A `tf.distribute.Server` instance encapsulates a set of devices and a
104  `tf.compat.v1.Session` target that
105  can participate in distributed training. A server belongs to a
106  cluster (specified by a `tf.train.ClusterSpec`), and
107  corresponds to a particular task in a named job. The server can
108  communicate with any other server in the same cluster.
109  """
110
111  def __init__(self,
112               server_or_cluster_def,
113               job_name=None,
114               task_index=None,
115               protocol=None,
116               config=None,
117               start=True):
118    """Creates a new server with the given definition.
119
120    The `job_name`, `task_index`, and `protocol` arguments are optional, and
121    override any information provided in `server_or_cluster_def`.
122
123    Args:
124      server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef`
125        protocol buffer, or a `tf.train.ClusterSpec` object, describing the
126        server to be created and/or the cluster of which it is a member.
127      job_name: (Optional.) Specifies the name of the job of which the server is
128        a member. Defaults to the value in `server_or_cluster_def`, if
129        specified.
130      task_index: (Optional.) Specifies the task index of the server in its job.
131        Defaults to the value in `server_or_cluster_def`, if specified.
132        Otherwise defaults to 0 if the server's job has only one task.
133      protocol: (Optional.) Specifies the protocol to be used by the server.
134        Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value
135        in `server_or_cluster_def`, if specified. Otherwise defaults to
136        `"grpc"`.
137      config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default
138        configuration options for all sessions that run on this server.
139      start: (Optional.) Boolean, indicating whether to start the server after
140        creating it. Defaults to `True`.
141
142    Raises:
143      tf.errors.OpError: Or one of its subclasses if an error occurs while
144        creating the TensorFlow server.
145    """
146    self._server_def = _make_server_def(server_or_cluster_def, job_name,
147                                        task_index, protocol, config)
148    self._server = c_api.TF_NewServer(self._server_def.SerializeToString())
149    if start:
150      self.start()
151
152  def __del__(self):
153    # At shutdown, `errors` may have been garbage collected.
154    if errors is not None:
155      exception = errors.UnimplementedError
156    else:
157      exception = Exception
158    try:
159      c_api.TF_ServerStop(self._server)
160      # Clean shutdown of servers is not yet implemented, so
161      # we leak instead of calling c_api.TF_DeleteServer here.
162      # See:
163      # https://github.com/tensorflow/tensorflow/blob/0495317a6e9dd4cac577b9d5cf9525e62b571018/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h#L73
164    except AttributeError:
165      # At shutdown, `c_api` may have been garbage collected.
166      pass
167    except exception:
168      pass
169    self._server = None
170
171  def start(self):
172    """Starts this server.
173
174    Raises:
175      tf.errors.OpError: Or one of its subclasses if an error occurs while
176        starting the TensorFlow server.
177    """
178    c_api.TF_ServerStart(self._server)
179
180  def join(self):
181    """Blocks until the server has shut down.
182
183    This method currently blocks forever.
184
185    Raises:
186      tf.errors.OpError: Or one of its subclasses if an error occurs while
187        joining the TensorFlow server.
188    """
189    c_api.TF_ServerJoin(self._server)
190
191  @property
192  def server_def(self):
193    """Returns the `tf.train.ServerDef` for this server.
194
195    Returns:
196      A `tf.train.ServerDef` protocol buffer that describes the configuration
197      of this server.
198    """
199    return self._server_def
200
201  @property
202  def target(self):
203    """Returns the target for a `tf.compat.v1.Session` to connect to this server.
204
205    To create a
206    `tf.compat.v1.Session` that
207    connects to this server, use the following snippet:
208
209    ```python
210    server = tf.distribute.Server(...)
211    with tf.compat.v1.Session(server.target):
212      # ...
213    ```
214
215    Returns:
216      A string containing a session target for this server.
217    """
218    return c_api.TF_ServerTarget(self._server)
219
220  @staticmethod
221  def create_local_server(config=None, start=True):
222    """Creates a new single-process cluster running on the local host.
223
224    This method is a convenience wrapper for creating a
225    `tf.distribute.Server` with a `tf.train.ServerDef` that specifies a
226    single-process cluster containing a single task in a job called
227    `"local"`.
228
229    Args:
230      config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default
231        configuration options for all sessions that run on this server.
232      start: (Optional.) Boolean, indicating whether to start the server after
233        creating it. Defaults to `True`.
234
235    Returns:
236      A local `tf.distribute.Server`.
237    """
238    # Specifying port 0 means that the OS will choose a free port for the
239    # server.
240    return Server({"localhost": ["localhost:0"]},
241                  protocol="grpc",
242                  config=config,
243                  start=start)
244
245
246@tf_export("train.ClusterSpec")
247class ClusterSpec(object):
248  """Represents a cluster as a set of "tasks", organized into "jobs".
249
250  A `tf.train.ClusterSpec` represents the set of processes that
251  participate in a distributed TensorFlow computation. Every
252  `tf.distribute.Server` is constructed in a particular cluster.
253
254  To create a cluster with two jobs and five tasks, you specify the
255  mapping from job names to lists of network addresses (typically
256  hostname-port pairs).
257
258  ```python
259  cluster = tf.train.ClusterSpec({"worker": ["worker0.example.com:2222",
260                                             "worker1.example.com:2222",
261                                             "worker2.example.com:2222"],
262                                  "ps": ["ps0.example.com:2222",
263                                         "ps1.example.com:2222"]})
264  ```
265
266  Each job may also be specified as a sparse mapping from task indices
267  to network addresses. This enables a server to be configured without
268  needing to know the identity of (for example) all other worker
269  tasks:
270
271  ```python
272  cluster = tf.train.ClusterSpec({"worker": {1: "worker1.example.com:2222"},
273                                  "ps": ["ps0.example.com:2222",
274                                         "ps1.example.com:2222"]})
275  ```
276  """
277
278  def __init__(self, cluster):
279    """Creates a `ClusterSpec`.
280
281    Args:
282      cluster: A dictionary mapping one or more job names to (i) a list of
283        network addresses, or (ii) a dictionary mapping integer task indices to
284        network addresses; or a `tf.train.ClusterDef` protocol buffer.
285
286    Raises:
287      TypeError: If `cluster` is not a dictionary mapping strings to lists
288        of strings, and not a `tf.train.ClusterDef` protobuf.
289    """
290    if isinstance(cluster, dict):
291      self._cluster_spec = {}
292      for job_name, tasks in cluster.items():
293        if isinstance(tasks, (list, tuple)):
294          job_tasks = {i: task for i, task in enumerate(tasks)}
295        elif isinstance(tasks, dict):
296          job_tasks = {i: task for i, task in tasks.items()}
297        else:
298          raise TypeError("The tasks for job %r must be a list or a dictionary "
299                          "from integers to strings." % job_name)
300        self._cluster_spec[job_name] = job_tasks
301      self._make_cluster_def()
302    elif isinstance(cluster, cluster_pb2.ClusterDef):
303      self._cluster_def = cluster
304      self._cluster_spec = {}
305      for job_def in self._cluster_def.job:
306        self._cluster_spec[job_def.name] = {
307            i: t for i, t in job_def.tasks.items()
308        }
309    elif isinstance(cluster, ClusterSpec):
310      self._cluster_def = cluster_pb2.ClusterDef()
311      self._cluster_def.MergeFrom(cluster.as_cluster_def())
312      self._cluster_spec = {}
313      for job_def in self._cluster_def.job:
314        self._cluster_spec[job_def.name] = {
315            i: t for i, t in job_def.tasks.items()
316        }
317    else:
318      raise TypeError("`cluster` must be a dictionary mapping one or more "
319                      "job names to lists of network addresses, or a "
320                      "`ClusterDef` protocol buffer")
321
322  def __bool__(self):
323    return bool(self._cluster_spec)
324
325  # Python 2.x
326  __nonzero__ = __bool__
327
328  def __eq__(self, other):
329    return self._cluster_spec == other
330
331  def __ne__(self, other):
332    return self._cluster_spec != other
333
334  def __repr__(self):
335    key_values = self.as_dict()
336    string_items = [
337        repr(k) + ": " + repr(key_values[k]) for k in sorted(key_values)
338    ]
339    return "ClusterSpec({" + ", ".join(string_items) + "})"
340
341  def as_dict(self):
342    """Returns a dictionary from job names to their tasks.
343
344    For each job, if the task index space is dense, the corresponding
345    value will be a list of network addresses; otherwise it will be a
346    dictionary mapping (sparse) task indices to the corresponding
347    addresses.
348
349    Returns:
350      A dictionary mapping job names to lists or dictionaries
351      describing the tasks in those jobs.
352    """
353    ret = {}
354    for job in self.jobs:
355      task_indices = self.task_indices(job)
356      if len(task_indices) == 0:
357        ret[job] = {}
358        continue
359      if max(task_indices) + 1 == len(task_indices):
360        # Return a list because the task indices are dense. This
361        # matches the behavior of `as_dict()` before support for
362        # sparse jobs was added.
363        ret[job] = self.job_tasks(job)
364      else:
365        ret[job] = {i: self.task_address(job, i) for i in task_indices}
366    return ret
367
368  def as_cluster_def(self):
369    """Returns a `tf.train.ClusterDef` protocol buffer based on this cluster."""
370    return self._cluster_def
371
372  @property
373  def jobs(self):
374    """Returns a list of job names in this cluster.
375
376    Returns:
377      A list of strings, corresponding to the names of jobs in this cluster.
378    """
379    return list(self._cluster_spec.keys())
380
381  def num_tasks(self, job_name):
382    """Returns the number of tasks defined in the given job.
383
384    Args:
385      job_name: The string name of a job in this cluster.
386
387    Returns:
388      The number of tasks defined in the given job.
389
390    Raises:
391      ValueError: If `job_name` does not name a job in this cluster.
392    """
393    try:
394      job = self._cluster_spec[job_name]
395    except KeyError:
396      raise ValueError("No such job in cluster: %r" % job_name)
397    return len(job)
398
399  def task_indices(self, job_name):
400    """Returns a list of valid task indices in the given job.
401
402    Args:
403      job_name: The string name of a job in this cluster.
404
405    Returns:
406      A list of valid task indices in the given job.
407
408    Raises:
409      ValueError: If `job_name` does not name a job in this cluster,
410      or no task with index `task_index` is defined in that job.
411    """
412    try:
413      job = self._cluster_spec[job_name]
414    except KeyError:
415      raise ValueError("No such job in cluster: %r" % job_name)
416    return list(sorted(job.keys()))
417
418  def task_address(self, job_name, task_index):
419    """Returns the address of the given task in the given job.
420
421    Args:
422      job_name: The string name of a job in this cluster.
423      task_index: A non-negative integer.
424
425    Returns:
426      The address of the given task in the given job.
427
428    Raises:
429      ValueError: If `job_name` does not name a job in this cluster,
430      or no task with index `task_index` is defined in that job.
431    """
432    try:
433      job = self._cluster_spec[job_name]
434    except KeyError:
435      raise ValueError("No such job in cluster: %r" % job_name)
436    try:
437      return job[task_index]
438    except KeyError:
439      raise ValueError("No task with index %r in job %r" %
440                       (task_index, job_name))
441
442  def job_tasks(self, job_name):
443    """Returns a mapping from task ID to address in the given job.
444
445    NOTE: For backwards compatibility, this method returns a list. If
446    the given job was defined with a sparse set of task indices, the
447    length of this list may not reflect the number of tasks defined in
448    this job. Use the `tf.train.ClusterSpec.num_tasks` method
449    to find the number of tasks defined in a particular job.
450
451    Args:
452      job_name: The string name of a job in this cluster.
453
454    Returns:
455      A list of task addresses, where the index in the list
456      corresponds to the task index of each task. The list may contain
457      `None` if the job was defined with a sparse set of task indices.
458
459    Raises:
460      ValueError: If `job_name` does not name a job in this cluster.
461    """
462    try:
463      job = self._cluster_spec[job_name]
464    except KeyError:
465      raise ValueError("No such job in cluster: %r" % job_name)
466    ret = [None for _ in range(max(job.keys()) + 1)]
467    for i, task in job.items():
468      ret[i] = task
469    return ret
470
471  def _make_cluster_def(self):
472    """Creates a `tf.train.ClusterDef` based on the given `cluster_spec`.
473
474    Raises:
475      TypeError: If `cluster_spec` is not a dictionary mapping strings to lists
476        of strings.
477    """
478    self._cluster_def = cluster_pb2.ClusterDef()
479
480    # NOTE(mrry): Sort by job_name to produce deterministic protobufs.
481    for job_name, tasks in sorted(self._cluster_spec.items()):
482      try:
483        job_name = compat.as_bytes(job_name)
484      except TypeError:
485        raise TypeError("Job name %r must be bytes or unicode" % job_name)
486
487      job_def = self._cluster_def.job.add()
488      job_def.name = job_name
489
490      for i, task_address in sorted(tasks.items()):
491        try:
492          task_address = compat.as_bytes(task_address)
493        except TypeError:
494          raise TypeError("Task address %r must be bytes or unicode" %
495                          task_address)
496        job_def.tasks[i] = task_address
497
498
499@tf_export("config.experimental.ClusterDeviceFilters")
500class ClusterDeviceFilters(object):
501  """Represent a collection of device filters for the remote workers in cluster.
502
503  NOTE: this is an experimental API and subject to changes.
504
505  Set device filters for selective jobs and tasks. For each remote worker, the
506  device filters are a list of strings. When any filters are present, the remote
507  worker will ignore all devices which do not match any of its filters. Each
508  filter can be partially specified, e.g. "/job:ps", "/job:worker/replica:3",
509  etc. Note that a device is always visible to the worker it is located on.
510
511  For example, to set the device filters for a parameter server cluster:
512
513  ```python
514  cdf = tf.config.experimental.ClusterDeviceFilters()
515  for i in range(num_workers):
516    cdf.set_device_filters('worker', i, ['/job:ps'])
517  for i in range(num_ps):
518    cdf.set_device_filters('ps', i, ['/job:worker'])
519
520  tf.config.experimental_connect_to_cluster(cluster_def,
521                                            cluster_device_filters=cdf)
522  ```
523
524  The device filters can be partically specified. For remote tasks that do not
525  have device filters specified, all devices will be visible to them.
526  """
527
528  def __init__(self):
529    # `_device_filters` is a dict mapping job names to job device filters.
530    # Job device filters further maps task IDs to task device filters.
531    # Task device filters are a list of strings, each one is a device filter.
532    self._device_filters = {}
533
534    # Serialized protobuf for cluster device filters.
535    self._cluster_device_filters = None
536
537  def set_device_filters(self, job_name, task_index, device_filters):
538    """Set the device filters for given job name and task id."""
539    assert all(isinstance(df, str) for df in device_filters)
540    self._device_filters.setdefault(job_name, {})
541    self._device_filters[job_name][task_index] = [df for df in device_filters]
542    # Due to updates in data, invalidate the serialized proto cache.
543    self._cluster_device_filters = None
544
545  def _as_cluster_device_filters(self):
546    """Returns a serialized protobuf of cluster device filters."""
547    if self._cluster_device_filters:
548      return self._cluster_device_filters
549
550    self._make_cluster_device_filters()
551    return self._cluster_device_filters
552
553  def _make_cluster_device_filters(self):
554    """Creates `ClusterDeviceFilters` proto based on the `_device_filters`.
555
556    Raises:
557      TypeError: If `_device_filters` is not a dictionary mapping strings to
558      a map of task indices and device filters.
559    """
560    self._cluster_device_filters = device_filters_pb2.ClusterDeviceFilters()
561
562    # Sort by job_name to produce deterministic protobufs.
563    for job_name, tasks in sorted(self._device_filters.items()):
564      try:
565        job_name = compat.as_bytes(job_name)
566      except TypeError:
567        raise TypeError("Job name %r must be bytes or unicode" % job_name)
568
569      jdf = self._cluster_device_filters.jobs.add()
570      jdf.name = job_name
571
572      for i, task_device_filters in sorted(tasks.items()):
573        for tdf in task_device_filters:
574          try:
575            tdf = compat.as_bytes(tdf)
576          except TypeError:
577            raise TypeError("Device filter %r must be bytes or unicode" % tdf)
578          jdf.tasks[i].device_filters.append(tdf)
579