1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Tests for common shapes."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import common_shapes
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.framework import test_util
26from tensorflow.python.platform import googletest
27
28
29class CommonShapesTest(test_util.TensorFlowTestCase):
30
31  # Asserts that we get the same result with numpy (for known shapes), and that
32  # the order of arguments does not matter (i.e., broadcasting is reflexive).
33  def _assert_incompatible_broadcast(self, shape1, shape2):
34    if shape1.dims is not None and shape2.dims is not None:
35      zeros1 = np.zeros(shape1.as_list())
36      zeros2 = np.zeros(shape2.as_list())
37      with self.assertRaises(ValueError):
38        np.broadcast(zeros1, zeros2)
39      with self.assertRaises(ValueError):
40        np.broadcast(zeros2, zeros1)
41    self.assertFalse(common_shapes.is_broadcast_compatible(shape1, shape2))
42    self.assertFalse(common_shapes.is_broadcast_compatible(shape2, shape1))
43    with self.assertRaises(ValueError):
44      common_shapes.broadcast_shape(shape1, shape2)
45    with self.assertRaises(ValueError):
46      common_shapes.broadcast_shape(shape2, shape1)
47
48  # Asserts that we get the same result with numpy (for known shapes), and that
49  # the order of arguments does not matter (i.e., broadcasting is reflexive).
50  def _assert_broadcast(self, expected, shape1, shape2):
51    if shape1.dims is not None and shape2.dims is not None:
52      expected_np = expected.as_list()
53      zeros1 = np.zeros(shape1.as_list())
54      zeros2 = np.zeros(shape2.as_list())
55      self.assertAllEqual(expected_np, np.broadcast(zeros1, zeros2).shape)
56      self.assertAllEqual(expected_np, np.broadcast(zeros2, zeros1).shape)
57      self.assertEqual(
58          expected, common_shapes.broadcast_shape(shape1, shape2))
59      self.assertEqual(
60          expected, common_shapes.broadcast_shape(shape2, shape1))
61    else:
62      self.assertEqual(expected, common_shapes.broadcast_shape(shape1, shape2))
63      self.assertEqual(expected, common_shapes.broadcast_shape(shape2, shape1))
64
65  def testBroadcast_one_dimension(self):
66    s1 = tensor_shape.vector(5)
67    s2 = tensor_shape.vector(7)
68
69    unknown = tensor_shape.unknown_shape()
70    scalar = tensor_shape.scalar()
71    expanded_scalar = tensor_shape.TensorShape([1])
72
73    # Tensors with same shape should have the same broadcast result.
74    for shape in (s1, s2, unknown, scalar, expanded_scalar):
75      self._assert_broadcast(expected=shape, shape1=shape, shape2=shape)
76
77    # [] and [1] act like identity.
78    self._assert_broadcast(expected=s1, shape1=s1, shape2=scalar)
79    self._assert_broadcast(expected=s2, shape1=s2, shape2=scalar)
80    self._assert_broadcast(expected=s1, shape1=s1, shape2=expanded_scalar)
81    self._assert_broadcast(expected=s2, shape1=s2, shape2=expanded_scalar)
82
83    self._assert_broadcast(expected=unknown, shape1=s1, shape2=unknown)
84    self._assert_broadcast(expected=unknown, shape1=s2, shape2=unknown)
85
86    self._assert_broadcast(
87        expected=expanded_scalar, shape1=scalar, shape2=expanded_scalar)
88
89    self._assert_incompatible_broadcast(shape1=s1, shape2=s2)
90
91  def testBroadcast_many_dimensions(self):
92    unknown = tensor_shape.unknown_shape()
93    shape_0 = tensor_shape.scalar()
94    shape_1 = tensor_shape.vector(1)
95    shape_4 = tensor_shape.vector(4)
96    shape_1x4 = tensor_shape.matrix(1, 4)
97    shape_4x1 = tensor_shape.matrix(4, 1)
98    shape_3x4 = tensor_shape.matrix(3, 4)
99    shape_4x3 = tensor_shape.matrix(4, 3)
100
101    # Tensors with same shape should have the same broadcast result.
102    for shape in (
103        shape_0, shape_1, shape_4, shape_1x4, shape_4x1, shape_3x4, shape_4x3):
104      self._assert_broadcast(expected=shape, shape1=shape, shape2=shape)
105
106    # [] and [1] act like identity.
107    for identity in (shape_0, shape_1):
108      for shape in (shape_4, shape_1x4, shape_4x1, shape_3x4, shape_4x3):
109        self._assert_broadcast(expected=shape, shape1=identity, shape2=shape)
110
111    # Unknown in, unknown out.
112    for shape in (shape_4, shape_1x4, shape_4x1, shape_3x4, shape_4x3):
113      self._assert_broadcast(expected=unknown, shape1=shape, shape2=unknown)
114
115    self._assert_broadcast(expected=shape_1x4, shape1=shape_4, shape2=shape_1x4)
116    shape_4x4 = tensor_shape.matrix(4, 4)
117    self._assert_broadcast(expected=shape_4x4, shape1=shape_4, shape2=shape_4x1)
118    self._assert_broadcast(expected=shape_3x4, shape1=shape_4, shape2=shape_3x4)
119    self._assert_incompatible_broadcast(shape1=shape_4, shape2=shape_4x3)
120    self._assert_broadcast(
121        expected=shape_4x4, shape1=shape_1x4, shape2=shape_4x1)
122    self._assert_broadcast(
123        expected=shape_3x4, shape1=shape_1x4, shape2=shape_3x4)
124    self._assert_incompatible_broadcast(shape1=shape_1x4, shape2=shape_4x3)
125    self._assert_incompatible_broadcast(shape1=shape_4x1, shape2=shape_3x4)
126    self._assert_broadcast(
127        expected=shape_4x3, shape1=shape_4x1, shape2=shape_4x3)
128    self._assert_incompatible_broadcast(shape1=shape_3x4, shape2=shape_4x3)
129
130  # Asserts that the order of arguments does not matter (i.e., broadcasting is
131  # reflexive).
132  def _assert_broadcast_with_unknown_dims(self, expected, shape1, shape2):
133    actual_dims = common_shapes.broadcast_shape(shape1, shape2).dims
134    reflexive_actual_dims = common_shapes.broadcast_shape(shape2, shape1).dims
135
136    if actual_dims is None:
137      self.assertIsNone(reflexive_actual_dims)
138    elif reflexive_actual_dims is None:
139      self.assertIsNone(actual_dims)
140    else:
141      self.assertEqual(len(actual_dims), len(reflexive_actual_dims))
142      for actual_dim, reflexive_actual_dim in zip(
143          actual_dims, reflexive_actual_dims):
144        self.assertEqual(actual_dim.value, reflexive_actual_dim.value)
145
146    expected_dims = expected.dims
147    if expected_dims is None:
148      self.assertIsNone(actual_dims)
149    elif actual_dims is None:
150      self.assertIsNone(expected_dims)
151    else:
152      self.assertEqual(len(expected_dims), len(actual_dims))
153      for expected_dim, actual_dim in zip(expected_dims, actual_dims):
154        self.assertEqual(expected_dim.value, actual_dim.value)
155
156  def testBroadcast_unknown_dims(self):
157    unknown = tensor_shape.unknown_shape()
158    shape_0 = tensor_shape.scalar()
159    shape_1 = tensor_shape.vector(1)
160    # pylint: disable=invalid-name
161    shape_U = tensor_shape.vector(None)
162    shape_1xU = tensor_shape.matrix(1, None)
163    shape_Ux1 = tensor_shape.matrix(None, 1)
164    shape_4xU = tensor_shape.matrix(4, None)
165    shape_Ux4 = tensor_shape.matrix(None, 4)
166    # pylint: enable=invalid-name
167
168    # Tensors with same shape should have the same broadcast result.
169    for shape in (shape_U, shape_1xU, shape_Ux1, shape_4xU, shape_Ux4):
170      self._assert_broadcast_with_unknown_dims(
171          expected=shape, shape1=shape, shape2=shape)
172
173    # [] and [1] act like identity.
174    for identity in (shape_0, shape_1):
175      for shape in (shape_U, shape_1xU, shape_Ux1, shape_4xU, shape_Ux4):
176        self._assert_broadcast_with_unknown_dims(
177            expected=shape, shape1=identity, shape2=shape)
178
179    # Unknown in, unknown out.
180    for shape in (shape_U, shape_1xU, shape_Ux1, shape_4xU, shape_Ux4):
181      self._assert_broadcast_with_unknown_dims(
182          expected=unknown, shape1=shape, shape2=unknown)
183
184    self._assert_broadcast_with_unknown_dims(
185        expected=shape_1xU, shape1=shape_U, shape2=shape_1xU)
186    shape_UxU = tensor_shape.matrix(None, None)  # pylint: disable=invalid-name
187    self._assert_broadcast_with_unknown_dims(
188        expected=shape_UxU, shape1=shape_U, shape2=shape_Ux1)
189    self._assert_broadcast_with_unknown_dims(
190        expected=shape_4xU, shape1=shape_U, shape2=shape_4xU)
191    self._assert_broadcast_with_unknown_dims(
192        expected=shape_Ux4, shape1=shape_U, shape2=shape_Ux4)
193    self._assert_broadcast_with_unknown_dims(
194        expected=shape_UxU, shape1=shape_1xU, shape2=shape_Ux1)
195    self._assert_broadcast_with_unknown_dims(
196        expected=shape_4xU, shape1=shape_1xU, shape2=shape_4xU)
197    self._assert_broadcast_with_unknown_dims(
198        expected=shape_Ux4, shape1=shape_1xU, shape2=shape_Ux4)
199    self._assert_broadcast_with_unknown_dims(
200        expected=shape_4xU, shape1=shape_Ux1, shape2=shape_4xU)
201    self._assert_broadcast_with_unknown_dims(
202        expected=shape_Ux4, shape1=shape_Ux1, shape2=shape_Ux4)
203    shape_4x4 = tensor_shape.matrix(4, 4)
204    self._assert_broadcast_with_unknown_dims(
205        expected=shape_4x4, shape1=shape_4xU, shape2=shape_Ux4)
206
207
208if __name__ == "__main__":
209  googletest.main()
210