1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functional tests for Split Op.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import errors_impl 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import test_util 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import gradients_impl 30from tensorflow.python.ops import math_ops 31from tensorflow.python.platform import test 32 33_TEST_DTYPES = (dtypes.float32, dtypes.float64, dtypes.complex64, 34 dtypes.complex128) 35 36 37class SplitOpTest(test.TestCase): 38 39 def _makeData(self, shape, dtype): 40 data = np.random.rand(*shape).astype(dtype.as_numpy_dtype) 41 if dtype.is_complex: 42 data -= 1j * data 43 return data 44 45 @test_util.run_deprecated_v1 46 def testShapeInference(self): 47 model_input = array_ops.placeholder(dtypes.float32, shape=(1, 10)) 48 49 # check that we fail during static shape inference if sizes are known 50 with self.assertRaises(ValueError): 51 # pylint: disable=expression-not-assigned 52 array_ops.split(model_input, [4], axis=1)[0] 53 # pylint: enable=expression-not-assigned 54 55 model_input = array_ops.placeholder(dtypes.float32) 56 inp = np.zeros((1, 10)) 57 # check that we still fail at runtime if the shapes were unknown 58 with self.cached_session(use_gpu=True) as sess: 59 with self.assertRaises(errors_impl.InvalidArgumentError): 60 sess.run(array_ops.split(model_input, [4]), {model_input: inp}) 61 62 # scalar Tensors are not permitted as num_splits 63 for axis in [0, -2]: 64 with self.cached_session(use_gpu=True) as sess: 65 with self.assertRaises(ValueError): 66 # pylint: disable=expression-not-assigned 67 sess.run( 68 array_ops.split( 69 array_ops.ones([4, 4]), 70 num_or_size_splits=constant_op.constant(2), 71 axis=axis)) 72 # pylint: enable=expression-not-assigned 73 74 # test that none split dimensions remain, even if we don't know how 75 # the split_dim will be split, but we do know the axis 76 result = array_ops.split( 77 array_ops.ones([5, 2]), array_ops.constant([2, 1, 2]) * 1, axis=0) 78 79 self.assertEqual(result[0].shape[1], 2) 80 self.assertEqual(result[1].shape[1], 2) 81 self.assertEqual(result[2].shape[1], 2) 82 83 model_input2 = array_ops.placeholder(dtypes.float32, shape=[None, 2]) 84 result = array_ops.split(model_input2, [2, 2], axis=0)[0] 85 86 with self.cached_session(use_gpu=True) as sess: 87 sess.run(result, feed_dict={model_input2: np.ones([4, 2])}) 88 89 @test_util.run_deprecated_v1 90 def testFailWithoutExplicitNum(self): 91 size_splits = array_ops.placeholder(dtype=dtypes.int32, shape=[None]) 92 93 value = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 94 95 with self.session(use_gpu=True) as sess: 96 with self.assertRaises(ValueError) as context: 97 sess.run(array_ops.split(value, size_splits), {size_splits: [2, 2, 6]}) 98 self.assertTrue("Cannot infer num from shape" in str(context.exception)) 99 100 @test_util.run_in_graph_and_eager_modes 101 def testExplicitNum(self): 102 size_splits = array_ops.constant([2, 2, 6], dtype=dtypes.int32) 103 value = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 104 105 # Eager and Graph modes raise different exceptions 106 with self.assertRaises((errors_impl.InvalidArgumentError, ValueError)): 107 array_ops.split(value, size_splits, num=4) 108 109 r = self.evaluate(array_ops.split(value, size_splits, num=3)) 110 self.assertAllEqual(r[0], value[0:2]) 111 self.assertAllEqual(r[1], value[2:4]) 112 self.assertAllEqual(r[2], value[4:]) 113 114 @test_util.run_in_graph_and_eager_modes 115 def testListOfScalarTensors(self): 116 a = math_ops.cast(5, dtypes.int32) 117 b = math_ops.cast(6, dtypes.int32) 118 119 value = np.random.rand(11, 11) 120 121 with test_util.device(use_gpu=True): 122 result = self.evaluate(array_ops.split(value, [a, b])) 123 124 self.assertAllEqual(result[0], value[0:5, :]) 125 self.assertAllEqual(result[1], value[5:, :]) 126 127 def _RunAndVerifyVariable(self, dtype, large_num_splits=False): 128 # Random dims of rank 5 129 shape = np.random.randint(1, 5, size=5) 130 split_dim = np.random.randint(-5, 5) 131 if large_num_splits: 132 num_split = np.random.randint(16, 25) 133 else: 134 num_split = np.random.randint(2, 8) 135 size_splits = np.random.randint(2, 8, num_split, dtype=np.int32) 136 shape[split_dim] = np.sum(size_splits) 137 inp = self._makeData(shape, dtype) 138 with test_util.device(use_gpu=True): 139 result = self.evaluate(array_ops.split(inp, size_splits, split_dim)) 140 slices = [slice(0, x) for x in shape] 141 offset = 0 142 for i in range(num_split): 143 slices[split_dim] = slice(offset, offset + size_splits[i]) 144 offset += size_splits[i] 145 self.assertAllEqual(result[i], inp[slices]) 146 147 def _testSpecialCasesVariable(self): 148 inp = np.random.rand(4, 4).astype("f") 149 150 with test_util.device(use_gpu=True): 151 result = self.evaluate(array_ops.split(inp, [4], 0)) 152 self.assertAllEqual(result[0], inp) 153 154 result = self.evaluate(array_ops.split(inp, [-1, 3], 0)) 155 self.assertAllEqual(result[0], inp[0:1, :]) 156 self.assertAllEqual(result[1], inp[1:4, :]) 157 158 def _testHugeNumberOfTensorsVariable(self, dtype): 159 num_split = 1000 160 size_splits = np.random.randint(1, 3, num_split, dtype=np.int32) 161 shape = [3, np.sum(size_splits)] 162 split_dim = 1 163 inp = self._makeData(shape, dtype) 164 with test_util.device(use_gpu=True): 165 result = self.evaluate(array_ops.split(inp, size_splits, split_dim)) 166 slices = [slice(0, x) for x in shape] 167 offset = 0 168 for i in range(num_split): 169 slices[split_dim] = slice(offset, offset + size_splits[i]) 170 offset += size_splits[i] 171 self.assertAllEqual(result[i], inp[slices]) 172 173 @test_util.run_in_graph_and_eager_modes 174 def testSpecialCasesVariable(self): 175 self._testSpecialCasesVariable() 176 for dtype in _TEST_DTYPES: 177 self._testHugeNumberOfTensorsVariable(dtype) 178 179 @test_util.run_in_graph_and_eager_modes 180 def testDegenerateVariable(self): 181 inp = np.random.rand(4, 4).astype("f") 182 with test_util.device(use_gpu=True): 183 result = self.evaluate(array_ops.split(inp, [-1, 4], 0)) 184 self.assertAllEqual(result[0], inp[0:0, :]) 185 self.assertAllEqual(result[1], inp[0:4, :]) 186 187 result = self.evaluate(array_ops.split(inp, [4, -1], 0)) 188 self.assertAllEqual(result[0], inp[0:4, :]) 189 self.assertAllEqual(result[1], inp[4:4, :]) 190 191 result = self.evaluate(array_ops.split(inp, [-1, 4], 1)) 192 self.assertAllEqual(result[0], inp[:, 0:0]) 193 self.assertAllEqual(result[1], inp[:, 0:4]) 194 195 result = self.evaluate(array_ops.split(inp, [4, -1], 1)) 196 self.assertAllEqual(result[0], inp[:, 0:4]) 197 self.assertAllEqual(result[1], inp[:, 4:4]) 198 199 def _testGradientsSimpleVariable(self, dtype): 200 inp = self._makeData((4, 4), dtype) 201 with test_util.device(use_gpu=True): 202 inp_tensor = ops.convert_to_tensor(inp) 203 s = array_ops.split(inp_tensor, [1, 3], 1) 204 inp_grads = [ 205 self._makeData((4, 1), dtype), self._makeData((4, 3), dtype) 206 ] 207 grad_tensors = [constant_op.constant(x) for x in inp_grads] 208 grad = gradients_impl.gradients(s, [inp_tensor], grad_tensors)[-1] 209 result = self.evaluate(grad) 210 211 self.assertAllEqual(result[:, 0:1], inp_grads[0]) 212 self.assertAllEqual(result[:, 1:4], inp_grads[1]) 213 214 @test_util.run_deprecated_v1 215 def testOutputShape(self): 216 for axis in [1, -1]: 217 with self.cached_session(use_gpu=True): 218 tensor = array_ops.placeholder(dtypes.float32, shape=[None, 12]) 219 size_splits = [3, 7, 2] 220 outputs = array_ops.split(tensor, size_splits, axis) 221 for i, output in enumerate(outputs): 222 self.assertEqual(output.get_shape().as_list(), [None, size_splits[i]]) 223 224 def _compare(self, x, dim, num): 225 np_ans = np.split(x, num, dim) 226 with test_util.device(use_gpu=True): 227 tf_ans = array_ops.split(value=x, num_or_size_splits=num, axis=dim) 228 out = self.evaluate(tf_ans) 229 self.assertEqual(num, len(np_ans)) 230 self.assertEqual(num, len(np_ans)) 231 self.assertEqual(num, len(out)) 232 for i in range(num): 233 self.assertAllEqual(np_ans[i], out[i]) 234 self.assertShapeEqual(np_ans[i], tf_ans[i]) 235 236 @test_util.run_in_graph_and_eager_modes 237 def testSplitRows(self): 238 for dtype in _TEST_DTYPES: 239 inp = self._makeData((4, 4), dtype) 240 self._compare(inp, 0, 4) 241 242 @test_util.run_in_graph_and_eager_modes 243 def testSplitCols(self): 244 for dtype in _TEST_DTYPES: 245 inp = self._makeData((4, 4), dtype) 246 self._compare(inp, 1, 4) 247 248 def _testEmpty(self, x, dim, num, expected_shape): 249 with test_util.device(use_gpu=True): 250 tf_ans = array_ops.split(value=x, num_or_size_splits=num, axis=dim) 251 out = self.evaluate(tf_ans) 252 self.assertEqual(x.size, 0) 253 self.assertEqual(len(out), num) 254 for i in range(num): 255 self.assertEqual(out[i].shape, expected_shape) 256 self.assertEqual(expected_shape, tf_ans[i].get_shape()) 257 258 @test_util.run_in_graph_and_eager_modes 259 def testEmpty(self): 260 # Note: np.split returns a rank-0 empty ndarray 261 # if the input ndarray is empty. 262 for dtype in _TEST_DTYPES: 263 inp = self._makeData((8, 0, 21), dtype) 264 self._testEmpty(inp, 0, 2, (4, 0, 21)) 265 self._testEmpty(inp, 0, 4, (2, 0, 21)) 266 self._testEmpty(inp, 1, 4, (8, 0, 21)) 267 self._testEmpty(inp, 2, 3, (8, 0, 7)) 268 self._testEmpty(inp, 2, 7, (8, 0, 3)) 269 270 @test_util.run_in_graph_and_eager_modes 271 def testIdentity(self): 272 for dtype in _TEST_DTYPES: 273 inp = self._makeData((2, 2, 2), dtype) 274 self._compare(inp, 0, 1) 275 self._compare(inp, 1, 1) 276 self._compare(inp, 2, 1) 277 278 @test_util.run_in_graph_and_eager_modes 279 def testSplitDim0(self): 280 for dtype in _TEST_DTYPES: 281 self._compare(self._makeData((6, 10, 18), dtype), 0, 3) 282 self._compare(self._makeData((6, 7, 18), dtype), 0, 3) 283 self._compare(self._makeData((6, 7, 9), dtype), 0, 3) 284 285 def _RunAndVerify(self, dtype, large_num_splits=False): 286 # Random dims of rank 5 287 shape = np.random.randint(0, 5, size=5) 288 split_dim = np.random.randint(-5, 5) 289 if large_num_splits: 290 num_split = np.random.randint(9, 15) 291 else: 292 num_split = np.random.randint(2, 8) 293 shape[split_dim] = np.random.randint(2, 5) * num_split 294 inp = self._makeData(shape, dtype) 295 with test_util.device(use_gpu=True): 296 result = self.evaluate( 297 array_ops.split( 298 value=inp, num_or_size_splits=num_split, axis=split_dim)) 299 slices = [slice(0, x) for x in shape] 300 offset = 0 301 length = shape[split_dim] // num_split 302 for i in range(num_split): 303 slices[split_dim] = slice(offset, offset + length) 304 offset += length 305 self.assertAllEqual(result[i], inp[slices]) 306 307 @test_util.run_in_graph_and_eager_modes 308 def testRandom(self): 309 for dtype in _TEST_DTYPES: 310 for _ in range(5): 311 self._RunAndVerify(dtype) 312 self._RunAndVerify(dtype, large_num_splits=True) 313 self._RunAndVerifyVariable(dtype) 314 self._RunAndVerifyVariable(dtype, large_num_splits=True) 315 316 def _testGradientsSimple(self, dtype): 317 inp = self._makeData((4, 4), dtype) 318 with self.cached_session(use_gpu=True): 319 inp_tensor = ops.convert_to_tensor(inp) 320 s = array_ops.split(value=inp_tensor, num_or_size_splits=4, axis=1) 321 inp_grads = [self._makeData((4, 1), dtype)for _ in range(4)] 322 grad_tensors = [constant_op.constant(x) for x in inp_grads] 323 grad = gradients_impl.gradients(s, [inp_tensor], grad_tensors)[0] 324 result = self.evaluate(grad) 325 for i in range(4): 326 self.assertAllEqual(result[:, i:i + 1], inp_grads[i]) 327 328 @test_util.run_deprecated_v1 329 def testGradientsAll(self): 330 for dtype in _TEST_DTYPES: 331 self._testGradientsSimple(dtype) 332 self._testGradientsSimpleVariable(dtype) 333 334 @test_util.run_deprecated_v1 335 def testShapeFunctionEdgeCases(self): 336 # split_dim greater than rank of input. 337 with self.assertRaises(ValueError): 338 array_ops.split(value=[[0, 1], [2, 3]], num_or_size_splits=4, axis=2) 339 340 # split dim less than -(rank of input) 341 with self.assertRaises(ValueError): 342 array_ops.split(value=[[0, 1], [2, 3]], num_or_size_splits=4, axis=-3) 343 344 # num_split does not evenly divide the size in split_dim. 345 with self.assertRaisesRegexp(ValueError, "should evenly divide"): 346 array_ops.split(value=[0, 1, 2, 3], num_or_size_splits=3, axis=0) 347 348 # Unknown split_dim. 349 splits = array_ops.split( 350 value=[[0, 1, 2, 3]], 351 num_or_size_splits=4, 352 axis=array_ops.placeholder(dtypes.int32)) 353 for s in splits: 354 self.assertEqual([None, None], s.get_shape().as_list()) 355 356 # Unknown split_dim and input shape. 357 splits = array_ops.split( 358 value=array_ops.placeholder(dtypes.float32), 359 num_or_size_splits=4, 360 axis=array_ops.placeholder(dtypes.int32)) 361 for s in splits: 362 self.assertEqual(None, s.get_shape().ndims) 363 364 @test_util.run_deprecated_v1 365 def testVariableShapeFunction(self): 366 # size_splits too big 367 with self.assertRaises(ValueError): 368 array_ops.split([0, 1], [3, -1], axis=0) 369 370 # Correct inference of variable dimension 371 s0, s1 = array_ops.split([0, 1, 2], [2, -1], axis=0) 372 assert s0.shape.as_list() == [2] 373 assert s1.shape.as_list() == [1] 374 375 @test_util.run_deprecated_v1 376 @test_util.disable_xla("b/123337890") # Error messages differ 377 def testNonexistentDimTensor(self): 378 x = array_ops.placeholder(dtypes.int32) 379 values = np.zeros([5, 30]) 380 splits = array_ops.placeholder(dtypes.int32) 381 with self.assertRaisesRegexp(ValueError, "Cannot infer"): 382 y = array_ops.split(values, splits, axis=x) 383 384 splits = array_ops.placeholder(dtypes.int32, [3]) 385 y = array_ops.split(values, splits, axis=x) 386 with self.session(use_gpu=True) as sess: 387 with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 388 "must have exactly one element"): 389 sess.run(y, {x: np.array([], dtype=np.int32), splits: [4, 11, 15]}) 390 391 392if __name__ == "__main__": 393 test.main() 394