1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functional tests for Unstack Op."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22from six.moves import xrange  # pylint: disable=redefined-builtin
23
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import test_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import gradient_checker
28from tensorflow.python.platform import test
29
30
31def np_split_squeeze(array, axis):
32  axis_len = array.shape[axis]
33  return [
34      np.squeeze(
35          arr, axis=(axis,)) for arr in np.split(
36              array, axis_len, axis=axis)
37  ]
38
39
40class UnstackOpTest(test.TestCase):
41
42  def testSimple(self):
43    np.random.seed(7)
44    for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
45      for dtype in [
46          np.bool, np.float16, np.float32, np.float64, np.uint8, np.int32,
47          np.int64
48      ]:
49        data = np.random.randn(*shape).astype(dtype)
50        # Convert data to a single tensorflow tensor
51        x = constant_op.constant(data)
52        # Unstack into a list of tensors
53        cs = array_ops.unstack(x, num=shape[0])
54        self.assertEqual(type(cs), list)
55        self.assertEqual(len(cs), shape[0])
56        cs = [self.evaluate(c) for c in cs]
57        self.assertAllEqual(cs, data)
58
59  def testSimpleGpu(self):
60    if not test_util.is_gpu_available():
61      self.skipTest('No GPU available')
62
63    np.random.seed(7)
64    with test_util.force_gpu():
65      for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
66        for dtype in [
67            np.bool, np.float16, np.float32, np.float64, np.uint8, np.int32,
68            np.int64
69        ]:
70          data = np.random.randn(*shape).astype(dtype)
71          # Convert data to a single tensorflow tensor
72          x = constant_op.constant(data)
73          # Unstack into a list of tensors
74          cs = array_ops.unstack(x, num=shape[0])
75          self.assertEqual(type(cs), list)
76          self.assertEqual(len(cs), shape[0])
77          cs = [self.evaluate(c) for c in cs]
78          self.assertAllEqual(cs, data)
79
80  @test_util.run_deprecated_v1
81  def testGradientsAxis0(self):
82    for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
83      data = np.random.randn(*shape)
84      shapes = [shape[1:]] * shape[0]
85      for i in xrange(shape[0]):
86        with self.cached_session():
87          x = constant_op.constant(data)
88          cs = array_ops.unstack(x, num=shape[0])
89          err = gradient_checker.compute_gradient_error(x, shape, cs[i],
90                                                        shapes[i])
91          self.assertLess(err, 1e-6)
92
93  @test_util.run_deprecated_v1
94  def testGradientsAxis1(self):
95    for shape in (2, 3), (3, 2), (4, 3, 2):
96      data = np.random.randn(*shape)
97      out_shape = list(shape)
98      del out_shape[1]
99      for i in xrange(shape[1]):
100        with self.cached_session():
101          x = constant_op.constant(data)
102          cs = array_ops.unstack(x, num=shape[1], axis=1)
103          err = gradient_checker.compute_gradient_error(x, shape, cs[i],
104                                                        out_shape)
105          self.assertLess(err, 1e-6)
106
107  @test_util.run_deprecated_v1
108  def testInferNum(self):
109    for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
110      x = array_ops.placeholder(np.float32, shape=shape)
111      cs = array_ops.unstack(x)
112      self.assertEqual(type(cs), list)
113      self.assertEqual(len(cs), shape[0])
114
115  @test_util.run_deprecated_v1
116  def testCannotInferNumFromUnknownShape(self):
117    x = array_ops.placeholder(np.float32)
118    with self.assertRaisesRegexp(ValueError,
119                                 r'Cannot infer num from shape <unknown>'):
120      array_ops.unstack(x)
121
122  @test_util.run_deprecated_v1
123  def testUnknownShapeOkWithNum(self):
124    x = array_ops.placeholder(np.float32)
125    array_ops.unstack(x, num=2)
126
127  @test_util.run_deprecated_v1
128  def testCannotInferNumFromNoneShape(self):
129    x = array_ops.placeholder(np.float32, shape=(None,))
130    with self.assertRaisesRegexp(ValueError,
131                                 r'Cannot infer num from shape \((\?|None),\)'):
132      array_ops.unstack(x)
133
134  def testAgainstNumpy(self):
135    # For 1 to 5 dimensions.
136    for i in range(1, 6):
137      a = np.random.random(np.random.permutation(i) + 1)
138
139      # For all the possible axis to split it, including negative indices.
140      for j in range(-i, i):
141        expected = np_split_squeeze(a, j)
142
143        actual_unstack = self.evaluate(array_ops.unstack(a, axis=j))
144
145        self.assertAllEqual(expected, actual_unstack)
146
147  def testAxis0Default(self):
148    a = constant_op.constant([[1, 2, 3], [4, 5, 6]], name='a')
149    unstacked = self.evaluate(array_ops.unstack(a))
150
151    self.assertEqual(len(unstacked), 2)
152    self.assertAllEqual(unstacked[0], [1, 2, 3])
153    self.assertAllEqual(unstacked[1], [4, 5, 6])
154
155  def testAxisOutOfRange(self):
156    a = constant_op.constant([[1, 2, 3], [4, 5, 6]], name='a')
157    with self.assertRaisesRegexp(ValueError, r'axis = 2 not in \[-2, 2\)'):
158      array_ops.unstack(a, axis=2)
159
160  def testAxisOutOfNegativeRange(self):
161    a = constant_op.constant([[1, 2, 3], [4, 5, 6]], name='a')
162    with self.assertRaisesRegexp(ValueError, r'axis = -3 not in \[-2, 2\)'):
163      array_ops.unstack(a, axis=-3)
164
165  def testZeroLengthDim(self):
166    x = array_ops.zeros(shape=(0, 1, 2))
167    y = self.evaluate(array_ops.unstack(x, axis=1)[0])
168    self.assertEqual(y.shape, (0, 2))
169
170
171if __name__ == '__main__':
172  test.main()
173