1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for SparseTensorsMap."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.client import session
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import sparse_ops
30from tensorflow.python.ops import variables
31from tensorflow.python.platform import benchmark
32from tensorflow.python.platform import test
33
34# pylint: disable=protected-access
35add_sparse_to_tensors_map = sparse_ops._add_sparse_to_tensors_map
36add_many_sparse_to_tensors_map = sparse_ops._add_many_sparse_to_tensors_map
37take_many_sparse_from_tensors_map = (
38    sparse_ops._take_many_sparse_from_tensors_map)
39
40# pylint: enable=protected-access
41
42
43class SparseTensorsMapTest(test.TestCase):
44
45  def _SparseTensorPlaceholder(self, dtype=None):
46    if dtype is None:
47      dtype = dtypes.int32
48    return sparse_tensor_lib.SparseTensor(
49        array_ops.placeholder(dtypes.int64),
50        array_ops.placeholder(dtype), array_ops.placeholder(dtypes.int64))
51
52  def _SparseTensorValue_5x6(self, permutation):
53    ind = np.array([[0, 0], [1, 0], [1, 3], [1, 4], [3, 2],
54                    [3, 3]]).astype(np.int64)
55    val = np.array([0, 10, 13, 14, 32, 33]).astype(np.int32)
56
57    ind = ind[permutation]
58    val = val[permutation]
59
60    shape = np.array([5, 6]).astype(np.int64)
61    return sparse_tensor_lib.SparseTensorValue(ind, val, shape)
62
63  def _SparseTensorValue_3x4(self, permutation):
64    ind = np.array([[0, 0], [1, 0], [1, 2], [1, 3], [2, 2],
65                    [2, 3]]).astype(np.int64)
66    val = np.array([0, 10, 13, 14, 32, 33]).astype(np.int32)
67
68    ind = ind[permutation]
69    val = val[permutation]
70
71    shape = np.array([3, 4]).astype(np.int64)
72    return sparse_tensor_lib.SparseTensorValue(ind, val, shape)
73
74  def _SparseTensorValue_1x1x1(self):
75    ind = np.array([[0, 0, 0]]).astype(np.int64)
76    val = np.array([0]).astype(np.int32)
77    shape = np.array([3, 4, 5]).astype(np.int64)
78    return sparse_tensor_lib.SparseTensorValue(ind, val, shape)
79
80  @test_util.run_deprecated_v1
81  def testAddTakeMany(self):
82    with self.session(graph=ops.Graph(), use_gpu=False) as sess:
83      sp_input0 = self._SparseTensorValue_5x6(np.arange(6))
84      sp_input1 = self._SparseTensorValue_3x4(np.arange(6))
85      handle0 = add_sparse_to_tensors_map(sp_input0, shared_name="a")
86      handle1 = add_sparse_to_tensors_map(sp_input1, shared_name="a")
87      self.assertEqual(handle0.get_shape(), ())
88      handles_concat = array_ops.stack([handle0, handle1])
89
90      sp_out = take_many_sparse_from_tensors_map(
91          sparse_map_op=handle0.op, sparse_handles=handles_concat)
92
93      combined_indices, combined_values, combined_shape = self.evaluate(sp_out)
94
95      self.assertAllEqual(combined_indices[:6, 0], [0] * 6)  # minibatch 0
96      self.assertAllEqual(combined_indices[:6, 1:], sp_input0[0])
97      self.assertAllEqual(combined_indices[6:, 0], [1] * 6)  # minibatch 1
98      self.assertAllEqual(combined_indices[6:, 1:], sp_input1[0])
99      self.assertAllEqual(combined_values[:6], sp_input0[1])
100      self.assertAllEqual(combined_values[6:], sp_input1[1])
101      self.assertAllEqual(combined_shape, [2, 5, 6])
102
103  @test_util.run_deprecated_v1
104  def testFeedAddTakeMany(self):
105    with self.session(use_gpu=False) as sess:
106      sp_input = self._SparseTensorPlaceholder()
107      input0_val = self._SparseTensorValue_5x6(np.arange(6))
108      input1_val = self._SparseTensorValue_3x4(np.arange(6))
109      handle = add_sparse_to_tensors_map(sp_input)
110
111      handle0_value = sess.run(handle, feed_dict={sp_input: input0_val})
112      handle1_value = sess.run(handle, feed_dict={sp_input: input1_val})
113
114      sparse_handles = ops.convert_to_tensor(
115          [handle0_value, handle1_value], dtype=dtypes.int64)
116
117      sp_roundtrip = take_many_sparse_from_tensors_map(
118          sparse_map_op=handle.op, sparse_handles=sparse_handles)
119
120      combined_indices, combined_values, combined_shape = self.evaluate(
121          sp_roundtrip)
122
123      self.assertAllEqual(combined_indices[:6, 0], [0] * 6)  # minibatch 0
124      self.assertAllEqual(combined_indices[:6, 1:], input0_val[0])
125      self.assertAllEqual(combined_indices[6:, 0], [1] * 6)  # minibatch 1
126      self.assertAllEqual(combined_indices[6:, 1:], input1_val[0])
127      self.assertAllEqual(combined_values[:6], input0_val[1])
128      self.assertAllEqual(combined_values[6:], input1_val[1])
129      self.assertAllEqual(combined_shape, [2, 5, 6])
130
131  @test_util.run_deprecated_v1
132  def testAddManyTakeManyRoundTrip(self):
133    with self.session(use_gpu=False) as sess:
134      # N == 4 because shape_value == [4, 5]
135      indices_value = np.array([[0, 0], [0, 1], [2, 0]], dtype=np.int64)
136      values_value = np.array([b"a", b"b", b"c"])
137      shape_value = np.array([4, 5], dtype=np.int64)
138      sparse_tensor = self._SparseTensorPlaceholder(dtype=dtypes.string)
139      handles = add_many_sparse_to_tensors_map(sparse_tensor)
140      roundtrip = take_many_sparse_from_tensors_map(
141          sparse_map_op=handles.op, sparse_handles=handles)
142      handles_value, roundtrip_value = sess.run(
143          [handles, roundtrip],
144          feed_dict={
145              sparse_tensor.indices: indices_value,
146              sparse_tensor.values: values_value,
147              sparse_tensor.dense_shape: shape_value
148          })
149      self.assertEqual(handles_value.shape, (4,))
150      self.assertAllEqual(roundtrip_value.indices, indices_value)
151      self.assertAllEqual(roundtrip_value.values, values_value)
152      self.assertAllEqual(roundtrip_value.dense_shape, shape_value)
153
154  @test_util.run_deprecated_v1
155  def testDeserializeFailsInconsistentRank(self):
156    with self.session(use_gpu=False) as sess:
157      sp_input = self._SparseTensorPlaceholder()
158      input0_val = self._SparseTensorValue_5x6(np.arange(6))
159      input1_val = self._SparseTensorValue_1x1x1()
160      handle = add_sparse_to_tensors_map(sp_input)
161
162      handle0_value = sess.run(handle, feed_dict={sp_input: input0_val})
163      handle1_value = sess.run(handle, feed_dict={sp_input: input1_val})
164
165      handle_concat = ops.convert_to_tensor(
166          [handle0_value, handle1_value], dtype=dtypes.int64)
167
168      sp_roundtrip = take_many_sparse_from_tensors_map(
169          sparse_map_op=handle.op, sparse_handles=handle_concat)
170
171      with self.assertRaisesOpError(
172          r"Inconsistent rank across SparseTensors: rank prior to "
173          r"SparseTensor\[1\] was: 3 but rank of SparseTensor\[1\] is: 4"):
174        self.evaluate(sp_roundtrip)
175
176  @test_util.run_deprecated_v1
177  def testTakeManyFailsWrongInputOp(self):
178    with self.session(use_gpu=False) as sess:
179      input_val = self._SparseTensorValue_5x6(np.arange(6))
180      handle = add_sparse_to_tensors_map(input_val)
181      handle_value = self.evaluate(handle)
182      bad_handle = handle_value + 10
183      sp_roundtrip = take_many_sparse_from_tensors_map(
184          sparse_map_op=handle.op, sparse_handles=[handle_value, bad_handle])
185
186      with self.assertRaisesOpError(r"Unable to find SparseTensor: 10"):
187        self.evaluate(sp_roundtrip)
188
189
190class BenchmarkSparseTensorsMapVsSerialization(test.Benchmark):
191
192  def benchmarkVeryLarge2DFloatSparseTensor(self):
193    np.random.seed(127)
194    num_elements = 10000
195    batch_size = 64
196    indices_batch = np.random.randint(
197        batch_size, size=num_elements, dtype=np.int64)
198    indices_value = np.arange(num_elements, dtype=np.int64)
199    indices = np.asarray(
200        sorted(zip(indices_batch, indices_value)), dtype=np.int64)
201    values = ["feature_value_for_embedding_lookup"] * num_elements
202    shape = np.asarray([batch_size, num_elements], dtype=np.int64)
203    with session.Session(config=benchmark.benchmark_config()) as sess:
204      with ops.device("/cpu:0"):
205        indices = variables.Variable(indices)
206        values = variables.Variable(values)
207        shape = variables.Variable(shape)
208        st = sparse_tensor_lib.SparseTensor(indices, values, shape)
209
210        st_handles = add_many_sparse_to_tensors_map(st)
211        st_roundtrip = take_many_sparse_from_tensors_map(
212            sparse_map_op=st_handles.op, sparse_handles=st_handles)
213        st_roundtrip_op = st_roundtrip.values.op
214
215        st_serialized = sparse_ops.serialize_many_sparse(st)
216        st_deserialized = sparse_ops.deserialize_many_sparse(
217            st_serialized, dtype=values.dtype)
218        st_deserialized_op = st_deserialized.values.op
219
220        self.evaluate(variables.global_variables_initializer())
221
222        st_roundtrip_values = self.evaluate(st_roundtrip)
223        st_deserialized_values = self.evaluate(st_deserialized)
224        np.testing.assert_equal(st_roundtrip_values.values,
225                                st_deserialized_values.values)
226        np.testing.assert_equal(st_roundtrip_values.indices,
227                                st_deserialized_values.indices)
228        np.testing.assert_equal(st_roundtrip_values.dense_shape,
229                                st_deserialized_values.dense_shape)
230
231        self.run_op_benchmark(
232            sess,
233            st_roundtrip_op,
234            min_iters=2000,
235            name="benchmark_very_large_2d_float_st_tensor_maps")
236        self.run_op_benchmark(
237            sess,
238            st_deserialized_op,
239            min_iters=2000,
240            name="benchmark_very_large_2d_float_st_serialization")
241
242
243if __name__ == "__main__":
244  test.main()
245