1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""ShardedVariable class.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import copy 21import math 22import numpy as np 23 24from tensorflow.python.framework import composite_tensor 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.framework import type_spec 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import embedding_ops 32from tensorflow.python.ops import partitioned_variables 33from tensorflow.python.ops import resource_variable_ops 34from tensorflow.python.ops import variables as variables_lib 35from tensorflow.python.saved_model import revived_types 36from tensorflow.python.saved_model import save_context 37from tensorflow.python.training.saving import saveable_object_util 38from tensorflow.python.training.tracking import base as trackable 39from tensorflow.python.util import dispatch 40from tensorflow.python.util.tf_export import tf_export 41 42 43@tf_export('distribute.experimental.partitioners.Partitioner', v1=[]) 44class Partitioner(object): 45 """Partitioner base class: all partitiners inherit from this class. 46 47 Partitioners should implement a `__call__` method with the following 48 signature: 49 50 ```python 51 def __call__(self, shape, dtype, axis=0): 52 # Partitions the given `shape` and returns the partition results. 53 # See docstring of `__call__` method for the format of partition results. 54 ``` 55 """ 56 57 def __call__(self, shape, dtype, axis=0): 58 """Partitions the given `shape` and returns the partition results. 59 60 Examples of a partitioner that allocates a fixed number of shards: 61 62 ```python 63 partitioner = FixedShardsPartitioner(num_shards=2) 64 partitions = partitioner(tf.TensorShape([10, 3], tf.float32), axis=0) 65 print(partitions) # [2, 0] 66 ``` 67 68 Args: 69 shape: a `tf.TensorShape`, the shape to partition. 70 dtype: a `tf.dtypes.Dtype` indicating the type of the partition value. 71 axis: The axis to partition along. Default: outermost axis. 72 73 Returns: 74 A list of integers representing the number of partitions on each axis, 75 where i-th value correponds to i-th axis. 76 """ 77 raise NotImplementedError 78 79 80@tf_export('distribute.experimental.partitioners.FixedShardsPartitioner', v1=[]) 81class FixedShardsPartitioner(Partitioner): 82 """Partitioner that allocates a fixed number of shards. 83 84 Examples: 85 86 >>> # standalone usage: 87 >>> partitioner = FixedShardsPartitioner(num_shards=2) 88 >>> partitions = partitioner(tf.TensorShape([10, 3]), tf.float32) 89 >>> [2, 1] 90 >>> 91 >>> # use in ParameterServerStrategy 92 >>> # strategy = tf.distribute.experimental.ParameterServerStrategy( 93 >>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner) 94 95 """ 96 97 def __init__(self, num_shards): 98 """Creates a new `FixedShardsPartitioner`. 99 100 Args: 101 num_shards: `int`, number of shards to partition. 102 """ 103 self._num_shards = num_shards 104 105 def __call__(self, shape, dtype, axis=0): 106 del dtype 107 result = [1] * len(shape) 108 result[axis] = min(self._num_shards, shape.dims[axis].value) 109 return result 110 111 112@tf_export('distribute.experimental.partitioners.MinSizePartitioner', v1=[]) 113class MinSizePartitioner(Partitioner): 114 """Partitioner that allocates a minimum size per shard. 115 116 This partitioner ensures each shard has at least `min_shard_bytes`, and tries 117 to allocate as many shards as possible, i.e., keeping shard size as small as 118 possible. The maximum number of such shards (upper bound) is given by 119 `max_shards`. 120 121 Examples: 122 123 >>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=2) 124 >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) 125 >>> [2, 1] 126 >>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=10) 127 >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) 128 >>> [6, 1] 129 >>> 130 >>> # use in ParameterServerStrategy 131 >>> # strategy = tf.distribute.experimental.ParameterServerStrategy( 132 >>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner) 133 """ 134 135 def __init__(self, 136 min_shard_bytes=256 << 10, 137 max_shards=1, 138 bytes_per_string=16): 139 """Creates a new `MinSizePartitioner`. 140 141 Args: 142 min_shard_bytes: Minimum bytes of each shard. Defaults to 256K. 143 max_shards: Upper bound on the number of shards. Defaults to 1. 144 bytes_per_string: If the partition value is of type string, this provides 145 an estimate of how large each string is. 146 """ 147 if min_shard_bytes < 1: 148 raise ValueError('min_shard_bytes must be positive, got: %r' % 149 min_shard_bytes) 150 if max_shards < 1: 151 raise ValueError('max_shards must be positive, got: %r' % max_shards) 152 if bytes_per_string < 1: 153 raise ValueError('bytes_per_string must be positive, got: %r' % 154 bytes_per_string) 155 self._min_shard_bytes = min_shard_bytes 156 self._max_shards = max_shards 157 self._bytes_per_string = bytes_per_string 158 159 def __call__(self, shape, dtype, axis=0): 160 return partitioned_variables.min_max_variable_partitioner( 161 max_partitions=self._max_shards, 162 axis=axis, 163 min_slice_size=self._min_shard_bytes, 164 bytes_per_string_element=self._bytes_per_string)(shape, dtype) 165 166 167@tf_export('distribute.experimental.partitioners.MaxSizePartitioner', v1=[]) 168class MaxSizePartitioner(Partitioner): 169 """Partitioner that keeps shards below `max_shard_bytes`. 170 171 This partitioner ensures each shard has at most `max_shard_bytes`, and tries 172 to allocate as few shards as possible, i.e., keeping shard size as large 173 as possible. 174 175 If the partitioner hits the `max_shards` limit, then each shard may end up 176 larger than `max_shard_bytes`. By default `max_shards` equals `None` and no 177 limit on the number of shards is enforced. 178 179 Examples: 180 181 >>> partitioner = MaxSizePartitioner(max_shard_bytes=4) 182 >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) 183 >>> [6, 1] 184 >>> partitioner = MaxSizePartitioner(max_shard_bytes=4, max_shards=2) 185 >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) 186 >>> [2, 1] 187 >>> partitioner = MaxSizePartitioner(max_shard_bytes=1024) 188 >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32) 189 >>> [1, 1] 190 >>> 191 >>> # use in ParameterServerStrategy 192 >>> # strategy = tf.distribute.experimental.ParameterServerStrategy( 193 >>> # cluster_resolver=cluster_resolver, variable_partitioner=partitioner) 194 """ 195 196 def __init__(self, max_shard_bytes, max_shards=None, bytes_per_string=16): 197 """Creates a new `MaxSizePartitioner`. 198 199 Args: 200 max_shard_bytes: The maximum size any given shard is allowed to be. 201 max_shards: The maximum number of shards in `int` created taking 202 precedence over `max_shard_bytes`. 203 bytes_per_string: If the partition value is of type string, this provides 204 an estimate of how large each string is. 205 """ 206 if max_shard_bytes < 1: 207 raise ValueError('max_shard_bytes must be positive, got: %r' % 208 max_shard_bytes) 209 if max_shards and max_shards < 1: 210 raise ValueError('max_shards must be positive, got: %r' % max_shards) 211 if bytes_per_string < 1: 212 raise ValueError('bytes_per_string must be positive, got: %r' % 213 bytes_per_string) 214 215 self._max_shard_bytes = max_shard_bytes 216 self._max_shards = max_shards 217 self._bytes_per_string = bytes_per_string 218 219 def __call__(self, shape, dtype, axis=0): 220 return partitioned_variables.variable_axis_size_partitioner( 221 max_shard_bytes=self._max_shard_bytes, 222 max_shards=self._max_shards, 223 bytes_per_string_element=self._bytes_per_string, 224 axis=axis)(shape, dtype) 225 226 227class ShardedVariableSpec(type_spec.TypeSpec): 228 """Type specification for a `ShardedVariable`.""" 229 230 __slots__ = ['_variable_specs'] 231 232 value_type = property(lambda self: ShardedVariable) 233 234 def __init__(self, *variable_specs): 235 self._variable_specs = tuple(variable_specs) 236 237 def _serialize(self): 238 return self._variable_specs 239 240 @property 241 def _component_specs(self): 242 return self._variable_specs 243 244 def _to_components(self, value): 245 return value.variables 246 247 def _from_components(self, variables): 248 return ShardedVariable(variables) 249 250 251class ShardedVariableMixin(trackable.Trackable): 252 """Mixin for ShardedVariable.""" 253 254 # TODO(b/170877138): Remove this mixin once fixed. This mixin is required 255 # since TPUShardedVariable can't be a CompositeTensor. 256 257 def __init__(self, variables, name='ShardedVariable'): 258 """Treats `variables` as shards of a larger Variable. 259 260 261 Example: 262 263 ``` 264 variables = [ 265 tf.Variable(..., shape=(10, 100), dtype=tf.float32), 266 tf.Variable(..., shape=(15, 100), dtype=tf.float32), 267 tf.Variable(..., shape=(5, 100), dtype=tf.float32) 268 ] 269 sharded_variable = ShardedVariableMixin(variables) 270 assert sharded_variable.shape.as_list() == [30, 100] 271 ``` 272 273 Args: 274 variables: A list of `ResourceVariable`s that comprise this sharded 275 variable. Variables should not be shared between different 276 `ShardedVariableMixin` objects. 277 name: String. Name of this container. Defaults to "ShardedVariable". 278 """ 279 super(ShardedVariableMixin, self).__init__() 280 self._variables = variables 281 self._name = name 282 283 first_var = variables[0] 284 285 if any(not isinstance(v, variables_lib.Variable) for v in variables): 286 raise ValueError( 287 'Expected a list of `Variable`s, found: {}'.format(variables)) 288 289 var_dtypes = {v.dtype for v in variables} 290 if len(var_dtypes) > 1: 291 raise ValueError( 292 'All `Variable`s must have the same dtype, found: {}'.format( 293 [v.dtype for v in variables])) 294 self._dtype = first_var.dtype 295 296 # All variables must have the same shape for axes > 0. 297 higher_dim_shapes = {tuple(v.shape.as_list()[1:]) for v in variables} 298 if len(higher_dim_shapes) > 1: 299 raise ValueError( 300 'All `Variables`s must have the same shapes except for the first ' 301 'axis, found {}'.format([v.shape for v in variables])) 302 first_dim = sum(int(v.shape[0]) for v in variables) 303 self._shape = tensor_shape.TensorShape([first_dim] + first_var.shape[1:]) 304 self._var_offsets = [ 305 [0 for _ in range(len(first_var.shape))] for _ in range(len(variables)) 306 ] 307 for i in range(1, len(variables)): 308 # Always partition on the first axis. Offsets on other axes are 0. 309 self._var_offsets[i][0] += ( 310 self._var_offsets[i - 1][0] + variables[i - 1].shape[0]) 311 312 save_slice_info = [v._get_save_slice_info() for v in variables] # pylint: disable=protected-access 313 if any(slice_info is not None for slice_info in save_slice_info): 314 raise ValueError('`SaveSliceInfo` should not be set for `Variable`s. ' 315 '`ShardedVariable` will infer `SaveSliceInfo` according ' 316 'to the order of the `Variable`s in the list passed to ' 317 'the constructor. Found {}'.format(save_slice_info)) 318 319 # We create an uninitialized saving_variable with the full shape, which can 320 # be later captured in signatures so that the signatures can treat this 321 # ShardedVariable as one single variable. 322 self._saving_variable = resource_variable_ops.UninitializedVariable( 323 shape=self._shape, dtype=self._dtype, name=self._name) 324 325 def __iter__(self): 326 """Return an iterable for accessing the underlying sharded variables.""" 327 return iter(self._variables) 328 329 def __getitem__(self, slice_spec): 330 """Extracts the specified region as a Tensor from the sharded variable. 331 332 The API contract is identical to `Tensor.__getitem__`. Assignment to the 333 sliced range is not yet supported. 334 335 Args: 336 slice_spec: The arguments to __getitem__, specifying the global slicing of 337 the sharded variable. 338 339 Returns: 340 The appropriate slice of tensor based on `slice_spec`. 341 342 Raises: 343 IndexError: If a slice index is out of bound. 344 TypeError: If `spec_spec` contains Tensor. 345 """ 346 347 # TODO(b/177482728): Support tensor input. 348 # TODO(b/177482728): Support slice assign, similar to variable slice assign. 349 350 if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and 351 slice_spec.dtype == dtypes.bool) or 352 (isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool)): 353 tensor = _var_to_tensor(self) 354 return array_ops.boolean_mask(tensor=tensor, mask=slice_spec) 355 356 if not isinstance(slice_spec, (list, tuple)): 357 slice_spec = (slice_spec,) 358 359 s = slice_spec[0] 360 if isinstance(s, slice): 361 first_dim_slice_specs = self._decompose_slice_spec(s) 362 values = [] 363 for i, var in enumerate(self._variables): 364 if first_dim_slice_specs[i] is not None: 365 all_dim_slice_spec = (first_dim_slice_specs[i],) + slice_spec[1:] 366 values.append(var[all_dim_slice_spec]) 367 if s.step is not None and s.step < 0: 368 values.reverse() 369 if not values: 370 return constant_op.constant([], 371 dtype=self._dtype, 372 shape=((0,) + self._shape[1:])) 373 return array_ops.concat(values, axis=0) 374 elif s is Ellipsis: 375 return array_ops.concat([var[slice_spec] for var in self._variables], 376 axis=0) 377 elif s is array_ops.newaxis: 378 return array_ops.concat([var[slice_spec[1:]] for var in self._variables], 379 axis=0)[array_ops.newaxis] 380 else: 381 if isinstance(s, ops.Tensor): 382 raise TypeError( 383 'ShardedVariable: using Tensor for indexing is not allowed.') 384 if s < 0: 385 s += self._shape[0] 386 if s < 0 or s >= self._shape[0]: 387 raise IndexError('slice index %d of dimension 0 out of bounds.' % s) 388 for i in range(len(self._variables)): 389 if i == len(self._variables) - 1 or (s > self._var_offsets[i][0] and 390 s < self._var_offsets[i + 1][0]): 391 return self._variables[i][(s - self._var_offsets[i][0],) + 392 slice_spec[1:]] 393 394 def _decompose_slice_spec(self, slice_spec): 395 """Decompose a global slice_spec into a list of per-variable slice_spec. 396 397 `ShardedVariable` only supports first dimension partitioning, thus 398 `slice_spec` must be for first dimension. 399 400 Args: 401 slice_spec: A python `slice` object that specifies the global slicing. 402 403 Returns: 404 A list of python `slice` objects or None specifying the local slicing for 405 each component variable. None means no slicing. 406 407 For example, given component variables: 408 v0 = [0, 1, 2] 409 v1 = [3, 4, 5] 410 v2 = [6, 7, 8, 9] 411 412 If `slice_spec` is slice(start=None, stop=None, step=None), we will have: 413 v0[returned[0]] = [0, 1, 2] 414 v1[returned[1]] = [3, 4, 5] 415 v2[returned[2]] = [6, 7, 8, 9] 416 If `slice_spec` is slice(start=2, stop=8, step=3), we will have: 417 v0[returned[0]] = [2] 418 v1[returned[1]] = [5] 419 returned[2] == None 420 If `slice_spec` is slice(start=9, stop=3, step=-2), we will have: 421 returned[0] == None 422 v1[returned[1]] = [5] 423 v2[returned[2]] = [9, 7] 424 """ 425 if isinstance(slice_spec.start, ops.Tensor) or isinstance( 426 slice_spec.stop, ops.Tensor) or isinstance(slice_spec.step, ops.Tensor): 427 raise TypeError( 428 'ShardedVariable: using Tensor in slice_spec is not allowed. Please ' 429 'file a feature request with the TensorFlow team.') 430 431 result = [] 432 # Normalize start, end and stop. 433 slice_step = slice_spec.step if slice_spec.step is not None else 1 434 if slice_step == 0: 435 raise ValueError('slice step cannot be zero') 436 slice_start = slice_spec.start 437 if slice_start is None: 438 slice_start = 0 if slice_step > 0 else self._shape[0] - 1 439 elif slice_start < 0: 440 slice_start += self._shape[0] 441 slice_end = slice_spec.stop 442 if slice_end is None: 443 # After the normalization, we no longer interpret negative index, thus 444 # "-1" conceptually refers to the element before the first one, which 445 # doesn't exist. This is to ease the decomposition code. 446 slice_end = self._shape[0] if slice_step > 0 else -1 447 elif slice_end < 0: 448 slice_end += self._shape[0] 449 450 # To find the local slice_spec of each component variable, we start from 451 # the start of the global slice, and iterate through each variable. 452 # When iterating on a variable, we move the cursor (`cur`) to the first 453 # index that falls into the variable's range, which becomes the start of 454 # the variable's local slice_spec. The end of the local_spec is determined 455 # by using whatever is smaller between global slice end and variable range 456 # end. 457 cur = slice_start 458 if slice_step > 0: 459 for i in range(len(self._var_offsets)): 460 var_start = self._var_offsets[i][0] 461 var_end = ( 462 self._var_offsets[i + 1][0] 463 if i < len(self._var_offsets) - 1 else self._shape[0]) 464 if cur < var_start: 465 cur += slice_step * int(math.ceil((var_start - cur) / slice_step)) 466 if cur >= var_end or cur >= slice_end: 467 result.append(None) 468 else: 469 start = cur - var_start 470 end = min(slice_end, var_end) - var_start 471 result.append(slice(start, end, slice_step)) 472 else: # slice_step < 0 473 for i in range(len(self._var_offsets) - 1, -1, -1): 474 var_start = self._var_offsets[i][0] 475 var_end = ( 476 self._var_offsets[i + 1][0] 477 if i < len(self._var_offsets) - 1 else self._shape[0]) 478 if cur >= var_end: 479 cur += slice_step * int(math.ceil((var_end - cur - 1) / slice_step)) 480 if cur < var_start or cur <= slice_end: 481 result.append(None) 482 else: 483 start = cur - var_start 484 if slice_end >= var_start: 485 end = slice_end - var_start 486 else: 487 end = None # no explicit end: slice until hitting the boundary. 488 result.append(slice(start, end, slice_step)) 489 490 result.reverse() 491 492 return result 493 494 @property 495 def _type_spec(self): 496 return ShardedVariableSpec(*( 497 resource_variable_ops.VariableSpec(v.shape, v.dtype) 498 for v in self._variables)) 499 500 @property 501 def variables(self): 502 """The list of `Variable`s that make up the shards of this object.""" 503 if save_context.in_save_context(): 504 return [self._saving_variable] 505 return self._variables 506 507 @property 508 def name(self): 509 """The name of this object. Used for checkpointing.""" 510 return self._name 511 512 @property 513 def dtype(self): 514 """The dtype of all `Variable`s in this object.""" 515 return self._dtype 516 517 @property 518 def shape(self): 519 """The overall shape, combining all shards along axis `0`.""" 520 return self._shape 521 522 def assign(self, value, use_locking=None, name=None, read_value=True): 523 for i, v in enumerate(self._variables): 524 v.assign(array_ops.slice(value, self._var_offsets[i], v.shape.as_list())) 525 526 def assign_add(self, delta, use_locking=False, name=None, read_value=True): 527 for i, v in enumerate(self._variables): 528 v.assign_add( 529 array_ops.slice(delta, self._var_offsets[i], v.shape.as_list())) 530 531 def assign_sub(self, delta, use_locking=False, name=None, read_value=True): 532 for i, v in enumerate(self._variables): 533 v.assign_sub( 534 array_ops.slice(delta, self._var_offsets[i], v.shape.as_list())) 535 536 def _gather_saveables_for_checkpoint(self): 537 """Return a `Saveable` for each shard. See `Trackable`.""" 538 539 def _saveable_factory(name=self.name): 540 """Creates `SaveableObject`s for this `ShardedVariable`.""" 541 saveables = [] 542 dims = len(self._variables[0].shape) 543 var_offset = [0 for _ in range(dims)] 544 for v in self._variables: 545 save_slice_info = variables_lib.Variable.SaveSliceInfo( 546 full_name=self.name, 547 full_shape=self.shape.as_list(), 548 var_offset=copy.copy(var_offset), 549 var_shape=v.shape.as_list()) 550 saveables.append( 551 saveable_object_util.ResourceVariableSaveable( 552 v, save_slice_info.spec, name)) 553 var_offset[0] += int(v.shape[0]) 554 return saveables 555 556 return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} 557 558 def _map_resources(self, save_options): 559 """For implementing `Trackable`.""" 560 obj_map, resource_map = {}, {} 561 for v in self._variables + [self._saving_variable]: 562 v_obj_map, v_resource_map = v._map_resources(save_options) # pylint:disable=protected-access 563 obj_map.update(v_obj_map) 564 resource_map.update(v_resource_map) 565 obj_map[self] = ShardedVariable([obj_map[self._saving_variable]], 566 name=self.name) 567 568 return obj_map, resource_map 569 570 571class ShardedVariable(ShardedVariableMixin, composite_tensor.CompositeTensor): 572 """A container for `Variables` that should be treated as shards. 573 574 Variables that are too large to fit on a single device (e.g., large 575 embeddings) 576 may need to be sharded over multiple devices. This class maintains a list of 577 smaller variables that can be independently stored on separate devices (eg, 578 multiple parameter servers), and saves and restores those variables as if they 579 were a single larger variable. 580 581 Objects of this class can be saved with a given number of shards and then 582 restored from a checkpoint into a different number of shards. 583 584 Objects of this class can be saved to SavedModel format using 585 `tf.saved_model.save`. The SavedModel can be used by programs like TF serving 586 APIs. It is not yet supported to load the SavedModel with 587 `tf.saved_model.load`. 588 589 Since `ShardedVariable` can be saved and then restored to different number of 590 shards depending on the restore environments, for example, TF serving APIs 591 would restore to one shard for serving efficiency, when using 592 `ShardedVariable` in a tf.function, one should generally not assume it has the 593 same number of shards across save and load. 594 595 Sharding is only supported along the first dimension. 596 597 >>> class Model(tf.Module): 598 ... def __init__(self): 599 ... self.sharded_variable = ShardedVariable([ 600 ... tf.Variable([3.0], dtype=tf.float32), 601 ... tf.Variable([2.0], dtype=tf.float32) 602 ... ]) 603 ... 604 ... @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)]) 605 ... def fn(self, x): 606 ... return tf.nn.embedding_lookup(self.sharded_variable.variables, x) 607 ... 608 ... @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)]) 609 ... def serve_fn(self, x): 610 ... return tf.nn.embedding_lookup(self.sharded_variable.variables, x) 611 >>> 612 >>> model = Model() 613 >>> model.fn(1).numpy() 614 2.0 615 >>> tf.saved_model.save(model, export_dir='/tmp/saved_model', 616 ... signatures=model.serve_fn) 617 """ 618 619 @property 620 def _type_spec(self): 621 return ShardedVariableSpec(*( 622 resource_variable_ops.VariableSpec(v.shape, v.dtype) 623 for v in self._variables)) 624 625 626def _var_to_tensor(var, dtype=None, name=None, as_ref=False): 627 """Converts a `ShardedVariable` to a `Tensor`.""" 628 del name 629 if dtype is not None and not dtype.is_compatible_with(var.dtype): 630 raise ValueError( 631 'Incompatible type conversion requested to type {!r} for variable ' 632 'of type {!r}'.format(dtype.name, var.dtype.name)) 633 if as_ref: 634 raise NotImplementedError( 635 "ShardedVariable doesn't support being used as a reference.") 636 # We use op dispatch mechanism to override embedding_lookup ops when called 637 # with ShardedVariable. This requires embedding_lookup ops to raise TypeError 638 # when called with ShardedVariable. However since ShardedVariable can be 639 # converted to a tensor via concat, embedding_lookup ops would silently 640 # do the convertion and never raise a TypeError. To be able to properly 641 # raise a TypeError, namescope is used to detect if this method is called 642 # within a embedding_lookup op. 643 # NOTE: This doesn't work in eager mode since op namescope is always cleared 644 # in eager. This also breaks if user sets the name of embedding_lookup op 645 # with something that doesn't contain str "embedding_lookup". 646 # 647 # TODO(chenkai): Find a more robust way to do this, which should not rely 648 # on namescope. 649 if 'embedding_lookup' in ops.get_name_scope(): 650 raise TypeError('Converting ShardedVariable to tensor in embedding lookup' 651 ' ops is disallowed.') 652 return array_ops.concat(var.variables, axis=0) 653 654 655# Register a conversion function which reads the value of the variable, 656# allowing instances of the class to be used as tensors. 657ops.register_tensor_conversion_function(ShardedVariable, _var_to_tensor) 658 659 660# Override the behavior of embedding_lookup(sharded_variable, ...) 661@dispatch.dispatch_for_types(embedding_ops.embedding_lookup, ShardedVariable) 662def embedding_lookup(params, 663 ids, 664 partition_strategy='mod', 665 name=None, 666 validate_indices=True, 667 max_norm=None): 668 if isinstance(params, list): 669 params = params[0] 670 return embedding_ops.embedding_lookup(params.variables, ids, 671 partition_strategy, name, 672 validate_indices, max_norm) 673 674 675def _raise_when_load(_): 676 # We don't have serialization and deserialization mechanisms for 677 # `ShardedVariable` in 2.x style save/load yet. 678 raise ValueError('Loading `ShardedVariable` is not supported') 679 680 681revived_types.register_revived_type( 682 '_tf_distribute_sharded_variable', 683 lambda obj: isinstance(obj, ShardedVariable), 684 versions=[ 685 revived_types.VersionedTypeRegistration( 686 object_factory=_raise_when_load, 687 version=0, 688 min_producer_version=0, 689 min_consumer_version=0) 690 ]) 691