1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Device function for replicated training.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import six 21 22from tensorflow.core.framework import node_def_pb2 23from tensorflow.python.framework import device as pydev 24from tensorflow.python.platform import tf_logging as logging 25from tensorflow.python.training import server_lib 26from tensorflow.python.util.tf_export import tf_export 27 28# This is a tuple of PS ops used by tf.estimator.Estimator which should work in 29# almost all of cases. 30STANDARD_PS_OPS = ("Variable", "VariableV2", "AutoReloadVariable", 31 "MutableHashTable", "MutableHashTableV2", 32 "MutableHashTableOfTensors", "MutableHashTableOfTensorsV2", 33 "MutableDenseHashTable", "MutableDenseHashTableV2", 34 "VarHandleOp", "BoostedTreesEnsembleResourceHandleOp") 35 36 37class _RoundRobinStrategy(object): 38 """Returns the next ps task index for placement in round-robin order. 39 40 This class is not to be used directly by users. See instead 41 `replica_device_setter()` below. 42 """ 43 44 def __init__(self, num_tasks): 45 """Create a new `_RoundRobinStrategy`. 46 47 Args: 48 num_tasks: Number of ps tasks to cycle among. 49 """ 50 self._num_tasks = num_tasks 51 self._next_task = 0 52 53 def __call__(self, unused_op): 54 """Choose a ps task index for the given `Operation`. 55 56 Args: 57 unused_op: An `Operation` to be placed on ps. 58 59 Returns: 60 The next ps task index to use for the `Operation`. Returns the next 61 index, in the range `[offset, offset + num_tasks)`. 62 """ 63 task = self._next_task 64 self._next_task = (self._next_task + 1) % self._num_tasks 65 return task 66 67 68class _ReplicaDeviceChooser(object): 69 """Class to choose devices for Ops in a replicated training setup. 70 71 This class is not to be used directly by users. See instead 72 `replica_device_setter()` below. 73 """ 74 75 def __init__(self, ps_tasks, ps_device, worker_device, merge_devices, ps_ops, 76 ps_strategy): 77 """Create a new `_ReplicaDeviceChooser`. 78 79 Args: 80 ps_tasks: Number of tasks in the `ps` job. 81 ps_device: String. Name of the `ps` job. 82 worker_device: String. Name of the `worker` job. 83 merge_devices: Boolean. Set to True to allow merging of device specs. 84 ps_ops: List of strings representing `Operation` types that need to be 85 placed on `ps` devices. 86 ps_strategy: A callable invoked for every ps `Operation` (i.e. matched by 87 `ps_ops`), that takes the `Operation` and returns the ps task index to 88 use. 89 """ 90 self._ps_tasks = ps_tasks 91 self._ps_device = ps_device 92 self._worker_device = worker_device 93 self._merge_devices = merge_devices 94 self._ps_ops = ps_ops 95 self._ps_strategy = ps_strategy 96 97 def device_function(self, op): 98 """Choose a device for `op`. 99 100 Args: 101 op: an `Operation`. 102 103 Returns: 104 The device to use for the `Operation`. 105 """ 106 # If we don't return early here, either merge_devices is True, or op.device 107 # is empty (in which case merging is a no-op). So we can always merge below. 108 if not self._merge_devices and op.device: 109 return op.device 110 111 current_device = pydev.DeviceSpec.from_string(op.device or "") 112 113 # The ps_device will be used for specified ops (ps_ops) whenever it is 114 # present and ps_tasks is non-zero. However, its task number will only be 115 # set (using ps_strategy) if there is a job field in ps_device that won't be 116 # changed by the job field (if present) in current_device. 117 node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def 118 if self._ps_tasks and self._ps_device and node_def.op in self._ps_ops: 119 ps_device = pydev.DeviceSpec.from_string(self._ps_device) 120 121 current_job, ps_job = current_device.job, ps_device.job 122 if ps_job and (not current_job or current_job == ps_job): 123 ps_device.task = self._ps_strategy(op) 124 125 ps_device.merge_from(current_device) 126 return ps_device.to_string() 127 128 worker_device = pydev.DeviceSpec.from_string(self._worker_device or "") 129 worker_device.merge_from(current_device) 130 return worker_device.to_string() 131 132 133@tf_export(v1=["train.replica_device_setter"]) 134def replica_device_setter(ps_tasks=0, ps_device="/job:ps", 135 worker_device="/job:worker", merge_devices=True, 136 cluster=None, ps_ops=None, ps_strategy=None): 137 """Return a `device function` to use when building a Graph for replicas. 138 139 Device Functions are used in `with tf.device(device_function):` statement to 140 automatically assign devices to `Operation` objects as they are constructed, 141 Device constraints are added from the inner-most context first, working 142 outwards. The merging behavior adds constraints to fields that are yet unset 143 by a more inner context. Currently the fields are (job, task, cpu/gpu). 144 145 If `cluster` is `None`, and `ps_tasks` is 0, the returned function is a no-op. 146 Otherwise, the value of `ps_tasks` is derived from `cluster`. 147 148 By default, only Variable ops are placed on ps tasks, and the placement 149 strategy is round-robin over all ps tasks. A custom `ps_strategy` may be used 150 to do more intelligent placement, such as 151 `tf.contrib.training.GreedyLoadBalancingStrategy`. 152 153 For example, 154 155 ```python 156 # To build a cluster with two ps jobs on hosts ps0 and ps1, and 3 worker 157 # jobs on hosts worker0, worker1 and worker2. 158 cluster_spec = { 159 "ps": ["ps0:2222", "ps1:2222"], 160 "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]} 161 with tf.device(tf.train.replica_device_setter(cluster=cluster_spec)): 162 # Build your graph 163 v1 = tf.Variable(...) # assigned to /job:ps/task:0 164 v2 = tf.Variable(...) # assigned to /job:ps/task:1 165 v3 = tf.Variable(...) # assigned to /job:ps/task:0 166 # Run compute 167 ``` 168 169 Args: 170 ps_tasks: Number of tasks in the `ps` job. Ignored if `cluster` is 171 provided. 172 ps_device: String. Device of the `ps` job. If empty no `ps` job is used. 173 Defaults to `ps`. 174 worker_device: String. Device of the `worker` job. If empty no `worker` 175 job is used. 176 merge_devices: `Boolean`. If `True`, merges or only sets a device if the 177 device constraint is completely unset. merges device specification rather 178 than overriding them. 179 cluster: `ClusterDef` proto or `ClusterSpec`. 180 ps_ops: List of strings representing `Operation` types that need to be 181 placed on `ps` devices. If `None`, defaults to `STANDARD_PS_OPS`. 182 ps_strategy: A callable invoked for every ps `Operation` (i.e. matched by 183 `ps_ops`), that takes the `Operation` and returns the ps task index to 184 use. If `None`, defaults to a round-robin strategy across all `ps` 185 devices. 186 187 Returns: 188 A function to pass to `tf.device()`. 189 190 Raises: 191 TypeError if `cluster` is not a dictionary or `ClusterDef` protocol buffer, 192 or if `ps_strategy` is provided but not a callable. 193 """ 194 if cluster is not None: 195 if isinstance(cluster, server_lib.ClusterSpec): 196 cluster_spec = cluster.as_dict() 197 else: 198 cluster_spec = server_lib.ClusterSpec(cluster).as_dict() 199 # Get ps_job_name from ps_device by stripping "/job:". 200 ps_job_name = pydev.DeviceSpec.from_string(ps_device).job 201 if ps_job_name not in cluster_spec or cluster_spec[ps_job_name] is None: 202 return None 203 ps_tasks = len(cluster_spec[ps_job_name]) 204 205 if ps_tasks == 0: 206 return None 207 208 if ps_ops is None: 209 # TODO(sherrym): Variables in the LOCAL_VARIABLES collection should not be 210 # placed in the parameter server. 211 ps_ops = list(STANDARD_PS_OPS) 212 213 if not merge_devices: 214 logging.warning( 215 "DEPRECATION: It is recommended to set merge_devices=true in " 216 "replica_device_setter") 217 if ps_strategy is None: 218 ps_strategy = _RoundRobinStrategy(ps_tasks) 219 if not six.callable(ps_strategy): 220 raise TypeError("ps_strategy must be callable") 221 chooser = _ReplicaDeviceChooser( 222 ps_tasks, ps_device, worker_device, merge_devices, ps_ops, ps_strategy) 223 return chooser.device_function 224