1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Helper library for sharding during TPU compilation."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from six.moves import xrange  # pylint: disable=redefined-builtin
23
24from tensorflow.python.framework import tensor_shape
25
26_DEFAULT_NUMBER_OF_SHARDS = 1
27_DEFAULT_SHARD_DIMENSION = 0
28
29
30# TODO(b/36777903) change other parts of tpu.py to use this class.
31class ShardingPolicy(object):
32  """An object use to hold the sharding policy for a Tensor.
33  """
34
35  def __init__(self):
36    self._number_of_shards = None
37    self._number_of_partitions = 1
38    self._shard_dimension = None
39    self._frozen = False
40
41  def __str__(self):
42    if self.number_of_shards is None or self.shard_dimension is None:
43      return "ShardingPolicy(unset)"
44    else:
45      return ("ShardingPolicy(%d shards dimension %d)" %
46              (self.number_of_shards, self.shard_dimension))
47
48  def _fill_default_values(self):
49    if self._number_of_shards is None:
50      self._number_of_shards = _DEFAULT_NUMBER_OF_SHARDS
51    if self._shard_dimension is None:
52      self._shard_dimension = tensor_shape.as_dimension(
53          _DEFAULT_SHARD_DIMENSION)
54
55  def freeze(self):
56    """Prevents further modification to the sharding policy.
57
58    Any values that have not been set when freeze is called are set to
59    defaults. If the ShardingPolicy is already frozen, this is a NoOp.
60    """
61    if not self._frozen:
62      self._fill_default_values()
63      self._frozen = True
64
65  @property
66  def number_of_shards(self):
67    """Returns the number of shards in the policy or None if unspecified."""
68    return self._number_of_shards
69
70  def set_number_of_shards(self, number_of_shards):
71    """Sets the number of shards for the current policy.
72
73    If the policy has been frozen then number_of_shards must match the
74    existing setting.
75
76    Args:
77      number_of_shards: The number of shards to use in the policy.
78
79    Raises:
80      ValueError: If the policy has been frozen and number_of_shards
81        differs from the frozen value; or number_of_shards <= 0.
82    """
83    if self._frozen:
84      if self._number_of_shards != number_of_shards:
85        raise ValueError(
86            "Can't set sharding policy to use %d shards since it has been "
87            "frozen to use %d." % (number_of_shards, self._number_of_shards))
88    else:
89      if number_of_shards > 0:
90        self._number_of_shards = number_of_shards
91      else:
92        raise ValueError(
93            "Can't set sharding policy to use %s shards; value must be >0" %
94            str(number_of_shards))
95
96  @property
97  def number_of_partitions(self):
98    """Returns the number of partitions of the policy or None if unspecified."""
99    return self._number_of_partitions
100
101  def set_number_of_partitions(self, number_of_partitions):
102    """Sets the number of partitions for the current policy.
103
104    If the policy has been frozen then shard_dimension must match the
105    existing setting.
106
107    Args:
108      number_of_partitions: The number of partitions to use in the policy.
109
110    Raises:
111      ValueError: If the policy has been frozen and shard_dimension
112        differs from the frozen value.
113    """
114    if self._frozen:
115      if self._number_of_partitions != number_of_partitions:
116        raise ValueError(
117            "Can't set number_of_partitions to %d since it has been frozen to "
118            "use %d." % (number_of_partitions, self._number_of_partitions))
119    else:
120      self._number_of_partitions = number_of_partitions
121
122  @property
123  def shard_dimension(self):
124    """Returns the shard dimension of the policy or None if unspecified."""
125    return self._shard_dimension
126
127  def set_shard_dimension(self, shard_dimension):
128    """Sets the shard dimension for the current policy.
129
130    If the policy has been frozen then shard_dimension must match the
131    existing setting.
132
133    Args:
134      shard_dimension: The shard dimension to use in the policy.
135
136    Raises:
137      ValueError: If the policy has been frozen and shard_dimension
138        differs from the frozen value, or shard_dimension can't be
139        interpreted as a Dimension.
140    """
141    if self._frozen:
142      if self._shard_dimension != shard_dimension:
143        raise ValueError(
144            "Can't set shard dimension to %d since it has been frozen to "
145            "use %d." % (shard_dimension, self._shard_dimension))
146    else:
147      self._shard_dimension = tensor_shape.as_dimension(shard_dimension)
148
149  def merge(self, other):
150    """Merges the policy of another policy into the current policy.
151
152    Args:
153      other: The policy to merge into this one.
154
155    Raises:
156      ValueError: If this policy has been frozen and the merge conflicts with
157      the frozen policy.
158    """
159    if other.number_of_shards is not None:
160      self.set_number_of_shards(other.number_of_shards)
161    if other.shard_dimension is not None:
162      self.set_shard_dimension(other.shard_dimension)
163
164  def get_unpartitioned_shape(self, shape):
165    """Returns the shape of an unpartitioned Tensor.
166
167    When given the shape of a 'sharded-size' Tensor, returns the shape
168    of the full shape of its unpartitioned Tensor.
169
170    Args:
171      shape: The shape of the sharded Tensor.
172
173    Returns:
174      The shape of the unpartitioned version of the Tensor.
175
176    Raises:
177      ValueError: if shape has unknown sharded dimension
178    """
179    shape = tensor_shape.as_shape(shape)
180    dims = shape.as_list()
181    if (self._shard_dimension is None or self._number_of_partitions is None or
182        not dims):
183      return None
184    if dims[self._shard_dimension] is None:
185      raise ValueError("shape %s must have a fixed size for dimension %d "
186                       "that is known at graph construction time." %
187                       (shape.as_list(), self._shard_dimension))
188    if self._number_of_partitions > 1:
189      dims[self._shard_dimension] *= self._number_of_partitions
190    return tensor_shape.as_shape(dims)
191
192  def get_sharded_shape(self, shape, shard_index=None):
193    """Returns the shape of a shard of a full Tensor.
194
195    When given the shape of a 'full-size' Tensor, returns the shape of
196    the sub-Tensor after it has been sharded. Freezes the policy if it
197    has not yet been frozen.
198
199    Args:
200      shape: The shape of the full-size Tensor to be sharded.
201      shard_index: The index of the shard whose shape should be returned.
202        shard_index can be None for sharding policies that use the same
203        shape for every shard.
204
205    Returns:
206      The shape of the sharded version of the Tensor.
207
208    Raises:
209      ValueError: If shard_index is None when shards are of different
210        shapes; or shard_index is not None and
211        !(0<=shard_index<number_of_shards); or shape does not have at
212        least self.shard_dimension+1 dimensions; or the value of
213        shape's shard dimension is not a multiple of
214        self.number_of_shards
215    """
216    if self._shard_dimension is None or self._number_of_shards is None:
217      # Don't raise an error if the config is unset.
218      return None
219    if shard_index is not None:
220      if shard_index < 0 or shard_index >= self.number_of_shards:
221        raise ValueError("shard_index %d, but must be in [0,%d)." %
222                         (shard_index, self._number_of_shards))
223    shape = tensor_shape.as_shape(shape)
224    if self._number_of_shards == 1:
225      # Don't do anything when there's only one shard.
226      return shape
227    ndims = shape.ndims
228    if ndims is None:
229      raise ValueError("shape must be a specified shape not Unknown")
230    if ndims <= self._shard_dimension:
231      raise ValueError("shape %s does not contain shard_dimension %d" %
232                       (shape.as_list(), self._shard_dimension))
233    dims = shape.as_list()
234    if dims[self._shard_dimension] is None:
235      raise ValueError("shape %s must have a fixed size for dimension %d "
236                       "that is known at graph construction time." %
237                       (shape.as_list(), self._shard_dimension))
238    if (dims[self._shard_dimension] % self._number_of_shards) != 0:
239      raise ValueError("shape %s cannot be sharded %d ways along dimension %d" %
240                       (shape.as_list(), self._number_of_shards,
241                        self._shard_dimension))
242    dims[self._shard_dimension] //= self._number_of_shards
243    return tensor_shape.TensorShape(dims)
244
245  def _unshard_shape(self, shape):
246    """Return the unsharded shape that would generate a given sharded shape.
247
248    Args:
249      shape: the sharded shape to unshard
250
251    Returns:
252      The unsharded shape.
253
254    Raises:
255      ValueError: if shape is unknown or does not contain
256        self.shard_dimension
257      TypeError: if shape is not convertible to a TensorShape
258    """
259    shape = tensor_shape.as_shape(shape)
260    if self._number_of_shards == 1:
261      # Don't do anything when there's only one shard.
262      return shape
263    ndims = shape.ndims
264    if ndims is None:
265      raise ValueError("shape must be a specified shape not Unknown")
266    if ndims <= self._shard_dimension:
267      raise ValueError("shape %s does not contain shard_dimension %d" %
268                       (shape.as_list(), self._shard_dimension))
269    dims = shape.as_list()
270    dims[self._shard_dimension] *= self._number_of_shards
271    return tensor_shape.TensorShape(dims)
272
273  def get_unsharded_shape(self, shapes):
274    """Returns the shape of an unsharded Tensor given a list of shards.
275
276    When given a list of shapes of shards, returns the shape of the
277    unsharded Tensor that would generate the shards. Sets defaults for the
278    policy if number_of_shards or shard_dimension is None.
279
280    Args:
281      shapes: The shapes of the Tensor shards to be combined.
282
283    Returns:
284      The shape of the unsharded version of the Tensor.
285
286    Raises:
287      ValueError: if shapes is not a list of length
288        self.number_of_shards; or any element of shapes is not a valid
289        shape consistent with the sharding policy; or the list of
290        shapes is not a valid sharding of a full shape.
291      TypeError: if an element of shapes is not convertible to a
292        TensorShape
293    """
294    self._fill_default_values()
295    if len(shapes) != self.number_of_shards:
296      raise ValueError(
297          "shapes is %s but must be a list of length number_of_shards=%d" % (
298              str(shapes), self.number_of_shards))
299    unsharded_shapes = [self._unshard_shape(s) for s in shapes]
300    for i in xrange(self.number_of_shards - 1):
301      if not unsharded_shapes[i].is_compatible_with(
302          unsharded_shapes[self.number_of_shards - 1]):
303        raise ValueError(
304            "sharded shapes %s are not consistent shards of a full shape "
305            "sharded %d ways along dimension %d" % (
306                str(shapes), self.number_of_shards, self.shard_dimension))
307    return unsharded_shapes[0]
308