1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operations for embeddings."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from six.moves import xrange  # pylint: disable=redefined-builtin
21
22from tensorflow.python.compat import compat
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import sparse_tensor
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import clip_ops
30# Imports gradient definitions.
31from tensorflow.python.ops import data_flow_grad  # pylint: disable=unused-import
32from tensorflow.python.ops import data_flow_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import resource_variable_ops
35from tensorflow.python.ops import sparse_ops
36from tensorflow.python.ops import variables
37from tensorflow.python.ops.ragged import ragged_functional_ops
38from tensorflow.python.ops.ragged import ragged_tensor
39from tensorflow.python.util import dispatch
40from tensorflow.python.util.tf_export import tf_export
41
42
43def _clip(params, ids, max_norm):
44  """Helper function for _embedding_lookup_and_transform.
45
46  This function optionally clips embeddings to an l2-norm of max_norm.
47
48  Args:
49    params: A `Tensor` of embeddings retrieved by `gather`.
50    ids: The `ids` argument that was passed to `gather`.
51    max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
52      than this value.
53
54  Returns:
55    A `Tensor` with the same type as `params`.
56  """
57
58  def _rank(x):
59    """Helper function to retrieve the rank of a tensor.
60
61    Args:
62      x: Something convertible to `Tensor`.
63
64    Returns:
65      Either a pair `(rank, True)` where `rank` is an integer or a pair
66      `(rank, False)` where `rank` is an integer `Tensor`. In either case,
67      `rank` is the rank of `x`.
68    """
69    rank = ops.convert_to_tensor(x).get_shape().ndims
70    if rank:
71      return rank, True
72    else:
73      return array_ops.rank(x), False
74
75  if max_norm is None:
76    return params
77  ids_rank, ids_static = _rank(ids)
78  params_rank, params_static = _rank(params)
79  return clip_ops.clip_by_norm(
80      params,
81      max_norm,
82      axes=(list(range(ids_rank, params_rank)) if ids_static and params_static
83            else math_ops.range(ids_rank, params_rank)))
84
85
86def _embedding_lookup_and_transform(params,
87                                    ids,
88                                    partition_strategy="mod",
89                                    name=None,
90                                    max_norm=None,
91                                    transform_fn=None):
92  """Helper function for embedding_lookup and _compute_sampled_logits.
93
94  This function is a generalization of embedding_lookup that optionally
95  applies a caller-specified transformation to each embedding. This is
96  done through the `transform_fn` argument. If provided, the function is
97  applied to each partitioned tensor of retrieved embeddings, colocated
98  with the embeddings. This function will be called with a single `Tensor`
99  argument of the same type as the `params` tensor and should return a
100  `Tensor`. The shape of the argument will be the same as `params` except
101  for the size of the first dimension. The first dimension of the result's
102  shape must be the same size as the argument's.
103
104  Args:
105    params: See embedding_lookup.
106    ids: See embedding_lookup.
107    partition_strategy: See embedding_lookup.
108    name: See embedding_lookup.
109    max_norm: See embedding_lookup.
110    transform_fn: An optional function to apply to each retrieved embedding. If
111      max_norm is provided, transform_fn is applied to the norm-limited
112      embeddings.
113
114  Returns:
115    See embedding_lookup for details.
116  Raises:
117    ValueError: If `params` is empty.
118  """
119  if params is None:
120    raise ValueError("params must be specified")
121  if isinstance(params, (list, tuple)) and not params:
122    raise ValueError("Need at least one param")
123  if isinstance(params, variables.PartitionedVariable):
124    params = list(params)  # Iterate to get the underlying Variables.
125  if not isinstance(params, list):
126    params = [params]
127
128  with ops.name_scope(name, "embedding_lookup", params + [ids]) as name:
129    np = len(params)  # Number of partitions
130    # Preserve the resource variable status to avoid accidental dense reads.
131    if not any(
132        isinstance(p, resource_variable_ops.ResourceVariable) for p in params):
133      params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
134    ids = ops.convert_to_tensor(ids, name="ids")
135    if np == 1 and (not transform_fn or ids.get_shape().ndims == 1):
136      with ops.colocate_with(params[0]):
137        result = _clip(
138            array_ops.gather(params[0], ids, name=name), ids, max_norm)
139        if transform_fn:
140          result = transform_fn(result)
141      # Make sure the final result does not have colocation constraints on the
142      # params. Similar to the case np > 1 where parallel_dynamic_stitch is
143      # outside the scioe of all with ops.colocate_with(params[p]).
144      return array_ops.identity(result)
145    else:
146      # Flatten the ids. There are two cases where we need to do this.
147      # - There is more than one params tensor.
148      # - There is a transform_fn and ids is not statically known to be 1-D.
149      #   We must flatten in this case because transform_fn expects a flat
150      #   tensor of embeddings.
151      flat_ids = array_ops.reshape(ids, [-1])
152      original_indices = math_ops.range(array_ops.size(flat_ids))
153
154      # Create p_assignments and set new_ids depending on the strategy.
155      if partition_strategy == "mod":
156        p_assignments = flat_ids % np
157        new_ids = flat_ids // np
158      elif partition_strategy == "div":
159        # Compute num_total_ids as the sum of dim-0 of params, then assign to
160        # partitions based on a constant number of ids per partition. Optimize
161        # if we already know the full shape statically.
162        dim_0_size = tensor_shape.Dimension(
163            tensor_shape.dimension_value(params[0].get_shape()[0]))
164        for p in xrange(1, np):
165          dim_0_size += tensor_shape.Dimension(
166              tensor_shape.dimension_value(params[p].get_shape()[0]))
167        if dim_0_size.value:
168          num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
169        else:
170          dim_0_sizes = []
171          for p in xrange(np):
172            param_p_dim = tensor_shape.dimension_value(params[p].get_shape()[0])
173            if param_p_dim is not None:
174              dim_0_sizes.append(param_p_dim)
175            else:
176              with ops.colocate_with(params[p]):
177                dim_0_sizes.append(array_ops.shape(params[p])[0])
178          num_total_ids = math_ops.reduce_sum(
179              math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
180        ids_per_partition = num_total_ids // np
181        extras = num_total_ids % np
182
183        p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1),
184                                         (flat_ids - extras) //
185                                         ids_per_partition)
186
187        # Emulate a conditional using a boolean indicator tensor
188        new_ids = array_ops.where(p_assignments < extras,
189                                  flat_ids % (ids_per_partition + 1),
190                                  (flat_ids - extras) % ids_per_partition)
191      else:
192        raise ValueError("Unrecognized partition strategy: " +
193                         partition_strategy)
194
195      # Cast partition assignments to int32 for use in dynamic_partition.
196      # There really should not be more than 2^32 partitions.
197      p_assignments = math_ops.cast(p_assignments, dtypes.int32)
198      # Partition list of ids based on assignments into np separate lists
199      gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np)
200      # Similarly, partition the original indices.
201      pindices = data_flow_ops.dynamic_partition(original_indices,
202                                                 p_assignments, np)
203      # Do np separate lookups, finding embeddings for plist[p] in params[p]
204      partitioned_result = []
205      for p in xrange(np):
206        pids = gather_ids[p]
207        with ops.colocate_with(params[p]):
208          result = array_ops.gather(params[p], pids)
209          if transform_fn:
210            # If transform_fn is provided, the clip_by_norm precedes
211            # the transform and hence must be co-located. See below
212            # for the counterpart if transform_fn is not provided.
213            result = transform_fn(_clip(result, pids, max_norm))
214        partitioned_result.append(result)
215      # Stitch these back together
216      ret = data_flow_ops.parallel_dynamic_stitch(
217          pindices, partitioned_result, name=name)
218
219      # Determine the static element shape.
220      if transform_fn is None:
221        element_shape_s = params[0].get_shape()[1:]
222        for p in params[1:]:
223          element_shape_s = element_shape_s.merge_with(p.get_shape()[1:])
224      else:
225        element_shape_s = ret.get_shape()[1:]
226
227      # Compute the dynamic element shape.
228      if element_shape_s.is_fully_defined():
229        element_shape_d = element_shape_s
230      elif transform_fn is None:
231        # It's important that we compute params[0].shape on the right device
232        # to avoid data motion.
233        with ops.colocate_with(params[0]):
234          params_shape = array_ops.shape(params[0])
235        element_shape_d = params_shape[1:]
236      else:
237        element_shape_d = array_ops.shape(ret)[1:]
238
239      # Reshape to reverse the flattening of ids.
240      ret = array_ops.reshape(
241          ret, array_ops.concat([array_ops.shape(ids), element_shape_d], 0))
242
243      # Normally the reshape is sufficient, but setting shape explicitly
244      # teaches shape inference that params[1:].get_shape() matters
245      # (in the case that transform_fn is None).
246      ret.set_shape(ids.get_shape().concatenate(element_shape_s))
247      if not transform_fn:
248        # If transform_fn was provided, the clip_by_norm was done above.
249        ret = _clip(ret, ids, max_norm)
250      return ret
251
252
253@tf_export(v1=["nn.embedding_lookup"])
254@dispatch.add_dispatch_support
255def embedding_lookup(
256    params,
257    ids,
258    partition_strategy="mod",
259    name=None,
260    validate_indices=True,  # pylint: disable=unused-argument
261    max_norm=None):
262  """Looks up embeddings for the given `ids` from a list of tensors.
263
264  This function is used to perform parallel lookups on the list of tensors in
265  `params`.  It is a generalization of `tf.gather`, where `params` is
266  interpreted as a partitioning of a large embedding tensor.  `params` may be
267  a `PartitionedVariable` as returned by using `tf.compat.v1.get_variable()`
268  with a partitioner.
269
270  If `len(params) > 1`, each element `id` of `ids` is partitioned between
271  the elements of `params` according to the `partition_strategy`.
272  In all strategies, if the id space does not evenly divide the number of
273  partitions, each of the first `(max_id + 1) % len(params)` partitions will
274  be assigned one more id.
275
276  If `partition_strategy` is `"mod"`, we assign each id to partition
277  `p = id % len(params)`. For instance,
278  13 ids are split across 5 partitions as:
279  `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`
280
281  If `partition_strategy` is `"div"`, we assign ids to partitions in a
282  contiguous manner. In this case, 13 ids are split across 5 partitions as:
283  `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`
284
285  If the input ids are ragged tensors, partition variables are not supported and
286  the partition strategy and the max_norm are ignored.
287  The results of the lookup are concatenated into a dense
288  tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
289
290  Args:
291    params: A single tensor representing the complete embedding tensor, or a
292      list of P tensors all of same shape except for the first dimension,
293      representing sharded embedding tensors.  Alternatively, a
294      `PartitionedVariable`, created by partitioning along dimension 0. Each
295      element must be appropriately sized for the given `partition_strategy`.
296    ids: A `Tensor` or a 'RaggedTensor' with type `int32` or `int64` containing
297      the ids to be looked up in `params`.
298    partition_strategy: A string specifying the partitioning strategy, relevant
299      if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
300      is `"mod"`.
301    name: A name for the operation (optional).
302    validate_indices: DEPRECATED. If this operation is assigned to CPU, values
303      in `indices` are always validated to be within range.  If assigned to GPU,
304      out-of-bound indices result in safe but unspecified behavior, which may
305      include raising an error.
306    max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
307      than this value.
308
309  Returns:
310    A `Tensor` or a 'RaggedTensor', depending on the input, with the same type
311    as the tensors in `params`.
312
313  Raises:
314    ValueError: If `params` is empty.
315  """
316  if isinstance(ids, ragged_tensor.RaggedTensor):
317    return embedding_lookup_ragged(params, ids,
318                                   partition_strategy=partition_strategy,
319                                   max_norm=max_norm,
320                                   name=name)
321
322  return _embedding_lookup_and_transform(
323      params=params,
324      ids=ids,
325      partition_strategy=partition_strategy,
326      name=name,
327      max_norm=max_norm,
328      transform_fn=None)
329
330
331@tf_export("nn.embedding_lookup", v1=[])
332@dispatch.add_dispatch_support
333def embedding_lookup_v2(params, ids, max_norm=None, name=None):
334  """Looks up embeddings for the given `ids` from a list of tensors.
335
336  This function is used to perform parallel lookups on the list of tensors in
337  `params`.  It is a generalization of `tf.gather`, where `params` is
338  interpreted as a partitioning of a large embedding tensor.
339
340  If `len(params) > 1`, each element `id` of `ids` is partitioned between the
341  elements of `params` according to the "div" partition strategy, which means we
342  assign ids to partitions in a contiguous manner. For instance, 13 ids are
343  split across 5 partitions as:
344  `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`.
345
346  If the id space does not evenly divide the number of partitions, each of the
347  first `(max_id + 1) % len(params)` partitions will be assigned one more id.
348
349  The results of the lookup are concatenated into a dense
350  tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
351
352  Args:
353    params: A single tensor representing the complete embedding tensor, or a
354      list of tensors all of same shape except for the first dimension,
355      representing sharded embedding tensors following "div" partition strategy.
356    ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked
357      up in `params`.
358    max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
359      than this value.
360    name: A name for the operation (optional).
361
362  Returns:
363    A `Tensor` with the same type as the tensors in `params`.
364
365    For instance, if `params` is a 5x2 matrix:
366
367    ```python
368    [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
369    ```
370
371    or a list of matrices:
372
373    ```python
374    params[0]: [[1, 2], [3, 4]]
375    params[1]: [[5, 6], [7, 8]]
376    params[2]: [[9, 10]]
377    ```
378
379    and `ids` is:
380
381    ```python
382    [0, 3, 4]
383    ```
384
385    The output will be a 3x2 matrix:
386
387    ```python
388    [[1, 2], [7, 8], [9, 10]]
389    ```
390
391  Raises:
392    ValueError: If `params` is empty.
393  """
394  return embedding_lookup(params, ids, "div", name, max_norm=max_norm)
395
396
397@tf_export(v1=["nn.embedding_lookup_sparse"])
398@dispatch.add_dispatch_support
399def embedding_lookup_sparse(params,
400                            sp_ids,
401                            sp_weights,
402                            partition_strategy="mod",
403                            name=None,
404                            combiner=None,
405                            max_norm=None):
406  """Looks up embeddings for the given ids and weights from a list of tensors.
407
408  This op assumes that there is at least one id for each row in the dense tensor
409  represented by sp_ids (i.e. there are no rows with empty features), and that
410  all the indices of sp_ids are in canonical row-major order.
411
412  `sp_ids` and `sp_weights` (if not None) are `SparseTensor`s with rank of 2.
413  Embeddings are always aggregated along the last dimension.
414
415  It also assumes that all id values lie in the range [0, p0), where p0
416  is the sum of the size of params along dimension 0.
417
418  Args:
419    params: A single tensor representing the complete embedding tensor, or a
420      list tensors all of same shape except for the first dimension,
421      representing sharded embedding tensors. Alternatively, a
422      `PartitionedVariable`, created by partitioning along dimension 0. Each
423      element must be appropriately sized for the given `partition_strategy`.
424    sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size
425      and M is arbitrary.
426    sp_weights: either a `SparseTensor` of float / double weights, or `None` to
427      indicate all weights should be taken to be 1. If specified, `sp_weights`
428      must have exactly the same shape and indices as `sp_ids`.
429    partition_strategy: A string specifying the partitioning strategy, relevant
430      if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
431      is `"mod"`. See `tf.nn.embedding_lookup` for more details.
432    name: Optional name for the op.
433    combiner: A string specifying the reduction op. Currently "mean", "sqrtn"
434      and "sum" are supported. "sum" computes the weighted sum of the embedding
435      results for each row. "mean" is the weighted sum divided by the total
436      weight. "sqrtn" is the weighted sum divided by the square root of the sum
437      of the squares of the weights. Defaults to `mean`.
438    max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
439      than this value, before combining.
440
441  Returns:
442    A dense tensor representing the combined embeddings for the
443    sparse ids. For each row in the dense tensor represented by `sp_ids`, the op
444    looks up the embeddings for all ids in that row, multiplies them by the
445    corresponding weight, and combines these embeddings as specified.
446
447    In other words, if
448
449      `shape(combined params) = [p0, p1, ..., pm]`
450
451    and
452
453      `shape(sp_ids) = shape(sp_weights) = [d0, d1]`
454
455    then
456
457      `shape(output) = [d0, p1, ..., pm]`.
458
459    For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
460
461      ```python
462      [0, 0]: id 1, weight 2.0
463      [0, 1]: id 3, weight 0.5
464      [1, 0]: id 0, weight 1.0
465      [2, 3]: id 1, weight 3.0
466      ```
467
468    with `combiner`="mean", then the output will be a 3x20 matrix where
469
470      ```python
471      output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
472      output[1, :] = (params[0, :] * 1.0) / 1.0
473      output[2, :] = (params[1, :] * 3.0) / 3.0
474      ```
475
476  Raises:
477    TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is
478      neither `None` nor `SparseTensor`.
479    ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}.
480  """
481  if combiner is None:
482    combiner = "mean"
483  if combiner not in ("mean", "sqrtn", "sum"):
484    raise ValueError("combiner must be one of 'mean', 'sqrtn' or 'sum'")
485  if isinstance(params, variables.PartitionedVariable):
486    params = list(params)  # Iterate to get the underlying Variables.
487  if not isinstance(params, list):
488    params = [params]
489  if not isinstance(sp_ids, sparse_tensor.SparseTensor):
490    raise TypeError("sp_ids must be SparseTensor")
491  ignore_weights = sp_weights is None
492  if not ignore_weights:
493    if not isinstance(sp_weights, sparse_tensor.SparseTensor):
494      raise TypeError("sp_weights must be either None or SparseTensor")
495    sp_ids.values.get_shape().assert_is_compatible_with(
496        sp_weights.values.get_shape())
497    sp_ids.indices.get_shape().assert_is_compatible_with(
498        sp_weights.indices.get_shape())
499    sp_ids.dense_shape.get_shape().assert_is_compatible_with(
500        sp_weights.dense_shape.get_shape())
501    # TODO(yleon): Add enhanced node assertions to verify that sp_ids and
502    # sp_weights have equal indices and shapes.
503
504  with ops.name_scope(name, "embedding_lookup_sparse",
505                      params + [sp_ids]) as name:
506    segment_ids = sp_ids.indices[:, 0]
507
508    ids = sp_ids.values
509    ids, idx = array_ops.unique(ids)
510
511    embeddings = embedding_lookup(
512        params, ids, partition_strategy=partition_strategy, max_norm=max_norm)
513    if embeddings.dtype in (dtypes.float16, dtypes.bfloat16):
514      embeddings = math_ops.cast(embeddings, dtypes.float32)
515    if not ignore_weights:
516      if segment_ids.dtype != dtypes.int32:
517        segment_ids = math_ops.cast(segment_ids, dtypes.int32)
518
519      weights = sp_weights.values
520      if weights.dtype != embeddings.dtype:
521        weights = math_ops.cast(weights, embeddings.dtype)
522
523      embeddings = array_ops.gather(embeddings, idx)
524
525      # Reshape weights to allow broadcast
526      ones = array_ops.fill(
527          array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)
528      bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones],
529                                             0)
530
531      orig_weights_shape = weights.get_shape()
532      weights = array_ops.reshape(weights, bcast_weights_shape)
533
534      # Set the weight shape, since after reshaping to bcast_weights_shape,
535      # the shape becomes None.
536      if embeddings.get_shape().ndims is not None:
537        weights.set_shape(
538            orig_weights_shape.concatenate(
539                [1 for _ in range(embeddings.get_shape().ndims - 1)]))
540
541      embeddings *= weights
542
543      if combiner == "sum":
544        embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name)
545      elif combiner == "mean":
546        embeddings = math_ops.segment_sum(embeddings, segment_ids)
547        weight_sum = math_ops.segment_sum(weights, segment_ids)
548        embeddings = math_ops.divide(embeddings, weight_sum, name=name)
549      elif combiner == "sqrtn":
550        embeddings = math_ops.segment_sum(embeddings, segment_ids)
551        weights_squared = math_ops.pow(weights, 2)
552        weight_sum = math_ops.segment_sum(weights_squared, segment_ids)
553        weight_sum_sqrt = math_ops.sqrt(weight_sum)
554        embeddings = math_ops.divide(embeddings, weight_sum_sqrt, name=name)
555      else:
556        assert False, "Unrecognized combiner"
557    else:
558      if compat.forward_compatible(2020, 5, 14):
559        if segment_ids.dtype not in (dtypes.int32, dtypes.int64):
560          segment_ids = math_ops.cast(segment_ids, dtypes.int32)
561      else:
562        if segment_ids.dtype != dtypes.int32:
563          segment_ids = math_ops.cast(segment_ids, dtypes.int32)
564      assert idx is not None
565      if combiner == "sum":
566        embeddings = math_ops.sparse_segment_sum(
567            embeddings, idx, segment_ids, name=name)
568      elif combiner == "mean":
569        embeddings = math_ops.sparse_segment_mean(
570            embeddings, idx, segment_ids, name=name)
571      elif combiner == "sqrtn":
572        embeddings = math_ops.sparse_segment_sqrt_n(
573            embeddings, idx, segment_ids, name=name)
574      else:
575        assert False, "Unrecognized combiner"
576
577    return embeddings
578
579
580@tf_export("nn.embedding_lookup_sparse", v1=[])
581@dispatch.add_dispatch_support
582def embedding_lookup_sparse_v2(params,
583                               sp_ids,
584                               sp_weights,
585                               combiner=None,
586                               max_norm=None,
587                               name=None):
588  """Looks up embeddings for the given ids and weights from a list of tensors.
589
590  This op assumes that there is at least one id for each row in the dense tensor
591  represented by sp_ids (i.e. there are no rows with empty features), and that
592  all the indices of sp_ids are in canonical row-major order.
593
594  `sp_ids` and `sp_weights` (if not None) are `SparseTensor`s with rank of 2.
595  Embeddings are always aggregated along the last dimension.
596
597  It also assumes that all id values lie in the range [0, p0), where p0
598  is the sum of the size of params along dimension 0.
599
600  If `len(params) > 1`, each element of `sp_ids` is partitioned between the
601  elements of `params` according to the "div" partition strategy, which means we
602  assign ids to partitions in a contiguous manner. For instance, 13 ids are
603  split across 5 partitions as:
604  `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`.
605
606  If the id space does not evenly divide the number of partitions, each of the
607  first `(max_id + 1) % len(params)` partitions will be assigned one more id.
608
609  Args:
610    params: A single tensor representing the complete embedding tensor, or a
611      list of tensors all of same shape except for the first dimension,
612      representing sharded embedding tensors following "div" partition strategy.
613    sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size
614      and M is arbitrary.
615    sp_weights: either a `SparseTensor` of float / double weights, or `None` to
616      indicate all weights should be taken to be 1. If specified, `sp_weights`
617      must have exactly the same shape and indices as `sp_ids`.
618    combiner: A string specifying the reduction op. Currently "mean", "sqrtn"
619      and "sum" are supported. "sum" computes the weighted sum of the embedding
620      results for each row. "mean" is the weighted sum divided by the total
621      weight. "sqrtn" is the weighted sum divided by the square root of the sum
622      of the squares of the weights. Defaults to `mean`.
623    max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
624      than this value, before combining.
625    name: Optional name for the op.
626
627  Returns:
628    A dense tensor representing the combined embeddings for the
629    sparse ids. For each row in the dense tensor represented by `sp_ids`, the op
630    looks up the embeddings for all ids in that row, multiplies them by the
631    corresponding weight, and combines these embeddings as specified.
632
633    In other words, if
634
635      `shape(combined params) = [p0, p1, ..., pm]`
636
637    and
638
639      `shape(sp_ids) = shape(sp_weights) = [d0, d1]`
640
641    then
642
643      `shape(output) = [d0, p1, ..., pm]`.
644
645    For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
646
647      ```python
648      [0, 0]: id 1, weight 2.0
649      [0, 1]: id 3, weight 0.5
650      [1, 0]: id 0, weight 1.0
651      [2, 3]: id 1, weight 3.0
652      ```
653
654    with `combiner`="mean", then the output will be a 3x20 matrix where
655
656      ```python
657      output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
658      output[1, :] = (params[0, :] * 1.0) / 1.0
659      output[2, :] = (params[1, :] * 3.0) / 3.0
660      ```
661
662  Raises:
663    TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is
664      neither `None` nor `SparseTensor`.
665    ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}.
666  """
667  return embedding_lookup_sparse(params, sp_ids, sp_weights, "div", name,
668                                 combiner, max_norm)
669
670
671@tf_export("nn.safe_embedding_lookup_sparse", v1=[])
672@dispatch.add_dispatch_support
673def safe_embedding_lookup_sparse_v2(embedding_weights,
674                                    sparse_ids,
675                                    sparse_weights=None,
676                                    combiner="mean",
677                                    default_id=None,
678                                    max_norm=None,
679                                    name=None):
680  """Lookup embedding results, accounting for invalid IDs and empty features.
681
682  The partitioned embedding in `embedding_weights` must all be the same shape
683  except for the first dimension. The first dimension is allowed to vary as the
684  vocabulary size is not necessarily a multiple of num of shards.
685
686  Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
687  with non-positive weight. For an entry with no features, the embedding vector
688  for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
689
690  The ids and weights may be multi-dimensional. Embeddings are always aggregated
691  along the last dimension.
692
693  If `len(embedding_weights) > 1`, each element `id` of `ids` is partitioned
694  between the elements of `embedding_weights` according to the "div" partition
695  strategy, which means we assign ids to partitions in a contiguous manner. For
696  instance, 13 ids are split across 5 partitions as:
697  `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`.
698
699  If the id space does not evenly divide the number of partitions, each of the
700  first `(max_id + 1) % len(embedding_weights)` partitions will be assigned one
701  more id.
702
703  Args:
704    embedding_weights: A single tensor representing the complete embedding
705      tensor, or a list of tensors all of same shape except for the first
706      dimension, representing sharded embedding tensors following "div"
707      partition strategy.
708    sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
709      ids. `d_0` is typically batch size.
710    sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing
711      float weights corresponding to `sparse_ids`, or `None` if all weights are
712      be assumed to be 1.0.
713    combiner: A string specifying how to combine embedding results for each
714      entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the
715      default.
716    default_id: The id to use for an entry with no features. Defaults to
717      0-vector.
718    max_norm: If not `None`, all embeddings are l2-normalized to max_norm before
719      combining.
720    name: A name for this operation (optional).
721
722  Returns:
723    A dense tensor representing the combined embeddings for the
724    sparse ids. For each row in the dense tensor represented by `sparse_ids`,
725    the op looks up the embeddings for all ids in that row, multiplies them by
726    the corresponding weight, and combines these embeddings as specified.
727
728    In other words, if
729
730      `shape(combined embedding_weights) = [p0, p1, ..., pm]`
731
732    and
733
734      `shape(sparse_ids) = shape(sparse_weights) = [d0, d1, ..., dn]`
735
736    then
737
738      `shape(output) = [d0, d1, ... dn-1, p1, ..., pm]`.
739
740    For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
741
742      ```python
743      [0, 0]: id 1, weight 2.0
744      [0, 1]: id 3, weight 0.5
745      [1, 0]: id -1, weight 1.0
746      [2, 3]: id 1, weight 3.0
747      ```
748
749    `default_id` is 0.
750
751    with `combiner`="mean", then the output will be a 3x20 matrix where
752
753      ```python
754      output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
755      output[1, :] = (params[0, :] * 1.0) / 1.0
756      output[2, :] = (params[1, :] * 3.0) / 3.0
757      ```
758
759  Raises:
760    ValueError: if `embedding_weights` is empty.
761  """
762  return safe_embedding_lookup_sparse(
763      embedding_weights,
764      sparse_ids,
765      sparse_weights=sparse_weights,
766      combiner=combiner,
767      default_id=default_id,
768      name=name,
769      partition_strategy="div",
770      max_norm=max_norm)
771
772
773@tf_export(v1=["nn.safe_embedding_lookup_sparse"])
774@dispatch.add_dispatch_support
775def safe_embedding_lookup_sparse(embedding_weights,
776                                 sparse_ids,
777                                 sparse_weights=None,
778                                 combiner="mean",
779                                 default_id=None,
780                                 name=None,
781                                 partition_strategy="div",
782                                 max_norm=None):
783  """Lookup embedding results, accounting for invalid IDs and empty features.
784
785  The partitioned embedding in `embedding_weights` must all be the same shape
786  except for the first dimension. The first dimension is allowed to vary as the
787  vocabulary size is not necessarily a multiple of `P`.  `embedding_weights`
788  may be a `PartitionedVariable` as returned by using
789  `tf.compat.v1.get_variable()` with a
790  partitioner.
791
792  Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
793  with non-positive weight. For an entry with no features, the embedding vector
794  for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
795
796  The ids and weights may be multi-dimensional. Embeddings are always aggregated
797  along the last dimension.
798
799  Args:
800    embedding_weights: A single tensor representing the complete embedding
801      tensor, or a list tensors all of same shape except for the first
802      dimension, representing sharded embedding tensors. Alternatively, a
803      `PartitionedVariable`, created by partitioning along dimension 0. Each
804      element must be appropriately sized for the given `partition_strategy`.
805    sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
806      ids. `d_0` is typically batch size.
807    sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing
808      float weights corresponding to `sparse_ids`, or `None` if all weights are
809      be assumed to be 1.0.
810    combiner: A string specifying how to combine embedding results for each
811      entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the
812      default.
813    default_id: The id to use for an entry with no features.
814    name: A name for this operation (optional).
815    partition_strategy: A string specifying the partitioning strategy. Currently
816      `"div"` and `"mod"` are supported. Default is `"div"`.
817    max_norm: If not `None`, all embeddings are l2-normalized to max_norm before
818      combining.
819
820  Returns:
821    A dense tensor representing the combined embeddings for the
822    sparse ids. For each row in the dense tensor represented by `sp_ids`, the op
823    looks up the embeddings for all ids in that row, multiplies them by the
824    corresponding weight, and combines these embeddings as specified.
825
826    In other words, if
827
828      `shape(combined embedding_weights) = [p0, p1, ..., pm]`
829
830    and
831
832      `shape(sparse_ids) = shape(sparse_weights) = [d0, d1, ..., dn]`
833
834    then
835
836      `shape(output) = [d0, d1, ... dn-1, p1, ..., pm]`.
837
838    For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
839
840      ```python
841      [0, 0]: id 1, weight 2.0
842      [0, 1]: id 3, weight 0.5
843      [1, 0]: id -1, weight 1.0
844      [2, 3]: id 1, weight 3.0
845      ```
846
847    `default_id` is 0.
848
849    with `combiner`="mean", then the output will be a 3x20 matrix where
850
851      ```python
852      output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
853      output[1, :] = (params[0, :] * 1.0) / 1.0
854      output[2, :] = (params[1, :] * 3.0) / 3.0
855      ```
856
857  Raises:
858    ValueError: if `embedding_weights` is empty.
859  """
860  if embedding_weights is None:
861    raise ValueError("Missing embedding_weights %s." % embedding_weights)
862  if isinstance(embedding_weights, variables.PartitionedVariable):
863    embedding_weights = list(embedding_weights)  # get underlying Variables.
864  if not isinstance(embedding_weights, list):
865    embedding_weights = [embedding_weights]
866  if len(embedding_weights) < 1:
867    raise ValueError("Missing embedding_weights %s." % embedding_weights)
868
869  dtype = sparse_weights.dtype if sparse_weights is not None else None
870  embedding_weights = [
871      w if (isinstance(w, resource_variable_ops.ResourceVariable)
872            and dtype in (None, w.dtype))
873      else ops.convert_to_tensor(w, dtype=dtype)
874      for w in embedding_weights
875  ]
876
877  with ops.name_scope(name, "embedding_lookup", embedding_weights +
878                      [sparse_ids, sparse_weights]) as scope:
879    # Reshape higher-rank sparse ids and weights to linear segment ids.
880    original_shape = sparse_ids.dense_shape
881    original_rank_dim = tensor_shape.dimension_value(
882        sparse_ids.dense_shape.get_shape()[0])
883    original_rank = (
884        array_ops.size(original_shape)
885        if original_rank_dim is None else original_rank_dim)
886    sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [
887        math_ops.reduce_prod(
888            array_ops.slice(original_shape, [0], [original_rank - 1])),
889        array_ops.gather(original_shape, original_rank - 1)
890    ])
891    if sparse_weights is not None:
892      sparse_weights = sparse_tensor.SparseTensor(sparse_ids.indices,
893                                                  sparse_weights.values,
894                                                  sparse_ids.dense_shape)
895
896    # Prune invalid ids and weights.
897    sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights)
898    if combiner != "sum":
899      sparse_ids, sparse_weights = _prune_invalid_weights(
900          sparse_ids, sparse_weights)
901
902    # Fill in dummy values for empty features, if necessary.
903    sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(
904        sparse_ids, default_id or 0)
905    if sparse_weights is not None:
906      sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0)
907
908    result = embedding_lookup_sparse(
909        embedding_weights,
910        sparse_ids,
911        sparse_weights,
912        combiner=combiner,
913        partition_strategy=partition_strategy,
914        name=None if default_id is None else scope,
915        max_norm=max_norm)
916
917    if default_id is None:
918      # Broadcast is_row_empty to the same shape as embedding_lookup_result,
919      # for use in Select.
920      is_row_empty = array_ops.tile(
921          array_ops.reshape(is_row_empty, [-1, 1]),
922          array_ops.stack([1, array_ops.shape(result)[1]]))
923
924      result = array_ops.where(
925          is_row_empty, array_ops.zeros_like(result), result, name=scope)
926
927    # Reshape back from linear ids back into higher-dimensional dense result.
928    final_result = array_ops.reshape(
929        result,
930        array_ops.concat([
931            array_ops.slice(
932                math_ops.cast(original_shape, dtypes.int32), [0],
933                [original_rank - 1]),
934            array_ops.slice(array_ops.shape(result), [1], [-1])
935        ], 0))
936    final_result.set_shape(
937        tensor_shape.unknown_shape(
938            (tensor_shape.Dimension(original_rank_dim) - 1).value).concatenate(
939                result.get_shape()[1:]))
940    return final_result
941
942
943def embedding_lookup_ragged(embedding_weights,
944                            ragged_ids,
945                            partition_strategy="mod",
946                            max_norm=None,
947                            name=None):
948  """Look up the ragged ids in a list of embedding tensors.
949
950  Args:
951    embedding_weights: A tensor representing the complete embedding tensor
952      having the shape [e1, ...eM]
953    ragged_ids: A 'RaggedTensor' with type 'int32' or 'int64' containing the ids
954      to be looked up in 'embedding_weights' of shape [r0, ..rN]. Values must be
955      in the range '[0, embedding_weights.shape[0]]'.
956    partition_strategy: A string specifying the partitioning strategy.
957    max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
958      than this value.
959    name: A name for the operation (optional)
960
961  Returns:
962    A ragged tensor of shape [r0, r1, ...rN, e1, ...eM].
963
964  Raises:
965    ValueError: whether the embedding_weights is empty or the ragged_ids is
966    not a RaggedTensor.
967  """
968  if embedding_weights is None:
969    raise ValueError("The embedding weights must be specified.")
970  if isinstance(embedding_weights, (list, tuple)) and not embedding_weights:
971    raise ValueError("The embedding weights should not be empty.")
972  if ragged_ids.dtype != dtypes.int32 and ragged_ids.dtype != dtypes.int64:
973    raise ValueError("The values contained by the inputs have type " +
974                     str(ragged_ids.dtype) +
975                     " and cannot be processed. All values"
976                     " should be indices, either of type `in32` or `int64`.")
977
978  with ops.name_scope(name, "embedding_lookup_ragged") as name:
979    looked_up_ragged = ragged_functional_ops.map_flat_values(
980        embedding_lookup,
981        params=embedding_weights,
982        ids=ragged_ids,
983        partition_strategy=partition_strategy,
984        max_norm=max_norm)
985
986    return looked_up_ragged
987
988
989def _prune_invalid_ids(sparse_ids, sparse_weights):
990  """Prune invalid IDs (< 0) from the input ids and weights."""
991  is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
992  if sparse_weights is not None:
993    is_id_valid = math_ops.logical_and(
994        is_id_valid,
995        array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool))
996  sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
997  if sparse_weights is not None:
998    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
999  return sparse_ids, sparse_weights
1000
1001
1002def _prune_invalid_weights(sparse_ids, sparse_weights):
1003  """Prune invalid weights (< 0) from the input ids and weights."""
1004  if sparse_weights is not None:
1005    is_weights_valid = math_ops.greater(sparse_weights.values, 0)
1006    sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid)
1007    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid)
1008  return sparse_ids, sparse_weights
1009