1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains definitions for the original form of Residual Networks.
16
17The 'v1' residual networks (ResNets) implemented in this module were proposed
18by:
19[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
20    Deep Residual Learning for Image Recognition. arXiv:1512.03385
21
22Other variants were introduced in:
23[2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
24    Identity Mappings in Deep Residual Networks. arXiv: 1603.05027
25
26The networks defined in this module utilize the bottleneck building block of
27[1] with projection shortcuts only for increasing depths. They employ batch
28normalization *after* every weight layer. This is the architecture used by
29MSRA in the Imagenet and MSCOCO 2016 competition models ResNet-101 and
30ResNet-152. See [2; Fig. 1a] for a comparison between the current 'v1'
31architecture and the alternative 'v2' architecture of [2] which uses batch
32normalization *before* every weight layer in the so-called full pre-activation
33units.
34
35Typical use:
36
37   from tensorflow.contrib.slim.python.slim.nets import
38   resnet_v1
39
40ResNet-101 for image classification into 1000 classes:
41
42   # inputs has shape [batch, 224, 224, 3]
43   with slim.arg_scope(resnet_v1.resnet_arg_scope()):
44      net, end_points = resnet_v1.resnet_v1_101(inputs, 1000, is_training=False)
45
46ResNet-101 for semantic segmentation into 21 classes:
47
48   # inputs has shape [batch, 513, 513, 3]
49   with slim.arg_scope(resnet_v1.resnet_arg_scope()):
50      net, end_points = resnet_v1.resnet_v1_101(inputs,
51                                                21,
52                                                is_training=False,
53                                                global_pool=False,
54                                                output_stride=16)
55"""
56
57from __future__ import absolute_import
58from __future__ import division
59from __future__ import print_function
60
61from tensorflow.contrib import layers
62from tensorflow.contrib.framework.python.ops import add_arg_scope
63from tensorflow.contrib.framework.python.ops import arg_scope
64from tensorflow.contrib.layers.python.layers import layers as layers_lib
65from tensorflow.contrib.layers.python.layers import utils
66from tensorflow.contrib.slim.python.slim.nets import resnet_utils
67from tensorflow.python.ops import math_ops
68from tensorflow.python.ops import nn_ops
69from tensorflow.python.ops import variable_scope
70
71resnet_arg_scope = resnet_utils.resnet_arg_scope
72
73
74@add_arg_scope
75def bottleneck(inputs,
76               depth,
77               depth_bottleneck,
78               stride,
79               rate=1,
80               outputs_collections=None,
81               scope=None):
82  """Bottleneck residual unit variant with BN after convolutions.
83
84  This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for
85  its definition. Note that we use here the bottleneck variant which has an
86  extra bottleneck layer.
87
88  When putting together two consecutive ResNet blocks that use this unit, one
89  should use stride = 2 in the last unit of the first block.
90
91  Args:
92    inputs: A tensor of size [batch, height, width, channels].
93    depth: The depth of the ResNet unit output.
94    depth_bottleneck: The depth of the bottleneck layers.
95    stride: The ResNet unit's stride. Determines the amount of downsampling of
96      the units output compared to its input.
97    rate: An integer, rate for atrous convolution.
98    outputs_collections: Collection to add the ResNet unit output.
99    scope: Optional variable_scope.
100
101  Returns:
102    The ResNet unit's output.
103  """
104  with variable_scope.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc:
105    depth_in = utils.last_dimension(inputs.get_shape(), min_rank=4)
106    if depth == depth_in:
107      shortcut = resnet_utils.subsample(inputs, stride, 'shortcut')
108    else:
109      shortcut = layers.conv2d(
110          inputs,
111          depth, [1, 1],
112          stride=stride,
113          activation_fn=None,
114          scope='shortcut')
115
116    residual = layers.conv2d(
117        inputs, depth_bottleneck, [1, 1], stride=1, scope='conv1')
118    residual = resnet_utils.conv2d_same(
119        residual, depth_bottleneck, 3, stride, rate=rate, scope='conv2')
120    residual = layers.conv2d(
121        residual, depth, [1, 1], stride=1, activation_fn=None, scope='conv3')
122
123    output = nn_ops.relu(shortcut + residual)
124
125    return utils.collect_named_outputs(outputs_collections, sc.name, output)
126
127
128def resnet_v1(inputs,
129              blocks,
130              num_classes=None,
131              is_training=True,
132              global_pool=True,
133              output_stride=None,
134              include_root_block=True,
135              reuse=None,
136              scope=None):
137  """Generator for v1 ResNet models.
138
139  This function generates a family of ResNet v1 models. See the resnet_v1_*()
140  methods for specific model instantiations, obtained by selecting different
141  block instantiations that produce ResNets of various depths.
142
143  Training for image classification on Imagenet is usually done with [224, 224]
144  inputs, resulting in [7, 7] feature maps at the output of the last ResNet
145  block for the ResNets defined in [1] that have nominal stride equal to 32.
146  However, for dense prediction tasks we advise that one uses inputs with
147  spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In
148  this case the feature maps at the ResNet output will have spatial shape
149  [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1]
150  and corners exactly aligned with the input image corners, which greatly
151  facilitates alignment of the features to the image. Using as input [225, 225]
152  images results in [8, 8] feature maps at the output of the last ResNet block.
153
154  For dense prediction tasks, the ResNet needs to run in fully-convolutional
155  (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all
156  have nominal stride equal to 32 and a good choice in FCN mode is to use
157  output_stride=16 in order to increase the density of the computed features at
158  small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915.
159
160  Args:
161    inputs: A tensor of size [batch, height_in, width_in, channels].
162    blocks: A list of length equal to the number of ResNet blocks. Each element
163      is a resnet_utils.Block object describing the units in the block.
164    num_classes: Number of predicted classes for classification tasks. If None
165      we return the features before the logit layer.
166    is_training: whether batch_norm layers are in training mode.
167    global_pool: If True, we perform global average pooling before computing the
168      logits. Set to True for image classification, False for dense prediction.
169    output_stride: If None, then the output will be computed at the nominal
170      network stride. If output_stride is not None, it specifies the requested
171      ratio of input to output spatial resolution.
172    include_root_block: If True, include the initial convolution followed by
173      max-pooling, if False excludes it.
174    reuse: whether or not the network and its variables should be reused. To be
175      able to reuse 'scope' must be given.
176    scope: Optional variable_scope.
177
178  Returns:
179    net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
180      If global_pool is False, then height_out and width_out are reduced by a
181      factor of output_stride compared to the respective height_in and width_in,
182      else both height_out and width_out equal one. If num_classes is None, then
183      net is the output of the last ResNet block, potentially after global
184      average pooling. If num_classes is not None, net contains the pre-softmax
185      activations.
186    end_points: A dictionary from components of the network to the corresponding
187      activation.
188
189  Raises:
190    ValueError: If the target output_stride is not valid.
191  """
192  with variable_scope.variable_scope(
193      scope, 'resnet_v1', [inputs], reuse=reuse) as sc:
194    end_points_collection = sc.original_name_scope + '_end_points'
195    with arg_scope(
196        [layers.conv2d, bottleneck, resnet_utils.stack_blocks_dense],
197        outputs_collections=end_points_collection):
198      with arg_scope([layers.batch_norm], is_training=is_training):
199        net = inputs
200        if include_root_block:
201          if output_stride is not None:
202            if output_stride % 4 != 0:
203              raise ValueError('The output_stride needs to be a multiple of 4.')
204            output_stride /= 4
205          net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1')
206          net = layers_lib.max_pool2d(net, [3, 3], stride=2, scope='pool1')
207        net = resnet_utils.stack_blocks_dense(net, blocks, output_stride)
208        if global_pool:
209          # Global average pooling.
210          net = math_ops.reduce_mean(net, [1, 2], name='pool5', keepdims=True)
211        if num_classes is not None:
212          net = layers.conv2d(
213              net,
214              num_classes, [1, 1],
215              activation_fn=None,
216              normalizer_fn=None,
217              scope='logits')
218        # Convert end_points_collection into a dictionary of end_points.
219        end_points = utils.convert_collection_to_dict(end_points_collection)
220        if num_classes is not None:
221          end_points['predictions'] = layers_lib.softmax(
222              net, scope='predictions')
223        return net, end_points
224resnet_v1.default_image_size = 224
225
226
227def resnet_v1_block(scope, base_depth, num_units, stride):
228  """Helper function for creating a resnet_v1 bottleneck block.
229
230  Args:
231    scope: The scope of the block.
232    base_depth: The depth of the bottleneck layer for each unit.
233    num_units: The number of units in the block.
234    stride: The stride of the block, implemented as a stride in the last unit.
235      All other units have stride=1.
236
237  Returns:
238    A resnet_v1 bottleneck block.
239  """
240  return resnet_utils.Block(scope, bottleneck, [{
241      'depth': base_depth * 4,
242      'depth_bottleneck': base_depth,
243      'stride': 1
244  }] * (num_units - 1) + [{
245      'depth': base_depth * 4,
246      'depth_bottleneck': base_depth,
247      'stride': stride
248  }])
249
250
251def resnet_v1_50(inputs,
252                 num_classes=None,
253                 is_training=True,
254                 global_pool=True,
255                 output_stride=None,
256                 reuse=None,
257                 scope='resnet_v1_50'):
258  """ResNet-50 model of [1]. See resnet_v1() for arg and return description."""
259  blocks = [
260      resnet_v1_block('block1', base_depth=64, num_units=3, stride=2),
261      resnet_v1_block('block2', base_depth=128, num_units=4, stride=2),
262      resnet_v1_block('block3', base_depth=256, num_units=6, stride=2),
263      resnet_v1_block('block4', base_depth=512, num_units=3, stride=1),
264  ]
265  return resnet_v1(
266      inputs,
267      blocks,
268      num_classes,
269      is_training,
270      global_pool,
271      output_stride,
272      include_root_block=True,
273      reuse=reuse,
274      scope=scope)
275
276
277def resnet_v1_101(inputs,
278                  num_classes=None,
279                  is_training=True,
280                  global_pool=True,
281                  output_stride=None,
282                  reuse=None,
283                  scope='resnet_v1_101'):
284  """ResNet-101 model of [1]. See resnet_v1() for arg and return description."""
285  blocks = [
286      resnet_v1_block('block1', base_depth=64, num_units=3, stride=2),
287      resnet_v1_block('block2', base_depth=128, num_units=4, stride=2),
288      resnet_v1_block('block3', base_depth=256, num_units=23, stride=2),
289      resnet_v1_block('block4', base_depth=512, num_units=3, stride=1),
290  ]
291  return resnet_v1(
292      inputs,
293      blocks,
294      num_classes,
295      is_training,
296      global_pool,
297      output_stride,
298      include_root_block=True,
299      reuse=reuse,
300      scope=scope)
301
302
303def resnet_v1_152(inputs,
304                  num_classes=None,
305                  is_training=True,
306                  global_pool=True,
307                  output_stride=None,
308                  reuse=None,
309                  scope='resnet_v1_152'):
310  """ResNet-152 model of [1]. See resnet_v1() for arg and return description."""
311  blocks = [
312      resnet_v1_block('block1', base_depth=64, num_units=3, stride=2),
313      resnet_v1_block('block2', base_depth=128, num_units=8, stride=2),
314      resnet_v1_block('block3', base_depth=256, num_units=36, stride=2),
315      resnet_v1_block('block4', base_depth=512, num_units=3, stride=1),
316  ]
317  return resnet_v1(
318      inputs,
319      blocks,
320      num_classes,
321      is_training,
322      global_pool,
323      output_stride,
324      include_root_block=True,
325      reuse=reuse,
326      scope=scope)
327
328
329def resnet_v1_200(inputs,
330                  num_classes=None,
331                  is_training=True,
332                  global_pool=True,
333                  output_stride=None,
334                  reuse=None,
335                  scope='resnet_v1_200'):
336  """ResNet-200 model of [2]. See resnet_v1() for arg and return description."""
337  blocks = [
338      resnet_v1_block('block1', base_depth=64, num_units=3, stride=2),
339      resnet_v1_block('block2', base_depth=128, num_units=24, stride=2),
340      resnet_v1_block('block3', base_depth=256, num_units=36, stride=2),
341      resnet_v1_block('block4', base_depth=512, num_units=3, stride=1),
342  ]
343  return resnet_v1(
344      inputs,
345      blocks,
346      num_classes,
347      is_training,
348      global_pool,
349      output_stride,
350      include_root_block=True,
351      reuse=reuse,
352      scope=scope)
353