1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the 'License');
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an 'AS IS' BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ======================================
15"""Experimental support for defining XLA shardings."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as _np  # Avoids becoming a part of public Tensorflow API.
22
23from tensorflow.compiler.tf2xla.python import xla as tf2xla
24from tensorflow.compiler.xla import xla_data_pb2
25from tensorflow.core.framework import attr_value_pb2
26
27
28class Sharding(object):
29  """A class to support adding sharding attributes to Ops.
30
31  Use the factory constructors and then call apply_to_tensor:
32    Sharding.replicate().apply_to_tensor(tensor)
33  """
34
35  def __init__(self, proto=None):
36    """Do not use this constructor; use the factory functions below."""
37    self._proto = proto
38
39  @classmethod
40  def replicate(cls):
41    """Returns a replicated sharding attribute.
42
43    This causes an op to be computed in its entirety independently on all
44    cores in the XLA device.
45    """
46    return Sharding(
47        proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED))
48
49  @classmethod
50  def manual(cls):
51    """Returns a manuall sharding attribute.
52
53    This means the op is manually partitioned by the user and XLA will not
54    change the shapes.
55    """
56    return Sharding(
57        proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.MANUAL))
58
59  @classmethod
60  def assign_device(cls, core):
61    """Returns an AssignDevice sharding attribute.
62
63    This causes an op to be computed in its entirety only on one core in
64    the XLA device.
65    Args:
66      core: The core to assign this Op to.
67    """
68    return Sharding(
69        proto=xla_data_pb2.OpSharding(
70            type=xla_data_pb2.OpSharding.MAXIMAL,
71            tile_assignment_dimensions=[1],
72            tile_assignment_devices=[core]))
73
74  @classmethod
75  def tile(cls, tile_assignment):
76    """Returns a Tiled sharding attribute.
77
78    This causes an op to be partially computed on multiple cores in the
79    XLA device.
80
81    Args:
82      tile_assignment: An np.ndarray describing the topology of the tiling and
83        which device will compute which part of the topology.
84
85    Raises:
86      TypeError: tile_assignment was not of np.array type.
87
88    TODO(jmolloy): This concept is nefarious and is not
89    something we really want to expose to users (especially as the
90    contract for tile_assignment is very strict).
91    """
92    if not isinstance(tile_assignment, _np.ndarray):
93      raise TypeError('Tile assignment must be of type np.ndarray')
94    dims = list(tile_assignment.shape)
95    flattened_devices = tile_assignment.reshape(-1, order='C')
96    return Sharding(
97        proto=xla_data_pb2.OpSharding(
98            type=xla_data_pb2.OpSharding.OTHER,
99            tile_assignment_dimensions=dims,
100            tile_assignment_devices=list(flattened_devices)))
101
102  @classmethod
103  def partial_tile(cls, tile_assignment):
104    """Returns a partially tiled sharding attribute.
105
106    This is similar to tile(), but tile_assignment has one more dimension than
107    the tensor, and tiles in the last dimension of tile_assignment are
108    replicated.
109
110    Args:
111      tile_assignment: An np.ndarray describing the topology of the tiling and
112        which device will compute which part of the topology.
113
114    Raises:
115      TypeError: tile_assignment was not of np.array type.
116    """
117    if not isinstance(tile_assignment, _np.ndarray):
118      raise TypeError('PartialTile assignment must be of type np.ndarray')
119    dims = list(tile_assignment.shape)
120    flattened_devices = tile_assignment.reshape(-1, order='C')
121    return Sharding(
122        proto=xla_data_pb2.OpSharding(
123            type=xla_data_pb2.OpSharding.OTHER,
124            tile_assignment_dimensions=dims,
125            tile_assignment_devices=list(flattened_devices),
126            replicate_on_last_tile_dim=True))
127
128  @classmethod
129  def split(cls, tensor, split_dimension, num_devices, input_shape=None):
130    """Returns a Sharding that splits a tensor across a dimension.
131
132    This creates a Tiled attribute, similar to tile(), but easier to use for the
133    common case of tiling a tensor N ways in one dimension.
134
135    Args:
136      tensor: A tf.Tensor to split.
137      split_dimension: The dimension number to split.
138      num_devices: The number of cores to split `tensor` over.
139      input_shape: The shape of the original tensor.
140
141    Raises:
142      ValueError: The tensor to split was smaller in the split dimension than
143        the number of devices to split over.
144    """
145    if input_shape:
146      shape = input_shape
147    else:
148      shape = tensor.shape.as_list()
149    if (shape[split_dimension] is not None and
150        shape[split_dimension] < num_devices):
151      raise ValueError('Split dimension was smaller than the required number '
152                       'of splits: shape=%r, dimension=%r, num_devices=%r' %
153                       (shape, split_dimension, num_devices))
154
155    tile_assignment_dims = [1] * len(shape)
156    tile_assignment_dims[split_dimension] = num_devices
157
158    return Sharding(
159        proto=xla_data_pb2.OpSharding(
160            type=xla_data_pb2.OpSharding.OTHER,
161            tile_assignment_dimensions=tile_assignment_dims,
162            tile_assignment_devices=range(num_devices)))
163
164  def apply_to_tensor(self,
165                      tensor,
166                      assign_tuple_sharding=False,
167                      use_sharding_op=False):
168    """Applies this Sharding attribute to `tensor`.
169
170    Args:
171      tensor: A tf.Tensor to split.
172      assign_tuple_sharding: If the sharding type should be a tuple.
173      use_sharding_op: whether to create a sharding op on `tensor`.
174
175    Returns:
176      The tensor with Sharding attribute.
177    """
178    proto = self._proto
179    if use_sharding_op:
180      if assign_tuple_sharding:
181        proto = self._create_tuple_proto(num_outputs=1)
182        tensor = tf2xla.sharding(tensor, sharding=proto.SerializeToString())
183      else:
184        tensor = tf2xla.sharding(
185            tensor, sharding=proto.SerializeToString())
186    elif assign_tuple_sharding or len(tensor.op.outputs) > 1:
187      proto = self._get_or_create_tuple_proto(tensor.op)
188      # We can't mutate an element of old_proto.tuple_shardings, so create
189      # a new proto.
190      tuple_shardings = list(proto.tuple_shardings)
191      tuple_shardings[tensor.value_index] = self._proto
192      proto = xla_data_pb2.OpSharding(
193          type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=tuple_shardings)
194
195    # TODO(jmolloy): This need to be seriously revisited before declaring this
196    # API available for public use.
197    # pylint: disable=protected-access
198    tensor.op._set_attr('_XlaSharding',
199                        attr_value_pb2.AttrValue(s=proto.SerializeToString()))
200    return tensor
201
202  def apply_to_operation(self, operation):
203    """Applies this Sharding attribute to `operation`.
204
205    Args:
206      operation: A tf.Operation to add sharding annotation.
207    """
208    attr_value = attr_value_pb2.AttrValue(s=self._proto.SerializeToString())
209    # pylint: disable=protected-access
210    operation._set_attr('_XlaSharding', attr_value)
211
212  @property
213  def proto(self):
214    """Return the sharding protobuf of type xla_data_pb2.OpSharding."""
215    return self._proto
216
217  def _get_or_create_tuple_proto(self, op):
218    try:
219      attr = op.get_attr('_XlaSharding')
220      proto = xla_data_pb2.OpSharding()
221      proto.ParseFromString(attr)
222      return proto
223    except ValueError:
224      return self._create_tuple_proto(len(op.outputs))
225
226  def _create_tuple_proto(self, num_outputs):
227    shardings = [
228        xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED)
229    ] * num_outputs
230    return xla_data_pb2.OpSharding(
231        type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=shardings)
232
233
234def copy_sharding(from_tensor, to_tensor, use_sharding_op=False):
235  """Copies the a tensor's sharding to another.
236
237  Args:
238    from_tensor: Source tensor. Must be the sole output of an op.
239    to_tensor: the tensor the annotate with the copy.
240    use_sharding_op: whether to create a sharding op on `to_tensor`.
241
242  Returns:
243    A tensor with sharding annotation copied from `from_tensor`.
244  """
245  sharding = get_tensor_sharding(from_tensor)
246  if sharding is None:
247    return to_tensor
248
249  if use_sharding_op:
250    to_tensor = tf2xla.sharding(to_tensor, sharding=sharding)
251  attr_value = attr_value_pb2.AttrValue(s=sharding)
252  # pylint: disable=protected-access
253  to_tensor.op._set_attr('_XlaSharding', attr_value)
254  return to_tensor
255
256# Helpers for the above factory functions that allow easy application of
257# shardings, for example:
258#   tensor = xla_sharding.replicate(tensor)
259
260
261def replicate(tensor, assign_tuple_sharding=False, use_sharding_op=False):
262  return Sharding.replicate().apply_to_tensor(
263      tensor,
264      assign_tuple_sharding=assign_tuple_sharding,
265      use_sharding_op=use_sharding_op)
266
267
268def assign_device(tensor,
269                  device,
270                  assign_tuple_sharding=False,
271                  use_sharding_op=False):
272  """Returns a tensor that has AssignDevice sharding attribute."""
273  return Sharding.assign_device(device).apply_to_tensor(
274      tensor,
275      assign_tuple_sharding=assign_tuple_sharding,
276      use_sharding_op=use_sharding_op)
277
278
279def tile(tensor,
280         tile_assignment,
281         assign_tuple_sharding=False,
282         use_sharding_op=False):
283  """Returns a tensor that has tiled sharding.
284
285  Args:
286    tensor: A tf.Tensor to shard.
287    tile_assignment: An np.ndarray describing the topology of the tiling and
288      which device will compute which part of the topology.
289    assign_tuple_sharding: If the sharding type should be a tuple.
290    use_sharding_op: If true, adds a sharding op to set the sharding.
291  """
292  return Sharding.tile(tile_assignment).apply_to_tensor(
293      tensor,
294      assign_tuple_sharding=assign_tuple_sharding,
295      use_sharding_op=use_sharding_op)
296
297
298def split(tensor,
299          split_dimension,
300          num_devices,
301          assign_tuple_sharding=False,
302          use_sharding_op=False,
303          input_shape=None):
304  """Returns a tensor that is split along the given dimension.
305
306  Args:
307    tensor: A tf.Tensor to split.
308    split_dimension: The dimension to split.
309    num_devices: The number of devices to partition the dimension.
310    assign_tuple_sharding: If the sharding type should be a tuple.
311    use_sharding_op: If true, adds a sharding op to set the sharding.
312    input_shape: The full shape of the input tensor.
313  """
314  return Sharding.split(tensor, split_dimension, num_devices,
315                        input_shape).apply_to_tensor(
316                            tensor,
317                            assign_tuple_sharding=assign_tuple_sharding,
318                            use_sharding_op=use_sharding_op)
319
320
321def partial_tile(tensor, tile_assignment, use_sharding_op=False):
322  """Returns a tensor that has tiled sharding.
323
324  Args:
325    tensor: A tf.Tensor to shard.
326    tile_assignment: An np.ndarray describing the topology of the tiling and
327      which device will compute which part of the topology. It must have one
328      more dimension than tensor, and the last dimension represents partially
329      replicated tiles.
330    use_sharding_op: If true, adds a sharding op to set the sharding.
331  """
332  return Sharding.partial_tile(tile_assignment).apply_to_tensor(
333      tensor, use_sharding_op=use_sharding_op)
334
335
336def get_op_sharding(op):
337  """Returns sharding attribute of an op.
338
339  Args:
340    op: a TensorFlow op.
341
342  Returns:
343    The attribute representing XLA sharding on this op.
344  """
345  try:
346    return op.get_attr('_XlaSharding')
347  except ValueError:
348    return None
349  except AttributeError:
350    # AttributeError: 'DistributedVarOp' object has no attribute 'get_attr'.
351    return None
352
353
354def get_tensor_sharding(tensor):
355  """Returns sharding attribute of a Tensor.
356
357  Args:
358    tensor: a Tensor.
359
360  Returns:
361    The attribute representing XLA sharding on tensor's op.
362  """
363  try:
364    return get_op_sharding(tensor.op)
365  except AttributeError:
366    # AttributeError: Tensor.op is meaningless when eager execution is enabled.
367    return None
368
369
370def auto_to_manual_spmd_partition(tensor, manual_sharding):
371  """Switches from automatic SPMD partitioning to manual partitioning.
372
373  Converts a full-shaped tensor (to be automatically partitioned by SPMD
374  partitioner) to a shard-shaped tensor to be consumed by manually partitioned
375  ops.
376
377  Args:
378    tensor: A tf.Tensor in full shape.
379    manual_sharding: a serialized string of OpSharding to be used in manual
380      partitioning.
381
382  Returns:
383    A shard-shaped tensor to be consumed by manually partitioned ops.
384  """
385  return tf2xla.spmd_full_to_shard_shape(
386      tensor, manual_sharding=manual_sharding)
387
388
389def manual_to_auto_spmd_partition(tensor, manual_sharding, full_shape):
390  """Switches from manual partitioning to automatic SPMD partitioning.
391
392  Converts a shard-shaped tensor (manually partitioned in SPMD-style) to a
393  full-shaped tensor to be partitioned automatically by the SPMD partitioner.
394
395  Args:
396    tensor: A tf.Tensor in shard shape.
397    manual_sharding: a serialized string of OpSharding to be used in manual
398      partitioning.
399    full_shape: the shape of tensor before partitioning.
400
401  Returns:
402    A full-shaped tensor to be partitioned automatically by the SPMD
403    partitioner.
404  """
405  return tf2xla.spmd_shard_to_full_shape(
406      tensor, manual_sharding=manual_sharding, full_shape=full_shape)
407
408
409def mesh_split_sharding(device_mesh, tensor_split_dims_mapping):
410  """Returns a Sharding object representing sharding along multiple dimensions.
411
412  Args:
413    device_mesh: An np.ndarray describing the topology of the device mesh and
414      each element is the ID of the device in the topology.
415    tensor_split_dims_mapping: A list of integers that map each tensor axis to
416      the device mesh axis along which it is sharded. Its length is the tensor
417      rank, and tensor_split_dims_mapping[i] is device mesh axis for tensor
418      dimension i. Use -1 for tensor dimensions that are not sharded.
419
420  Raises:
421    ValueError: The number of tensor split dimensions is larger than device mesh
422      rank.
423  """
424  permutation = [d for d in tensor_split_dims_mapping if d >= 0]
425  if len(permutation) > len(device_mesh.shape):
426    raise ValueError(
427        'Number of tensor split dimensions (%r) is larger than device mesh '
428        'rank (%r). tensor_split_dims_mapping: %r, device_mesh.shape: %r' %
429        (len(permutation), len(
430            device_mesh.shape), tensor_split_dims_mapping, device_mesh.shape))
431  # Append replicated dimensions to the end.
432  transpose_permutation = permutation + [
433      d for d in range(len(device_mesh.shape)) if d not in permutation
434  ]
435  tile_assignment = _np.transpose(device_mesh, transpose_permutation)
436  tile_shape = [
437      1 if d < 0 else device_mesh.shape[d] for d in tensor_split_dims_mapping
438  ]
439  partial = len(permutation) < len(device_mesh.shape)
440  if partial:
441    tile_shape.append(_np.prod(device_mesh.shape) // _np.prod(tile_shape))
442  tile_assignment = _np.reshape(tile_assignment, tile_shape)
443
444  if partial:
445    return Sharding.partial_tile(tile_assignment)
446  return Sharding.tile(tile_assignment)
447
448
449def mesh_split(tensor,
450               device_mesh,
451               tensor_split_dims_mapping,
452               use_sharding_op=False):
453  """Returns a tensor that is split along multiple dimensions in a device mesh.
454
455  Args:
456    tensor: A tf.Tensor to split.
457    device_mesh: An np.ndarray describing the topology of the device mesh and
458      each element is the ID of the device in the topology.
459    tensor_split_dims_mapping: A list of integers that map each tensor axis to
460      the device mesh axis along which it is sharded. Its length is the tensor
461      rank, and tensor_split_dims_mapping[i] is device mesh axis for tensor
462      dimension i. Use -1 for tensor dimensions that are not sharded.
463    use_sharding_op: If true, adds a sharding op to set the sharding.
464
465  Raises:
466    ValueError: The number of tensor split dimensions is larger than device mesh
467      rank.
468  """
469  sharding = mesh_split_sharding(device_mesh, tensor_split_dims_mapping)
470  return sharding.apply_to_tensor(tensor, use_sharding_op=use_sharding_op)
471