1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for SparseConcat.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import sparse_tensor 26from tensorflow.python.framework import test_util 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import sparse_ops 29from tensorflow.python.platform import test 30 31 32class SparseConcatTest(test.TestCase): 33 34 def _SparseTensor_UnknownShape(self, 35 ind_shape=None, 36 val_shape=None, 37 shape_shape=None): 38 return sparse_tensor.SparseTensor( 39 array_ops.placeholder( 40 dtypes.int64, shape=ind_shape), 41 array_ops.placeholder( 42 dtypes.float32, shape=val_shape), 43 array_ops.placeholder( 44 dtypes.int64, shape=shape_shape)) 45 46 def _SparseTensorValue_3x3(self): 47 # [ 1] 48 # [2 ] 49 # [3 4] 50 ind = np.array([[0, 2], [1, 0], [2, 0], [2, 2]]) 51 val = np.array([1, 2, 3, 4]) 52 shape = np.array([3, 3]) 53 return sparse_tensor.SparseTensorValue( 54 np.array(ind, np.int64), 55 np.array(val, np.float32), np.array(shape, np.int64)) 56 57 def _SparseTensor_3x3(self): 58 return sparse_tensor.SparseTensor.from_value(self._SparseTensorValue_3x3()) 59 60 def _SparseTensorValue_3x5(self): 61 # [ ] 62 # [ 1 ] 63 # [2 1 0] 64 ind = np.array([[1, 1], [2, 0], [2, 3], [2, 4]]) 65 val = np.array([1, 2, 1, 0]) 66 shape = np.array([3, 5]) 67 return sparse_tensor.SparseTensorValue( 68 np.array(ind, np.int64), 69 np.array(val, np.float32), np.array(shape, np.int64)) 70 71 def _SparseTensor_3x5(self): 72 return sparse_tensor.SparseTensor.from_value(self._SparseTensorValue_3x5()) 73 74 def _SparseTensor_3x2(self): 75 # [ ] 76 # [1 ] 77 # [2 ] 78 ind = np.array([[1, 0], [2, 0]]) 79 val = np.array([1, 2]) 80 shape = np.array([3, 2]) 81 return sparse_tensor.SparseTensor( 82 constant_op.constant(ind, dtypes.int64), 83 constant_op.constant(val, dtypes.float32), 84 constant_op.constant(shape, dtypes.int64)) 85 86 def _SparseTensor_2x3(self): 87 # [ 1 ] 88 # [1 2] 89 ind = np.array([[0, 1], [1, 0], [1, 2]]) 90 val = np.array([1, 1, 2]) 91 shape = np.array([2, 3]) 92 return sparse_tensor.SparseTensor( 93 constant_op.constant(ind, dtypes.int64), 94 constant_op.constant(val, dtypes.float32), 95 constant_op.constant(shape, dtypes.int64)) 96 97 def _SparseTensor_2x3x4(self): 98 ind = np.array([ 99 [0, 0, 1], 100 [0, 1, 0], [0, 1, 2], 101 [1, 0, 3], 102 [1, 1, 1], [1, 1, 3], 103 [1, 2, 2]]) 104 val = np.array([1, 10, 12, 103, 111, 113, 122]) 105 shape = np.array([2, 3, 4]) 106 return sparse_tensor.SparseTensor( 107 constant_op.constant(ind, dtypes.int64), 108 constant_op.constant(val, dtypes.float32), 109 constant_op.constant(shape, dtypes.int64)) 110 111 def _SparseTensor_String3x3(self): 112 # [ a] 113 # [b ] 114 # [c d] 115 ind = np.array([[0, 2], [1, 0], [2, 0], [2, 2]]) 116 val = np.array(["a", "b", "c", "d"]) 117 shape = np.array([3, 3]) 118 return sparse_tensor.SparseTensor( 119 constant_op.constant(ind, dtypes.int64), 120 constant_op.constant(val, dtypes.string), 121 constant_op.constant(shape, dtypes.int64)) 122 123 def _SparseTensor_String3x5(self): 124 # [ ] 125 # [ e ] 126 # [f g h] 127 ind = np.array([[1, 1], [2, 0], [2, 3], [2, 4]]) 128 val = np.array(["e", "f", "g", "h"]) 129 shape = np.array([3, 5]) 130 return sparse_tensor.SparseTensor( 131 constant_op.constant(ind, dtypes.int64), 132 constant_op.constant(val, dtypes.string), 133 constant_op.constant(shape, dtypes.int64)) 134 135 def testConcat1(self): 136 with self.session(use_gpu=False) as sess: 137 # concat(A): 138 # [ 1] 139 # [2 ] 140 # [3 4] 141 for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()): 142 # Note that we ignore concat_dim in this case since we short-circuit the 143 # single-input case in python. 144 for concat_dim in (-2000, 1, 2000): 145 sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a]) 146 147 self.assertEqual(sp_concat.indices.get_shape(), [4, 2]) 148 self.assertEqual(sp_concat.values.get_shape(), [4]) 149 self.assertEqual(sp_concat.dense_shape.get_shape(), [2]) 150 151 concat_out = self.evaluate(sp_concat) 152 153 self.assertAllEqual(concat_out.indices, 154 [[0, 2], [1, 0], [2, 0], [2, 2]]) 155 self.assertAllEqual(concat_out.values, [1, 2, 3, 4]) 156 self.assertAllEqual(concat_out.dense_shape, [3, 3]) 157 158 def testConcat2(self): 159 with self.session(use_gpu=False) as sess: 160 # concat(A, B): 161 # [ 1 ] 162 # [2 1 ] 163 # [3 4 2 1 0] 164 for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()): 165 for sp_b in (self._SparseTensorValue_3x5(), self._SparseTensor_3x5()): 166 for concat_dim in (-1, 1): 167 sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b]) 168 169 self.assertEqual(sp_concat.indices.get_shape(), [8, 2]) 170 self.assertEqual(sp_concat.values.get_shape(), [8]) 171 self.assertEqual(sp_concat.dense_shape.get_shape(), [2]) 172 173 concat_out = self.evaluate(sp_concat) 174 175 self.assertAllEqual(concat_out.indices, [[0, 2], [1, 0], [1, 4], 176 [2, 0], [2, 2], [2, 3], 177 [2, 6], [2, 7]]) 178 self.assertAllEqual(concat_out.values, [1, 2, 1, 3, 4, 2, 1, 0]) 179 self.assertAllEqual(concat_out.dense_shape, [3, 8]) 180 181 def testConcatDim0(self): 182 with self.session(use_gpu=False) as sess: 183 # concat(A, D): 184 # [ 1] 185 # [2 ] 186 # [3 4] 187 # [ 1 ] 188 # [1 2] 189 sp_a = self._SparseTensor_3x3() 190 sp_d = self._SparseTensor_2x3() 191 192 for concat_dim in (-2, 0): 193 sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_d]) 194 195 self.assertEqual(sp_concat.indices.get_shape(), [7, 2]) 196 self.assertEqual(sp_concat.values.get_shape(), [7]) 197 self.assertEqual(sp_concat.dense_shape.get_shape(), [2]) 198 199 concat_out = self.evaluate(sp_concat) 200 201 self.assertAllEqual( 202 concat_out.indices, 203 [[0, 2], [1, 0], [2, 0], [2, 2], [3, 1], [4, 0], [4, 2]]) 204 self.assertAllEqual(concat_out.values, np.array([1, 2, 3, 4, 1, 1, 2])) 205 self.assertAllEqual(concat_out.dense_shape, np.array([5, 3])) 206 207 def testConcat3(self): 208 with self.session(use_gpu=False) as sess: 209 # concat(A, B, C): 210 # [ 1 ] 211 # [2 1 1 ] 212 # [3 4 2 1 0 2 ] 213 sp_a = self._SparseTensor_3x3() 214 sp_b = self._SparseTensor_3x5() 215 sp_c = self._SparseTensor_3x2() 216 217 for concat_dim in (-1, 1): 218 sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b, sp_c]) 219 220 self.assertEqual(sp_concat.indices.get_shape(), [10, 2]) 221 self.assertEqual(sp_concat.values.get_shape(), [10]) 222 self.assertEqual(sp_concat.dense_shape.get_shape(), [2]) 223 224 concat_out = self.evaluate(sp_concat) 225 226 self.assertAllEqual(concat_out.indices, [[0, 2], [1, 0], [1, 4], [1, 8], 227 [2, 0], [2, 2], [2, 3], [2, 6], 228 [2, 7], [2, 8]]) 229 self.assertAllEqual(concat_out.values, [1, 2, 1, 1, 3, 4, 2, 1, 0, 2]) 230 self.assertAllEqual(concat_out.dense_shape, [3, 10]) 231 232 def testConcatNonNumeric(self): 233 with self.session(use_gpu=False) as sess: 234 # concat(A, B): 235 # [ a ] 236 # [b e ] 237 # [c d f g h] 238 sp_a = self._SparseTensor_String3x3() 239 sp_b = self._SparseTensor_String3x5() 240 241 for concat_dim in (-1, 1): 242 sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b]) 243 244 self.assertEqual(sp_concat.indices.get_shape(), [8, 2]) 245 self.assertEqual(sp_concat.values.get_shape(), [8]) 246 self.assertEqual(sp_concat.dense_shape.get_shape(), [2]) 247 248 concat_out = self.evaluate(sp_concat) 249 250 self.assertAllEqual( 251 concat_out.indices, 252 [[0, 2], [1, 0], [1, 4], [2, 0], [2, 2], [2, 3], [2, 6], [2, 7]]) 253 self.assertAllEqual(concat_out.values, 254 [b"a", b"b", b"e", b"c", b"d", b"f", b"g", b"h"]) 255 self.assertAllEqual(concat_out.dense_shape, [3, 8]) 256 257 @test_util.run_deprecated_v1 258 def testMismatchedRank(self): 259 with self.session(use_gpu=False): 260 sp_a = self._SparseTensor_3x3() 261 sp_e = self._SparseTensor_2x3x4() 262 263 # Rank mismatches can be caught at shape-inference time 264 for concat_dim in (-1, 1): 265 with self.assertRaises(ValueError): 266 sparse_ops.sparse_concat(concat_dim, [sp_a, sp_e]) 267 268 @test_util.run_deprecated_v1 269 def testMismatchedRankExpandNonconcatDim(self): 270 with self.session(use_gpu=False): 271 sp_a = self._SparseTensor_3x3() 272 sp_e = self._SparseTensor_2x3x4() 273 274 # Rank mismatches should be caught at shape-inference time, even for 275 # expand_nonconcat_dim=True. 276 for concat_dim in (-1, 1): 277 with self.assertRaises(ValueError): 278 sparse_ops.sparse_concat( 279 concat_dim, [sp_a, sp_e], expand_nonconcat_dim=True) 280 281 @test_util.run_deprecated_v1 282 def testMismatchedShapes(self): 283 with self.session(use_gpu=False) as sess: 284 sp_a = self._SparseTensor_3x3() 285 sp_b = self._SparseTensor_3x5() 286 sp_c = self._SparseTensor_3x2() 287 sp_d = self._SparseTensor_2x3() 288 for concat_dim in (-1, 1): 289 sp_concat = sparse_ops.sparse_concat(concat_dim, 290 [sp_a, sp_b, sp_c, sp_d]) 291 292 # Shape mismatches can only be caught when the op is run 293 with self.assertRaisesOpError("Input shapes must match"): 294 self.evaluate(sp_concat) 295 296 def testMismatchedShapesExpandNonconcatDim(self): 297 with self.session(use_gpu=False) as sess: 298 sp_a = self._SparseTensor_3x3() 299 sp_b = self._SparseTensor_3x5() 300 sp_c = self._SparseTensor_3x2() 301 sp_d = self._SparseTensor_2x3() 302 for concat_dim0 in (-2, 0): 303 for concat_dim1 in (-1, 1): 304 sp_concat_dim0 = sparse_ops.sparse_concat( 305 concat_dim0, [sp_a, sp_b, sp_c, sp_d], expand_nonconcat_dim=True) 306 sp_concat_dim1 = sparse_ops.sparse_concat( 307 concat_dim1, [sp_a, sp_b, sp_c, sp_d], expand_nonconcat_dim=True) 308 309 sp_concat_dim0_out = self.evaluate(sp_concat_dim0) 310 sp_concat_dim1_out = self.evaluate(sp_concat_dim1) 311 312 self.assertAllEqual(sp_concat_dim0_out.indices, 313 [[0, 2], [1, 0], [2, 0], [2, 2], [4, 1], [5, 0], 314 [5, 3], [5, 4], [7, 0], [8, 0], [9, 1], [10, 0], 315 [10, 2]]) 316 self.assertAllEqual(sp_concat_dim0_out.values, 317 [1, 2, 3, 4, 1, 2, 1, 0, 1, 2, 1, 1, 2]) 318 self.assertAllEqual(sp_concat_dim0_out.dense_shape, [11, 5]) 319 320 self.assertAllEqual(sp_concat_dim1_out.indices, 321 [[0, 2], [0, 11], [1, 0], [1, 4], [1, 8], [1, 10], 322 [1, 12], [2, 0], [2, 2], [2, 3], [2, 6], [2, 7], 323 [2, 8]]) 324 self.assertAllEqual(sp_concat_dim1_out.values, 325 [1, 1, 2, 1, 1, 1, 2, 3, 4, 2, 1, 0, 2]) 326 self.assertAllEqual(sp_concat_dim1_out.dense_shape, [3, 13]) 327 328 @test_util.run_deprecated_v1 329 def testShapeInferenceUnknownShapes(self): 330 with self.session(use_gpu=False): 331 sp_inputs = [ 332 self._SparseTensor_UnknownShape(), 333 self._SparseTensor_UnknownShape(val_shape=[3]), 334 self._SparseTensor_UnknownShape(ind_shape=[1, 3]), 335 self._SparseTensor_UnknownShape(shape_shape=[3]) 336 ] 337 338 for concat_dim in (-2, 0): 339 sp_concat = sparse_ops.sparse_concat(concat_dim, sp_inputs) 340 341 self.assertEqual(sp_concat.indices.get_shape().as_list(), [None, 3]) 342 self.assertEqual(sp_concat.values.get_shape().as_list(), [None]) 343 self.assertEqual(sp_concat.dense_shape.get_shape(), [3]) 344 345 def testConcatShape(self): 346 # Test case for GitHub 21964. 347 x = sparse_tensor.SparseTensor( 348 indices=[[0, 0], [1, 1]], values=[1, 2], dense_shape=[2, 2]) 349 y = sparse_tensor.SparseTensor( 350 indices=[[0, 0], [1, 1]], values=[1, 2], dense_shape=[2, 2]) 351 z = sparse_ops.sparse_concat(-1, [x, y]) 352 self.assertEqual(z.get_shape().as_list(), [2, 4]) 353 354 355if __name__ == "__main__": 356 test.main() 357