1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Tests for common shapes.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import common_shapes 24from tensorflow.python.framework import tensor_shape 25from tensorflow.python.framework import test_util 26from tensorflow.python.platform import googletest 27 28 29class CommonShapesTest(test_util.TensorFlowTestCase): 30 31 # Asserts that we get the same result with numpy (for known shapes), and that 32 # the order of arguments does not matter (i.e., broadcasting is reflexive). 33 def _assert_incompatible_broadcast(self, shape1, shape2): 34 if shape1.dims is not None and shape2.dims is not None: 35 zeros1 = np.zeros(shape1.as_list()) 36 zeros2 = np.zeros(shape2.as_list()) 37 with self.assertRaises(ValueError): 38 np.broadcast(zeros1, zeros2) 39 with self.assertRaises(ValueError): 40 np.broadcast(zeros2, zeros1) 41 self.assertFalse(common_shapes.is_broadcast_compatible(shape1, shape2)) 42 self.assertFalse(common_shapes.is_broadcast_compatible(shape2, shape1)) 43 with self.assertRaises(ValueError): 44 common_shapes.broadcast_shape(shape1, shape2) 45 with self.assertRaises(ValueError): 46 common_shapes.broadcast_shape(shape2, shape1) 47 48 # Asserts that we get the same result with numpy (for known shapes), and that 49 # the order of arguments does not matter (i.e., broadcasting is reflexive). 50 def _assert_broadcast(self, expected, shape1, shape2): 51 if shape1.dims is not None and shape2.dims is not None: 52 expected_np = expected.as_list() 53 zeros1 = np.zeros(shape1.as_list()) 54 zeros2 = np.zeros(shape2.as_list()) 55 self.assertAllEqual(expected_np, np.broadcast(zeros1, zeros2).shape) 56 self.assertAllEqual(expected_np, np.broadcast(zeros2, zeros1).shape) 57 self.assertEqual( 58 expected, common_shapes.broadcast_shape(shape1, shape2)) 59 self.assertEqual( 60 expected, common_shapes.broadcast_shape(shape2, shape1)) 61 else: 62 self.assertEqual(expected, common_shapes.broadcast_shape(shape1, shape2)) 63 self.assertEqual(expected, common_shapes.broadcast_shape(shape2, shape1)) 64 65 def testBroadcast_one_dimension(self): 66 s1 = tensor_shape.vector(5) 67 s2 = tensor_shape.vector(7) 68 69 unknown = tensor_shape.unknown_shape() 70 scalar = tensor_shape.scalar() 71 expanded_scalar = tensor_shape.TensorShape([1]) 72 73 # Tensors with same shape should have the same broadcast result. 74 for shape in (s1, s2, unknown, scalar, expanded_scalar): 75 self._assert_broadcast(expected=shape, shape1=shape, shape2=shape) 76 77 # [] and [1] act like identity. 78 self._assert_broadcast(expected=s1, shape1=s1, shape2=scalar) 79 self._assert_broadcast(expected=s2, shape1=s2, shape2=scalar) 80 self._assert_broadcast(expected=s1, shape1=s1, shape2=expanded_scalar) 81 self._assert_broadcast(expected=s2, shape1=s2, shape2=expanded_scalar) 82 83 self._assert_broadcast(expected=unknown, shape1=s1, shape2=unknown) 84 self._assert_broadcast(expected=unknown, shape1=s2, shape2=unknown) 85 86 self._assert_broadcast( 87 expected=expanded_scalar, shape1=scalar, shape2=expanded_scalar) 88 89 self._assert_incompatible_broadcast(shape1=s1, shape2=s2) 90 91 def testBroadcast_many_dimensions(self): 92 unknown = tensor_shape.unknown_shape() 93 shape_0 = tensor_shape.scalar() 94 shape_1 = tensor_shape.vector(1) 95 shape_4 = tensor_shape.vector(4) 96 shape_1x4 = tensor_shape.matrix(1, 4) 97 shape_4x1 = tensor_shape.matrix(4, 1) 98 shape_3x4 = tensor_shape.matrix(3, 4) 99 shape_4x3 = tensor_shape.matrix(4, 3) 100 101 # Tensors with same shape should have the same broadcast result. 102 for shape in ( 103 shape_0, shape_1, shape_4, shape_1x4, shape_4x1, shape_3x4, shape_4x3): 104 self._assert_broadcast(expected=shape, shape1=shape, shape2=shape) 105 106 # [] and [1] act like identity. 107 for identity in (shape_0, shape_1): 108 for shape in (shape_4, shape_1x4, shape_4x1, shape_3x4, shape_4x3): 109 self._assert_broadcast(expected=shape, shape1=identity, shape2=shape) 110 111 # Unknown in, unknown out. 112 for shape in (shape_4, shape_1x4, shape_4x1, shape_3x4, shape_4x3): 113 self._assert_broadcast(expected=unknown, shape1=shape, shape2=unknown) 114 115 self._assert_broadcast(expected=shape_1x4, shape1=shape_4, shape2=shape_1x4) 116 shape_4x4 = tensor_shape.matrix(4, 4) 117 self._assert_broadcast(expected=shape_4x4, shape1=shape_4, shape2=shape_4x1) 118 self._assert_broadcast(expected=shape_3x4, shape1=shape_4, shape2=shape_3x4) 119 self._assert_incompatible_broadcast(shape1=shape_4, shape2=shape_4x3) 120 self._assert_broadcast( 121 expected=shape_4x4, shape1=shape_1x4, shape2=shape_4x1) 122 self._assert_broadcast( 123 expected=shape_3x4, shape1=shape_1x4, shape2=shape_3x4) 124 self._assert_incompatible_broadcast(shape1=shape_1x4, shape2=shape_4x3) 125 self._assert_incompatible_broadcast(shape1=shape_4x1, shape2=shape_3x4) 126 self._assert_broadcast( 127 expected=shape_4x3, shape1=shape_4x1, shape2=shape_4x3) 128 self._assert_incompatible_broadcast(shape1=shape_3x4, shape2=shape_4x3) 129 130 # Asserts that the order of arguments does not matter (i.e., broadcasting is 131 # reflexive). 132 def _assert_broadcast_with_unknown_dims(self, expected, shape1, shape2): 133 actual_dims = common_shapes.broadcast_shape(shape1, shape2).dims 134 reflexive_actual_dims = common_shapes.broadcast_shape(shape2, shape1).dims 135 136 if actual_dims is None: 137 self.assertIsNone(reflexive_actual_dims) 138 elif reflexive_actual_dims is None: 139 self.assertIsNone(actual_dims) 140 else: 141 self.assertEqual(len(actual_dims), len(reflexive_actual_dims)) 142 for actual_dim, reflexive_actual_dim in zip( 143 actual_dims, reflexive_actual_dims): 144 self.assertEqual(actual_dim.value, reflexive_actual_dim.value) 145 146 expected_dims = expected.dims 147 if expected_dims is None: 148 self.assertIsNone(actual_dims) 149 elif actual_dims is None: 150 self.assertIsNone(expected_dims) 151 else: 152 self.assertEqual(len(expected_dims), len(actual_dims)) 153 for expected_dim, actual_dim in zip(expected_dims, actual_dims): 154 self.assertEqual(expected_dim.value, actual_dim.value) 155 156 def testBroadcast_unknown_dims(self): 157 unknown = tensor_shape.unknown_shape() 158 shape_0 = tensor_shape.scalar() 159 shape_1 = tensor_shape.vector(1) 160 # pylint: disable=invalid-name 161 shape_U = tensor_shape.vector(None) 162 shape_1xU = tensor_shape.matrix(1, None) 163 shape_Ux1 = tensor_shape.matrix(None, 1) 164 shape_4xU = tensor_shape.matrix(4, None) 165 shape_Ux4 = tensor_shape.matrix(None, 4) 166 # pylint: enable=invalid-name 167 168 # Tensors with same shape should have the same broadcast result. 169 for shape in (shape_U, shape_1xU, shape_Ux1, shape_4xU, shape_Ux4): 170 self._assert_broadcast_with_unknown_dims( 171 expected=shape, shape1=shape, shape2=shape) 172 173 # [] and [1] act like identity. 174 for identity in (shape_0, shape_1): 175 for shape in (shape_U, shape_1xU, shape_Ux1, shape_4xU, shape_Ux4): 176 self._assert_broadcast_with_unknown_dims( 177 expected=shape, shape1=identity, shape2=shape) 178 179 # Unknown in, unknown out. 180 for shape in (shape_U, shape_1xU, shape_Ux1, shape_4xU, shape_Ux4): 181 self._assert_broadcast_with_unknown_dims( 182 expected=unknown, shape1=shape, shape2=unknown) 183 184 self._assert_broadcast_with_unknown_dims( 185 expected=shape_1xU, shape1=shape_U, shape2=shape_1xU) 186 shape_UxU = tensor_shape.matrix(None, None) # pylint: disable=invalid-name 187 self._assert_broadcast_with_unknown_dims( 188 expected=shape_UxU, shape1=shape_U, shape2=shape_Ux1) 189 self._assert_broadcast_with_unknown_dims( 190 expected=shape_4xU, shape1=shape_U, shape2=shape_4xU) 191 self._assert_broadcast_with_unknown_dims( 192 expected=shape_Ux4, shape1=shape_U, shape2=shape_Ux4) 193 self._assert_broadcast_with_unknown_dims( 194 expected=shape_UxU, shape1=shape_1xU, shape2=shape_Ux1) 195 self._assert_broadcast_with_unknown_dims( 196 expected=shape_4xU, shape1=shape_1xU, shape2=shape_4xU) 197 self._assert_broadcast_with_unknown_dims( 198 expected=shape_Ux4, shape1=shape_1xU, shape2=shape_Ux4) 199 self._assert_broadcast_with_unknown_dims( 200 expected=shape_4xU, shape1=shape_Ux1, shape2=shape_4xU) 201 self._assert_broadcast_with_unknown_dims( 202 expected=shape_Ux4, shape1=shape_Ux1, shape2=shape_Ux4) 203 shape_4x4 = tensor_shape.matrix(4, 4) 204 self._assert_broadcast_with_unknown_dims( 205 expected=shape_4x4, shape1=shape_4xU, shape2=shape_Ux4) 206 207 208if __name__ == "__main__": 209 googletest.main() 210