1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for SparseConcat."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import sparse_tensor
26from tensorflow.python.framework import test_util
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import sparse_ops
29from tensorflow.python.platform import test
30
31
32class SparseConcatTest(test.TestCase):
33
34  def _SparseTensor_UnknownShape(self,
35                                 ind_shape=None,
36                                 val_shape=None,
37                                 shape_shape=None):
38    return sparse_tensor.SparseTensor(
39        array_ops.placeholder(
40            dtypes.int64, shape=ind_shape),
41        array_ops.placeholder(
42            dtypes.float32, shape=val_shape),
43        array_ops.placeholder(
44            dtypes.int64, shape=shape_shape))
45
46  def _SparseTensorValue_3x3(self):
47    # [    1]
48    # [2    ]
49    # [3   4]
50    ind = np.array([[0, 2], [1, 0], [2, 0], [2, 2]])
51    val = np.array([1, 2, 3, 4])
52    shape = np.array([3, 3])
53    return sparse_tensor.SparseTensorValue(
54        np.array(ind, np.int64),
55        np.array(val, np.float32), np.array(shape, np.int64))
56
57  def _SparseTensor_3x3(self):
58    return sparse_tensor.SparseTensor.from_value(self._SparseTensorValue_3x3())
59
60  def _SparseTensorValue_3x5(self):
61    # [         ]
62    # [  1      ]
63    # [2     1 0]
64    ind = np.array([[1, 1], [2, 0], [2, 3], [2, 4]])
65    val = np.array([1, 2, 1, 0])
66    shape = np.array([3, 5])
67    return sparse_tensor.SparseTensorValue(
68        np.array(ind, np.int64),
69        np.array(val, np.float32), np.array(shape, np.int64))
70
71  def _SparseTensor_3x5(self):
72    return sparse_tensor.SparseTensor.from_value(self._SparseTensorValue_3x5())
73
74  def _SparseTensor_3x2(self):
75    # [   ]
76    # [1  ]
77    # [2  ]
78    ind = np.array([[1, 0], [2, 0]])
79    val = np.array([1, 2])
80    shape = np.array([3, 2])
81    return sparse_tensor.SparseTensor(
82        constant_op.constant(ind, dtypes.int64),
83        constant_op.constant(val, dtypes.float32),
84        constant_op.constant(shape, dtypes.int64))
85
86  def _SparseTensor_2x3(self):
87    # [  1  ]
88    # [1   2]
89    ind = np.array([[0, 1], [1, 0], [1, 2]])
90    val = np.array([1, 1, 2])
91    shape = np.array([2, 3])
92    return sparse_tensor.SparseTensor(
93        constant_op.constant(ind, dtypes.int64),
94        constant_op.constant(val, dtypes.float32),
95        constant_op.constant(shape, dtypes.int64))
96
97  def _SparseTensor_2x3x4(self):
98    ind = np.array([
99        [0, 0, 1],
100        [0, 1, 0], [0, 1, 2],
101        [1, 0, 3],
102        [1, 1, 1], [1, 1, 3],
103        [1, 2, 2]])
104    val = np.array([1, 10, 12, 103, 111, 113, 122])
105    shape = np.array([2, 3, 4])
106    return sparse_tensor.SparseTensor(
107        constant_op.constant(ind, dtypes.int64),
108        constant_op.constant(val, dtypes.float32),
109        constant_op.constant(shape, dtypes.int64))
110
111  def _SparseTensor_String3x3(self):
112    # [    a]
113    # [b    ]
114    # [c   d]
115    ind = np.array([[0, 2], [1, 0], [2, 0], [2, 2]])
116    val = np.array(["a", "b", "c", "d"])
117    shape = np.array([3, 3])
118    return sparse_tensor.SparseTensor(
119        constant_op.constant(ind, dtypes.int64),
120        constant_op.constant(val, dtypes.string),
121        constant_op.constant(shape, dtypes.int64))
122
123  def _SparseTensor_String3x5(self):
124    # [         ]
125    # [  e      ]
126    # [f     g h]
127    ind = np.array([[1, 1], [2, 0], [2, 3], [2, 4]])
128    val = np.array(["e", "f", "g", "h"])
129    shape = np.array([3, 5])
130    return sparse_tensor.SparseTensor(
131        constant_op.constant(ind, dtypes.int64),
132        constant_op.constant(val, dtypes.string),
133        constant_op.constant(shape, dtypes.int64))
134
135  def testConcat1(self):
136    with self.session(use_gpu=False) as sess:
137      # concat(A):
138      # [    1]
139      # [2    ]
140      # [3   4]
141      for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()):
142        # Note that we ignore concat_dim in this case since we short-circuit the
143        # single-input case in python.
144        for concat_dim in (-2000, 1, 2000):
145          sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a])
146
147          self.assertEqual(sp_concat.indices.get_shape(), [4, 2])
148          self.assertEqual(sp_concat.values.get_shape(), [4])
149          self.assertEqual(sp_concat.dense_shape.get_shape(), [2])
150
151          concat_out = self.evaluate(sp_concat)
152
153          self.assertAllEqual(concat_out.indices,
154                              [[0, 2], [1, 0], [2, 0], [2, 2]])
155          self.assertAllEqual(concat_out.values, [1, 2, 3, 4])
156          self.assertAllEqual(concat_out.dense_shape, [3, 3])
157
158  def testConcat2(self):
159    with self.session(use_gpu=False) as sess:
160      # concat(A, B):
161      # [    1          ]
162      # [2       1      ]
163      # [3   4 2     1 0]
164      for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()):
165        for sp_b in (self._SparseTensorValue_3x5(), self._SparseTensor_3x5()):
166          for concat_dim in (-1, 1):
167            sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b])
168
169            self.assertEqual(sp_concat.indices.get_shape(), [8, 2])
170            self.assertEqual(sp_concat.values.get_shape(), [8])
171            self.assertEqual(sp_concat.dense_shape.get_shape(), [2])
172
173            concat_out = self.evaluate(sp_concat)
174
175            self.assertAllEqual(concat_out.indices, [[0, 2], [1, 0], [1, 4],
176                                                     [2, 0], [2, 2], [2, 3],
177                                                     [2, 6], [2, 7]])
178            self.assertAllEqual(concat_out.values, [1, 2, 1, 3, 4, 2, 1, 0])
179            self.assertAllEqual(concat_out.dense_shape, [3, 8])
180
181  def testConcatDim0(self):
182    with self.session(use_gpu=False) as sess:
183      # concat(A, D):
184      # [    1]
185      # [2    ]
186      # [3   4]
187      # [  1  ]
188      # [1   2]
189      sp_a = self._SparseTensor_3x3()
190      sp_d = self._SparseTensor_2x3()
191
192      for concat_dim in (-2, 0):
193        sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_d])
194
195        self.assertEqual(sp_concat.indices.get_shape(), [7, 2])
196        self.assertEqual(sp_concat.values.get_shape(), [7])
197        self.assertEqual(sp_concat.dense_shape.get_shape(), [2])
198
199        concat_out = self.evaluate(sp_concat)
200
201        self.assertAllEqual(
202            concat_out.indices,
203            [[0, 2], [1, 0], [2, 0], [2, 2], [3, 1], [4, 0], [4, 2]])
204        self.assertAllEqual(concat_out.values, np.array([1, 2, 3, 4, 1, 1, 2]))
205        self.assertAllEqual(concat_out.dense_shape, np.array([5, 3]))
206
207  def testConcat3(self):
208    with self.session(use_gpu=False) as sess:
209      # concat(A, B, C):
210      # [    1              ]
211      # [2       1       1  ]
212      # [3   4 2     1 0 2  ]
213      sp_a = self._SparseTensor_3x3()
214      sp_b = self._SparseTensor_3x5()
215      sp_c = self._SparseTensor_3x2()
216
217      for concat_dim in (-1, 1):
218        sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b, sp_c])
219
220        self.assertEqual(sp_concat.indices.get_shape(), [10, 2])
221        self.assertEqual(sp_concat.values.get_shape(), [10])
222        self.assertEqual(sp_concat.dense_shape.get_shape(), [2])
223
224        concat_out = self.evaluate(sp_concat)
225
226        self.assertAllEqual(concat_out.indices, [[0, 2], [1, 0], [1, 4], [1, 8],
227                                                 [2, 0], [2, 2], [2, 3], [2, 6],
228                                                 [2, 7], [2, 8]])
229        self.assertAllEqual(concat_out.values, [1, 2, 1, 1, 3, 4, 2, 1, 0, 2])
230        self.assertAllEqual(concat_out.dense_shape, [3, 10])
231
232  def testConcatNonNumeric(self):
233    with self.session(use_gpu=False) as sess:
234      # concat(A, B):
235      # [    a          ]
236      # [b       e      ]
237      # [c   d f     g h]
238      sp_a = self._SparseTensor_String3x3()
239      sp_b = self._SparseTensor_String3x5()
240
241      for concat_dim in (-1, 1):
242        sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b])
243
244        self.assertEqual(sp_concat.indices.get_shape(), [8, 2])
245        self.assertEqual(sp_concat.values.get_shape(), [8])
246        self.assertEqual(sp_concat.dense_shape.get_shape(), [2])
247
248        concat_out = self.evaluate(sp_concat)
249
250        self.assertAllEqual(
251            concat_out.indices,
252            [[0, 2], [1, 0], [1, 4], [2, 0], [2, 2], [2, 3], [2, 6], [2, 7]])
253        self.assertAllEqual(concat_out.values,
254                            [b"a", b"b", b"e", b"c", b"d", b"f", b"g", b"h"])
255        self.assertAllEqual(concat_out.dense_shape, [3, 8])
256
257  @test_util.run_deprecated_v1
258  def testMismatchedRank(self):
259    with self.session(use_gpu=False):
260      sp_a = self._SparseTensor_3x3()
261      sp_e = self._SparseTensor_2x3x4()
262
263      # Rank mismatches can be caught at shape-inference time
264      for concat_dim in (-1, 1):
265        with self.assertRaises(ValueError):
266          sparse_ops.sparse_concat(concat_dim, [sp_a, sp_e])
267
268  @test_util.run_deprecated_v1
269  def testMismatchedRankExpandNonconcatDim(self):
270    with self.session(use_gpu=False):
271      sp_a = self._SparseTensor_3x3()
272      sp_e = self._SparseTensor_2x3x4()
273
274      # Rank mismatches should be caught at shape-inference time, even for
275      # expand_nonconcat_dim=True.
276      for concat_dim in (-1, 1):
277        with self.assertRaises(ValueError):
278          sparse_ops.sparse_concat(
279              concat_dim, [sp_a, sp_e], expand_nonconcat_dim=True)
280
281  @test_util.run_deprecated_v1
282  def testMismatchedShapes(self):
283    with self.session(use_gpu=False) as sess:
284      sp_a = self._SparseTensor_3x3()
285      sp_b = self._SparseTensor_3x5()
286      sp_c = self._SparseTensor_3x2()
287      sp_d = self._SparseTensor_2x3()
288      for concat_dim in (-1, 1):
289        sp_concat = sparse_ops.sparse_concat(concat_dim,
290                                             [sp_a, sp_b, sp_c, sp_d])
291
292        # Shape mismatches can only be caught when the op is run
293        with self.assertRaisesOpError("Input shapes must match"):
294          self.evaluate(sp_concat)
295
296  def testMismatchedShapesExpandNonconcatDim(self):
297    with self.session(use_gpu=False) as sess:
298      sp_a = self._SparseTensor_3x3()
299      sp_b = self._SparseTensor_3x5()
300      sp_c = self._SparseTensor_3x2()
301      sp_d = self._SparseTensor_2x3()
302      for concat_dim0 in (-2, 0):
303        for concat_dim1 in (-1, 1):
304          sp_concat_dim0 = sparse_ops.sparse_concat(
305              concat_dim0, [sp_a, sp_b, sp_c, sp_d], expand_nonconcat_dim=True)
306          sp_concat_dim1 = sparse_ops.sparse_concat(
307              concat_dim1, [sp_a, sp_b, sp_c, sp_d], expand_nonconcat_dim=True)
308
309          sp_concat_dim0_out = self.evaluate(sp_concat_dim0)
310          sp_concat_dim1_out = self.evaluate(sp_concat_dim1)
311
312          self.assertAllEqual(sp_concat_dim0_out.indices,
313                              [[0, 2], [1, 0], [2, 0], [2, 2], [4, 1], [5, 0],
314                               [5, 3], [5, 4], [7, 0], [8, 0], [9, 1], [10, 0],
315                               [10, 2]])
316          self.assertAllEqual(sp_concat_dim0_out.values,
317                              [1, 2, 3, 4, 1, 2, 1, 0, 1, 2, 1, 1, 2])
318          self.assertAllEqual(sp_concat_dim0_out.dense_shape, [11, 5])
319
320          self.assertAllEqual(sp_concat_dim1_out.indices,
321                              [[0, 2], [0, 11], [1, 0], [1, 4], [1, 8], [1, 10],
322                               [1, 12], [2, 0], [2, 2], [2, 3], [2, 6], [2, 7],
323                               [2, 8]])
324          self.assertAllEqual(sp_concat_dim1_out.values,
325                              [1, 1, 2, 1, 1, 1, 2, 3, 4, 2, 1, 0, 2])
326          self.assertAllEqual(sp_concat_dim1_out.dense_shape, [3, 13])
327
328  @test_util.run_deprecated_v1
329  def testShapeInferenceUnknownShapes(self):
330    with self.session(use_gpu=False):
331      sp_inputs = [
332          self._SparseTensor_UnknownShape(),
333          self._SparseTensor_UnknownShape(val_shape=[3]),
334          self._SparseTensor_UnknownShape(ind_shape=[1, 3]),
335          self._SparseTensor_UnknownShape(shape_shape=[3])
336      ]
337
338      for concat_dim in (-2, 0):
339        sp_concat = sparse_ops.sparse_concat(concat_dim, sp_inputs)
340
341        self.assertEqual(sp_concat.indices.get_shape().as_list(), [None, 3])
342        self.assertEqual(sp_concat.values.get_shape().as_list(), [None])
343        self.assertEqual(sp_concat.dense_shape.get_shape(), [3])
344
345  def testConcatShape(self):
346    # Test case for GitHub 21964.
347    x = sparse_tensor.SparseTensor(
348        indices=[[0, 0], [1, 1]], values=[1, 2], dense_shape=[2, 2])
349    y = sparse_tensor.SparseTensor(
350        indices=[[0, 0], [1, 1]], values=[1, 2], dense_shape=[2, 2])
351    z = sparse_ops.sparse_concat(-1, [x, y])
352    self.assertEqual(z.get_shape().as_list(), [2, 4])
353
354
355if __name__ == "__main__":
356  test.main()
357