1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities used by convolution layers."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import itertools
21
22import numpy as np
23from six.moves import range  # pylint: disable=redefined-builtin
24
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.keras import backend
28from tensorflow.python.ops import array_ops
29
30
31def convert_data_format(data_format, ndim):
32  if data_format == 'channels_last':
33    if ndim == 3:
34      return 'NWC'
35    elif ndim == 4:
36      return 'NHWC'
37    elif ndim == 5:
38      return 'NDHWC'
39    else:
40      raise ValueError('Input rank not supported:', ndim)
41  elif data_format == 'channels_first':
42    if ndim == 3:
43      return 'NCW'
44    elif ndim == 4:
45      return 'NCHW'
46    elif ndim == 5:
47      return 'NCDHW'
48    else:
49      raise ValueError('Input rank not supported:', ndim)
50  else:
51    raise ValueError('Invalid data_format:', data_format)
52
53
54def normalize_tuple(value, n, name):
55  """Transforms a single integer or iterable of integers into an integer tuple.
56
57  Args:
58    value: The value to validate and convert. Could an int, or any iterable of
59      ints.
60    n: The size of the tuple to be returned.
61    name: The name of the argument being validated, e.g. "strides" or
62      "kernel_size". This is only used to format error messages.
63
64  Returns:
65    A tuple of n integers.
66
67  Raises:
68    ValueError: If something else than an int/long or iterable thereof was
69      passed.
70  """
71  if isinstance(value, int):
72    return (value,) * n
73  else:
74    try:
75      value_tuple = tuple(value)
76    except TypeError:
77      raise ValueError('The `' + name + '` argument must be a tuple of ' +
78                       str(n) + ' integers. Received: ' + str(value))
79    if len(value_tuple) != n:
80      raise ValueError('The `' + name + '` argument must be a tuple of ' +
81                       str(n) + ' integers. Received: ' + str(value))
82    for single_value in value_tuple:
83      try:
84        int(single_value)
85      except (ValueError, TypeError):
86        raise ValueError('The `' + name + '` argument must be a tuple of ' +
87                         str(n) + ' integers. Received: ' + str(value) + ' '
88                         'including element ' + str(single_value) + ' of type' +
89                         ' ' + str(type(single_value)))
90    return value_tuple
91
92
93def conv_output_length(input_length, filter_size, padding, stride, dilation=1):
94  """Determines output length of a convolution given input length.
95
96  Args:
97      input_length: integer.
98      filter_size: integer.
99      padding: one of "same", "valid", "full", "causal"
100      stride: integer.
101      dilation: dilation rate, integer.
102
103  Returns:
104      The output length (integer).
105  """
106  if input_length is None:
107    return None
108  assert padding in {'same', 'valid', 'full', 'causal'}
109  dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1)
110  if padding in ['same', 'causal']:
111    output_length = input_length
112  elif padding == 'valid':
113    output_length = input_length - dilated_filter_size + 1
114  elif padding == 'full':
115    output_length = input_length + dilated_filter_size - 1
116  return (output_length + stride - 1) // stride
117
118
119def conv_input_length(output_length, filter_size, padding, stride):
120  """Determines input length of a convolution given output length.
121
122  Args:
123      output_length: integer.
124      filter_size: integer.
125      padding: one of "same", "valid", "full".
126      stride: integer.
127
128  Returns:
129      The input length (integer).
130  """
131  if output_length is None:
132    return None
133  assert padding in {'same', 'valid', 'full'}
134  if padding == 'same':
135    pad = filter_size // 2
136  elif padding == 'valid':
137    pad = 0
138  elif padding == 'full':
139    pad = filter_size - 1
140  return (output_length - 1) * stride - 2 * pad + filter_size
141
142
143def deconv_output_length(input_length,
144                         filter_size,
145                         padding,
146                         output_padding=None,
147                         stride=0,
148                         dilation=1):
149  """Determines output length of a transposed convolution given input length.
150
151  Args:
152      input_length: Integer.
153      filter_size: Integer.
154      padding: one of `"same"`, `"valid"`, `"full"`.
155      output_padding: Integer, amount of padding along the output dimension. Can
156        be set to `None` in which case the output length is inferred.
157      stride: Integer.
158      dilation: Integer.
159
160  Returns:
161      The output length (integer).
162  """
163  assert padding in {'same', 'valid', 'full'}
164  if input_length is None:
165    return None
166
167  # Get the dilated kernel size
168  filter_size = filter_size + (filter_size - 1) * (dilation - 1)
169
170  # Infer length if output padding is None, else compute the exact length
171  if output_padding is None:
172    if padding == 'valid':
173      length = input_length * stride + max(filter_size - stride, 0)
174    elif padding == 'full':
175      length = input_length * stride - (stride + filter_size - 2)
176    elif padding == 'same':
177      length = input_length * stride
178
179  else:
180    if padding == 'same':
181      pad = filter_size // 2
182    elif padding == 'valid':
183      pad = 0
184    elif padding == 'full':
185      pad = filter_size - 1
186
187    length = ((input_length - 1) * stride + filter_size - 2 * pad +
188              output_padding)
189  return length
190
191
192def normalize_data_format(value):
193  if value is None:
194    value = backend.image_data_format()
195  data_format = value.lower()
196  if data_format not in {'channels_first', 'channels_last'}:
197    raise ValueError('The `data_format` argument must be one of '
198                     '"channels_first", "channels_last". Received: ' +
199                     str(value))
200  return data_format
201
202
203def normalize_padding(value):
204  if isinstance(value, (list, tuple)):
205    return value
206  padding = value.lower()
207  if padding not in {'valid', 'same', 'causal'}:
208    raise ValueError('The `padding` argument must be a list/tuple or one of '
209                     '"valid", "same" (or "causal", only for `Conv1D). '
210                     'Received: ' + str(padding))
211  return padding
212
213
214def conv_kernel_mask(input_shape, kernel_shape, strides, padding):
215  """Compute a mask representing the connectivity of a convolution operation.
216
217  Assume a convolution with given parameters is applied to an input having N
218  spatial dimensions with `input_shape = (d_in1, ..., d_inN)` to produce an
219  output with shape `(d_out1, ..., d_outN)`. This method returns a boolean array
220  of shape `(d_in1, ..., d_inN, d_out1, ..., d_outN)` with `True` entries
221  indicating pairs of input and output locations that are connected by a weight.
222
223  Example:
224
225    >>> input_shape = (4,)
226    >>> kernel_shape = (2,)
227    >>> strides = (1,)
228    >>> padding = "valid"
229    >>> conv_kernel_mask(input_shape, kernel_shape, strides, padding)
230    array([[ True, False, False],
231           [ True,  True, False],
232           [False,  True,  True],
233           [False, False,  True]])
234
235    where rows and columns correspond to inputs and outputs respectively.
236
237
238  Args:
239    input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
240      input.
241    kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
242      receptive field.
243    strides: tuple of size N, strides along each spatial dimension.
244    padding: type of padding, string `"same"` or `"valid"`.
245      `"valid"` means no padding. `"same"` results in padding evenly to
246      the left/right or up/down of the input such that output has the same
247      height/width dimension as the input.
248
249  Returns:
250    A boolean 2N-D `np.ndarray` of shape
251    `(d_in1, ..., d_inN, d_out1, ..., d_outN)`, where `(d_out1, ..., d_outN)`
252    is the spatial shape of the output. `True` entries in the mask represent
253    pairs of input-output locations that are connected by a weight.
254
255  Raises:
256    ValueError: if `input_shape`, `kernel_shape` and `strides` don't have the
257        same number of dimensions.
258    NotImplementedError: if `padding` is not in {`"same"`, `"valid"`}.
259  """
260  if padding not in {'same', 'valid'}:
261    raise NotImplementedError('Padding type %s not supported. '
262                              'Only "valid" and "same" '
263                              'are implemented.' % padding)
264
265  in_dims = len(input_shape)
266  if isinstance(kernel_shape, int):
267    kernel_shape = (kernel_shape,) * in_dims
268  if isinstance(strides, int):
269    strides = (strides,) * in_dims
270
271  kernel_dims = len(kernel_shape)
272  stride_dims = len(strides)
273  if kernel_dims != in_dims or stride_dims != in_dims:
274    raise ValueError('Number of strides, input and kernel dimensions must all '
275                     'match. Received: %d, %d, %d.' %
276                     (stride_dims, in_dims, kernel_dims))
277
278  output_shape = conv_output_shape(input_shape, kernel_shape, strides, padding)
279
280  mask_shape = input_shape + output_shape
281  mask = np.zeros(mask_shape, np.bool)
282
283  output_axes_ticks = [range(dim) for dim in output_shape]
284  for output_position in itertools.product(*output_axes_ticks):
285    input_axes_ticks = conv_connected_inputs(input_shape, kernel_shape,
286                                             output_position, strides, padding)
287    for input_position in itertools.product(*input_axes_ticks):
288      mask[input_position + output_position] = True
289
290  return mask
291
292
293def conv_kernel_idxs(input_shape, kernel_shape, strides, padding, filters_in,
294                     filters_out, data_format):
295  """Yields output-input tuples of indices in a CNN layer.
296
297  The generator iterates over all `(output_idx, input_idx)` tuples, where
298    `output_idx` is an integer index in a flattened tensor representing a single
299    output image of a convolutional layer that is connected (via the layer
300    weights) to the respective single input image at `input_idx`
301
302  Example:
303
304    >>> input_shape = (2, 2)
305    >>> kernel_shape = (2, 1)
306    >>> strides = (1, 1)
307    >>> padding = "valid"
308    >>> filters_in = 1
309    >>> filters_out = 1
310    >>> data_format = "channels_last"
311    >>> list(conv_kernel_idxs(input_shape, kernel_shape, strides, padding,
312    ...                       filters_in, filters_out, data_format))
313    [(0, 0), (0, 2), (1, 1), (1, 3)]
314
315  Args:
316    input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
317      input.
318    kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
319      receptive field.
320    strides: tuple of size N, strides along each spatial dimension.
321    padding: type of padding, string `"same"` or `"valid"`.
322      `"valid"` means no padding. `"same"` results in padding evenly to
323      the left/right or up/down of the input such that output has the same
324      height/width dimension as the input.
325    filters_in: `int`, number if filters in the input to the layer.
326    filters_out: `int', number if filters in the output of the layer.
327    data_format: string, "channels_first" or "channels_last".
328
329  Yields:
330    The next tuple `(output_idx, input_idx)`, where
331    `output_idx` is an integer index in a flattened tensor representing a single
332    output image of a convolutional layer that is connected (via the layer
333    weights) to the respective single input image at `input_idx`.
334
335  Raises:
336      ValueError: if `data_format` is neither
337      `"channels_last"` nor `"channels_first"`, or if number of strides, input,
338      and kernel number of dimensions do not match.
339
340      NotImplementedError: if `padding` is neither `"same"` nor `"valid"`.
341  """
342  if padding not in ('same', 'valid'):
343    raise NotImplementedError('Padding type %s not supported. '
344                              'Only "valid" and "same" '
345                              'are implemented.' % padding)
346
347  in_dims = len(input_shape)
348  if isinstance(kernel_shape, int):
349    kernel_shape = (kernel_shape,) * in_dims
350  if isinstance(strides, int):
351    strides = (strides,) * in_dims
352
353  kernel_dims = len(kernel_shape)
354  stride_dims = len(strides)
355  if kernel_dims != in_dims or stride_dims != in_dims:
356    raise ValueError('Number of strides, input and kernel dimensions must all '
357                     'match. Received: %d, %d, %d.' %
358                     (stride_dims, in_dims, kernel_dims))
359
360  output_shape = conv_output_shape(input_shape, kernel_shape, strides, padding)
361  output_axes_ticks = [range(dim) for dim in output_shape]
362
363  if data_format == 'channels_first':
364    concat_idxs = lambda spatial_idx, filter_idx: (filter_idx,) + spatial_idx
365  elif data_format == 'channels_last':
366    concat_idxs = lambda spatial_idx, filter_idx: spatial_idx + (filter_idx,)
367  else:
368    raise ValueError('Data format %s not recognized.'
369                     '`data_format` must be "channels_first" or '
370                     '"channels_last".' % data_format)
371
372  for output_position in itertools.product(*output_axes_ticks):
373    input_axes_ticks = conv_connected_inputs(input_shape, kernel_shape,
374                                             output_position, strides, padding)
375    for input_position in itertools.product(*input_axes_ticks):
376      for f_in in range(filters_in):
377        for f_out in range(filters_out):
378          out_idx = np.ravel_multi_index(
379              multi_index=concat_idxs(output_position, f_out),
380              dims=concat_idxs(output_shape, filters_out))
381          in_idx = np.ravel_multi_index(
382              multi_index=concat_idxs(input_position, f_in),
383              dims=concat_idxs(input_shape, filters_in))
384          yield (out_idx, in_idx)
385
386
387def conv_connected_inputs(input_shape, kernel_shape, output_position, strides,
388                          padding):
389  """Return locations of the input connected to an output position.
390
391  Assume a convolution with given parameters is applied to an input having N
392  spatial dimensions with `input_shape = (d_in1, ..., d_inN)`. This method
393  returns N ranges specifying the input region that was convolved with the
394  kernel to produce the output at position
395  `output_position = (p_out1, ..., p_outN)`.
396
397  Example:
398
399    >>> input_shape = (4, 4)
400    >>> kernel_shape = (2, 1)
401    >>> output_position = (1, 1)
402    >>> strides = (1, 1)
403    >>> padding = "valid"
404    >>> conv_connected_inputs(input_shape, kernel_shape, output_position,
405    ...                       strides, padding)
406    [range(1, 3), range(1, 2)]
407
408  Args:
409    input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
410      input.
411    kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
412      receptive field.
413    output_position: tuple of size N: `(p_out1, ..., p_outN)`, a single position
414      in the output of the convolution.
415    strides: tuple of size N, strides along each spatial dimension.
416    padding: type of padding, string `"same"` or `"valid"`.
417      `"valid"` means no padding. `"same"` results in padding evenly to
418      the left/right or up/down of the input such that output has the same
419      height/width dimension as the input.
420
421  Returns:
422    N ranges `[[p_in_left1, ..., p_in_right1], ...,
423              [p_in_leftN, ..., p_in_rightN]]` specifying the region in the
424    input connected to output_position.
425  """
426  ranges = []
427
428  ndims = len(input_shape)
429  for d in range(ndims):
430    left_shift = int(kernel_shape[d] / 2)
431    right_shift = kernel_shape[d] - left_shift
432
433    center = output_position[d] * strides[d]
434
435    if padding == 'valid':
436      center += left_shift
437
438    start = max(0, center - left_shift)
439    end = min(input_shape[d], center + right_shift)
440
441    ranges.append(range(start, end))
442
443  return ranges
444
445
446def conv_output_shape(input_shape, kernel_shape, strides, padding):
447  """Return the output shape of an N-D convolution.
448
449  Forces dimensions where input is empty (size 0) to remain empty.
450
451  Args:
452    input_shape: tuple of size N: `(d_in1, ..., d_inN)`, spatial shape of the
453      input.
454    kernel_shape: tuple of size N, spatial shape of the convolutional kernel /
455      receptive field.
456    strides: tuple of size N, strides along each spatial dimension.
457    padding: type of padding, string `"same"` or `"valid"`.
458      `"valid"` means no padding. `"same"` results in padding evenly to
459      the left/right or up/down of the input such that output has the same
460      height/width dimension as the input.
461
462  Returns:
463    tuple of size N: `(d_out1, ..., d_outN)`, spatial shape of the output.
464  """
465  dims = range(len(kernel_shape))
466  output_shape = [
467      conv_output_length(input_shape[d], kernel_shape[d], padding, strides[d])
468      for d in dims
469  ]
470  output_shape = tuple(
471      [0 if input_shape[d] == 0 else output_shape[d] for d in dims])
472  return output_shape
473
474
475def squeeze_batch_dims(inp, op, inner_rank):
476  """Returns `unsqueeze_batch(op(squeeze_batch(inp)))`.
477
478  Where `squeeze_batch` reshapes `inp` to shape
479  `[prod(inp.shape[:-inner_rank])] + inp.shape[-inner_rank:]`
480  and `unsqueeze_batch` does the reverse reshape but on the output.
481
482  Args:
483    inp: A tensor with dims `batch_shape + inner_shape` where `inner_shape`
484      is length `inner_rank`.
485    op: A callable that takes a single input tensor and returns a single.
486      output tensor.
487    inner_rank: A python integer.
488
489  Returns:
490    `unsqueeze_batch_op(squeeze_batch(inp))`.
491  """
492  with ops.name_scope_v2('squeeze_batch_dims'):
493    shape = inp.shape
494
495    inner_shape = shape[-inner_rank:]
496    if not inner_shape.is_fully_defined():
497      inner_shape = array_ops.shape(inp)[-inner_rank:]
498
499    batch_shape = shape[:-inner_rank]
500    if not batch_shape.is_fully_defined():
501      batch_shape = array_ops.shape(inp)[:-inner_rank]
502
503    if isinstance(inner_shape, tensor_shape.TensorShape):
504      inp_reshaped = array_ops.reshape(inp, [-1] + inner_shape.as_list())
505    else:
506      inp_reshaped = array_ops.reshape(
507          inp, array_ops.concat(([-1], inner_shape), axis=-1))
508
509    out_reshaped = op(inp_reshaped)
510
511    out_inner_shape = out_reshaped.shape[-inner_rank:]
512    if not out_inner_shape.is_fully_defined():
513      out_inner_shape = array_ops.shape(out_reshaped)[-inner_rank:]
514
515    out = array_ops.reshape(
516        out_reshaped, array_ops.concat((batch_shape, out_inner_shape), axis=-1))
517
518    out.set_shape(inp.shape[:-inner_rank] + out.shape[-inner_rank:])
519    return out
520