1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for conv_utils.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import itertools 22 23from absl.testing import parameterized 24import numpy as np 25 26from tensorflow.python.keras.utils import conv_utils 27from tensorflow.python.platform import test 28 29 30def _get_const_output_shape(input_shape, dim): 31 return tuple([min(d, dim) for d in input_shape]) 32 33 34input_shapes = [ 35 (0,), 36 (0, 0), 37 (1,), 38 (2,), 39 (3,), 40 (1, 0), 41 (0, 3), 42 (1, 1), 43 (1, 2), 44 (3, 1), 45 (2, 2), 46 (3, 3), 47 (1, 0, 1), 48 (5, 2, 3), 49 (3, 5, 6, 7, 0), 50 (3, 2, 2, 4, 4), 51 (1, 2, 3, 4, 7, 2), 52] 53 54 55class TestBasicConvUtilsTest(test.TestCase): 56 57 def test_convert_data_format(self): 58 self.assertEqual('NCDHW', conv_utils.convert_data_format( 59 'channels_first', 5)) 60 self.assertEqual('NCHW', conv_utils.convert_data_format( 61 'channels_first', 4)) 62 self.assertEqual('NCW', conv_utils.convert_data_format('channels_first', 3)) 63 self.assertEqual('NHWC', conv_utils.convert_data_format('channels_last', 4)) 64 self.assertEqual('NWC', conv_utils.convert_data_format('channels_last', 3)) 65 self.assertEqual('NDHWC', conv_utils.convert_data_format( 66 'channels_last', 5)) 67 68 with self.assertRaises(ValueError): 69 conv_utils.convert_data_format('invalid', 2) 70 71 def test_normalize_tuple(self): 72 self.assertEqual((2, 2, 2), 73 conv_utils.normalize_tuple(2, n=3, name='strides')) 74 self.assertEqual((2, 1, 2), 75 conv_utils.normalize_tuple((2, 1, 2), n=3, name='strides')) 76 77 with self.assertRaises(ValueError): 78 conv_utils.normalize_tuple((2, 1), n=3, name='strides') 79 80 with self.assertRaises(ValueError): 81 conv_utils.normalize_tuple(None, n=3, name='strides') 82 83 def test_normalize_data_format(self): 84 self.assertEqual('channels_last', 85 conv_utils.normalize_data_format('Channels_Last')) 86 self.assertEqual('channels_first', 87 conv_utils.normalize_data_format('CHANNELS_FIRST')) 88 89 with self.assertRaises(ValueError): 90 conv_utils.normalize_data_format('invalid') 91 92 def test_normalize_padding(self): 93 self.assertEqual('same', conv_utils.normalize_padding('SAME')) 94 self.assertEqual('valid', conv_utils.normalize_padding('VALID')) 95 96 with self.assertRaises(ValueError): 97 conv_utils.normalize_padding('invalid') 98 99 def test_conv_output_length(self): 100 self.assertEqual(4, conv_utils.conv_output_length(4, 2, 'same', 1, 1)) 101 self.assertEqual(2, conv_utils.conv_output_length(4, 2, 'same', 2, 1)) 102 self.assertEqual(3, conv_utils.conv_output_length(4, 2, 'valid', 1, 1)) 103 self.assertEqual(2, conv_utils.conv_output_length(4, 2, 'valid', 2, 1)) 104 self.assertEqual(5, conv_utils.conv_output_length(4, 2, 'full', 1, 1)) 105 self.assertEqual(3, conv_utils.conv_output_length(4, 2, 'full', 2, 1)) 106 self.assertEqual(2, conv_utils.conv_output_length(5, 2, 'valid', 2, 2)) 107 108 def test_conv_input_length(self): 109 self.assertEqual(3, conv_utils.conv_input_length(4, 2, 'same', 1)) 110 self.assertEqual(2, conv_utils.conv_input_length(2, 2, 'same', 2)) 111 self.assertEqual(4, conv_utils.conv_input_length(3, 2, 'valid', 1)) 112 self.assertEqual(4, conv_utils.conv_input_length(2, 2, 'valid', 2)) 113 self.assertEqual(3, conv_utils.conv_input_length(4, 2, 'full', 1)) 114 self.assertEqual(4, conv_utils.conv_input_length(3, 2, 'full', 2)) 115 116 def test_deconv_output_length(self): 117 self.assertEqual(4, conv_utils.deconv_output_length(4, 2, 'same', stride=1)) 118 self.assertEqual(8, conv_utils.deconv_output_length(4, 2, 'same', stride=2)) 119 self.assertEqual(5, conv_utils.deconv_output_length( 120 4, 2, 'valid', stride=1)) 121 self.assertEqual(8, conv_utils.deconv_output_length( 122 4, 2, 'valid', stride=2)) 123 self.assertEqual(3, conv_utils.deconv_output_length(4, 2, 'full', stride=1)) 124 self.assertEqual(6, conv_utils.deconv_output_length(4, 2, 'full', stride=2)) 125 self.assertEqual( 126 5, 127 conv_utils.deconv_output_length( 128 4, 2, 'same', output_padding=2, stride=1)) 129 self.assertEqual( 130 7, 131 conv_utils.deconv_output_length( 132 4, 2, 'same', output_padding=1, stride=2)) 133 self.assertEqual( 134 7, 135 conv_utils.deconv_output_length( 136 4, 2, 'valid', output_padding=2, stride=1)) 137 self.assertEqual( 138 9, 139 conv_utils.deconv_output_length( 140 4, 2, 'valid', output_padding=1, stride=2)) 141 self.assertEqual( 142 5, 143 conv_utils.deconv_output_length( 144 4, 2, 'full', output_padding=2, stride=1)) 145 self.assertEqual( 146 7, 147 conv_utils.deconv_output_length( 148 4, 2, 'full', output_padding=1, stride=2)) 149 self.assertEqual( 150 5, 151 conv_utils.deconv_output_length( 152 4, 2, 'same', output_padding=1, stride=1, dilation=2)) 153 self.assertEqual( 154 12, 155 conv_utils.deconv_output_length( 156 4, 2, 'valid', output_padding=2, stride=2, dilation=3)) 157 self.assertEqual( 158 6, 159 conv_utils.deconv_output_length( 160 4, 2, 'full', output_padding=2, stride=2, dilation=3)) 161 162 163@parameterized.parameters(input_shapes) 164class TestConvUtils(test.TestCase, parameterized.TestCase): 165 166 def test_conv_kernel_mask_fc(self, *input_shape): 167 padding = 'valid' 168 kernel_shape = input_shape 169 ndims = len(input_shape) 170 strides = (1,) * ndims 171 output_shape = _get_const_output_shape(input_shape, dim=1) 172 mask = np.ones(input_shape + output_shape, np.bool) 173 self.assertAllEqual( 174 mask, 175 conv_utils.conv_kernel_mask( 176 input_shape, 177 kernel_shape, 178 strides, 179 padding 180 ) 181 ) 182 183 def test_conv_kernel_mask_diag(self, *input_shape): 184 ndims = len(input_shape) 185 kernel_shape = (1,) * ndims 186 strides = (1,) * ndims 187 188 for padding in ['valid', 'same']: 189 mask = np.identity(int(np.prod(input_shape)), np.bool) 190 mask = np.reshape(mask, input_shape * 2) 191 self.assertAllEqual( 192 mask, 193 conv_utils.conv_kernel_mask( 194 input_shape, 195 kernel_shape, 196 strides, 197 padding 198 ) 199 ) 200 201 def test_conv_kernel_mask_full_stride(self, *input_shape): 202 padding = 'valid' 203 ndims = len(input_shape) 204 kernel_shape = (1,) * ndims 205 strides = tuple([max(d, 1) for d in input_shape]) 206 output_shape = _get_const_output_shape(input_shape, dim=1) 207 208 mask = np.zeros(input_shape + output_shape, np.bool) 209 if all(d > 0 for d in mask.shape): 210 mask[(0,) * len(output_shape)] = True 211 212 self.assertAllEqual( 213 mask, 214 conv_utils.conv_kernel_mask( 215 input_shape, 216 kernel_shape, 217 strides, 218 padding 219 ) 220 ) 221 222 def test_conv_kernel_mask_almost_full_stride(self, *input_shape): 223 padding = 'valid' 224 ndims = len(input_shape) 225 kernel_shape = (1,) * ndims 226 strides = tuple([max(d - 1, 1) for d in input_shape]) 227 output_shape = _get_const_output_shape(input_shape, dim=2) 228 229 mask = np.zeros(input_shape + output_shape, np.bool) 230 if all(d > 0 for d in mask.shape): 231 for in_position in itertools.product(*[[0, d - 1] for d in input_shape]): 232 out_position = tuple([min(p, 1) for p in in_position]) 233 mask[in_position + out_position] = True 234 235 self.assertAllEqual( 236 mask, 237 conv_utils.conv_kernel_mask( 238 input_shape, 239 kernel_shape, 240 strides, 241 padding 242 ) 243 ) 244 245 def test_conv_kernel_mask_rect_kernel(self, *input_shape): 246 padding = 'valid' 247 ndims = len(input_shape) 248 strides = (1,) * ndims 249 250 for d in range(ndims): 251 kernel_shape = [1] * ndims 252 kernel_shape[d] = input_shape[d] 253 254 output_shape = list(input_shape) 255 output_shape[d] = min(1, input_shape[d]) 256 257 mask = np.identity(int(np.prod(input_shape)), np.bool) 258 mask = np.reshape(mask, input_shape * 2) 259 260 for p in itertools.product(*[range(input_shape[dim]) 261 for dim in range(ndims)]): 262 p = list(p) 263 p[d] = slice(None) 264 mask[p * 2] = True 265 266 mask = np.take(mask, range(0, min(1, input_shape[d])), ndims + d) 267 268 self.assertAllEqual( 269 mask, 270 conv_utils.conv_kernel_mask( 271 input_shape, 272 kernel_shape, 273 strides, 274 padding 275 ) 276 ) 277 278 def test_conv_kernel_mask_wrong_padding(self, *input_shape): 279 ndims = len(input_shape) 280 kernel_shape = (1,) * ndims 281 strides = (1,) * ndims 282 283 conv_utils.conv_kernel_mask( 284 input_shape, 285 kernel_shape, 286 strides, 287 'valid' 288 ) 289 290 conv_utils.conv_kernel_mask( 291 input_shape, 292 kernel_shape, 293 strides, 294 'same' 295 ) 296 297 self.assertRaises(NotImplementedError, 298 conv_utils.conv_kernel_mask, 299 input_shape, kernel_shape, strides, 'full') 300 301 def test_conv_kernel_mask_wrong_dims(self, *input_shape): 302 kernel_shape = 1 303 strides = 1 304 305 conv_utils.conv_kernel_mask( 306 input_shape, 307 kernel_shape, 308 strides, 309 'valid' 310 ) 311 312 ndims = len(input_shape) 313 314 kernel_shape = (2,) * (ndims + 1) 315 self.assertRaises(ValueError, 316 conv_utils.conv_kernel_mask, 317 input_shape, kernel_shape, strides, 'same') 318 319 strides = (1,) * ndims 320 self.assertRaises(ValueError, 321 conv_utils.conv_kernel_mask, 322 input_shape, kernel_shape, strides, 'valid') 323 324 kernel_shape = (1,) * ndims 325 strides = (2,) * (ndims - 1) 326 self.assertRaises(ValueError, 327 conv_utils.conv_kernel_mask, 328 input_shape, kernel_shape, strides, 'valid') 329 330 strides = (2,) * ndims 331 conv_utils.conv_kernel_mask( 332 input_shape, 333 kernel_shape, 334 strides, 335 'valid' 336 ) 337 338 339if __name__ == '__main__': 340 test.main() 341