1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=invalid-name 16"""MobileNet v2 models for Keras. 17 18MobileNetV2 is a general architecture and can be used for multiple use cases. 19Depending on the use case, it can use different input layer size and 20different width factors. This allows different width models to reduce 21the number of multiply-adds and thereby 22reduce inference cost on mobile devices. 23 24MobileNetV2 is very similar to the original MobileNet, 25except that it uses inverted residual blocks with 26bottlenecking features. It has a drastically lower 27parameter count than the original MobileNet. 28MobileNets support any input size greater 29than 32 x 32, with larger image sizes 30offering better performance. 31 32The number of parameters and number of multiply-adds 33can be modified by using the `alpha` parameter, 34which increases/decreases the number of filters in each layer. 35By altering the image size and `alpha` parameter, 36all 22 models from the paper can be built, with ImageNet weights provided. 37 38The paper demonstrates the performance of MobileNets using `alpha` values of 391.0 (also called 100 % MobileNet), 0.35, 0.5, 0.75, 1.0, 1.3, and 1.4 40For each of these `alpha` values, weights for 5 different input image sizes 41are provided (224, 192, 160, 128, and 96). 42 43The following table describes the performance of 44MobileNet on various input sizes: 45------------------------------------------------------------------------ 46MACs stands for Multiply Adds 47 Classification Checkpoint|MACs (M)|Parameters (M)|Top 1 Accuracy|Top 5 Accuracy 48--------------------------|------------|---------------|---------|----|--------- 49| [mobilenet_v2_1.4_224] | 582 | 6.06 | 75.0 | 92.5 | 50| [mobilenet_v2_1.3_224] | 509 | 5.34 | 74.4 | 92.1 | 51| [mobilenet_v2_1.0_224] | 300 | 3.47 | 71.8 | 91.0 | 52| [mobilenet_v2_1.0_192] | 221 | 3.47 | 70.7 | 90.1 | 53| [mobilenet_v2_1.0_160] | 154 | 3.47 | 68.8 | 89.0 | 54| [mobilenet_v2_1.0_128] | 99 | 3.47 | 65.3 | 86.9 | 55| [mobilenet_v2_1.0_96] | 56 | 3.47 | 60.3 | 83.2 | 56| [mobilenet_v2_0.75_224] | 209 | 2.61 | 69.8 | 89.6 | 57| [mobilenet_v2_0.75_192] | 153 | 2.61 | 68.7 | 88.9 | 58| [mobilenet_v2_0.75_160] | 107 | 2.61 | 66.4 | 87.3 | 59| [mobilenet_v2_0.75_128] | 69 | 2.61 | 63.2 | 85.3 | 60| [mobilenet_v2_0.75_96] | 39 | 2.61 | 58.8 | 81.6 | 61| [mobilenet_v2_0.5_224] | 97 | 1.95 | 65.4 | 86.4 | 62| [mobilenet_v2_0.5_192] | 71 | 1.95 | 63.9 | 85.4 | 63| [mobilenet_v2_0.5_160] | 50 | 1.95 | 61.0 | 83.2 | 64| [mobilenet_v2_0.5_128] | 32 | 1.95 | 57.7 | 80.8 | 65| [mobilenet_v2_0.5_96] | 18 | 1.95 | 51.2 | 75.8 | 66| [mobilenet_v2_0.35_224] | 59 | 1.66 | 60.3 | 82.9 | 67| [mobilenet_v2_0.35_192] | 43 | 1.66 | 58.2 | 81.2 | 68| [mobilenet_v2_0.35_160] | 30 | 1.66 | 55.7 | 79.1 | 69| [mobilenet_v2_0.35_128] | 20 | 1.66 | 50.8 | 75.0 | 70| [mobilenet_v2_0.35_96] | 11 | 1.66 | 45.5 | 70.4 | 71 72 Reference: 73 - [MobileNetV2: Inverted Residuals and Linear Bottlenecks]( 74 https://arxiv.org/abs/1801.04381) (CVPR 2018) 75""" 76from __future__ import absolute_import 77from __future__ import division 78from __future__ import print_function 79 80from tensorflow.python.keras import backend 81from tensorflow.python.keras.applications import imagenet_utils 82from tensorflow.python.keras.engine import training 83from tensorflow.python.keras.layers import VersionAwareLayers 84from tensorflow.python.keras.utils import data_utils 85from tensorflow.python.keras.utils import layer_utils 86from tensorflow.python.lib.io import file_io 87from tensorflow.python.platform import tf_logging as logging 88from tensorflow.python.util.tf_export import keras_export 89 90BASE_WEIGHT_PATH = ('https://storage.googleapis.com/tensorflow/' 91 'keras-applications/mobilenet_v2/') 92layers = None 93 94 95@keras_export('keras.applications.mobilenet_v2.MobileNetV2', 96 'keras.applications.MobileNetV2') 97def MobileNetV2(input_shape=None, 98 alpha=1.0, 99 include_top=True, 100 weights='imagenet', 101 input_tensor=None, 102 pooling=None, 103 classes=1000, 104 classifier_activation='softmax', 105 **kwargs): 106 """Instantiates the MobileNetV2 architecture. 107 108 Reference: 109 - [MobileNetV2: Inverted Residuals and Linear Bottlenecks]( 110 https://arxiv.org/abs/1801.04381) (CVPR 2018) 111 112 Optionally loads weights pre-trained on ImageNet. 113 114 Note: each Keras Application expects a specific kind of input preprocessing. 115 For MobileNetV2, call `tf.keras.applications.mobilenet_v2.preprocess_input` 116 on your inputs before passing them to the model. 117 118 Args: 119 input_shape: Optional shape tuple, to be specified if you would 120 like to use a model with an input image resolution that is not 121 (224, 224, 3). 122 It should have exactly 3 inputs channels (224, 224, 3). 123 You can also omit this option if you would like 124 to infer input_shape from an input_tensor. 125 If you choose to include both input_tensor and input_shape then 126 input_shape will be used if they match, if the shapes 127 do not match then we will throw an error. 128 E.g. `(160, 160, 3)` would be one valid value. 129 alpha: Float between 0 and 1. controls the width of the network. 130 This is known as the width multiplier in the MobileNetV2 paper, 131 but the name is kept for consistency with `applications.MobileNetV1` 132 model in Keras. 133 - If `alpha` < 1.0, proportionally decreases the number 134 of filters in each layer. 135 - If `alpha` > 1.0, proportionally increases the number 136 of filters in each layer. 137 - If `alpha` = 1, default number of filters from the paper 138 are used at each layer. 139 include_top: Boolean, whether to include the fully-connected 140 layer at the top of the network. Defaults to `True`. 141 weights: String, one of `None` (random initialization), 142 'imagenet' (pre-training on ImageNet), 143 or the path to the weights file to be loaded. 144 input_tensor: Optional Keras tensor (i.e. output of 145 `layers.Input()`) 146 to use as image input for the model. 147 pooling: String, optional pooling mode for feature extraction 148 when `include_top` is `False`. 149 - `None` means that the output of the model 150 will be the 4D tensor output of the 151 last convolutional block. 152 - `avg` means that global average pooling 153 will be applied to the output of the 154 last convolutional block, and thus 155 the output of the model will be a 156 2D tensor. 157 - `max` means that global max pooling will 158 be applied. 159 classes: Integer, optional number of classes to classify images 160 into, only to be specified if `include_top` is True, and 161 if no `weights` argument is specified. 162 classifier_activation: A `str` or callable. The activation function to use 163 on the "top" layer. Ignored unless `include_top=True`. Set 164 `classifier_activation=None` to return the logits of the "top" layer. 165 **kwargs: For backwards compatibility only. 166 167 Returns: 168 A `keras.Model` instance. 169 170 Raises: 171 ValueError: in case of invalid argument for `weights`, 172 or invalid input shape or invalid alpha, rows when 173 weights='imagenet' 174 ValueError: if `classifier_activation` is not `softmax` or `None` when 175 using a pretrained top layer. 176 """ 177 global layers 178 if 'layers' in kwargs: 179 layers = kwargs.pop('layers') 180 else: 181 layers = VersionAwareLayers() 182 if kwargs: 183 raise ValueError('Unknown argument(s): %s' % (kwargs,)) 184 if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)): 185 raise ValueError('The `weights` argument should be either ' 186 '`None` (random initialization), `imagenet` ' 187 '(pre-training on ImageNet), ' 188 'or the path to the weights file to be loaded.') 189 190 if weights == 'imagenet' and include_top and classes != 1000: 191 raise ValueError('If using `weights` as `"imagenet"` with `include_top` ' 192 'as true, `classes` should be 1000') 193 194 # Determine proper input shape and default size. 195 # If both input_shape and input_tensor are used, they should match 196 if input_shape is not None and input_tensor is not None: 197 try: 198 is_input_t_tensor = backend.is_keras_tensor(input_tensor) 199 except ValueError: 200 try: 201 is_input_t_tensor = backend.is_keras_tensor( 202 layer_utils.get_source_inputs(input_tensor)) 203 except ValueError: 204 raise ValueError('input_tensor: ', input_tensor, 205 'is not type input_tensor') 206 if is_input_t_tensor: 207 if backend.image_data_format() == 'channels_first': 208 if backend.int_shape(input_tensor)[1] != input_shape[1]: 209 raise ValueError('input_shape: ', input_shape, 'and input_tensor: ', 210 input_tensor, 211 'do not meet the same shape requirements') 212 else: 213 if backend.int_shape(input_tensor)[2] != input_shape[1]: 214 raise ValueError('input_shape: ', input_shape, 'and input_tensor: ', 215 input_tensor, 216 'do not meet the same shape requirements') 217 else: 218 raise ValueError('input_tensor specified: ', input_tensor, 219 'is not a keras tensor') 220 221 # If input_shape is None, infer shape from input_tensor 222 if input_shape is None and input_tensor is not None: 223 224 try: 225 backend.is_keras_tensor(input_tensor) 226 except ValueError: 227 raise ValueError('input_tensor: ', input_tensor, 'is type: ', 228 type(input_tensor), 'which is not a valid type') 229 230 if input_shape is None and not backend.is_keras_tensor(input_tensor): 231 default_size = 224 232 elif input_shape is None and backend.is_keras_tensor(input_tensor): 233 if backend.image_data_format() == 'channels_first': 234 rows = backend.int_shape(input_tensor)[2] 235 cols = backend.int_shape(input_tensor)[3] 236 else: 237 rows = backend.int_shape(input_tensor)[1] 238 cols = backend.int_shape(input_tensor)[2] 239 240 if rows == cols and rows in [96, 128, 160, 192, 224]: 241 default_size = rows 242 else: 243 default_size = 224 244 245 # If input_shape is None and no input_tensor 246 elif input_shape is None: 247 default_size = 224 248 249 # If input_shape is not None, assume default size 250 else: 251 if backend.image_data_format() == 'channels_first': 252 rows = input_shape[1] 253 cols = input_shape[2] 254 else: 255 rows = input_shape[0] 256 cols = input_shape[1] 257 258 if rows == cols and rows in [96, 128, 160, 192, 224]: 259 default_size = rows 260 else: 261 default_size = 224 262 263 input_shape = imagenet_utils.obtain_input_shape( 264 input_shape, 265 default_size=default_size, 266 min_size=32, 267 data_format=backend.image_data_format(), 268 require_flatten=include_top, 269 weights=weights) 270 271 if backend.image_data_format() == 'channels_last': 272 row_axis, col_axis = (0, 1) 273 else: 274 row_axis, col_axis = (1, 2) 275 rows = input_shape[row_axis] 276 cols = input_shape[col_axis] 277 278 if weights == 'imagenet': 279 if alpha not in [0.35, 0.50, 0.75, 1.0, 1.3, 1.4]: 280 raise ValueError('If imagenet weights are being loaded, ' 281 'alpha can be one of `0.35`, `0.50`, `0.75`, ' 282 '`1.0`, `1.3` or `1.4` only.') 283 284 if rows != cols or rows not in [96, 128, 160, 192, 224]: 285 rows = 224 286 logging.warning('`input_shape` is undefined or non-square, ' 287 'or `rows` is not in [96, 128, 160, 192, 224].' 288 ' Weights for input shape (224, 224) will be' 289 ' loaded as the default.') 290 291 if input_tensor is None: 292 img_input = layers.Input(shape=input_shape) 293 else: 294 if not backend.is_keras_tensor(input_tensor): 295 img_input = layers.Input(tensor=input_tensor, shape=input_shape) 296 else: 297 img_input = input_tensor 298 299 channel_axis = 1 if backend.image_data_format() == 'channels_first' else -1 300 301 first_block_filters = _make_divisible(32 * alpha, 8) 302 x = layers.Conv2D( 303 first_block_filters, 304 kernel_size=3, 305 strides=(2, 2), 306 padding='same', 307 use_bias=False, 308 name='Conv1')(img_input) 309 x = layers.BatchNormalization( 310 axis=channel_axis, epsilon=1e-3, momentum=0.999, name='bn_Conv1')( 311 x) 312 x = layers.ReLU(6., name='Conv1_relu')(x) 313 314 x = _inverted_res_block( 315 x, filters=16, alpha=alpha, stride=1, expansion=1, block_id=0) 316 317 x = _inverted_res_block( 318 x, filters=24, alpha=alpha, stride=2, expansion=6, block_id=1) 319 x = _inverted_res_block( 320 x, filters=24, alpha=alpha, stride=1, expansion=6, block_id=2) 321 322 x = _inverted_res_block( 323 x, filters=32, alpha=alpha, stride=2, expansion=6, block_id=3) 324 x = _inverted_res_block( 325 x, filters=32, alpha=alpha, stride=1, expansion=6, block_id=4) 326 x = _inverted_res_block( 327 x, filters=32, alpha=alpha, stride=1, expansion=6, block_id=5) 328 329 x = _inverted_res_block( 330 x, filters=64, alpha=alpha, stride=2, expansion=6, block_id=6) 331 x = _inverted_res_block( 332 x, filters=64, alpha=alpha, stride=1, expansion=6, block_id=7) 333 x = _inverted_res_block( 334 x, filters=64, alpha=alpha, stride=1, expansion=6, block_id=8) 335 x = _inverted_res_block( 336 x, filters=64, alpha=alpha, stride=1, expansion=6, block_id=9) 337 338 x = _inverted_res_block( 339 x, filters=96, alpha=alpha, stride=1, expansion=6, block_id=10) 340 x = _inverted_res_block( 341 x, filters=96, alpha=alpha, stride=1, expansion=6, block_id=11) 342 x = _inverted_res_block( 343 x, filters=96, alpha=alpha, stride=1, expansion=6, block_id=12) 344 345 x = _inverted_res_block( 346 x, filters=160, alpha=alpha, stride=2, expansion=6, block_id=13) 347 x = _inverted_res_block( 348 x, filters=160, alpha=alpha, stride=1, expansion=6, block_id=14) 349 x = _inverted_res_block( 350 x, filters=160, alpha=alpha, stride=1, expansion=6, block_id=15) 351 352 x = _inverted_res_block( 353 x, filters=320, alpha=alpha, stride=1, expansion=6, block_id=16) 354 355 # no alpha applied to last conv as stated in the paper: 356 # if the width multiplier is greater than 1 we 357 # increase the number of output channels 358 if alpha > 1.0: 359 last_block_filters = _make_divisible(1280 * alpha, 8) 360 else: 361 last_block_filters = 1280 362 363 x = layers.Conv2D( 364 last_block_filters, kernel_size=1, use_bias=False, name='Conv_1')( 365 x) 366 x = layers.BatchNormalization( 367 axis=channel_axis, epsilon=1e-3, momentum=0.999, name='Conv_1_bn')( 368 x) 369 x = layers.ReLU(6., name='out_relu')(x) 370 371 if include_top: 372 x = layers.GlobalAveragePooling2D()(x) 373 imagenet_utils.validate_activation(classifier_activation, weights) 374 x = layers.Dense(classes, activation=classifier_activation, 375 name='predictions')(x) 376 377 else: 378 if pooling == 'avg': 379 x = layers.GlobalAveragePooling2D()(x) 380 elif pooling == 'max': 381 x = layers.GlobalMaxPooling2D()(x) 382 383 # Ensure that the model takes into account 384 # any potential predecessors of `input_tensor`. 385 if input_tensor is not None: 386 inputs = layer_utils.get_source_inputs(input_tensor) 387 else: 388 inputs = img_input 389 390 # Create model. 391 model = training.Model(inputs, x, name='mobilenetv2_%0.2f_%s' % (alpha, rows)) 392 393 # Load weights. 394 if weights == 'imagenet': 395 if include_top: 396 model_name = ('mobilenet_v2_weights_tf_dim_ordering_tf_kernels_' + 397 str(alpha) + '_' + str(rows) + '.h5') 398 weight_path = BASE_WEIGHT_PATH + model_name 399 weights_path = data_utils.get_file( 400 model_name, weight_path, cache_subdir='models') 401 else: 402 model_name = ('mobilenet_v2_weights_tf_dim_ordering_tf_kernels_' + 403 str(alpha) + '_' + str(rows) + '_no_top' + '.h5') 404 weight_path = BASE_WEIGHT_PATH + model_name 405 weights_path = data_utils.get_file( 406 model_name, weight_path, cache_subdir='models') 407 model.load_weights(weights_path) 408 elif weights is not None: 409 model.load_weights(weights) 410 411 return model 412 413 414def _inverted_res_block(inputs, expansion, stride, alpha, filters, block_id): 415 """Inverted ResNet block.""" 416 channel_axis = 1 if backend.image_data_format() == 'channels_first' else -1 417 418 in_channels = backend.int_shape(inputs)[channel_axis] 419 pointwise_conv_filters = int(filters * alpha) 420 pointwise_filters = _make_divisible(pointwise_conv_filters, 8) 421 x = inputs 422 prefix = 'block_{}_'.format(block_id) 423 424 if block_id: 425 # Expand 426 x = layers.Conv2D( 427 expansion * in_channels, 428 kernel_size=1, 429 padding='same', 430 use_bias=False, 431 activation=None, 432 name=prefix + 'expand')( 433 x) 434 x = layers.BatchNormalization( 435 axis=channel_axis, 436 epsilon=1e-3, 437 momentum=0.999, 438 name=prefix + 'expand_BN')( 439 x) 440 x = layers.ReLU(6., name=prefix + 'expand_relu')(x) 441 else: 442 prefix = 'expanded_conv_' 443 444 # Depthwise 445 if stride == 2: 446 x = layers.ZeroPadding2D( 447 padding=imagenet_utils.correct_pad(x, 3), 448 name=prefix + 'pad')(x) 449 x = layers.DepthwiseConv2D( 450 kernel_size=3, 451 strides=stride, 452 activation=None, 453 use_bias=False, 454 padding='same' if stride == 1 else 'valid', 455 name=prefix + 'depthwise')( 456 x) 457 x = layers.BatchNormalization( 458 axis=channel_axis, 459 epsilon=1e-3, 460 momentum=0.999, 461 name=prefix + 'depthwise_BN')( 462 x) 463 464 x = layers.ReLU(6., name=prefix + 'depthwise_relu')(x) 465 466 # Project 467 x = layers.Conv2D( 468 pointwise_filters, 469 kernel_size=1, 470 padding='same', 471 use_bias=False, 472 activation=None, 473 name=prefix + 'project')( 474 x) 475 x = layers.BatchNormalization( 476 axis=channel_axis, 477 epsilon=1e-3, 478 momentum=0.999, 479 name=prefix + 'project_BN')( 480 x) 481 482 if in_channels == pointwise_filters and stride == 1: 483 return layers.Add(name=prefix + 'add')([inputs, x]) 484 return x 485 486 487def _make_divisible(v, divisor, min_value=None): 488 if min_value is None: 489 min_value = divisor 490 new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 491 # Make sure that round down does not go down by more than 10%. 492 if new_v < 0.9 * v: 493 new_v += divisor 494 return new_v 495 496 497@keras_export('keras.applications.mobilenet_v2.preprocess_input') 498def preprocess_input(x, data_format=None): 499 return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf') 500 501 502@keras_export('keras.applications.mobilenet_v2.decode_predictions') 503def decode_predictions(preds, top=5): 504 return imagenet_utils.decode_predictions(preds, top=top) 505 506 507preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format( 508 mode='', 509 ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF, 510 error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC) 511decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__ 512