1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for ImageNet data preprocessing & prediction decoding.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import json 21import warnings 22 23import numpy as np 24 25from tensorflow.python.keras import activations 26from tensorflow.python.keras import backend 27from tensorflow.python.keras.utils import data_utils 28from tensorflow.python.util.tf_export import keras_export 29 30 31CLASS_INDEX = None 32CLASS_INDEX_PATH = ('https://storage.googleapis.com/download.tensorflow.org/' 33 'data/imagenet_class_index.json') 34 35 36PREPROCESS_INPUT_DOC = """ 37 Preprocesses a tensor or Numpy array encoding a batch of images. 38 39 Usage example with `applications.MobileNet`: 40 41 ```python 42 i = tf.keras.layers.Input([None, None, 3], dtype = tf.uint8) 43 x = tf.cast(i, tf.float32) 44 x = tf.keras.applications.mobilenet.preprocess_input(x) 45 core = tf.keras.applications.MobileNet() 46 x = core(x) 47 model = tf.keras.Model(inputs=[i], outputs=[x]) 48 49 image = tf.image.decode_png(tf.io.read_file('file.png')) 50 result = model(image) 51 ``` 52 53 Args: 54 x: A floating point `numpy.array` or a `tf.Tensor`, 3D or 4D with 3 color 55 channels, with values in the range [0, 255]. 56 The preprocessed data are written over the input data 57 if the data types are compatible. To avoid this 58 behaviour, `numpy.copy(x)` can be used. 59 data_format: Optional data format of the image tensor/array. Defaults to 60 None, in which case the global setting 61 `tf.keras.backend.image_data_format()` is used (unless you changed it, 62 it defaults to "channels_last").{mode} 63 64 Returns: 65 Preprocessed `numpy.array` or a `tf.Tensor` with type `float32`. 66 {ret} 67 68 Raises: 69 {error} 70 """ 71 72PREPROCESS_INPUT_MODE_DOC = """ 73 mode: One of "caffe", "tf" or "torch". Defaults to "caffe". 74 - caffe: will convert the images from RGB to BGR, 75 then will zero-center each color channel with 76 respect to the ImageNet dataset, 77 without scaling. 78 - tf: will scale pixels between -1 and 1, 79 sample-wise. 80 - torch: will scale pixels between 0 and 1 and then 81 will normalize each channel with respect to the 82 ImageNet dataset. 83 """ 84 85PREPROCESS_INPUT_DEFAULT_ERROR_DOC = """ 86 ValueError: In case of unknown `mode` or `data_format` argument.""" 87 88PREPROCESS_INPUT_ERROR_DOC = """ 89 ValueError: In case of unknown `data_format` argument.""" 90 91PREPROCESS_INPUT_RET_DOC_TF = """ 92 The inputs pixel values are scaled between -1 and 1, sample-wise.""" 93 94PREPROCESS_INPUT_RET_DOC_TORCH = """ 95 The input pixels values are scaled between 0 and 1 and each channel is 96 normalized with respect to the ImageNet dataset.""" 97 98PREPROCESS_INPUT_RET_DOC_CAFFE = """ 99 The images are converted from RGB to BGR, then each color channel is 100 zero-centered with respect to the ImageNet dataset, without scaling.""" 101 102 103@keras_export('keras.applications.imagenet_utils.preprocess_input') 104def preprocess_input(x, data_format=None, mode='caffe'): 105 """Preprocesses a tensor or Numpy array encoding a batch of images.""" 106 if mode not in {'caffe', 'tf', 'torch'}: 107 raise ValueError('Unknown mode ' + str(mode)) 108 109 if data_format is None: 110 data_format = backend.image_data_format() 111 elif data_format not in {'channels_first', 'channels_last'}: 112 raise ValueError('Unknown data_format ' + str(data_format)) 113 114 if isinstance(x, np.ndarray): 115 return _preprocess_numpy_input( 116 x, data_format=data_format, mode=mode) 117 else: 118 return _preprocess_symbolic_input( 119 x, data_format=data_format, mode=mode) 120 121 122preprocess_input.__doc__ = PREPROCESS_INPUT_DOC.format( 123 mode=PREPROCESS_INPUT_MODE_DOC, 124 ret='', 125 error=PREPROCESS_INPUT_DEFAULT_ERROR_DOC) 126 127 128@keras_export('keras.applications.imagenet_utils.decode_predictions') 129def decode_predictions(preds, top=5): 130 """Decodes the prediction of an ImageNet model. 131 132 Args: 133 preds: Numpy array encoding a batch of predictions. 134 top: Integer, how many top-guesses to return. Defaults to 5. 135 136 Returns: 137 A list of lists of top class prediction tuples 138 `(class_name, class_description, score)`. 139 One list of tuples per sample in batch input. 140 141 Raises: 142 ValueError: In case of invalid shape of the `pred` array 143 (must be 2D). 144 """ 145 global CLASS_INDEX 146 147 if len(preds.shape) != 2 or preds.shape[1] != 1000: 148 raise ValueError('`decode_predictions` expects ' 149 'a batch of predictions ' 150 '(i.e. a 2D array of shape (samples, 1000)). ' 151 'Found array with shape: ' + str(preds.shape)) 152 if CLASS_INDEX is None: 153 fpath = data_utils.get_file( 154 'imagenet_class_index.json', 155 CLASS_INDEX_PATH, 156 cache_subdir='models', 157 file_hash='c2c37ea517e94d9795004a39431a14cb') 158 with open(fpath) as f: 159 CLASS_INDEX = json.load(f) 160 results = [] 161 for pred in preds: 162 top_indices = pred.argsort()[-top:][::-1] 163 result = [tuple(CLASS_INDEX[str(i)]) + (pred[i],) for i in top_indices] 164 result.sort(key=lambda x: x[2], reverse=True) 165 results.append(result) 166 return results 167 168 169def _preprocess_numpy_input(x, data_format, mode): 170 """Preprocesses a Numpy array encoding a batch of images. 171 172 Args: 173 x: Input array, 3D or 4D. 174 data_format: Data format of the image array. 175 mode: One of "caffe", "tf" or "torch". 176 - caffe: will convert the images from RGB to BGR, 177 then will zero-center each color channel with 178 respect to the ImageNet dataset, 179 without scaling. 180 - tf: will scale pixels between -1 and 1, 181 sample-wise. 182 - torch: will scale pixels between 0 and 1 and then 183 will normalize each channel with respect to the 184 ImageNet dataset. 185 186 Returns: 187 Preprocessed Numpy array. 188 """ 189 if not issubclass(x.dtype.type, np.floating): 190 x = x.astype(backend.floatx(), copy=False) 191 192 if mode == 'tf': 193 x /= 127.5 194 x -= 1. 195 return x 196 elif mode == 'torch': 197 x /= 255. 198 mean = [0.485, 0.456, 0.406] 199 std = [0.229, 0.224, 0.225] 200 else: 201 if data_format == 'channels_first': 202 # 'RGB'->'BGR' 203 if x.ndim == 3: 204 x = x[::-1, ...] 205 else: 206 x = x[:, ::-1, ...] 207 else: 208 # 'RGB'->'BGR' 209 x = x[..., ::-1] 210 mean = [103.939, 116.779, 123.68] 211 std = None 212 213 # Zero-center by mean pixel 214 if data_format == 'channels_first': 215 if x.ndim == 3: 216 x[0, :, :] -= mean[0] 217 x[1, :, :] -= mean[1] 218 x[2, :, :] -= mean[2] 219 if std is not None: 220 x[0, :, :] /= std[0] 221 x[1, :, :] /= std[1] 222 x[2, :, :] /= std[2] 223 else: 224 x[:, 0, :, :] -= mean[0] 225 x[:, 1, :, :] -= mean[1] 226 x[:, 2, :, :] -= mean[2] 227 if std is not None: 228 x[:, 0, :, :] /= std[0] 229 x[:, 1, :, :] /= std[1] 230 x[:, 2, :, :] /= std[2] 231 else: 232 x[..., 0] -= mean[0] 233 x[..., 1] -= mean[1] 234 x[..., 2] -= mean[2] 235 if std is not None: 236 x[..., 0] /= std[0] 237 x[..., 1] /= std[1] 238 x[..., 2] /= std[2] 239 return x 240 241 242def _preprocess_symbolic_input(x, data_format, mode): 243 """Preprocesses a tensor encoding a batch of images. 244 245 Args: 246 x: Input tensor, 3D or 4D. 247 data_format: Data format of the image tensor. 248 mode: One of "caffe", "tf" or "torch". 249 - caffe: will convert the images from RGB to BGR, 250 then will zero-center each color channel with 251 respect to the ImageNet dataset, 252 without scaling. 253 - tf: will scale pixels between -1 and 1, 254 sample-wise. 255 - torch: will scale pixels between 0 and 1 and then 256 will normalize each channel with respect to the 257 ImageNet dataset. 258 259 Returns: 260 Preprocessed tensor. 261 """ 262 if mode == 'tf': 263 x /= 127.5 264 x -= 1. 265 return x 266 elif mode == 'torch': 267 x /= 255. 268 mean = [0.485, 0.456, 0.406] 269 std = [0.229, 0.224, 0.225] 270 else: 271 if data_format == 'channels_first': 272 # 'RGB'->'BGR' 273 if backend.ndim(x) == 3: 274 x = x[::-1, ...] 275 else: 276 x = x[:, ::-1, ...] 277 else: 278 # 'RGB'->'BGR' 279 x = x[..., ::-1] 280 mean = [103.939, 116.779, 123.68] 281 std = None 282 283 mean_tensor = backend.constant(-np.array(mean)) 284 285 # Zero-center by mean pixel 286 if backend.dtype(x) != backend.dtype(mean_tensor): 287 x = backend.bias_add( 288 x, backend.cast(mean_tensor, backend.dtype(x)), data_format=data_format) 289 else: 290 x = backend.bias_add(x, mean_tensor, data_format) 291 if std is not None: 292 x /= std 293 return x 294 295 296def obtain_input_shape(input_shape, 297 default_size, 298 min_size, 299 data_format, 300 require_flatten, 301 weights=None): 302 """Internal utility to compute/validate a model's input shape. 303 304 Args: 305 input_shape: Either None (will return the default network input shape), 306 or a user-provided shape to be validated. 307 default_size: Default input width/height for the model. 308 min_size: Minimum input width/height accepted by the model. 309 data_format: Image data format to use. 310 require_flatten: Whether the model is expected to 311 be linked to a classifier via a Flatten layer. 312 weights: One of `None` (random initialization) 313 or 'imagenet' (pre-training on ImageNet). 314 If weights='imagenet' input channels must be equal to 3. 315 316 Returns: 317 An integer shape tuple (may include None entries). 318 319 Raises: 320 ValueError: In case of invalid argument values. 321 """ 322 if weights != 'imagenet' and input_shape and len(input_shape) == 3: 323 if data_format == 'channels_first': 324 if input_shape[0] not in {1, 3}: 325 warnings.warn('This model usually expects 1 or 3 input channels. ' 326 'However, it was passed an input_shape with ' + 327 str(input_shape[0]) + ' input channels.') 328 default_shape = (input_shape[0], default_size, default_size) 329 else: 330 if input_shape[-1] not in {1, 3}: 331 warnings.warn('This model usually expects 1 or 3 input channels. ' 332 'However, it was passed an input_shape with ' + 333 str(input_shape[-1]) + ' input channels.') 334 default_shape = (default_size, default_size, input_shape[-1]) 335 else: 336 if data_format == 'channels_first': 337 default_shape = (3, default_size, default_size) 338 else: 339 default_shape = (default_size, default_size, 3) 340 if weights == 'imagenet' and require_flatten: 341 if input_shape is not None: 342 if input_shape != default_shape: 343 raise ValueError('When setting `include_top=True` ' 344 'and loading `imagenet` weights, ' 345 '`input_shape` should be ' + str(default_shape) + '.') 346 return default_shape 347 if input_shape: 348 if data_format == 'channels_first': 349 if input_shape is not None: 350 if len(input_shape) != 3: 351 raise ValueError('`input_shape` must be a tuple of three integers.') 352 if input_shape[0] != 3 and weights == 'imagenet': 353 raise ValueError('The input must have 3 channels; got ' 354 '`input_shape=' + str(input_shape) + '`') 355 if ((input_shape[1] is not None and input_shape[1] < min_size) or 356 (input_shape[2] is not None and input_shape[2] < min_size)): 357 raise ValueError('Input size must be at least ' + str(min_size) + 358 'x' + str(min_size) + '; got `input_shape=' + 359 str(input_shape) + '`') 360 else: 361 if input_shape is not None: 362 if len(input_shape) != 3: 363 raise ValueError('`input_shape` must be a tuple of three integers.') 364 if input_shape[-1] != 3 and weights == 'imagenet': 365 raise ValueError('The input must have 3 channels; got ' 366 '`input_shape=' + str(input_shape) + '`') 367 if ((input_shape[0] is not None and input_shape[0] < min_size) or 368 (input_shape[1] is not None and input_shape[1] < min_size)): 369 raise ValueError('Input size must be at least ' + str(min_size) + 370 'x' + str(min_size) + '; got `input_shape=' + 371 str(input_shape) + '`') 372 else: 373 if require_flatten: 374 input_shape = default_shape 375 else: 376 if data_format == 'channels_first': 377 input_shape = (3, None, None) 378 else: 379 input_shape = (None, None, 3) 380 if require_flatten: 381 if None in input_shape: 382 raise ValueError('If `include_top` is True, ' 383 'you should specify a static `input_shape`. ' 384 'Got `input_shape=' + str(input_shape) + '`') 385 return input_shape 386 387 388def correct_pad(inputs, kernel_size): 389 """Returns a tuple for zero-padding for 2D convolution with downsampling. 390 391 Args: 392 inputs: Input tensor. 393 kernel_size: An integer or tuple/list of 2 integers. 394 395 Returns: 396 A tuple. 397 """ 398 img_dim = 2 if backend.image_data_format() == 'channels_first' else 1 399 input_size = backend.int_shape(inputs)[img_dim:(img_dim + 2)] 400 if isinstance(kernel_size, int): 401 kernel_size = (kernel_size, kernel_size) 402 if input_size[0] is None: 403 adjust = (1, 1) 404 else: 405 adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2) 406 correct = (kernel_size[0] // 2, kernel_size[1] // 2) 407 return ((correct[0] - adjust[0], correct[0]), 408 (correct[1] - adjust[1], correct[1])) 409 410 411def validate_activation(classifier_activation, weights): 412 """validates that the classifer_activation is compatible with the weights. 413 414 Args: 415 classifier_activation: str or callable activation function 416 weights: The pretrained weights to load. 417 418 Raises: 419 ValueError: if an activation other than `None` or `softmax` are used with 420 pretrained weights. 421 """ 422 if weights is None: 423 return 424 425 classifier_activation = activations.get(classifier_activation) 426 if classifier_activation not in { 427 activations.get('softmax'), 428 activations.get(None) 429 }: 430 raise ValueError('Only `None` and `softmax` activations are allowed ' 431 'for the `classifier_activation` argument when using ' 432 'pretrained weights, with `include_top=True`') 433