1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for conv_utils."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import itertools
22
23from absl.testing import parameterized
24import numpy as np
25
26from tensorflow.python.keras.utils import conv_utils
27from tensorflow.python.platform import test
28
29
30def _get_const_output_shape(input_shape, dim):
31  return tuple([min(d, dim) for d in input_shape])
32
33
34input_shapes = [
35    (0,),
36    (0, 0),
37    (1,),
38    (2,),
39    (3,),
40    (1, 0),
41    (0, 3),
42    (1, 1),
43    (1, 2),
44    (3, 1),
45    (2, 2),
46    (3, 3),
47    (1, 0, 1),
48    (5, 2, 3),
49    (3, 5, 6, 7, 0),
50    (3, 2, 2, 4, 4),
51    (1, 2, 3, 4, 7, 2),
52]
53
54
55class TestBasicConvUtilsTest(test.TestCase):
56
57  def test_convert_data_format(self):
58    self.assertEqual('NCDHW', conv_utils.convert_data_format(
59        'channels_first', 5))
60    self.assertEqual('NCHW', conv_utils.convert_data_format(
61        'channels_first', 4))
62    self.assertEqual('NCW', conv_utils.convert_data_format('channels_first', 3))
63    self.assertEqual('NHWC', conv_utils.convert_data_format('channels_last', 4))
64    self.assertEqual('NWC', conv_utils.convert_data_format('channels_last', 3))
65    self.assertEqual('NDHWC', conv_utils.convert_data_format(
66        'channels_last', 5))
67
68    with self.assertRaises(ValueError):
69      conv_utils.convert_data_format('invalid', 2)
70
71  def test_normalize_tuple(self):
72    self.assertEqual((2, 2, 2),
73                     conv_utils.normalize_tuple(2, n=3, name='strides'))
74    self.assertEqual((2, 1, 2),
75                     conv_utils.normalize_tuple((2, 1, 2), n=3, name='strides'))
76
77    with self.assertRaises(ValueError):
78      conv_utils.normalize_tuple((2, 1), n=3, name='strides')
79
80    with self.assertRaises(ValueError):
81      conv_utils.normalize_tuple(None, n=3, name='strides')
82
83  def test_normalize_data_format(self):
84    self.assertEqual('channels_last',
85                     conv_utils.normalize_data_format('Channels_Last'))
86    self.assertEqual('channels_first',
87                     conv_utils.normalize_data_format('CHANNELS_FIRST'))
88
89    with self.assertRaises(ValueError):
90      conv_utils.normalize_data_format('invalid')
91
92  def test_normalize_padding(self):
93    self.assertEqual('same', conv_utils.normalize_padding('SAME'))
94    self.assertEqual('valid', conv_utils.normalize_padding('VALID'))
95
96    with self.assertRaises(ValueError):
97      conv_utils.normalize_padding('invalid')
98
99  def test_conv_output_length(self):
100    self.assertEqual(4, conv_utils.conv_output_length(4, 2, 'same', 1, 1))
101    self.assertEqual(2, conv_utils.conv_output_length(4, 2, 'same', 2, 1))
102    self.assertEqual(3, conv_utils.conv_output_length(4, 2, 'valid', 1, 1))
103    self.assertEqual(2, conv_utils.conv_output_length(4, 2, 'valid', 2, 1))
104    self.assertEqual(5, conv_utils.conv_output_length(4, 2, 'full', 1, 1))
105    self.assertEqual(3, conv_utils.conv_output_length(4, 2, 'full', 2, 1))
106    self.assertEqual(2, conv_utils.conv_output_length(5, 2, 'valid', 2, 2))
107
108  def test_conv_input_length(self):
109    self.assertEqual(3, conv_utils.conv_input_length(4, 2, 'same', 1))
110    self.assertEqual(2, conv_utils.conv_input_length(2, 2, 'same', 2))
111    self.assertEqual(4, conv_utils.conv_input_length(3, 2, 'valid', 1))
112    self.assertEqual(4, conv_utils.conv_input_length(2, 2, 'valid', 2))
113    self.assertEqual(3, conv_utils.conv_input_length(4, 2, 'full', 1))
114    self.assertEqual(4, conv_utils.conv_input_length(3, 2, 'full', 2))
115
116  def test_deconv_output_length(self):
117    self.assertEqual(4, conv_utils.deconv_output_length(4, 2, 'same', stride=1))
118    self.assertEqual(8, conv_utils.deconv_output_length(4, 2, 'same', stride=2))
119    self.assertEqual(5, conv_utils.deconv_output_length(
120        4, 2, 'valid', stride=1))
121    self.assertEqual(8, conv_utils.deconv_output_length(
122        4, 2, 'valid', stride=2))
123    self.assertEqual(3, conv_utils.deconv_output_length(4, 2, 'full', stride=1))
124    self.assertEqual(6, conv_utils.deconv_output_length(4, 2, 'full', stride=2))
125    self.assertEqual(
126        5,
127        conv_utils.deconv_output_length(
128            4, 2, 'same', output_padding=2, stride=1))
129    self.assertEqual(
130        7,
131        conv_utils.deconv_output_length(
132            4, 2, 'same', output_padding=1, stride=2))
133    self.assertEqual(
134        7,
135        conv_utils.deconv_output_length(
136            4, 2, 'valid', output_padding=2, stride=1))
137    self.assertEqual(
138        9,
139        conv_utils.deconv_output_length(
140            4, 2, 'valid', output_padding=1, stride=2))
141    self.assertEqual(
142        5,
143        conv_utils.deconv_output_length(
144            4, 2, 'full', output_padding=2, stride=1))
145    self.assertEqual(
146        7,
147        conv_utils.deconv_output_length(
148            4, 2, 'full', output_padding=1, stride=2))
149    self.assertEqual(
150        5,
151        conv_utils.deconv_output_length(
152            4, 2, 'same', output_padding=1, stride=1, dilation=2))
153    self.assertEqual(
154        12,
155        conv_utils.deconv_output_length(
156            4, 2, 'valid', output_padding=2, stride=2, dilation=3))
157    self.assertEqual(
158        6,
159        conv_utils.deconv_output_length(
160            4, 2, 'full', output_padding=2, stride=2, dilation=3))
161
162
163@parameterized.parameters(input_shapes)
164class TestConvUtils(test.TestCase, parameterized.TestCase):
165
166  def test_conv_kernel_mask_fc(self, *input_shape):
167    padding = 'valid'
168    kernel_shape = input_shape
169    ndims = len(input_shape)
170    strides = (1,) * ndims
171    output_shape = _get_const_output_shape(input_shape, dim=1)
172    mask = np.ones(input_shape + output_shape, np.bool)
173    self.assertAllEqual(
174        mask,
175        conv_utils.conv_kernel_mask(
176            input_shape,
177            kernel_shape,
178            strides,
179            padding
180        )
181    )
182
183  def test_conv_kernel_mask_diag(self, *input_shape):
184    ndims = len(input_shape)
185    kernel_shape = (1,) * ndims
186    strides = (1,) * ndims
187
188    for padding in ['valid', 'same']:
189      mask = np.identity(int(np.prod(input_shape)), np.bool)
190      mask = np.reshape(mask, input_shape * 2)
191      self.assertAllEqual(
192          mask,
193          conv_utils.conv_kernel_mask(
194              input_shape,
195              kernel_shape,
196              strides,
197              padding
198          )
199      )
200
201  def test_conv_kernel_mask_full_stride(self, *input_shape):
202    padding = 'valid'
203    ndims = len(input_shape)
204    kernel_shape = (1,) * ndims
205    strides = tuple([max(d, 1) for d in input_shape])
206    output_shape = _get_const_output_shape(input_shape, dim=1)
207
208    mask = np.zeros(input_shape + output_shape, np.bool)
209    if all(d > 0 for d in mask.shape):
210      mask[(0,) * len(output_shape)] = True
211
212    self.assertAllEqual(
213        mask,
214        conv_utils.conv_kernel_mask(
215            input_shape,
216            kernel_shape,
217            strides,
218            padding
219        )
220    )
221
222  def test_conv_kernel_mask_almost_full_stride(self, *input_shape):
223    padding = 'valid'
224    ndims = len(input_shape)
225    kernel_shape = (1,) * ndims
226    strides = tuple([max(d - 1, 1) for d in input_shape])
227    output_shape = _get_const_output_shape(input_shape, dim=2)
228
229    mask = np.zeros(input_shape + output_shape, np.bool)
230    if all(d > 0 for d in mask.shape):
231      for in_position in itertools.product(*[[0, d - 1] for d in input_shape]):
232        out_position = tuple([min(p, 1) for p in in_position])
233        mask[in_position + out_position] = True
234
235    self.assertAllEqual(
236        mask,
237        conv_utils.conv_kernel_mask(
238            input_shape,
239            kernel_shape,
240            strides,
241            padding
242        )
243    )
244
245  def test_conv_kernel_mask_rect_kernel(self, *input_shape):
246    padding = 'valid'
247    ndims = len(input_shape)
248    strides = (1,) * ndims
249
250    for d in range(ndims):
251      kernel_shape = [1] * ndims
252      kernel_shape[d] = input_shape[d]
253
254      output_shape = list(input_shape)
255      output_shape[d] = min(1, input_shape[d])
256
257      mask = np.identity(int(np.prod(input_shape)), np.bool)
258      mask = np.reshape(mask, input_shape * 2)
259
260      for p in itertools.product(*[range(input_shape[dim])
261                                   for dim in range(ndims)]):
262        p = list(p)
263        p[d] = slice(None)
264        mask[p * 2] = True
265
266      mask = np.take(mask, range(0, min(1, input_shape[d])), ndims + d)
267
268      self.assertAllEqual(
269          mask,
270          conv_utils.conv_kernel_mask(
271              input_shape,
272              kernel_shape,
273              strides,
274              padding
275          )
276      )
277
278  def test_conv_kernel_mask_wrong_padding(self, *input_shape):
279    ndims = len(input_shape)
280    kernel_shape = (1,) * ndims
281    strides = (1,) * ndims
282
283    conv_utils.conv_kernel_mask(
284        input_shape,
285        kernel_shape,
286        strides,
287        'valid'
288    )
289
290    conv_utils.conv_kernel_mask(
291        input_shape,
292        kernel_shape,
293        strides,
294        'same'
295    )
296
297    self.assertRaises(NotImplementedError,
298                      conv_utils.conv_kernel_mask,
299                      input_shape, kernel_shape, strides, 'full')
300
301  def test_conv_kernel_mask_wrong_dims(self, *input_shape):
302    kernel_shape = 1
303    strides = 1
304
305    conv_utils.conv_kernel_mask(
306        input_shape,
307        kernel_shape,
308        strides,
309        'valid'
310    )
311
312    ndims = len(input_shape)
313
314    kernel_shape = (2,) * (ndims + 1)
315    self.assertRaises(ValueError,
316                      conv_utils.conv_kernel_mask,
317                      input_shape, kernel_shape, strides, 'same')
318
319    strides = (1,) * ndims
320    self.assertRaises(ValueError,
321                      conv_utils.conv_kernel_mask,
322                      input_shape, kernel_shape, strides, 'valid')
323
324    kernel_shape = (1,) * ndims
325    strides = (2,) * (ndims - 1)
326    self.assertRaises(ValueError,
327                      conv_utils.conv_kernel_mask,
328                      input_shape, kernel_shape, strides, 'valid')
329
330    strides = (2,) * ndims
331    conv_utils.conv_kernel_mask(
332        input_shape,
333        kernel_shape,
334        strides,
335        'valid'
336    )
337
338
339if __name__ == '__main__':
340  test.main()
341