1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class to represent a device."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.util.tf_export import tf_export
22
23
24# EPU represents for TPU embedding for now. Subject to change in future.
25_VALID_DEVICE_TYPES = frozenset({"CPU", "GPU", "TPU", "CUSTOM", "EPU"})
26
27
28# ==============================================================================
29# == Global Implementation Details =============================================
30# ==============================================================================
31_STRING_TO_COMPONENTS_CACHE = {}
32_COMPONENTS_TO_STRING_CACHE = {}
33
34
35def _as_str_or_none(inp):
36  return None if inp is None else str(inp)
37
38
39def _as_int_or_none(inp):
40  return None if inp is None else int(inp)
41
42
43def _as_device_str_or_none(device_type):
44  # For backwards compatibility only, we support lowercase variants of
45  # cpu and gpu but turn them into uppercase here.
46  if device_type in ("cpu", "gpu"):
47    return device_type.upper()
48  return _as_str_or_none(device_type)
49
50
51@tf_export("DeviceSpec", v1=[])
52class DeviceSpecV2(object):
53  """Represents a (possibly partial) specification for a TensorFlow device.
54
55  `DeviceSpec`s are used throughout TensorFlow to describe where state is stored
56  and computations occur. Using `DeviceSpec` allows you to parse device spec
57  strings to verify their validity, merge them or compose them programmatically.
58
59  Example:
60
61  ```python
62  # Place the operations on device "GPU:0" in the "ps" job.
63  device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0)
64  with tf.device(device_spec.to_string()):
65    # Both my_var and squared_var will be placed on /job:ps/device:GPU:0.
66    my_var = tf.Variable(..., name="my_variable")
67    squared_var = tf.square(my_var)
68  ```
69
70  With eager execution disabled (by default in TensorFlow 1.x and by calling
71  disable_eager_execution() in TensorFlow 2.x), the following syntax
72  can be used:
73
74  ```python
75  tf.compat.v1.disable_eager_execution()
76
77  # Same as previous
78  device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0)
79  # No need of .to_string() method.
80  with tf.device(device_spec):
81    my_var = tf.Variable(..., name="my_variable")
82    squared_var = tf.square(my_var)
83  ```
84
85  If a `DeviceSpec` is partially specified, it will be merged with other
86  `DeviceSpec`s according to the scope in which it is defined. `DeviceSpec`
87  components defined in inner scopes take precedence over those defined in
88  outer scopes.
89
90  ```python
91  gpu0_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0)
92  with tf.device(DeviceSpec(job="train").to_string()):
93    with tf.device(gpu0_spec.to_string()):
94      # Nodes created here will be assigned to /job:ps/device:GPU:0.
95    with tf.device(DeviceSpec(device_type="GPU", device_index=1).to_string()):
96      # Nodes created here will be assigned to /job:train/device:GPU:1.
97  ```
98
99  A `DeviceSpec` consists of 5 components -- each of
100  which is optionally specified:
101
102  * Job: The job name.
103  * Replica: The replica index.
104  * Task: The task index.
105  * Device type: The device type string (e.g. "CPU" or "GPU").
106  * Device index: The device index.
107  """
108
109  __slots__ = ("_job", "_replica", "_task", "_device_type", "_device_index",
110               "_as_string", "_hash")
111
112  def __init__(self, job=None, replica=None, task=None, device_type=None,
113               device_index=None):
114    """Create a new `DeviceSpec` object.
115
116    Args:
117      job: string.  Optional job name.
118      replica: int.  Optional replica index.
119      task: int.  Optional task index.
120      device_type: Optional device type string (e.g. "CPU" or "GPU")
121      device_index: int.  Optional device index.  If left
122        unspecified, device represents 'any' device_index.
123    """
124    self._job = _as_str_or_none(job)
125    self._replica = _as_int_or_none(replica)
126    self._task = _as_int_or_none(task)
127    self._device_type = _as_device_str_or_none(device_type)
128    self._device_index = _as_int_or_none(device_index)
129    self._as_string = self._components_to_string(
130        job=self._job, replica=self._replica, task=self._task,
131        device_type=self._device_type, device_index=self._device_index)
132    self._hash = hash(self.to_string())
133
134  def to_string(self):
135    """Return a string representation of this `DeviceSpec`.
136
137    Returns:
138      a string of the form
139      /job:<name>/replica:<id>/task:<id>/device:<device_type>:<id>.
140    """
141    return self._as_string
142
143  @classmethod
144  def from_string(cls, spec):
145    """Construct a `DeviceSpec` from a string.
146
147    Args:
148      spec: a string of the form
149       /job:<name>/replica:<id>/task:<id>/device:CPU:<id>
150      or
151       /job:<name>/replica:<id>/task:<id>/device:GPU:<id>
152      as cpu and gpu are mutually exclusive.
153      All entries are optional.
154
155    Returns:
156      A DeviceSpec.
157    """
158    return cls(*cls._string_to_components(spec))
159
160  def parse_from_string(self, spec):
161    """Parse a `DeviceSpec` name into its components.
162
163    **2.x behavior change**:
164
165    In TensorFlow 1.x, this function mutates its own state and returns itself.
166    In 2.x, DeviceSpecs are immutable, and this function will return a
167      DeviceSpec which contains the spec.
168
169    * Recommended:
170
171      ```
172      # my_spec and my_updated_spec are unrelated.
173      my_spec = tf.DeviceSpec.from_string("/CPU:0")
174      my_updated_spec = tf.DeviceSpec.from_string("/GPU:0")
175      with tf.device(my_updated_spec):
176        ...
177      ```
178
179    * Will work in 1.x and 2.x (though deprecated in 2.x):
180
181      ```
182      my_spec = tf.DeviceSpec.from_string("/CPU:0")
183      my_updated_spec = my_spec.parse_from_string("/GPU:0")
184      with tf.device(my_updated_spec):
185        ...
186      ```
187
188    * Will NOT work in 2.x:
189
190      ```
191      my_spec = tf.DeviceSpec.from_string("/CPU:0")
192      my_spec.parse_from_string("/GPU:0")  # <== Will not update my_spec
193      with tf.device(my_spec):
194        ...
195      ```
196
197    In general, `DeviceSpec.from_string` should completely replace
198    `DeviceSpec.parse_from_string`, and `DeviceSpec.replace` should
199    completely replace setting attributes directly.
200
201    Args:
202      spec: an optional string of the form
203       /job:<name>/replica:<id>/task:<id>/device:CPU:<id>
204      or
205       /job:<name>/replica:<id>/task:<id>/device:GPU:<id>
206      as cpu and gpu are mutually exclusive.
207      All entries are optional.
208
209    Returns:
210      The `DeviceSpec`.
211
212    Raises:
213      ValueError: if the spec was not valid.
214    """
215    return self.from_string(spec)
216
217  def make_merged_spec(self, dev):
218    """Returns a new DeviceSpec which incorporates `dev`.
219
220    When combining specs, `dev` will take precedence over the current spec.
221    So for instance:
222    ```
223    first_spec = tf.DeviceSpec(job=0, device_type="CPU")
224    second_spec = tf.DeviceSpec(device_type="GPU")
225    combined_spec = first_spec.make_merged_spec(second_spec)
226    ```
227
228    is equivalent to:
229    ```
230    combined_spec = tf.DeviceSpec(job=0, device_type="GPU")
231    ```
232
233    Args:
234      dev: a `DeviceSpec`
235
236    Returns:
237      A new `DeviceSpec` which combines `self` and `dev`
238    """
239    return self.__class__(*self._get_combined_properties(dev))
240
241  def replace(self, **kwargs):
242    """Convenience method for making a new DeviceSpec by overriding fields.
243
244    For instance:
245    ```
246    my_spec = DeviceSpec=(job="my_job", device="CPU")
247    my_updated_spec = my_spec.replace(device="GPU")
248    my_other_spec = my_spec.replace(device=None)
249    ```
250
251    Args:
252      **kwargs: This method takes the same args as the DeviceSpec constructor
253
254    Returns:
255      A DeviceSpec with the fields specified in kwargs overridden.
256    """
257    init_kwargs = dict(
258        job=self.job, replica=self.replica, task=self.task,
259        device_type=self.device_type, device_index=self.device_index)
260
261    # Explicitly provided kwargs take precedence.
262    init_kwargs.update(kwargs)
263    return self.__class__(**init_kwargs)
264
265  @property
266  def job(self):
267    return self._job
268
269  @property
270  def replica(self):
271    return self._replica
272
273  @property
274  def task(self):
275    return self._task
276
277  @property
278  def device_type(self):
279    return self._device_type
280
281  @property
282  def device_index(self):
283    return self._device_index
284
285  def _get_combined_properties(self, dev):
286    """Combine the current DeviceSpec with another DeviceSpec.
287
288    The combination of DeviceSpecs is will give priority to dev.
289
290    Args:
291      dev: a `DeviceSpec`
292
293    Returns:
294      A tuple of (job, replica, task, device_type, device_index) which
295      represents the combination of self and dev.
296    """
297    return (
298        dev.job if dev.job is not None else self.job,
299        dev.replica if dev.replica is not None else self.replica,
300        dev.task if dev.task is not None else self.task,
301        dev.device_type if dev.device_type is not None else self.device_type,
302        dev.device_index if dev.device_index is not None else self.device_index,
303    )
304
305  @staticmethod
306  def _string_to_components(spec=None):
307    """Stateless portion of device spec string parsing.
308
309    Args:
310      spec: An optional string specifying a device specification.
311
312    Returns:
313      The parsed components of `spec`. Note that the result of this function
314      must go through attribute setters of DeviceSpec, and should therefore NOT
315      be used directly.
316    """
317    cached_result = _STRING_TO_COMPONENTS_CACHE.get(spec)
318    if cached_result is not None:
319      return cached_result
320
321    raw_spec = spec  # keep a copy of the original to update the cache
322    job, replica, task, device_type, device_index = None, None, None, None, None
323
324    spec = spec or ""
325    splits = [x.split(":") for x in spec.split("/")]
326    for y in splits:
327      ly = len(y)
328      if y:
329        # NOTE(taylorrobie): these will go through setters later.
330        if ly == 2 and y[0] == "job":
331          job = y[1]
332        elif ly == 2 and y[0] == "replica":
333          replica = y[1]
334        elif ly == 2 and y[0] == "task":
335          task = y[1]
336        elif ((ly == 1 or ly == 2) and (y[0].upper() in _VALID_DEVICE_TYPES)):
337          if device_type is not None:
338            raise ValueError("Cannot specify multiple device types: %s" % spec)
339          device_type = y[0].upper()
340          if ly == 2 and y[1] != "*":
341            device_index = int(y[1])
342        elif ly == 3 and y[0] == "device":
343          if device_type is not None:
344            raise ValueError("Cannot specify multiple device types: %s" % spec)
345          device_type = y[1]
346          if y[2] != "*":
347            device_index = int(y[2])
348        elif ly and y[0] != "":  # pylint: disable=g-explicit-bool-comparison
349          raise ValueError("Unknown attribute: '%s' in '%s'" % (y[0], spec))
350
351    output = (job, replica, task, device_type, device_index)
352    _STRING_TO_COMPONENTS_CACHE[raw_spec] = output
353    return output
354
355  @staticmethod
356  def _components_to_string(job, replica, task, device_type, device_index):
357    """Stateless portion of `to_string` (separated to allow caching)."""
358    key = (job, replica, task, device_type, device_index)
359    cached_result = _COMPONENTS_TO_STRING_CACHE.get(key)
360    if cached_result is not None:
361      return cached_result
362
363    output = []
364    if job is not None:
365      output.append("/job:" + job)
366    if replica is not None:
367      output.append("/replica:" + str(replica))
368    if task is not None:
369      output.append("/task:" + str(task))
370    if device_type is not None:
371      device_index_string = "*"
372      if device_index is not None:
373        # Unlike the others, device_index is stored as an int.
374        device_index_string = str(device_index)
375      output.append("/device:%s:%s" % (device_type, device_index_string))
376
377    output = "".join(output)
378    _COMPONENTS_TO_STRING_CACHE[key] = output
379    return output
380
381  def __eq__(self, other):
382    """Checks if the `other` DeviceSpec is same as the current instance, eg have
383
384       same value for all the internal fields.
385
386    Args:
387      other: Another DeviceSpec
388
389    Returns:
390      Return `True` if `other` is also a DeviceSpec instance and has same value
391      as the current instance.
392      Return `False` otherwise.
393    """
394    return (isinstance(other, self.__class__) and
395            self.to_string() == other.to_string())
396
397  def __hash__(self):
398    return self._hash
399
400
401@tf_export(v1=["DeviceSpec"])  # pylint: disable=missing-docstring
402class DeviceSpecV1(DeviceSpecV2):
403  __doc__ = DeviceSpecV2.__doc__
404  __slots__ = DeviceSpecV2.__slots__
405
406  @DeviceSpecV2.job.setter
407  def job(self, job):
408    self._job = _as_str_or_none(job)
409    self._as_string, self._hash = None, None
410
411  @DeviceSpecV2.replica.setter
412  def replica(self, replica):
413    self._replica = _as_int_or_none(replica)
414    self._as_string, self._hash = None, None
415
416  @DeviceSpecV2.task.setter
417  def task(self, task):
418    self._task = _as_int_or_none(task)
419    self._as_string, self._hash = None, None
420
421  @DeviceSpecV2.device_type.setter
422  def device_type(self, device_type):
423    self._device_type = _as_device_str_or_none(device_type)
424    self._as_string, self._hash = None, None
425
426  @DeviceSpecV2.device_index.setter
427  def device_index(self, device_index):
428    self._device_index = _as_int_or_none(device_index)
429    self._as_string, self._hash = None, None
430
431  def __hash__(self):
432    if self._hash is None:
433      self._hash = hash(self.to_string())
434    return self._hash
435
436  def to_string(self):
437    if self._as_string is None:
438      self._as_string = self._components_to_string(
439          job=self.job, replica=self.replica, task=self.task,
440          device_type=self.device_type, device_index=self.device_index)
441    return self._as_string
442
443  def parse_from_string(self, spec):
444    (self.job, self.replica, self.task, self.device_type, self.device_index
445    ) = self._string_to_components(spec)
446
447    return self
448
449  def merge_from(self, dev):
450    """Merge the properties of "dev" into this `DeviceSpec`.
451
452    Note: Will be removed in TensorFlow 2.x since DeviceSpecs will become
453          immutable.
454
455    Args:
456      dev: a `DeviceSpec`.
457    """
458    (self.job, self.replica, self.task, self.device_type, self.device_index
459    ) = self._get_combined_properties(dev)
460
461  # Use parent class docstrings for public methods.
462  to_string.__doc__ = DeviceSpecV2.to_string.__doc__
463  parse_from_string.__doc__ = DeviceSpecV2.parse_from_string.__doc__
464