1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Class to represent a device.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22import threading 23from tensorflow.python.util.tf_export import tf_export 24 25 26@tf_export(v1=["DeviceSpec"]) 27class DeviceSpec(object): 28 """Represents a (possibly partial) specification for a TensorFlow device. 29 30 `DeviceSpec`s are used throughout TensorFlow to describe where state is stored 31 and computations occur. Using `DeviceSpec` allows you to parse device spec 32 strings to verify their validity, merge them or compose them programmatically. 33 34 Example: 35 36 ```python 37 # Place the operations on device "GPU:0" in the "ps" job. 38 device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0) 39 with tf.device(device_spec): 40 # Both my_var and squared_var will be placed on /job:ps/device:GPU:0. 41 my_var = tf.Variable(..., name="my_variable") 42 squared_var = tf.square(my_var) 43 ``` 44 45 If a `DeviceSpec` is partially specified, it will be merged with other 46 `DeviceSpec`s according to the scope in which it is defined. `DeviceSpec` 47 components defined in inner scopes take precedence over those defined in 48 outer scopes. 49 50 ```python 51 with tf.device(DeviceSpec(job="train", )): 52 with tf.device(DeviceSpec(job="ps", device_type="GPU", device_index=0): 53 # Nodes created here will be assigned to /job:ps/device:GPU:0. 54 with tf.device(DeviceSpec(device_type="GPU", device_index=1): 55 # Nodes created here will be assigned to /job:train/device:GPU:1. 56 ``` 57 58 A `DeviceSpec` consists of 5 components -- each of 59 which is optionally specified: 60 61 * Job: The job name. 62 * Replica: The replica index. 63 * Task: The task index. 64 * Device type: The device type string (e.g. "CPU" or "GPU"). 65 * Device index: The device index. 66 """ 67 68 def __init__(self, job=None, replica=None, task=None, device_type=None, 69 device_index=None): 70 """Create a new `DeviceSpec` object. 71 72 Args: 73 job: string. Optional job name. 74 replica: int. Optional replica index. 75 task: int. Optional task index. 76 device_type: Optional device type string (e.g. "CPU" or "GPU") 77 device_index: int. Optional device index. If left 78 unspecified, device represents 'any' device_index. 79 """ 80 self.job = job 81 self.replica = replica 82 self.task = task 83 if device_type == "cpu" or device_type == "gpu": 84 # For backwards compatibility only, we support lowercase variants of 85 # cpu and gpu but turn them into uppercase here. 86 self.device_type = device_type.upper() 87 else: 88 self.device_type = device_type 89 self.device_index = device_index 90 self._hash = hash(self.to_string()) 91 92 def _clear(self): 93 self._job = None 94 self._replica = None 95 self._task = None 96 self.device_type = None 97 self.device_index = None 98 99 @property 100 def job(self): 101 return self._job 102 103 @job.setter 104 def job(self, job): 105 if job is not None: 106 self._job = str(job) 107 else: 108 self._job = None 109 110 @property 111 def replica(self): 112 return self._replica 113 114 @replica.setter 115 def replica(self, replica): 116 if replica is not None: 117 self._replica = int(replica) 118 else: 119 self._replica = None 120 121 @property 122 def task(self): 123 return self._task 124 125 @task.setter 126 def task(self, task): 127 if task is not None: 128 self._task = int(task) 129 else: 130 self._task = None 131 132 def parse_from_string(self, spec): 133 """Parse a `DeviceSpec` name into its components. 134 135 Args: 136 spec: a string of the form 137 /job:<name>/replica:<id>/task:<id>/device:CPU:<id> 138 or 139 /job:<name>/replica:<id>/task:<id>/device:GPU:<id> 140 as cpu and gpu are mutually exclusive. 141 All entries are optional. 142 143 Returns: 144 The `DeviceSpec`. 145 146 Raises: 147 ValueError: if the spec was not valid. 148 """ 149 self._clear() 150 splits = [x.split(":") for x in spec.split("/")] 151 for y in splits: 152 ly = len(y) 153 if y: 154 # NOTE(touts): we use the property getters here. 155 if ly == 2 and y[0] == "job": 156 self.job = y[1] 157 elif ly == 2 and y[0] == "replica": 158 self.replica = y[1] 159 elif ly == 2 and y[0] == "task": 160 self.task = y[1] 161 elif ((ly == 1 or ly == 2) and 162 ((y[0].upper() == "GPU") or (y[0].upper() == "CPU"))): 163 if self.device_type is not None: 164 raise ValueError("Cannot specify multiple device types: %s" % spec) 165 self.device_type = y[0].upper() 166 if ly == 2 and y[1] != "*": 167 self.device_index = int(y[1]) 168 elif ly == 3 and y[0] == "device": 169 if self.device_type is not None: 170 raise ValueError("Cannot specify multiple device types: %s" % spec) 171 self.device_type = y[1] 172 if y[2] != "*": 173 self.device_index = int(y[2]) 174 elif ly and y[0] != "": # pylint: disable=g-explicit-bool-comparison 175 raise ValueError("Unknown attribute: '%s' in '%s'" % (y[0], spec)) 176 177 return self 178 179 def merge_from(self, dev): 180 """Merge the properties of "dev" into this `DeviceSpec`. 181 182 Args: 183 dev: a `DeviceSpec`. 184 """ 185 if dev.job is not None: 186 self.job = dev.job 187 if dev.replica is not None: 188 self.replica = dev.replica 189 if dev.task is not None: 190 self.task = dev.task 191 if dev.device_type is not None: 192 self.device_type = dev.device_type 193 if dev.device_index is not None: 194 self.device_index = dev.device_index 195 196 def to_string(self): 197 """Return a string representation of this `DeviceSpec`. 198 199 Returns: 200 a string of the form 201 /job:<name>/replica:<id>/task:<id>/device:<device_type>:<id>. 202 """ 203 dev = "" 204 if self.job is not None: 205 dev += "/job:" + self.job 206 if self.replica is not None: 207 dev += "/replica:" + str(self.replica) 208 if self.task is not None: 209 dev += "/task:" + str(self.task) 210 if self.device_type is not None: 211 device_index_string = "*" 212 if self.device_index is not None: 213 device_index_string = str(self.device_index) 214 dev += "/device:%s:%s" % (self.device_type, device_index_string) 215 return dev 216 217 @staticmethod 218 def from_string(spec): 219 """Construct a `DeviceSpec` from a string. 220 221 Args: 222 spec: a string of the form 223 /job:<name>/replica:<id>/task:<id>/device:CPU:<id> 224 or 225 /job:<name>/replica:<id>/task:<id>/device:GPU:<id> 226 as cpu and gpu are mutually exclusive. 227 All entries are optional. 228 229 Returns: 230 A DeviceSpec. 231 """ 232 return DeviceSpec().parse_from_string(spec) 233 234 def __eq__(self, other): 235 return self.to_string() == other.to_string() 236 237 def __hash__(self): 238 return self._hash 239 240 241def check_valid(spec): 242 """Check that a device spec is valid. 243 244 Args: 245 spec: a string. 246 247 Raises: 248 An exception if the spec is invalid. 249 """ 250 # Construct a DeviceSpec. It will assert a failure if spec is invalid. 251 DeviceSpec.from_string(spec) 252 253 254def canonical_name(device): 255 """Returns a canonical name for the given `DeviceSpec` or device name.""" 256 if device is None: 257 return "" 258 if isinstance(device, DeviceSpec): 259 return device.to_string() 260 else: 261 device = DeviceSpec.from_string(device) 262 return device.to_string() 263 264 265# Cache from DeviceSpec objects to their corresponding device functions. 266# This cache is maintained for correctness, not performance: it makes it 267# possible to compare the device function stacks belonging to different 268# graphs in a meaningful way. 269_cached_device_functions = {} 270_cached_device_specs = {} 271_cache_lock = threading.Lock() 272 273 274def merge_device(spec): 275 """Returns a device function that merges devices specifications. 276 277 This can be used to merge partial specifications of devices. The 278 innermost setting for a device field takes precedence. For example: 279 280 with tf.device(merge_device("/device:GPU:0")) 281 # Nodes created here have device "/device:GPU:0" 282 with tf.device(merge_device("/job:worker")): 283 # Nodes created here have device "/job:worker/device:GPU:0" 284 with tf.device(merge_device("/device:CPU:0")): 285 # Nodes created here have device "/job:worker/device:CPU:0" 286 with tf.device(merge_device("/job:ps")): 287 # Nodes created here have device "/job:ps/device:CPU:0" 288 289 Args: 290 spec: A `DeviceSpec` or a device spec string (partially) describing the 291 device that should be used for all nodes created in the scope of 292 the returned device function's with block. 293 294 Returns: 295 A device function with the above-described behavior. 296 297 Raises: 298 ValueError: if the spec was not valid. 299 """ 300 with _cache_lock: 301 if not isinstance(spec, DeviceSpec): 302 cached_device_spec = _cached_device_specs.get(spec, None) 303 if cached_device_spec is None: 304 device_spec = DeviceSpec.from_string(spec or "") 305 _cached_device_specs[spec] = device_spec 306 spec = device_spec 307 else: 308 spec = cached_device_spec 309 cached_function = _cached_device_functions.get(spec, None) 310 if cached_function is not None: 311 return cached_function 312 313 def _device_function(node_def): 314 current_device = DeviceSpec.from_string(node_def.device or "") 315 copy_spec = copy.copy(spec) 316 copy_spec.merge_from(current_device) # current_device takes precedence. 317 return copy_spec 318 319 _cached_device_functions[spec] = _device_function 320 return _device_function 321