1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for sampling_ops.py."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.nn.python.ops import sampling_ops
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.ops import nn
26from tensorflow.python.platform import test
27
28
29class RankSampledSoftmaxLossTest(test.TestCase):
30
31  def setUp(self):
32    self._sampled = [3, 4, 5, 6, 7]
33    self._num_sampled = len(self._sampled)
34    # Because values of all matrices increase with indices, logits increase with
35    # class id. So, for the above sampled classes, adaptive sampling will select
36    # these resampled classes.
37    self._resampled = [5, 6, 7]
38    self._num_resampled = len(self._resampled)
39    self._num_classes = 10
40    self._num_true = 2
41    self._sampled_values = (self._sampled, [[0.5], [0.5]],
42                            [0.5, 0.5, 0.5, 0.5, 0.5])
43    self._resampled_values = (self._resampled, [[0.5], [0.5]], [0.5, 0.5, 0.5])
44    self._remove_accidental_hits = False
45    self._embed_dim = 5
46    self._batch_size = 2
47
48  def _weights(self):
49    return constant_op.constant([
50        [0.0, 0.1, 0.2, 0.3, 0.4],
51        [1.0, 1.1, 1.2, 1.3, 1.4],
52        [2.0, 2.1, 2.2, 2.3, 2.4],
53        [3.0, 3.1, 3.2, 3.3, 3.4],
54        [4.0, 4.1, 4.2, 4.3, 4.4],
55        [5.0, 5.1, 5.2, 5.3, 5.4],
56        [6.0, 6.1, 6.2, 6.3, 6.4],
57        [7.0, 7.1, 7.2, 7.3, 7.4],
58        [8.0, 8.1, 8.2, 8.3, 8.4],
59        [9.0, 9.1, 9.2, 9.3, 9.4],
60    ])
61
62  def _div_sharded_weights(self):
63    return [
64        constant_op.constant([
65            [0.0, 0.1, 0.2, 0.3, 0.4],
66            [1.0, 1.1, 1.2, 1.3, 1.4],
67        ]),
68        constant_op.constant([
69            [2.0, 2.1, 2.2, 2.3, 2.4],
70            [3.0, 3.1, 3.2, 3.3, 3.4],
71        ]),
72        constant_op.constant([
73            [4.0, 4.1, 4.2, 4.3, 4.4],
74            [5.0, 5.1, 5.2, 5.3, 5.4],
75        ]),
76        constant_op.constant([
77            [6.0, 6.1, 6.2, 6.3, 6.4],
78            [7.0, 7.1, 7.2, 7.3, 7.4],
79        ]),
80        constant_op.constant([
81            [8.0, 8.1, 8.2, 8.3, 8.4],
82            [9.0, 9.1, 9.2, 9.3, 9.4],
83        ]),
84    ]
85
86  def _mod_sharded_weights(self):
87    return [
88        constant_op.constant([
89            [0.0, 0.1, 0.2, 0.3, 0.4],
90            [5.0, 5.1, 5.2, 5.3, 5.4],
91        ]),
92        constant_op.constant([
93            [1.0, 1.1, 1.2, 1.3, 1.4],
94            [6.0, 6.1, 6.2, 6.3, 6.4],
95        ]),
96        constant_op.constant([
97            [2.0, 2.1, 2.2, 2.3, 2.4],
98            [7.0, 7.1, 7.2, 7.3, 7.4],
99        ]),
100        constant_op.constant([
101            [3.0, 3.1, 3.2, 3.3, 3.4],
102            [8.0, 8.1, 8.2, 8.3, 8.4],
103        ]),
104        constant_op.constant([
105            [4.0, 4.1, 4.2, 4.3, 4.4],
106            [9.0, 9.1, 9.2, 9.3, 9.4],
107        ]),
108    ]
109
110  def _biases(self):
111    return constant_op.constant(
112        [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
113
114  def _div_sharded_biases(self):
115    return [
116        constant_op.constant([0.0, 0.1]),
117        constant_op.constant([0.2, 0.3]),
118        constant_op.constant([0.4, 0.5]),
119        constant_op.constant([0.6, 0.7]),
120        constant_op.constant([0.8, 0.9]),
121    ]
122
123  def _mod_sharded_biases(self):
124    return [
125        constant_op.constant([0.0, 0.5]),
126        constant_op.constant([0.1, 0.6]),
127        constant_op.constant([0.2, 0.7]),
128        constant_op.constant([0.3, 0.8]),
129        constant_op.constant([0.4, 0.9]),
130    ]
131
132  def _labels(self):
133    return constant_op.constant(
134        [[0, 1], [1, 2]],
135        shape=(self._batch_size, self._num_true),
136        name='labels',
137        dtype=dtypes.int64)
138
139  def _inputs(self):
140    return constant_op.constant(
141        [
142            [0., 1., 2., 3., 4.],
143            [10., 11., 12., 13., 14.],
144        ],
145        shape=(self._batch_size, self._embed_dim),
146        name='inputs')
147
148  def testInvalidNumSampled0(self):
149    with ops.Graph().as_default():
150      with self.assertRaisesRegexp(
151          ValueError,
152          r'num_resampled \(3\) must be less than num_sampled \(3\)'):
153        sampling_ops.rank_sampled_softmax_loss(
154            weights=self._weights(),
155            biases=self._biases(),
156            labels=self._labels(),
157            inputs=self._inputs(),
158            num_sampled=3,
159            num_resampled=3,
160            num_classes=self._num_classes,
161            num_true=self._num_true,
162            sampled_values=None,
163            resampling_temperature=1.,
164            remove_accidental_hits=True,
165            partition_strategy='div')
166
167  def testInvalidNumSampled1(self):
168    with ops.Graph().as_default():
169      with self.assertRaisesRegexp(
170          ValueError,
171          r'num_resampled \(3\) must be less than num_sampled \(2\)'):
172        sampling_ops.rank_sampled_softmax_loss(
173            weights=self._weights(),
174            biases=self._biases(),
175            labels=self._labels(),
176            inputs=self._inputs(),
177            num_sampled=2,
178            num_resampled=3,
179            num_classes=self._num_classes,
180            num_true=self._num_true,
181            sampled_values=None,
182            resampling_temperature=1.,
183            remove_accidental_hits=True,
184            partition_strategy='div')
185
186  def testMissingPartitionStrategy(self):
187    with ops.Graph().as_default():
188      with self.assertRaisesRegexp(ValueError,
189                                   r'unsupported partition_strategy \(None\)'):
190        sampling_ops.rank_sampled_softmax_loss(
191            weights=self._weights(),
192            biases=self._biases(),
193            labels=self._labels(),
194            inputs=self._inputs(),
195            num_sampled=2,
196            num_resampled=1,
197            num_classes=self._num_classes,
198            num_true=self._num_true,
199            sampled_values=None,
200            resampling_temperature=1.,
201            remove_accidental_hits=True,
202            partition_strategy=None)
203
204  def _testCompareWithNN(self, weights, biases, partition_strategy):
205    with ops.Graph().as_default():
206      loss = sampling_ops.rank_sampled_softmax_loss(
207          weights=weights(),
208          biases=biases(),
209          labels=self._labels(),
210          inputs=self._inputs(),
211          num_sampled=self._num_sampled,
212          num_resampled=self._num_resampled,
213          num_classes=self._num_classes,
214          num_true=self._num_true,
215          sampled_values=self._sampled_values,
216          resampling_temperature=1.,
217          remove_accidental_hits=self._remove_accidental_hits,
218          partition_strategy=partition_strategy)
219      loss_nn = nn.sampled_softmax_loss(
220          weights=weights(),
221          biases=biases(),
222          labels=self._labels(),
223          inputs=self._inputs(),
224          num_sampled=self._num_resampled,
225          num_classes=self._num_classes,
226          num_true=self._num_true,
227          sampled_values=self._resampled_values,
228          remove_accidental_hits=self._remove_accidental_hits,
229          partition_strategy=partition_strategy)
230      with self.cached_session() as sess:
231        loss_val = sess.run(loss)
232        loss_nn_val = sess.run(loss_nn)
233
234    self.assertAllClose(loss_val, loss_nn_val)
235
236  def testCompareWithNNUnsharded(self):
237    self._testCompareWithNN(self._weights, self._biases, 'div')
238
239  def testCompareWithNNShardWeightsDiv(self):
240    self._testCompareWithNN(self._div_sharded_weights, self._biases, 'div')
241
242  def testCompareWithNNShardWeightsAndBiasesDiv(self):
243    self._testCompareWithNN(self._div_sharded_weights, self._div_sharded_biases,
244                            'div')
245
246  def testCompareWithNNShardWeightsMod(self):
247    self._testCompareWithNN(self._mod_sharded_weights, self._biases, 'mod')
248
249  def testCompareWithNNShardWeightsAndBiasesMod(self):
250    self._testCompareWithNN(self._mod_sharded_weights, self._mod_sharded_biases,
251                            'mod')
252
253  def _testCompareWithNNTemperature(self, temperature, resampled):
254    weights = [[1., 2.], [3., 4.]]  # two sampled classes
255    inputs = [[6., -5. / 2.], [-11., 21. / 2.]]
256    # Let w0, w1 = weights of sampled classes (biases set to 0 for simplicity)
257    # Let x0, x1 = inputs
258    # logits:
259    #   w0.x0 = 1
260    #   w0.x1 = 10
261    #   w1.x0 = 8
262    #   w1.x1 = 9
263    # Resampling 1 class with temperature = t will pick the larger of:
264    #   exp(1/t) + exp(10/t)  ==> w0, for values of t < 2.12
265    #   exp(8/t) + exp(9/t)   ==> w1, for values of t > 2.13
266    num_sampled = 2
267    num_resampled = 1
268    num_classes = 2
269    num_true = 1
270    sampled_values = [0, 1], [[1.], [1.]], [1., 1.]
271    resampled_values = [resampled], [[1.], [1.]], [1.]
272    remove_accidental_hits = False
273    with ops.Graph().as_default():
274      weights = constant_op.constant(weights)
275      biases = constant_op.constant([0., 0.])
276      labels = constant_op.constant([[0], [1]], dtype=dtypes.int64)
277      inputs = constant_op.constant(inputs)
278      loss = sampling_ops.rank_sampled_softmax_loss(
279          weights=weights,
280          biases=biases,
281          labels=labels,
282          inputs=inputs,
283          num_sampled=num_sampled,
284          num_resampled=num_resampled,
285          num_classes=num_classes,
286          num_true=num_true,
287          sampled_values=sampled_values,
288          resampling_temperature=constant_op.constant(temperature),
289          remove_accidental_hits=remove_accidental_hits,
290          partition_strategy='div')
291      loss_nn = nn.sampled_softmax_loss(
292          weights=weights,
293          biases=biases,
294          labels=labels,
295          inputs=inputs,
296          num_sampled=num_resampled,
297          num_classes=num_classes,
298          num_true=num_true,
299          sampled_values=resampled_values,
300          remove_accidental_hits=remove_accidental_hits,
301          partition_strategy='div')
302      with self.cached_session() as sess:
303        loss_val = sess.run(loss)
304        loss_nn_val = sess.run(loss_nn)
305
306    self.assertAllClose(loss_val, loss_nn_val)
307
308  def testCompareWithNNTemperatureLo1(self):
309    self._testCompareWithNNTemperature(1., 0)
310
311  def testCompareWithNNTemperatureLo2(self):
312    self._testCompareWithNNTemperature(2.12, 0)
313
314  def testCompareWithNNTemperatureHi1(self):
315    self._testCompareWithNNTemperature(2.13, 1)
316
317  def testCompareWithNNTemperatureHi2(self):
318    self._testCompareWithNNTemperature(3., 1)
319
320
321if __name__ == '__main__':
322  test.main()
323