1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ====================================== 15"""Defines the `Topology` class, that describes a TPU fabric topology.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22from six.moves import xrange # pylint: disable=redefined-builtin 23 24from tensorflow.core.protobuf.tpu import topology_pb2 25 26 27def _tpu_device_name(job, task, device): 28 """Returns the device name for the TPU `device` on `task` of `job`.""" 29 if job is None: 30 return "/task:%d/device:TPU:%d" % (task, device) 31 else: 32 return "/job:%s/task:%d/device:TPU:%d" % (job, task, device) 33 34 35def _tpu_host_device_name(job, task): 36 """Returns the device name for the CPU device on `task` of `job`.""" 37 if job is None: 38 return "/task:%d/device:CPU:0" % task 39 else: 40 return "/job:%s/task:%d/device:CPU:0" % (job, task) 41 42 43class Topology(object): 44 """Describes a set of TPU devices. 45 46 Represents both the shape of the physical mesh, and the mapping between 47 TensorFlow TPU devices to physical mesh coordinates. 48 """ 49 50 def __init__(self, serialized=None, mesh_shape=None, device_coordinates=None): 51 """Builds a Topology object. 52 53 If `serialized` is not `None`, the topology is parsed from `serialized` and 54 the other arguments are ignored. Otherwise, the topology is computed from 55 `mesh_shape` and `device_coordinates`. 56 57 Args: 58 serialized: A serialized `TopologyProto`, or `None`. If not `None`, the 59 serialized proto is parsed to discover the topology. 60 mesh_shape: A sequence of 3 positive integers, or `None`. If not `None`, 61 the shape of the TPU topology, in number of cores. Ignored if 62 `serialized` is not `None`. 63 device_coordinates: A rank 3 numpy array that describes the mapping from 64 TensorFlow TPU devices to TPU fabric coordinates, or `None`. Ignored 65 if `serialized is not `None`. 66 67 Raises: 68 ValueError: If `serialized` does not describe a well-formed topology. 69 ValueError: If `serialized` is `None` and `mesh_shape` is not a sequence 70 of 3 positive integers. 71 ValueError: If `serialized` is `None` and `device_coordinates` is not a 72 rank 3 numpy int32 array that describes a valid coordinate mapping. 73 """ 74 75 self._serialized = serialized 76 77 if serialized: 78 self._parse_topology(serialized) 79 else: 80 self._mesh_shape = np.asarray(mesh_shape, dtype=np.int32) 81 self._device_coordinates = np.asarray(device_coordinates, np.int32) 82 if len(self._mesh_shape) != 3 or any(self._mesh_shape < 1): 83 raise ValueError("`mesh_shape` must be a sequence of 3 positive " 84 "entries; got {}".format(self._mesh_shape)) 85 86 if (len(self._device_coordinates.shape) != 3 or 87 self._device_coordinates.shape[2] != len(self._mesh_shape)): 88 raise ValueError("`device_coordinates` must be a rank 3 int32 array " 89 "with minor dimension equal to the mesh shape rank") 90 91 self._topology_tasks, self._topology_devices = self._invert_topology() 92 93 def _parse_topology(self, serialized): 94 """Parses a serialized `TopologyProto` into `self`.""" 95 proto = topology_pb2.TopologyProto() 96 proto.ParseFromString(serialized) 97 98 self._mesh_shape = np.array(proto.mesh_shape, dtype=np.int32) 99 if len(self._mesh_shape) != 3 or any(self._mesh_shape < 1): 100 raise ValueError("`mesh_shape` must be a vector of size 3 with positive " 101 "entries; got {}".format(self._mesh_shape)) 102 103 if proto.num_tasks < 0: 104 raise ValueError("`num_tasks` must be >= 0; got {}".format( 105 proto.num_tasks)) 106 if proto.num_tpu_devices_per_task < 0: 107 raise ValueError("`num_tpu_devices_per_task` must be >= 0; got {}".format( 108 proto.num_tpu_devices_per_task)) 109 110 expected_coordinates_size = ( 111 proto.num_tasks * proto.num_tpu_devices_per_task * len( 112 proto.mesh_shape)) 113 if len(proto.device_coordinates) != expected_coordinates_size: 114 raise ValueError("`device_coordinates` must have shape num_tasks ({}) * " 115 "num_tpu_devices_per_task ({}) * len(mesh_shape) ({}); " 116 "got shape {}".format(proto.num_tasks, 117 proto.num_tpu_devices_per_task, 118 proto.mesh_shape, 119 len(proto.device_coordinates))) 120 121 coords = np.array(proto.device_coordinates, dtype=np.int32) 122 if any(coords < 0): 123 raise ValueError("`device_coordinates` must be >= 0") 124 coords = coords.reshape((proto.num_tasks, proto.num_tpu_devices_per_task, 125 len(proto.mesh_shape))) 126 self._device_coordinates = coords 127 128 def _invert_topology(self): 129 """Inverts a [task,device,axis] topology to [x,y,z] -> task/device maps.""" 130 tasks = np.full(list(self.mesh_shape), -1, dtype=np.int32) 131 devices = np.full(list(self.mesh_shape), -1, dtype=np.int32) 132 for task in xrange(self.device_coordinates.shape[0]): 133 for device in xrange(self.device_coordinates.shape[1]): 134 x, y, z = self.device_coordinates[task, device, :] 135 tasks[x, y, z] = task 136 devices[x, y, z] = device 137 return tasks, devices 138 139 @property 140 def mesh_shape(self): 141 """A rank 1 int32 array describing the shape of the TPU topology.""" 142 return self._mesh_shape 143 144 @property 145 def mesh_rank(self): 146 """Returns the number of dimensions in the mesh.""" 147 return len(self._mesh_shape) 148 149 @property 150 def device_coordinates(self): 151 """Describes the mapping from TPU devices to topology coordinates. 152 153 Returns: 154 A rank 3 int32 array with shape `[tasks, devices, axis]`. 155 `tasks` is the number of tasks in the TPU cluster, `devices` is the number 156 of TPU devices per task, and `axis` is the number of axes in the TPU 157 cluster topology. Each entry gives the `axis`-th coordinate in the 158 topology of a task/device pair. TPU topologies are 3-dimensional, with 159 dimensions `(x, y, core number)`. 160 """ 161 return self._device_coordinates 162 163 def task_ordinal_at_coordinates(self, device_coordinates): 164 """Returns the TensorFlow task number attached to `device_coordinates`. 165 166 Args: 167 device_coordinates: An integer sequence describing a device's physical 168 coordinates in the TPU fabric. 169 170 Returns: 171 Returns the TensorFlow task number that contains the TPU device with those 172 physical coordinates. 173 """ 174 return self._topology_tasks[tuple(device_coordinates)] 175 176 def tpu_device_ordinal_at_coordinates(self, device_coordinates): 177 """Returns the TensorFlow device number at `device_coordinates`. 178 179 Args: 180 device_coordinates: An integer sequence describing a device's physical 181 coordinates in the TPU fabric. 182 183 Returns: 184 Returns the TensorFlow device number within the task corresponding to 185 attached to the device with those physical coordinates. 186 """ 187 return self._topology_devices[tuple(device_coordinates)] 188 189 def cpu_device_name_at_coordinates(self, device_coordinates, job=None): 190 """Returns the CPU device attached to a logical core.""" 191 return _tpu_host_device_name( 192 job, self._topology_tasks[tuple(device_coordinates)]) 193 194 def tpu_device_name_at_coordinates(self, device_coordinates, job=None): 195 """Returns the name of the TPU device assigned to a logical core.""" 196 return _tpu_device_name(job, 197 self._topology_tasks[tuple(device_coordinates)], 198 self._topology_devices[tuple(device_coordinates)]) 199 200 @property 201 def num_tasks(self): 202 """Returns the number of TensorFlow tasks in the TPU slice.""" 203 return self._device_coordinates.shape[0] 204 205 @property 206 def num_tpus_per_task(self): 207 """Returns the number of TPU devices per task in the TPU slice.""" 208 return self._device_coordinates.shape[1] 209 210 def serialized(self): 211 """Returns the serialized form of the topology.""" 212 if self._serialized is None: 213 proto = topology_pb2.TopologyProto() 214 proto.mesh_shape[:] = list(self._mesh_shape) 215 proto.num_tasks = self._device_coordinates.shape[0] 216 proto.num_tpu_devices_per_task = self._device_coordinates.shape[1] 217 proto.device_coordinates.extend(list(self._device_coordinates.flatten())) 218 self._serialized = proto.SerializeToString() 219 220 return self._serialized 221