1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for ImageNet data preprocessing & prediction decoding."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import json
21import warnings
22
23import numpy as np
24
25from tensorflow.python.keras import activations
26from tensorflow.python.keras import backend
27from tensorflow.python.keras.utils import data_utils
28from tensorflow.python.util.tf_export import keras_export
29
30
31CLASS_INDEX = None
32CLASS_INDEX_PATH = ('https://storage.googleapis.com/download.tensorflow.org/'
33                    'data/imagenet_class_index.json')
34
35
36PREPROCESS_INPUT_DOC = """
37  Preprocesses a tensor or Numpy array encoding a batch of images.
38
39  Usage example with `applications.MobileNet`:
40
41  ```python
42  i = tf.keras.layers.Input([None, None, 3], dtype = tf.uint8)
43  x = tf.cast(i, tf.float32)
44  x = tf.keras.applications.mobilenet.preprocess_input(x)
45  core = tf.keras.applications.MobileNet()
46  x = core(x)
47  model = tf.keras.Model(inputs=[i], outputs=[x])
48
49  image = tf.image.decode_png(tf.io.read_file('file.png'))
50  result = model(image)
51  ```
52
53  Args:
54    x: A floating point `numpy.array` or a `tf.Tensor`, 3D or 4D with 3 color
55      channels, with values in the range [0, 255].
56      The preprocessed data are written over the input data
57      if the data types are compatible. To avoid this
58      behaviour, `numpy.copy(x)` can be used.
59    data_format: Optional data format of the image tensor/array. Defaults to
60      None, in which case the global setting
61      `tf.keras.backend.image_data_format()` is used (unless you changed it,
62      it defaults to "channels_last").{mode}
63
64  Returns:
65      Preprocessed `numpy.array` or a `tf.Tensor` with type `float32`.
66      {ret}
67
68  Raises:
69      {error}
70  """
71
72PREPROCESS_INPUT_MODE_DOC = """
73    mode: One of "caffe", "tf" or "torch". Defaults to "caffe".
74      - caffe: will convert the images from RGB to BGR,
75          then will zero-center each color channel with
76          respect to the ImageNet dataset,
77          without scaling.
78      - tf: will scale pixels between -1 and 1,
79          sample-wise.
80      - torch: will scale pixels between 0 and 1 and then
81          will normalize each channel with respect to the
82          ImageNet dataset.
83  """
84
85PREPROCESS_INPUT_DEFAULT_ERROR_DOC = """
86    ValueError: In case of unknown `mode` or `data_format` argument."""
87
88PREPROCESS_INPUT_ERROR_DOC = """
89    ValueError: In case of unknown `data_format` argument."""
90
91PREPROCESS_INPUT_RET_DOC_TF = """
92      The inputs pixel values are scaled between -1 and 1, sample-wise."""
93
94PREPROCESS_INPUT_RET_DOC_TORCH = """
95      The input pixels values are scaled between 0 and 1 and each channel is
96      normalized with respect to the ImageNet dataset."""
97
98PREPROCESS_INPUT_RET_DOC_CAFFE = """
99      The images are converted from RGB to BGR, then each color channel is
100      zero-centered with respect to the ImageNet dataset, without scaling."""
101
102
103@keras_export('keras.applications.imagenet_utils.preprocess_input')
104def preprocess_input(x, data_format=None, mode='caffe'):
105  """Preprocesses a tensor or Numpy array encoding a batch of images."""
106  if mode not in {'caffe', 'tf', 'torch'}:
107    raise ValueError('Unknown mode ' + str(mode))
108
109  if data_format is None:
110    data_format = backend.image_data_format()
111  elif data_format not in {'channels_first', 'channels_last'}:
112    raise ValueError('Unknown data_format ' + str(data_format))
113
114  if isinstance(x, np.ndarray):
115    return _preprocess_numpy_input(
116        x, data_format=data_format, mode=mode)
117  else:
118    return _preprocess_symbolic_input(
119        x, data_format=data_format, mode=mode)
120
121
122preprocess_input.__doc__ = PREPROCESS_INPUT_DOC.format(
123    mode=PREPROCESS_INPUT_MODE_DOC,
124    ret='',
125    error=PREPROCESS_INPUT_DEFAULT_ERROR_DOC)
126
127
128@keras_export('keras.applications.imagenet_utils.decode_predictions')
129def decode_predictions(preds, top=5):
130  """Decodes the prediction of an ImageNet model.
131
132  Args:
133    preds: Numpy array encoding a batch of predictions.
134    top: Integer, how many top-guesses to return. Defaults to 5.
135
136  Returns:
137    A list of lists of top class prediction tuples
138    `(class_name, class_description, score)`.
139    One list of tuples per sample in batch input.
140
141  Raises:
142    ValueError: In case of invalid shape of the `pred` array
143      (must be 2D).
144  """
145  global CLASS_INDEX
146
147  if len(preds.shape) != 2 or preds.shape[1] != 1000:
148    raise ValueError('`decode_predictions` expects '
149                     'a batch of predictions '
150                     '(i.e. a 2D array of shape (samples, 1000)). '
151                     'Found array with shape: ' + str(preds.shape))
152  if CLASS_INDEX is None:
153    fpath = data_utils.get_file(
154        'imagenet_class_index.json',
155        CLASS_INDEX_PATH,
156        cache_subdir='models',
157        file_hash='c2c37ea517e94d9795004a39431a14cb')
158    with open(fpath) as f:
159      CLASS_INDEX = json.load(f)
160  results = []
161  for pred in preds:
162    top_indices = pred.argsort()[-top:][::-1]
163    result = [tuple(CLASS_INDEX[str(i)]) + (pred[i],) for i in top_indices]
164    result.sort(key=lambda x: x[2], reverse=True)
165    results.append(result)
166  return results
167
168
169def _preprocess_numpy_input(x, data_format, mode):
170  """Preprocesses a Numpy array encoding a batch of images.
171
172  Args:
173    x: Input array, 3D or 4D.
174    data_format: Data format of the image array.
175    mode: One of "caffe", "tf" or "torch".
176      - caffe: will convert the images from RGB to BGR,
177          then will zero-center each color channel with
178          respect to the ImageNet dataset,
179          without scaling.
180      - tf: will scale pixels between -1 and 1,
181          sample-wise.
182      - torch: will scale pixels between 0 and 1 and then
183          will normalize each channel with respect to the
184          ImageNet dataset.
185
186  Returns:
187      Preprocessed Numpy array.
188  """
189  if not issubclass(x.dtype.type, np.floating):
190    x = x.astype(backend.floatx(), copy=False)
191
192  if mode == 'tf':
193    x /= 127.5
194    x -= 1.
195    return x
196  elif mode == 'torch':
197    x /= 255.
198    mean = [0.485, 0.456, 0.406]
199    std = [0.229, 0.224, 0.225]
200  else:
201    if data_format == 'channels_first':
202      # 'RGB'->'BGR'
203      if x.ndim == 3:
204        x = x[::-1, ...]
205      else:
206        x = x[:, ::-1, ...]
207    else:
208      # 'RGB'->'BGR'
209      x = x[..., ::-1]
210    mean = [103.939, 116.779, 123.68]
211    std = None
212
213  # Zero-center by mean pixel
214  if data_format == 'channels_first':
215    if x.ndim == 3:
216      x[0, :, :] -= mean[0]
217      x[1, :, :] -= mean[1]
218      x[2, :, :] -= mean[2]
219      if std is not None:
220        x[0, :, :] /= std[0]
221        x[1, :, :] /= std[1]
222        x[2, :, :] /= std[2]
223    else:
224      x[:, 0, :, :] -= mean[0]
225      x[:, 1, :, :] -= mean[1]
226      x[:, 2, :, :] -= mean[2]
227      if std is not None:
228        x[:, 0, :, :] /= std[0]
229        x[:, 1, :, :] /= std[1]
230        x[:, 2, :, :] /= std[2]
231  else:
232    x[..., 0] -= mean[0]
233    x[..., 1] -= mean[1]
234    x[..., 2] -= mean[2]
235    if std is not None:
236      x[..., 0] /= std[0]
237      x[..., 1] /= std[1]
238      x[..., 2] /= std[2]
239  return x
240
241
242def _preprocess_symbolic_input(x, data_format, mode):
243  """Preprocesses a tensor encoding a batch of images.
244
245  Args:
246    x: Input tensor, 3D or 4D.
247    data_format: Data format of the image tensor.
248    mode: One of "caffe", "tf" or "torch".
249      - caffe: will convert the images from RGB to BGR,
250          then will zero-center each color channel with
251          respect to the ImageNet dataset,
252          without scaling.
253      - tf: will scale pixels between -1 and 1,
254          sample-wise.
255      - torch: will scale pixels between 0 and 1 and then
256          will normalize each channel with respect to the
257          ImageNet dataset.
258
259  Returns:
260      Preprocessed tensor.
261  """
262  if mode == 'tf':
263    x /= 127.5
264    x -= 1.
265    return x
266  elif mode == 'torch':
267    x /= 255.
268    mean = [0.485, 0.456, 0.406]
269    std = [0.229, 0.224, 0.225]
270  else:
271    if data_format == 'channels_first':
272      # 'RGB'->'BGR'
273      if backend.ndim(x) == 3:
274        x = x[::-1, ...]
275      else:
276        x = x[:, ::-1, ...]
277    else:
278      # 'RGB'->'BGR'
279      x = x[..., ::-1]
280    mean = [103.939, 116.779, 123.68]
281    std = None
282
283  mean_tensor = backend.constant(-np.array(mean))
284
285  # Zero-center by mean pixel
286  if backend.dtype(x) != backend.dtype(mean_tensor):
287    x = backend.bias_add(
288        x, backend.cast(mean_tensor, backend.dtype(x)), data_format=data_format)
289  else:
290    x = backend.bias_add(x, mean_tensor, data_format)
291  if std is not None:
292    x /= std
293  return x
294
295
296def obtain_input_shape(input_shape,
297                       default_size,
298                       min_size,
299                       data_format,
300                       require_flatten,
301                       weights=None):
302  """Internal utility to compute/validate a model's input shape.
303
304  Args:
305    input_shape: Either None (will return the default network input shape),
306      or a user-provided shape to be validated.
307    default_size: Default input width/height for the model.
308    min_size: Minimum input width/height accepted by the model.
309    data_format: Image data format to use.
310    require_flatten: Whether the model is expected to
311      be linked to a classifier via a Flatten layer.
312    weights: One of `None` (random initialization)
313      or 'imagenet' (pre-training on ImageNet).
314      If weights='imagenet' input channels must be equal to 3.
315
316  Returns:
317    An integer shape tuple (may include None entries).
318
319  Raises:
320    ValueError: In case of invalid argument values.
321  """
322  if weights != 'imagenet' and input_shape and len(input_shape) == 3:
323    if data_format == 'channels_first':
324      if input_shape[0] not in {1, 3}:
325        warnings.warn('This model usually expects 1 or 3 input channels. '
326                      'However, it was passed an input_shape with ' +
327                      str(input_shape[0]) + ' input channels.')
328      default_shape = (input_shape[0], default_size, default_size)
329    else:
330      if input_shape[-1] not in {1, 3}:
331        warnings.warn('This model usually expects 1 or 3 input channels. '
332                      'However, it was passed an input_shape with ' +
333                      str(input_shape[-1]) + ' input channels.')
334      default_shape = (default_size, default_size, input_shape[-1])
335  else:
336    if data_format == 'channels_first':
337      default_shape = (3, default_size, default_size)
338    else:
339      default_shape = (default_size, default_size, 3)
340  if weights == 'imagenet' and require_flatten:
341    if input_shape is not None:
342      if input_shape != default_shape:
343        raise ValueError('When setting `include_top=True` '
344                         'and loading `imagenet` weights, '
345                         '`input_shape` should be ' + str(default_shape) + '.')
346    return default_shape
347  if input_shape:
348    if data_format == 'channels_first':
349      if input_shape is not None:
350        if len(input_shape) != 3:
351          raise ValueError('`input_shape` must be a tuple of three integers.')
352        if input_shape[0] != 3 and weights == 'imagenet':
353          raise ValueError('The input must have 3 channels; got '
354                           '`input_shape=' + str(input_shape) + '`')
355        if ((input_shape[1] is not None and input_shape[1] < min_size) or
356            (input_shape[2] is not None and input_shape[2] < min_size)):
357          raise ValueError('Input size must be at least ' + str(min_size) +
358                           'x' + str(min_size) + '; got `input_shape=' +
359                           str(input_shape) + '`')
360    else:
361      if input_shape is not None:
362        if len(input_shape) != 3:
363          raise ValueError('`input_shape` must be a tuple of three integers.')
364        if input_shape[-1] != 3 and weights == 'imagenet':
365          raise ValueError('The input must have 3 channels; got '
366                           '`input_shape=' + str(input_shape) + '`')
367        if ((input_shape[0] is not None and input_shape[0] < min_size) or
368            (input_shape[1] is not None and input_shape[1] < min_size)):
369          raise ValueError('Input size must be at least ' + str(min_size) +
370                           'x' + str(min_size) + '; got `input_shape=' +
371                           str(input_shape) + '`')
372  else:
373    if require_flatten:
374      input_shape = default_shape
375    else:
376      if data_format == 'channels_first':
377        input_shape = (3, None, None)
378      else:
379        input_shape = (None, None, 3)
380  if require_flatten:
381    if None in input_shape:
382      raise ValueError('If `include_top` is True, '
383                       'you should specify a static `input_shape`. '
384                       'Got `input_shape=' + str(input_shape) + '`')
385  return input_shape
386
387
388def correct_pad(inputs, kernel_size):
389  """Returns a tuple for zero-padding for 2D convolution with downsampling.
390
391  Args:
392    inputs: Input tensor.
393    kernel_size: An integer or tuple/list of 2 integers.
394
395  Returns:
396    A tuple.
397  """
398  img_dim = 2 if backend.image_data_format() == 'channels_first' else 1
399  input_size = backend.int_shape(inputs)[img_dim:(img_dim + 2)]
400  if isinstance(kernel_size, int):
401    kernel_size = (kernel_size, kernel_size)
402  if input_size[0] is None:
403    adjust = (1, 1)
404  else:
405    adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2)
406  correct = (kernel_size[0] // 2, kernel_size[1] // 2)
407  return ((correct[0] - adjust[0], correct[0]),
408          (correct[1] - adjust[1], correct[1]))
409
410
411def validate_activation(classifier_activation, weights):
412  """validates that the classifer_activation is compatible with the weights.
413
414  Args:
415    classifier_activation: str or callable activation function
416    weights: The pretrained weights to load.
417
418  Raises:
419    ValueError: if an activation other than `None` or `softmax` are used with
420      pretrained weights.
421  """
422  if weights is None:
423    return
424
425  classifier_activation = activations.get(classifier_activation)
426  if classifier_activation not in {
427      activations.get('softmax'),
428      activations.get(None)
429  }:
430    raise ValueError('Only `None` and `softmax` activations are allowed '
431                     'for the `classifier_activation` argument when using '
432                     'pretrained weights, with `include_top=True`')
433