1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""ShardedVariable class."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import copy
21import math
22import numpy as np
23
24from tensorflow.python.framework import composite_tensor
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.framework import type_spec
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import embedding_ops
32from tensorflow.python.ops import partitioned_variables
33from tensorflow.python.ops import resource_variable_ops
34from tensorflow.python.ops import variables as variables_lib
35from tensorflow.python.saved_model import revived_types
36from tensorflow.python.saved_model import save_context
37from tensorflow.python.training.saving import saveable_object_util
38from tensorflow.python.training.tracking import base as trackable
39from tensorflow.python.util import dispatch
40from tensorflow.python.util.tf_export import tf_export
41
42
43@tf_export('distribute.experimental.partitioners.Partitioner', v1=[])
44class Partitioner(object):
45  """Partitioner base class: all partitiners inherit from this class.
46
47  Partitioners should implement a `__call__` method with the following
48  signature:
49
50  ```python
51  def __call__(self, shape, dtype, axis=0):
52    # Partitions the given `shape` and returns the partition results.
53    # See docstring of `__call__` method for the format of partition results.
54  ```
55  """
56
57  def __call__(self, shape, dtype, axis=0):
58    """Partitions the given `shape` and returns the partition results.
59
60    Examples of a partitioner that allocates a fixed number of shards:
61
62    ```python
63    partitioner = FixedShardsPartitioner(num_shards=2)
64    partitions = partitioner(tf.TensorShape([10, 3], tf.float32), axis=0)
65    print(partitions) # [2, 0]
66    ```
67
68    Args:
69      shape: a `tf.TensorShape`, the shape to partition.
70      dtype: a `tf.dtypes.Dtype` indicating the type of the partition value.
71      axis: The axis to partition along.  Default: outermost axis.
72
73    Returns:
74      A list of integers representing the number of partitions on each axis,
75      where i-th value correponds to i-th axis.
76    """
77    raise NotImplementedError
78
79
80@tf_export('distribute.experimental.partitioners.FixedShardsPartitioner', v1=[])
81class FixedShardsPartitioner(Partitioner):
82  """Partitioner that allocates a fixed number of shards.
83
84  Examples:
85
86  >>> # standalone usage:
87  >>> partitioner = FixedShardsPartitioner(num_shards=2)
88  >>> partitions = partitioner(tf.TensorShape([10, 3]), tf.float32)
89  >>> [2, 1]
90  >>>
91  >>> # use in ParameterServerStrategy
92  >>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
93  >>> #   cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
94
95  """
96
97  def __init__(self, num_shards):
98    """Creates a new `FixedShardsPartitioner`.
99
100    Args:
101      num_shards: `int`, number of shards to partition.
102    """
103    self._num_shards = num_shards
104
105  def __call__(self, shape, dtype, axis=0):
106    del dtype
107    result = [1] * len(shape)
108    result[axis] = min(self._num_shards, shape.dims[axis].value)
109    return result
110
111
112@tf_export('distribute.experimental.partitioners.MinSizePartitioner', v1=[])
113class MinSizePartitioner(Partitioner):
114  """Partitioner that allocates a minimum size per shard.
115
116  This partitioner ensures each shard has at least `min_shard_bytes`, and tries
117  to allocate as many shards as possible, i.e., keeping shard size as small as
118  possible. The maximum number of such shards (upper bound) is given by
119  `max_shards`.
120
121  Examples:
122
123  >>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=2)
124  >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
125  >>> [2, 1]
126  >>> partitioner = MinSizePartitioner(min_shard_bytes=4, max_shards=10)
127  >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
128  >>> [6, 1]
129  >>>
130  >>> # use in ParameterServerStrategy
131  >>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
132  >>> #   cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
133  """
134
135  def __init__(self,
136               min_shard_bytes=256 << 10,
137               max_shards=1,
138               bytes_per_string=16):
139    """Creates a new `MinSizePartitioner`.
140
141    Args:
142      min_shard_bytes: Minimum bytes of each shard. Defaults to 256K.
143      max_shards: Upper bound on the number of shards. Defaults to 1.
144      bytes_per_string: If the partition value is of type string, this provides
145        an estimate of how large each string is.
146    """
147    if min_shard_bytes < 1:
148      raise ValueError('min_shard_bytes must be positive, got: %r' %
149                       min_shard_bytes)
150    if max_shards < 1:
151      raise ValueError('max_shards must be positive, got: %r' % max_shards)
152    if bytes_per_string < 1:
153      raise ValueError('bytes_per_string must be positive, got: %r' %
154                       bytes_per_string)
155    self._min_shard_bytes = min_shard_bytes
156    self._max_shards = max_shards
157    self._bytes_per_string = bytes_per_string
158
159  def __call__(self, shape, dtype, axis=0):
160    return partitioned_variables.min_max_variable_partitioner(
161        max_partitions=self._max_shards,
162        axis=axis,
163        min_slice_size=self._min_shard_bytes,
164        bytes_per_string_element=self._bytes_per_string)(shape, dtype)
165
166
167@tf_export('distribute.experimental.partitioners.MaxSizePartitioner', v1=[])
168class MaxSizePartitioner(Partitioner):
169  """Partitioner that keeps shards below `max_shard_bytes`.
170
171  This partitioner ensures each shard has at most `max_shard_bytes`, and tries
172  to allocate as few shards as possible, i.e., keeping shard size as large
173  as possible.
174
175  If the partitioner hits the `max_shards` limit, then each shard may end up
176  larger than `max_shard_bytes`. By default `max_shards` equals `None` and no
177  limit on the number of shards is enforced.
178
179  Examples:
180
181  >>> partitioner = MaxSizePartitioner(max_shard_bytes=4)
182  >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
183  >>> [6, 1]
184  >>> partitioner = MaxSizePartitioner(max_shard_bytes=4, max_shards=2)
185  >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
186  >>> [2, 1]
187  >>> partitioner = MaxSizePartitioner(max_shard_bytes=1024)
188  >>> partitions = partitioner(tf.TensorShape([6, 1]), tf.float32)
189  >>> [1, 1]
190  >>>
191  >>> # use in ParameterServerStrategy
192  >>> # strategy = tf.distribute.experimental.ParameterServerStrategy(
193  >>> #   cluster_resolver=cluster_resolver, variable_partitioner=partitioner)
194  """
195
196  def __init__(self, max_shard_bytes, max_shards=None, bytes_per_string=16):
197    """Creates a new `MaxSizePartitioner`.
198
199    Args:
200      max_shard_bytes: The maximum size any given shard is allowed to be.
201      max_shards: The maximum number of shards in `int` created taking
202        precedence over `max_shard_bytes`.
203      bytes_per_string: If the partition value is of type string, this provides
204        an estimate of how large each string is.
205    """
206    if max_shard_bytes < 1:
207      raise ValueError('max_shard_bytes must be positive, got: %r' %
208                       max_shard_bytes)
209    if max_shards and max_shards < 1:
210      raise ValueError('max_shards must be positive, got: %r' % max_shards)
211    if bytes_per_string < 1:
212      raise ValueError('bytes_per_string must be positive, got: %r' %
213                       bytes_per_string)
214
215    self._max_shard_bytes = max_shard_bytes
216    self._max_shards = max_shards
217    self._bytes_per_string = bytes_per_string
218
219  def __call__(self, shape, dtype, axis=0):
220    return partitioned_variables.variable_axis_size_partitioner(
221        max_shard_bytes=self._max_shard_bytes,
222        max_shards=self._max_shards,
223        bytes_per_string_element=self._bytes_per_string,
224        axis=axis)(shape, dtype)
225
226
227class ShardedVariableSpec(type_spec.TypeSpec):
228  """Type specification for a `ShardedVariable`."""
229
230  __slots__ = ['_variable_specs']
231
232  value_type = property(lambda self: ShardedVariable)
233
234  def __init__(self, *variable_specs):
235    self._variable_specs = tuple(variable_specs)
236
237  def _serialize(self):
238    return self._variable_specs
239
240  @property
241  def _component_specs(self):
242    return self._variable_specs
243
244  def _to_components(self, value):
245    return value.variables
246
247  def _from_components(self, variables):
248    return ShardedVariable(variables)
249
250
251class ShardedVariableMixin(trackable.Trackable):
252  """Mixin for ShardedVariable."""
253
254  # TODO(b/170877138): Remove this mixin once fixed. This mixin is required
255  # since TPUShardedVariable can't be a CompositeTensor.
256
257  def __init__(self, variables, name='ShardedVariable'):
258    """Treats `variables` as shards of a larger Variable.
259
260
261    Example:
262
263    ```
264    variables = [
265      tf.Variable(..., shape=(10, 100), dtype=tf.float32),
266      tf.Variable(..., shape=(15, 100), dtype=tf.float32),
267      tf.Variable(..., shape=(5, 100), dtype=tf.float32)
268    ]
269    sharded_variable = ShardedVariableMixin(variables)
270    assert sharded_variable.shape.as_list() == [30, 100]
271    ```
272
273    Args:
274      variables: A list of `ResourceVariable`s that comprise this sharded
275        variable. Variables should not be shared between different
276        `ShardedVariableMixin` objects.
277      name: String. Name of this container. Defaults to "ShardedVariable".
278    """
279    super(ShardedVariableMixin, self).__init__()
280    self._variables = variables
281    self._name = name
282
283    first_var = variables[0]
284
285    if any(not isinstance(v, variables_lib.Variable) for v in variables):
286      raise ValueError(
287          'Expected a list of `Variable`s, found: {}'.format(variables))
288
289    var_dtypes = {v.dtype for v in variables}
290    if len(var_dtypes) > 1:
291      raise ValueError(
292          'All `Variable`s must have the same dtype, found: {}'.format(
293              [v.dtype for v in variables]))
294    self._dtype = first_var.dtype
295
296    # All variables must have the same shape for axes > 0.
297    higher_dim_shapes = {tuple(v.shape.as_list()[1:]) for v in variables}
298    if len(higher_dim_shapes) > 1:
299      raise ValueError(
300          'All `Variables`s must have the same shapes except for the first '
301          'axis, found {}'.format([v.shape for v in variables]))
302    first_dim = sum(int(v.shape[0]) for v in variables)
303    self._shape = tensor_shape.TensorShape([first_dim] + first_var.shape[1:])
304    self._var_offsets = [
305        [0 for _ in range(len(first_var.shape))] for _ in range(len(variables))
306    ]
307    for i in range(1, len(variables)):
308      # Always partition on the first axis. Offsets on other axes are 0.
309      self._var_offsets[i][0] += (
310          self._var_offsets[i - 1][0] + variables[i - 1].shape[0])
311
312    save_slice_info = [v._get_save_slice_info() for v in variables]  # pylint: disable=protected-access
313    if any(slice_info is not None for slice_info in save_slice_info):
314      raise ValueError('`SaveSliceInfo` should not be set for `Variable`s. '
315                       '`ShardedVariable` will infer `SaveSliceInfo` according '
316                       'to the order of the `Variable`s in the list passed to '
317                       'the constructor. Found {}'.format(save_slice_info))
318
319    # We create an uninitialized saving_variable with the full shape, which can
320    # be later captured in signatures so that the signatures can treat this
321    # ShardedVariable as one single variable.
322    self._saving_variable = resource_variable_ops.UninitializedVariable(
323        shape=self._shape, dtype=self._dtype, name=self._name)
324
325  def __iter__(self):
326    """Return an iterable for accessing the underlying sharded variables."""
327    return iter(self._variables)
328
329  def __getitem__(self, slice_spec):
330    """Extracts the specified region as a Tensor from the sharded variable.
331
332    The API contract is identical to `Tensor.__getitem__`. Assignment to the
333    sliced range is not yet supported.
334
335    Args:
336      slice_spec: The arguments to __getitem__, specifying the global slicing of
337        the sharded variable.
338
339    Returns:
340      The appropriate slice of tensor based on `slice_spec`.
341
342    Raises:
343      IndexError: If a slice index is out of bound.
344      TypeError: If `spec_spec` contains Tensor.
345    """
346
347    # TODO(b/177482728): Support tensor input.
348    # TODO(b/177482728): Support slice assign, similar to variable slice assign.
349
350    if (isinstance(slice_spec, bool) or (isinstance(slice_spec, ops.Tensor) and
351                                         slice_spec.dtype == dtypes.bool) or
352        (isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool)):
353      tensor = _var_to_tensor(self)
354      return array_ops.boolean_mask(tensor=tensor, mask=slice_spec)
355
356    if not isinstance(slice_spec, (list, tuple)):
357      slice_spec = (slice_spec,)
358
359    s = slice_spec[0]
360    if isinstance(s, slice):
361      first_dim_slice_specs = self._decompose_slice_spec(s)
362      values = []
363      for i, var in enumerate(self._variables):
364        if first_dim_slice_specs[i] is not None:
365          all_dim_slice_spec = (first_dim_slice_specs[i],) + slice_spec[1:]
366          values.append(var[all_dim_slice_spec])
367      if s.step is not None and s.step < 0:
368        values.reverse()
369      if not values:
370        return constant_op.constant([],
371                                    dtype=self._dtype,
372                                    shape=((0,) + self._shape[1:]))
373      return array_ops.concat(values, axis=0)
374    elif s is Ellipsis:
375      return array_ops.concat([var[slice_spec] for var in self._variables],
376                              axis=0)
377    elif s is array_ops.newaxis:
378      return array_ops.concat([var[slice_spec[1:]] for var in self._variables],
379                              axis=0)[array_ops.newaxis]
380    else:
381      if isinstance(s, ops.Tensor):
382        raise TypeError(
383            'ShardedVariable: using Tensor for indexing is not allowed.')
384      if s < 0:
385        s += self._shape[0]
386      if s < 0 or s >= self._shape[0]:
387        raise IndexError('slice index %d of dimension 0 out of bounds.' % s)
388      for i in range(len(self._variables)):
389        if i == len(self._variables) - 1 or (s > self._var_offsets[i][0] and
390                                             s < self._var_offsets[i + 1][0]):
391          return self._variables[i][(s - self._var_offsets[i][0],) +
392                                    slice_spec[1:]]
393
394  def _decompose_slice_spec(self, slice_spec):
395    """Decompose a global slice_spec into a list of per-variable slice_spec.
396
397    `ShardedVariable` only supports first dimension partitioning, thus
398    `slice_spec` must be for first dimension.
399
400    Args:
401      slice_spec: A python `slice` object that specifies the global slicing.
402
403    Returns:
404      A list of python `slice` objects or None specifying the local slicing for
405      each component variable. None means no slicing.
406
407    For example, given component variables:
408      v0 = [0, 1, 2]
409      v1 = [3, 4, 5]
410      v2 = [6, 7, 8, 9]
411
412    If `slice_spec` is slice(start=None, stop=None, step=None), we will have:
413      v0[returned[0]] = [0, 1, 2]
414      v1[returned[1]] = [3, 4, 5]
415      v2[returned[2]] = [6, 7, 8, 9]
416    If `slice_spec` is slice(start=2, stop=8, step=3), we will have:
417      v0[returned[0]] = [2]
418      v1[returned[1]] = [5]
419      returned[2] == None
420    If `slice_spec` is slice(start=9, stop=3, step=-2), we will have:
421      returned[0] == None
422      v1[returned[1]] = [5]
423      v2[returned[2]] = [9, 7]
424    """
425    if isinstance(slice_spec.start, ops.Tensor) or isinstance(
426        slice_spec.stop, ops.Tensor) or isinstance(slice_spec.step, ops.Tensor):
427      raise TypeError(
428          'ShardedVariable: using Tensor in slice_spec is not allowed. Please '
429          'file a feature request with the TensorFlow team.')
430
431    result = []
432    # Normalize start, end and stop.
433    slice_step = slice_spec.step if slice_spec.step is not None else 1
434    if slice_step == 0:
435      raise ValueError('slice step cannot be zero')
436    slice_start = slice_spec.start
437    if slice_start is None:
438      slice_start = 0 if slice_step > 0 else self._shape[0] - 1
439    elif slice_start < 0:
440      slice_start += self._shape[0]
441    slice_end = slice_spec.stop
442    if slice_end is None:
443      # After the normalization, we no longer interpret negative index, thus
444      # "-1" conceptually refers to the element before the first one, which
445      # doesn't exist. This is to ease the decomposition code.
446      slice_end = self._shape[0] if slice_step > 0 else -1
447    elif slice_end < 0:
448      slice_end += self._shape[0]
449
450    # To find the local slice_spec of each component variable, we start from
451    # the start of the global slice, and iterate through each variable.
452    # When iterating on a variable, we move the cursor (`cur`) to the first
453    # index that falls into the variable's range, which becomes the start of
454    # the variable's local slice_spec. The end of the local_spec is determined
455    # by using whatever is smaller between global slice end and variable range
456    # end.
457    cur = slice_start
458    if slice_step > 0:
459      for i in range(len(self._var_offsets)):
460        var_start = self._var_offsets[i][0]
461        var_end = (
462            self._var_offsets[i + 1][0]
463            if i < len(self._var_offsets) - 1 else self._shape[0])
464        if cur < var_start:
465          cur += slice_step * int(math.ceil((var_start - cur) / slice_step))
466        if cur >= var_end or cur >= slice_end:
467          result.append(None)
468        else:
469          start = cur - var_start
470          end = min(slice_end, var_end) - var_start
471          result.append(slice(start, end, slice_step))
472    else:  # slice_step < 0
473      for i in range(len(self._var_offsets) - 1, -1, -1):
474        var_start = self._var_offsets[i][0]
475        var_end = (
476            self._var_offsets[i + 1][0]
477            if i < len(self._var_offsets) - 1 else self._shape[0])
478        if cur >= var_end:
479          cur += slice_step * int(math.ceil((var_end - cur - 1) / slice_step))
480        if cur < var_start or cur <= slice_end:
481          result.append(None)
482        else:
483          start = cur - var_start
484          if slice_end >= var_start:
485            end = slice_end - var_start
486          else:
487            end = None  # no explicit end: slice until hitting the boundary.
488          result.append(slice(start, end, slice_step))
489
490      result.reverse()
491
492    return result
493
494  @property
495  def _type_spec(self):
496    return ShardedVariableSpec(*(
497        resource_variable_ops.VariableSpec(v.shape, v.dtype)
498        for v in self._variables))
499
500  @property
501  def variables(self):
502    """The list of `Variable`s that make up the shards of this object."""
503    if save_context.in_save_context():
504      return [self._saving_variable]
505    return self._variables
506
507  @property
508  def name(self):
509    """The name of this object. Used for checkpointing."""
510    return self._name
511
512  @property
513  def dtype(self):
514    """The dtype of all `Variable`s in this object."""
515    return self._dtype
516
517  @property
518  def shape(self):
519    """The overall shape, combining all shards along axis `0`."""
520    return self._shape
521
522  def assign(self, value, use_locking=None, name=None, read_value=True):
523    for i, v in enumerate(self._variables):
524      v.assign(array_ops.slice(value, self._var_offsets[i], v.shape.as_list()))
525
526  def assign_add(self, delta, use_locking=False, name=None, read_value=True):
527    for i, v in enumerate(self._variables):
528      v.assign_add(
529          array_ops.slice(delta, self._var_offsets[i], v.shape.as_list()))
530
531  def assign_sub(self, delta, use_locking=False, name=None, read_value=True):
532    for i, v in enumerate(self._variables):
533      v.assign_sub(
534          array_ops.slice(delta, self._var_offsets[i], v.shape.as_list()))
535
536  def _gather_saveables_for_checkpoint(self):
537    """Return a `Saveable` for each shard. See `Trackable`."""
538
539    def _saveable_factory(name=self.name):
540      """Creates `SaveableObject`s for this `ShardedVariable`."""
541      saveables = []
542      dims = len(self._variables[0].shape)
543      var_offset = [0 for _ in range(dims)]
544      for v in self._variables:
545        save_slice_info = variables_lib.Variable.SaveSliceInfo(
546            full_name=self.name,
547            full_shape=self.shape.as_list(),
548            var_offset=copy.copy(var_offset),
549            var_shape=v.shape.as_list())
550        saveables.append(
551            saveable_object_util.ResourceVariableSaveable(
552                v, save_slice_info.spec, name))
553        var_offset[0] += int(v.shape[0])
554      return saveables
555
556    return {trackable.VARIABLE_VALUE_KEY: _saveable_factory}
557
558  def _map_resources(self, save_options):
559    """For implementing `Trackable`."""
560    obj_map, resource_map = {}, {}
561    for v in self._variables + [self._saving_variable]:
562      v_obj_map, v_resource_map = v._map_resources(save_options)  # pylint:disable=protected-access
563      obj_map.update(v_obj_map)
564      resource_map.update(v_resource_map)
565    obj_map[self] = ShardedVariable([obj_map[self._saving_variable]],
566                                    name=self.name)
567
568    return obj_map, resource_map
569
570
571class ShardedVariable(ShardedVariableMixin, composite_tensor.CompositeTensor):
572  """A container for `Variables` that should be treated as shards.
573
574  Variables that are too large to fit on a single device (e.g., large
575  embeddings)
576  may need to be sharded over multiple devices. This class maintains a list of
577  smaller variables that can be independently stored on separate devices (eg,
578  multiple parameter servers), and saves and restores those variables as if they
579  were a single larger variable.
580
581  Objects of this class can be saved with a given number of shards and then
582  restored from a checkpoint into a different number of shards.
583
584  Objects of this class can be saved to SavedModel format using
585  `tf.saved_model.save`. The SavedModel can be used by programs like TF serving
586  APIs. It is not yet supported to load the SavedModel with
587  `tf.saved_model.load`.
588
589  Since `ShardedVariable` can be saved and then restored to different number of
590  shards depending on the restore environments, for example, TF serving APIs
591  would restore to one shard for serving efficiency, when using
592  `ShardedVariable` in a tf.function, one should generally not assume it has the
593  same number of shards across save and load.
594
595  Sharding is only supported along the first dimension.
596
597  >>> class Model(tf.Module):
598  ...   def __init__(self):
599  ...     self.sharded_variable = ShardedVariable([
600  ...       tf.Variable([3.0], dtype=tf.float32),
601  ...       tf.Variable([2.0], dtype=tf.float32)
602  ...     ])
603  ...
604  ...   @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)])
605  ...   def fn(self, x):
606  ...     return tf.nn.embedding_lookup(self.sharded_variable.variables, x)
607  ...
608  ...   @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int32)])
609  ...   def serve_fn(self, x):
610  ...     return tf.nn.embedding_lookup(self.sharded_variable.variables, x)
611  >>>
612  >>> model = Model()
613  >>> model.fn(1).numpy()
614  2.0
615  >>> tf.saved_model.save(model, export_dir='/tmp/saved_model',
616  ...   signatures=model.serve_fn)
617  """
618
619  @property
620  def _type_spec(self):
621    return ShardedVariableSpec(*(
622        resource_variable_ops.VariableSpec(v.shape, v.dtype)
623        for v in self._variables))
624
625
626def _var_to_tensor(var, dtype=None, name=None, as_ref=False):
627  """Converts a `ShardedVariable` to a `Tensor`."""
628  del name
629  if dtype is not None and not dtype.is_compatible_with(var.dtype):
630    raise ValueError(
631        'Incompatible type conversion requested to type {!r} for variable '
632        'of type {!r}'.format(dtype.name, var.dtype.name))
633  if as_ref:
634    raise NotImplementedError(
635        "ShardedVariable doesn't support being used as a reference.")
636  # We use op dispatch mechanism to override embedding_lookup ops when called
637  # with ShardedVariable. This requires embedding_lookup ops to raise TypeError
638  # when called with ShardedVariable. However since ShardedVariable can be
639  # converted to a tensor via concat, embedding_lookup ops would silently
640  # do the convertion and never raise a TypeError. To be able to properly
641  # raise a TypeError, namescope is used to detect if this method is called
642  # within a embedding_lookup op.
643  # NOTE: This doesn't work in eager mode since op namescope is always cleared
644  # in eager. This also breaks if user sets the name of embedding_lookup op
645  # with something that doesn't contain str "embedding_lookup".
646  #
647  # TODO(chenkai): Find a more robust way to do this, which should not rely
648  # on namescope.
649  if 'embedding_lookup' in ops.get_name_scope():
650    raise TypeError('Converting ShardedVariable to tensor in embedding lookup'
651                    ' ops is disallowed.')
652  return array_ops.concat(var.variables, axis=0)
653
654
655# Register a conversion function which reads the value of the variable,
656# allowing instances of the class to be used as tensors.
657ops.register_tensor_conversion_function(ShardedVariable, _var_to_tensor)
658
659
660# Override the behavior of embedding_lookup(sharded_variable, ...)
661@dispatch.dispatch_for_types(embedding_ops.embedding_lookup, ShardedVariable)
662def embedding_lookup(params,
663                     ids,
664                     partition_strategy='mod',
665                     name=None,
666                     validate_indices=True,
667                     max_norm=None):
668  if isinstance(params, list):
669    params = params[0]
670  return embedding_ops.embedding_lookup(params.variables, ids,
671                                        partition_strategy, name,
672                                        validate_indices, max_norm)
673
674
675def _raise_when_load(_):
676  # We don't have serialization and deserialization mechanisms for
677  # `ShardedVariable` in 2.x style save/load yet.
678  raise ValueError('Loading `ShardedVariable` is not supported')
679
680
681revived_types.register_revived_type(
682    '_tf_distribute_sharded_variable',
683    lambda obj: isinstance(obj, ShardedVariable),
684    versions=[
685        revived_types.VersionedTypeRegistration(
686            object_factory=_raise_when_load,
687            version=0,
688            min_producer_version=0,
689            min_consumer_version=0)
690    ])
691