1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Class to represent a device."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22import threading
23from tensorflow.python.util.tf_export import tf_export
24
25
26@tf_export(v1=["DeviceSpec"])
27class DeviceSpec(object):
28  """Represents a (possibly partial) specification for a TensorFlow device.
29
30  `DeviceSpec`s are used throughout TensorFlow to describe where state is stored
31  and computations occur. Using `DeviceSpec` allows you to parse device spec
32  strings to verify their validity, merge them or compose them programmatically.
33
34  Example:
35
36  ```python
37  # Place the operations on device "GPU:0" in the "ps" job.
38  device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0)
39  with tf.device(device_spec):
40    # Both my_var and squared_var will be placed on /job:ps/device:GPU:0.
41    my_var = tf.Variable(..., name="my_variable")
42    squared_var = tf.square(my_var)
43  ```
44
45  If a `DeviceSpec` is partially specified, it will be merged with other
46  `DeviceSpec`s according to the scope in which it is defined. `DeviceSpec`
47  components defined in inner scopes take precedence over those defined in
48  outer scopes.
49
50  ```python
51  with tf.device(DeviceSpec(job="train", )):
52    with tf.device(DeviceSpec(job="ps", device_type="GPU", device_index=0):
53      # Nodes created here will be assigned to /job:ps/device:GPU:0.
54    with tf.device(DeviceSpec(device_type="GPU", device_index=1):
55      # Nodes created here will be assigned to /job:train/device:GPU:1.
56  ```
57
58  A `DeviceSpec` consists of 5 components -- each of
59  which is optionally specified:
60
61  * Job: The job name.
62  * Replica: The replica index.
63  * Task: The task index.
64  * Device type: The device type string (e.g. "CPU" or "GPU").
65  * Device index: The device index.
66  """
67
68  def __init__(self, job=None, replica=None, task=None, device_type=None,
69               device_index=None):
70    """Create a new `DeviceSpec` object.
71
72    Args:
73      job: string.  Optional job name.
74      replica: int.  Optional replica index.
75      task: int.  Optional task index.
76      device_type: Optional device type string (e.g. "CPU" or "GPU")
77      device_index: int.  Optional device index.  If left
78        unspecified, device represents 'any' device_index.
79    """
80    self.job = job
81    self.replica = replica
82    self.task = task
83    if device_type == "cpu" or device_type == "gpu":
84      # For backwards compatibility only, we support lowercase variants of
85      # cpu and gpu but turn them into uppercase here.
86      self.device_type = device_type.upper()
87    else:
88      self.device_type = device_type
89    self.device_index = device_index
90    self._hash = hash(self.to_string())
91
92  def _clear(self):
93    self._job = None
94    self._replica = None
95    self._task = None
96    self.device_type = None
97    self.device_index = None
98
99  @property
100  def job(self):
101    return self._job
102
103  @job.setter
104  def job(self, job):
105    if job is not None:
106      self._job = str(job)
107    else:
108      self._job = None
109
110  @property
111  def replica(self):
112    return self._replica
113
114  @replica.setter
115  def replica(self, replica):
116    if replica is not None:
117      self._replica = int(replica)
118    else:
119      self._replica = None
120
121  @property
122  def task(self):
123    return self._task
124
125  @task.setter
126  def task(self, task):
127    if task is not None:
128      self._task = int(task)
129    else:
130      self._task = None
131
132  def parse_from_string(self, spec):
133    """Parse a `DeviceSpec` name into its components.
134
135    Args:
136      spec: a string of the form
137       /job:<name>/replica:<id>/task:<id>/device:CPU:<id>
138      or
139       /job:<name>/replica:<id>/task:<id>/device:GPU:<id>
140      as cpu and gpu are mutually exclusive.
141      All entries are optional.
142
143    Returns:
144      The `DeviceSpec`.
145
146    Raises:
147      ValueError: if the spec was not valid.
148    """
149    self._clear()
150    splits = [x.split(":") for x in spec.split("/")]
151    for y in splits:
152      ly = len(y)
153      if y:
154        # NOTE(touts): we use the property getters here.
155        if ly == 2 and y[0] == "job":
156          self.job = y[1]
157        elif ly == 2 and y[0] == "replica":
158          self.replica = y[1]
159        elif ly == 2 and y[0] == "task":
160          self.task = y[1]
161        elif ((ly == 1 or ly == 2) and
162              ((y[0].upper() == "GPU") or (y[0].upper() == "CPU"))):
163          if self.device_type is not None:
164            raise ValueError("Cannot specify multiple device types: %s" % spec)
165          self.device_type = y[0].upper()
166          if ly == 2 and y[1] != "*":
167            self.device_index = int(y[1])
168        elif ly == 3 and y[0] == "device":
169          if self.device_type is not None:
170            raise ValueError("Cannot specify multiple device types: %s" % spec)
171          self.device_type = y[1]
172          if y[2] != "*":
173            self.device_index = int(y[2])
174        elif ly and y[0] != "":  # pylint: disable=g-explicit-bool-comparison
175          raise ValueError("Unknown attribute: '%s' in '%s'" % (y[0], spec))
176
177    return self
178
179  def merge_from(self, dev):
180    """Merge the properties of "dev" into this `DeviceSpec`.
181
182    Args:
183      dev: a `DeviceSpec`.
184    """
185    if dev.job is not None:
186      self.job = dev.job
187    if dev.replica is not None:
188      self.replica = dev.replica
189    if dev.task is not None:
190      self.task = dev.task
191    if dev.device_type is not None:
192      self.device_type = dev.device_type
193    if dev.device_index is not None:
194      self.device_index = dev.device_index
195
196  def to_string(self):
197    """Return a string representation of this `DeviceSpec`.
198
199    Returns:
200      a string of the form
201      /job:<name>/replica:<id>/task:<id>/device:<device_type>:<id>.
202    """
203    dev = ""
204    if self.job is not None:
205      dev += "/job:" + self.job
206    if self.replica is not None:
207      dev += "/replica:" + str(self.replica)
208    if self.task is not None:
209      dev += "/task:" + str(self.task)
210    if self.device_type is not None:
211      device_index_string = "*"
212      if self.device_index is not None:
213        device_index_string = str(self.device_index)
214      dev += "/device:%s:%s" % (self.device_type, device_index_string)
215    return dev
216
217  @staticmethod
218  def from_string(spec):
219    """Construct a `DeviceSpec` from a string.
220
221    Args:
222      spec: a string of the form
223       /job:<name>/replica:<id>/task:<id>/device:CPU:<id>
224      or
225       /job:<name>/replica:<id>/task:<id>/device:GPU:<id>
226      as cpu and gpu are mutually exclusive.
227      All entries are optional.
228
229    Returns:
230      A DeviceSpec.
231    """
232    return DeviceSpec().parse_from_string(spec)
233
234  def __eq__(self, other):
235    return self.to_string() == other.to_string()
236
237  def __hash__(self):
238    return self._hash
239
240
241def check_valid(spec):
242  """Check that a device spec is valid.
243
244  Args:
245    spec: a string.
246
247  Raises:
248    An exception if the spec is invalid.
249  """
250  # Construct a DeviceSpec.  It will assert a failure if spec is invalid.
251  DeviceSpec.from_string(spec)
252
253
254def canonical_name(device):
255  """Returns a canonical name for the given `DeviceSpec` or device name."""
256  if device is None:
257    return ""
258  if isinstance(device, DeviceSpec):
259    return device.to_string()
260  else:
261    device = DeviceSpec.from_string(device)
262    return device.to_string()
263
264
265# Cache from DeviceSpec objects to their corresponding device functions.
266# This cache is maintained for correctness, not performance: it makes it
267# possible to compare the device function stacks belonging to different
268# graphs in a meaningful way.
269_cached_device_functions = {}
270_cached_device_specs = {}
271_cache_lock = threading.Lock()
272
273
274def merge_device(spec):
275  """Returns a device function that merges devices specifications.
276
277  This can be used to merge partial specifications of devices. The
278  innermost setting for a device field takes precedence. For example:
279
280    with tf.device(merge_device("/device:GPU:0"))
281      # Nodes created here have device "/device:GPU:0"
282      with tf.device(merge_device("/job:worker")):
283        # Nodes created here have device "/job:worker/device:GPU:0"
284        with tf.device(merge_device("/device:CPU:0")):
285          # Nodes created here have device "/job:worker/device:CPU:0"
286          with tf.device(merge_device("/job:ps")):
287            # Nodes created here have device "/job:ps/device:CPU:0"
288
289  Args:
290    spec: A `DeviceSpec` or a device spec string (partially) describing the
291      device that should be used for all nodes created in the scope of
292      the returned device function's with block.
293
294  Returns:
295    A device function with the above-described behavior.
296
297  Raises:
298    ValueError: if the spec was not valid.
299  """
300  with _cache_lock:
301    if not isinstance(spec, DeviceSpec):
302      cached_device_spec = _cached_device_specs.get(spec, None)
303      if cached_device_spec is None:
304        device_spec = DeviceSpec.from_string(spec or "")
305        _cached_device_specs[spec] = device_spec
306        spec = device_spec
307      else:
308        spec = cached_device_spec
309    cached_function = _cached_device_functions.get(spec, None)
310    if cached_function is not None:
311      return cached_function
312
313    def _device_function(node_def):
314      current_device = DeviceSpec.from_string(node_def.device or "")
315      copy_spec = copy.copy(spec)
316      copy_spec.merge_from(current_device)  # current_device takes precedence.
317      return copy_spec
318
319    _cached_device_functions[spec] = _device_function
320    return _device_function
321