1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Class to represent a device."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22from tensorflow.python.util.tf_export import tf_export
23
24
25@tf_export("DeviceSpec")
26class DeviceSpec(object):
27  """Represents a (possibly partial) specification for a TensorFlow device.
28
29  `DeviceSpec`s are used throughout TensorFlow to describe where state is stored
30  and computations occur. Using `DeviceSpec` allows you to parse device spec
31  strings to verify their validity, merge them or compose them programmatically.
32
33  Example:
34
35  ```python
36  # Place the operations on device "GPU:0" in the "ps" job.
37  device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0)
38  with tf.device(device_spec):
39    # Both my_var and squared_var will be placed on /job:ps/device:GPU:0.
40    my_var = tf.Variable(..., name="my_variable")
41    squared_var = tf.square(my_var)
42  ```
43
44  If a `DeviceSpec` is partially specified, it will be merged with other
45  `DeviceSpec`s according to the scope in which it is defined. `DeviceSpec`
46  components defined in inner scopes take precedence over those defined in
47  outer scopes.
48
49  ```python
50  with tf.device(DeviceSpec(job="train", )):
51    with tf.device(DeviceSpec(job="ps", device_type="GPU", device_index=0):
52      # Nodes created here will be assigned to /job:ps/device:GPU:0.
53    with tf.device(DeviceSpec(device_type="GPU", device_index=1):
54      # Nodes created here will be assigned to /job:train/device:GPU:1.
55  ```
56
57  A `DeviceSpec` consists of 5 components -- each of
58  which is optionally specified:
59
60  * Job: The job name.
61  * Replica: The replica index.
62  * Task: The task index.
63  * Device type: The device type string (e.g. "CPU" or "GPU").
64  * Device index: The device index.
65  """
66
67  def __init__(self, job=None, replica=None, task=None, device_type=None,
68               device_index=None):
69    """Create a new `DeviceSpec` object.
70
71    Args:
72      job: string.  Optional job name.
73      replica: int.  Optional replica index.
74      task: int.  Optional task index.
75      device_type: Optional device type string (e.g. "CPU" or "GPU")
76      device_index: int.  Optional device index.  If left
77        unspecified, device represents 'any' device_index.
78    """
79    self.job = job
80    self.replica = replica
81    self.task = task
82    if device_type == "cpu" or device_type == "gpu":
83      # For backwards compatibility only, we support lowercase variants of
84      # cpu and gpu but turn them into uppercase here.
85      self.device_type = device_type.upper()
86    else:
87      self.device_type = device_type
88    self.device_index = device_index
89
90  def _clear(self):
91    self._job = None
92    self._replica = None
93    self._task = None
94    self.device_type = None
95    self.device_index = None
96
97  @property
98  def job(self):
99    return self._job
100
101  @job.setter
102  def job(self, job):
103    if job is not None:
104      self._job = str(job)
105    else:
106      self._job = None
107
108  @property
109  def replica(self):
110    return self._replica
111
112  @replica.setter
113  def replica(self, replica):
114    if replica is not None:
115      self._replica = int(replica)
116    else:
117      self._replica = None
118
119  @property
120  def task(self):
121    return self._task
122
123  @task.setter
124  def task(self, task):
125    if task is not None:
126      self._task = int(task)
127    else:
128      self._task = None
129
130  def parse_from_string(self, spec):
131    """Parse a `DeviceSpec` name into its components.
132
133    Args:
134      spec: a string of the form
135       /job:<name>/replica:<id>/task:<id>/device:CPU:<id>
136      or
137       /job:<name>/replica:<id>/task:<id>/device:GPU:<id>
138      as cpu and gpu are mutually exclusive.
139      All entries are optional.
140
141    Returns:
142      The `DeviceSpec`.
143
144    Raises:
145      ValueError: if the spec was not valid.
146    """
147    self._clear()
148    splits = [x.split(":") for x in spec.split("/")]
149    for y in splits:
150      ly = len(y)
151      if y:
152        # NOTE(touts): we use the property getters here.
153        if ly == 2 and y[0] == "job":
154          self.job = y[1]
155        elif ly == 2 and y[0] == "replica":
156          self.replica = y[1]
157        elif ly == 2 and y[0] == "task":
158          self.task = y[1]
159        elif ((ly == 1 or ly == 2) and
160              ((y[0].upper() == "GPU") or (y[0].upper() == "CPU"))):
161          if self.device_type is not None:
162            raise ValueError("Cannot specify multiple device types: %s" % spec)
163          self.device_type = y[0].upper()
164          if ly == 2 and y[1] != "*":
165            self.device_index = int(y[1])
166        elif ly == 3 and y[0] == "device":
167          if self.device_type is not None:
168            raise ValueError("Cannot specify multiple device types: %s" % spec)
169          self.device_type = y[1]
170          if y[2] != "*":
171            self.device_index = int(y[2])
172        elif ly and y[0] != "":  # pylint: disable=g-explicit-bool-comparison
173          raise ValueError("Unknown attribute: '%s' in '%s'" % (y[0], spec))
174
175    return self
176
177  def merge_from(self, dev):
178    """Merge the properties of "dev" into this `DeviceSpec`.
179
180    Args:
181      dev: a `DeviceSpec`.
182    """
183    if dev.job is not None:
184      self.job = dev.job
185    if dev.replica is not None:
186      self.replica = dev.replica
187    if dev.task is not None:
188      self.task = dev.task
189    if dev.device_type is not None:
190      self.device_type = dev.device_type
191    if dev.device_index is not None:
192      self.device_index = dev.device_index
193
194  def to_string(self):
195    """Return a string representation of this `DeviceSpec`.
196
197    Returns:
198      a string of the form
199      /job:<name>/replica:<id>/task:<id>/device:<device_type>:<id>.
200    """
201    dev = ""
202    if self.job is not None:
203      dev += "/job:" + self.job
204    if self.replica is not None:
205      dev += "/replica:" + str(self.replica)
206    if self.task is not None:
207      dev += "/task:" + str(self.task)
208    if self.device_type is not None:
209      device_index_string = "*"
210      if self.device_index is not None:
211        device_index_string = str(self.device_index)
212      dev += "/device:%s:%s" % (self.device_type, device_index_string)
213    return dev
214
215  @staticmethod
216  def from_string(spec):
217    """Construct a `DeviceSpec` from a string.
218
219    Args:
220      spec: a string of the form
221       /job:<name>/replica:<id>/task:<id>/device:CPU:<id>
222      or
223       /job:<name>/replica:<id>/task:<id>/device:GPU:<id>
224      as cpu and gpu are mutually exclusive.
225      All entries are optional.
226
227    Returns:
228      A DeviceSpec.
229    """
230    return DeviceSpec().parse_from_string(spec)
231
232
233def check_valid(spec):
234  """Check that a device spec is valid.
235
236  Args:
237    spec: a string.
238
239  Raises:
240    An exception if the spec is invalid.
241  """
242  # Construct a DeviceSpec.  It will assert a failure if spec is invalid.
243  DeviceSpec.from_string(spec)
244
245
246def canonical_name(device):
247  """Returns a canonical name for the given `DeviceSpec` or device name."""
248  if device is None:
249    return ""
250  if isinstance(device, DeviceSpec):
251    return device.to_string()
252  else:
253    device = DeviceSpec.from_string(device)
254    return device.to_string()
255
256
257def merge_device(spec):
258  """Returns a device function that merges devices specifications.
259
260  This can be used to merge partial specifications of devices. The
261  innermost setting for a device field takes precedence. For example:
262
263    with tf.device(merge_device("/device:GPU:0"))
264      # Nodes created here have device "/device:GPU:0"
265      with tf.device(merge_device("/job:worker")):
266        # Nodes created here have device "/job:worker/device:GPU:0"
267        with tf.device(merge_device("/device:CPU:0")):
268          # Nodes created here have device "/job:worker/device:CPU:0"
269          with tf.device(merge_device("/job:ps")):
270            # Nodes created here have device "/job:ps/device:CPU:0"
271
272  Args:
273    spec: A `DeviceSpec` or a device spec string (partially) describing the
274      device that should be used for all nodes created in the scope of
275      the returned device function's with block.
276
277  Returns:
278    A device function with the above-described behavior.
279
280  Raises:
281    ValueError: if the spec was not valid.
282  """
283  if not isinstance(spec, DeviceSpec):
284    spec = DeviceSpec.from_string(spec or "")
285  def _device_function(node_def):
286    current_device = DeviceSpec.from_string(node_def.device or "")
287    copy_spec = copy.copy(spec)
288    copy_spec.merge_from(current_device)  # current_device takes precedence.
289    return copy_spec
290  return _device_function
291