1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ops related to candidate sampling."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import embedding_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops import nn
27from tensorflow.python.ops import nn_impl
28from tensorflow.python.ops import nn_ops
29
30
31def _rank_resample(weights, biases, inputs, sampled_values, num_resampled,
32                   resampling_temperature, partition_strategy):
33  """A helper function for rank_sampled_softmax_loss.
34
35  This computes, for each i in `sampled_values`,
36
37      log(sum_j exp((w_i * x_j + b_i) / resampling_temperature))
38
39  where w_i, b_i are the weight and bias of the i-th class, respectively,
40  and j ranges over the rows of `inputs`. For efficiency, we rearrange the
41  computation to
42
43      log(sum_j exp(w_i * (x_j / resampling_temperature))) +
44          b_i / resampling_temperature.
45
46  This translates to the following batched computation using tensorflow ops:
47
48      reduce_logsumexp(matmul(embeddings,
49                       transpose(inputs / resampling_temperature))) +
50          biases / resampling_temperature
51
52  The computation of the first term is colocated with the embeddings using
53  `transform_fn` in `embedding_ops._embedding_lookup_and_transform`. The second
54  term, not the bottleneck, is computed at the worker.
55
56  Args:
57    weights: From `rank_sampled_softmax_loss`.
58    biases: From `rank_sampled_softmax_loss`.
59    inputs: From `rank_sampled_softmax_loss`.
60    sampled_values: A tuple of (`sampled_candidates`, `true_expected_count`,
61        `sampled_expected_count`) returned by a `*_candidate_sampler` function.
62    num_resampled: An `int`. This many values are selected from
63        `sampled_values` using the adaptive resampling algorithm. The caller
64        must ensure that `num_resampled` is less than the size of
65        `sampled_values`.
66    resampling_temperature: A scalar `Tensor` with the temperature parameter
67        for the adaptive resampling algorithm.
68    partition_strategy: From `rank_sampled_softmax_loss`.
69
70  Returns:
71    A tuple of (`resampled_candidates`, `true_expected_count`,
72        `resampled_expected_count`), similar to `sampled_values` but sampled
73        down to `num_resampled` values.
74  """
75  # This code supports passing a Tensor for num_resampled, but since it is only
76  # called with an int, that's what we specify in the arg list. If this
77  # function is ever externalized, we should change the doc to support Tensor.
78
79  sampled, true_expected_count, sampled_expected_count = sampled_values
80
81  sampled = math_ops.cast(array_ops.stop_gradient(sampled), dtypes.int64)
82  true_expected_count = array_ops.stop_gradient(true_expected_count)
83  sampled_expected_count = array_ops.stop_gradient(sampled_expected_count)
84
85  reweighted_inputs = inputs / resampling_temperature
86
87  def logsumexp_logit(embeddings):
88    return math_ops.reduce_logsumexp(
89        math_ops.matmul(embeddings, reweighted_inputs, transpose_b=True),
90        axis=1,
91        keepdims=False)
92
93  # Calling this protected form of embedding_lookup allows co-locating
94  # the logsumexp computation with the partitioned weights, which yields
95  # a large speedup in practice.
96  sampled_logits = embedding_ops._embedding_lookup_and_transform(  # pylint: disable=protected-access
97      weights, sampled, partition_strategy, transform_fn=logsumexp_logit)
98  sampled_b = array_ops.reshape(
99      embedding_ops.embedding_lookup(biases, sampled, partition_strategy), [-1])
100  sampled_logits += sampled_b / resampling_temperature
101
102  _, resampled_indices = nn.top_k(sampled_logits, k=num_resampled, sorted=False)
103  resampled = array_ops.gather(sampled, indices=resampled_indices)
104  resampled_expected_count = array_ops.gather(
105      sampled_expected_count, indices=resampled_indices)
106
107  return resampled, true_expected_count, resampled_expected_count
108
109
110def rank_sampled_softmax_loss(weights,
111                              biases,
112                              labels,
113                              inputs,
114                              num_sampled,
115                              num_resampled,
116                              num_classes,
117                              num_true,
118                              sampled_values,
119                              resampling_temperature,
120                              remove_accidental_hits,
121                              partition_strategy,
122                              name=None):
123  """Computes softmax loss using rank-based adaptive resampling.
124
125  This has been shown to improve rank loss after training compared to
126  `tf.nn.sampled_softmax_loss`. For a description of the algorithm and some
127  experimental results, please see: [TAPAS: Two-pass Approximate Adaptive
128  Sampling for Softmax](https://arxiv.org/abs/1707.03073).
129
130  Sampling follows two phases:
131  * In the first phase, `num_sampled` classes are selected using
132    `tf.nn.learned_unigram_candidate_sampler` or supplied `sampled_values`.
133    The logits are calculated on those sampled classes. This phases is
134    similar to `tf.nn.sampled_softmax_loss`.
135  * In the second phase, the `num_resampled` classes with highest predicted
136    probability are kept. Probabilities are
137    `LogSumExp(logits / resampling_temperature)`, where the sum is over
138    `inputs`.
139
140  The `resampling_temperature` parameter controls the "adaptiveness" of the
141  resampling. At lower temperatures, resampling is more adaptive because it
142  picks more candidates close to the predicted classes. A common strategy is
143  to decrease the temperature as training proceeds.
144
145  See `tf.nn.sampled_softmax_loss` for more documentation on sampling and
146  for typical default values for some of the parameters.
147
148  This operation is for training only. It is generally an underestimate of
149  the full softmax loss.
150
151  A common use case is to use this method for training, and calculate the full
152  softmax loss for evaluation or inference. In this case, you must set
153  `partition_strategy="div"` for the two losses to be consistent, as in the
154  following example:
155
156  ```python
157  if mode == "train":
158    loss = rank_sampled_softmax_loss(
159        weights=weights,
160        biases=biases,
161        labels=labels,
162        inputs=inputs,
163        ...,
164        partition_strategy="div")
165  elif mode == "eval":
166    logits = tf.matmul(inputs, tf.transpose(weights))
167    logits = tf.nn.bias_add(logits, biases)
168    labels_one_hot = tf.one_hot(labels, n_classes)
169    loss = tf.nn.softmax_cross_entropy_with_logits(
170        labels=labels_one_hot,
171        logits=logits)
172  ```
173
174  Args:
175    weights: A `Tensor` or `PartitionedVariable` of shape `[num_classes, dim]`,
176        or a list of `Tensor` objects whose concatenation along dimension 0
177        has shape [num_classes, dim]. The (possibly-sharded) class embeddings.
178    biases: A `Tensor` or `PartitionedVariable` of shape `[num_classes]`.
179        The (possibly-sharded) class biases.
180    labels: A `Tensor` of type `int64` and shape `[batch_size,
181        num_true]`. The target classes. Note that this format differs from
182        the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
183    inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
184        activations of the input network.
185    num_sampled: An `int`. The number of classes to randomly sample per batch.
186    num_resampled: An `int`. The number of classes to select from the
187        `num_sampled` classes using the adaptive resampling algorithm. Must be
188        less than `num_sampled`.
189    num_classes: An `int`. The number of possible classes.
190    num_true: An `int`.  The number of target classes per training example.
191    sampled_values: A tuple of (`sampled_candidates`, `true_expected_count`,
192        `sampled_expected_count`) returned by a `*_candidate_sampler` function.
193        If None, default to `nn.learned_unigram_candidate_sampler`.
194    resampling_temperature: A scalar `Tensor` with the temperature parameter
195        for the adaptive resampling algorithm.
196    remove_accidental_hits: A `bool`. Whether to remove "accidental hits"
197        where a sampled class equals one of the target classes.
198    partition_strategy: A string specifying the partitioning strategy, relevant
199        if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
200        See `tf.nn.embedding_lookup` for more details.
201    name: A name for the operation (optional).
202
203  Returns:
204    A `batch_size` 1-D tensor of per-example sampled softmax losses.
205
206  Raises:
207    ValueError: If `num_sampled <= num_resampled`.
208  """
209  if num_sampled > num_classes:
210    raise ValueError("num_sampled ({}) cannot be greater than num_classes ({})".
211                     format(num_sampled, num_classes))
212  if num_sampled <= num_resampled:
213    raise ValueError("num_resampled ({}) must be less than num_sampled ({})".
214                     format(num_resampled, num_sampled))
215  if partition_strategy not in ("div", "mod"):
216    raise ValueError(
217        "unsupported partition_strategy ({})".format(partition_strategy))
218  with ops.name_scope(name, "rank_sampled_softmax_loss", [
219      weights, biases, labels, inputs, sampled_values, resampling_temperature
220  ]) as name:
221    if not sampled_values:
222      sampled_values = nn.learned_unigram_candidate_sampler(
223          true_classes=labels,
224          num_true=num_true,
225          num_sampled=num_sampled,
226          unique=True,
227          range_max=num_classes)
228    # From sampled_values, select the top num_resampled values using the
229    # adaptive rank resampling strategy.
230    resampled_values = _rank_resample(weights, biases, inputs, sampled_values,
231                                      num_resampled, resampling_temperature,
232                                      partition_strategy)
233    return nn.sampled_softmax_loss(
234        weights=weights,
235        biases=biases,
236        labels=labels,
237        inputs=inputs,
238        num_sampled=num_resampled,
239        num_classes=num_classes,
240        num_true=num_true,
241        sampled_values=resampled_values,
242        remove_accidental_hits=remove_accidental_hits,
243        partition_strategy=partition_strategy,
244        name=name)
245
246
247def sampled_sparse_softmax_loss(weights,
248                                biases,
249                                labels,
250                                inputs,
251                                num_sampled,
252                                num_classes,
253                                sampled_values=None,
254                                remove_accidental_hits=True,
255                                partition_strategy="mod",
256                                name="sampled_sparse_softmax_loss"):
257  """Computes and returns the sampled sparse softmax training loss.
258
259  This is a faster way to train a softmax classifier over a huge number of
260  classes.
261
262  This operation is for training only.  It is generally an underestimate of
263  the full softmax loss.
264
265  A common use case is to use this method for training, and calculate the full
266  softmax loss for evaluation or inference. In this case, you must set
267  `partition_strategy="div"` for the two losses to be consistent, as in the
268  following example:
269
270  ```python
271  if mode == "train":
272    loss = tf.nn.sampled_sparse_softmax_loss(
273        weights=weights,
274        biases=biases,
275        labels=labels,
276        inputs=inputs,
277        ...,
278        partition_strategy="div")
279  elif mode == "eval":
280    logits = tf.matmul(inputs, tf.transpose(weights))
281    logits = tf.nn.bias_add(logits, biases)
282    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
283        labels=tf.squeeze(labels),
284        logits=logits)
285  ```
286
287  See our [Candidate Sampling Algorithms Reference]
288  (https://www.tensorflow.org/extras/candidate_sampling.pdf)
289
290  Also see Section 3 of [Jean et al., 2014](http://arxiv.org/abs/1412.2007)
291  ([pdf](http://arxiv.org/pdf/1412.2007.pdf)) for the math.
292
293  Args:
294    weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
295        objects whose concatenation along dimension 0 has shape
296        [num_classes, dim].  The (possibly-sharded) class embeddings.
297    biases: A `Tensor` of shape `[num_classes]`.  The class biases.
298    labels: A `Tensor` of type `int64` and shape `[batch_size, 1]`.
299        The index of the single target class for each row of logits.  Note that
300        this format differs from the `labels` argument of
301        `nn.sparse_softmax_cross_entropy_with_logits`.
302    inputs: A `Tensor` of shape `[batch_size, dim]`.  The forward
303        activations of the input network.
304    num_sampled: An `int`.  The number of classes to randomly sample per batch.
305    num_classes: An `int`. The number of possible classes.
306    sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
307        `sampled_expected_count`) returned by a `*_candidate_sampler` function.
308        (if None, we default to `log_uniform_candidate_sampler`)
309    remove_accidental_hits:  A `bool`.  whether to remove "accidental hits"
310        where a sampled class equals one of the target classes.  Default is
311        True.
312    partition_strategy: A string specifying the partitioning strategy, relevant
313        if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
314        Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
315    name: A name for the operation (optional).
316
317  Returns:
318    A `batch_size` 1-D tensor of per-example sampled softmax losses.
319
320  """
321  logits, _ = nn_impl._compute_sampled_logits(
322      weights=weights,
323      biases=biases,
324      labels=labels,
325      inputs=inputs,
326      num_sampled=num_sampled,
327      num_classes=num_classes,
328      num_true=1,
329      sampled_values=sampled_values,
330      subtract_log_q=True,
331      remove_accidental_hits=remove_accidental_hits,
332      partition_strategy=partition_strategy,
333      name=name)
334
335  # There is only one true label. _compute_sampled_logits puts the true logit
336  # at index 0.
337  labels = array_ops.zeros([array_ops.shape(logits)[0], 1], dtype=dtypes.int64)
338
339  sampled_losses = nn_ops.sparse_softmax_cross_entropy_with_logits(
340      labels=array_ops.squeeze(labels), logits=logits)
341  # sampled_losses is a [batch_size] tensor.
342  return sampled_losses
343