1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Device function for replicated training."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import six
21
22from tensorflow.core.framework import node_def_pb2
23from tensorflow.python.framework import device as pydev
24from tensorflow.python.platform import tf_logging as logging
25from tensorflow.python.training import server_lib
26from tensorflow.python.util.tf_export import tf_export
27
28# This is a tuple of PS ops used by tf.estimator.Estimator which should work in
29# almost all of cases.
30STANDARD_PS_OPS = ("Variable", "VariableV2", "AutoReloadVariable",
31                   "MutableHashTable", "MutableHashTableV2",
32                   "MutableHashTableOfTensors", "MutableHashTableOfTensorsV2",
33                   "MutableDenseHashTable", "MutableDenseHashTableV2",
34                   "VarHandleOp", "BoostedTreesEnsembleResourceHandleOp")
35
36
37class _RoundRobinStrategy(object):
38  """Returns the next ps task index for placement in round-robin order.
39
40  This class is not to be used directly by users.  See instead
41  `replica_device_setter()` below.
42  """
43
44  def __init__(self, num_tasks):
45    """Create a new `_RoundRobinStrategy`.
46
47    Args:
48      num_tasks: Number of ps tasks to cycle among.
49    """
50    self._num_tasks = num_tasks
51    self._next_task = 0
52
53  def __call__(self, unused_op):
54    """Choose a ps task index for the given `Operation`.
55
56    Args:
57      unused_op: An `Operation` to be placed on ps.
58
59    Returns:
60      The next ps task index to use for the `Operation`. Returns the next
61      index, in the range `[offset, offset + num_tasks)`.
62    """
63    task = self._next_task
64    self._next_task = (self._next_task + 1) % self._num_tasks
65    return task
66
67
68class _ReplicaDeviceChooser(object):
69  """Class to choose devices for Ops in a replicated training setup.
70
71  This class is not to be used directly by users.  See instead
72  `replica_device_setter()` below.
73  """
74
75  def __init__(self, ps_tasks, ps_device, worker_device, merge_devices, ps_ops,
76               ps_strategy):
77    """Create a new `_ReplicaDeviceChooser`.
78
79    Args:
80      ps_tasks: Number of tasks in the `ps` job.
81      ps_device: String.  Name of the `ps` job.
82      worker_device: String.  Name of the `worker` job.
83      merge_devices: Boolean. Set to True to allow merging of device specs.
84      ps_ops: List of strings representing `Operation` types that need to be
85        placed on `ps` devices.
86      ps_strategy: A callable invoked for every ps `Operation` (i.e. matched by
87        `ps_ops`), that takes the `Operation` and returns the ps task index to
88        use.
89    """
90    self._ps_tasks = ps_tasks
91    self._ps_device = ps_device
92    self._worker_device = worker_device
93    self._merge_devices = merge_devices
94    self._ps_ops = ps_ops
95    self._ps_strategy = ps_strategy
96
97  def device_function(self, op):
98    """Choose a device for `op`.
99
100    Args:
101      op: an `Operation`.
102
103    Returns:
104      The device to use for the `Operation`.
105    """
106    # If we don't return early here, either merge_devices is True, or op.device
107    # is empty (in which case merging is a no-op). So we can always merge below.
108    if not self._merge_devices and op.device:
109      return op.device
110
111    current_device = pydev.DeviceSpec.from_string(op.device or "")
112
113    # The ps_device will be used for specified ops (ps_ops) whenever it is
114    # present and ps_tasks is non-zero. However, its task number will only be
115    # set (using ps_strategy) if there is a job field in ps_device that won't be
116    # changed by the job field (if present) in current_device.
117    node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
118    if self._ps_tasks and self._ps_device and node_def.op in self._ps_ops:
119      ps_device = pydev.DeviceSpec.from_string(self._ps_device)
120
121      current_job, ps_job = current_device.job, ps_device.job
122      if ps_job and (not current_job or current_job == ps_job):
123        ps_device.task = self._ps_strategy(op)
124
125      ps_device.merge_from(current_device)
126      return ps_device.to_string()
127
128    worker_device = pydev.DeviceSpec.from_string(self._worker_device or "")
129    worker_device.merge_from(current_device)
130    return worker_device.to_string()
131
132
133@tf_export(v1=["train.replica_device_setter"])
134def replica_device_setter(ps_tasks=0, ps_device="/job:ps",
135                          worker_device="/job:worker", merge_devices=True,
136                          cluster=None, ps_ops=None, ps_strategy=None):
137  """Return a `device function` to use when building a Graph for replicas.
138
139  Device Functions are used in `with tf.device(device_function):` statement to
140  automatically assign devices to `Operation` objects as they are constructed,
141  Device constraints are added from the inner-most context first, working
142  outwards. The merging behavior adds constraints to fields that are yet unset
143  by a more inner context. Currently the fields are (job, task, cpu/gpu).
144
145  If `cluster` is `None`, and `ps_tasks` is 0, the returned function is a no-op.
146  Otherwise, the value of `ps_tasks` is derived from `cluster`.
147
148  By default, only Variable ops are placed on ps tasks, and the placement
149  strategy is round-robin over all ps tasks. A custom `ps_strategy` may be used
150  to do more intelligent placement, such as
151  `tf.contrib.training.GreedyLoadBalancingStrategy`.
152
153  For example,
154
155  ```python
156  # To build a cluster with two ps jobs on hosts ps0 and ps1, and 3 worker
157  # jobs on hosts worker0, worker1 and worker2.
158  cluster_spec = {
159      "ps": ["ps0:2222", "ps1:2222"],
160      "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]}
161  with tf.device(tf.train.replica_device_setter(cluster=cluster_spec)):
162    # Build your graph
163    v1 = tf.Variable(...)  # assigned to /job:ps/task:0
164    v2 = tf.Variable(...)  # assigned to /job:ps/task:1
165    v3 = tf.Variable(...)  # assigned to /job:ps/task:0
166  # Run compute
167  ```
168
169  Args:
170    ps_tasks: Number of tasks in the `ps` job.  Ignored if `cluster` is
171      provided.
172    ps_device: String.  Device of the `ps` job.  If empty no `ps` job is used.
173      Defaults to `ps`.
174    worker_device: String.  Device of the `worker` job.  If empty no `worker`
175      job is used.
176    merge_devices: `Boolean`. If `True`, merges or only sets a device if the
177      device constraint is completely unset. merges device specification rather
178      than overriding them.
179    cluster: `ClusterDef` proto or `ClusterSpec`.
180    ps_ops: List of strings representing `Operation` types that need to be
181      placed on `ps` devices.  If `None`, defaults to `STANDARD_PS_OPS`.
182    ps_strategy: A callable invoked for every ps `Operation` (i.e. matched by
183      `ps_ops`), that takes the `Operation` and returns the ps task index to
184      use.  If `None`, defaults to a round-robin strategy across all `ps`
185      devices.
186
187  Returns:
188    A function to pass to `tf.device()`.
189
190  Raises:
191    TypeError if `cluster` is not a dictionary or `ClusterDef` protocol buffer,
192    or if `ps_strategy` is provided but not a callable.
193  """
194  if cluster is not None:
195    if isinstance(cluster, server_lib.ClusterSpec):
196      cluster_spec = cluster.as_dict()
197    else:
198      cluster_spec = server_lib.ClusterSpec(cluster).as_dict()
199    # Get ps_job_name from ps_device by stripping "/job:".
200    ps_job_name = pydev.DeviceSpec.from_string(ps_device).job
201    if ps_job_name not in cluster_spec or cluster_spec[ps_job_name] is None:
202      return None
203    ps_tasks = len(cluster_spec[ps_job_name])
204
205  if ps_tasks == 0:
206    return None
207
208  if ps_ops is None:
209    # TODO(sherrym): Variables in the LOCAL_VARIABLES collection should not be
210    # placed in the parameter server.
211    ps_ops = list(STANDARD_PS_OPS)
212
213  if not merge_devices:
214    logging.warning(
215        "DEPRECATION: It is recommended to set merge_devices=true in "
216        "replica_device_setter")
217  if ps_strategy is None:
218    ps_strategy = _RoundRobinStrategy(ps_tasks)
219  if not six.callable(ps_strategy):
220    raise TypeError("ps_strategy must be callable")
221  chooser = _ReplicaDeviceChooser(
222      ps_tasks, ps_device, worker_device, merge_devices, ps_ops, ps_strategy)
223  return chooser.device_function
224