1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=invalid-name 16"""VGG16 model for Keras. 17 18Reference: 19 - [Very Deep Convolutional Networks for Large-Scale Image Recognition] 20 (https://arxiv.org/abs/1409.1556) (ICLR 2015) 21""" 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26from tensorflow.python.keras import backend 27from tensorflow.python.keras.applications import imagenet_utils 28from tensorflow.python.keras.engine import training 29from tensorflow.python.keras.layers import VersionAwareLayers 30from tensorflow.python.keras.utils import data_utils 31from tensorflow.python.keras.utils import layer_utils 32from tensorflow.python.lib.io import file_io 33from tensorflow.python.util.tf_export import keras_export 34 35 36WEIGHTS_PATH = ('https://storage.googleapis.com/tensorflow/keras-applications/' 37 'vgg16/vgg16_weights_tf_dim_ordering_tf_kernels.h5') 38WEIGHTS_PATH_NO_TOP = ('https://storage.googleapis.com/tensorflow/' 39 'keras-applications/vgg16/' 40 'vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5') 41 42layers = VersionAwareLayers() 43 44 45@keras_export('keras.applications.vgg16.VGG16', 'keras.applications.VGG16') 46def VGG16( 47 include_top=True, 48 weights='imagenet', 49 input_tensor=None, 50 input_shape=None, 51 pooling=None, 52 classes=1000, 53 classifier_activation='softmax'): 54 """Instantiates the VGG16 model. 55 56 Reference: 57 - [Very Deep Convolutional Networks for Large-Scale Image Recognition]( 58 https://arxiv.org/abs/1409.1556) (ICLR 2015) 59 60 By default, it loads weights pre-trained on ImageNet. Check 'weights' for 61 other options. 62 63 This model can be built both with 'channels_first' data format 64 (channels, height, width) or 'channels_last' data format 65 (height, width, channels). 66 67 The default input size for this model is 224x224. 68 69 Note: each Keras Application expects a specific kind of input preprocessing. 70 For VGG16, call `tf.keras.applications.vgg16.preprocess_input` on your 71 inputs before passing them to the model. 72 73 Args: 74 include_top: whether to include the 3 fully-connected 75 layers at the top of the network. 76 weights: one of `None` (random initialization), 77 'imagenet' (pre-training on ImageNet), 78 or the path to the weights file to be loaded. 79 input_tensor: optional Keras tensor 80 (i.e. output of `layers.Input()`) 81 to use as image input for the model. 82 input_shape: optional shape tuple, only to be specified 83 if `include_top` is False (otherwise the input shape 84 has to be `(224, 224, 3)` 85 (with `channels_last` data format) 86 or `(3, 224, 224)` (with `channels_first` data format). 87 It should have exactly 3 input channels, 88 and width and height should be no smaller than 32. 89 E.g. `(200, 200, 3)` would be one valid value. 90 pooling: Optional pooling mode for feature extraction 91 when `include_top` is `False`. 92 - `None` means that the output of the model will be 93 the 4D tensor output of the 94 last convolutional block. 95 - `avg` means that global average pooling 96 will be applied to the output of the 97 last convolutional block, and thus 98 the output of the model will be a 2D tensor. 99 - `max` means that global max pooling will 100 be applied. 101 classes: optional number of classes to classify images 102 into, only to be specified if `include_top` is True, and 103 if no `weights` argument is specified. 104 classifier_activation: A `str` or callable. The activation function to use 105 on the "top" layer. Ignored unless `include_top=True`. Set 106 `classifier_activation=None` to return the logits of the "top" layer. 107 108 Returns: 109 A `keras.Model` instance. 110 111 Raises: 112 ValueError: in case of invalid argument for `weights`, 113 or invalid input shape. 114 ValueError: if `classifier_activation` is not `softmax` or `None` when 115 using a pretrained top layer. 116 """ 117 if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)): 118 raise ValueError('The `weights` argument should be either ' 119 '`None` (random initialization), `imagenet` ' 120 '(pre-training on ImageNet), ' 121 'or the path to the weights file to be loaded.') 122 123 if weights == 'imagenet' and include_top and classes != 1000: 124 raise ValueError('If using `weights` as `"imagenet"` with `include_top`' 125 ' as true, `classes` should be 1000') 126 # Determine proper input shape 127 input_shape = imagenet_utils.obtain_input_shape( 128 input_shape, 129 default_size=224, 130 min_size=32, 131 data_format=backend.image_data_format(), 132 require_flatten=include_top, 133 weights=weights) 134 135 if input_tensor is None: 136 img_input = layers.Input(shape=input_shape) 137 else: 138 if not backend.is_keras_tensor(input_tensor): 139 img_input = layers.Input(tensor=input_tensor, shape=input_shape) 140 else: 141 img_input = input_tensor 142 # Block 1 143 x = layers.Conv2D( 144 64, (3, 3), activation='relu', padding='same', name='block1_conv1')( 145 img_input) 146 x = layers.Conv2D( 147 64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x) 148 x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) 149 150 # Block 2 151 x = layers.Conv2D( 152 128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x) 153 x = layers.Conv2D( 154 128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x) 155 x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) 156 157 # Block 3 158 x = layers.Conv2D( 159 256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x) 160 x = layers.Conv2D( 161 256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x) 162 x = layers.Conv2D( 163 256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x) 164 x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x) 165 166 # Block 4 167 x = layers.Conv2D( 168 512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x) 169 x = layers.Conv2D( 170 512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x) 171 x = layers.Conv2D( 172 512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x) 173 x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x) 174 175 # Block 5 176 x = layers.Conv2D( 177 512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x) 178 x = layers.Conv2D( 179 512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x) 180 x = layers.Conv2D( 181 512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x) 182 x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x) 183 184 if include_top: 185 # Classification block 186 x = layers.Flatten(name='flatten')(x) 187 x = layers.Dense(4096, activation='relu', name='fc1')(x) 188 x = layers.Dense(4096, activation='relu', name='fc2')(x) 189 190 imagenet_utils.validate_activation(classifier_activation, weights) 191 x = layers.Dense(classes, activation=classifier_activation, 192 name='predictions')(x) 193 else: 194 if pooling == 'avg': 195 x = layers.GlobalAveragePooling2D()(x) 196 elif pooling == 'max': 197 x = layers.GlobalMaxPooling2D()(x) 198 199 # Ensure that the model takes into account 200 # any potential predecessors of `input_tensor`. 201 if input_tensor is not None: 202 inputs = layer_utils.get_source_inputs(input_tensor) 203 else: 204 inputs = img_input 205 # Create model. 206 model = training.Model(inputs, x, name='vgg16') 207 208 # Load weights. 209 if weights == 'imagenet': 210 if include_top: 211 weights_path = data_utils.get_file( 212 'vgg16_weights_tf_dim_ordering_tf_kernels.h5', 213 WEIGHTS_PATH, 214 cache_subdir='models', 215 file_hash='64373286793e3c8b2b4e3219cbf3544b') 216 else: 217 weights_path = data_utils.get_file( 218 'vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5', 219 WEIGHTS_PATH_NO_TOP, 220 cache_subdir='models', 221 file_hash='6d6bbae143d832006294945121d1f1fc') 222 model.load_weights(weights_path) 223 elif weights is not None: 224 model.load_weights(weights) 225 226 return model 227 228 229@keras_export('keras.applications.vgg16.preprocess_input') 230def preprocess_input(x, data_format=None): 231 return imagenet_utils.preprocess_input( 232 x, data_format=data_format, mode='caffe') 233 234 235@keras_export('keras.applications.vgg16.decode_predictions') 236def decode_predictions(preds, top=5): 237 return imagenet_utils.decode_predictions(preds, top=top) 238 239 240preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format( 241 mode='', 242 ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_CAFFE, 243 error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC) 244decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__ 245