1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Contains layer utilities for input validation and format conversion."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import smart_cond as smart_module
22from tensorflow.python.ops import control_flow_ops
23from tensorflow.python.ops import variables
24
25
26def convert_data_format(data_format, ndim):
27  if data_format == 'channels_last':
28    if ndim == 3:
29      return 'NWC'
30    elif ndim == 4:
31      return 'NHWC'
32    elif ndim == 5:
33      return 'NDHWC'
34    else:
35      raise ValueError('Input rank not supported:', ndim)
36  elif data_format == 'channels_first':
37    if ndim == 3:
38      return 'NCW'
39    elif ndim == 4:
40      return 'NCHW'
41    elif ndim == 5:
42      return 'NCDHW'
43    else:
44      raise ValueError('Input rank not supported:', ndim)
45  else:
46    raise ValueError('Invalid data_format:', data_format)
47
48
49def normalize_tuple(value, n, name):
50  """Transforms a single integer or iterable of integers into an integer tuple.
51
52  Args:
53    value: The value to validate and convert. Could an int, or any iterable
54      of ints.
55    n: The size of the tuple to be returned.
56    name: The name of the argument being validated, e.g. "strides" or
57      "kernel_size". This is only used to format error messages.
58
59  Returns:
60    A tuple of n integers.
61
62  Raises:
63    ValueError: If something else than an int/long or iterable thereof was
64      passed.
65  """
66  if isinstance(value, int):
67    return (value,) * n
68  else:
69    try:
70      value_tuple = tuple(value)
71    except TypeError:
72      raise ValueError('The `' + name + '` argument must be a tuple of ' +
73                       str(n) + ' integers. Received: ' + str(value))
74    if len(value_tuple) != n:
75      raise ValueError('The `' + name + '` argument must be a tuple of ' +
76                       str(n) + ' integers. Received: ' + str(value))
77    for single_value in value_tuple:
78      try:
79        int(single_value)
80      except (ValueError, TypeError):
81        raise ValueError('The `' + name + '` argument must be a tuple of ' +
82                         str(n) + ' integers. Received: ' + str(value) + ' '
83                         'including element ' + str(single_value) + ' of type' +
84                         ' ' + str(type(single_value)))
85    return value_tuple
86
87
88def normalize_data_format(value):
89  data_format = value.lower()
90  if data_format not in {'channels_first', 'channels_last'}:
91    raise ValueError('The `data_format` argument must be one of '
92                     '"channels_first", "channels_last". Received: ' +
93                     str(value))
94  return data_format
95
96
97def normalize_padding(value):
98  padding = value.lower()
99  if padding not in {'valid', 'same'}:
100    raise ValueError('The `padding` argument must be one of "valid", "same". '
101                     'Received: ' + str(padding))
102  return padding
103
104
105def conv_output_length(input_length, filter_size, padding, stride, dilation=1):
106  """Determines output length of a convolution given input length.
107
108  Args:
109      input_length: integer.
110      filter_size: integer.
111      padding: one of "same", "valid", "full".
112      stride: integer.
113      dilation: dilation rate, integer.
114
115  Returns:
116      The output length (integer).
117  """
118  if input_length is None:
119    return None
120  assert padding in {'same', 'valid', 'full'}
121  dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1)
122  if padding == 'same':
123    output_length = input_length
124  elif padding == 'valid':
125    output_length = input_length - dilated_filter_size + 1
126  elif padding == 'full':
127    output_length = input_length + dilated_filter_size - 1
128  return (output_length + stride - 1) // stride
129
130
131def conv_input_length(output_length, filter_size, padding, stride):
132  """Determines input length of a convolution given output length.
133
134  Args:
135      output_length: integer.
136      filter_size: integer.
137      padding: one of "same", "valid", "full".
138      stride: integer.
139
140  Returns:
141      The input length (integer).
142  """
143  if output_length is None:
144    return None
145  assert padding in {'same', 'valid', 'full'}
146  if padding == 'same':
147    pad = filter_size // 2
148  elif padding == 'valid':
149    pad = 0
150  elif padding == 'full':
151    pad = filter_size - 1
152  return (output_length - 1) * stride - 2 * pad + filter_size
153
154
155def deconv_output_length(input_length, filter_size, padding, stride):
156  """Determines output length of a transposed convolution given input length.
157
158  Args:
159      input_length: integer.
160      filter_size: integer.
161      padding: one of "same", "valid", "full".
162      stride: integer.
163
164  Returns:
165      The output length (integer).
166  """
167  if input_length is None:
168    return None
169  input_length *= stride
170  if padding == 'valid':
171    input_length += max(filter_size - stride, 0)
172  elif padding == 'full':
173    input_length -= (stride + filter_size - 2)
174  return input_length
175
176
177def smart_cond(pred, true_fn=None, false_fn=None, name=None):
178  """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
179
180  If `pred` is a bool or has a constant value, we return either `true_fn()`
181  or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
182
183  Args:
184    pred: A scalar determining whether to return the result of `true_fn` or
185      `false_fn`.
186    true_fn: The callable to be performed if pred is true.
187    false_fn: The callable to be performed if pred is false.
188    name: Optional name prefix when using `tf.cond`.
189
190  Returns:
191    Tensors returned by the call to either `true_fn` or `false_fn`.
192
193  Raises:
194    TypeError: If `true_fn` or `false_fn` is not callable.
195  """
196  if isinstance(pred, variables.Variable):
197    return control_flow_ops.cond(
198        pred, true_fn=true_fn, false_fn=false_fn, name=name)
199  return smart_module.smart_cond(
200      pred, true_fn=true_fn, false_fn=false_fn, name=name)
201
202
203def constant_value(pred):
204  """Return the bool value for `pred`, or None if `pred` had a dynamic value.
205
206    Args:
207      pred: A scalar, either a Python bool or a TensorFlow boolean variable
208        or tensor, or the Python integer 1 or 0.
209
210    Returns:
211      True or False if `pred` has a constant boolean value, None otherwise.
212
213    Raises:
214      TypeError: If `pred` is not a Variable, Tensor or bool, or Python
215        integer 1 or 0.
216    """
217  # Allow integer booleans.
218  if isinstance(pred, int):
219    if pred == 1:
220      pred = True
221    elif pred == 0:
222      pred = False
223
224  if isinstance(pred, variables.Variable):
225    return None
226  return smart_module.smart_constant_value(pred)
227