1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ======================================
15"""Defines the `Topology` class, that describes a TPU fabric topology."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22from six.moves import xrange  # pylint: disable=redefined-builtin
23
24from tensorflow.core.protobuf.tpu import topology_pb2
25
26
27def _tpu_device_name(job, task, device):
28  """Returns the device name for the TPU `device` on `task` of `job`."""
29  if job is None:
30    return "/task:%d/device:TPU:%d" % (task, device)
31  else:
32    return "/job:%s/task:%d/device:TPU:%d" % (job, task, device)
33
34
35def _tpu_host_device_name(job, task):
36  """Returns the device name for the CPU device on `task` of `job`."""
37  if job is None:
38    return "/task:%d/device:CPU:0" % task
39  else:
40    return "/job:%s/task:%d/device:CPU:0" % (job, task)
41
42
43class Topology(object):
44  """Describes a set of TPU devices.
45
46  Represents both the shape of the physical mesh, and the mapping between
47  TensorFlow TPU devices to physical mesh coordinates.
48  """
49
50  def __init__(self, serialized=None, mesh_shape=None, device_coordinates=None):
51    """Builds a Topology object.
52
53    If `serialized` is not `None`, the topology is parsed from `serialized` and
54    the other arguments are ignored. Otherwise, the topology is computed from
55    `mesh_shape` and `device_coordinates`.
56
57    Args:
58      serialized: A serialized `TopologyProto`, or `None`. If not `None`, the
59        serialized proto is parsed to discover the topology.
60      mesh_shape: A sequence of 3 positive integers, or `None`. If not `None`,
61        the shape of the TPU topology, in number of cores. Ignored if
62        `serialized` is not `None`.
63      device_coordinates: A rank 3 numpy array that describes the mapping from
64        TensorFlow TPU devices to TPU fabric coordinates, or `None`. Ignored
65        if `serialized is not `None`.
66
67    Raises:
68      ValueError: If `serialized` does not describe a well-formed topology.
69      ValueError: If `serialized` is `None` and `mesh_shape` is not a sequence
70        of 3 positive integers.
71      ValueError: If `serialized` is `None` and `device_coordinates` is not a
72        rank 3 numpy int32 array that describes a valid coordinate mapping.
73    """
74
75    self._serialized = serialized
76
77    if serialized:
78      self._parse_topology(serialized)
79    else:
80      self._mesh_shape = np.asarray(mesh_shape, dtype=np.int32)
81      self._device_coordinates = np.asarray(device_coordinates, np.int32)
82      if len(self._mesh_shape) != 3 or any(self._mesh_shape < 1):
83        raise ValueError("`mesh_shape` must be a sequence of 3 positive "
84                         "entries; got {}".format(self._mesh_shape))
85
86      if (len(self._device_coordinates.shape) != 3 or
87          self._device_coordinates.shape[2] != len(self._mesh_shape)):
88        raise ValueError("`device_coordinates` must be a rank 3 int32 array "
89                         "with minor dimension equal to the mesh shape rank")
90
91    self._topology_tasks, self._topology_devices = self._invert_topology()
92
93  def _parse_topology(self, serialized):
94    """Parses a serialized `TopologyProto` into `self`."""
95    proto = topology_pb2.TopologyProto()
96    proto.ParseFromString(serialized)
97
98    self._mesh_shape = np.array(proto.mesh_shape, dtype=np.int32)
99    if len(self._mesh_shape) != 3 or any(self._mesh_shape < 1):
100      raise ValueError("`mesh_shape` must be a vector of size 3 with positive "
101                       "entries; got {}".format(self._mesh_shape))
102
103    if proto.num_tasks < 0:
104      raise ValueError("`num_tasks` must be >= 0; got {}".format(
105          proto.num_tasks))
106    if proto.num_tpu_devices_per_task < 0:
107      raise ValueError("`num_tpu_devices_per_task` must be >= 0; got {}".format(
108          proto.num_tpu_devices_per_task))
109
110    expected_coordinates_size = (
111        proto.num_tasks * proto.num_tpu_devices_per_task * len(
112            proto.mesh_shape))
113    if len(proto.device_coordinates) != expected_coordinates_size:
114      raise ValueError("`device_coordinates` must have shape num_tasks ({}) * "
115                       "num_tpu_devices_per_task ({}) * len(mesh_shape) ({}); "
116                       "got shape {}".format(proto.num_tasks,
117                                             proto.num_tpu_devices_per_task,
118                                             proto.mesh_shape,
119                                             len(proto.device_coordinates)))
120
121    coords = np.array(proto.device_coordinates, dtype=np.int32)
122    if any(coords < 0):
123      raise ValueError("`device_coordinates` must be >= 0")
124    coords = coords.reshape((proto.num_tasks, proto.num_tpu_devices_per_task,
125                             len(proto.mesh_shape)))
126    self._device_coordinates = coords
127
128  def _invert_topology(self):
129    """Inverts a [task,device,axis] topology to [x,y,z] -> task/device maps."""
130    tasks = np.full(list(self.mesh_shape), -1, dtype=np.int32)
131    devices = np.full(list(self.mesh_shape), -1, dtype=np.int32)
132    for task in xrange(self.device_coordinates.shape[0]):
133      for device in xrange(self.device_coordinates.shape[1]):
134        x, y, z = self.device_coordinates[task, device, :]
135        tasks[x, y, z] = task
136        devices[x, y, z] = device
137    return tasks, devices
138
139  @property
140  def mesh_shape(self):
141    """A rank 1 int32 array describing the shape of the TPU topology."""
142    return self._mesh_shape
143
144  @property
145  def mesh_rank(self):
146    """Returns the number of dimensions in the mesh."""
147    return len(self._mesh_shape)
148
149  @property
150  def device_coordinates(self):
151    """Describes the mapping from TPU devices to topology coordinates.
152
153    Returns:
154      A rank 3 int32 array with shape `[tasks, devices, axis]`.
155      `tasks` is the number of tasks in the TPU cluster, `devices` is the number
156      of TPU devices per task, and `axis` is the number of axes in the TPU
157      cluster topology. Each entry gives the `axis`-th coordinate in the
158      topology of a task/device pair. TPU topologies are 3-dimensional, with
159      dimensions `(x, y, core number)`.
160    """
161    return self._device_coordinates
162
163  def task_ordinal_at_coordinates(self, device_coordinates):
164    """Returns the TensorFlow task number attached to `device_coordinates`.
165
166    Args:
167      device_coordinates: An integer sequence describing a device's physical
168        coordinates in the TPU fabric.
169
170    Returns:
171      Returns the TensorFlow task number that contains the TPU device with those
172      physical coordinates.
173    """
174    return self._topology_tasks[tuple(device_coordinates)]
175
176  def tpu_device_ordinal_at_coordinates(self, device_coordinates):
177    """Returns the TensorFlow device number at `device_coordinates`.
178
179    Args:
180      device_coordinates: An integer sequence describing a device's physical
181        coordinates in the TPU fabric.
182
183    Returns:
184      Returns the TensorFlow device number within the task corresponding to
185      attached to the device with those physical coordinates.
186    """
187    return self._topology_devices[tuple(device_coordinates)]
188
189  def cpu_device_name_at_coordinates(self, device_coordinates, job=None):
190    """Returns the CPU device attached to a logical core."""
191    return _tpu_host_device_name(
192        job, self._topology_tasks[tuple(device_coordinates)])
193
194  def tpu_device_name_at_coordinates(self, device_coordinates, job=None):
195    """Returns the name of the TPU device assigned to a logical core."""
196    return _tpu_device_name(job,
197                            self._topology_tasks[tuple(device_coordinates)],
198                            self._topology_devices[tuple(device_coordinates)])
199
200  @property
201  def num_tasks(self):
202    """Returns the number of TensorFlow tasks in the TPU slice."""
203    return self._device_coordinates.shape[0]
204
205  @property
206  def num_tpus_per_task(self):
207    """Returns the number of TPU devices per task in the TPU slice."""
208    return self._device_coordinates.shape[1]
209
210  def serialized(self):
211    """Returns the serialized form of the topology."""
212    if self._serialized is None:
213      proto = topology_pb2.TopologyProto()
214      proto.mesh_shape[:] = list(self._mesh_shape)
215      proto.num_tasks = self._device_coordinates.shape[0]
216      proto.num_tpu_devices_per_task = self._device_coordinates.shape[1]
217      proto.device_coordinates.extend(list(self._device_coordinates.flatten()))
218      self._serialized = proto.SerializeToString()
219
220    return self._serialized
221