1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities used by convolution layers.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import itertools 21 22import numpy as np 23from six.moves import range # pylint: disable=redefined-builtin 24 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.keras import backend 28from tensorflow.python.ops import array_ops 29 30 31def convert_data_format(data_format, ndim): 32 if data_format == 'channels_last': 33 if ndim == 3: 34 return 'NWC' 35 elif ndim == 4: 36 return 'NHWC' 37 elif ndim == 5: 38 return 'NDHWC' 39 else: 40 raise ValueError('Input rank not supported:', ndim) 41 elif data_format == 'channels_first': 42 if ndim == 3: 43 return 'NCW' 44 elif ndim == 4: 45 return 'NCHW' 46 elif ndim == 5: 47 return 'NCDHW' 48 else: 49 raise ValueError('Input rank not supported:', ndim) 50 else: 51 raise ValueError('Invalid data_format:', data_format) 52 53 54def normalize_tuple(value, n, name): 55 """Transforms a single integer or iterable of integers into an integer tuple. 56 57 Args: 58 value: The value to validate and convert. Could an int, or any iterable of 59 ints. 60 n: The size of the tuple to be returned. 61 name: The name of the argument being validated, e.g. "strides" or 62 "kernel_size". This is only used to format error messages. 63 64 Returns: 65 A tuple of n integers. 66 67 Raises: 68 ValueError: If something else than an int/long or iterable thereof was 69 passed. 70 """ 71 if isinstance(value, int): 72 return (value,) * n 73 else: 74 try: 75 value_tuple = tuple(value) 76 except TypeError: 77 raise ValueError('The `' + name + '` argument must be a tuple of ' + 78 str(n) + ' integers. Received: ' + str(value)) 79 if len(value_tuple) != n: 80 raise ValueError('The `' + name + '` argument must be a tuple of ' + 81 str(n) + ' integers. Received: ' + str(value)) 82 for single_value in value_tuple: 83 try: 84 int(single_value) 85 except (ValueError, TypeError): 86 raise ValueError('The `' + name + '` argument must be a tuple of ' + 87 str(n) + ' integers. Received: ' + str(value) + ' ' 88 'including element ' + str(single_value) + ' of type' + 89 ' ' + str(type(single_value))) 90 return value_tuple 91 92 93def conv_output_length(input_length, filter_size, padding, stride, dilation=1): 94 """Determines output length of a convolution given input length. 95 96 Args: 97 input_length: integer. 98 filter_size: integer. 99 padding: one of "same", "valid", "full", "causal" 100 stride: integer. 101 dilation: dilation rate, integer. 102 103 Returns: 104 The output length (integer). 105 """ 106 if input_length is None: 107 return None 108 assert padding in {'same', 'valid', 'full', 'causal'} 109 dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1) 110 if padding in ['same', 'causal']: 111 output_length = input_length 112 elif padding == 'valid': 113 output_length = input_length - dilated_filter_size + 1 114 elif padding == 'full': 115 output_length = input_length + dilated_filter_size - 1 116 return (output_length + stride - 1) // stride 117 118 119def conv_input_length(output_length, filter_size, padding, stride): 120 """Determines input length of a convolution given output length. 121 122 Args: 123 output_length: integer. 124 filter_size: integer. 125 padding: one of "same", "valid", "full". 126 stride: integer. 127 128 Returns: 129 The input length (integer). 130 """ 131 if output_length is None: 132 return None 133 assert padding in {'same', 'valid', 'full'} 134 if padding == 'same': 135 pad = filter_size // 2 136 elif padding == 'valid': 137 pad = 0 138 elif padding == 'full': 139 pad = filter_size - 1 140 return (output_length - 1) * stride - 2 * pad + filter_size 141 142 143def deconv_output_length(input_length, 144 filter_size, 145 padding, 146 output_padding=None, 147 stride=0, 148 dilation=1): 149 """Determines output length of a transposed convolution given input length. 150 151 Args: 152 input_length: Integer. 153 filter_size: Integer. 154 padding: one of `"same"`, `"valid"`, `"full"`. 155 output_padding: Integer, amount of padding along the output dimension. Can 156 be set to `None` in which case the output length is inferred. 157 stride: Integer. 158 dilation: Integer. 159 160 Returns: 161 The output length (integer). 162 """ 163 assert padding in {'same', 'valid', 'full'} 164 if input_length is None: 165 return None 166 167 # Get the dilated kernel size 168 filter_size = filter_size + (filter_size - 1) * (dilation - 1) 169 170 # Infer length if output padding is None, else compute the exact length 171 if output_padding is None: 172 if padding == 'valid': 173 length = input_length * stride + max(filter_size - stride, 0) 174 elif padding == 'full': 175 length = input_length * stride - (stride + filter_size - 2) 176 elif padding == 'same': 177 length = input_length * stride 178 179 else: 180 if padding == 'same': 181 pad = filter_size // 2 182 elif padding == 'valid': 183 pad = 0 184 elif padding == 'full': 185 pad = filter_size - 1 186 187 length = ((input_length - 1) * stride + filter_size - 2 * pad + 188 output_padding) 189 return length 190 191 192def normalize_data_format(value): 193 if value is None: 194 value = backend.image_data_format() 195 data_format = value.lower() 196 if data_format not in {'channels_first', 'channels_last'}: 197 raise ValueError('The `data_format` argument must be one of ' 198 '"channels_first", "channels_last". Received: ' + 199 str(value)) 200 return data_format 201 202 203def normalize_padding(value): 204 if isinstance(value, (list, tuple)): 205 return value 206 padding = value.lower() 207 if padding not in {'valid', 'same', 'causal'}: 208 raise ValueError('The `padding` argument must be a list/tuple or one of ' 209 '"valid", "same" (or "causal", only for `Conv1D). ' 210 'Received: ' + str(padding)) 211 return padding 212 213 214def conv_kernel_mask(input_shape, kernel_shape, strides, padding): 215 """Compute a mask representing the connectivity of a convolution operation. 216 217 Assume a convolution with given parameters is applied to an input having N 218 spatial dimensions with `input_shape = (d_in1, ..., d_inN)` to produce an 219 output with shape `(d_out1, ..., d_outN)`. This method returns a boolean array 220 of shape `(d_in1, ..., d_inN, d_out1, ..., d_outN)` with `True` entries 221 indicating pairs of input and output locations that are connected by a weight. 222 223 Example: 224 225 >>> input_shape = (4,) 226 >>> kernel_shape = (2,) 227 >>> strides = (1,) 228 >>> padding = "valid" 229 >>> conv_kernel_mask(input_shape, kernel_shape, strides, padding) 230 array([[ True, False, False], 231 [ True, True, False], 232 [False, True, True], 233 [False, False, True]]) 234 235 where rows and columns correspond to inputs and outputs respectively. 236 237 238 Args: 239 input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the 240 input. 241 kernel_shape: tuple of size N, spatial shape of the convolutional kernel / 242 receptive field. 243 strides: tuple of size N, strides along each spatial dimension. 244 padding: type of padding, string `"same"` or `"valid"`. 245 `"valid"` means no padding. `"same"` results in padding evenly to 246 the left/right or up/down of the input such that output has the same 247 height/width dimension as the input. 248 249 Returns: 250 A boolean 2N-D `np.ndarray` of shape 251 `(d_in1, ..., d_inN, d_out1, ..., d_outN)`, where `(d_out1, ..., d_outN)` 252 is the spatial shape of the output. `True` entries in the mask represent 253 pairs of input-output locations that are connected by a weight. 254 255 Raises: 256 ValueError: if `input_shape`, `kernel_shape` and `strides` don't have the 257 same number of dimensions. 258 NotImplementedError: if `padding` is not in {`"same"`, `"valid"`}. 259 """ 260 if padding not in {'same', 'valid'}: 261 raise NotImplementedError('Padding type %s not supported. ' 262 'Only "valid" and "same" ' 263 'are implemented.' % padding) 264 265 in_dims = len(input_shape) 266 if isinstance(kernel_shape, int): 267 kernel_shape = (kernel_shape,) * in_dims 268 if isinstance(strides, int): 269 strides = (strides,) * in_dims 270 271 kernel_dims = len(kernel_shape) 272 stride_dims = len(strides) 273 if kernel_dims != in_dims or stride_dims != in_dims: 274 raise ValueError('Number of strides, input and kernel dimensions must all ' 275 'match. Received: %d, %d, %d.' % 276 (stride_dims, in_dims, kernel_dims)) 277 278 output_shape = conv_output_shape(input_shape, kernel_shape, strides, padding) 279 280 mask_shape = input_shape + output_shape 281 mask = np.zeros(mask_shape, np.bool) 282 283 output_axes_ticks = [range(dim) for dim in output_shape] 284 for output_position in itertools.product(*output_axes_ticks): 285 input_axes_ticks = conv_connected_inputs(input_shape, kernel_shape, 286 output_position, strides, padding) 287 for input_position in itertools.product(*input_axes_ticks): 288 mask[input_position + output_position] = True 289 290 return mask 291 292 293def conv_kernel_idxs(input_shape, kernel_shape, strides, padding, filters_in, 294 filters_out, data_format): 295 """Yields output-input tuples of indices in a CNN layer. 296 297 The generator iterates over all `(output_idx, input_idx)` tuples, where 298 `output_idx` is an integer index in a flattened tensor representing a single 299 output image of a convolutional layer that is connected (via the layer 300 weights) to the respective single input image at `input_idx` 301 302 Example: 303 304 >>> input_shape = (2, 2) 305 >>> kernel_shape = (2, 1) 306 >>> strides = (1, 1) 307 >>> padding = "valid" 308 >>> filters_in = 1 309 >>> filters_out = 1 310 >>> data_format = "channels_last" 311 >>> list(conv_kernel_idxs(input_shape, kernel_shape, strides, padding, 312 ... filters_in, filters_out, data_format)) 313 [(0, 0), (0, 2), (1, 1), (1, 3)] 314 315 Args: 316 input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the 317 input. 318 kernel_shape: tuple of size N, spatial shape of the convolutional kernel / 319 receptive field. 320 strides: tuple of size N, strides along each spatial dimension. 321 padding: type of padding, string `"same"` or `"valid"`. 322 `"valid"` means no padding. `"same"` results in padding evenly to 323 the left/right or up/down of the input such that output has the same 324 height/width dimension as the input. 325 filters_in: `int`, number if filters in the input to the layer. 326 filters_out: `int', number if filters in the output of the layer. 327 data_format: string, "channels_first" or "channels_last". 328 329 Yields: 330 The next tuple `(output_idx, input_idx)`, where 331 `output_idx` is an integer index in a flattened tensor representing a single 332 output image of a convolutional layer that is connected (via the layer 333 weights) to the respective single input image at `input_idx`. 334 335 Raises: 336 ValueError: if `data_format` is neither 337 `"channels_last"` nor `"channels_first"`, or if number of strides, input, 338 and kernel number of dimensions do not match. 339 340 NotImplementedError: if `padding` is neither `"same"` nor `"valid"`. 341 """ 342 if padding not in ('same', 'valid'): 343 raise NotImplementedError('Padding type %s not supported. ' 344 'Only "valid" and "same" ' 345 'are implemented.' % padding) 346 347 in_dims = len(input_shape) 348 if isinstance(kernel_shape, int): 349 kernel_shape = (kernel_shape,) * in_dims 350 if isinstance(strides, int): 351 strides = (strides,) * in_dims 352 353 kernel_dims = len(kernel_shape) 354 stride_dims = len(strides) 355 if kernel_dims != in_dims or stride_dims != in_dims: 356 raise ValueError('Number of strides, input and kernel dimensions must all ' 357 'match. Received: %d, %d, %d.' % 358 (stride_dims, in_dims, kernel_dims)) 359 360 output_shape = conv_output_shape(input_shape, kernel_shape, strides, padding) 361 output_axes_ticks = [range(dim) for dim in output_shape] 362 363 if data_format == 'channels_first': 364 concat_idxs = lambda spatial_idx, filter_idx: (filter_idx,) + spatial_idx 365 elif data_format == 'channels_last': 366 concat_idxs = lambda spatial_idx, filter_idx: spatial_idx + (filter_idx,) 367 else: 368 raise ValueError('Data format %s not recognized.' 369 '`data_format` must be "channels_first" or ' 370 '"channels_last".' % data_format) 371 372 for output_position in itertools.product(*output_axes_ticks): 373 input_axes_ticks = conv_connected_inputs(input_shape, kernel_shape, 374 output_position, strides, padding) 375 for input_position in itertools.product(*input_axes_ticks): 376 for f_in in range(filters_in): 377 for f_out in range(filters_out): 378 out_idx = np.ravel_multi_index( 379 multi_index=concat_idxs(output_position, f_out), 380 dims=concat_idxs(output_shape, filters_out)) 381 in_idx = np.ravel_multi_index( 382 multi_index=concat_idxs(input_position, f_in), 383 dims=concat_idxs(input_shape, filters_in)) 384 yield (out_idx, in_idx) 385 386 387def conv_connected_inputs(input_shape, kernel_shape, output_position, strides, 388 padding): 389 """Return locations of the input connected to an output position. 390 391 Assume a convolution with given parameters is applied to an input having N 392 spatial dimensions with `input_shape = (d_in1, ..., d_inN)`. This method 393 returns N ranges specifying the input region that was convolved with the 394 kernel to produce the output at position 395 `output_position = (p_out1, ..., p_outN)`. 396 397 Example: 398 399 >>> input_shape = (4, 4) 400 >>> kernel_shape = (2, 1) 401 >>> output_position = (1, 1) 402 >>> strides = (1, 1) 403 >>> padding = "valid" 404 >>> conv_connected_inputs(input_shape, kernel_shape, output_position, 405 ... strides, padding) 406 [range(1, 3), range(1, 2)] 407 408 Args: 409 input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the 410 input. 411 kernel_shape: tuple of size N, spatial shape of the convolutional kernel / 412 receptive field. 413 output_position: tuple of size N: `(p_out1, ..., p_outN)`, a single position 414 in the output of the convolution. 415 strides: tuple of size N, strides along each spatial dimension. 416 padding: type of padding, string `"same"` or `"valid"`. 417 `"valid"` means no padding. `"same"` results in padding evenly to 418 the left/right or up/down of the input such that output has the same 419 height/width dimension as the input. 420 421 Returns: 422 N ranges `[[p_in_left1, ..., p_in_right1], ..., 423 [p_in_leftN, ..., p_in_rightN]]` specifying the region in the 424 input connected to output_position. 425 """ 426 ranges = [] 427 428 ndims = len(input_shape) 429 for d in range(ndims): 430 left_shift = int(kernel_shape[d] / 2) 431 right_shift = kernel_shape[d] - left_shift 432 433 center = output_position[d] * strides[d] 434 435 if padding == 'valid': 436 center += left_shift 437 438 start = max(0, center - left_shift) 439 end = min(input_shape[d], center + right_shift) 440 441 ranges.append(range(start, end)) 442 443 return ranges 444 445 446def conv_output_shape(input_shape, kernel_shape, strides, padding): 447 """Return the output shape of an N-D convolution. 448 449 Forces dimensions where input is empty (size 0) to remain empty. 450 451 Args: 452 input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the 453 input. 454 kernel_shape: tuple of size N, spatial shape of the convolutional kernel / 455 receptive field. 456 strides: tuple of size N, strides along each spatial dimension. 457 padding: type of padding, string `"same"` or `"valid"`. 458 `"valid"` means no padding. `"same"` results in padding evenly to 459 the left/right or up/down of the input such that output has the same 460 height/width dimension as the input. 461 462 Returns: 463 tuple of size N: `(d_out1, ..., d_outN)`, spatial shape of the output. 464 """ 465 dims = range(len(kernel_shape)) 466 output_shape = [ 467 conv_output_length(input_shape[d], kernel_shape[d], padding, strides[d]) 468 for d in dims 469 ] 470 output_shape = tuple( 471 [0 if input_shape[d] == 0 else output_shape[d] for d in dims]) 472 return output_shape 473 474 475def squeeze_batch_dims(inp, op, inner_rank): 476 """Returns `unsqueeze_batch(op(squeeze_batch(inp)))`. 477 478 Where `squeeze_batch` reshapes `inp` to shape 479 `[prod(inp.shape[:-inner_rank])] + inp.shape[-inner_rank:]` 480 and `unsqueeze_batch` does the reverse reshape but on the output. 481 482 Args: 483 inp: A tensor with dims `batch_shape + inner_shape` where `inner_shape` 484 is length `inner_rank`. 485 op: A callable that takes a single input tensor and returns a single. 486 output tensor. 487 inner_rank: A python integer. 488 489 Returns: 490 `unsqueeze_batch_op(squeeze_batch(inp))`. 491 """ 492 with ops.name_scope_v2('squeeze_batch_dims'): 493 shape = inp.shape 494 495 inner_shape = shape[-inner_rank:] 496 if not inner_shape.is_fully_defined(): 497 inner_shape = array_ops.shape(inp)[-inner_rank:] 498 499 batch_shape = shape[:-inner_rank] 500 if not batch_shape.is_fully_defined(): 501 batch_shape = array_ops.shape(inp)[:-inner_rank] 502 503 if isinstance(inner_shape, tensor_shape.TensorShape): 504 inp_reshaped = array_ops.reshape(inp, [-1] + inner_shape.as_list()) 505 else: 506 inp_reshaped = array_ops.reshape( 507 inp, array_ops.concat(([-1], inner_shape), axis=-1)) 508 509 out_reshaped = op(inp_reshaped) 510 511 out_inner_shape = out_reshaped.shape[-inner_rank:] 512 if not out_inner_shape.is_fully_defined(): 513 out_inner_shape = array_ops.shape(out_reshaped)[-inner_rank:] 514 515 out = array_ops.reshape( 516 out_reshaped, array_ops.concat((batch_shape, out_inner_shape), axis=-1)) 517 518 out.set_shape(inp.shape[:-inner_rank] + out.shape[-inner_rank:]) 519 return out 520