1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=invalid-name
16"""Xception V1 model for Keras.
17
18On ImageNet, this model gets to a top-1 validation accuracy of 0.790
19and a top-5 validation accuracy of 0.945.
20
21Reference:
22  - [Xception: Deep Learning with Depthwise Separable Convolutions](
23      https://arxiv.org/abs/1610.02357) (CVPR 2017)
24
25"""
26from __future__ import absolute_import
27from __future__ import division
28from __future__ import print_function
29
30from tensorflow.python.keras import backend
31from tensorflow.python.keras.applications import imagenet_utils
32from tensorflow.python.keras.engine import training
33from tensorflow.python.keras.layers import VersionAwareLayers
34from tensorflow.python.keras.utils import data_utils
35from tensorflow.python.keras.utils import layer_utils
36from tensorflow.python.lib.io import file_io
37from tensorflow.python.util.tf_export import keras_export
38
39
40TF_WEIGHTS_PATH = (
41    'https://storage.googleapis.com/tensorflow/keras-applications/'
42    'xception/xception_weights_tf_dim_ordering_tf_kernels.h5')
43TF_WEIGHTS_PATH_NO_TOP = (
44    'https://storage.googleapis.com/tensorflow/keras-applications/'
45    'xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5')
46
47layers = VersionAwareLayers()
48
49
50@keras_export('keras.applications.xception.Xception',
51              'keras.applications.Xception')
52def Xception(
53    include_top=True,
54    weights='imagenet',
55    input_tensor=None,
56    input_shape=None,
57    pooling=None,
58    classes=1000,
59    classifier_activation='softmax'):
60  """Instantiates the Xception architecture.
61
62  Reference:
63  - [Xception: Deep Learning with Depthwise Separable Convolutions](
64      https://arxiv.org/abs/1610.02357) (CVPR 2017)
65
66  Optionally loads weights pre-trained on ImageNet.
67  Note that the data format convention used by the model is
68  the one specified in your Keras config at `~/.keras/keras.json`.
69  Note that the default input image size for this model is 299x299.
70
71  Note: each Keras Application expects a specific kind of input preprocessing.
72  For Xception, call `tf.keras.applications.xception.preprocess_input` on your
73  inputs before passing them to the model.
74
75  Args:
76    include_top: whether to include the fully-connected
77      layer at the top of the network.
78    weights: one of `None` (random initialization),
79      'imagenet' (pre-training on ImageNet),
80      or the path to the weights file to be loaded.
81    input_tensor: optional Keras tensor
82      (i.e. output of `layers.Input()`)
83      to use as image input for the model.
84    input_shape: optional shape tuple, only to be specified
85      if `include_top` is False (otherwise the input shape
86      has to be `(299, 299, 3)`.
87      It should have exactly 3 inputs channels,
88      and width and height should be no smaller than 71.
89      E.g. `(150, 150, 3)` would be one valid value.
90    pooling: Optional pooling mode for feature extraction
91      when `include_top` is `False`.
92      - `None` means that the output of the model will be
93          the 4D tensor output of the
94          last convolutional block.
95      - `avg` means that global average pooling
96          will be applied to the output of the
97          last convolutional block, and thus
98          the output of the model will be a 2D tensor.
99      - `max` means that global max pooling will
100          be applied.
101    classes: optional number of classes to classify images
102      into, only to be specified if `include_top` is True,
103      and if no `weights` argument is specified.
104    classifier_activation: A `str` or callable. The activation function to use
105      on the "top" layer. Ignored unless `include_top=True`. Set
106      `classifier_activation=None` to return the logits of the "top" layer.
107
108  Returns:
109    A `keras.Model` instance.
110
111  Raises:
112    ValueError: in case of invalid argument for `weights`,
113      or invalid input shape.
114    ValueError: if `classifier_activation` is not `softmax` or `None` when
115      using a pretrained top layer.
116  """
117  if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
118    raise ValueError('The `weights` argument should be either '
119                     '`None` (random initialization), `imagenet` '
120                     '(pre-training on ImageNet), '
121                     'or the path to the weights file to be loaded.')
122
123  if weights == 'imagenet' and include_top and classes != 1000:
124    raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
125                     ' as true, `classes` should be 1000')
126
127  # Determine proper input shape
128  input_shape = imagenet_utils.obtain_input_shape(
129      input_shape,
130      default_size=299,
131      min_size=71,
132      data_format=backend.image_data_format(),
133      require_flatten=include_top,
134      weights=weights)
135
136  if input_tensor is None:
137    img_input = layers.Input(shape=input_shape)
138  else:
139    if not backend.is_keras_tensor(input_tensor):
140      img_input = layers.Input(tensor=input_tensor, shape=input_shape)
141    else:
142      img_input = input_tensor
143
144  channel_axis = 1 if backend.image_data_format() == 'channels_first' else -1
145
146  x = layers.Conv2D(
147      32, (3, 3),
148      strides=(2, 2),
149      use_bias=False,
150      name='block1_conv1')(img_input)
151  x = layers.BatchNormalization(axis=channel_axis, name='block1_conv1_bn')(x)
152  x = layers.Activation('relu', name='block1_conv1_act')(x)
153  x = layers.Conv2D(64, (3, 3), use_bias=False, name='block1_conv2')(x)
154  x = layers.BatchNormalization(axis=channel_axis, name='block1_conv2_bn')(x)
155  x = layers.Activation('relu', name='block1_conv2_act')(x)
156
157  residual = layers.Conv2D(
158      128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
159  residual = layers.BatchNormalization(axis=channel_axis)(residual)
160
161  x = layers.SeparableConv2D(
162      128, (3, 3), padding='same', use_bias=False, name='block2_sepconv1')(x)
163  x = layers.BatchNormalization(axis=channel_axis, name='block2_sepconv1_bn')(x)
164  x = layers.Activation('relu', name='block2_sepconv2_act')(x)
165  x = layers.SeparableConv2D(
166      128, (3, 3), padding='same', use_bias=False, name='block2_sepconv2')(x)
167  x = layers.BatchNormalization(axis=channel_axis, name='block2_sepconv2_bn')(x)
168
169  x = layers.MaxPooling2D((3, 3),
170                          strides=(2, 2),
171                          padding='same',
172                          name='block2_pool')(x)
173  x = layers.add([x, residual])
174
175  residual = layers.Conv2D(
176      256, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
177  residual = layers.BatchNormalization(axis=channel_axis)(residual)
178
179  x = layers.Activation('relu', name='block3_sepconv1_act')(x)
180  x = layers.SeparableConv2D(
181      256, (3, 3), padding='same', use_bias=False, name='block3_sepconv1')(x)
182  x = layers.BatchNormalization(axis=channel_axis, name='block3_sepconv1_bn')(x)
183  x = layers.Activation('relu', name='block3_sepconv2_act')(x)
184  x = layers.SeparableConv2D(
185      256, (3, 3), padding='same', use_bias=False, name='block3_sepconv2')(x)
186  x = layers.BatchNormalization(axis=channel_axis, name='block3_sepconv2_bn')(x)
187
188  x = layers.MaxPooling2D((3, 3),
189                          strides=(2, 2),
190                          padding='same',
191                          name='block3_pool')(x)
192  x = layers.add([x, residual])
193
194  residual = layers.Conv2D(
195      728, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
196  residual = layers.BatchNormalization(axis=channel_axis)(residual)
197
198  x = layers.Activation('relu', name='block4_sepconv1_act')(x)
199  x = layers.SeparableConv2D(
200      728, (3, 3), padding='same', use_bias=False, name='block4_sepconv1')(x)
201  x = layers.BatchNormalization(axis=channel_axis, name='block4_sepconv1_bn')(x)
202  x = layers.Activation('relu', name='block4_sepconv2_act')(x)
203  x = layers.SeparableConv2D(
204      728, (3, 3), padding='same', use_bias=False, name='block4_sepconv2')(x)
205  x = layers.BatchNormalization(axis=channel_axis, name='block4_sepconv2_bn')(x)
206
207  x = layers.MaxPooling2D((3, 3),
208                          strides=(2, 2),
209                          padding='same',
210                          name='block4_pool')(x)
211  x = layers.add([x, residual])
212
213  for i in range(8):
214    residual = x
215    prefix = 'block' + str(i + 5)
216
217    x = layers.Activation('relu', name=prefix + '_sepconv1_act')(x)
218    x = layers.SeparableConv2D(
219        728, (3, 3),
220        padding='same',
221        use_bias=False,
222        name=prefix + '_sepconv1')(x)
223    x = layers.BatchNormalization(
224        axis=channel_axis, name=prefix + '_sepconv1_bn')(x)
225    x = layers.Activation('relu', name=prefix + '_sepconv2_act')(x)
226    x = layers.SeparableConv2D(
227        728, (3, 3),
228        padding='same',
229        use_bias=False,
230        name=prefix + '_sepconv2')(x)
231    x = layers.BatchNormalization(
232        axis=channel_axis, name=prefix + '_sepconv2_bn')(x)
233    x = layers.Activation('relu', name=prefix + '_sepconv3_act')(x)
234    x = layers.SeparableConv2D(
235        728, (3, 3),
236        padding='same',
237        use_bias=False,
238        name=prefix + '_sepconv3')(x)
239    x = layers.BatchNormalization(
240        axis=channel_axis, name=prefix + '_sepconv3_bn')(x)
241
242    x = layers.add([x, residual])
243
244  residual = layers.Conv2D(
245      1024, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
246  residual = layers.BatchNormalization(axis=channel_axis)(residual)
247
248  x = layers.Activation('relu', name='block13_sepconv1_act')(x)
249  x = layers.SeparableConv2D(
250      728, (3, 3), padding='same', use_bias=False, name='block13_sepconv1')(x)
251  x = layers.BatchNormalization(
252      axis=channel_axis, name='block13_sepconv1_bn')(x)
253  x = layers.Activation('relu', name='block13_sepconv2_act')(x)
254  x = layers.SeparableConv2D(
255      1024, (3, 3), padding='same', use_bias=False, name='block13_sepconv2')(x)
256  x = layers.BatchNormalization(
257      axis=channel_axis, name='block13_sepconv2_bn')(x)
258
259  x = layers.MaxPooling2D((3, 3),
260                          strides=(2, 2),
261                          padding='same',
262                          name='block13_pool')(x)
263  x = layers.add([x, residual])
264
265  x = layers.SeparableConv2D(
266      1536, (3, 3), padding='same', use_bias=False, name='block14_sepconv1')(x)
267  x = layers.BatchNormalization(
268      axis=channel_axis, name='block14_sepconv1_bn')(x)
269  x = layers.Activation('relu', name='block14_sepconv1_act')(x)
270
271  x = layers.SeparableConv2D(
272      2048, (3, 3), padding='same', use_bias=False, name='block14_sepconv2')(x)
273  x = layers.BatchNormalization(
274      axis=channel_axis, name='block14_sepconv2_bn')(x)
275  x = layers.Activation('relu', name='block14_sepconv2_act')(x)
276
277  if include_top:
278    x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
279    imagenet_utils.validate_activation(classifier_activation, weights)
280    x = layers.Dense(classes, activation=classifier_activation,
281                     name='predictions')(x)
282  else:
283    if pooling == 'avg':
284      x = layers.GlobalAveragePooling2D()(x)
285    elif pooling == 'max':
286      x = layers.GlobalMaxPooling2D()(x)
287
288  # Ensure that the model takes into account
289  # any potential predecessors of `input_tensor`.
290  if input_tensor is not None:
291    inputs = layer_utils.get_source_inputs(input_tensor)
292  else:
293    inputs = img_input
294  # Create model.
295  model = training.Model(inputs, x, name='xception')
296
297  # Load weights.
298  if weights == 'imagenet':
299    if include_top:
300      weights_path = data_utils.get_file(
301          'xception_weights_tf_dim_ordering_tf_kernels.h5',
302          TF_WEIGHTS_PATH,
303          cache_subdir='models',
304          file_hash='0a58e3b7378bc2990ea3b43d5981f1f6')
305    else:
306      weights_path = data_utils.get_file(
307          'xception_weights_tf_dim_ordering_tf_kernels_notop.h5',
308          TF_WEIGHTS_PATH_NO_TOP,
309          cache_subdir='models',
310          file_hash='b0042744bf5b25fce3cb969f33bebb97')
311    model.load_weights(weights_path)
312  elif weights is not None:
313    model.load_weights(weights)
314
315  return model
316
317
318@keras_export('keras.applications.xception.preprocess_input')
319def preprocess_input(x, data_format=None):
320  return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf')
321
322
323@keras_export('keras.applications.xception.decode_predictions')
324def decode_predictions(preds, top=5):
325  return imagenet_utils.decode_predictions(preds, top=top)
326
327
328preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
329    mode='',
330    ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,
331    error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
332decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
333