1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operations for embeddings."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from six.moves import xrange  # pylint: disable=redefined-builtin
21
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import sparse_tensor
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import clip_ops
29# Imports gradient definitions.
30from tensorflow.python.ops import data_flow_grad  # pylint: disable=unused-import
31from tensorflow.python.ops import data_flow_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import resource_variable_ops
34from tensorflow.python.ops import sparse_ops
35from tensorflow.python.ops import variables
36from tensorflow.python.platform import tf_logging as logging
37from tensorflow.python.util.tf_export import tf_export
38
39
40def _clip(params, ids, max_norm):
41  """Helper function for _embedding_lookup_and_transform.
42
43  This function optionally clips embeddings to an l2-norm of max_norm.
44
45  Args:
46    params: A `Tensor` of embeddings retrieved by `gather`.
47    ids: The `ids` argument that was passed to `gather`.
48    max_norm: If not `None`, each embedding is clipped if its l2-norm is
49      larger than this value.
50
51  Returns:
52    A `Tensor` with the same type as `params`.
53  """
54
55  def _rank(x):
56    """Helper function to retrieve the rank of a tensor.
57
58    Args:
59      x: Something convertible to `Tensor`.
60
61    Returns:
62      Either a pair `(rank, True)` where `rank` is an integer or a pair
63      `(rank, False)` where `rank` is an integer `Tensor`. In either case,
64      `rank` is the rank of `x`.
65    """
66    rank = ops.convert_to_tensor(x).get_shape().ndims
67    if rank:
68      return rank, True
69    else:
70      return array_ops.rank(x), False
71
72  if max_norm is None:
73    return params
74  ids_rank, ids_static = _rank(ids)
75  params_rank, params_static = _rank(params)
76  return clip_ops.clip_by_norm(
77      params,
78      max_norm,
79      axes=(list(range(ids_rank, params_rank))
80            if ids_static and params_static
81            else math_ops.range(ids_rank, params_rank)))
82
83
84def _embedding_lookup_and_transform(params,
85                                    ids,
86                                    partition_strategy="mod",
87                                    name=None,
88                                    max_norm=None,
89                                    transform_fn=None):
90  """Helper function for embedding_lookup and _compute_sampled_logits.
91
92  This function is a generalization of embedding_lookup that optionally
93  applies a caller-specified transformation to each embedding. This is
94  done through the `transform_fn` argument. If provided, the function is
95  applied to each partitioned tensor of retrieved embeddings, colocated
96  with the embeddings. This function will be called with a single `Tensor`
97  argument of the same type as the `params` tensor and should return a
98  `Tensor`. The shape of the argument will be the same as `params` except
99  for the size of the first dimension. The first dimension of the result's
100  shape must be the same size as the argument's.
101
102  Args:
103    params: See embedding_lookup.
104    ids: See embedding_lookup.
105    partition_strategy: See embedding_lookup.
106    name: See embedding_lookup.
107    max_norm: See embedding_lookup.
108    transform_fn: An optional function to apply to each retrieved embedding.
109      If max_norm is provided, transform_fn is applied to the norm-limited
110      embeddings.
111
112  Returns:
113    See embedding_lookup for details.
114  Raises:
115    ValueError: If `params` is empty.
116  """
117  if params is None or params in ((), []):
118    raise ValueError("Need at least one param")
119  if isinstance(params, variables.PartitionedVariable):
120    params = list(params)  # Iterate to get the underlying Variables.
121  if not isinstance(params, list):
122    params = [params]
123
124  with ops.name_scope(name, "embedding_lookup", params + [ids]) as name:
125    np = len(params)  # Number of partitions
126    # Preserve the resource variable status to avoid accidental dense reads.
127    if not any(
128        isinstance(p, resource_variable_ops.ResourceVariable) for p in params):
129      params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
130    ids = ops.convert_to_tensor(ids, name="ids")
131    if np == 1 and (not transform_fn or ids.get_shape().ndims == 1):
132      with ops.colocate_with(params[0]):
133        result = _clip(array_ops.gather(params[0], ids, name=name),
134                       ids, max_norm)
135        if transform_fn:
136          result = transform_fn(result)
137      # Make sure the final result does not have colocation contraints on the
138      # params. Similar to the case np > 1 where parallel_dynamic_stitch is
139      # outside the scioe of all with ops.colocate_with(params[p]).
140      return array_ops.identity(result)
141    else:
142      # Flatten the ids. There are two cases where we need to do this.
143      # - There is more than one params tensor.
144      # - There is a transform_fn and ids is not statically known to be 1-D.
145      #   We must flatten in this case because transform_fn expects a flat
146      #   tensor of embeddings.
147      flat_ids = array_ops.reshape(ids, [-1])
148      original_indices = math_ops.range(array_ops.size(flat_ids))
149
150      # Create p_assignments and set new_ids depending on the strategy.
151      if partition_strategy == "mod":
152        p_assignments = flat_ids % np
153        new_ids = flat_ids // np
154      elif partition_strategy == "div":
155        # Compute num_total_ids as the sum of dim-0 of params, then assign to
156        # partitions based on a constant number of ids per partition. Optimize
157        # if we already know the full shape statically.
158        dim_0_size = tensor_shape.Dimension(tensor_shape.dimension_value(
159            params[0].get_shape()[0]))
160        for p in xrange(1, np):
161          dim_0_size += tensor_shape.Dimension(tensor_shape.dimension_value(
162              params[p].get_shape()[0]))
163        if dim_0_size.value:
164          num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
165        else:
166          dim_0_sizes = []
167          for p in xrange(np):
168            param_p_dim = tensor_shape.dimension_value(params[p].get_shape()[0])
169            if param_p_dim is not None:
170              dim_0_sizes.append(param_p_dim)
171            else:
172              with ops.colocate_with(params[p]):
173                dim_0_sizes.append(array_ops.shape(params[p])[0])
174          num_total_ids = math_ops.reduce_sum(
175              math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
176        ids_per_partition = num_total_ids // np
177        extras = num_total_ids % np
178
179        p_assignments = math_ops.maximum(
180            flat_ids // (ids_per_partition + 1),
181            (flat_ids - extras) // ids_per_partition)
182
183        # Emulate a conditional using a boolean indicator tensor
184        new_ids = array_ops.where(p_assignments < extras,
185                                  flat_ids % (ids_per_partition + 1),
186                                  (flat_ids - extras) % ids_per_partition)
187      else:
188        raise ValueError("Unrecognized partition strategy: " +
189                         partition_strategy)
190
191      # Cast partition assignments to int32 for use in dynamic_partition.
192      # There really should not be more than 2^32 partitions.
193      p_assignments = math_ops.cast(p_assignments, dtypes.int32)
194      # Partition list of ids based on assignments into np separate lists
195      gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np)
196      # Similarly, partition the original indices.
197      pindices = data_flow_ops.dynamic_partition(original_indices,
198                                                 p_assignments, np)
199      # Do np separate lookups, finding embeddings for plist[p] in params[p]
200      partitioned_result = []
201      for p in xrange(np):
202        pids = gather_ids[p]
203        with ops.colocate_with(params[p]):
204          result = array_ops.gather(params[p], pids)
205          if transform_fn:
206            # If transform_fn is provided, the clip_by_norm precedes
207            # the transform and hence must be co-located. See below
208            # for the counterpart if transform_fn is not proveded.
209            result = transform_fn(_clip(result, pids, max_norm))
210        partitioned_result.append(result)
211      # Stitch these back together
212      ret = data_flow_ops.parallel_dynamic_stitch(
213          pindices, partitioned_result, name=name)
214
215      # Determine the static element shape.
216      if transform_fn is None:
217        element_shape_s = params[0].get_shape()[1:]
218        for p in params[1:]:
219          element_shape_s = element_shape_s.merge_with(p.get_shape()[1:])
220      else:
221        element_shape_s = ret.get_shape()[1:]
222
223      # Compute the dynamic element shape.
224      if element_shape_s.is_fully_defined():
225        element_shape_d = element_shape_s
226      elif transform_fn is None:
227        # It's important that we compute params[0].shape on the right device
228        # to avoid data motion.
229        with ops.colocate_with(params[0]):
230          params_shape = array_ops.shape(params[0])
231        element_shape_d = params_shape[1:]
232      else:
233        element_shape_d = array_ops.shape(ret)[1:]
234
235      # Reshape to reverse the flattening of ids.
236      ret = array_ops.reshape(ret,
237                              array_ops.concat(
238                                  [array_ops.shape(ids), element_shape_d], 0))
239
240      # Normally the reshape is sufficient, but setting shape explicitly
241      # teaches shape inference that params[1:].get_shape() matters
242      # (in the case that transform_fn is None).
243      ret.set_shape(ids.get_shape().concatenate(element_shape_s))
244      if not transform_fn:
245        # If transform_fn was provided, the clip_by_norm was done above.
246        ret = _clip(ret, ids, max_norm)
247      return ret
248
249
250@tf_export(v1=["nn.embedding_lookup"])
251def embedding_lookup(
252    params,
253    ids,
254    partition_strategy="mod",
255    name=None,
256    validate_indices=True,  # pylint: disable=unused-argument
257    max_norm=None):
258  """Looks up `ids` in a list of embedding tensors.
259
260  This function is used to perform parallel lookups on the list of
261  tensors in `params`.  It is a generalization of
262  `tf.gather`, where `params` is
263  interpreted as a partitioning of a large embedding tensor.  `params` may be
264  a `PartitionedVariable` as returned by using `tf.get_variable()` with a
265  partitioner.
266
267  If `len(params) > 1`, each element `id` of `ids` is partitioned between
268  the elements of `params` according to the `partition_strategy`.
269  In all strategies, if the id space does not evenly divide the number of
270  partitions, each of the first `(max_id + 1) % len(params)` partitions will
271  be assigned one more id.
272
273  If `partition_strategy` is `"mod"`, we assign each id to partition
274  `p = id % len(params)`. For instance,
275  13 ids are split across 5 partitions as:
276  `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`
277
278  If `partition_strategy` is `"div"`, we assign ids to partitions in a
279  contiguous manner. In this case, 13 ids are split across 5 partitions as:
280  `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`
281
282  The results of the lookup are concatenated into a dense
283  tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
284
285  Args:
286    params: A single tensor representing the complete embedding tensor,
287      or a list of P tensors all of same shape except for the first dimension,
288      representing sharded embedding tensors.  Alternatively, a
289      `PartitionedVariable`, created by partitioning along dimension 0. Each
290      element must be appropriately sized for the given `partition_strategy`.
291    ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked
292      up in `params`.
293    partition_strategy: A string specifying the partitioning strategy, relevant
294      if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
295      is `"mod"`.
296    name: A name for the operation (optional).
297    validate_indices: DEPRECATED. If this operation is assigned to CPU, values
298      in `indices` are always validated to be within range.  If assigned to GPU,
299      out-of-bound indices result in safe but unspecified behavior, which may
300      include raising an error.
301    max_norm: If not `None`, each embedding is clipped if its l2-norm is
302      larger than this value.
303
304  Returns:
305    A `Tensor` with the same type as the tensors in `params`.
306
307  Raises:
308    ValueError: If `params` is empty.
309  """
310  return _embedding_lookup_and_transform(
311      params=params,
312      ids=ids,
313      partition_strategy=partition_strategy,
314      name=name,
315      max_norm=max_norm,
316      transform_fn=None)
317
318
319@tf_export("nn.embedding_lookup", v1=[])
320def embedding_lookup_v2(
321    params,
322    ids,
323    max_norm=None,
324    name=None):
325  """Looks up `ids` in a list of embedding tensors.
326
327  This function is used to perform parallel lookups on the list of
328  tensors in `params`.  It is a generalization of
329  `tf.gather`, where `params` is
330  interpreted as a partitioning of a large embedding tensor.  `params` may be
331  a `PartitionedVariable` as returned by using `tf.get_variable()` with a
332  partitioner.
333
334  If `len(params) > 1`, each element `id` of `ids` is partitioned between
335  the elements of `params` according to the `partition_strategy`.
336  In all strategies, if the id space does not evenly divide the number of
337  partitions, each of the first `(max_id + 1) % len(params)` partitions will
338  be assigned one more id.
339
340  The `partition_strategy` is always `"div"` currently. This means that we
341  assign ids to partitions in a contiguous manner. For instance, 13 ids are
342  split across 5 partitions as:
343  `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`
344
345  The results of the lookup are concatenated into a dense
346  tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
347
348  Args:
349    params: A single tensor representing the complete embedding tensor,
350      or a list of P tensors all of same shape except for the first dimension,
351      representing sharded embedding tensors.  Alternatively, a
352      `PartitionedVariable`, created by partitioning along dimension 0. Each
353      element must be appropriately sized for the 'div' `partition_strategy`.
354    ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked
355      up in `params`.
356    max_norm: If not `None`, each embedding is clipped if its l2-norm is
357      larger than this value.
358    name: A name for the operation (optional).
359
360  Returns:
361    A `Tensor` with the same type as the tensors in `params`.
362
363  Raises:
364    ValueError: If `params` is empty.
365  """
366  return embedding_lookup(params, ids, "div", name,
367                          max_norm=max_norm)
368
369
370@tf_export(v1=["nn.embedding_lookup_sparse"])
371def embedding_lookup_sparse(params,
372                            sp_ids,
373                            sp_weights,
374                            partition_strategy="mod",
375                            name=None,
376                            combiner=None,
377                            max_norm=None):
378  """Computes embeddings for the given ids and weights.
379
380  This op assumes that there is at least one id for each row in the dense tensor
381  represented by sp_ids (i.e. there are no rows with empty features), and that
382  all the indices of sp_ids are in canonical row-major order.
383
384  It also assumes that all id values lie in the range [0, p0), where p0
385  is the sum of the size of params along dimension 0.
386
387  Args:
388    params: A single tensor representing the complete embedding tensor,
389      or a list of P tensors all of same shape except for the first dimension,
390      representing sharded embedding tensors.  Alternatively, a
391      `PartitionedVariable`, created by partitioning along dimension 0. Each
392      element must be appropriately sized for the given `partition_strategy`.
393    sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size
394      and M is arbitrary.
395    sp_weights: either a `SparseTensor` of float / double weights, or `None` to
396      indicate all weights should be taken to be 1. If specified, `sp_weights`
397      must have exactly the same shape and indices as `sp_ids`.
398    partition_strategy: A string specifying the partitioning strategy, relevant
399      if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
400      is `"mod"`. See `tf.nn.embedding_lookup` for more details.
401    name: Optional name for the op.
402    combiner: A string specifying the reduction op. Currently "mean", "sqrtn"
403      and "sum" are supported.
404      "sum" computes the weighted sum of the embedding results for each row.
405      "mean" is the weighted sum divided by the total weight.
406      "sqrtn" is the weighted sum divided by the square root of the sum of the
407      squares of the weights.
408    max_norm: If not `None`, each embedding is clipped if its l2-norm is
409      larger than this value, before combining.
410
411  Returns:
412    A dense tensor representing the combined embeddings for the
413    sparse ids. For each row in the dense tensor represented by `sp_ids`, the op
414    looks up the embeddings for all ids in that row, multiplies them by the
415    corresponding weight, and combines these embeddings as specified.
416
417    In other words, if
418
419      `shape(combined params) = [p0, p1, ..., pm]`
420
421    and
422
423      `shape(sp_ids) = shape(sp_weights) = [d0, d1, ..., dn]`
424
425    then
426
427      `shape(output) = [d0, d1, ..., dn-1, p1, ..., pm]`.
428
429    For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
430
431      ```python
432      [0, 0]: id 1, weight 2.0
433      [0, 1]: id 3, weight 0.5
434      [1, 0]: id 0, weight 1.0
435      [2, 3]: id 1, weight 3.0
436      ```
437
438    with `combiner`="mean", then the output will be a 3x20 matrix where
439
440      ```python
441      output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
442      output[1, :] = (params[0, :] * 1.0) / 1.0
443      output[2, :] = (params[1, :] * 3.0) / 3.0
444      ```
445
446  Raises:
447    TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is
448      neither `None` nor `SparseTensor`.
449    ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}.
450  """
451  if combiner is None:
452    logging.warn("The default value of combiner will change from \"mean\" "
453                 "to \"sqrtn\" after 2016/11/01.")
454    combiner = "mean"
455  if combiner not in ("mean", "sqrtn", "sum"):
456    raise ValueError("combiner must be one of 'mean', 'sqrtn' or 'sum'")
457  if isinstance(params, variables.PartitionedVariable):
458    params = list(params)  # Iterate to get the underlying Variables.
459  if not isinstance(params, list):
460    params = [params]
461  if not isinstance(sp_ids, sparse_tensor.SparseTensor):
462    raise TypeError("sp_ids must be SparseTensor")
463  ignore_weights = sp_weights is None
464  if not ignore_weights:
465    if not isinstance(sp_weights, sparse_tensor.SparseTensor):
466      raise TypeError("sp_weights must be either None or SparseTensor")
467    sp_ids.values.get_shape().assert_is_compatible_with(
468        sp_weights.values.get_shape())
469    sp_ids.indices.get_shape().assert_is_compatible_with(
470        sp_weights.indices.get_shape())
471    sp_ids.dense_shape.get_shape().assert_is_compatible_with(
472        sp_weights.dense_shape.get_shape())
473    # TODO(yleon): Add enhanced node assertions to verify that sp_ids and
474    # sp_weights have equal indices and shapes.
475
476  with ops.name_scope(name, "embedding_lookup_sparse",
477                      params + [sp_ids]) as name:
478    segment_ids = sp_ids.indices[:, 0]
479    if segment_ids.dtype != dtypes.int32:
480      segment_ids = math_ops.cast(segment_ids, dtypes.int32)
481
482    ids = sp_ids.values
483    ids, idx = array_ops.unique(ids)
484
485    embeddings = embedding_lookup(
486        params, ids, partition_strategy=partition_strategy, max_norm=max_norm)
487    if embeddings.dtype in (dtypes.float16, dtypes.bfloat16):
488      embeddings = math_ops.cast(embeddings, dtypes.float32)
489    if not ignore_weights:
490      weights = sp_weights.values
491      if weights.dtype != embeddings.dtype:
492        weights = math_ops.cast(weights, embeddings.dtype)
493
494      embeddings = array_ops.gather(embeddings, idx)
495
496      # Reshape weights to allow broadcast
497      ones = array_ops.fill(
498          array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)
499      bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones],
500                                             0)
501
502      orig_weights_shape = weights.get_shape()
503      weights = array_ops.reshape(weights, bcast_weights_shape)
504
505      # Set the weight shape, since after reshaping to bcast_weights_shape,
506      # the shape becomes None.
507      if embeddings.get_shape().ndims is not None:
508        weights.set_shape(
509            orig_weights_shape.concatenate(
510                [1 for _ in range(embeddings.get_shape().ndims - 1)]))
511
512      embeddings *= weights
513
514      if combiner == "sum":
515        embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name)
516      elif combiner == "mean":
517        embeddings = math_ops.segment_sum(embeddings, segment_ids)
518        weight_sum = math_ops.segment_sum(weights, segment_ids)
519        embeddings = math_ops.div(embeddings, weight_sum, name=name)
520      elif combiner == "sqrtn":
521        embeddings = math_ops.segment_sum(embeddings, segment_ids)
522        weights_squared = math_ops.pow(weights, 2)
523        weight_sum = math_ops.segment_sum(weights_squared, segment_ids)
524        weight_sum_sqrt = math_ops.sqrt(weight_sum)
525        embeddings = math_ops.div(embeddings, weight_sum_sqrt, name=name)
526      else:
527        assert False, "Unrecognized combiner"
528    else:
529      assert idx is not None
530      if combiner == "sum":
531        embeddings = math_ops.sparse_segment_sum(
532            embeddings, idx, segment_ids, name=name)
533      elif combiner == "mean":
534        embeddings = math_ops.sparse_segment_mean(
535            embeddings, idx, segment_ids, name=name)
536      elif combiner == "sqrtn":
537        embeddings = math_ops.sparse_segment_sqrt_n(
538            embeddings, idx, segment_ids, name=name)
539      else:
540        assert False, "Unrecognized combiner"
541
542    return embeddings
543
544
545@tf_export("nn.embedding_lookup_sparse", v1=[])
546def embedding_lookup_sparse_v2(params,
547                               sp_ids,
548                               sp_weights,
549                               combiner=None,
550                               max_norm=None,
551                               name=None):
552  """Computes embeddings for the given ids and weights.
553
554  This op assumes that there is at least one id for each row in the dense tensor
555  represented by sp_ids (i.e. there are no rows with empty features), and that
556  all the indices of sp_ids are in canonical row-major order.
557
558  It also assumes that all id values lie in the range [0, p0), where p0
559  is the sum of the size of params along dimension 0.
560
561  Args:
562    params: A single tensor representing the complete embedding tensor,
563      or a list of P tensors all of same shape except for the first dimension,
564      representing sharded embedding tensors.  Alternatively, a
565      `PartitionedVariable`, created by partitioning along dimension 0. Each
566      element must be appropriately sized for ``"div"`` `partition_strategy`.
567    sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size
568      and M is arbitrary.
569    sp_weights: either a `SparseTensor` of float / double weights, or `None` to
570      indicate all weights should be taken to be 1. If specified, `sp_weights`
571      must have exactly the same shape and indices as `sp_ids`.
572    combiner: A string specifying the reduction op. Currently "mean", "sqrtn"
573      and "sum" are supported.
574      "sum" computes the weighted sum of the embedding results for each row.
575      "mean" is the weighted sum divided by the total weight.
576      "sqrtn" is the weighted sum divided by the square root of the sum of the
577      squares of the weights.
578    max_norm: If not `None`, each embedding is clipped if its l2-norm is
579      larger than this value, before combining.
580    name: Optional name for the op.
581
582  Returns:
583    A dense tensor representing the combined embeddings for the
584    sparse ids. For each row in the dense tensor represented by `sp_ids`, the op
585    looks up the embeddings for all ids in that row, multiplies them by the
586    corresponding weight, and combines these embeddings as specified.
587
588    In other words, if
589
590      `shape(combined params) = [p0, p1, ..., pm]`
591
592    and
593
594      `shape(sp_ids) = shape(sp_weights) = [d0, d1, ..., dn]`
595
596    then
597
598      `shape(output) = [d0, d1, ..., dn-1, p1, ..., pm]`.
599
600    For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
601
602      ```python
603      [0, 0]: id 1, weight 2.0
604      [0, 1]: id 3, weight 0.5
605      [1, 0]: id 0, weight 1.0
606      [2, 3]: id 1, weight 3.0
607      ```
608
609    with `combiner`="mean", then the output will be a 3x20 matrix where
610
611      ```python
612      output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
613      output[1, :] = (params[0, :] * 1.0) / 1.0
614      output[2, :] = (params[1, :] * 3.0) / 3.0
615      ```
616
617  Raises:
618    TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is
619      neither `None` nor `SparseTensor`.
620    ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}.
621  """
622  return embedding_lookup_sparse(
623      params, sp_ids, sp_weights, "div", name, combiner, max_norm)
624
625
626@tf_export("nn.safe_embedding_lookup_sparse", v1=[])
627def safe_embedding_lookup_sparse_v2(embedding_weights,
628                                    sparse_ids,
629                                    sparse_weights=None,
630                                    combiner="mean",
631                                    default_id=None,
632                                    max_norm=None,
633                                    name=None):
634  """Lookup embedding results, accounting for invalid IDs and empty features.
635
636  The partitioned embedding in `embedding_weights` must all be the same shape
637  except for the first dimension. The first dimension is allowed to vary as the
638  vocabulary size is not necessarily a multiple of `P`.  `embedding_weights`
639  may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a
640  partitioner.
641
642  Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
643  with non-positive weight. For an entry with no features, the embedding vector
644  for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
645
646  The ids and weights may be multi-dimensional. Embeddings are always aggregated
647  along the last dimension.
648
649  Note: when doing embedding lookup on `embedding_weights`, "div" partition
650  strategy will be used. Support for other partition strategy will be added
651  later.
652
653  Args:
654    embedding_weights:  A list of `P` float `Tensor`s or values representing
655      partitioned embedding `Tensor`s.  Alternatively, a `PartitionedVariable`
656      created by partitioning along dimension 0.  The total unpartitioned shape
657      should be `[e_0, e_1, ..., e_m]`, where `e_0` represents the vocab size
658      and `e_1, ..., e_m` are the embedding dimensions.
659    sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
660      ids. `d_0` is typically batch size.
661    sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing
662      float weights corresponding to `sparse_ids`, or `None` if all weights are
663      be assumed to be 1.0.
664    combiner: A string specifying how to combine embedding results for each
665      entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean" the
666      default.
667    default_id: The id to use for an entry with no features.
668    max_norm: If not `None`, all embeddings are l2-normalized to max_norm before
669      combining.
670    name: A name for this operation (optional).
671
672  Returns:
673    Dense `Tensor` of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`.
674
675  Raises:
676    ValueError: if `embedding_weights` is empty.
677  """
678  return safe_embedding_lookup_sparse(
679      embedding_weights,
680      sparse_ids,
681      sparse_weights=sparse_weights,
682      combiner=combiner,
683      default_id=default_id,
684      name=name,
685      partition_strategy="div",
686      max_norm=max_norm)
687
688
689@tf_export(v1=["nn.safe_embedding_lookup_sparse"])
690def safe_embedding_lookup_sparse(embedding_weights,
691                                 sparse_ids,
692                                 sparse_weights=None,
693                                 combiner='mean',
694                                 default_id=None,
695                                 name=None,
696                                 partition_strategy='div',
697                                 max_norm=None):
698  """Lookup embedding results, accounting for invalid IDs and empty features.
699
700  The partitioned embedding in `embedding_weights` must all be the same shape
701  except for the first dimension. The first dimension is allowed to vary as the
702  vocabulary size is not necessarily a multiple of `P`.  `embedding_weights`
703  may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a
704  partitioner.
705
706  Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
707  with non-positive weight. For an entry with no features, the embedding vector
708  for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
709
710  The ids and weights may be multi-dimensional. Embeddings are always aggregated
711  along the last dimension.
712
713  Args:
714    embedding_weights:  A list of `P` float `Tensor`s or values representing
715        partitioned embedding `Tensor`s.  Alternatively, a `PartitionedVariable`
716        created by partitioning along dimension 0.  The total unpartitioned
717        shape should be `[e_0, e_1, ..., e_m]`, where `e_0` represents the
718        vocab size and `e_1, ..., e_m` are the embedding dimensions.
719    sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
720        ids. `d_0` is typically batch size.
721    sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing
722        float weights corresponding to `sparse_ids`, or `None` if all weights
723        are be assumed to be 1.0.
724    combiner: A string specifying how to combine embedding results for each
725        entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
726        the default.
727    default_id: The id to use for an entry with no features.
728    name: A name for this operation (optional).
729    partition_strategy: A string specifying the partitioning strategy.
730        Currently `"div"` and `"mod"` are supported. Default is `"div"`.
731    max_norm: If not `None`, all embeddings are l2-normalized to max_norm before
732        combining.
733
734
735  Returns:
736    Dense `Tensor` of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`.
737
738  Raises:
739    ValueError: if `embedding_weights` is empty.
740  """
741  if embedding_weights is None:
742    raise ValueError('Missing embedding_weights %s.' % embedding_weights)
743  if isinstance(embedding_weights, variables.PartitionedVariable):
744    embedding_weights = list(embedding_weights)  # get underlying Variables.
745  if not isinstance(embedding_weights, list):
746    embedding_weights = [embedding_weights]
747  if len(embedding_weights) < 1:
748    raise ValueError('Missing embedding_weights %s.' % embedding_weights)
749
750  dtype = sparse_weights.dtype if sparse_weights is not None else None
751  embedding_weights = [
752      w if (isinstance(w, resource_variable_ops.ResourceVariable)
753            and dtype in (None, w.dtype))
754      else ops.convert_to_tensor(w, dtype=dtype)
755      for w in embedding_weights
756  ]
757
758  with ops.name_scope(name, 'embedding_lookup',
759                      embedding_weights + [sparse_ids,
760                                           sparse_weights]) as scope:
761    # Reshape higher-rank sparse ids and weights to linear segment ids.
762    original_shape = sparse_ids.dense_shape
763    original_rank_dim = tensor_shape.dimension_value(
764        sparse_ids.dense_shape.get_shape()[0])
765    original_rank = (
766        array_ops.size(original_shape)
767        if original_rank_dim is None
768        else original_rank_dim)
769    sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [
770        math_ops.reduce_prod(
771            array_ops.slice(original_shape, [0], [original_rank - 1])),
772        array_ops.gather(original_shape, original_rank - 1)])
773    if sparse_weights is not None:
774      sparse_weights = sparse_tensor.SparseTensor(
775          sparse_ids.indices,
776          sparse_weights.values, sparse_ids.dense_shape)
777
778    # Prune invalid ids and weights.
779    sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights)
780    if combiner != 'sum':
781      sparse_ids, sparse_weights = _prune_invalid_weights(
782          sparse_ids, sparse_weights)
783
784    # Fill in dummy values for empty features, if necessary.
785    sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(sparse_ids,
786                                                                 default_id or
787                                                                 0)
788    if sparse_weights is not None:
789      sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0)
790
791    result = embedding_lookup_sparse(
792        embedding_weights,
793        sparse_ids,
794        sparse_weights,
795        combiner=combiner,
796        partition_strategy=partition_strategy,
797        name=None if default_id is None else scope,
798        max_norm=max_norm)
799
800    if default_id is None:
801      # Broadcast is_row_empty to the same shape as embedding_lookup_result,
802      # for use in Select.
803      is_row_empty = array_ops.tile(
804          array_ops.reshape(is_row_empty, [-1, 1]),
805          array_ops.stack([1, array_ops.shape(result)[1]]))
806
807      result = array_ops.where(is_row_empty,
808                               array_ops.zeros_like(result),
809                               result,
810                               name=scope)
811
812    # Reshape back from linear ids back into higher-dimensional dense result.
813    final_result = array_ops.reshape(
814        result,
815        array_ops.concat([
816            array_ops.slice(
817                math_ops.cast(original_shape, dtypes.int32), [0],
818                [original_rank - 1]),
819            array_ops.slice(array_ops.shape(result), [1], [-1])
820        ], 0))
821    final_result.set_shape(tensor_shape.unknown_shape(
822        (tensor_shape.Dimension(original_rank_dim) - 1).value).concatenate(
823            result.get_shape()[1:]))
824    return final_result
825
826
827def _prune_invalid_ids(sparse_ids, sparse_weights):
828  """Prune invalid IDs (< 0) from the input ids and weights."""
829  is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
830  if sparse_weights is not None:
831    is_id_valid = math_ops.logical_and(
832        is_id_valid,
833        array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool))
834  sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
835  if sparse_weights is not None:
836    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
837  return sparse_ids, sparse_weights
838
839
840def _prune_invalid_weights(sparse_ids, sparse_weights):
841  """Prune invalid weights (< 0) from the input ids and weights."""
842  if sparse_weights is not None:
843    is_weights_valid = math_ops.greater(sparse_weights.values, 0)
844    sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid)
845    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid)
846  return sparse_ids, sparse_weights
847