1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the 'License'); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an 'AS IS' BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ====================================== 15"""Experimental support for defining XLA shardings.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as _np # Avoids becoming a part of public Tensorflow API. 22 23from tensorflow.compiler.tf2xla.python import xla as tf2xla 24from tensorflow.compiler.xla import xla_data_pb2 25from tensorflow.core.framework import attr_value_pb2 26 27 28class Sharding(object): 29 """A class to support adding sharding attributes to Ops. 30 31 Use the factory constructors and then call apply_to_tensor: 32 Sharding.replicate().apply_to_tensor(tensor) 33 """ 34 35 def __init__(self, proto=None): 36 """Do not use this constructor; use the factory functions below.""" 37 self._proto = proto 38 39 @classmethod 40 def replicate(cls): 41 """Returns a replicated sharding attribute. 42 43 This causes an op to be computed in its entirety independently on all 44 cores in the XLA device. 45 """ 46 return Sharding( 47 proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED)) 48 49 @classmethod 50 def manual(cls): 51 """Returns a manuall sharding attribute. 52 53 This means the op is manually partitioned by the user and XLA will not 54 change the shapes. 55 """ 56 return Sharding( 57 proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.MANUAL)) 58 59 @classmethod 60 def assign_device(cls, core): 61 """Returns an AssignDevice sharding attribute. 62 63 This causes an op to be computed in its entirety only on one core in 64 the XLA device. 65 Args: 66 core: The core to assign this Op to. 67 """ 68 return Sharding( 69 proto=xla_data_pb2.OpSharding( 70 type=xla_data_pb2.OpSharding.MAXIMAL, 71 tile_assignment_dimensions=[1], 72 tile_assignment_devices=[core])) 73 74 @classmethod 75 def tile(cls, tile_assignment): 76 """Returns a Tiled sharding attribute. 77 78 This causes an op to be partially computed on multiple cores in the 79 XLA device. 80 81 Args: 82 tile_assignment: An np.ndarray describing the topology of the tiling and 83 which device will compute which part of the topology. 84 85 Raises: 86 TypeError: tile_assignment was not of np.array type. 87 88 TODO(jmolloy): This concept is nefarious and is not 89 something we really want to expose to users (especially as the 90 contract for tile_assignment is very strict). 91 """ 92 if not isinstance(tile_assignment, _np.ndarray): 93 raise TypeError('Tile assignment must be of type np.ndarray') 94 dims = list(tile_assignment.shape) 95 flattened_devices = tile_assignment.reshape(-1, order='C') 96 return Sharding( 97 proto=xla_data_pb2.OpSharding( 98 type=xla_data_pb2.OpSharding.OTHER, 99 tile_assignment_dimensions=dims, 100 tile_assignment_devices=list(flattened_devices))) 101 102 @classmethod 103 def partial_tile(cls, tile_assignment): 104 """Returns a partially tiled sharding attribute. 105 106 This is similar to tile(), but tile_assignment has one more dimension than 107 the tensor, and tiles in the last dimension of tile_assignment are 108 replicated. 109 110 Args: 111 tile_assignment: An np.ndarray describing the topology of the tiling and 112 which device will compute which part of the topology. 113 114 Raises: 115 TypeError: tile_assignment was not of np.array type. 116 """ 117 if not isinstance(tile_assignment, _np.ndarray): 118 raise TypeError('PartialTile assignment must be of type np.ndarray') 119 dims = list(tile_assignment.shape) 120 flattened_devices = tile_assignment.reshape(-1, order='C') 121 return Sharding( 122 proto=xla_data_pb2.OpSharding( 123 type=xla_data_pb2.OpSharding.OTHER, 124 tile_assignment_dimensions=dims, 125 tile_assignment_devices=list(flattened_devices), 126 replicate_on_last_tile_dim=True)) 127 128 @classmethod 129 def split(cls, tensor, split_dimension, num_devices, input_shape=None): 130 """Returns a Sharding that splits a tensor across a dimension. 131 132 This creates a Tiled attribute, similar to tile(), but easier to use for the 133 common case of tiling a tensor N ways in one dimension. 134 135 Args: 136 tensor: A tf.Tensor to split. 137 split_dimension: The dimension number to split. 138 num_devices: The number of cores to split `tensor` over. 139 input_shape: The shape of the original tensor. 140 141 Raises: 142 ValueError: The tensor to split was smaller in the split dimension than 143 the number of devices to split over. 144 """ 145 if input_shape: 146 shape = input_shape 147 else: 148 shape = tensor.shape.as_list() 149 if (shape[split_dimension] is not None and 150 shape[split_dimension] < num_devices): 151 raise ValueError('Split dimension was smaller than the required number ' 152 'of splits: shape=%r, dimension=%r, num_devices=%r' % 153 (shape, split_dimension, num_devices)) 154 155 tile_assignment_dims = [1] * len(shape) 156 tile_assignment_dims[split_dimension] = num_devices 157 158 return Sharding( 159 proto=xla_data_pb2.OpSharding( 160 type=xla_data_pb2.OpSharding.OTHER, 161 tile_assignment_dimensions=tile_assignment_dims, 162 tile_assignment_devices=range(num_devices))) 163 164 def apply_to_tensor(self, 165 tensor, 166 assign_tuple_sharding=False, 167 use_sharding_op=False): 168 """Applies this Sharding attribute to `tensor`. 169 170 Args: 171 tensor: A tf.Tensor to split. 172 assign_tuple_sharding: If the sharding type should be a tuple. 173 use_sharding_op: whether to create a sharding op on `tensor`. 174 175 Returns: 176 The tensor with Sharding attribute. 177 """ 178 proto = self._proto 179 if use_sharding_op: 180 if assign_tuple_sharding: 181 proto = self._create_tuple_proto(num_outputs=1) 182 tensor = tf2xla.sharding(tensor, sharding=proto.SerializeToString()) 183 else: 184 tensor = tf2xla.sharding( 185 tensor, sharding=proto.SerializeToString()) 186 elif assign_tuple_sharding or len(tensor.op.outputs) > 1: 187 proto = self._get_or_create_tuple_proto(tensor.op) 188 # We can't mutate an element of old_proto.tuple_shardings, so create 189 # a new proto. 190 tuple_shardings = list(proto.tuple_shardings) 191 tuple_shardings[tensor.value_index] = self._proto 192 proto = xla_data_pb2.OpSharding( 193 type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=tuple_shardings) 194 195 # TODO(jmolloy): This need to be seriously revisited before declaring this 196 # API available for public use. 197 # pylint: disable=protected-access 198 tensor.op._set_attr('_XlaSharding', 199 attr_value_pb2.AttrValue(s=proto.SerializeToString())) 200 return tensor 201 202 def apply_to_operation(self, operation): 203 """Applies this Sharding attribute to `operation`. 204 205 Args: 206 operation: A tf.Operation to add sharding annotation. 207 """ 208 attr_value = attr_value_pb2.AttrValue(s=self._proto.SerializeToString()) 209 # pylint: disable=protected-access 210 operation._set_attr('_XlaSharding', attr_value) 211 212 @property 213 def proto(self): 214 """Return the sharding protobuf of type xla_data_pb2.OpSharding.""" 215 return self._proto 216 217 def _get_or_create_tuple_proto(self, op): 218 try: 219 attr = op.get_attr('_XlaSharding') 220 proto = xla_data_pb2.OpSharding() 221 proto.ParseFromString(attr) 222 return proto 223 except ValueError: 224 return self._create_tuple_proto(len(op.outputs)) 225 226 def _create_tuple_proto(self, num_outputs): 227 shardings = [ 228 xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED) 229 ] * num_outputs 230 return xla_data_pb2.OpSharding( 231 type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=shardings) 232 233 234def copy_sharding(from_tensor, to_tensor, use_sharding_op=False): 235 """Copies the a tensor's sharding to another. 236 237 Args: 238 from_tensor: Source tensor. Must be the sole output of an op. 239 to_tensor: the tensor the annotate with the copy. 240 use_sharding_op: whether to create a sharding op on `to_tensor`. 241 242 Returns: 243 A tensor with sharding annotation copied from `from_tensor`. 244 """ 245 sharding = get_tensor_sharding(from_tensor) 246 if sharding is None: 247 return to_tensor 248 249 if use_sharding_op: 250 to_tensor = tf2xla.sharding(to_tensor, sharding=sharding) 251 attr_value = attr_value_pb2.AttrValue(s=sharding) 252 # pylint: disable=protected-access 253 to_tensor.op._set_attr('_XlaSharding', attr_value) 254 return to_tensor 255 256# Helpers for the above factory functions that allow easy application of 257# shardings, for example: 258# tensor = xla_sharding.replicate(tensor) 259 260 261def replicate(tensor, assign_tuple_sharding=False, use_sharding_op=False): 262 return Sharding.replicate().apply_to_tensor( 263 tensor, 264 assign_tuple_sharding=assign_tuple_sharding, 265 use_sharding_op=use_sharding_op) 266 267 268def assign_device(tensor, 269 device, 270 assign_tuple_sharding=False, 271 use_sharding_op=False): 272 """Returns a tensor that has AssignDevice sharding attribute.""" 273 return Sharding.assign_device(device).apply_to_tensor( 274 tensor, 275 assign_tuple_sharding=assign_tuple_sharding, 276 use_sharding_op=use_sharding_op) 277 278 279def tile(tensor, 280 tile_assignment, 281 assign_tuple_sharding=False, 282 use_sharding_op=False): 283 """Returns a tensor that has tiled sharding. 284 285 Args: 286 tensor: A tf.Tensor to shard. 287 tile_assignment: An np.ndarray describing the topology of the tiling and 288 which device will compute which part of the topology. 289 assign_tuple_sharding: If the sharding type should be a tuple. 290 use_sharding_op: If true, adds a sharding op to set the sharding. 291 """ 292 return Sharding.tile(tile_assignment).apply_to_tensor( 293 tensor, 294 assign_tuple_sharding=assign_tuple_sharding, 295 use_sharding_op=use_sharding_op) 296 297 298def split(tensor, 299 split_dimension, 300 num_devices, 301 assign_tuple_sharding=False, 302 use_sharding_op=False, 303 input_shape=None): 304 """Returns a tensor that is split along the given dimension. 305 306 Args: 307 tensor: A tf.Tensor to split. 308 split_dimension: The dimension to split. 309 num_devices: The number of devices to partition the dimension. 310 assign_tuple_sharding: If the sharding type should be a tuple. 311 use_sharding_op: If true, adds a sharding op to set the sharding. 312 input_shape: The full shape of the input tensor. 313 """ 314 return Sharding.split(tensor, split_dimension, num_devices, 315 input_shape).apply_to_tensor( 316 tensor, 317 assign_tuple_sharding=assign_tuple_sharding, 318 use_sharding_op=use_sharding_op) 319 320 321def partial_tile(tensor, tile_assignment, use_sharding_op=False): 322 """Returns a tensor that has tiled sharding. 323 324 Args: 325 tensor: A tf.Tensor to shard. 326 tile_assignment: An np.ndarray describing the topology of the tiling and 327 which device will compute which part of the topology. It must have one 328 more dimension than tensor, and the last dimension represents partially 329 replicated tiles. 330 use_sharding_op: If true, adds a sharding op to set the sharding. 331 """ 332 return Sharding.partial_tile(tile_assignment).apply_to_tensor( 333 tensor, use_sharding_op=use_sharding_op) 334 335 336def get_op_sharding(op): 337 """Returns sharding attribute of an op. 338 339 Args: 340 op: a TensorFlow op. 341 342 Returns: 343 The attribute representing XLA sharding on this op. 344 """ 345 try: 346 return op.get_attr('_XlaSharding') 347 except ValueError: 348 return None 349 except AttributeError: 350 # AttributeError: 'DistributedVarOp' object has no attribute 'get_attr'. 351 return None 352 353 354def get_tensor_sharding(tensor): 355 """Returns sharding attribute of a Tensor. 356 357 Args: 358 tensor: a Tensor. 359 360 Returns: 361 The attribute representing XLA sharding on tensor's op. 362 """ 363 try: 364 return get_op_sharding(tensor.op) 365 except AttributeError: 366 # AttributeError: Tensor.op is meaningless when eager execution is enabled. 367 return None 368 369 370def auto_to_manual_spmd_partition(tensor, manual_sharding): 371 """Switches from automatic SPMD partitioning to manual partitioning. 372 373 Converts a full-shaped tensor (to be automatically partitioned by SPMD 374 partitioner) to a shard-shaped tensor to be consumed by manually partitioned 375 ops. 376 377 Args: 378 tensor: A tf.Tensor in full shape. 379 manual_sharding: a serialized string of OpSharding to be used in manual 380 partitioning. 381 382 Returns: 383 A shard-shaped tensor to be consumed by manually partitioned ops. 384 """ 385 return tf2xla.spmd_full_to_shard_shape( 386 tensor, manual_sharding=manual_sharding) 387 388 389def manual_to_auto_spmd_partition(tensor, manual_sharding, full_shape): 390 """Switches from manual partitioning to automatic SPMD partitioning. 391 392 Converts a shard-shaped tensor (manually partitioned in SPMD-style) to a 393 full-shaped tensor to be partitioned automatically by the SPMD partitioner. 394 395 Args: 396 tensor: A tf.Tensor in shard shape. 397 manual_sharding: a serialized string of OpSharding to be used in manual 398 partitioning. 399 full_shape: the shape of tensor before partitioning. 400 401 Returns: 402 A full-shaped tensor to be partitioned automatically by the SPMD 403 partitioner. 404 """ 405 return tf2xla.spmd_shard_to_full_shape( 406 tensor, manual_sharding=manual_sharding, full_shape=full_shape) 407 408 409def mesh_split_sharding(device_mesh, tensor_split_dims_mapping): 410 """Returns a Sharding object representing sharding along multiple dimensions. 411 412 Args: 413 device_mesh: An np.ndarray describing the topology of the device mesh and 414 each element is the ID of the device in the topology. 415 tensor_split_dims_mapping: A list of integers that map each tensor axis to 416 the device mesh axis along which it is sharded. Its length is the tensor 417 rank, and tensor_split_dims_mapping[i] is device mesh axis for tensor 418 dimension i. Use -1 for tensor dimensions that are not sharded. 419 420 Raises: 421 ValueError: The number of tensor split dimensions is larger than device mesh 422 rank. 423 """ 424 permutation = [d for d in tensor_split_dims_mapping if d >= 0] 425 if len(permutation) > len(device_mesh.shape): 426 raise ValueError( 427 'Number of tensor split dimensions (%r) is larger than device mesh ' 428 'rank (%r). tensor_split_dims_mapping: %r, device_mesh.shape: %r' % 429 (len(permutation), len( 430 device_mesh.shape), tensor_split_dims_mapping, device_mesh.shape)) 431 # Append replicated dimensions to the end. 432 transpose_permutation = permutation + [ 433 d for d in range(len(device_mesh.shape)) if d not in permutation 434 ] 435 tile_assignment = _np.transpose(device_mesh, transpose_permutation) 436 tile_shape = [ 437 1 if d < 0 else device_mesh.shape[d] for d in tensor_split_dims_mapping 438 ] 439 partial = len(permutation) < len(device_mesh.shape) 440 if partial: 441 tile_shape.append(_np.prod(device_mesh.shape) // _np.prod(tile_shape)) 442 tile_assignment = _np.reshape(tile_assignment, tile_shape) 443 444 if partial: 445 return Sharding.partial_tile(tile_assignment) 446 return Sharding.tile(tile_assignment) 447 448 449def mesh_split(tensor, 450 device_mesh, 451 tensor_split_dims_mapping, 452 use_sharding_op=False): 453 """Returns a tensor that is split along multiple dimensions in a device mesh. 454 455 Args: 456 tensor: A tf.Tensor to split. 457 device_mesh: An np.ndarray describing the topology of the device mesh and 458 each element is the ID of the device in the topology. 459 tensor_split_dims_mapping: A list of integers that map each tensor axis to 460 the device mesh axis along which it is sharded. Its length is the tensor 461 rank, and tensor_split_dims_mapping[i] is device mesh axis for tensor 462 dimension i. Use -1 for tensor dimensions that are not sharded. 463 use_sharding_op: If true, adds a sharding op to set the sharding. 464 465 Raises: 466 ValueError: The number of tensor split dimensions is larger than device mesh 467 rank. 468 """ 469 sharding = mesh_split_sharding(device_mesh, tensor_split_dims_mapping) 470 return sharding.apply_to_tensor(tensor, use_sharding_op=use_sharding_op) 471