1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Device function for replicated training."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import six
21
22from tensorflow.core.framework import node_def_pb2
23from tensorflow.python.framework import device as pydev
24from tensorflow.python.platform import tf_logging as logging
25from tensorflow.python.training import server_lib
26from tensorflow.python.util.tf_export import tf_export
27
28# This is a tuple of PS ops used by tf.estimator.Estimator which should work in
29# almost all of cases.
30STANDARD_PS_OPS = ("Variable", "VariableV2", "AutoReloadVariable",
31                   "MutableHashTable", "MutableHashTableV2",
32                   "MutableHashTableOfTensors", "MutableHashTableOfTensorsV2",
33                   "MutableDenseHashTable", "MutableDenseHashTableV2",
34                   "VarHandleOp", "BoostedTreesEnsembleResourceHandleOp",
35                   "BoostedTreesQuantileStreamResourceHandleOp",
36                   "ResourceConditionalAccumulator",
37                   "DecisionTreeResource")
38
39
40class _RoundRobinStrategy(object):
41  """Returns the next ps task index for placement in round-robin order.
42
43  This class is not to be used directly by users.  See instead
44  `replica_device_setter()` below.
45  """
46
47  def __init__(self, num_tasks):
48    """Create a new `_RoundRobinStrategy`.
49
50    Args:
51      num_tasks: Number of ps tasks to cycle among.
52    """
53    self._num_tasks = num_tasks
54    self._next_task = 0
55
56  def __call__(self, unused_op):
57    """Choose a ps task index for the given `Operation`.
58
59    Args:
60      unused_op: An `Operation` to be placed on ps.
61
62    Returns:
63      The next ps task index to use for the `Operation`. Returns the next
64      index, in the range `[offset, offset + num_tasks)`.
65    """
66    task = self._next_task
67    self._next_task = (self._next_task + 1) % self._num_tasks
68    return task
69
70
71class _ReplicaDeviceChooser(object):
72  """Class to choose devices for Ops in a replicated training setup.
73
74  This class is not to be used directly by users.  See instead
75  `replica_device_setter()` below.
76  """
77
78  def __init__(self, ps_tasks, ps_device, worker_device, merge_devices, ps_ops,
79               ps_strategy):
80    """Create a new `_ReplicaDeviceChooser`.
81
82    Args:
83      ps_tasks: Number of tasks in the `ps` job.
84      ps_device: String.  Name of the `ps` job.
85      worker_device: String.  Name of the `worker` job.
86      merge_devices: Boolean. Set to True to allow merging of device specs.
87      ps_ops: List of strings representing `Operation` types that need to be
88        placed on `ps` devices.
89      ps_strategy: A callable invoked for every ps `Operation` (i.e. matched by
90        `ps_ops`), that takes the `Operation` and returns the ps task index to
91        use.
92    """
93    self._ps_tasks = ps_tasks
94    self._ps_device = ps_device
95    self._worker_device = worker_device
96    self._merge_devices = merge_devices
97    self._ps_ops = ps_ops
98    self._ps_strategy = ps_strategy
99
100  def device_function(self, op):
101    """Choose a device for `op`.
102
103    Args:
104      op: an `Operation`.
105
106    Returns:
107      The device to use for the `Operation`.
108    """
109    # If we don't return early here, either merge_devices is True, or op.device
110    # is empty (in which case merging is a no-op). So we can always merge below.
111    if not self._merge_devices and op.device:
112      return op.device
113
114    current_device = pydev.DeviceSpec.from_string(op.device or "")
115
116    # The ps_device will be used for specified ops (ps_ops) whenever it is
117    # present and ps_tasks is non-zero. However, its task number will only be
118    # set (using ps_strategy) if there is a job field in ps_device that won't be
119    # changed by the job field (if present) in current_device.
120    node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
121    if self._ps_tasks and self._ps_device and node_def.op in self._ps_ops:
122      ps_device = pydev.DeviceSpec.from_string(self._ps_device)
123
124      current_job, ps_job = current_device.job, ps_device.job
125      if ps_job and (not current_job or current_job == ps_job):
126        ps_device = ps_device.replace(task=self._ps_strategy(op))
127
128      ps_device = ps_device.make_merged_spec(current_device)
129      return ps_device.to_string()
130
131    worker_device = pydev.DeviceSpec.from_string(self._worker_device or "")
132    worker_device = worker_device.make_merged_spec(current_device)
133    return worker_device.to_string()
134
135
136@tf_export(v1=["train.replica_device_setter"])
137def replica_device_setter(ps_tasks=0,
138                          ps_device="/job:ps",
139                          worker_device="/job:worker",
140                          merge_devices=True,
141                          cluster=None,
142                          ps_ops=None,
143                          ps_strategy=None):
144  """Return a `device function` to use when building a Graph for replicas.
145
146  Device Functions are used in `with tf.device(device_function):` statement to
147  automatically assign devices to `Operation` objects as they are constructed,
148  Device constraints are added from the inner-most context first, working
149  outwards. The merging behavior adds constraints to fields that are yet unset
150  by a more inner context. Currently the fields are (job, task, cpu/gpu).
151
152  If `cluster` is `None`, and `ps_tasks` is 0, the returned function is a no-op.
153  Otherwise, the value of `ps_tasks` is derived from `cluster`.
154
155  By default, only Variable ops are placed on ps tasks, and the placement
156  strategy is round-robin over all ps tasks. A custom `ps_strategy` may be used
157  to do more intelligent placement, such as
158  `tf.contrib.training.GreedyLoadBalancingStrategy`.
159
160  For example,
161
162  ```python
163  # To build a cluster with two ps jobs on hosts ps0 and ps1, and 3 worker
164  # jobs on hosts worker0, worker1 and worker2.
165  cluster_spec = {
166      "ps": ["ps0:2222", "ps1:2222"],
167      "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]}
168  with
169  tf.device(tf.compat.v1.train.replica_device_setter(cluster=cluster_spec)):
170    # Build your graph
171    v1 = tf.Variable(...)  # assigned to /job:ps/task:0
172    v2 = tf.Variable(...)  # assigned to /job:ps/task:1
173    v3 = tf.Variable(...)  # assigned to /job:ps/task:0
174  # Run compute
175  ```
176
177  Args:
178    ps_tasks: Number of tasks in the `ps` job.  Ignored if `cluster` is
179      provided.
180    ps_device: String.  Device of the `ps` job.  If empty no `ps` job is used.
181      Defaults to `ps`.
182    worker_device: String.  Device of the `worker` job.  If empty no `worker`
183      job is used.
184    merge_devices: `Boolean`. If `True`, merges or only sets a device if the
185      device constraint is completely unset. merges device specification rather
186      than overriding them.
187    cluster: `ClusterDef` proto or `ClusterSpec`.
188    ps_ops: List of strings representing `Operation` types that need to be
189      placed on `ps` devices.  If `None`, defaults to `STANDARD_PS_OPS`.
190    ps_strategy: A callable invoked for every ps `Operation` (i.e. matched by
191      `ps_ops`), that takes the `Operation` and returns the ps task index to
192      use.  If `None`, defaults to a round-robin strategy across all `ps`
193      devices.
194
195  Returns:
196    A function to pass to `tf.device()`.
197
198  Raises:
199    TypeError if `cluster` is not a dictionary or `ClusterDef` protocol buffer,
200    or if `ps_strategy` is provided but not a callable.
201  """
202  if cluster is not None:
203    if isinstance(cluster, server_lib.ClusterSpec):
204      cluster_spec = cluster.as_dict()
205    else:
206      cluster_spec = server_lib.ClusterSpec(cluster).as_dict()
207    # Get ps_job_name from ps_device by stripping "/job:".
208    ps_job_name = pydev.DeviceSpec.from_string(ps_device).job
209    if ps_job_name not in cluster_spec or cluster_spec[ps_job_name] is None:
210      return None
211    ps_tasks = len(cluster_spec[ps_job_name])
212
213  if ps_tasks == 0:
214    return None
215
216  if ps_ops is None:
217    # TODO(sherrym): Variables in the LOCAL_VARIABLES collection should not be
218    # placed in the parameter server.
219    ps_ops = list(STANDARD_PS_OPS)
220
221  if not merge_devices:
222    logging.warning(
223        "DEPRECATION: It is recommended to set merge_devices=true in "
224        "replica_device_setter")
225  if ps_strategy is None:
226    ps_strategy = _RoundRobinStrategy(ps_tasks)
227  if not six.callable(ps_strategy):
228    raise TypeError("ps_strategy must be callable")
229  chooser = _ReplicaDeviceChooser(ps_tasks, ps_device, worker_device,
230                                  merge_devices, ps_ops, ps_strategy)
231  return chooser.device_function
232