1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Device function for replicated training.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import six 21 22from tensorflow.core.framework import node_def_pb2 23from tensorflow.python.framework import device as pydev 24from tensorflow.python.platform import tf_logging as logging 25from tensorflow.python.training import server_lib 26from tensorflow.python.util.tf_export import tf_export 27 28# This is a tuple of PS ops used by tf.estimator.Estimator which should work in 29# almost all of cases. 30STANDARD_PS_OPS = ("Variable", "VariableV2", "AutoReloadVariable", 31 "MutableHashTable", "MutableHashTableV2", 32 "MutableHashTableOfTensors", "MutableHashTableOfTensorsV2", 33 "MutableDenseHashTable", "MutableDenseHashTableV2", 34 "VarHandleOp", "BoostedTreesEnsembleResourceHandleOp", 35 "BoostedTreesQuantileStreamResourceHandleOp", 36 "ResourceConditionalAccumulator", 37 "DecisionTreeResource") 38 39 40class _RoundRobinStrategy(object): 41 """Returns the next ps task index for placement in round-robin order. 42 43 This class is not to be used directly by users. See instead 44 `replica_device_setter()` below. 45 """ 46 47 def __init__(self, num_tasks): 48 """Create a new `_RoundRobinStrategy`. 49 50 Args: 51 num_tasks: Number of ps tasks to cycle among. 52 """ 53 self._num_tasks = num_tasks 54 self._next_task = 0 55 56 def __call__(self, unused_op): 57 """Choose a ps task index for the given `Operation`. 58 59 Args: 60 unused_op: An `Operation` to be placed on ps. 61 62 Returns: 63 The next ps task index to use for the `Operation`. Returns the next 64 index, in the range `[offset, offset + num_tasks)`. 65 """ 66 task = self._next_task 67 self._next_task = (self._next_task + 1) % self._num_tasks 68 return task 69 70 71class _ReplicaDeviceChooser(object): 72 """Class to choose devices for Ops in a replicated training setup. 73 74 This class is not to be used directly by users. See instead 75 `replica_device_setter()` below. 76 """ 77 78 def __init__(self, ps_tasks, ps_device, worker_device, merge_devices, ps_ops, 79 ps_strategy): 80 """Create a new `_ReplicaDeviceChooser`. 81 82 Args: 83 ps_tasks: Number of tasks in the `ps` job. 84 ps_device: String. Name of the `ps` job. 85 worker_device: String. Name of the `worker` job. 86 merge_devices: Boolean. Set to True to allow merging of device specs. 87 ps_ops: List of strings representing `Operation` types that need to be 88 placed on `ps` devices. 89 ps_strategy: A callable invoked for every ps `Operation` (i.e. matched by 90 `ps_ops`), that takes the `Operation` and returns the ps task index to 91 use. 92 """ 93 self._ps_tasks = ps_tasks 94 self._ps_device = ps_device 95 self._worker_device = worker_device 96 self._merge_devices = merge_devices 97 self._ps_ops = ps_ops 98 self._ps_strategy = ps_strategy 99 100 def device_function(self, op): 101 """Choose a device for `op`. 102 103 Args: 104 op: an `Operation`. 105 106 Returns: 107 The device to use for the `Operation`. 108 """ 109 # If we don't return early here, either merge_devices is True, or op.device 110 # is empty (in which case merging is a no-op). So we can always merge below. 111 if not self._merge_devices and op.device: 112 return op.device 113 114 current_device = pydev.DeviceSpec.from_string(op.device or "") 115 116 # The ps_device will be used for specified ops (ps_ops) whenever it is 117 # present and ps_tasks is non-zero. However, its task number will only be 118 # set (using ps_strategy) if there is a job field in ps_device that won't be 119 # changed by the job field (if present) in current_device. 120 node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def 121 if self._ps_tasks and self._ps_device and node_def.op in self._ps_ops: 122 ps_device = pydev.DeviceSpec.from_string(self._ps_device) 123 124 current_job, ps_job = current_device.job, ps_device.job 125 if ps_job and (not current_job or current_job == ps_job): 126 ps_device = ps_device.replace(task=self._ps_strategy(op)) 127 128 ps_device = ps_device.make_merged_spec(current_device) 129 return ps_device.to_string() 130 131 worker_device = pydev.DeviceSpec.from_string(self._worker_device or "") 132 worker_device = worker_device.make_merged_spec(current_device) 133 return worker_device.to_string() 134 135 136@tf_export(v1=["train.replica_device_setter"]) 137def replica_device_setter(ps_tasks=0, 138 ps_device="/job:ps", 139 worker_device="/job:worker", 140 merge_devices=True, 141 cluster=None, 142 ps_ops=None, 143 ps_strategy=None): 144 """Return a `device function` to use when building a Graph for replicas. 145 146 Device Functions are used in `with tf.device(device_function):` statement to 147 automatically assign devices to `Operation` objects as they are constructed, 148 Device constraints are added from the inner-most context first, working 149 outwards. The merging behavior adds constraints to fields that are yet unset 150 by a more inner context. Currently the fields are (job, task, cpu/gpu). 151 152 If `cluster` is `None`, and `ps_tasks` is 0, the returned function is a no-op. 153 Otherwise, the value of `ps_tasks` is derived from `cluster`. 154 155 By default, only Variable ops are placed on ps tasks, and the placement 156 strategy is round-robin over all ps tasks. A custom `ps_strategy` may be used 157 to do more intelligent placement, such as 158 `tf.contrib.training.GreedyLoadBalancingStrategy`. 159 160 For example, 161 162 ```python 163 # To build a cluster with two ps jobs on hosts ps0 and ps1, and 3 worker 164 # jobs on hosts worker0, worker1 and worker2. 165 cluster_spec = { 166 "ps": ["ps0:2222", "ps1:2222"], 167 "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]} 168 with 169 tf.device(tf.compat.v1.train.replica_device_setter(cluster=cluster_spec)): 170 # Build your graph 171 v1 = tf.Variable(...) # assigned to /job:ps/task:0 172 v2 = tf.Variable(...) # assigned to /job:ps/task:1 173 v3 = tf.Variable(...) # assigned to /job:ps/task:0 174 # Run compute 175 ``` 176 177 Args: 178 ps_tasks: Number of tasks in the `ps` job. Ignored if `cluster` is 179 provided. 180 ps_device: String. Device of the `ps` job. If empty no `ps` job is used. 181 Defaults to `ps`. 182 worker_device: String. Device of the `worker` job. If empty no `worker` 183 job is used. 184 merge_devices: `Boolean`. If `True`, merges or only sets a device if the 185 device constraint is completely unset. merges device specification rather 186 than overriding them. 187 cluster: `ClusterDef` proto or `ClusterSpec`. 188 ps_ops: List of strings representing `Operation` types that need to be 189 placed on `ps` devices. If `None`, defaults to `STANDARD_PS_OPS`. 190 ps_strategy: A callable invoked for every ps `Operation` (i.e. matched by 191 `ps_ops`), that takes the `Operation` and returns the ps task index to 192 use. If `None`, defaults to a round-robin strategy across all `ps` 193 devices. 194 195 Returns: 196 A function to pass to `tf.device()`. 197 198 Raises: 199 TypeError if `cluster` is not a dictionary or `ClusterDef` protocol buffer, 200 or if `ps_strategy` is provided but not a callable. 201 """ 202 if cluster is not None: 203 if isinstance(cluster, server_lib.ClusterSpec): 204 cluster_spec = cluster.as_dict() 205 else: 206 cluster_spec = server_lib.ClusterSpec(cluster).as_dict() 207 # Get ps_job_name from ps_device by stripping "/job:". 208 ps_job_name = pydev.DeviceSpec.from_string(ps_device).job 209 if ps_job_name not in cluster_spec or cluster_spec[ps_job_name] is None: 210 return None 211 ps_tasks = len(cluster_spec[ps_job_name]) 212 213 if ps_tasks == 0: 214 return None 215 216 if ps_ops is None: 217 # TODO(sherrym): Variables in the LOCAL_VARIABLES collection should not be 218 # placed in the parameter server. 219 ps_ops = list(STANDARD_PS_OPS) 220 221 if not merge_devices: 222 logging.warning( 223 "DEPRECATION: It is recommended to set merge_devices=true in " 224 "replica_device_setter") 225 if ps_strategy is None: 226 ps_strategy = _RoundRobinStrategy(ps_tasks) 227 if not six.callable(ps_strategy): 228 raise TypeError("ps_strategy must be callable") 229 chooser = _ReplicaDeviceChooser(ps_tasks, ps_device, worker_device, 230 merge_devices, ps_ops, ps_strategy) 231 return chooser.device_function 232