1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=invalid-name 16"""Xception V1 model for Keras. 17 18On ImageNet, this model gets to a top-1 validation accuracy of 0.790 19and a top-5 validation accuracy of 0.945. 20 21Reference: 22 - [Xception: Deep Learning with Depthwise Separable Convolutions]( 23 https://arxiv.org/abs/1610.02357) (CVPR 2017) 24 25""" 26from __future__ import absolute_import 27from __future__ import division 28from __future__ import print_function 29 30from tensorflow.python.keras import backend 31from tensorflow.python.keras.applications import imagenet_utils 32from tensorflow.python.keras.engine import training 33from tensorflow.python.keras.layers import VersionAwareLayers 34from tensorflow.python.keras.utils import data_utils 35from tensorflow.python.keras.utils import layer_utils 36from tensorflow.python.lib.io import file_io 37from tensorflow.python.util.tf_export import keras_export 38 39 40TF_WEIGHTS_PATH = ( 41 'https://storage.googleapis.com/tensorflow/keras-applications/' 42 'xception/xception_weights_tf_dim_ordering_tf_kernels.h5') 43TF_WEIGHTS_PATH_NO_TOP = ( 44 'https://storage.googleapis.com/tensorflow/keras-applications/' 45 'xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5') 46 47layers = VersionAwareLayers() 48 49 50@keras_export('keras.applications.xception.Xception', 51 'keras.applications.Xception') 52def Xception( 53 include_top=True, 54 weights='imagenet', 55 input_tensor=None, 56 input_shape=None, 57 pooling=None, 58 classes=1000, 59 classifier_activation='softmax'): 60 """Instantiates the Xception architecture. 61 62 Reference: 63 - [Xception: Deep Learning with Depthwise Separable Convolutions]( 64 https://arxiv.org/abs/1610.02357) (CVPR 2017) 65 66 Optionally loads weights pre-trained on ImageNet. 67 Note that the data format convention used by the model is 68 the one specified in your Keras config at `~/.keras/keras.json`. 69 Note that the default input image size for this model is 299x299. 70 71 Note: each Keras Application expects a specific kind of input preprocessing. 72 For Xception, call `tf.keras.applications.xception.preprocess_input` on your 73 inputs before passing them to the model. 74 75 Args: 76 include_top: whether to include the fully-connected 77 layer at the top of the network. 78 weights: one of `None` (random initialization), 79 'imagenet' (pre-training on ImageNet), 80 or the path to the weights file to be loaded. 81 input_tensor: optional Keras tensor 82 (i.e. output of `layers.Input()`) 83 to use as image input for the model. 84 input_shape: optional shape tuple, only to be specified 85 if `include_top` is False (otherwise the input shape 86 has to be `(299, 299, 3)`. 87 It should have exactly 3 inputs channels, 88 and width and height should be no smaller than 71. 89 E.g. `(150, 150, 3)` would be one valid value. 90 pooling: Optional pooling mode for feature extraction 91 when `include_top` is `False`. 92 - `None` means that the output of the model will be 93 the 4D tensor output of the 94 last convolutional block. 95 - `avg` means that global average pooling 96 will be applied to the output of the 97 last convolutional block, and thus 98 the output of the model will be a 2D tensor. 99 - `max` means that global max pooling will 100 be applied. 101 classes: optional number of classes to classify images 102 into, only to be specified if `include_top` is True, 103 and if no `weights` argument is specified. 104 classifier_activation: A `str` or callable. The activation function to use 105 on the "top" layer. Ignored unless `include_top=True`. Set 106 `classifier_activation=None` to return the logits of the "top" layer. 107 108 Returns: 109 A `keras.Model` instance. 110 111 Raises: 112 ValueError: in case of invalid argument for `weights`, 113 or invalid input shape. 114 ValueError: if `classifier_activation` is not `softmax` or `None` when 115 using a pretrained top layer. 116 """ 117 if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)): 118 raise ValueError('The `weights` argument should be either ' 119 '`None` (random initialization), `imagenet` ' 120 '(pre-training on ImageNet), ' 121 'or the path to the weights file to be loaded.') 122 123 if weights == 'imagenet' and include_top and classes != 1000: 124 raise ValueError('If using `weights` as `"imagenet"` with `include_top`' 125 ' as true, `classes` should be 1000') 126 127 # Determine proper input shape 128 input_shape = imagenet_utils.obtain_input_shape( 129 input_shape, 130 default_size=299, 131 min_size=71, 132 data_format=backend.image_data_format(), 133 require_flatten=include_top, 134 weights=weights) 135 136 if input_tensor is None: 137 img_input = layers.Input(shape=input_shape) 138 else: 139 if not backend.is_keras_tensor(input_tensor): 140 img_input = layers.Input(tensor=input_tensor, shape=input_shape) 141 else: 142 img_input = input_tensor 143 144 channel_axis = 1 if backend.image_data_format() == 'channels_first' else -1 145 146 x = layers.Conv2D( 147 32, (3, 3), 148 strides=(2, 2), 149 use_bias=False, 150 name='block1_conv1')(img_input) 151 x = layers.BatchNormalization(axis=channel_axis, name='block1_conv1_bn')(x) 152 x = layers.Activation('relu', name='block1_conv1_act')(x) 153 x = layers.Conv2D(64, (3, 3), use_bias=False, name='block1_conv2')(x) 154 x = layers.BatchNormalization(axis=channel_axis, name='block1_conv2_bn')(x) 155 x = layers.Activation('relu', name='block1_conv2_act')(x) 156 157 residual = layers.Conv2D( 158 128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x) 159 residual = layers.BatchNormalization(axis=channel_axis)(residual) 160 161 x = layers.SeparableConv2D( 162 128, (3, 3), padding='same', use_bias=False, name='block2_sepconv1')(x) 163 x = layers.BatchNormalization(axis=channel_axis, name='block2_sepconv1_bn')(x) 164 x = layers.Activation('relu', name='block2_sepconv2_act')(x) 165 x = layers.SeparableConv2D( 166 128, (3, 3), padding='same', use_bias=False, name='block2_sepconv2')(x) 167 x = layers.BatchNormalization(axis=channel_axis, name='block2_sepconv2_bn')(x) 168 169 x = layers.MaxPooling2D((3, 3), 170 strides=(2, 2), 171 padding='same', 172 name='block2_pool')(x) 173 x = layers.add([x, residual]) 174 175 residual = layers.Conv2D( 176 256, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x) 177 residual = layers.BatchNormalization(axis=channel_axis)(residual) 178 179 x = layers.Activation('relu', name='block3_sepconv1_act')(x) 180 x = layers.SeparableConv2D( 181 256, (3, 3), padding='same', use_bias=False, name='block3_sepconv1')(x) 182 x = layers.BatchNormalization(axis=channel_axis, name='block3_sepconv1_bn')(x) 183 x = layers.Activation('relu', name='block3_sepconv2_act')(x) 184 x = layers.SeparableConv2D( 185 256, (3, 3), padding='same', use_bias=False, name='block3_sepconv2')(x) 186 x = layers.BatchNormalization(axis=channel_axis, name='block3_sepconv2_bn')(x) 187 188 x = layers.MaxPooling2D((3, 3), 189 strides=(2, 2), 190 padding='same', 191 name='block3_pool')(x) 192 x = layers.add([x, residual]) 193 194 residual = layers.Conv2D( 195 728, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x) 196 residual = layers.BatchNormalization(axis=channel_axis)(residual) 197 198 x = layers.Activation('relu', name='block4_sepconv1_act')(x) 199 x = layers.SeparableConv2D( 200 728, (3, 3), padding='same', use_bias=False, name='block4_sepconv1')(x) 201 x = layers.BatchNormalization(axis=channel_axis, name='block4_sepconv1_bn')(x) 202 x = layers.Activation('relu', name='block4_sepconv2_act')(x) 203 x = layers.SeparableConv2D( 204 728, (3, 3), padding='same', use_bias=False, name='block4_sepconv2')(x) 205 x = layers.BatchNormalization(axis=channel_axis, name='block4_sepconv2_bn')(x) 206 207 x = layers.MaxPooling2D((3, 3), 208 strides=(2, 2), 209 padding='same', 210 name='block4_pool')(x) 211 x = layers.add([x, residual]) 212 213 for i in range(8): 214 residual = x 215 prefix = 'block' + str(i + 5) 216 217 x = layers.Activation('relu', name=prefix + '_sepconv1_act')(x) 218 x = layers.SeparableConv2D( 219 728, (3, 3), 220 padding='same', 221 use_bias=False, 222 name=prefix + '_sepconv1')(x) 223 x = layers.BatchNormalization( 224 axis=channel_axis, name=prefix + '_sepconv1_bn')(x) 225 x = layers.Activation('relu', name=prefix + '_sepconv2_act')(x) 226 x = layers.SeparableConv2D( 227 728, (3, 3), 228 padding='same', 229 use_bias=False, 230 name=prefix + '_sepconv2')(x) 231 x = layers.BatchNormalization( 232 axis=channel_axis, name=prefix + '_sepconv2_bn')(x) 233 x = layers.Activation('relu', name=prefix + '_sepconv3_act')(x) 234 x = layers.SeparableConv2D( 235 728, (3, 3), 236 padding='same', 237 use_bias=False, 238 name=prefix + '_sepconv3')(x) 239 x = layers.BatchNormalization( 240 axis=channel_axis, name=prefix + '_sepconv3_bn')(x) 241 242 x = layers.add([x, residual]) 243 244 residual = layers.Conv2D( 245 1024, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x) 246 residual = layers.BatchNormalization(axis=channel_axis)(residual) 247 248 x = layers.Activation('relu', name='block13_sepconv1_act')(x) 249 x = layers.SeparableConv2D( 250 728, (3, 3), padding='same', use_bias=False, name='block13_sepconv1')(x) 251 x = layers.BatchNormalization( 252 axis=channel_axis, name='block13_sepconv1_bn')(x) 253 x = layers.Activation('relu', name='block13_sepconv2_act')(x) 254 x = layers.SeparableConv2D( 255 1024, (3, 3), padding='same', use_bias=False, name='block13_sepconv2')(x) 256 x = layers.BatchNormalization( 257 axis=channel_axis, name='block13_sepconv2_bn')(x) 258 259 x = layers.MaxPooling2D((3, 3), 260 strides=(2, 2), 261 padding='same', 262 name='block13_pool')(x) 263 x = layers.add([x, residual]) 264 265 x = layers.SeparableConv2D( 266 1536, (3, 3), padding='same', use_bias=False, name='block14_sepconv1')(x) 267 x = layers.BatchNormalization( 268 axis=channel_axis, name='block14_sepconv1_bn')(x) 269 x = layers.Activation('relu', name='block14_sepconv1_act')(x) 270 271 x = layers.SeparableConv2D( 272 2048, (3, 3), padding='same', use_bias=False, name='block14_sepconv2')(x) 273 x = layers.BatchNormalization( 274 axis=channel_axis, name='block14_sepconv2_bn')(x) 275 x = layers.Activation('relu', name='block14_sepconv2_act')(x) 276 277 if include_top: 278 x = layers.GlobalAveragePooling2D(name='avg_pool')(x) 279 imagenet_utils.validate_activation(classifier_activation, weights) 280 x = layers.Dense(classes, activation=classifier_activation, 281 name='predictions')(x) 282 else: 283 if pooling == 'avg': 284 x = layers.GlobalAveragePooling2D()(x) 285 elif pooling == 'max': 286 x = layers.GlobalMaxPooling2D()(x) 287 288 # Ensure that the model takes into account 289 # any potential predecessors of `input_tensor`. 290 if input_tensor is not None: 291 inputs = layer_utils.get_source_inputs(input_tensor) 292 else: 293 inputs = img_input 294 # Create model. 295 model = training.Model(inputs, x, name='xception') 296 297 # Load weights. 298 if weights == 'imagenet': 299 if include_top: 300 weights_path = data_utils.get_file( 301 'xception_weights_tf_dim_ordering_tf_kernels.h5', 302 TF_WEIGHTS_PATH, 303 cache_subdir='models', 304 file_hash='0a58e3b7378bc2990ea3b43d5981f1f6') 305 else: 306 weights_path = data_utils.get_file( 307 'xception_weights_tf_dim_ordering_tf_kernels_notop.h5', 308 TF_WEIGHTS_PATH_NO_TOP, 309 cache_subdir='models', 310 file_hash='b0042744bf5b25fce3cb969f33bebb97') 311 model.load_weights(weights_path) 312 elif weights is not None: 313 model.load_weights(weights) 314 315 return model 316 317 318@keras_export('keras.applications.xception.preprocess_input') 319def preprocess_input(x, data_format=None): 320 return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf') 321 322 323@keras_export('keras.applications.xception.decode_predictions') 324def decode_predictions(preds, top=5): 325 return imagenet_utils.decode_predictions(preds, top=top) 326 327 328preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format( 329 mode='', 330 ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF, 331 error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC) 332decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__ 333