1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functional tests for Split Op."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import errors_impl
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import gradients_impl
30from tensorflow.python.ops import math_ops
31from tensorflow.python.platform import test
32
33_TEST_DTYPES = (dtypes.float32, dtypes.float64, dtypes.complex64,
34                dtypes.complex128)
35
36
37class SplitOpTest(test.TestCase):
38
39  def _makeData(self, shape, dtype):
40    data = np.random.rand(*shape).astype(dtype.as_numpy_dtype)
41    if dtype.is_complex:
42      data -= 1j * data
43    return data
44
45  @test_util.run_deprecated_v1
46  def testShapeInference(self):
47    model_input = array_ops.placeholder(dtypes.float32, shape=(1, 10))
48
49    # check that we fail during static shape inference if sizes are known
50    with self.assertRaises(ValueError):
51      # pylint: disable=expression-not-assigned
52      array_ops.split(model_input, [4], axis=1)[0]
53      # pylint: enable=expression-not-assigned
54
55    model_input = array_ops.placeholder(dtypes.float32)
56    inp = np.zeros((1, 10))
57    # check that we still fail at runtime if the shapes were unknown
58    with self.cached_session(use_gpu=True) as sess:
59      with self.assertRaises(errors_impl.InvalidArgumentError):
60        sess.run(array_ops.split(model_input, [4]), {model_input: inp})
61
62    # scalar Tensors are not permitted as num_splits
63    for axis in [0, -2]:
64      with self.cached_session(use_gpu=True) as sess:
65        with self.assertRaises(ValueError):
66          # pylint: disable=expression-not-assigned
67          sess.run(
68              array_ops.split(
69                  array_ops.ones([4, 4]),
70                  num_or_size_splits=constant_op.constant(2),
71                  axis=axis))
72          # pylint: enable=expression-not-assigned
73
74    # test that none split dimensions remain, even if we don't know how
75    # the split_dim will be split, but we do know the axis
76    result = array_ops.split(
77        array_ops.ones([5, 2]), array_ops.constant([2, 1, 2]) * 1, axis=0)
78
79    self.assertEqual(result[0].shape[1], 2)
80    self.assertEqual(result[1].shape[1], 2)
81    self.assertEqual(result[2].shape[1], 2)
82
83    model_input2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
84    result = array_ops.split(model_input2, [2, 2], axis=0)[0]
85
86    with self.cached_session(use_gpu=True) as sess:
87      sess.run(result, feed_dict={model_input2: np.ones([4, 2])})
88
89  @test_util.run_deprecated_v1
90  def testFailWithoutExplicitNum(self):
91    size_splits = array_ops.placeholder(dtype=dtypes.int32, shape=[None])
92
93    value = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
94
95    with self.session(use_gpu=True) as sess:
96      with self.assertRaises(ValueError) as context:
97        sess.run(array_ops.split(value, size_splits), {size_splits: [2, 2, 6]})
98      self.assertTrue("Cannot infer num from shape" in str(context.exception))
99
100  @test_util.run_in_graph_and_eager_modes
101  def testExplicitNum(self):
102    size_splits = array_ops.constant([2, 2, 6], dtype=dtypes.int32)
103    value = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
104
105    # Eager and Graph modes raise different exceptions
106    with self.assertRaises((errors_impl.InvalidArgumentError, ValueError)):
107      array_ops.split(value, size_splits, num=4)
108
109    r = self.evaluate(array_ops.split(value, size_splits, num=3))
110    self.assertAllEqual(r[0], value[0:2])
111    self.assertAllEqual(r[1], value[2:4])
112    self.assertAllEqual(r[2], value[4:])
113
114  @test_util.run_in_graph_and_eager_modes
115  def testListOfScalarTensors(self):
116    a = math_ops.cast(5, dtypes.int32)
117    b = math_ops.cast(6, dtypes.int32)
118
119    value = np.random.rand(11, 11)
120
121    with test_util.device(use_gpu=True):
122      result = self.evaluate(array_ops.split(value, [a, b]))
123
124    self.assertAllEqual(result[0], value[0:5, :])
125    self.assertAllEqual(result[1], value[5:, :])
126
127  def _RunAndVerifyVariable(self, dtype, large_num_splits=False):
128    # Random dims of rank 5
129    shape = np.random.randint(1, 5, size=5)
130    split_dim = np.random.randint(-5, 5)
131    if large_num_splits:
132      num_split = np.random.randint(16, 25)
133    else:
134      num_split = np.random.randint(2, 8)
135    size_splits = np.random.randint(2, 8, num_split, dtype=np.int32)
136    shape[split_dim] = np.sum(size_splits)
137    inp = self._makeData(shape, dtype)
138    with test_util.device(use_gpu=True):
139      result = self.evaluate(array_ops.split(inp, size_splits, split_dim))
140    slices = [slice(0, x) for x in shape]
141    offset = 0
142    for i in range(num_split):
143      slices[split_dim] = slice(offset, offset + size_splits[i])
144      offset += size_splits[i]
145      self.assertAllEqual(result[i], inp[slices])
146
147  def _testSpecialCasesVariable(self):
148    inp = np.random.rand(4, 4).astype("f")
149
150    with test_util.device(use_gpu=True):
151      result = self.evaluate(array_ops.split(inp, [4], 0))
152      self.assertAllEqual(result[0], inp)
153
154      result = self.evaluate(array_ops.split(inp, [-1, 3], 0))
155      self.assertAllEqual(result[0], inp[0:1, :])
156      self.assertAllEqual(result[1], inp[1:4, :])
157
158  def _testHugeNumberOfTensorsVariable(self, dtype):
159    num_split = 1000
160    size_splits = np.random.randint(1, 3, num_split, dtype=np.int32)
161    shape = [3, np.sum(size_splits)]
162    split_dim = 1
163    inp = self._makeData(shape, dtype)
164    with test_util.device(use_gpu=True):
165      result = self.evaluate(array_ops.split(inp, size_splits, split_dim))
166    slices = [slice(0, x) for x in shape]
167    offset = 0
168    for i in range(num_split):
169      slices[split_dim] = slice(offset, offset + size_splits[i])
170      offset += size_splits[i]
171      self.assertAllEqual(result[i], inp[slices])
172
173  @test_util.run_in_graph_and_eager_modes
174  def testSpecialCasesVariable(self):
175    self._testSpecialCasesVariable()
176    for dtype in _TEST_DTYPES:
177      self._testHugeNumberOfTensorsVariable(dtype)
178
179  @test_util.run_in_graph_and_eager_modes
180  def testDegenerateVariable(self):
181    inp = np.random.rand(4, 4).astype("f")
182    with test_util.device(use_gpu=True):
183      result = self.evaluate(array_ops.split(inp, [-1, 4], 0))
184      self.assertAllEqual(result[0], inp[0:0, :])
185      self.assertAllEqual(result[1], inp[0:4, :])
186
187      result = self.evaluate(array_ops.split(inp, [4, -1], 0))
188      self.assertAllEqual(result[0], inp[0:4, :])
189      self.assertAllEqual(result[1], inp[4:4, :])
190
191      result = self.evaluate(array_ops.split(inp, [-1, 4], 1))
192      self.assertAllEqual(result[0], inp[:, 0:0])
193      self.assertAllEqual(result[1], inp[:, 0:4])
194
195      result = self.evaluate(array_ops.split(inp, [4, -1], 1))
196      self.assertAllEqual(result[0], inp[:, 0:4])
197      self.assertAllEqual(result[1], inp[:, 4:4])
198
199  def _testGradientsSimpleVariable(self, dtype):
200    inp = self._makeData((4, 4), dtype)
201    with test_util.device(use_gpu=True):
202      inp_tensor = ops.convert_to_tensor(inp)
203      s = array_ops.split(inp_tensor, [1, 3], 1)
204      inp_grads = [
205          self._makeData((4, 1), dtype), self._makeData((4, 3), dtype)
206      ]
207      grad_tensors = [constant_op.constant(x) for x in inp_grads]
208      grad = gradients_impl.gradients(s, [inp_tensor], grad_tensors)[-1]
209      result = self.evaluate(grad)
210
211    self.assertAllEqual(result[:, 0:1], inp_grads[0])
212    self.assertAllEqual(result[:, 1:4], inp_grads[1])
213
214  @test_util.run_deprecated_v1
215  def testOutputShape(self):
216    for axis in [1, -1]:
217      with self.cached_session(use_gpu=True):
218        tensor = array_ops.placeholder(dtypes.float32, shape=[None, 12])
219        size_splits = [3, 7, 2]
220        outputs = array_ops.split(tensor, size_splits, axis)
221        for i, output in enumerate(outputs):
222          self.assertEqual(output.get_shape().as_list(), [None, size_splits[i]])
223
224  def _compare(self, x, dim, num):
225    np_ans = np.split(x, num, dim)
226    with test_util.device(use_gpu=True):
227      tf_ans = array_ops.split(value=x, num_or_size_splits=num, axis=dim)
228      out = self.evaluate(tf_ans)
229    self.assertEqual(num, len(np_ans))
230    self.assertEqual(num, len(np_ans))
231    self.assertEqual(num, len(out))
232    for i in range(num):
233      self.assertAllEqual(np_ans[i], out[i])
234      self.assertShapeEqual(np_ans[i], tf_ans[i])
235
236  @test_util.run_in_graph_and_eager_modes
237  def testSplitRows(self):
238    for dtype in _TEST_DTYPES:
239      inp = self._makeData((4, 4), dtype)
240      self._compare(inp, 0, 4)
241
242  @test_util.run_in_graph_and_eager_modes
243  def testSplitCols(self):
244    for dtype in _TEST_DTYPES:
245      inp = self._makeData((4, 4), dtype)
246      self._compare(inp, 1, 4)
247
248  def _testEmpty(self, x, dim, num, expected_shape):
249    with test_util.device(use_gpu=True):
250      tf_ans = array_ops.split(value=x, num_or_size_splits=num, axis=dim)
251      out = self.evaluate(tf_ans)
252    self.assertEqual(x.size, 0)
253    self.assertEqual(len(out), num)
254    for i in range(num):
255      self.assertEqual(out[i].shape, expected_shape)
256      self.assertEqual(expected_shape, tf_ans[i].get_shape())
257
258  @test_util.run_in_graph_and_eager_modes
259  def testEmpty(self):
260    # Note: np.split returns a rank-0 empty ndarray
261    # if the input ndarray is empty.
262    for dtype in _TEST_DTYPES:
263      inp = self._makeData((8, 0, 21), dtype)
264      self._testEmpty(inp, 0, 2, (4, 0, 21))
265      self._testEmpty(inp, 0, 4, (2, 0, 21))
266      self._testEmpty(inp, 1, 4, (8, 0, 21))
267      self._testEmpty(inp, 2, 3, (8, 0, 7))
268      self._testEmpty(inp, 2, 7, (8, 0, 3))
269
270  @test_util.run_in_graph_and_eager_modes
271  def testIdentity(self):
272    for dtype in _TEST_DTYPES:
273      inp = self._makeData((2, 2, 2), dtype)
274      self._compare(inp, 0, 1)
275      self._compare(inp, 1, 1)
276      self._compare(inp, 2, 1)
277
278  @test_util.run_in_graph_and_eager_modes
279  def testSplitDim0(self):
280    for dtype in _TEST_DTYPES:
281      self._compare(self._makeData((6, 10, 18), dtype), 0, 3)
282      self._compare(self._makeData((6, 7, 18), dtype), 0, 3)
283      self._compare(self._makeData((6, 7, 9), dtype), 0, 3)
284
285  def _RunAndVerify(self, dtype, large_num_splits=False):
286    # Random dims of rank 5
287    shape = np.random.randint(0, 5, size=5)
288    split_dim = np.random.randint(-5, 5)
289    if large_num_splits:
290      num_split = np.random.randint(9, 15)
291    else:
292      num_split = np.random.randint(2, 8)
293    shape[split_dim] = np.random.randint(2, 5) * num_split
294    inp = self._makeData(shape, dtype)
295    with test_util.device(use_gpu=True):
296      result = self.evaluate(
297          array_ops.split(
298              value=inp, num_or_size_splits=num_split, axis=split_dim))
299    slices = [slice(0, x) for x in shape]
300    offset = 0
301    length = shape[split_dim] // num_split
302    for i in range(num_split):
303      slices[split_dim] = slice(offset, offset + length)
304      offset += length
305      self.assertAllEqual(result[i], inp[slices])
306
307  @test_util.run_in_graph_and_eager_modes
308  def testRandom(self):
309    for dtype in _TEST_DTYPES:
310      for _ in range(5):
311        self._RunAndVerify(dtype)
312        self._RunAndVerify(dtype, large_num_splits=True)
313        self._RunAndVerifyVariable(dtype)
314        self._RunAndVerifyVariable(dtype, large_num_splits=True)
315
316  def _testGradientsSimple(self, dtype):
317    inp = self._makeData((4, 4), dtype)
318    with self.cached_session(use_gpu=True):
319      inp_tensor = ops.convert_to_tensor(inp)
320      s = array_ops.split(value=inp_tensor, num_or_size_splits=4, axis=1)
321      inp_grads = [self._makeData((4, 1), dtype)for _ in range(4)]
322      grad_tensors = [constant_op.constant(x) for x in inp_grads]
323      grad = gradients_impl.gradients(s, [inp_tensor], grad_tensors)[0]
324      result = self.evaluate(grad)
325    for i in range(4):
326      self.assertAllEqual(result[:, i:i + 1], inp_grads[i])
327
328  @test_util.run_deprecated_v1
329  def testGradientsAll(self):
330    for dtype in _TEST_DTYPES:
331      self._testGradientsSimple(dtype)
332      self._testGradientsSimpleVariable(dtype)
333
334  @test_util.run_deprecated_v1
335  def testShapeFunctionEdgeCases(self):
336    # split_dim greater than rank of input.
337    with self.assertRaises(ValueError):
338      array_ops.split(value=[[0, 1], [2, 3]], num_or_size_splits=4, axis=2)
339
340    # split dim less than -(rank of input)
341    with self.assertRaises(ValueError):
342      array_ops.split(value=[[0, 1], [2, 3]], num_or_size_splits=4, axis=-3)
343
344    # num_split does not evenly divide the size in split_dim.
345    with self.assertRaisesRegexp(ValueError, "should evenly divide"):
346      array_ops.split(value=[0, 1, 2, 3], num_or_size_splits=3, axis=0)
347
348    # Unknown split_dim.
349    splits = array_ops.split(
350        value=[[0, 1, 2, 3]],
351        num_or_size_splits=4,
352        axis=array_ops.placeholder(dtypes.int32))
353    for s in splits:
354      self.assertEqual([None, None], s.get_shape().as_list())
355
356    # Unknown split_dim and input shape.
357    splits = array_ops.split(
358        value=array_ops.placeholder(dtypes.float32),
359        num_or_size_splits=4,
360        axis=array_ops.placeholder(dtypes.int32))
361    for s in splits:
362      self.assertEqual(None, s.get_shape().ndims)
363
364  @test_util.run_deprecated_v1
365  def testVariableShapeFunction(self):
366    # size_splits too big
367    with self.assertRaises(ValueError):
368      array_ops.split([0, 1], [3, -1], axis=0)
369
370    # Correct inference of variable dimension
371    s0, s1 = array_ops.split([0, 1, 2], [2, -1], axis=0)
372    assert s0.shape.as_list() == [2]
373    assert s1.shape.as_list() == [1]
374
375  @test_util.run_deprecated_v1
376  @test_util.disable_xla("b/123337890")  # Error messages differ
377  def testNonexistentDimTensor(self):
378    x = array_ops.placeholder(dtypes.int32)
379    values = np.zeros([5, 30])
380    splits = array_ops.placeholder(dtypes.int32)
381    with self.assertRaisesRegexp(ValueError, "Cannot infer"):
382      y = array_ops.split(values, splits, axis=x)
383
384    splits = array_ops.placeholder(dtypes.int32, [3])
385    y = array_ops.split(values, splits, axis=x)
386    with self.session(use_gpu=True) as sess:
387      with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
388                                   "must have exactly one element"):
389        sess.run(y, {x: np.array([], dtype=np.int32), splits: [4, 11, 15]})
390
391
392if __name__ == "__main__":
393  test.main()
394