1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functional tests for Unstack Op.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22from six.moves import xrange # pylint: disable=redefined-builtin 23 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import test_util 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import gradient_checker 28from tensorflow.python.platform import test 29 30 31def np_split_squeeze(array, axis): 32 axis_len = array.shape[axis] 33 return [ 34 np.squeeze( 35 arr, axis=(axis,)) for arr in np.split( 36 array, axis_len, axis=axis) 37 ] 38 39 40class UnstackOpTest(test.TestCase): 41 42 def testSimple(self): 43 np.random.seed(7) 44 for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): 45 for dtype in [ 46 np.bool, np.float16, np.float32, np.float64, np.uint8, np.int32, 47 np.int64 48 ]: 49 data = np.random.randn(*shape).astype(dtype) 50 # Convert data to a single tensorflow tensor 51 x = constant_op.constant(data) 52 # Unstack into a list of tensors 53 cs = array_ops.unstack(x, num=shape[0]) 54 self.assertEqual(type(cs), list) 55 self.assertEqual(len(cs), shape[0]) 56 cs = [self.evaluate(c) for c in cs] 57 self.assertAllEqual(cs, data) 58 59 def testSimpleGpu(self): 60 if not test_util.is_gpu_available(): 61 self.skipTest('No GPU available') 62 63 np.random.seed(7) 64 with test_util.force_gpu(): 65 for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): 66 for dtype in [ 67 np.bool, np.float16, np.float32, np.float64, np.uint8, np.int32, 68 np.int64 69 ]: 70 data = np.random.randn(*shape).astype(dtype) 71 # Convert data to a single tensorflow tensor 72 x = constant_op.constant(data) 73 # Unstack into a list of tensors 74 cs = array_ops.unstack(x, num=shape[0]) 75 self.assertEqual(type(cs), list) 76 self.assertEqual(len(cs), shape[0]) 77 cs = [self.evaluate(c) for c in cs] 78 self.assertAllEqual(cs, data) 79 80 @test_util.run_deprecated_v1 81 def testGradientsAxis0(self): 82 for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): 83 data = np.random.randn(*shape) 84 shapes = [shape[1:]] * shape[0] 85 for i in xrange(shape[0]): 86 with self.cached_session(): 87 x = constant_op.constant(data) 88 cs = array_ops.unstack(x, num=shape[0]) 89 err = gradient_checker.compute_gradient_error(x, shape, cs[i], 90 shapes[i]) 91 self.assertLess(err, 1e-6) 92 93 @test_util.run_deprecated_v1 94 def testGradientsAxis1(self): 95 for shape in (2, 3), (3, 2), (4, 3, 2): 96 data = np.random.randn(*shape) 97 out_shape = list(shape) 98 del out_shape[1] 99 for i in xrange(shape[1]): 100 with self.cached_session(): 101 x = constant_op.constant(data) 102 cs = array_ops.unstack(x, num=shape[1], axis=1) 103 err = gradient_checker.compute_gradient_error(x, shape, cs[i], 104 out_shape) 105 self.assertLess(err, 1e-6) 106 107 @test_util.run_deprecated_v1 108 def testInferNum(self): 109 for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): 110 x = array_ops.placeholder(np.float32, shape=shape) 111 cs = array_ops.unstack(x) 112 self.assertEqual(type(cs), list) 113 self.assertEqual(len(cs), shape[0]) 114 115 @test_util.run_deprecated_v1 116 def testCannotInferNumFromUnknownShape(self): 117 x = array_ops.placeholder(np.float32) 118 with self.assertRaisesRegexp(ValueError, 119 r'Cannot infer num from shape <unknown>'): 120 array_ops.unstack(x) 121 122 @test_util.run_deprecated_v1 123 def testUnknownShapeOkWithNum(self): 124 x = array_ops.placeholder(np.float32) 125 array_ops.unstack(x, num=2) 126 127 @test_util.run_deprecated_v1 128 def testCannotInferNumFromNoneShape(self): 129 x = array_ops.placeholder(np.float32, shape=(None,)) 130 with self.assertRaisesRegexp(ValueError, 131 r'Cannot infer num from shape \((\?|None),\)'): 132 array_ops.unstack(x) 133 134 def testAgainstNumpy(self): 135 # For 1 to 5 dimensions. 136 for i in range(1, 6): 137 a = np.random.random(np.random.permutation(i) + 1) 138 139 # For all the possible axis to split it, including negative indices. 140 for j in range(-i, i): 141 expected = np_split_squeeze(a, j) 142 143 actual_unstack = self.evaluate(array_ops.unstack(a, axis=j)) 144 145 self.assertAllEqual(expected, actual_unstack) 146 147 def testAxis0Default(self): 148 a = constant_op.constant([[1, 2, 3], [4, 5, 6]], name='a') 149 unstacked = self.evaluate(array_ops.unstack(a)) 150 151 self.assertEqual(len(unstacked), 2) 152 self.assertAllEqual(unstacked[0], [1, 2, 3]) 153 self.assertAllEqual(unstacked[1], [4, 5, 6]) 154 155 def testAxisOutOfRange(self): 156 a = constant_op.constant([[1, 2, 3], [4, 5, 6]], name='a') 157 with self.assertRaisesRegexp(ValueError, r'axis = 2 not in \[-2, 2\)'): 158 array_ops.unstack(a, axis=2) 159 160 def testAxisOutOfNegativeRange(self): 161 a = constant_op.constant([[1, 2, 3], [4, 5, 6]], name='a') 162 with self.assertRaisesRegexp(ValueError, r'axis = -3 not in \[-2, 2\)'): 163 array_ops.unstack(a, axis=-3) 164 165 def testZeroLengthDim(self): 166 x = array_ops.zeros(shape=(0, 1, 2)) 167 y = self.evaluate(array_ops.unstack(x, axis=1)[0]) 168 self.assertEqual(y.shape, (0, 2)) 169 170 171if __name__ == '__main__': 172 test.main() 173