1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operations for embeddings."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from six.moves import xrange  # pylint: disable=redefined-builtin
21
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import sparse_tensor
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import clip_ops
28# Imports gradient definitions.
29from tensorflow.python.ops import data_flow_grad  # pylint: disable=unused-import
30from tensorflow.python.ops import data_flow_ops
31from tensorflow.python.ops import math_ops
32from tensorflow.python.ops import resource_variable_ops
33from tensorflow.python.ops import variables
34from tensorflow.python.platform import tf_logging as logging
35from tensorflow.python.util.tf_export import tf_export
36
37
38def _gather(params, ids, name=None):
39  """Helper function for _embedding_lookup_and_transform.
40
41  This function gathers embeddings from a single tensor. The gather deals with
42  resource variables specially.
43
44  Args:
45    params: A `Tensor` of embeddings.
46    ids: A `Tensor` indexing the embeddings to be retrieved from `params`.
47    name: A name for the operation (optional).
48
49  Returns:
50    A `Tensor` with the same type as `params`.
51  """
52  if isinstance(params, resource_variable_ops.ResourceVariable):
53    return params.sparse_read(ids, name=name)
54  else:
55    return array_ops.gather(params, ids, name=name)
56
57
58def _clip(params, ids, max_norm):
59  """Helper function for _embedding_lookup_and_transform.
60
61  This function optionally clips embeddings to an l2-norm of max_norm.
62
63  Args:
64    params: A `Tensor` of embeddings retrieved by `_gather`.
65    ids: The `ids` argument that was passed to `_gather`.
66    max_norm: If provided, the embeddings are l2-normalized to the value of
67      max_norm.
68
69  Returns:
70    A `Tensor` with the same type as `params`.
71  """
72
73  def _rank(x):
74    """Helper function to retrieve the rank of a tensor.
75
76    Args:
77      x: Something convertible to `Tensor`.
78
79    Returns:
80      Either a pair `(rank, True)` where `rank` is an integer or a pair
81      `(rank, False)` where `rank` is an integer `Tensor`. In either case,
82      `rank` is the rank of `x`.
83    """
84    rank = ops.convert_to_tensor(x).get_shape().ndims
85    if rank:
86      return rank, True
87    else:
88      return array_ops.rank(x), False
89
90  if max_norm is None:
91    return params
92  ids_rank, ids_static = _rank(ids)
93  params_rank, params_static = _rank(params)
94  return clip_ops.clip_by_norm(
95      params,
96      max_norm,
97      axes=(list(range(ids_rank, params_rank))
98            if ids_static and params_static
99            else math_ops.range(ids_rank, params_rank)))
100
101
102def _embedding_lookup_and_transform(params,
103                                    ids,
104                                    partition_strategy="mod",
105                                    name=None,
106                                    max_norm=None,
107                                    transform_fn=None):
108  """Helper function for embedding_lookup and _compute_sampled_logits.
109
110  This function is a generalization of embedding_lookup that optionally
111  applies a caller-specified transformation to each embedding. This is
112  done through the `transform_fn` argument. If provided, the function is
113  applied to each partitioned tensor of retrieved embeddings, colocated
114  with the embeddings. This function will be called with a single `Tensor`
115  argument of the same type as the `params` tensor and should return a
116  `Tensor`. The shape of the argument will be the same as `params` except
117  for the size of the first dimension. The first dimension of the result's
118  shape must be the same size as the argument's.
119
120  Args:
121    params: See embedding_lookup.
122    ids: See embedding_lookup.
123    partition_strategy: See embedding_lookup.
124    name: See embedding_lookup.
125    max_norm: See embedding_lookup.
126    transform_fn: An optional function to apply to each retrieved embedding.
127      If max_norm is provided, transform_fn is applied to the norm-limited
128      embeddings.
129
130  Returns:
131    See embedding_lookup for details.
132  Raises:
133    ValueError: If `params` is empty.
134  """
135  if params is None or params in ((), []):
136    raise ValueError("Need at least one param")
137  if isinstance(params, variables.PartitionedVariable):
138    params = list(params)  # Iterate to get the underlying Variables.
139  if not isinstance(params, list):
140    params = [params]
141
142  with ops.name_scope(name, "embedding_lookup", params + [ids]) as name:
143    np = len(params)  # Number of partitions
144    # Preserve the resource variable status to avoid accidental dense reads.
145    if not any(
146        isinstance(p, resource_variable_ops.ResourceVariable) for p in params):
147      params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
148    ids = ops.convert_to_tensor(ids, name="ids")
149    if np == 1 and (not transform_fn or ids.get_shape().ndims == 1):
150      with ops.colocate_with(params[0]):
151        result = _clip(_gather(params[0], ids, name=name), ids, max_norm)
152        if transform_fn:
153          result = transform_fn(result)
154        return result
155    else:
156      # Flatten the ids. There are two cases where we need to do this.
157      # - There is more than one params tensor.
158      # - There is a transform_fn and ids is not statically known to be 1-D.
159      #   We must flatten in this case because transform_fn expects a flat
160      #   tensor of embeddings.
161      flat_ids = array_ops.reshape(ids, [-1])
162      original_indices = math_ops.range(array_ops.size(flat_ids))
163
164      # Create p_assignments and set new_ids depending on the strategy.
165      if partition_strategy == "mod":
166        p_assignments = flat_ids % np
167        new_ids = flat_ids // np
168      elif partition_strategy == "div":
169        # Compute num_total_ids as the sum of dim-0 of params, then assign to
170        # partitions based on a constant number of ids per partition. Optimize
171        # if we already know the full shape statically.
172        dim_0_size = params[0].get_shape()[0]
173        for p in xrange(1, np):
174          dim_0_size += params[p].get_shape()[0]
175        if dim_0_size.value:
176          num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
177        else:
178          dim_0_sizes = []
179          for p in xrange(np):
180            if params[p].get_shape()[0].value is not None:
181              dim_0_sizes.append(params[p].get_shape()[0].value)
182            else:
183              with ops.colocate_with(params[p]):
184                dim_0_sizes.append(array_ops.shape(params[p])[0])
185          num_total_ids = math_ops.reduce_sum(
186              math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
187        ids_per_partition = num_total_ids // np
188        extras = num_total_ids % np
189
190        p_assignments = math_ops.maximum(
191            flat_ids // (ids_per_partition + 1),
192            (flat_ids - extras) // ids_per_partition)
193
194        # Emulate a conditional using a boolean indicator tensor
195        new_ids = array_ops.where(p_assignments < extras,
196                                  flat_ids % (ids_per_partition + 1),
197                                  (flat_ids - extras) % ids_per_partition)
198      else:
199        raise ValueError("Unrecognized partition strategy: " +
200                         partition_strategy)
201
202      # Cast partition assignments to int32 for use in dynamic_partition.
203      # There really should not be more than 2^32 partitions.
204      p_assignments = math_ops.cast(p_assignments, dtypes.int32)
205      # Partition list of ids based on assignments into np separate lists
206      gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np)
207      # Similarly, partition the original indices.
208      pindices = data_flow_ops.dynamic_partition(original_indices,
209                                                 p_assignments, np)
210      # Do np separate lookups, finding embeddings for plist[p] in params[p]
211      partitioned_result = []
212      for p in xrange(np):
213        pids = gather_ids[p]
214        with ops.colocate_with(params[p]):
215          result = _gather(params[p], pids)
216          if transform_fn:
217            # If transform_fn is provided, the clip_by_norm precedes
218            # the transform and hence must be co-located. See below
219            # for the counterpart if transform_fn is not proveded.
220            result = transform_fn(_clip(result, pids, max_norm))
221        partitioned_result.append(result)
222      # Stitch these back together
223      ret = data_flow_ops.parallel_dynamic_stitch(
224          pindices, partitioned_result, name=name)
225
226      # Determine the static element shape.
227      if transform_fn is None:
228        element_shape_s = params[0].get_shape()[1:]
229        for p in params[1:]:
230          element_shape_s = element_shape_s.merge_with(p.get_shape()[1:])
231      else:
232        element_shape_s = ret.get_shape()[1:]
233
234      # Compute the dynamic element shape.
235      if element_shape_s.is_fully_defined():
236        element_shape_d = element_shape_s
237      elif transform_fn is None:
238        # It's important that we compute params[0].shape on the right device
239        # to avoid data motion.
240        with ops.colocate_with(params[0]):
241          params_shape = array_ops.shape(params[0])
242        element_shape_d = params_shape[1:]
243      else:
244        element_shape_d = array_ops.shape(ret)[1:]
245
246      # Reshape to reverse the flattening of ids.
247      ret = array_ops.reshape(ret,
248                              array_ops.concat(
249                                  [array_ops.shape(ids), element_shape_d], 0))
250
251      # Normally the reshape is sufficient, but setting shape explicitly
252      # teaches shape inference that params[1:].get_shape() matters
253      # (in the case that transform_fn is None).
254      ret.set_shape(ids.get_shape().concatenate(element_shape_s))
255      if not transform_fn:
256        # If transform_fn was provided, the clip_by_norm was done above.
257        ret = _clip(ret, ids, max_norm)
258      return ret
259
260
261@tf_export("nn.embedding_lookup")
262def embedding_lookup(
263    params,
264    ids,
265    partition_strategy="mod",
266    name=None,
267    validate_indices=True,  # pylint: disable=unused-argument
268    max_norm=None):
269  """Looks up `ids` in a list of embedding tensors.
270
271  This function is used to perform parallel lookups on the list of
272  tensors in `params`.  It is a generalization of
273  @{tf.gather}, where `params` is
274  interpreted as a partitioning of a large embedding tensor.  `params` may be
275  a `PartitionedVariable` as returned by using `tf.get_variable()` with a
276  partitioner.
277
278  If `len(params) > 1`, each element `id` of `ids` is partitioned between
279  the elements of `params` according to the `partition_strategy`.
280  In all strategies, if the id space does not evenly divide the number of
281  partitions, each of the first `(max_id + 1) % len(params)` partitions will
282  be assigned one more id.
283
284  If `partition_strategy` is `"mod"`, we assign each id to partition
285  `p = id % len(params)`. For instance,
286  13 ids are split across 5 partitions as:
287  `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`
288
289  If `partition_strategy` is `"div"`, we assign ids to partitions in a
290  contiguous manner. In this case, 13 ids are split across 5 partitions as:
291  `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`
292
293  The results of the lookup are concatenated into a dense
294  tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
295
296  Args:
297    params: A single tensor representing the complete embedding tensor,
298      or a list of P tensors all of same shape except for the first dimension,
299      representing sharded embedding tensors.  Alternatively, a
300      `PartitionedVariable`, created by partitioning along dimension 0. Each
301      element must be appropriately sized for the given `partition_strategy`.
302    ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked
303      up in `params`.
304    partition_strategy: A string specifying the partitioning strategy, relevant
305      if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
306      is `"mod"`.
307    name: A name for the operation (optional).
308    validate_indices: DEPRECATED. If this operation is assigned to CPU, values
309      in `indices` are always validated to be within range.  If assigned to GPU,
310      out-of-bound indices result in safe but unspecified behavior, which may
311      include raising an error.
312    max_norm: If provided, embedding values are l2-normalized to the value of
313      max_norm.
314
315  Returns:
316    A `Tensor` with the same type as the tensors in `params`.
317
318  Raises:
319    ValueError: If `params` is empty.
320  """
321  return _embedding_lookup_and_transform(
322      params=params,
323      ids=ids,
324      partition_strategy=partition_strategy,
325      name=name,
326      max_norm=max_norm,
327      transform_fn=None)
328
329
330@tf_export("nn.embedding_lookup_sparse")
331def embedding_lookup_sparse(params,
332                            sp_ids,
333                            sp_weights,
334                            partition_strategy="mod",
335                            name=None,
336                            combiner=None,
337                            max_norm=None):
338  """Computes embeddings for the given ids and weights.
339
340  This op assumes that there is at least one id for each row in the dense tensor
341  represented by sp_ids (i.e. there are no rows with empty features), and that
342  all the indices of sp_ids are in canonical row-major order.
343
344  It also assumes that all id values lie in the range [0, p0), where p0
345  is the sum of the size of params along dimension 0.
346
347  Args:
348    params: A single tensor representing the complete embedding tensor,
349      or a list of P tensors all of same shape except for the first dimension,
350      representing sharded embedding tensors.  Alternatively, a
351      `PartitionedVariable`, created by partitioning along dimension 0. Each
352      element must be appropriately sized for the given `partition_strategy`.
353    sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId),
354      where N is typically batch size and M is arbitrary.
355    sp_weights: either a SparseTensor of float / double weights, or None to
356      indicate all weights should be taken to be 1. If specified, sp_weights
357      must have exactly the same shape and indices as sp_ids.
358    partition_strategy: A string specifying the partitioning strategy, relevant
359      if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
360      is `"mod"`. See `tf.nn.embedding_lookup` for more details.
361    name: Optional name for the op.
362    combiner: A string specifying the reduction op. Currently "mean", "sqrtn"
363      and "sum" are supported.
364      "sum" computes the weighted sum of the embedding results for each row.
365      "mean" is the weighted sum divided by the total weight.
366      "sqrtn" is the weighted sum divided by the square root of the sum of the
367      squares of the weights.
368    max_norm: If provided, each embedding is normalized to have l2 norm equal
369      to max_norm before combining.
370
371  Returns:
372    A dense tensor representing the combined embeddings for the
373    sparse ids. For each row in the dense tensor represented by sp_ids, the op
374    looks up the embeddings for all ids in that row, multiplies them by the
375    corresponding weight, and combines these embeddings as specified.
376
377    In other words, if
378
379      shape(combined params) = [p0, p1, ..., pm]
380
381    and
382
383      shape(sp_ids) = shape(sp_weights) = [d0, d1, ..., dn]
384
385    then
386
387      shape(output) = [d0, d1, ..., dn-1, p1, ..., pm].
388
389    For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
390
391      [0, 0]: id 1, weight 2.0
392      [0, 1]: id 3, weight 0.5
393      [1, 0]: id 0, weight 1.0
394      [2, 3]: id 1, weight 3.0
395
396    with `combiner`="mean", then the output will be a 3x20 matrix where
397
398      output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
399      output[1, :] = params[0, :] * 1.0
400      output[2, :] = params[1, :] * 3.0
401
402  Raises:
403    TypeError: If sp_ids is not a SparseTensor, or if sp_weights is neither
404      None nor SparseTensor.
405    ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}.
406  """
407  if combiner is None:
408    logging.warn("The default value of combiner will change from \"mean\" "
409                 "to \"sqrtn\" after 2016/11/01.")
410    combiner = "mean"
411  if combiner not in ("mean", "sqrtn", "sum"):
412    raise ValueError("combiner must be one of 'mean', 'sqrtn' or 'sum'")
413  if isinstance(params, variables.PartitionedVariable):
414    params = list(params)  # Iterate to get the underlying Variables.
415  if not isinstance(params, list):
416    params = [params]
417  if not isinstance(sp_ids, sparse_tensor.SparseTensor):
418    raise TypeError("sp_ids must be SparseTensor")
419  ignore_weights = sp_weights is None
420  if not ignore_weights:
421    if not isinstance(sp_weights, sparse_tensor.SparseTensor):
422      raise TypeError("sp_weights must be either None or SparseTensor")
423    sp_ids.values.get_shape().assert_is_compatible_with(
424        sp_weights.values.get_shape())
425    sp_ids.indices.get_shape().assert_is_compatible_with(
426        sp_weights.indices.get_shape())
427    sp_ids.dense_shape.get_shape().assert_is_compatible_with(
428        sp_weights.dense_shape.get_shape())
429    # TODO(yleon): Add enhanced node assertions to verify that sp_ids and
430    # sp_weights have equal indices and shapes.
431
432  with ops.name_scope(name, "embedding_lookup_sparse",
433                      params + [sp_ids]) as name:
434    segment_ids = sp_ids.indices[:, 0]
435    if segment_ids.dtype != dtypes.int32:
436      segment_ids = math_ops.cast(segment_ids, dtypes.int32)
437
438    ids = sp_ids.values
439    if ignore_weights:
440      ids, idx = array_ops.unique(ids)
441    else:
442      idx = None
443
444    embeddings = embedding_lookup(
445        params, ids, partition_strategy=partition_strategy, max_norm=max_norm)
446    if not ignore_weights:
447      weights = sp_weights.values
448      if weights.dtype != embeddings.dtype:
449        weights = math_ops.cast(weights, embeddings.dtype)
450
451      # Reshape weights to allow broadcast
452      ones = array_ops.fill(
453          array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)
454      bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones],
455                                             0)
456
457      orig_weights_shape = weights.get_shape()
458      weights = array_ops.reshape(weights, bcast_weights_shape)
459
460      # Set the weight shape, since after reshaping to bcast_weights_shape,
461      # the shape becomes None.
462      if embeddings.get_shape().ndims is not None:
463        weights.set_shape(
464            orig_weights_shape.concatenate(
465                [1 for _ in range(embeddings.get_shape().ndims - 1)]))
466
467      embeddings *= weights
468
469      if combiner == "sum":
470        embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name)
471      elif combiner == "mean":
472        embeddings = math_ops.segment_sum(embeddings, segment_ids)
473        weight_sum = math_ops.segment_sum(weights, segment_ids)
474        embeddings = math_ops.div(embeddings, weight_sum, name=name)
475      elif combiner == "sqrtn":
476        embeddings = math_ops.segment_sum(embeddings, segment_ids)
477        weights_squared = math_ops.pow(weights, 2)
478        weight_sum = math_ops.segment_sum(weights_squared, segment_ids)
479        weight_sum_sqrt = math_ops.sqrt(weight_sum)
480        embeddings = math_ops.div(embeddings, weight_sum_sqrt, name=name)
481      else:
482        assert False, "Unrecognized combiner"
483    else:
484      assert idx is not None
485      if combiner == "sum":
486        embeddings = math_ops.sparse_segment_sum(
487            embeddings, idx, segment_ids, name=name)
488      elif combiner == "mean":
489        embeddings = math_ops.sparse_segment_mean(
490            embeddings, idx, segment_ids, name=name)
491      elif combiner == "sqrtn":
492        embeddings = math_ops.sparse_segment_sqrt_n(
493            embeddings, idx, segment_ids, name=name)
494      else:
495        assert False, "Unrecognized combiner"
496
497    return embeddings
498