1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=invalid-name 16"""Inception-ResNet V2 model for Keras. 17 18 19Reference: 20 - [Inception-v4, Inception-ResNet and the Impact of 21 Residual Connections on Learning](https://arxiv.org/abs/1602.07261) 22 (AAAI 2017) 23""" 24from __future__ import absolute_import 25from __future__ import division 26from __future__ import print_function 27 28from tensorflow.python.keras import backend 29from tensorflow.python.keras.applications import imagenet_utils 30from tensorflow.python.keras.engine import training 31from tensorflow.python.keras.layers import VersionAwareLayers 32from tensorflow.python.keras.utils import data_utils 33from tensorflow.python.keras.utils import layer_utils 34from tensorflow.python.lib.io import file_io 35from tensorflow.python.util.tf_export import keras_export 36 37 38BASE_WEIGHT_URL = ('https://storage.googleapis.com/tensorflow/' 39 'keras-applications/inception_resnet_v2/') 40layers = None 41 42 43@keras_export('keras.applications.inception_resnet_v2.InceptionResNetV2', 44 'keras.applications.InceptionResNetV2') 45def InceptionResNetV2(include_top=True, 46 weights='imagenet', 47 input_tensor=None, 48 input_shape=None, 49 pooling=None, 50 classes=1000, 51 classifier_activation='softmax', 52 **kwargs): 53 """Instantiates the Inception-ResNet v2 architecture. 54 55 Reference: 56 - [Inception-v4, Inception-ResNet and the Impact of 57 Residual Connections on Learning](https://arxiv.org/abs/1602.07261) 58 (AAAI 2017) 59 60 Optionally loads weights pre-trained on ImageNet. 61 Note that the data format convention used by the model is 62 the one specified in your Keras config at `~/.keras/keras.json`. 63 64 Note: each Keras Application expects a specific kind of input preprocessing. 65 For InceptionResNetV2, call 66 `tf.keras.applications.inception_resnet_v2.preprocess_input` 67 on your inputs before passing them to the model. 68 69 Args: 70 include_top: whether to include the fully-connected 71 layer at the top of the network. 72 weights: one of `None` (random initialization), 73 'imagenet' (pre-training on ImageNet), 74 or the path to the weights file to be loaded. 75 input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) 76 to use as image input for the model. 77 input_shape: optional shape tuple, only to be specified 78 if `include_top` is `False` (otherwise the input shape 79 has to be `(299, 299, 3)` (with `'channels_last'` data format) 80 or `(3, 299, 299)` (with `'channels_first'` data format). 81 It should have exactly 3 inputs channels, 82 and width and height should be no smaller than 75. 83 E.g. `(150, 150, 3)` would be one valid value. 84 pooling: Optional pooling mode for feature extraction 85 when `include_top` is `False`. 86 - `None` means that the output of the model will be 87 the 4D tensor output of the last convolutional block. 88 - `'avg'` means that global average pooling 89 will be applied to the output of the 90 last convolutional block, and thus 91 the output of the model will be a 2D tensor. 92 - `'max'` means that global max pooling will be applied. 93 classes: optional number of classes to classify images 94 into, only to be specified if `include_top` is `True`, and 95 if no `weights` argument is specified. 96 classifier_activation: A `str` or callable. The activation function to use 97 on the "top" layer. Ignored unless `include_top=True`. Set 98 `classifier_activation=None` to return the logits of the "top" layer. 99 **kwargs: For backwards compatibility only. 100 101 Returns: 102 A `keras.Model` instance. 103 104 Raises: 105 ValueError: in case of invalid argument for `weights`, 106 or invalid input shape. 107 ValueError: if `classifier_activation` is not `softmax` or `None` when 108 using a pretrained top layer. 109 """ 110 global layers 111 if 'layers' in kwargs: 112 layers = kwargs.pop('layers') 113 else: 114 layers = VersionAwareLayers() 115 if kwargs: 116 raise ValueError('Unknown argument(s): %s' % (kwargs,)) 117 if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)): 118 raise ValueError('The `weights` argument should be either ' 119 '`None` (random initialization), `imagenet` ' 120 '(pre-training on ImageNet), ' 121 'or the path to the weights file to be loaded.') 122 123 if weights == 'imagenet' and include_top and classes != 1000: 124 raise ValueError('If using `weights` as `"imagenet"` with `include_top`' 125 ' as true, `classes` should be 1000') 126 127 # Determine proper input shape 128 input_shape = imagenet_utils.obtain_input_shape( 129 input_shape, 130 default_size=299, 131 min_size=75, 132 data_format=backend.image_data_format(), 133 require_flatten=include_top, 134 weights=weights) 135 136 if input_tensor is None: 137 img_input = layers.Input(shape=input_shape) 138 else: 139 if not backend.is_keras_tensor(input_tensor): 140 img_input = layers.Input(tensor=input_tensor, shape=input_shape) 141 else: 142 img_input = input_tensor 143 144 # Stem block: 35 x 35 x 192 145 x = conv2d_bn(img_input, 32, 3, strides=2, padding='valid') 146 x = conv2d_bn(x, 32, 3, padding='valid') 147 x = conv2d_bn(x, 64, 3) 148 x = layers.MaxPooling2D(3, strides=2)(x) 149 x = conv2d_bn(x, 80, 1, padding='valid') 150 x = conv2d_bn(x, 192, 3, padding='valid') 151 x = layers.MaxPooling2D(3, strides=2)(x) 152 153 # Mixed 5b (Inception-A block): 35 x 35 x 320 154 branch_0 = conv2d_bn(x, 96, 1) 155 branch_1 = conv2d_bn(x, 48, 1) 156 branch_1 = conv2d_bn(branch_1, 64, 5) 157 branch_2 = conv2d_bn(x, 64, 1) 158 branch_2 = conv2d_bn(branch_2, 96, 3) 159 branch_2 = conv2d_bn(branch_2, 96, 3) 160 branch_pool = layers.AveragePooling2D(3, strides=1, padding='same')(x) 161 branch_pool = conv2d_bn(branch_pool, 64, 1) 162 branches = [branch_0, branch_1, branch_2, branch_pool] 163 channel_axis = 1 if backend.image_data_format() == 'channels_first' else 3 164 x = layers.Concatenate(axis=channel_axis, name='mixed_5b')(branches) 165 166 # 10x block35 (Inception-ResNet-A block): 35 x 35 x 320 167 for block_idx in range(1, 11): 168 x = inception_resnet_block( 169 x, scale=0.17, block_type='block35', block_idx=block_idx) 170 171 # Mixed 6a (Reduction-A block): 17 x 17 x 1088 172 branch_0 = conv2d_bn(x, 384, 3, strides=2, padding='valid') 173 branch_1 = conv2d_bn(x, 256, 1) 174 branch_1 = conv2d_bn(branch_1, 256, 3) 175 branch_1 = conv2d_bn(branch_1, 384, 3, strides=2, padding='valid') 176 branch_pool = layers.MaxPooling2D(3, strides=2, padding='valid')(x) 177 branches = [branch_0, branch_1, branch_pool] 178 x = layers.Concatenate(axis=channel_axis, name='mixed_6a')(branches) 179 180 # 20x block17 (Inception-ResNet-B block): 17 x 17 x 1088 181 for block_idx in range(1, 21): 182 x = inception_resnet_block( 183 x, scale=0.1, block_type='block17', block_idx=block_idx) 184 185 # Mixed 7a (Reduction-B block): 8 x 8 x 2080 186 branch_0 = conv2d_bn(x, 256, 1) 187 branch_0 = conv2d_bn(branch_0, 384, 3, strides=2, padding='valid') 188 branch_1 = conv2d_bn(x, 256, 1) 189 branch_1 = conv2d_bn(branch_1, 288, 3, strides=2, padding='valid') 190 branch_2 = conv2d_bn(x, 256, 1) 191 branch_2 = conv2d_bn(branch_2, 288, 3) 192 branch_2 = conv2d_bn(branch_2, 320, 3, strides=2, padding='valid') 193 branch_pool = layers.MaxPooling2D(3, strides=2, padding='valid')(x) 194 branches = [branch_0, branch_1, branch_2, branch_pool] 195 x = layers.Concatenate(axis=channel_axis, name='mixed_7a')(branches) 196 197 # 10x block8 (Inception-ResNet-C block): 8 x 8 x 2080 198 for block_idx in range(1, 10): 199 x = inception_resnet_block( 200 x, scale=0.2, block_type='block8', block_idx=block_idx) 201 x = inception_resnet_block( 202 x, scale=1., activation=None, block_type='block8', block_idx=10) 203 204 # Final convolution block: 8 x 8 x 1536 205 x = conv2d_bn(x, 1536, 1, name='conv_7b') 206 207 if include_top: 208 # Classification block 209 x = layers.GlobalAveragePooling2D(name='avg_pool')(x) 210 imagenet_utils.validate_activation(classifier_activation, weights) 211 x = layers.Dense(classes, activation=classifier_activation, 212 name='predictions')(x) 213 else: 214 if pooling == 'avg': 215 x = layers.GlobalAveragePooling2D()(x) 216 elif pooling == 'max': 217 x = layers.GlobalMaxPooling2D()(x) 218 219 # Ensure that the model takes into account 220 # any potential predecessors of `input_tensor`. 221 if input_tensor is not None: 222 inputs = layer_utils.get_source_inputs(input_tensor) 223 else: 224 inputs = img_input 225 226 # Create model. 227 model = training.Model(inputs, x, name='inception_resnet_v2') 228 229 # Load weights. 230 if weights == 'imagenet': 231 if include_top: 232 fname = 'inception_resnet_v2_weights_tf_dim_ordering_tf_kernels.h5' 233 weights_path = data_utils.get_file( 234 fname, 235 BASE_WEIGHT_URL + fname, 236 cache_subdir='models', 237 file_hash='e693bd0210a403b3192acc6073ad2e96') 238 else: 239 fname = ('inception_resnet_v2_weights_' 240 'tf_dim_ordering_tf_kernels_notop.h5') 241 weights_path = data_utils.get_file( 242 fname, 243 BASE_WEIGHT_URL + fname, 244 cache_subdir='models', 245 file_hash='d19885ff4a710c122648d3b5c3b684e4') 246 model.load_weights(weights_path) 247 elif weights is not None: 248 model.load_weights(weights) 249 250 return model 251 252 253def conv2d_bn(x, 254 filters, 255 kernel_size, 256 strides=1, 257 padding='same', 258 activation='relu', 259 use_bias=False, 260 name=None): 261 """Utility function to apply conv + BN. 262 263 Args: 264 x: input tensor. 265 filters: filters in `Conv2D`. 266 kernel_size: kernel size as in `Conv2D`. 267 strides: strides in `Conv2D`. 268 padding: padding mode in `Conv2D`. 269 activation: activation in `Conv2D`. 270 use_bias: whether to use a bias in `Conv2D`. 271 name: name of the ops; will become `name + '_ac'` for the activation 272 and `name + '_bn'` for the batch norm layer. 273 274 Returns: 275 Output tensor after applying `Conv2D` and `BatchNormalization`. 276 """ 277 x = layers.Conv2D( 278 filters, 279 kernel_size, 280 strides=strides, 281 padding=padding, 282 use_bias=use_bias, 283 name=name)( 284 x) 285 if not use_bias: 286 bn_axis = 1 if backend.image_data_format() == 'channels_first' else 3 287 bn_name = None if name is None else name + '_bn' 288 x = layers.BatchNormalization(axis=bn_axis, scale=False, name=bn_name)(x) 289 if activation is not None: 290 ac_name = None if name is None else name + '_ac' 291 x = layers.Activation(activation, name=ac_name)(x) 292 return x 293 294 295def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'): 296 """Adds an Inception-ResNet block. 297 298 This function builds 3 types of Inception-ResNet blocks mentioned 299 in the paper, controlled by the `block_type` argument (which is the 300 block name used in the official TF-slim implementation): 301 - Inception-ResNet-A: `block_type='block35'` 302 - Inception-ResNet-B: `block_type='block17'` 303 - Inception-ResNet-C: `block_type='block8'` 304 305 Args: 306 x: input tensor. 307 scale: scaling factor to scale the residuals (i.e., the output of passing 308 `x` through an inception module) before adding them to the shortcut 309 branch. Let `r` be the output from the residual branch, the output of this 310 block will be `x + scale * r`. 311 block_type: `'block35'`, `'block17'` or `'block8'`, determines the network 312 structure in the residual branch. 313 block_idx: an `int` used for generating layer names. The Inception-ResNet 314 blocks are repeated many times in this network. We use `block_idx` to 315 identify each of the repetitions. For example, the first 316 Inception-ResNet-A block will have `block_type='block35', block_idx=0`, 317 and the layer names will have a common prefix `'block35_0'`. 318 activation: activation function to use at the end of the block (see 319 [activations](../activations.md)). When `activation=None`, no activation 320 is applied 321 (i.e., "linear" activation: `a(x) = x`). 322 323 Returns: 324 Output tensor for the block. 325 326 Raises: 327 ValueError: if `block_type` is not one of `'block35'`, 328 `'block17'` or `'block8'`. 329 """ 330 if block_type == 'block35': 331 branch_0 = conv2d_bn(x, 32, 1) 332 branch_1 = conv2d_bn(x, 32, 1) 333 branch_1 = conv2d_bn(branch_1, 32, 3) 334 branch_2 = conv2d_bn(x, 32, 1) 335 branch_2 = conv2d_bn(branch_2, 48, 3) 336 branch_2 = conv2d_bn(branch_2, 64, 3) 337 branches = [branch_0, branch_1, branch_2] 338 elif block_type == 'block17': 339 branch_0 = conv2d_bn(x, 192, 1) 340 branch_1 = conv2d_bn(x, 128, 1) 341 branch_1 = conv2d_bn(branch_1, 160, [1, 7]) 342 branch_1 = conv2d_bn(branch_1, 192, [7, 1]) 343 branches = [branch_0, branch_1] 344 elif block_type == 'block8': 345 branch_0 = conv2d_bn(x, 192, 1) 346 branch_1 = conv2d_bn(x, 192, 1) 347 branch_1 = conv2d_bn(branch_1, 224, [1, 3]) 348 branch_1 = conv2d_bn(branch_1, 256, [3, 1]) 349 branches = [branch_0, branch_1] 350 else: 351 raise ValueError('Unknown Inception-ResNet block type. ' 352 'Expects "block35", "block17" or "block8", ' 353 'but got: ' + str(block_type)) 354 355 block_name = block_type + '_' + str(block_idx) 356 channel_axis = 1 if backend.image_data_format() == 'channels_first' else 3 357 mixed = layers.Concatenate( 358 axis=channel_axis, name=block_name + '_mixed')( 359 branches) 360 up = conv2d_bn( 361 mixed, 362 backend.int_shape(x)[channel_axis], 363 1, 364 activation=None, 365 use_bias=True, 366 name=block_name + '_conv') 367 368 x = layers.Lambda( 369 lambda inputs, scale: inputs[0] + inputs[1] * scale, 370 output_shape=backend.int_shape(x)[1:], 371 arguments={'scale': scale}, 372 name=block_name)([x, up]) 373 if activation is not None: 374 x = layers.Activation(activation, name=block_name + '_ac')(x) 375 return x 376 377 378@keras_export('keras.applications.inception_resnet_v2.preprocess_input') 379def preprocess_input(x, data_format=None): 380 return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf') 381 382 383@keras_export('keras.applications.inception_resnet_v2.decode_predictions') 384def decode_predictions(preds, top=5): 385 return imagenet_utils.decode_predictions(preds, top=top) 386 387 388preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format( 389 mode='', 390 ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF, 391 error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC) 392decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__ 393