1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=invalid-name 16"""ResNet v2 models for Keras. 17 18Reference: 19 - [Identity Mappings in Deep Residual Networks] 20 (https://arxiv.org/abs/1603.05027) (CVPR 2016) 21""" 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26from tensorflow.python.keras.applications import imagenet_utils 27from tensorflow.python.keras.applications import resnet 28from tensorflow.python.util.tf_export import keras_export 29 30 31@keras_export('keras.applications.resnet_v2.ResNet50V2', 32 'keras.applications.ResNet50V2') 33def ResNet50V2( 34 include_top=True, 35 weights='imagenet', 36 input_tensor=None, 37 input_shape=None, 38 pooling=None, 39 classes=1000, 40 classifier_activation='softmax'): 41 """Instantiates the ResNet50V2 architecture.""" 42 def stack_fn(x): 43 x = resnet.stack2(x, 64, 3, name='conv2') 44 x = resnet.stack2(x, 128, 4, name='conv3') 45 x = resnet.stack2(x, 256, 6, name='conv4') 46 return resnet.stack2(x, 512, 3, stride1=1, name='conv5') 47 48 return resnet.ResNet( 49 stack_fn, 50 True, 51 True, 52 'resnet50v2', 53 include_top, 54 weights, 55 input_tensor, 56 input_shape, 57 pooling, 58 classes, 59 classifier_activation=classifier_activation) 60 61 62@keras_export('keras.applications.resnet_v2.ResNet101V2', 63 'keras.applications.ResNet101V2') 64def ResNet101V2( 65 include_top=True, 66 weights='imagenet', 67 input_tensor=None, 68 input_shape=None, 69 pooling=None, 70 classes=1000, 71 classifier_activation='softmax'): 72 """Instantiates the ResNet101V2 architecture.""" 73 def stack_fn(x): 74 x = resnet.stack2(x, 64, 3, name='conv2') 75 x = resnet.stack2(x, 128, 4, name='conv3') 76 x = resnet.stack2(x, 256, 23, name='conv4') 77 return resnet.stack2(x, 512, 3, stride1=1, name='conv5') 78 79 return resnet.ResNet( 80 stack_fn, 81 True, 82 True, 83 'resnet101v2', 84 include_top, 85 weights, 86 input_tensor, 87 input_shape, 88 pooling, 89 classes, 90 classifier_activation=classifier_activation) 91 92 93@keras_export('keras.applications.resnet_v2.ResNet152V2', 94 'keras.applications.ResNet152V2') 95def ResNet152V2( 96 include_top=True, 97 weights='imagenet', 98 input_tensor=None, 99 input_shape=None, 100 pooling=None, 101 classes=1000, 102 classifier_activation='softmax'): 103 """Instantiates the ResNet152V2 architecture.""" 104 def stack_fn(x): 105 x = resnet.stack2(x, 64, 3, name='conv2') 106 x = resnet.stack2(x, 128, 8, name='conv3') 107 x = resnet.stack2(x, 256, 36, name='conv4') 108 return resnet.stack2(x, 512, 3, stride1=1, name='conv5') 109 110 return resnet.ResNet( 111 stack_fn, 112 True, 113 True, 114 'resnet152v2', 115 include_top, 116 weights, 117 input_tensor, 118 input_shape, 119 pooling, 120 classes, 121 classifier_activation=classifier_activation) 122 123 124@keras_export('keras.applications.resnet_v2.preprocess_input') 125def preprocess_input(x, data_format=None): 126 return imagenet_utils.preprocess_input( 127 x, data_format=data_format, mode='tf') 128 129 130@keras_export('keras.applications.resnet_v2.decode_predictions') 131def decode_predictions(preds, top=5): 132 return imagenet_utils.decode_predictions(preds, top=top) 133 134 135preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format( 136 mode='', 137 ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF, 138 error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC) 139decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__ 140 141DOC = """ 142 143 Reference: 144 - [Identity Mappings in Deep Residual Networks] 145 (https://arxiv.org/abs/1603.05027) (CVPR 2016) 146 147 Optionally loads weights pre-trained on ImageNet. 148 Note that the data format convention used by the model is 149 the one specified in your Keras config at `~/.keras/keras.json`. 150 151 Note: each Keras Application expects a specific kind of input preprocessing. 152 For ResNetV2, call `tf.keras.applications.resnet_v2.preprocess_input` on your 153 inputs before passing them to the model. 154 155 Args: 156 include_top: whether to include the fully-connected 157 layer at the top of the network. 158 weights: one of `None` (random initialization), 159 'imagenet' (pre-training on ImageNet), 160 or the path to the weights file to be loaded. 161 input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) 162 to use as image input for the model. 163 input_shape: optional shape tuple, only to be specified 164 if `include_top` is False (otherwise the input shape 165 has to be `(224, 224, 3)` (with `'channels_last'` data format) 166 or `(3, 224, 224)` (with `'channels_first'` data format). 167 It should have exactly 3 inputs channels, 168 and width and height should be no smaller than 32. 169 E.g. `(200, 200, 3)` would be one valid value. 170 pooling: Optional pooling mode for feature extraction 171 when `include_top` is `False`. 172 - `None` means that the output of the model will be 173 the 4D tensor output of the 174 last convolutional block. 175 - `avg` means that global average pooling 176 will be applied to the output of the 177 last convolutional block, and thus 178 the output of the model will be a 2D tensor. 179 - `max` means that global max pooling will 180 be applied. 181 classes: optional number of classes to classify images 182 into, only to be specified if `include_top` is True, and 183 if no `weights` argument is specified. 184 classifier_activation: A `str` or callable. The activation function to use 185 on the "top" layer. Ignored unless `include_top=True`. Set 186 `classifier_activation=None` to return the logits of the "top" layer. 187 188 Returns: 189 A `keras.Model` instance. 190""" 191 192setattr(ResNet50V2, '__doc__', ResNet50V2.__doc__ + DOC) 193setattr(ResNet101V2, '__doc__', ResNet101V2.__doc__ + DOC) 194setattr(ResNet152V2, '__doc__', ResNet152V2.__doc__ + DOC) 195