1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities used by convolution layers.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import itertools
22import numpy as np
23from six.moves import range  # pylint: disable=redefined-builtin
24
25from tensorflow.python.keras import backend
26
27
28def convert_data_format(data_format, ndim):
29  if data_format == 'channels_last':
30    if ndim == 3:
31      return 'NWC'
32    elif ndim == 4:
33      return 'NHWC'
34    elif ndim == 5:
35      return 'NDHWC'
36    else:
37      raise ValueError('Input rank not supported:', ndim)
38  elif data_format == 'channels_first':
39    if ndim == 3:
40      return 'NCW'
41    elif ndim == 4:
42      return 'NCHW'
43    elif ndim == 5:
44      return 'NCDHW'
45    else:
46      raise ValueError('Input rank not supported:', ndim)
47  else:
48    raise ValueError('Invalid data_format:', data_format)
49
50
51def normalize_tuple(value, n, name):
52  """Transforms a single integer or iterable of integers into an integer tuple.
53
54  Arguments:
55    value: The value to validate and convert. Could an int, or any iterable
56      of ints.
57    n: The size of the tuple to be returned.
58    name: The name of the argument being validated, e.g. "strides" or
59      "kernel_size". This is only used to format error messages.
60
61  Returns:
62    A tuple of n integers.
63
64  Raises:
65    ValueError: If something else than an int/long or iterable thereof was
66      passed.
67  """
68  if isinstance(value, int):
69    return (value,) * n
70  else:
71    try:
72      value_tuple = tuple(value)
73    except TypeError:
74      raise ValueError('The `' + name + '` argument must be a tuple of ' +
75                       str(n) + ' integers. Received: ' + str(value))
76    if len(value_tuple) != n:
77      raise ValueError('The `' + name + '` argument must be a tuple of ' +
78                       str(n) + ' integers. Received: ' + str(value))
79    for single_value in value_tuple:
80      try:
81        int(single_value)
82      except (ValueError, TypeError):
83        raise ValueError('The `' + name + '` argument must be a tuple of ' +
84                         str(n) + ' integers. Received: ' + str(value) + ' '
85                         'including element ' + str(single_value) + ' of type' +
86                         ' ' + str(type(single_value)))
87    return value_tuple
88
89
90def conv_output_length(input_length, filter_size, padding, stride, dilation=1):
91  """Determines output length of a convolution given input length.
92
93  Arguments:
94      input_length: integer.
95      filter_size: integer.
96      padding: one of "same", "valid", "full", "causal"
97      stride: integer.
98      dilation: dilation rate, integer.
99
100  Returns:
101      The output length (integer).
102  """
103  if input_length is None:
104    return None
105  assert padding in {'same', 'valid', 'full', 'causal'}
106  dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1)
107  if padding in ['same', 'causal']:
108    output_length = input_length
109  elif padding == 'valid':
110    output_length = input_length - dilated_filter_size + 1
111  elif padding == 'full':
112    output_length = input_length + dilated_filter_size - 1
113  return (output_length + stride - 1) // stride
114
115
116def conv_input_length(output_length, filter_size, padding, stride):
117  """Determines input length of a convolution given output length.
118
119  Arguments:
120      output_length: integer.
121      filter_size: integer.
122      padding: one of "same", "valid", "full".
123      stride: integer.
124
125  Returns:
126      The input length (integer).
127  """
128  if output_length is None:
129    return None
130  assert padding in {'same', 'valid', 'full'}
131  if padding == 'same':
132    pad = filter_size // 2
133  elif padding == 'valid':
134    pad = 0
135  elif padding == 'full':
136    pad = filter_size - 1
137  return (output_length - 1) * stride - 2 * pad + filter_size
138
139
140def deconv_output_length(input_length, filter_size, padding,
141                         output_padding=None, stride=0, dilation=1):
142  """Determines output length of a transposed convolution given input length.
143
144  Arguments:
145      input_length: Integer.
146      filter_size: Integer.
147      padding: one of `"same"`, `"valid"`, `"full"`.
148      output_padding: Integer, amount of padding along the output dimension.
149          Can be set to `None` in which case the output length is inferred.
150      stride: Integer.
151      dilation: Integer.
152
153  Returns:
154      The output length (integer).
155  """
156  assert padding in {'same', 'valid', 'full'}
157  if input_length is None:
158    return None
159
160  # Get the dilated kernel size
161  filter_size = filter_size + (filter_size - 1) * (dilation - 1)
162
163  # Infer length if output padding is None, else compute the exact length
164  if output_padding is None:
165    if padding == 'valid':
166      length = input_length * stride + max(filter_size - stride, 0)
167    elif padding == 'full':
168      length = input_length * stride - (stride + filter_size - 2)
169    elif padding == 'same':
170      length = input_length * stride
171
172  else:
173    if padding == 'same':
174      pad = filter_size // 2
175    elif padding == 'valid':
176      pad = 0
177    elif padding == 'full':
178      pad = filter_size - 1
179
180    length = ((input_length - 1) * stride + filter_size - 2 * pad +
181              output_padding)
182  return length
183
184
185def normalize_data_format(value):
186  if value is None:
187    value = backend.image_data_format()
188  data_format = value.lower()
189  if data_format not in {'channels_first', 'channels_last'}:
190    raise ValueError('The `data_format` argument must be one of '
191                     '"channels_first", "channels_last". Received: ' +
192                     str(value))
193  return data_format
194
195
196def normalize_padding(value):
197  if isinstance(value, (list, tuple)):
198    return value
199  padding = value.lower()
200  if padding not in {'valid', 'same', 'causal'}:
201    raise ValueError('The `padding` argument must be a list/tuple or one of '
202                     '"valid", "same" (or "causal", only for `Conv1D). '
203                     'Received: ' + str(padding))
204  return padding
205
206
207def convert_kernel(kernel):
208  """Converts a Numpy kernel matrix from Theano format to TensorFlow format.
209
210  Also works reciprocally, since the transformation is its own inverse.
211
212  Arguments:
213      kernel: Numpy array (3D, 4D or 5D).
214
215  Returns:
216      The converted kernel.
217
218  Raises:
219      ValueError: in case of invalid kernel shape or invalid data_format.
220  """
221  kernel = np.asarray(kernel)
222  if not 3 <= kernel.ndim <= 5:
223    raise ValueError('Invalid kernel shape:', kernel.shape)
224  slices = [slice(None, None, -1) for _ in range(kernel.ndim)]
225  no_flip = (slice(None, None), slice(None, None))
226  slices[-2:] = no_flip
227  return np.copy(kernel[slices])
228
229
230def conv_kernel_mask(input_shape, kernel_shape, strides, padding):
231  """Compute a mask representing the connectivity of a convolution operation.
232
233  Assume a convolution with given parameters is applied to an input having N
234  spatial dimensions with `input_shape = (d_in1, ..., d_inN)` to produce an
235  output with shape `(d_out1, ..., d_outN)`. This method returns a boolean array
236  of shape `(d_in1, ..., d_inN, d_out1, ..., d_outN)` with `True` entries
237  indicating pairs of input and output locations that are connected by a weight.
238
239  Example:
240    ```python
241        >>> input_shape = (4,)
242        >>> kernel_shape = (2,)
243        >>> strides = (1,)
244        >>> padding = "valid"
245        >>> conv_kernel_mask(input_shape, kernel_shape, strides, padding)
246        array([[ True, False, False],
247               [ True,  True, False],
248               [False,  True,  True],
249               [False, False,  True]], dtype=bool)
250    ```
251    where rows and columns correspond to inputs and outputs respectively.
252
253
254  Args:
255    input_shape: tuple of size N: `(d_in1, ..., d_inN)`,
256                 spatial shape of the input.
257    kernel_shape: tuple of size N, spatial shape of the convolutional kernel
258                  / receptive field.
259    strides: tuple of size N, strides along each spatial dimension.
260    padding: type of padding, string `"same"` or `"valid"`.
261
262  Returns:
263    A boolean 2N-D `np.ndarray` of shape
264    `(d_in1, ..., d_inN, d_out1, ..., d_outN)`, where `(d_out1, ..., d_outN)`
265    is the spatial shape of the output. `True` entries in the mask represent
266    pairs of input-output locations that are connected by a weight.
267
268  Raises:
269    ValueError: if `input_shape`, `kernel_shape` and `strides` don't have the
270        same number of dimensions.
271    NotImplementedError: if `padding` is not in {`"same"`, `"valid"`}.
272  """
273  if padding not in {'same', 'valid'}:
274    raise NotImplementedError('Padding type %s not supported. '
275                              'Only "valid" and "same" '
276                              'are implemented.' % padding)
277
278  in_dims = len(input_shape)
279  if isinstance(kernel_shape, int):
280    kernel_shape = (kernel_shape,) * in_dims
281  if isinstance(strides, int):
282    strides = (strides,) * in_dims
283
284  kernel_dims = len(kernel_shape)
285  stride_dims = len(strides)
286  if kernel_dims != in_dims or stride_dims != in_dims:
287    raise ValueError('Number of strides, input and kernel dimensions must all '
288                     'match. Received: %d, %d, %d.' %
289                     (stride_dims, in_dims, kernel_dims))
290
291  output_shape = conv_output_shape(input_shape, kernel_shape, strides, padding)
292
293  mask_shape = input_shape + output_shape
294  mask = np.zeros(mask_shape, np.bool)
295
296  output_axes_ticks = [range(dim) for dim in output_shape]
297  for output_position in itertools.product(*output_axes_ticks):
298    input_axes_ticks = conv_connected_inputs(input_shape,
299                                             kernel_shape,
300                                             output_position,
301                                             strides,
302                                             padding)
303    for input_position in itertools.product(*input_axes_ticks):
304      mask[input_position + output_position] = True
305
306  return mask
307
308
309def conv_connected_inputs(input_shape,
310                          kernel_shape,
311                          output_position,
312                          strides,
313                          padding):
314  """Return locations of the input connected to an output position.
315
316  Assume a convolution with given parameters is applied to an input having N
317  spatial dimensions with `input_shape = (d_in1, ..., d_inN)`. This method
318  returns N ranges specifying the input region that was convolved with the
319  kernel to produce the output at position
320  `output_position = (p_out1, ..., p_outN)`.
321
322  Example:
323    ```python
324        >>> input_shape = (4, 4)
325        >>> kernel_shape = (2, 1)
326        >>> output_position = (1, 1)
327        >>> strides = (1, 1)
328        >>> padding = "valid"
329        >>> conv_connected_inputs(input_shape, kernel_shape, output_position,
330        >>>                       strides, padding)
331        [xrange(1, 3), xrange(1, 2)]
332    ```
333  Args:
334    input_shape: tuple of size N: `(d_in1, ..., d_inN)`,
335                 spatial shape of the input.
336    kernel_shape: tuple of size N, spatial shape of the convolutional kernel
337                  / receptive field.
338    output_position: tuple of size N: `(p_out1, ..., p_outN)`,
339                     a single position in the output of the convolution.
340    strides: tuple of size N, strides along each spatial dimension.
341    padding: type of padding, string `"same"` or `"valid"`.
342
343  Returns:
344    N ranges `[[p_in_left1, ..., p_in_right1], ...,
345              [p_in_leftN, ..., p_in_rightN]]` specifying the region in the
346    input connected to output_position.
347  """
348  ranges = []
349
350  ndims = len(input_shape)
351  for d in range(ndims):
352    left_shift = int(kernel_shape[d] / 2)
353    right_shift = kernel_shape[d] - left_shift
354
355    center = output_position[d] * strides[d]
356
357    if padding == 'valid':
358      center += left_shift
359
360    start = max(0, center - left_shift)
361    end = min(input_shape[d], center + right_shift)
362
363    ranges.append(range(start, end))
364
365  return ranges
366
367
368def conv_output_shape(input_shape, kernel_shape, strides, padding):
369  """Return the output shape of an N-D convolution.
370
371  Forces dimensions where input is empty (size 0) to remain empty.
372
373  Args:
374    input_shape: tuple of size N: `(d_in1, ..., d_inN)`,
375                 spatial shape of the input.
376    kernel_shape: tuple of size N, spatial shape of the convolutional kernel
377                  / receptive field.
378    strides: tuple of size N, strides along each spatial dimension.
379    padding: type of padding, string `"same"` or `"valid"`.
380
381  Returns:
382    tuple of size N: `(d_out1, ..., d_outN)`, spatial shape of the output.
383  """
384  dims = range(len(kernel_shape))
385  output_shape = [conv_output_length(input_shape[d],
386                                     kernel_shape[d],
387                                     padding,
388                                     strides[d])
389                  for d in dims]
390  output_shape = tuple([0 if input_shape[d] == 0 else output_shape[d]
391                        for d in dims])
392  return output_shape
393