1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Embedding functions."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from six.moves import xrange  # pylint: disable=redefined-builtin
21
22from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
23from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
24
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import sparse_tensor
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import clip_ops
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import data_flow_ops
34from tensorflow.python.ops import embedding_ops
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import resource_variable_ops
37from tensorflow.python.ops import sparse_ops
38from tensorflow.python.ops import variables
39from tensorflow.python.platform import tf_logging as logging
40
41__all__ = [
42    "safe_embedding_lookup_sparse", "scattered_embedding_lookup",
43    "scattered_embedding_lookup_sparse", "embedding_lookup_unique",
44    "embedding_lookup_sparse_with_distributed_aggregation"
45]
46
47
48def safe_embedding_lookup_sparse(embedding_weights,
49                                 sparse_ids,
50                                 sparse_weights=None,
51                                 combiner=None,
52                                 default_id=None,
53                                 name=None,
54                                 partition_strategy="div",
55                                 max_norm=None):
56  """Lookup embedding results, accounting for invalid IDs and empty features.
57
58  The partitioned embedding in `embedding_weights` must all be the same shape
59  except for the first dimension. The first dimension is allowed to vary as the
60  vocabulary size is not necessarily a multiple of `P`.  `embedding_weights`
61  may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a
62  partitioner.
63
64  Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
65  with non-positive weight. For an entry with no features, the embedding vector
66  for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
67
68  The ids and weights may be multi-dimensional. Embeddings are always aggregated
69  along the last dimension.
70
71  Args:
72    embedding_weights:  A list of `P` float tensors or values representing
73        partitioned embedding tensors.  Alternatively, a `PartitionedVariable`,
74        created by partitioning along dimension 0.  The total unpartitioned
75        shape should be `[e_0, e_1, ..., e_m]`, where `e_0` represents the
76        vocab size and `e_1, ..., e_m` are the embedding dimensions.
77    sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
78        ids. `d_0` is typically batch size.
79    sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing
80        float weights corresponding to `sparse_ids`, or `None` if all weights
81        are be assumed to be 1.0.
82    combiner: A string specifying how to combine embedding results for each
83        entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
84        the default.
85    default_id: The id to use for an entry with no features.
86    name: A name for this operation (optional).
87    partition_strategy: A string specifying the partitioning strategy.
88        Currently `"div"` and `"mod"` are supported. Default is `"div"`.
89    max_norm: If not None, all embeddings are l2-normalized to max_norm before
90        combining.
91
92
93  Returns:
94    Dense tensor of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`.
95
96  Raises:
97    ValueError: if `embedding_weights` is empty.
98  """
99  if combiner is None:
100    logging.warn("The default value of combiner will change from \"mean\" "
101                 "to \"sqrtn\" after 2016/11/01.")
102    combiner = "mean"
103  if embedding_weights is None:
104    raise ValueError("Missing embedding_weights %s." % embedding_weights)
105  if isinstance(embedding_weights, variables.PartitionedVariable):
106    embedding_weights = list(embedding_weights)  # get underlying Variables.
107  if not isinstance(embedding_weights, list):
108    embedding_weights = [embedding_weights]
109  if len(embedding_weights) < 1:
110    raise ValueError("Missing embedding_weights %s." % embedding_weights)
111
112  dtype = sparse_weights.dtype if sparse_weights is not None else None
113  if isinstance(embedding_weights, variables.PartitionedVariable):
114    embedding_weights = list(embedding_weights)
115  embedding_weights = [
116      ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
117  ]
118
119  contrib_tensor_util.assert_same_float_dtype(embedding_weights +
120                                              [sparse_weights])
121
122  with ops.name_scope(name, "embedding_lookup",
123                      embedding_weights + [sparse_ids,
124                                           sparse_weights]) as scope:
125    # Reshape higher-rank sparse ids and weights to linear segment ids.
126    original_shape = sparse_ids.dense_shape
127    original_rank_dim = tensor_shape.Dimension(tensor_shape.dimension_value(
128        sparse_ids.dense_shape.get_shape()[0]))
129    original_rank = (
130        array_ops.size(original_shape)
131        if original_rank_dim.value is None
132        else original_rank_dim.value)
133    sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [
134        math_ops.reduce_prod(
135            array_ops.slice(original_shape, [0], [original_rank - 1])),
136        array_ops.gather(original_shape, original_rank - 1)])
137    if sparse_weights is not None:
138      sparse_weights = sparse_tensor.SparseTensor(
139          sparse_ids.indices,
140          sparse_weights.values, sparse_ids.dense_shape)
141
142    # Prune invalid ids and weights.
143    sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights)
144    if combiner != "sum":
145      sparse_ids, sparse_weights = _prune_invalid_weights(
146          sparse_ids, sparse_weights)
147
148    # Fill in dummy values for empty features, if necessary.
149    sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(sparse_ids,
150                                                                 default_id or
151                                                                 0)
152    if sparse_weights is not None:
153      sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0)
154
155    result = embedding_ops.embedding_lookup_sparse(
156        embedding_weights,
157        sparse_ids,
158        sparse_weights,
159        combiner=combiner,
160        partition_strategy=partition_strategy,
161        name=None if default_id is None else scope,
162        max_norm=max_norm)
163
164    if default_id is None:
165      # Broadcast is_row_empty to the same shape as embedding_lookup_result,
166      # for use in Select.
167      is_row_empty = array_ops.tile(
168          array_ops.reshape(is_row_empty, [-1, 1]),
169          array_ops.stack([1, array_ops.shape(result)[1]]))
170
171      result = array_ops.where(is_row_empty,
172                               array_ops.zeros_like(result),
173                               result,
174                               name=scope)
175
176    # Reshape back from linear ids back into higher-dimensional dense result.
177    final_result = array_ops.reshape(
178        result,
179        array_ops.concat([
180            array_ops.slice(
181                math_ops.cast(original_shape, dtypes.int32), [0],
182                [original_rank - 1]),
183            array_ops.slice(array_ops.shape(result), [1], [-1])
184        ], 0))
185    final_result.set_shape(tensor_shape.unknown_shape(
186        (original_rank_dim - 1).value).concatenate(result.get_shape()[1:]))
187    return final_result
188
189
190def _prune_invalid_ids(sparse_ids, sparse_weights):
191  """Prune invalid IDs (< 0) from the input ids and weights."""
192  is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
193  if sparse_weights is not None:
194    is_id_valid = math_ops.logical_and(
195        is_id_valid,
196        array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool))
197  sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
198  if sparse_weights is not None:
199    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
200  return sparse_ids, sparse_weights
201
202
203def _prune_invalid_weights(sparse_ids, sparse_weights):
204  """Prune invalid weights (< 0) from the input ids and weights."""
205  if sparse_weights is not None:
206    is_weights_valid = math_ops.greater(sparse_weights.values, 0)
207    sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid)
208    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid)
209  return sparse_ids, sparse_weights
210
211
212def scattered_embedding_lookup(params,
213                               values,
214                               dimension,
215                               name=None,
216                               hash_key=None):
217  """Looks up embeddings using parameter hashing for each value in `values`.
218
219  The i-th embedding component of a value v in `values` is found by retrieving
220  the weight whose index is a fingerprint of the pair (v,i).
221  The concept is explored as "feature hashing" for model compression in this
222  paper: http://arxiv.org/pdf/1504.04788.pdf
223
224  Feature hashing has the pleasant effect of allowing us to compute an embedding
225  without needing a pre-determined vocabulary, relieving some amount of process
226  complexity. It also allows for us to maintain embeddings for possibly
227  trillions of features with a fixed amount of memory.
228
229  Note that this is superior to out-of-vocabulary shared "hash buckets" in that
230  the embedding is extremely likely to be unique for each token as opposed to
231  being shared across probably-colliding tokens. The price is that we must
232  compute a hash once for each scalar in the token's embedding as opposed to
233  once per token.
234
235  If `params` is a list, it represents a partition of the embedding parameters.
236  Each tensor in the list should have the same length, except for the first ones
237  which may have an additional element. For instance 10 parameters can be
238  partitioned in 4 tensors with length `[3, 3, 2, 2]`.
239
240  Args:
241    params: A `Tensor`, `list` of `Tensors`, or `PartitionedVariable`.
242      Each tensor must be of rank 1 with fully-defined shape.
243    values: `Tensor` of values to be embedded with shape `[d0, ..., dn]`.
244    dimension: Embedding dimension.
245    name: An optional name for this op.
246    hash_key: Specify the hash_key that will be used by the `FingerprintCat64`
247      function to combine the crosses fingerprints on SparseFeatureCrossOp
248      (optional).
249
250  Returns:
251    A `Tensor` with shape `[d0, ..., dn, dimension]`.
252
253  Raises:
254    ValueError: if dimension is not positive or the partition size is invalid.
255  """
256  if dimension is None:
257    raise ValueError("You must specify dimension.")
258  return _sampled_scattered_embedding_lookup(
259      params, values, dimension=dimension, sampled_candidates=None,
260      hash_key=hash_key, name=name)
261
262
263def _sampled_scattered_embedding_lookup(
264    params, values, dimension=None, sampled_candidates=None, hash_key=None,
265    name=None):
266  """Looks up embeddings using parameter hashing for each value in `values`.
267
268  This method looks up selected embedding dimensions if `sampled_candidates` is
269  given, otherwise looks up all dimensions.
270
271  The i-th embedding component of a value v in `values` is found by retrieving
272  the weight whose index is a fingerprint of the pair (v,i).
273  The concept is explored as "feature hashing" for model compression in this
274  paper: http://arxiv.org/pdf/1504.04788.pdf
275
276  Feature hashing has the pleasant effect of allowing us to compute an embedding
277  without needing a pre-determined vocabulary, relieving some amount of process
278  complexity. It also allows for us to maintain embeddings for possibly
279  trillions of features with a fixed amount of memory.
280
281  Note that this is superior to out-of-vocabulary shared "hash buckets" in that
282  the embedding is extremely likely to be unique for each token as opposed to
283  being shared across probably-colliding tokens. The price is that we must
284  compute a hash once for each scalar in the token's embedding as opposed to
285  once per token.
286
287  If `params` is a list, it represents a partition of the embedding parameters.
288  Each tensor in the list should have the same length, except for the first ones
289  which may have an additional element. For instance 10 parameters can be
290  partitioned in 4 tensors with length `[3, 3, 2, 2]`.
291
292  Args:
293    params: A `Tensor`, `list` of `Tensors`, or `PartitionedVariable`.
294      Each tensor must be of rank 1 with fully-defined shape.
295    values: `Tensor` of values to be embedded with shape `[d0, ..., dn]`.
296    dimension: Embedding dimension. The user must specify either `dimension` or
297      `sampled_candidates`.
298    sampled_candidates: An optional `Tensor` of slice indices to keep along the
299      final dimension with shape `[d0, ..., dn, N]`. If given, `dimension` is
300      ignored. If `None`, looks up all candidates.
301    hash_key: Specify the hash_key that will be used by the `FingerprintCat64`
302      function to combine the crosses fingerprints on SparseFeatureCrossOp
303      (optional).
304    name: An optional name for this op.
305
306  Returns:
307    A `Tensor` with shape `[d0, ..., dn, dimension]`.
308    If `sampled_candidates` is given, the output shape is `[d0, ..., dn, N]`
309
310  Raises:
311    ValueError: if dimension is not positive or the partition size is invalid.
312  """
313  if isinstance(params, variables.PartitionedVariable):
314    params = list(params)
315  if not isinstance(params, list):
316    params = [params]
317
318  with ops.name_scope(name, "scattered_embedding_lookup",
319                      params + [dimension, values]):
320    # Flatten the values
321    values_shape = array_ops.shape(values)
322    values = array_ops.reshape(values, [-1, 1])
323
324    if sampled_candidates is None:
325      if dimension is None:
326        raise ValueError(
327            "You must specify either dimension or sampled_candidates.")
328      if dimension <= 0:
329        raise ValueError("Dimension must be >0. Given is %d" % dimension)
330      sampled_candidates = array_ops.tile(array_ops.expand_dims(
331          math_ops.range(0, dimension), 0), array_ops.shape(values))
332    else:
333      dimension = array_ops.shape(sampled_candidates)[
334          math_ops.subtract(array_ops.rank(sampled_candidates), 1)]
335      sampled_candidates_shape = array_ops.shape(sampled_candidates)
336      dimension_tensor = array_ops.reshape(dimension, shape=[1,])
337      expected_shape = array_ops.concat([values_shape, dimension_tensor], 0)
338      with ops.control_dependencies([control_flow_ops.Assert(
339          math_ops.reduce_all(math_ops.equal(sampled_candidates_shape,
340                                             expected_shape)),
341          ["The shape of sampled_candidates: ", sampled_candidates_shape,
342           " does not match the shape of values: ", values_shape])]):
343        # Flatten sampled_candidates, same way as values are flattened.
344        sampled_candidates = array_ops.reshape(sampled_candidates,
345                                               [-1, dimension])
346
347    num_partitions = len(params)
348    partition_sizes = []
349    for p in range(num_partitions):
350      shape = params[p].get_shape()
351      shape.assert_has_rank(1)
352      shape.assert_is_fully_defined()
353      partition_sizes.append(tensor_shape.dimension_value(shape[0]))
354    num_params = sum(partition_sizes)  # Total number of parameters.
355
356    # Assert the size of each partition.
357    for p in range(num_partitions):
358      expected_size = (num_params - p - 1) // num_partitions + 1
359      if partition_sizes[p] != expected_size:
360        raise ValueError("Tensor %d in params has size %d, expected %d." %
361                         (p, partition_sizes[p], expected_size))
362
363    # With two values v1 and v2 and 3 dimensions, we will cross
364    # [[0, 1, 2], [0, 1, 2]] with [[v1], [v2]].
365    tensors_to_cross = [sampled_candidates, values]
366    ids = sparse_feature_cross_op.sparse_feature_cross(
367        tensors_to_cross, hashed_output=True, num_buckets=num_params,
368        hash_key=hash_key)
369    ids = sparse_ops.sparse_tensor_to_dense(ids)
370
371    # No need to validate the indices since we have checked the params
372    # dimensions and we know the largest id.
373    result = embedding_ops.embedding_lookup(
374        params, ids, partition_strategy="div")
375
376    return array_ops.reshape(result,
377                             array_ops.concat([values_shape, [dimension]], 0))
378
379
380def scattered_embedding_lookup_sparse(params,
381                                      sparse_values,
382                                      dimension,
383                                      combiner=None,
384                                      default_value=None,
385                                      name=None,
386                                      hash_key=None):
387  """Looks up embeddings of a sparse feature using parameter hashing.
388
389  See `tf.contrib.layers.scattered_embedding_lookup` for embedding with hashing.
390
391  Args:
392    params: A `Tensor`, `list` of `Tensors`, or `PartitionedVariable`.
393      Each tensor must be of rank 1 with fully-defined shape.
394    sparse_values: A 2-D `SparseTensor` containing the values to be embedded.
395      Some rows may be empty.
396    dimension: Embedding dimension
397    combiner: A string specifying how to combine embedding results for each
398        entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
399        the default.
400    default_value: The value to use for an entry with no features.
401    name: An optional name for this op.
402    hash_key: Specify the hash_key that will be used by the `FingerprintCat64`
403      function to combine the crosses fingerprints on SparseFeatureCrossOp
404      (optional).
405
406  Returns:
407     Dense tensor with shape [N, dimension] with N the number of rows in
408       sparse_values.
409
410  Raises:
411    TypeError: If sparse_values is not a SparseTensor.
412    ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}.
413  """
414  if combiner is None:
415    logging.warn("The default value of combiner will change from \"mean\" "
416                 "to \"sqrtn\" after 2016/11/01.")
417    combiner = "mean"
418  if isinstance(params, variables.PartitionedVariable):
419    params = list(params)
420  if not isinstance(params, list):
421    params = [params]
422  if not isinstance(sparse_values, sparse_tensor.SparseTensor):
423    raise TypeError("sparse_values must be SparseTensor")
424
425  with ops.name_scope(name, "scattered_embedding_lookup_sparse",
426                      params + [sparse_values]) as scope:
427    # Fill in the empty rows.
428    if default_value is None:
429      # Random default values to reduce the risk of collision.
430      if sparse_values.dtype == dtypes.string:
431        default_value = "6ZxWzWOHxZ"
432      else:
433        default_value = 1288896567
434    sparse_values, _ = sparse_ops.sparse_fill_empty_rows(
435        sparse_values, default_value)
436
437    segment_ids = sparse_values.indices[:, 0]
438    if segment_ids.dtype != dtypes.int32:
439      segment_ids = math_ops.cast(segment_ids, dtypes.int32)
440
441    values = sparse_values.values
442    values, idx = array_ops.unique(values)
443
444    embeddings = scattered_embedding_lookup(
445        params, values, dimension, hash_key=hash_key)
446
447    if combiner == "sum":
448      embeddings = math_ops.sparse_segment_sum(embeddings, idx, segment_ids,
449                                               name=scope)
450    elif combiner == "mean":
451      embeddings = math_ops.sparse_segment_mean(embeddings, idx, segment_ids,
452                                                name=scope)
453    elif combiner == "sqrtn":
454      embeddings = math_ops.sparse_segment_sqrt_n(embeddings, idx, segment_ids,
455                                                  name=scope)
456    else:
457      raise ValueError("Combiner must be one of 'mean', 'sqrtn' or 'sum'.")
458
459    return embeddings
460
461
462def embedding_lookup_unique(params, ids, partition_strategy="mod", name=None):
463  """Version of embedding_lookup that avoids duplicate lookups.
464
465  This can save communication in the case of repeated ids.
466  Same interface as embedding_lookup. Except it supports multi-dimensional `ids`
467  which allows to not reshape input/output to fit gather.
468
469  Args:
470    params: A list of tensors with the same shape and type, or a
471      `PartitionedVariable`. Shape `[index, d1, d2, ...]`.
472    ids: A one-dimensional `Tensor` with type `int32` or `int64` containing
473      the ids to be looked up in `params`. Shape `[ids1, ids2, ...]`.
474    partition_strategy: A string specifying the partitioning strategy, relevant
475      if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
476      is `"mod"`.
477    name: A name for this operation (optional).
478
479  Returns:
480    A `Tensor` with the same type as the tensors in `params` and dimension of
481    `[ids1, ids2, d1, d2, ...]`.
482
483  Raises:
484    ValueError: If `params` is empty.
485  """
486  with ops.name_scope(name, "EmbeddingLookupUnique", [params, ids]):
487    ids = ops.convert_to_tensor(ids)
488    shape = array_ops.shape(ids)
489    ids_flat = array_ops.reshape(
490        ids, math_ops.reduce_prod(shape, keepdims=True))
491    unique_ids, idx = array_ops.unique(ids_flat)
492    unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids,
493                                                       partition_strategy)
494    embeds_flat = array_ops.gather(unique_embeddings, idx)
495    embed_shape = array_ops.concat(
496        [shape, array_ops.shape(unique_embeddings)[1:]], 0)
497    embeds = array_ops.reshape(embeds_flat, embed_shape)
498    embeds.set_shape(ids.get_shape().concatenate(
499        unique_embeddings.get_shape()[1:]))
500    return embeds
501
502
503def _sampled_scattered_embedding_lookup_sparse(params,
504                                               sp_values,
505                                               dimension=None,
506                                               sampled_candidates=None,
507                                               hash_key=None,
508                                               with_sign_hash=False,
509                                               name=None):
510  """Looks up embeddings using parameter hashing for sparse values.
511
512  This method looks up selected embedding dimensions if `sampled_candidates` is
513  given, otherwise looks up all dimensions.
514
515  The i-th embedding component of a value v in `values` is found by retrieving
516  the weight whose index is a fingerprint of the pair (v,i).
517  The concept is explored as "feature hashing" for model compression in this
518  paper: http://arxiv.org/pdf/1504.04788.pdf
519
520  This is logically equivalent to:
521  * Transforming `sp_values` (which has shape `[d0, d1]`) into a one-hot
522    `Tensor` of shape `[d0, N]`.
523  * Multiplying with a `Tensor` `h` of shape `[N, dimension]`, where
524    `h(i, j) = params[hash(i, j)]`.
525
526  Args:
527    params: A float `Tensor` with rank 1 and fully-defined shape.
528    sp_values: A 2D `SparseTensor` to be embedded with shape `[d0, d1]`.
529    dimension: An int `Tensor` of the final dimension. The user needs to provide
530      either `dimension` or `sampled_candidates`.
531    sampled_candidates: An optional `Tensor` of column indices to keep along
532      the final dimension with shape `[d0, N]`. If given, `dimension` is
533      ignored. If `None`, looks up all candidates.
534    hash_key: Specify the hash_key that will be used by the `FingerprintCat64`
535      function to combine the crosses fingerprints on SparseFeatureCrossOp
536      (optional).
537    with_sign_hash:  A `bool` indicating whether `h(i, j)` should be multiplied
538      by `+1` or `-1`, where the value selected is determined by hashing
539      `(i, j)`. This is often necessary to remove bias resulting from hash
540      collisions.
541    name: An optional name for this op.
542
543  Returns:
544    A `Tensor` of shape `[d0, dimension]`.
545    If `sampled_candidates` is given, the output shape is `[d0, N]`.
546
547  Raises:
548    TypeError: If sp_values is not `SparseTensor`.
549    ValueError: If both `dimension` and `sampled_candidates` are `None`.
550  """
551  if not isinstance(sp_values, sparse_tensor.SparseTensor):
552    raise TypeError("sp_values must be SparseTensor")
553
554  with ops.name_scope(
555      name=name,
556      default_name="sampled_scattered_embedding_lookup_sparse",
557      values=[sp_values, params, dimension, sampled_candidates]) as name_scope:
558    segment_ids = sp_values.indices[:, 0]
559    if sampled_candidates is not None:
560      # Tile sampled_candidates so there is one line corresponding to each
561      # element in sp_values.values
562      sampled_candidates = array_ops.gather(sampled_candidates, segment_ids)
563
564    embeddings = _sampled_scattered_embedding_lookup(
565        params, sp_values.values, dimension=dimension,
566        sampled_candidates=sampled_candidates,
567        hash_key=hash_key, name="values_lookup")
568    if with_sign_hash:
569      signs = _sampled_scattered_embedding_lookup(
570          array_ops.constant([-1., 1.]), sp_values.values, dimension=dimension,
571          sampled_candidates=sampled_candidates, hash_key=hash_key,
572          name="signs_lookup")
573      embeddings = math_ops.multiply(signs, embeddings, name="signs_hash")
574
575    if segment_ids.dtype != dtypes.int32:
576      segment_ids = math_ops.cast(segment_ids, dtypes.int32)
577    num_segments = array_ops.shape(sp_values)[0]
578
579    return math_ops.unsorted_segment_sum(embeddings, segment_ids,
580                                         num_segments=num_segments,
581                                         name=name_scope)
582
583
584def embedding_lookup_sparse_with_distributed_aggregation(
585    params,
586    sp_ids,
587    sp_weights,
588    partition_strategy="mod",
589    name=None,
590    combiner=None,
591    max_norm=None):
592  """Computes embeddings for the given ids and weights.
593
594  Embeddings belonging to same param are aggregated on that device first. This
595  op is intended to decrease data transmission and improve parallelism. See
596  `tf.nn.embedding_lookup_sparse` for the functionality and example of this op.
597
598  Args:
599    params: A single tensor representing the complete embedding tensor,
600      or a list of P tensors all of same shape except for the first dimension,
601      representing sharded embedding tensors.  Alternatively, a
602      `PartitionedVariable`, created by partitioning along dimension 0. Each
603      element must be appropriately sized for the given `partition_strategy`.
604    sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId),
605      where N is typically batch size and M is arbitrary.
606    sp_weights: either a SparseTensor of float / double weights, or None to
607      indicate all weights should be taken to be 1. If specified, sp_weights
608      must have exactly the same shape and indices as sp_ids.
609    partition_strategy: A string specifying the partitioning strategy, relevant
610      if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
611      is `"mod"`. See `tf.nn.embedding_lookup` for more details.
612    name: Optional name for the op.
613    combiner: A string specifying the reduction op. Currently "mean", "sqrtn"
614      and "sum" are supported.
615      "sum" computes the weighted sum of the embedding results for each row.
616      "mean" is the weighted sum divided by the total weight.
617      "sqrtn" is the weighted sum divided by the square root of the sum of the
618      squares of the weights.
619    max_norm: If not None, each embedding is normalized to have l2 norm equal
620      to max_norm before combining.
621
622  Returns:
623    A dense tensor representing the combined embeddings for the
624    sparse ids. For each row in the dense tensor represented by sp_ids, the op
625    looks up the embeddings for all ids in that row, multiplies them by the
626    corresponding weight, and combines these embeddings as specified.
627
628  Raises:
629    TypeError: If sp_ids is not a SparseTensor, or if sp_weights is neither
630      None nor SparseTensor.
631    ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}.
632  """
633  if combiner is None:
634    logging.warn("The default value of combiner will change from \"mean\" "
635                 "to \"sqrtn\" after 2016/11/01.")
636    combiner = "mean"
637  if combiner not in ("mean", "sqrtn", "sum"):
638    raise ValueError("combiner must be one of 'mean', 'sqrtn' or 'sum'")
639  if isinstance(params, variables.PartitionedVariable):
640    params = list(params)  # Iterate to get the underlying Variables.
641  if not isinstance(params, list):
642    params = [params]
643  if not isinstance(sp_ids, sparse_tensor.SparseTensor):
644    raise TypeError("sp_ids must be SparseTensor")
645  ignore_weights = sp_weights is None
646  if not ignore_weights:
647    if not isinstance(sp_weights, sparse_tensor.SparseTensor):
648      raise TypeError("sp_weights must be either None or SparseTensor")
649    sp_ids.values.get_shape().assert_is_compatible_with(
650        sp_weights.values.get_shape())
651    sp_ids.indices.get_shape().assert_is_compatible_with(
652        sp_weights.indices.get_shape())
653    sp_ids.dense_shape.get_shape().assert_is_compatible_with(
654        sp_weights.dense_shape.get_shape())
655    # TODO(yleon): Add enhanced node assertions to verify that sp_ids and
656    # sp_weights have equal indices and shapes.
657
658  with ops.name_scope(name, "embedding_lookup_sparse",
659                      params + [sp_ids]) as name:
660    segment_ids = sp_ids.indices[:, 0]
661    if segment_ids.dtype != dtypes.int32:
662      segment_ids = math_ops.cast(segment_ids, dtypes.int32)
663
664    ids = sp_ids.values
665    if ignore_weights:
666      ids, idx = array_ops.unique(ids)
667    else:
668      idx = None
669
670    weights = None if ignore_weights else sp_weights.values
671    embeddings = _embedding_lookup_with_distributed_aggregation(
672        params,
673        ids,
674        partition_strategy=partition_strategy,
675        max_norm=max_norm,
676        weights=weights,
677        idx=idx,
678        segment_ids=segment_ids)
679    # Set weights to all one if ignore weights.
680    if ignore_weights:
681      weights = array_ops.fill([array_ops.shape(segment_ids)[0]], 1)
682    if weights.dtype != embeddings.dtype:
683      weights = math_ops.cast(weights, embeddings.dtype)
684    # Reshape weights.
685    ones = array_ops.fill(
686        array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)
687    bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones], 0)
688    orig_weights_shape = weights.get_shape()
689    weights = array_ops.reshape(weights, bcast_weights_shape)
690    if embeddings.get_shape().ndims is not None:
691      weights.set_shape(
692          orig_weights_shape.concatenate(
693              [1 for _ in range(embeddings.get_shape().ndims - 1)]))
694
695    if combiner == "mean":
696      weight_sum = math_ops.segment_sum(weights, segment_ids)
697      embeddings = math_ops.div(embeddings, weight_sum)
698    elif combiner == "sqrtn":
699      weights_squared = math_ops.pow(weights, 2)
700      weight_sum = math_ops.segment_sum(weights_squared, segment_ids)
701      weight_sum_sqrt = math_ops.sqrt(weight_sum)
702      embeddings = math_ops.div(embeddings, weight_sum_sqrt)
703    elif combiner != "sum":
704      assert False, "Unrecognized combiner"
705    return embeddings
706
707
708def _do_gather(params, ids, name=None):
709  """Deals with doing gather differently for resource variables."""
710  if isinstance(params, resource_variable_ops.ResourceVariable):
711    return params.sparse_read(ids, name=name)
712  return array_ops.gather(params, ids, name=name)
713
714
715def _embedding_lookup_with_distributed_aggregation(params,
716                                                   ids,
717                                                   partition_strategy="mod",
718                                                   name=None,
719                                                   max_norm=None,
720                                                   weights=None,
721                                                   idx=None,
722                                                   segment_ids=None):
723  """Lookup helper for embedding_lookup_sparse_with_distributed_aggregation."""
724  if params is None or params == []:  # pylint: disable=g-explicit-bool-comparison
725    raise ValueError("Need at least one param")
726  if isinstance(params, variables.PartitionedVariable):
727    params = list(params)  # Iterate to get the underlying Variables.
728  if not isinstance(params, list):
729    params = [params]
730
731  def maybe_normalize(x):
732    if max_norm is not None:
733      if x.get_shape().ndims is not None:
734        ndims = x.get_shape().ndims
735      else:
736        ndims = array_ops.size(array_ops.shape(x))
737      return clip_ops.clip_by_norm(x, max_norm, axes=list(range(1, ndims)))
738    return x
739
740  with ops.name_scope(name, "embedding_lookup_with_distributed_aggregation",
741                      params + [ids]) as name:
742    np = len(params)  # Number of partitions
743    # Preserve the resource variable status to avoid accidental dense reads.
744    if not any(
745        isinstance(p, resource_variable_ops.ResourceVariable) for p in params):
746      params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
747    if np == 1:
748      with ops.colocate_with(params[0]):
749        ret = maybe_normalize(_do_gather(params[0], ids))
750        ignore_weights = weights is None
751        if not ignore_weights:
752          if weights.dtype != ret.dtype:
753            weights = math_ops.cast(weights, ret.dtype)
754          # Reshape to allow broadcast
755          ones = array_ops.fill(
756              array_ops.expand_dims(array_ops.rank(ret) - 1, 0), 1)
757          bcast_weights_shape = array_ops.concat(
758              [array_ops.shape(weights), ones], 0)
759          orig_weights_shape = weights.get_shape()
760          weights = array_ops.reshape(weights, bcast_weights_shape)
761          # Set weights shape after reshape
762          if ret.get_shape().ndims is not None:
763            weights.set_shape(
764                orig_weights_shape.concatenate(
765                    [1 for _ in range(ret.get_shape().ndims - 1)]))
766          ret *= weights
767          return math_ops.segment_sum(ret, segment_ids, name=name)
768        else:
769          return math_ops.sparse_segment_sum(ret, idx, segment_ids, name=name)
770    else:
771      ids = ops.convert_to_tensor(ids, name="ids")
772      flat_ids = array_ops.reshape(ids, [-1])
773      original_indices = math_ops.range(array_ops.size(flat_ids))
774
775      # Create p_assignments and set new_ids depending on the strategy.
776      if partition_strategy == "mod":
777        p_assignments = flat_ids % np
778        new_ids = flat_ids // np
779      elif partition_strategy == "div":
780        # Compute num_total_ids as the sum of dim-0 of params, then assign to
781        # partitions based on a constant number of ids per partition. Optimize
782        # if we already know the full shape statically.
783        dim_0_size = params[0].get_shape().dims[0]
784        for p in xrange(1, np):
785          dim_0_size += params[p].get_shape().dims[0]
786        if dim_0_size.value:
787          num_total_ids = constant_op.constant(dim_0_size, flat_ids.dtype)
788        else:
789          dim_0_sizes = []
790          for p in xrange(np):
791            if params[p].get_shape().dims[0].value is not None:
792              dim_0_sizes.append(params[p].get_shape().dims[0].value)
793            else:
794              with ops.colocate_with(params[p]):
795                dim_0_sizes.append(array_ops.shape(params[p])[0])
796          num_total_ids = math_ops.reduce_sum(
797              math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
798        ids_per_partition = num_total_ids // np
799        extras = num_total_ids % np
800
801        p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1), (
802            flat_ids - extras) // ids_per_partition)
803
804        # Emulate a conditional using a boolean indicator tensor
805        is_in_first_extras_partitions = math_ops.cast(p_assignments < extras,
806                                                      flat_ids.dtype)
807        new_ids = (is_in_first_extras_partitions * (flat_ids %
808                                                    (ids_per_partition + 1)) +
809                   (1 - is_in_first_extras_partitions) * (
810                       (flat_ids - extras) % ids_per_partition))
811      else:
812        raise ValueError("Unrecognized partition strategy: " +
813                         partition_strategy)
814
815      # Cast partition assignments to int32 for use in dynamic_partition.
816      # There really should not be more than 2^32 partitions.
817      p_assignments = math_ops.cast(p_assignments, dtypes.int32)
818      # Partition list of ids based on assignments into np separate lists
819      gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np)
820      # Similarly, partition the original indices.
821      pindices = data_flow_ops.dynamic_partition(original_indices,
822                                                 p_assignments, np)
823      # Do np separate lookups, finding embeddings for plist[p] in params[p]
824      partitioned_result = []
825      for p in xrange(np):
826        with ops.colocate_with(params[p]):
827          partitioned_result.append(_do_gather(params[p], gather_ids[p]))
828
829      ignore_weights = weights is None
830      if not ignore_weights:
831        # Partition weights according to pindices.
832        partitioned_weight = []
833        for p in xrange(np):
834          partitioned_weight.append(array_ops.gather(weights, pindices[p]))
835      # Reshape each partition result.
836      element_shape = params[0].get_shape()[1:]
837      for p in params[1:]:
838        element_shape = element_shape.merge_with(p.get_shape()[1:])
839      if element_shape.is_fully_defined():
840        for p in xrange(np):
841          with ops.colocate_with(params[p]):
842            partitioned_result[p] = array_ops.reshape(
843                partitioned_result[p],
844                array_ops.concat([array_ops.shape(pindices[p]), element_shape],
845                                 0))
846      else:
847        with ops.colocate_with(params[0]):
848          params_shape = array_ops.shape(params[0])
849        for p in xrange(np):
850          with ops.colocate_with(params[p]):
851            partitioned_result[p] = array_ops.reshape(
852                partitioned_result[p],
853                array_ops.concat([
854                    array_ops.shape(pindices[p]), array_ops.slice(
855                        params_shape, [1], [-1])
856                ], 0))
857      # Normalize each partition result.
858      for p in xrange(np):
859        with ops.colocate_with(params[p]):
860          partitioned_result[p] = maybe_normalize(partitioned_result[p])
861      if not ignore_weights:
862        # Multiply each partition result with partition weights.
863        for p in xrange(np):
864          with ops.colocate_with(params[p]):
865            if partitioned_weight[p].dtype != partitioned_result[p].dtype:
866              partitioned_weight[p] = math_ops.cast(partitioned_weight[p],
867                                                    partitioned_result[p].dtype)
868            # Reshape partition weights.
869            ones = array_ops.fill(
870                array_ops.expand_dims(
871                    array_ops.rank(partitioned_result[p]) - 1, 0), 1)
872            bcast_weights_shape = array_ops.concat(
873                [array_ops.shape(partitioned_weight[p]), ones], 0)
874            orig_weights_shape = partitioned_weight[p].get_shape()
875            partitioned_weight[p] = array_ops.reshape(partitioned_weight[p],
876                                                      bcast_weights_shape)
877            if partitioned_result[p].get_shape().ndims is not None:
878              partitioned_weight[p].set_shape(
879                  orig_weights_shape.concatenate([
880                      1
881                      for _ in range(partitioned_result[p].get_shape().ndims -
882                                     1)
883                  ]))
884            partitioned_result[p] *= partitioned_weight[p]
885      partitioned_segment_ids = []
886      for p in xrange(np):
887        if not ignore_weights:
888          # Partition segment_ids according to pindices.
889          p_segment_ids = array_ops.gather(segment_ids, pindices[p])
890          # Number the p_segment_ids to meet segment_sum's requirements. Note
891          # that unique_p_segment_ids contains unique segment ids of this
892          # partition and these ids' order is unchanged.
893          unique_p_segment_ids, unique_p_segment_idx = array_ops.unique(
894              p_segment_ids)
895          partitioned_segment_ids.append(unique_p_segment_ids)
896          # segment_sum this partition's result.
897          with ops.colocate_with(params[p]):
898            partitioned_result[p] = math_ops.segment_sum(
899                partitioned_result[p], unique_p_segment_idx)
900        else:
901          # When ignore weights, we need to get indexs of elements in idx and
902          # segment_ids.
903          _, exclude_idx = array_ops.setdiff1d(idx, pindices[p])
904          all_idx = math_ops.range(array_ops.shape(idx)[0])
905          _, include_idx = array_ops.setdiff1d(all_idx, exclude_idx)
906          # Gather segment_ids and idx according to indexs.
907          p_segment_ids = array_ops.gather(segment_ids, include_idx)
908          p_idx = array_ops.gather(idx, include_idx)
909          # Number the p_segment_ids, same as ignore_weights case above.
910          unique_p_segment_ids, unique_p_segment_idx = array_ops.unique(
911              p_segment_ids)
912          _, unique_p_idx_idx = array_ops.unique(p_idx)
913          partitioned_segment_ids.append(unique_p_segment_ids)
914          with ops.colocate_with(params[p]):
915            partitioned_result[p] = math_ops.sparse_segment_sum(
916                partitioned_result[p], unique_p_idx_idx, unique_p_segment_idx)
917      # Concat each partition's segment_ids and result for final segment_sum.
918      concat_segment_ids = array_ops.concat(partitioned_segment_ids, 0)
919      concat_partitioned_result = array_ops.concat(partitioned_result, 0)
920      return math_ops.unsorted_segment_sum(
921          concat_partitioned_result,
922          concat_segment_ids,
923          math_ops.reduce_max(concat_segment_ids) + 1,
924          name=name)
925