1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15 16"""Helper library for sharding during TPU compilation.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from six.moves import xrange # pylint: disable=redefined-builtin 23 24from tensorflow.python.framework import tensor_shape 25 26_DEFAULT_NUMBER_OF_SHARDS = 1 27_DEFAULT_SHARD_DIMENSION = 0 28 29 30# TODO(b/36777903) change other parts of tpu.py to use this class. 31class ShardingPolicy(object): 32 """An object use to hold the sharding policy for a Tensor. 33 """ 34 35 def __init__(self): 36 self._number_of_shards = None 37 self._number_of_partitions = 1 38 self._shard_dimension = None 39 self._frozen = False 40 41 def __str__(self): 42 if self.number_of_shards is None or self.shard_dimension is None: 43 return "ShardingPolicy(unset)" 44 else: 45 return ("ShardingPolicy(%d shards dimension %d)" % 46 (self.number_of_shards, self.shard_dimension)) 47 48 def _fill_default_values(self): 49 if self._number_of_shards is None: 50 self._number_of_shards = _DEFAULT_NUMBER_OF_SHARDS 51 if self._shard_dimension is None: 52 self._shard_dimension = tensor_shape.as_dimension( 53 _DEFAULT_SHARD_DIMENSION) 54 55 def freeze(self): 56 """Prevents further modification to the sharding policy. 57 58 Any values that have not been set when freeze is called are set to 59 defaults. If the ShardingPolicy is already frozen, this is a NoOp. 60 """ 61 if not self._frozen: 62 self._fill_default_values() 63 self._frozen = True 64 65 @property 66 def number_of_shards(self): 67 """Returns the number of shards in the policy or None if unspecified.""" 68 return self._number_of_shards 69 70 def set_number_of_shards(self, number_of_shards): 71 """Sets the number of shards for the current policy. 72 73 If the policy has been frozen then number_of_shards must match the 74 existing setting. 75 76 Args: 77 number_of_shards: The number of shards to use in the policy. 78 79 Raises: 80 ValueError: If the policy has been frozen and number_of_shards 81 differs from the frozen value; or number_of_shards <= 0. 82 """ 83 if self._frozen: 84 if self._number_of_shards != number_of_shards: 85 raise ValueError( 86 "Can't set sharding policy to use %d shards since it has been " 87 "frozen to use %d." % (number_of_shards, self._number_of_shards)) 88 else: 89 if number_of_shards > 0: 90 self._number_of_shards = number_of_shards 91 else: 92 raise ValueError( 93 "Can't set sharding policy to use %s shards; value must be >0" % 94 str(number_of_shards)) 95 96 @property 97 def number_of_partitions(self): 98 """Returns the number of partitions of the policy or None if unspecified.""" 99 return self._number_of_partitions 100 101 def set_number_of_partitions(self, number_of_partitions): 102 """Sets the number of partitions for the current policy. 103 104 If the policy has been frozen then shard_dimension must match the 105 existing setting. 106 107 Args: 108 number_of_partitions: The number of partitions to use in the policy. 109 110 Raises: 111 ValueError: If the policy has been frozen and shard_dimension 112 differs from the frozen value. 113 """ 114 if self._frozen: 115 if self._number_of_partitions != number_of_partitions: 116 raise ValueError( 117 "Can't set number_of_partitions to %d since it has been frozen to " 118 "use %d." % (number_of_partitions, self._number_of_partitions)) 119 else: 120 self._number_of_partitions = number_of_partitions 121 122 @property 123 def shard_dimension(self): 124 """Returns the shard dimension of the policy or None if unspecified.""" 125 return self._shard_dimension 126 127 def set_shard_dimension(self, shard_dimension): 128 """Sets the shard dimension for the current policy. 129 130 If the policy has been frozen then shard_dimension must match the 131 existing setting. 132 133 Args: 134 shard_dimension: The shard dimension to use in the policy. 135 136 Raises: 137 ValueError: If the policy has been frozen and shard_dimension 138 differs from the frozen value, or shard_dimension can't be 139 interpreted as a Dimension. 140 """ 141 if self._frozen: 142 if self._shard_dimension != shard_dimension: 143 raise ValueError( 144 "Can't set shard dimension to %d since it has been frozen to " 145 "use %d." % (shard_dimension, self._shard_dimension)) 146 else: 147 self._shard_dimension = tensor_shape.as_dimension(shard_dimension) 148 149 def merge(self, other): 150 """Merges the policy of another policy into the current policy. 151 152 Args: 153 other: The policy to merge into this one. 154 155 Raises: 156 ValueError: If this policy has been frozen and the merge conflicts with 157 the frozen policy. 158 """ 159 if other.number_of_shards is not None: 160 self.set_number_of_shards(other.number_of_shards) 161 if other.shard_dimension is not None: 162 self.set_shard_dimension(other.shard_dimension) 163 164 def get_unpartitioned_shape(self, shape): 165 """Returns the shape of an unpartitioned Tensor. 166 167 When given the shape of a 'sharded-size' Tensor, returns the shape 168 of the full shape of its unpartitioned Tensor. 169 170 Args: 171 shape: The shape of the sharded Tensor. 172 173 Returns: 174 The shape of the unpartitioned version of the Tensor. 175 176 Raises: 177 ValueError: if shape has unknown sharded dimension 178 """ 179 shape = tensor_shape.as_shape(shape) 180 dims = shape.as_list() 181 if (self._shard_dimension is None or self._number_of_partitions is None or 182 not dims): 183 return None 184 if dims[self._shard_dimension] is None: 185 raise ValueError("shape %s must have a fixed size for dimension %d " 186 "that is known at graph construction time." % 187 (shape.as_list(), self._shard_dimension)) 188 if self._number_of_partitions > 1: 189 dims[self._shard_dimension] *= self._number_of_partitions 190 return tensor_shape.as_shape(dims) 191 192 def get_sharded_shape(self, shape, shard_index=None): 193 """Returns the shape of a shard of a full Tensor. 194 195 When given the shape of a 'full-size' Tensor, returns the shape of 196 the sub-Tensor after it has been sharded. Freezes the policy if it 197 has not yet been frozen. 198 199 Args: 200 shape: The shape of the full-size Tensor to be sharded. 201 shard_index: The index of the shard whose shape should be returned. 202 shard_index can be None for sharding policies that use the same 203 shape for every shard. 204 205 Returns: 206 The shape of the sharded version of the Tensor. 207 208 Raises: 209 ValueError: If shard_index is None when shards are of different 210 shapes; or shard_index is not None and 211 !(0<=shard_index<number_of_shards); or shape does not have at 212 least self.shard_dimension+1 dimensions; or the value of 213 shape's shard dimension is not a multiple of 214 self.number_of_shards 215 """ 216 if self._shard_dimension is None or self._number_of_shards is None: 217 # Don't raise an error if the config is unset. 218 return None 219 if shard_index is not None: 220 if shard_index < 0 or shard_index >= self.number_of_shards: 221 raise ValueError("shard_index %d, but must be in [0,%d)." % 222 (shard_index, self._number_of_shards)) 223 shape = tensor_shape.as_shape(shape) 224 if self._number_of_shards == 1: 225 # Don't do anything when there's only one shard. 226 return shape 227 ndims = shape.ndims 228 if ndims is None: 229 raise ValueError("shape must be a specified shape not Unknown") 230 if ndims <= self._shard_dimension: 231 raise ValueError("shape %s does not contain shard_dimension %d" % 232 (shape.as_list(), self._shard_dimension)) 233 dims = shape.as_list() 234 if dims[self._shard_dimension] is None: 235 raise ValueError("shape %s must have a fixed size for dimension %d " 236 "that is known at graph construction time." % 237 (shape.as_list(), self._shard_dimension)) 238 if (dims[self._shard_dimension] % self._number_of_shards) != 0: 239 raise ValueError("shape %s cannot be sharded %d ways along dimension %d" % 240 (shape.as_list(), self._number_of_shards, 241 self._shard_dimension)) 242 dims[self._shard_dimension] //= self._number_of_shards 243 return tensor_shape.TensorShape(dims) 244 245 def _unshard_shape(self, shape): 246 """Return the unsharded shape that would generate a given sharded shape. 247 248 Args: 249 shape: the sharded shape to unshard 250 251 Returns: 252 The unsharded shape. 253 254 Raises: 255 ValueError: if shape is unknown or does not contain 256 self.shard_dimension 257 TypeError: if shape is not convertible to a TensorShape 258 """ 259 shape = tensor_shape.as_shape(shape) 260 if self._number_of_shards == 1: 261 # Don't do anything when there's only one shard. 262 return shape 263 ndims = shape.ndims 264 if ndims is None: 265 raise ValueError("shape must be a specified shape not Unknown") 266 if ndims <= self._shard_dimension: 267 raise ValueError("shape %s does not contain shard_dimension %d" % 268 (shape.as_list(), self._shard_dimension)) 269 dims = shape.as_list() 270 dims[self._shard_dimension] *= self._number_of_shards 271 return tensor_shape.TensorShape(dims) 272 273 def get_unsharded_shape(self, shapes): 274 """Returns the shape of an unsharded Tensor given a list of shards. 275 276 When given a list of shapes of shards, returns the shape of the 277 unsharded Tensor that would generate the shards. Sets defaults for the 278 policy if number_of_shards or shard_dimension is None. 279 280 Args: 281 shapes: The shapes of the Tensor shards to be combined. 282 283 Returns: 284 The shape of the unsharded version of the Tensor. 285 286 Raises: 287 ValueError: if shapes is not a list of length 288 self.number_of_shards; or any element of shapes is not a valid 289 shape consistent with the sharding policy; or the list of 290 shapes is not a valid sharding of a full shape. 291 TypeError: if an element of shapes is not convertible to a 292 TensorShape 293 """ 294 self._fill_default_values() 295 if len(shapes) != self.number_of_shards: 296 raise ValueError( 297 "shapes is %s but must be a list of length number_of_shards=%d" % ( 298 str(shapes), self.number_of_shards)) 299 unsharded_shapes = [self._unshard_shape(s) for s in shapes] 300 for i in xrange(self.number_of_shards - 1): 301 if not unsharded_shapes[i].is_compatible_with( 302 unsharded_shapes[self.number_of_shards - 1]): 303 raise ValueError( 304 "sharded shapes %s are not consistent shards of a full shape " 305 "sharded %d ways along dimension %d" % ( 306 str(shapes), self.number_of_shards, self.shard_dimension)) 307 return unsharded_shapes[0] 308