1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=invalid-name
16"""DenseNet models for Keras.
17
18Reference:
19  - [Densely Connected Convolutional Networks](
20      https://arxiv.org/abs/1608.06993) (CVPR 2017)
21"""
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26from tensorflow.python.keras import backend
27from tensorflow.python.keras.applications import imagenet_utils
28from tensorflow.python.keras.engine import training
29from tensorflow.python.keras.layers import VersionAwareLayers
30from tensorflow.python.keras.utils import data_utils
31from tensorflow.python.keras.utils import layer_utils
32from tensorflow.python.lib.io import file_io
33from tensorflow.python.util.tf_export import keras_export
34
35
36BASE_WEIGHTS_PATH = ('https://storage.googleapis.com/tensorflow/'
37                     'keras-applications/densenet/')
38DENSENET121_WEIGHT_PATH = (
39    BASE_WEIGHTS_PATH + 'densenet121_weights_tf_dim_ordering_tf_kernels.h5')
40DENSENET121_WEIGHT_PATH_NO_TOP = (
41    BASE_WEIGHTS_PATH +
42    'densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5')
43DENSENET169_WEIGHT_PATH = (
44    BASE_WEIGHTS_PATH + 'densenet169_weights_tf_dim_ordering_tf_kernels.h5')
45DENSENET169_WEIGHT_PATH_NO_TOP = (
46    BASE_WEIGHTS_PATH +
47    'densenet169_weights_tf_dim_ordering_tf_kernels_notop.h5')
48DENSENET201_WEIGHT_PATH = (
49    BASE_WEIGHTS_PATH + 'densenet201_weights_tf_dim_ordering_tf_kernels.h5')
50DENSENET201_WEIGHT_PATH_NO_TOP = (
51    BASE_WEIGHTS_PATH +
52    'densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5')
53
54layers = VersionAwareLayers()
55
56
57def dense_block(x, blocks, name):
58  """A dense block.
59
60  Args:
61    x: input tensor.
62    blocks: integer, the number of building blocks.
63    name: string, block label.
64
65  Returns:
66    Output tensor for the block.
67  """
68  for i in range(blocks):
69    x = conv_block(x, 32, name=name + '_block' + str(i + 1))
70  return x
71
72
73def transition_block(x, reduction, name):
74  """A transition block.
75
76  Args:
77    x: input tensor.
78    reduction: float, compression rate at transition layers.
79    name: string, block label.
80
81  Returns:
82    output tensor for the block.
83  """
84  bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
85  x = layers.BatchNormalization(
86      axis=bn_axis, epsilon=1.001e-5, name=name + '_bn')(
87          x)
88  x = layers.Activation('relu', name=name + '_relu')(x)
89  x = layers.Conv2D(
90      int(backend.int_shape(x)[bn_axis] * reduction),
91      1,
92      use_bias=False,
93      name=name + '_conv')(
94          x)
95  x = layers.AveragePooling2D(2, strides=2, name=name + '_pool')(x)
96  return x
97
98
99def conv_block(x, growth_rate, name):
100  """A building block for a dense block.
101
102  Args:
103    x: input tensor.
104    growth_rate: float, growth rate at dense layers.
105    name: string, block label.
106
107  Returns:
108    Output tensor for the block.
109  """
110  bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
111  x1 = layers.BatchNormalization(
112      axis=bn_axis, epsilon=1.001e-5, name=name + '_0_bn')(
113          x)
114  x1 = layers.Activation('relu', name=name + '_0_relu')(x1)
115  x1 = layers.Conv2D(
116      4 * growth_rate, 1, use_bias=False, name=name + '_1_conv')(
117          x1)
118  x1 = layers.BatchNormalization(
119      axis=bn_axis, epsilon=1.001e-5, name=name + '_1_bn')(
120          x1)
121  x1 = layers.Activation('relu', name=name + '_1_relu')(x1)
122  x1 = layers.Conv2D(
123      growth_rate, 3, padding='same', use_bias=False, name=name + '_2_conv')(
124          x1)
125  x = layers.Concatenate(axis=bn_axis, name=name + '_concat')([x, x1])
126  return x
127
128
129def DenseNet(
130    blocks,
131    include_top=True,
132    weights='imagenet',
133    input_tensor=None,
134    input_shape=None,
135    pooling=None,
136    classes=1000,
137    classifier_activation='softmax'):
138  """Instantiates the DenseNet architecture.
139
140  Reference:
141  - [Densely Connected Convolutional Networks](
142      https://arxiv.org/abs/1608.06993) (CVPR 2017)
143
144  Optionally loads weights pre-trained on ImageNet.
145  Note that the data format convention used by the model is
146  the one specified in your Keras config at `~/.keras/keras.json`.
147
148  Note: each Keras Application expects a specific kind of input preprocessing.
149  For DenseNet, call `tf.keras.applications.densenet.preprocess_input` on your
150  inputs before passing them to the model.
151
152  Args:
153    blocks: numbers of building blocks for the four dense layers.
154    include_top: whether to include the fully-connected
155      layer at the top of the network.
156    weights: one of `None` (random initialization),
157      'imagenet' (pre-training on ImageNet),
158      or the path to the weights file to be loaded.
159    input_tensor: optional Keras tensor
160      (i.e. output of `layers.Input()`)
161      to use as image input for the model.
162    input_shape: optional shape tuple, only to be specified
163      if `include_top` is False (otherwise the input shape
164      has to be `(224, 224, 3)` (with `'channels_last'` data format)
165      or `(3, 224, 224)` (with `'channels_first'` data format).
166      It should have exactly 3 inputs channels,
167      and width and height should be no smaller than 32.
168      E.g. `(200, 200, 3)` would be one valid value.
169    pooling: optional pooling mode for feature extraction
170      when `include_top` is `False`.
171      - `None` means that the output of the model will be
172          the 4D tensor output of the
173          last convolutional block.
174      - `avg` means that global average pooling
175          will be applied to the output of the
176          last convolutional block, and thus
177          the output of the model will be a 2D tensor.
178      - `max` means that global max pooling will
179          be applied.
180    classes: optional number of classes to classify images
181      into, only to be specified if `include_top` is True, and
182      if no `weights` argument is specified.
183    classifier_activation: A `str` or callable. The activation function to use
184      on the "top" layer. Ignored unless `include_top=True`. Set
185      `classifier_activation=None` to return the logits of the "top" layer.
186
187  Returns:
188    A `keras.Model` instance.
189
190  Raises:
191    ValueError: in case of invalid argument for `weights`,
192      or invalid input shape.
193    ValueError: if `classifier_activation` is not `softmax` or `None` when
194      using a pretrained top layer.
195  """
196  if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
197    raise ValueError('The `weights` argument should be either '
198                     '`None` (random initialization), `imagenet` '
199                     '(pre-training on ImageNet), '
200                     'or the path to the weights file to be loaded.')
201
202  if weights == 'imagenet' and include_top and classes != 1000:
203    raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
204                     ' as true, `classes` should be 1000')
205
206  # Determine proper input shape
207  input_shape = imagenet_utils.obtain_input_shape(
208      input_shape,
209      default_size=224,
210      min_size=32,
211      data_format=backend.image_data_format(),
212      require_flatten=include_top,
213      weights=weights)
214
215  if input_tensor is None:
216    img_input = layers.Input(shape=input_shape)
217  else:
218    if not backend.is_keras_tensor(input_tensor):
219      img_input = layers.Input(tensor=input_tensor, shape=input_shape)
220    else:
221      img_input = input_tensor
222
223  bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
224
225  x = layers.ZeroPadding2D(padding=((3, 3), (3, 3)))(img_input)
226  x = layers.Conv2D(64, 7, strides=2, use_bias=False, name='conv1/conv')(x)
227  x = layers.BatchNormalization(
228      axis=bn_axis, epsilon=1.001e-5, name='conv1/bn')(
229          x)
230  x = layers.Activation('relu', name='conv1/relu')(x)
231  x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)))(x)
232  x = layers.MaxPooling2D(3, strides=2, name='pool1')(x)
233
234  x = dense_block(x, blocks[0], name='conv2')
235  x = transition_block(x, 0.5, name='pool2')
236  x = dense_block(x, blocks[1], name='conv3')
237  x = transition_block(x, 0.5, name='pool3')
238  x = dense_block(x, blocks[2], name='conv4')
239  x = transition_block(x, 0.5, name='pool4')
240  x = dense_block(x, blocks[3], name='conv5')
241
242  x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name='bn')(x)
243  x = layers.Activation('relu', name='relu')(x)
244
245  if include_top:
246    x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
247
248    imagenet_utils.validate_activation(classifier_activation, weights)
249    x = layers.Dense(classes, activation=classifier_activation,
250                     name='predictions')(x)
251  else:
252    if pooling == 'avg':
253      x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
254    elif pooling == 'max':
255      x = layers.GlobalMaxPooling2D(name='max_pool')(x)
256
257  # Ensure that the model takes into account
258  # any potential predecessors of `input_tensor`.
259  if input_tensor is not None:
260    inputs = layer_utils.get_source_inputs(input_tensor)
261  else:
262    inputs = img_input
263
264  # Create model.
265  if blocks == [6, 12, 24, 16]:
266    model = training.Model(inputs, x, name='densenet121')
267  elif blocks == [6, 12, 32, 32]:
268    model = training.Model(inputs, x, name='densenet169')
269  elif blocks == [6, 12, 48, 32]:
270    model = training.Model(inputs, x, name='densenet201')
271  else:
272    model = training.Model(inputs, x, name='densenet')
273
274  # Load weights.
275  if weights == 'imagenet':
276    if include_top:
277      if blocks == [6, 12, 24, 16]:
278        weights_path = data_utils.get_file(
279            'densenet121_weights_tf_dim_ordering_tf_kernels.h5',
280            DENSENET121_WEIGHT_PATH,
281            cache_subdir='models',
282            file_hash='9d60b8095a5708f2dcce2bca79d332c7')
283      elif blocks == [6, 12, 32, 32]:
284        weights_path = data_utils.get_file(
285            'densenet169_weights_tf_dim_ordering_tf_kernels.h5',
286            DENSENET169_WEIGHT_PATH,
287            cache_subdir='models',
288            file_hash='d699b8f76981ab1b30698df4c175e90b')
289      elif blocks == [6, 12, 48, 32]:
290        weights_path = data_utils.get_file(
291            'densenet201_weights_tf_dim_ordering_tf_kernels.h5',
292            DENSENET201_WEIGHT_PATH,
293            cache_subdir='models',
294            file_hash='1ceb130c1ea1b78c3bf6114dbdfd8807')
295    else:
296      if blocks == [6, 12, 24, 16]:
297        weights_path = data_utils.get_file(
298            'densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5',
299            DENSENET121_WEIGHT_PATH_NO_TOP,
300            cache_subdir='models',
301            file_hash='30ee3e1110167f948a6b9946edeeb738')
302      elif blocks == [6, 12, 32, 32]:
303        weights_path = data_utils.get_file(
304            'densenet169_weights_tf_dim_ordering_tf_kernels_notop.h5',
305            DENSENET169_WEIGHT_PATH_NO_TOP,
306            cache_subdir='models',
307            file_hash='b8c4d4c20dd625c148057b9ff1c1176b')
308      elif blocks == [6, 12, 48, 32]:
309        weights_path = data_utils.get_file(
310            'densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5',
311            DENSENET201_WEIGHT_PATH_NO_TOP,
312            cache_subdir='models',
313            file_hash='c13680b51ded0fb44dff2d8f86ac8bb1')
314    model.load_weights(weights_path)
315  elif weights is not None:
316    model.load_weights(weights)
317
318  return model
319
320
321@keras_export('keras.applications.densenet.DenseNet121',
322              'keras.applications.DenseNet121')
323def DenseNet121(include_top=True,
324                weights='imagenet',
325                input_tensor=None,
326                input_shape=None,
327                pooling=None,
328                classes=1000):
329  """Instantiates the Densenet121 architecture."""
330  return DenseNet([6, 12, 24, 16], include_top, weights, input_tensor,
331                  input_shape, pooling, classes)
332
333
334@keras_export('keras.applications.densenet.DenseNet169',
335              'keras.applications.DenseNet169')
336def DenseNet169(include_top=True,
337                weights='imagenet',
338                input_tensor=None,
339                input_shape=None,
340                pooling=None,
341                classes=1000):
342  """Instantiates the Densenet169 architecture."""
343  return DenseNet([6, 12, 32, 32], include_top, weights, input_tensor,
344                  input_shape, pooling, classes)
345
346
347@keras_export('keras.applications.densenet.DenseNet201',
348              'keras.applications.DenseNet201')
349def DenseNet201(include_top=True,
350                weights='imagenet',
351                input_tensor=None,
352                input_shape=None,
353                pooling=None,
354                classes=1000):
355  """Instantiates the Densenet201 architecture."""
356  return DenseNet([6, 12, 48, 32], include_top, weights, input_tensor,
357                  input_shape, pooling, classes)
358
359
360@keras_export('keras.applications.densenet.preprocess_input')
361def preprocess_input(x, data_format=None):
362  return imagenet_utils.preprocess_input(
363      x, data_format=data_format, mode='torch')
364
365
366@keras_export('keras.applications.densenet.decode_predictions')
367def decode_predictions(preds, top=5):
368  return imagenet_utils.decode_predictions(preds, top=top)
369
370
371preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
372    mode='',
373    ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TORCH,
374    error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
375decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
376
377DOC = """
378
379  Reference:
380  - [Densely Connected Convolutional Networks](
381      https://arxiv.org/abs/1608.06993) (CVPR 2017)
382
383  Optionally loads weights pre-trained on ImageNet.
384  Note that the data format convention used by the model is
385  the one specified in your Keras config at `~/.keras/keras.json`.
386
387  Note: each Keras Application expects a specific kind of input preprocessing.
388  For DenseNet, call `tf.keras.applications.densenet.preprocess_input` on your
389  inputs before passing them to the model.
390
391  Args:
392    include_top: whether to include the fully-connected
393      layer at the top of the network.
394    weights: one of `None` (random initialization),
395      'imagenet' (pre-training on ImageNet),
396      or the path to the weights file to be loaded.
397    input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
398      to use as image input for the model.
399    input_shape: optional shape tuple, only to be specified
400      if `include_top` is False (otherwise the input shape
401      has to be `(224, 224, 3)` (with `'channels_last'` data format)
402      or `(3, 224, 224)` (with `'channels_first'` data format).
403      It should have exactly 3 inputs channels,
404      and width and height should be no smaller than 32.
405      E.g. `(200, 200, 3)` would be one valid value.
406    pooling: Optional pooling mode for feature extraction
407      when `include_top` is `False`.
408      - `None` means that the output of the model will be
409          the 4D tensor output of the
410          last convolutional block.
411      - `avg` means that global average pooling
412          will be applied to the output of the
413          last convolutional block, and thus
414          the output of the model will be a 2D tensor.
415      - `max` means that global max pooling will
416          be applied.
417    classes: optional number of classes to classify images
418      into, only to be specified if `include_top` is True, and
419      if no `weights` argument is specified.
420
421  Returns:
422    A Keras model instance.
423"""
424
425setattr(DenseNet121, '__doc__', DenseNet121.__doc__ + DOC)
426setattr(DenseNet169, '__doc__', DenseNet169.__doc__ + DOC)
427setattr(DenseNet201, '__doc__', DenseNet201.__doc__ + DOC)
428