1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helpers to connect to remote servers."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22
23from absl import logging
24
25from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
26from tensorflow.python import pywrap_tfe
27from tensorflow.python.distribute import device_util
28from tensorflow.python.distribute.cluster_resolver import cluster_resolver
29from tensorflow.python.eager import context
30from tensorflow.python.framework import ops
31from tensorflow.python.platform import remote_utils
32from tensorflow.python.training import server_lib
33from tensorflow.python.util import nest
34from tensorflow.python.util.tf_export import tf_export
35
36
37_GRPC_PREFIX = "grpc://"
38_LOCAL_MASTERS = ("", "local")
39
40
41@tf_export("config.experimental_connect_to_host")
42def connect_to_remote_host(remote_host=None, job_name="worker"):
43  """Connects to a single machine to enable remote execution on it.
44
45  Will make devices on the remote host available to use. Note that calling this
46  more than once will work, but will invalidate any tensor handles on the old
47  remote devices.
48
49  Using the default job_name of worker, you can schedule ops to run remotely as
50  follows:
51  ```python
52  # When eager execution is enabled, connect to the remote host.
53  tf.config.experimental_connect_to_host("exampleaddr.com:9876")
54
55  with ops.device("job:worker/replica:0/task:1/device:CPU:0"):
56    # The following tensors should be resident on the remote device, and the op
57    # will also execute remotely.
58    x1 = array_ops.ones([2, 2])
59    x2 = array_ops.ones([2, 2])
60    y = math_ops.matmul(x1, x2)
61  ```
62
63  Args:
64    remote_host: a single or a list the remote server addr in host-port format.
65    job_name: The job name under which the new server will be accessible.
66
67  Raises:
68    ValueError: if remote_host is None.
69  """
70  if not remote_host:
71    raise ValueError("Must provide at least one remote_host")
72
73  remote_hosts = nest.flatten(remote_host)
74  cluster_spec = server_lib.ClusterSpec(
75      {job_name: [_strip_prefix(host, _GRPC_PREFIX) for host in remote_hosts]})
76
77  connect_to_cluster(cluster_spec)
78
79
80@tf_export("config.experimental_connect_to_cluster")
81def connect_to_cluster(cluster_spec_or_resolver,
82                       job_name="localhost",
83                       task_index=0,
84                       protocol=None,
85                       make_master_device_default=True,
86                       cluster_device_filters=None):
87  """Connects to the given cluster.
88
89  Will make devices on the cluster available to use. Note that calling this more
90  than once will work, but will invalidate any tensor handles on the old remote
91  devices.
92
93  If the given local job name is not present in the cluster specification, it
94  will be automatically added, using an unused port on the localhost.
95
96  Device filters can be specified to isolate groups of remote tasks to avoid
97  undesired accesses between workers. Workers accessing resources or launching
98  ops / functions on filtered remote devices will result in errors (unknown
99  devices). For any remote task, if no device filter is present, all cluster
100  devices will be visible; if any device filter is specified, it can only
101  see devices matching at least one filter. Devices on the task itself are
102  always visible. Device filters can be particially specified.
103
104  For example, for a cluster set up for parameter server training, the following
105  device filters might be specified:
106
107  ```python
108  cdf = tf.config.experimental.ClusterDeviceFilters()
109  # For any worker, only the devices on PS nodes and itself are visible
110  for i in range(num_workers):
111    cdf.set_device_filters('worker', i, ['/job:ps'])
112  # Similarly for any ps, only the devices on workers and itself are visible
113  for i in range(num_ps):
114    cdf.set_device_filters('ps', i, ['/job:worker'])
115
116  tf.config.experimental_connect_to_cluster(cluster_def,
117                                            cluster_device_filters=cdf)
118  ```
119
120  Args:
121    cluster_spec_or_resolver: A `ClusterSpec` or `ClusterResolver` describing
122      the cluster.
123    job_name: The name of the local job.
124    task_index: The local task index.
125    protocol: The communication protocol, such as `"grpc"`. If unspecified, will
126      use the default from `python/platform/remote_utils.py`.
127    make_master_device_default: If True and a cluster resolver is passed, will
128      automatically enter the master task device scope, which indicates the
129      master becomes the default device to run ops. It won't do anything if
130      a cluster spec is passed. Will throw an error if the caller is currently
131      already in some device scope.
132    cluster_device_filters: an instance of
133      `tf.train.experimental/ClusterDeviceFilters` that specify device filters
134      to the remote tasks in cluster.
135  """
136  if not context.executing_eagerly():
137    raise ValueError(
138        "`tf.config.experimental_connect_to_cluster` can only be called in "
139        "eager mode."
140    )
141  protocol = protocol or remote_utils.get_default_communication_protocol()
142  if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec):
143    cluster_spec = cluster_spec_or_resolver
144  elif isinstance(cluster_spec_or_resolver, cluster_resolver.ClusterResolver):
145    if cluster_spec_or_resolver.master() in _LOCAL_MASTERS:
146      # Do nothing if the master is local.
147      return
148    cluster_spec = cluster_spec_or_resolver.cluster_spec()
149  else:
150    raise ValueError(
151        "`cluster_spec_or_resolver` must be a `ClusterSpec` or a "
152        "`ClusterResolver`.")
153
154  cluster_def = copy.deepcopy(cluster_spec.as_cluster_def())
155  if cluster_device_filters:
156    if isinstance(cluster_device_filters, server_lib.ClusterDeviceFilters):
157      cluster_device_filters = copy.deepcopy(
158          cluster_device_filters._as_cluster_device_filters())  # pylint: disable=protected-access
159    else:
160      raise ValueError("`cluster_device_filters` must be an instance of "
161                       "`tf.train.experimental.ClusterDeviceFilters`.")
162
163  # Automatically add local job, if not part of the cluster spec.
164  if job_name not in cluster_spec.jobs:
165    local_port = pywrap_tfe.TF_PickUnusedPortOrDie()
166    job_def = cluster_def.job.add()
167    job_def.name = job_name
168    # TODO(fishx): Update this to make sure remote worker has valid ip address
169    # to connect with local.
170    job_def.tasks[0] = "localhost:{}".format(local_port)
171
172  server_def = ServerDef(
173      cluster=cluster_def,
174      job_name=job_name,
175      task_index=task_index,
176      protocol=protocol,
177      default_session_config=context.context().config,
178      cluster_device_filters=cluster_device_filters)
179
180  if context.get_server_def() is None:
181    context.set_server_def(server_def)
182  else:
183    context.update_server_def(server_def)
184
185  if make_master_device_default and isinstance(
186      cluster_spec_or_resolver,
187      cluster_resolver.ClusterResolver) and cluster_spec_or_resolver.master():
188    master = cluster_spec_or_resolver.master()
189    master_job_name = None
190    master_task_id = None
191    for job_name in cluster_spec.jobs:
192      for task_id in cluster_spec.task_indices(job_name):
193        task_address = cluster_spec.task_address(job_name, task_id)
194        if master in task_address or task_address in master:
195          master_job_name = job_name
196          master_task_id = task_id
197          break
198
199    if not master_job_name:
200      raise ValueError(
201          "`make_master_device_default` is set to True but cannot find "
202          "master %s in the cluster" % master)
203
204    master_device = "/job:{}/replica:0/task:{}".format(master_job_name,
205                                                       master_task_id)
206    master_device = device_util.canonicalize(master_device)
207    current_device = device_util.current()
208    if current_device:
209      current_device = device_util.canonicalize(current_device)
210    if current_device and current_device != master_device:
211      raise ValueError("`connect_to_cluster` is called inside existing device "
212                       "scope %s, which is different from the master device "
213                       "scope %s to enter. This is not allowed." %
214                       (current_device, master_device))
215    # TODO(b/138389076): Think of the entering device scope behavior in the
216    # failure recovery case when dealing with preemptions.
217    if not current_device:
218      logging.info("Entering into master device scope: %s", master_device)
219      ops.device(master_device).__enter__()
220
221
222def _strip_prefix(s, prefix):
223  return s[len(prefix):] if s.startswith(prefix) else s
224