1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=invalid-name
16"""Inception-ResNet V2 model for Keras.
17
18
19Reference:
20  - [Inception-v4, Inception-ResNet and the Impact of
21     Residual Connections on Learning](https://arxiv.org/abs/1602.07261)
22    (AAAI 2017)
23"""
24from __future__ import absolute_import
25from __future__ import division
26from __future__ import print_function
27
28from tensorflow.python.keras import backend
29from tensorflow.python.keras.applications import imagenet_utils
30from tensorflow.python.keras.engine import training
31from tensorflow.python.keras.layers import VersionAwareLayers
32from tensorflow.python.keras.utils import data_utils
33from tensorflow.python.keras.utils import layer_utils
34from tensorflow.python.lib.io import file_io
35from tensorflow.python.util.tf_export import keras_export
36
37
38BASE_WEIGHT_URL = ('https://storage.googleapis.com/tensorflow/'
39                   'keras-applications/inception_resnet_v2/')
40layers = None
41
42
43@keras_export('keras.applications.inception_resnet_v2.InceptionResNetV2',
44              'keras.applications.InceptionResNetV2')
45def InceptionResNetV2(include_top=True,
46                      weights='imagenet',
47                      input_tensor=None,
48                      input_shape=None,
49                      pooling=None,
50                      classes=1000,
51                      classifier_activation='softmax',
52                      **kwargs):
53  """Instantiates the Inception-ResNet v2 architecture.
54
55  Reference:
56  - [Inception-v4, Inception-ResNet and the Impact of
57     Residual Connections on Learning](https://arxiv.org/abs/1602.07261)
58    (AAAI 2017)
59
60  Optionally loads weights pre-trained on ImageNet.
61  Note that the data format convention used by the model is
62  the one specified in your Keras config at `~/.keras/keras.json`.
63
64  Note: each Keras Application expects a specific kind of input preprocessing.
65  For InceptionResNetV2, call
66  `tf.keras.applications.inception_resnet_v2.preprocess_input`
67  on your inputs before passing them to the model.
68
69  Args:
70    include_top: whether to include the fully-connected
71      layer at the top of the network.
72    weights: one of `None` (random initialization),
73      'imagenet' (pre-training on ImageNet),
74      or the path to the weights file to be loaded.
75    input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
76      to use as image input for the model.
77    input_shape: optional shape tuple, only to be specified
78      if `include_top` is `False` (otherwise the input shape
79      has to be `(299, 299, 3)` (with `'channels_last'` data format)
80      or `(3, 299, 299)` (with `'channels_first'` data format).
81      It should have exactly 3 inputs channels,
82      and width and height should be no smaller than 75.
83      E.g. `(150, 150, 3)` would be one valid value.
84    pooling: Optional pooling mode for feature extraction
85      when `include_top` is `False`.
86      - `None` means that the output of the model will be
87          the 4D tensor output of the last convolutional block.
88      - `'avg'` means that global average pooling
89          will be applied to the output of the
90          last convolutional block, and thus
91          the output of the model will be a 2D tensor.
92      - `'max'` means that global max pooling will be applied.
93    classes: optional number of classes to classify images
94      into, only to be specified if `include_top` is `True`, and
95      if no `weights` argument is specified.
96    classifier_activation: A `str` or callable. The activation function to use
97      on the "top" layer. Ignored unless `include_top=True`. Set
98      `classifier_activation=None` to return the logits of the "top" layer.
99    **kwargs: For backwards compatibility only.
100
101  Returns:
102    A `keras.Model` instance.
103
104  Raises:
105    ValueError: in case of invalid argument for `weights`,
106      or invalid input shape.
107    ValueError: if `classifier_activation` is not `softmax` or `None` when
108      using a pretrained top layer.
109  """
110  global layers
111  if 'layers' in kwargs:
112    layers = kwargs.pop('layers')
113  else:
114    layers = VersionAwareLayers()
115  if kwargs:
116    raise ValueError('Unknown argument(s): %s' % (kwargs,))
117  if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
118    raise ValueError('The `weights` argument should be either '
119                     '`None` (random initialization), `imagenet` '
120                     '(pre-training on ImageNet), '
121                     'or the path to the weights file to be loaded.')
122
123  if weights == 'imagenet' and include_top and classes != 1000:
124    raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
125                     ' as true, `classes` should be 1000')
126
127  # Determine proper input shape
128  input_shape = imagenet_utils.obtain_input_shape(
129      input_shape,
130      default_size=299,
131      min_size=75,
132      data_format=backend.image_data_format(),
133      require_flatten=include_top,
134      weights=weights)
135
136  if input_tensor is None:
137    img_input = layers.Input(shape=input_shape)
138  else:
139    if not backend.is_keras_tensor(input_tensor):
140      img_input = layers.Input(tensor=input_tensor, shape=input_shape)
141    else:
142      img_input = input_tensor
143
144  # Stem block: 35 x 35 x 192
145  x = conv2d_bn(img_input, 32, 3, strides=2, padding='valid')
146  x = conv2d_bn(x, 32, 3, padding='valid')
147  x = conv2d_bn(x, 64, 3)
148  x = layers.MaxPooling2D(3, strides=2)(x)
149  x = conv2d_bn(x, 80, 1, padding='valid')
150  x = conv2d_bn(x, 192, 3, padding='valid')
151  x = layers.MaxPooling2D(3, strides=2)(x)
152
153  # Mixed 5b (Inception-A block): 35 x 35 x 320
154  branch_0 = conv2d_bn(x, 96, 1)
155  branch_1 = conv2d_bn(x, 48, 1)
156  branch_1 = conv2d_bn(branch_1, 64, 5)
157  branch_2 = conv2d_bn(x, 64, 1)
158  branch_2 = conv2d_bn(branch_2, 96, 3)
159  branch_2 = conv2d_bn(branch_2, 96, 3)
160  branch_pool = layers.AveragePooling2D(3, strides=1, padding='same')(x)
161  branch_pool = conv2d_bn(branch_pool, 64, 1)
162  branches = [branch_0, branch_1, branch_2, branch_pool]
163  channel_axis = 1 if backend.image_data_format() == 'channels_first' else 3
164  x = layers.Concatenate(axis=channel_axis, name='mixed_5b')(branches)
165
166  # 10x block35 (Inception-ResNet-A block): 35 x 35 x 320
167  for block_idx in range(1, 11):
168    x = inception_resnet_block(
169        x, scale=0.17, block_type='block35', block_idx=block_idx)
170
171  # Mixed 6a (Reduction-A block): 17 x 17 x 1088
172  branch_0 = conv2d_bn(x, 384, 3, strides=2, padding='valid')
173  branch_1 = conv2d_bn(x, 256, 1)
174  branch_1 = conv2d_bn(branch_1, 256, 3)
175  branch_1 = conv2d_bn(branch_1, 384, 3, strides=2, padding='valid')
176  branch_pool = layers.MaxPooling2D(3, strides=2, padding='valid')(x)
177  branches = [branch_0, branch_1, branch_pool]
178  x = layers.Concatenate(axis=channel_axis, name='mixed_6a')(branches)
179
180  # 20x block17 (Inception-ResNet-B block): 17 x 17 x 1088
181  for block_idx in range(1, 21):
182    x = inception_resnet_block(
183        x, scale=0.1, block_type='block17', block_idx=block_idx)
184
185  # Mixed 7a (Reduction-B block): 8 x 8 x 2080
186  branch_0 = conv2d_bn(x, 256, 1)
187  branch_0 = conv2d_bn(branch_0, 384, 3, strides=2, padding='valid')
188  branch_1 = conv2d_bn(x, 256, 1)
189  branch_1 = conv2d_bn(branch_1, 288, 3, strides=2, padding='valid')
190  branch_2 = conv2d_bn(x, 256, 1)
191  branch_2 = conv2d_bn(branch_2, 288, 3)
192  branch_2 = conv2d_bn(branch_2, 320, 3, strides=2, padding='valid')
193  branch_pool = layers.MaxPooling2D(3, strides=2, padding='valid')(x)
194  branches = [branch_0, branch_1, branch_2, branch_pool]
195  x = layers.Concatenate(axis=channel_axis, name='mixed_7a')(branches)
196
197  # 10x block8 (Inception-ResNet-C block): 8 x 8 x 2080
198  for block_idx in range(1, 10):
199    x = inception_resnet_block(
200        x, scale=0.2, block_type='block8', block_idx=block_idx)
201  x = inception_resnet_block(
202      x, scale=1., activation=None, block_type='block8', block_idx=10)
203
204  # Final convolution block: 8 x 8 x 1536
205  x = conv2d_bn(x, 1536, 1, name='conv_7b')
206
207  if include_top:
208    # Classification block
209    x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
210    imagenet_utils.validate_activation(classifier_activation, weights)
211    x = layers.Dense(classes, activation=classifier_activation,
212                     name='predictions')(x)
213  else:
214    if pooling == 'avg':
215      x = layers.GlobalAveragePooling2D()(x)
216    elif pooling == 'max':
217      x = layers.GlobalMaxPooling2D()(x)
218
219  # Ensure that the model takes into account
220  # any potential predecessors of `input_tensor`.
221  if input_tensor is not None:
222    inputs = layer_utils.get_source_inputs(input_tensor)
223  else:
224    inputs = img_input
225
226  # Create model.
227  model = training.Model(inputs, x, name='inception_resnet_v2')
228
229  # Load weights.
230  if weights == 'imagenet':
231    if include_top:
232      fname = 'inception_resnet_v2_weights_tf_dim_ordering_tf_kernels.h5'
233      weights_path = data_utils.get_file(
234          fname,
235          BASE_WEIGHT_URL + fname,
236          cache_subdir='models',
237          file_hash='e693bd0210a403b3192acc6073ad2e96')
238    else:
239      fname = ('inception_resnet_v2_weights_'
240               'tf_dim_ordering_tf_kernels_notop.h5')
241      weights_path = data_utils.get_file(
242          fname,
243          BASE_WEIGHT_URL + fname,
244          cache_subdir='models',
245          file_hash='d19885ff4a710c122648d3b5c3b684e4')
246    model.load_weights(weights_path)
247  elif weights is not None:
248    model.load_weights(weights)
249
250  return model
251
252
253def conv2d_bn(x,
254              filters,
255              kernel_size,
256              strides=1,
257              padding='same',
258              activation='relu',
259              use_bias=False,
260              name=None):
261  """Utility function to apply conv + BN.
262
263  Args:
264    x: input tensor.
265    filters: filters in `Conv2D`.
266    kernel_size: kernel size as in `Conv2D`.
267    strides: strides in `Conv2D`.
268    padding: padding mode in `Conv2D`.
269    activation: activation in `Conv2D`.
270    use_bias: whether to use a bias in `Conv2D`.
271    name: name of the ops; will become `name + '_ac'` for the activation
272        and `name + '_bn'` for the batch norm layer.
273
274  Returns:
275    Output tensor after applying `Conv2D` and `BatchNormalization`.
276  """
277  x = layers.Conv2D(
278      filters,
279      kernel_size,
280      strides=strides,
281      padding=padding,
282      use_bias=use_bias,
283      name=name)(
284          x)
285  if not use_bias:
286    bn_axis = 1 if backend.image_data_format() == 'channels_first' else 3
287    bn_name = None if name is None else name + '_bn'
288    x = layers.BatchNormalization(axis=bn_axis, scale=False, name=bn_name)(x)
289  if activation is not None:
290    ac_name = None if name is None else name + '_ac'
291    x = layers.Activation(activation, name=ac_name)(x)
292  return x
293
294
295def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'):
296  """Adds an Inception-ResNet block.
297
298  This function builds 3 types of Inception-ResNet blocks mentioned
299  in the paper, controlled by the `block_type` argument (which is the
300  block name used in the official TF-slim implementation):
301  - Inception-ResNet-A: `block_type='block35'`
302  - Inception-ResNet-B: `block_type='block17'`
303  - Inception-ResNet-C: `block_type='block8'`
304
305  Args:
306    x: input tensor.
307    scale: scaling factor to scale the residuals (i.e., the output of passing
308      `x` through an inception module) before adding them to the shortcut
309      branch. Let `r` be the output from the residual branch, the output of this
310      block will be `x + scale * r`.
311    block_type: `'block35'`, `'block17'` or `'block8'`, determines the network
312      structure in the residual branch.
313    block_idx: an `int` used for generating layer names. The Inception-ResNet
314      blocks are repeated many times in this network. We use `block_idx` to
315      identify each of the repetitions. For example, the first
316      Inception-ResNet-A block will have `block_type='block35', block_idx=0`,
317      and the layer names will have a common prefix `'block35_0'`.
318    activation: activation function to use at the end of the block (see
319      [activations](../activations.md)). When `activation=None`, no activation
320      is applied
321      (i.e., "linear" activation: `a(x) = x`).
322
323  Returns:
324      Output tensor for the block.
325
326  Raises:
327    ValueError: if `block_type` is not one of `'block35'`,
328      `'block17'` or `'block8'`.
329  """
330  if block_type == 'block35':
331    branch_0 = conv2d_bn(x, 32, 1)
332    branch_1 = conv2d_bn(x, 32, 1)
333    branch_1 = conv2d_bn(branch_1, 32, 3)
334    branch_2 = conv2d_bn(x, 32, 1)
335    branch_2 = conv2d_bn(branch_2, 48, 3)
336    branch_2 = conv2d_bn(branch_2, 64, 3)
337    branches = [branch_0, branch_1, branch_2]
338  elif block_type == 'block17':
339    branch_0 = conv2d_bn(x, 192, 1)
340    branch_1 = conv2d_bn(x, 128, 1)
341    branch_1 = conv2d_bn(branch_1, 160, [1, 7])
342    branch_1 = conv2d_bn(branch_1, 192, [7, 1])
343    branches = [branch_0, branch_1]
344  elif block_type == 'block8':
345    branch_0 = conv2d_bn(x, 192, 1)
346    branch_1 = conv2d_bn(x, 192, 1)
347    branch_1 = conv2d_bn(branch_1, 224, [1, 3])
348    branch_1 = conv2d_bn(branch_1, 256, [3, 1])
349    branches = [branch_0, branch_1]
350  else:
351    raise ValueError('Unknown Inception-ResNet block type. '
352                     'Expects "block35", "block17" or "block8", '
353                     'but got: ' + str(block_type))
354
355  block_name = block_type + '_' + str(block_idx)
356  channel_axis = 1 if backend.image_data_format() == 'channels_first' else 3
357  mixed = layers.Concatenate(
358      axis=channel_axis, name=block_name + '_mixed')(
359          branches)
360  up = conv2d_bn(
361      mixed,
362      backend.int_shape(x)[channel_axis],
363      1,
364      activation=None,
365      use_bias=True,
366      name=block_name + '_conv')
367
368  x = layers.Lambda(
369      lambda inputs, scale: inputs[0] + inputs[1] * scale,
370      output_shape=backend.int_shape(x)[1:],
371      arguments={'scale': scale},
372      name=block_name)([x, up])
373  if activation is not None:
374    x = layers.Activation(activation, name=block_name + '_ac')(x)
375  return x
376
377
378@keras_export('keras.applications.inception_resnet_v2.preprocess_input')
379def preprocess_input(x, data_format=None):
380  return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf')
381
382
383@keras_export('keras.applications.inception_resnet_v2.decode_predictions')
384def decode_predictions(preds, top=5):
385  return imagenet_utils.decode_predictions(preds, top=top)
386
387
388preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
389    mode='',
390    ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,
391    error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
392decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
393