1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Class to represent a device.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22from tensorflow.python.util.tf_export import tf_export 23 24 25@tf_export("DeviceSpec") 26class DeviceSpec(object): 27 """Represents a (possibly partial) specification for a TensorFlow device. 28 29 `DeviceSpec`s are used throughout TensorFlow to describe where state is stored 30 and computations occur. Using `DeviceSpec` allows you to parse device spec 31 strings to verify their validity, merge them or compose them programmatically. 32 33 Example: 34 35 ```python 36 # Place the operations on device "GPU:0" in the "ps" job. 37 device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0) 38 with tf.device(device_spec): 39 # Both my_var and squared_var will be placed on /job:ps/device:GPU:0. 40 my_var = tf.Variable(..., name="my_variable") 41 squared_var = tf.square(my_var) 42 ``` 43 44 If a `DeviceSpec` is partially specified, it will be merged with other 45 `DeviceSpec`s according to the scope in which it is defined. `DeviceSpec` 46 components defined in inner scopes take precedence over those defined in 47 outer scopes. 48 49 ```python 50 with tf.device(DeviceSpec(job="train", )): 51 with tf.device(DeviceSpec(job="ps", device_type="GPU", device_index=0): 52 # Nodes created here will be assigned to /job:ps/device:GPU:0. 53 with tf.device(DeviceSpec(device_type="GPU", device_index=1): 54 # Nodes created here will be assigned to /job:train/device:GPU:1. 55 ``` 56 57 A `DeviceSpec` consists of 5 components -- each of 58 which is optionally specified: 59 60 * Job: The job name. 61 * Replica: The replica index. 62 * Task: The task index. 63 * Device type: The device type string (e.g. "CPU" or "GPU"). 64 * Device index: The device index. 65 """ 66 67 def __init__(self, job=None, replica=None, task=None, device_type=None, 68 device_index=None): 69 """Create a new `DeviceSpec` object. 70 71 Args: 72 job: string. Optional job name. 73 replica: int. Optional replica index. 74 task: int. Optional task index. 75 device_type: Optional device type string (e.g. "CPU" or "GPU") 76 device_index: int. Optional device index. If left 77 unspecified, device represents 'any' device_index. 78 """ 79 self.job = job 80 self.replica = replica 81 self.task = task 82 if device_type == "cpu" or device_type == "gpu": 83 # For backwards compatibility only, we support lowercase variants of 84 # cpu and gpu but turn them into uppercase here. 85 self.device_type = device_type.upper() 86 else: 87 self.device_type = device_type 88 self.device_index = device_index 89 90 def _clear(self): 91 self._job = None 92 self._replica = None 93 self._task = None 94 self.device_type = None 95 self.device_index = None 96 97 @property 98 def job(self): 99 return self._job 100 101 @job.setter 102 def job(self, job): 103 if job is not None: 104 self._job = str(job) 105 else: 106 self._job = None 107 108 @property 109 def replica(self): 110 return self._replica 111 112 @replica.setter 113 def replica(self, replica): 114 if replica is not None: 115 self._replica = int(replica) 116 else: 117 self._replica = None 118 119 @property 120 def task(self): 121 return self._task 122 123 @task.setter 124 def task(self, task): 125 if task is not None: 126 self._task = int(task) 127 else: 128 self._task = None 129 130 def parse_from_string(self, spec): 131 """Parse a `DeviceSpec` name into its components. 132 133 Args: 134 spec: a string of the form 135 /job:<name>/replica:<id>/task:<id>/device:CPU:<id> 136 or 137 /job:<name>/replica:<id>/task:<id>/device:GPU:<id> 138 as cpu and gpu are mutually exclusive. 139 All entries are optional. 140 141 Returns: 142 The `DeviceSpec`. 143 144 Raises: 145 ValueError: if the spec was not valid. 146 """ 147 self._clear() 148 splits = [x.split(":") for x in spec.split("/")] 149 for y in splits: 150 ly = len(y) 151 if y: 152 # NOTE(touts): we use the property getters here. 153 if ly == 2 and y[0] == "job": 154 self.job = y[1] 155 elif ly == 2 and y[0] == "replica": 156 self.replica = y[1] 157 elif ly == 2 and y[0] == "task": 158 self.task = y[1] 159 elif ((ly == 1 or ly == 2) and 160 ((y[0].upper() == "GPU") or (y[0].upper() == "CPU"))): 161 if self.device_type is not None: 162 raise ValueError("Cannot specify multiple device types: %s" % spec) 163 self.device_type = y[0].upper() 164 if ly == 2 and y[1] != "*": 165 self.device_index = int(y[1]) 166 elif ly == 3 and y[0] == "device": 167 if self.device_type is not None: 168 raise ValueError("Cannot specify multiple device types: %s" % spec) 169 self.device_type = y[1] 170 if y[2] != "*": 171 self.device_index = int(y[2]) 172 elif ly and y[0] != "": # pylint: disable=g-explicit-bool-comparison 173 raise ValueError("Unknown attribute: '%s' in '%s'" % (y[0], spec)) 174 175 return self 176 177 def merge_from(self, dev): 178 """Merge the properties of "dev" into this `DeviceSpec`. 179 180 Args: 181 dev: a `DeviceSpec`. 182 """ 183 if dev.job is not None: 184 self.job = dev.job 185 if dev.replica is not None: 186 self.replica = dev.replica 187 if dev.task is not None: 188 self.task = dev.task 189 if dev.device_type is not None: 190 self.device_type = dev.device_type 191 if dev.device_index is not None: 192 self.device_index = dev.device_index 193 194 def to_string(self): 195 """Return a string representation of this `DeviceSpec`. 196 197 Returns: 198 a string of the form 199 /job:<name>/replica:<id>/task:<id>/device:<device_type>:<id>. 200 """ 201 dev = "" 202 if self.job is not None: 203 dev += "/job:" + self.job 204 if self.replica is not None: 205 dev += "/replica:" + str(self.replica) 206 if self.task is not None: 207 dev += "/task:" + str(self.task) 208 if self.device_type is not None: 209 device_index_string = "*" 210 if self.device_index is not None: 211 device_index_string = str(self.device_index) 212 dev += "/device:%s:%s" % (self.device_type, device_index_string) 213 return dev 214 215 @staticmethod 216 def from_string(spec): 217 """Construct a `DeviceSpec` from a string. 218 219 Args: 220 spec: a string of the form 221 /job:<name>/replica:<id>/task:<id>/device:CPU:<id> 222 or 223 /job:<name>/replica:<id>/task:<id>/device:GPU:<id> 224 as cpu and gpu are mutually exclusive. 225 All entries are optional. 226 227 Returns: 228 A DeviceSpec. 229 """ 230 return DeviceSpec().parse_from_string(spec) 231 232 233def check_valid(spec): 234 """Check that a device spec is valid. 235 236 Args: 237 spec: a string. 238 239 Raises: 240 An exception if the spec is invalid. 241 """ 242 # Construct a DeviceSpec. It will assert a failure if spec is invalid. 243 DeviceSpec.from_string(spec) 244 245 246def canonical_name(device): 247 """Returns a canonical name for the given `DeviceSpec` or device name.""" 248 if device is None: 249 return "" 250 if isinstance(device, DeviceSpec): 251 return device.to_string() 252 else: 253 device = DeviceSpec.from_string(device) 254 return device.to_string() 255 256 257def merge_device(spec): 258 """Returns a device function that merges devices specifications. 259 260 This can be used to merge partial specifications of devices. The 261 innermost setting for a device field takes precedence. For example: 262 263 with tf.device(merge_device("/device:GPU:0")) 264 # Nodes created here have device "/device:GPU:0" 265 with tf.device(merge_device("/job:worker")): 266 # Nodes created here have device "/job:worker/device:GPU:0" 267 with tf.device(merge_device("/device:CPU:0")): 268 # Nodes created here have device "/job:worker/device:CPU:0" 269 with tf.device(merge_device("/job:ps")): 270 # Nodes created here have device "/job:ps/device:CPU:0" 271 272 Args: 273 spec: A `DeviceSpec` or a device spec string (partially) describing the 274 device that should be used for all nodes created in the scope of 275 the returned device function's with block. 276 277 Returns: 278 A device function with the above-described behavior. 279 280 Raises: 281 ValueError: if the spec was not valid. 282 """ 283 if not isinstance(spec, DeviceSpec): 284 spec = DeviceSpec.from_string(spec or "") 285 def _device_function(node_def): 286 current_device = DeviceSpec.from_string(node_def.device or "") 287 copy_spec = copy.copy(spec) 288 copy_spec.merge_from(current_device) # current_device takes precedence. 289 return copy_spec 290 return _device_function 291