1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15 16"""Contains layer utilities for input validation and format conversion.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import smart_cond as smart_module 22from tensorflow.python.ops import control_flow_ops 23from tensorflow.python.ops import variables 24 25 26def convert_data_format(data_format, ndim): 27 if data_format == 'channels_last': 28 if ndim == 3: 29 return 'NWC' 30 elif ndim == 4: 31 return 'NHWC' 32 elif ndim == 5: 33 return 'NDHWC' 34 else: 35 raise ValueError('Input rank not supported:', ndim) 36 elif data_format == 'channels_first': 37 if ndim == 3: 38 return 'NCW' 39 elif ndim == 4: 40 return 'NCHW' 41 elif ndim == 5: 42 return 'NCDHW' 43 else: 44 raise ValueError('Input rank not supported:', ndim) 45 else: 46 raise ValueError('Invalid data_format:', data_format) 47 48 49def normalize_tuple(value, n, name): 50 """Transforms a single integer or iterable of integers into an integer tuple. 51 52 Args: 53 value: The value to validate and convert. Could an int, or any iterable 54 of ints. 55 n: The size of the tuple to be returned. 56 name: The name of the argument being validated, e.g. "strides" or 57 "kernel_size". This is only used to format error messages. 58 59 Returns: 60 A tuple of n integers. 61 62 Raises: 63 ValueError: If something else than an int/long or iterable thereof was 64 passed. 65 """ 66 if isinstance(value, int): 67 return (value,) * n 68 else: 69 try: 70 value_tuple = tuple(value) 71 except TypeError: 72 raise ValueError('The `' + name + '` argument must be a tuple of ' + 73 str(n) + ' integers. Received: ' + str(value)) 74 if len(value_tuple) != n: 75 raise ValueError('The `' + name + '` argument must be a tuple of ' + 76 str(n) + ' integers. Received: ' + str(value)) 77 for single_value in value_tuple: 78 try: 79 int(single_value) 80 except (ValueError, TypeError): 81 raise ValueError('The `' + name + '` argument must be a tuple of ' + 82 str(n) + ' integers. Received: ' + str(value) + ' ' 83 'including element ' + str(single_value) + ' of type' + 84 ' ' + str(type(single_value))) 85 return value_tuple 86 87 88def normalize_data_format(value): 89 data_format = value.lower() 90 if data_format not in {'channels_first', 'channels_last'}: 91 raise ValueError('The `data_format` argument must be one of ' 92 '"channels_first", "channels_last". Received: ' + 93 str(value)) 94 return data_format 95 96 97def normalize_padding(value): 98 padding = value.lower() 99 if padding not in {'valid', 'same'}: 100 raise ValueError('The `padding` argument must be one of "valid", "same". ' 101 'Received: ' + str(padding)) 102 return padding 103 104 105def conv_output_length(input_length, filter_size, padding, stride, dilation=1): 106 """Determines output length of a convolution given input length. 107 108 Args: 109 input_length: integer. 110 filter_size: integer. 111 padding: one of "same", "valid", "full". 112 stride: integer. 113 dilation: dilation rate, integer. 114 115 Returns: 116 The output length (integer). 117 """ 118 if input_length is None: 119 return None 120 assert padding in {'same', 'valid', 'full'} 121 dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1) 122 if padding == 'same': 123 output_length = input_length 124 elif padding == 'valid': 125 output_length = input_length - dilated_filter_size + 1 126 elif padding == 'full': 127 output_length = input_length + dilated_filter_size - 1 128 return (output_length + stride - 1) // stride 129 130 131def conv_input_length(output_length, filter_size, padding, stride): 132 """Determines input length of a convolution given output length. 133 134 Args: 135 output_length: integer. 136 filter_size: integer. 137 padding: one of "same", "valid", "full". 138 stride: integer. 139 140 Returns: 141 The input length (integer). 142 """ 143 if output_length is None: 144 return None 145 assert padding in {'same', 'valid', 'full'} 146 if padding == 'same': 147 pad = filter_size // 2 148 elif padding == 'valid': 149 pad = 0 150 elif padding == 'full': 151 pad = filter_size - 1 152 return (output_length - 1) * stride - 2 * pad + filter_size 153 154 155def deconv_output_length(input_length, filter_size, padding, stride): 156 """Determines output length of a transposed convolution given input length. 157 158 Args: 159 input_length: integer. 160 filter_size: integer. 161 padding: one of "same", "valid", "full". 162 stride: integer. 163 164 Returns: 165 The output length (integer). 166 """ 167 if input_length is None: 168 return None 169 input_length *= stride 170 if padding == 'valid': 171 input_length += max(filter_size - stride, 0) 172 elif padding == 'full': 173 input_length -= (stride + filter_size - 2) 174 return input_length 175 176 177def smart_cond(pred, true_fn=None, false_fn=None, name=None): 178 """Return either `true_fn()` if predicate `pred` is true else `false_fn()`. 179 180 If `pred` is a bool or has a constant value, we return either `true_fn()` 181 or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. 182 183 Args: 184 pred: A scalar determining whether to return the result of `true_fn` or 185 `false_fn`. 186 true_fn: The callable to be performed if pred is true. 187 false_fn: The callable to be performed if pred is false. 188 name: Optional name prefix when using `tf.cond`. 189 190 Returns: 191 Tensors returned by the call to either `true_fn` or `false_fn`. 192 193 Raises: 194 TypeError: If `true_fn` or `false_fn` is not callable. 195 """ 196 if isinstance(pred, variables.Variable): 197 return control_flow_ops.cond( 198 pred, true_fn=true_fn, false_fn=false_fn, name=name) 199 return smart_module.smart_cond( 200 pred, true_fn=true_fn, false_fn=false_fn, name=name) 201 202 203def constant_value(pred): 204 """Return the bool value for `pred`, or None if `pred` had a dynamic value. 205 206 Args: 207 pred: A scalar, either a Python bool or a TensorFlow boolean variable 208 or tensor, or the Python integer 1 or 0. 209 210 Returns: 211 True or False if `pred` has a constant boolean value, None otherwise. 212 213 Raises: 214 TypeError: If `pred` is not a Variable, Tensor or bool, or Python 215 integer 1 or 0. 216 """ 217 # Allow integer booleans. 218 if isinstance(pred, int): 219 if pred == 1: 220 pred = True 221 elif pred == 0: 222 pred = False 223 224 if isinstance(pred, variables.Variable): 225 return None 226 return smart_module.smart_constant_value(pred) 227