1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=invalid-name
16"""ResNet v2 models for Keras.
17
18Reference:
19  - [Identity Mappings in Deep Residual Networks]
20    (https://arxiv.org/abs/1603.05027) (CVPR 2016)
21"""
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26from tensorflow.python.keras.applications import imagenet_utils
27from tensorflow.python.keras.applications import resnet
28from tensorflow.python.util.tf_export import keras_export
29
30
31@keras_export('keras.applications.resnet_v2.ResNet50V2',
32              'keras.applications.ResNet50V2')
33def ResNet50V2(
34    include_top=True,
35    weights='imagenet',
36    input_tensor=None,
37    input_shape=None,
38    pooling=None,
39    classes=1000,
40    classifier_activation='softmax'):
41  """Instantiates the ResNet50V2 architecture."""
42  def stack_fn(x):
43    x = resnet.stack2(x, 64, 3, name='conv2')
44    x = resnet.stack2(x, 128, 4, name='conv3')
45    x = resnet.stack2(x, 256, 6, name='conv4')
46    return resnet.stack2(x, 512, 3, stride1=1, name='conv5')
47
48  return resnet.ResNet(
49      stack_fn,
50      True,
51      True,
52      'resnet50v2',
53      include_top,
54      weights,
55      input_tensor,
56      input_shape,
57      pooling,
58      classes,
59      classifier_activation=classifier_activation)
60
61
62@keras_export('keras.applications.resnet_v2.ResNet101V2',
63              'keras.applications.ResNet101V2')
64def ResNet101V2(
65    include_top=True,
66    weights='imagenet',
67    input_tensor=None,
68    input_shape=None,
69    pooling=None,
70    classes=1000,
71    classifier_activation='softmax'):
72  """Instantiates the ResNet101V2 architecture."""
73  def stack_fn(x):
74    x = resnet.stack2(x, 64, 3, name='conv2')
75    x = resnet.stack2(x, 128, 4, name='conv3')
76    x = resnet.stack2(x, 256, 23, name='conv4')
77    return resnet.stack2(x, 512, 3, stride1=1, name='conv5')
78
79  return resnet.ResNet(
80      stack_fn,
81      True,
82      True,
83      'resnet101v2',
84      include_top,
85      weights,
86      input_tensor,
87      input_shape,
88      pooling,
89      classes,
90      classifier_activation=classifier_activation)
91
92
93@keras_export('keras.applications.resnet_v2.ResNet152V2',
94              'keras.applications.ResNet152V2')
95def ResNet152V2(
96    include_top=True,
97    weights='imagenet',
98    input_tensor=None,
99    input_shape=None,
100    pooling=None,
101    classes=1000,
102    classifier_activation='softmax'):
103  """Instantiates the ResNet152V2 architecture."""
104  def stack_fn(x):
105    x = resnet.stack2(x, 64, 3, name='conv2')
106    x = resnet.stack2(x, 128, 8, name='conv3')
107    x = resnet.stack2(x, 256, 36, name='conv4')
108    return resnet.stack2(x, 512, 3, stride1=1, name='conv5')
109
110  return resnet.ResNet(
111      stack_fn,
112      True,
113      True,
114      'resnet152v2',
115      include_top,
116      weights,
117      input_tensor,
118      input_shape,
119      pooling,
120      classes,
121      classifier_activation=classifier_activation)
122
123
124@keras_export('keras.applications.resnet_v2.preprocess_input')
125def preprocess_input(x, data_format=None):
126  return imagenet_utils.preprocess_input(
127      x, data_format=data_format, mode='tf')
128
129
130@keras_export('keras.applications.resnet_v2.decode_predictions')
131def decode_predictions(preds, top=5):
132  return imagenet_utils.decode_predictions(preds, top=top)
133
134
135preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
136    mode='',
137    ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,
138    error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
139decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
140
141DOC = """
142
143  Reference:
144  - [Identity Mappings in Deep Residual Networks]
145    (https://arxiv.org/abs/1603.05027) (CVPR 2016)
146
147  Optionally loads weights pre-trained on ImageNet.
148  Note that the data format convention used by the model is
149  the one specified in your Keras config at `~/.keras/keras.json`.
150
151  Note: each Keras Application expects a specific kind of input preprocessing.
152  For ResNetV2, call `tf.keras.applications.resnet_v2.preprocess_input` on your
153  inputs before passing them to the model.
154
155  Args:
156    include_top: whether to include the fully-connected
157      layer at the top of the network.
158    weights: one of `None` (random initialization),
159      'imagenet' (pre-training on ImageNet),
160      or the path to the weights file to be loaded.
161    input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
162      to use as image input for the model.
163    input_shape: optional shape tuple, only to be specified
164      if `include_top` is False (otherwise the input shape
165      has to be `(224, 224, 3)` (with `'channels_last'` data format)
166      or `(3, 224, 224)` (with `'channels_first'` data format).
167      It should have exactly 3 inputs channels,
168      and width and height should be no smaller than 32.
169      E.g. `(200, 200, 3)` would be one valid value.
170    pooling: Optional pooling mode for feature extraction
171      when `include_top` is `False`.
172      - `None` means that the output of the model will be
173          the 4D tensor output of the
174          last convolutional block.
175      - `avg` means that global average pooling
176          will be applied to the output of the
177          last convolutional block, and thus
178          the output of the model will be a 2D tensor.
179      - `max` means that global max pooling will
180          be applied.
181    classes: optional number of classes to classify images
182      into, only to be specified if `include_top` is True, and
183      if no `weights` argument is specified.
184    classifier_activation: A `str` or callable. The activation function to use
185      on the "top" layer. Ignored unless `include_top=True`. Set
186      `classifier_activation=None` to return the logits of the "top" layer.
187
188  Returns:
189    A `keras.Model` instance.
190"""
191
192setattr(ResNet50V2, '__doc__', ResNet50V2.__doc__ + DOC)
193setattr(ResNet101V2, '__doc__', ResNet101V2.__doc__ + DOC)
194setattr(ResNet152V2, '__doc__', ResNet152V2.__doc__ + DOC)
195