1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains definitions for the preactivation form of Residual Networks.
16
17Residual networks (ResNets) were originally proposed in:
18[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
19    Deep Residual Learning for Image Recognition. arXiv:1512.03385
20
21The full preactivation 'v2' ResNet variant implemented in this module was
22introduced by:
23[2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
24    Identity Mappings in Deep Residual Networks. arXiv: 1603.05027
25
26The key difference of the full preactivation 'v2' variant compared to the
27'v1' variant in [1] is the use of batch normalization before every weight layer.
28
29Typical use:
30
31   from tensorflow.contrib.slim.python.slim.nets import
32   resnet_v2
33
34ResNet-101 for image classification into 1000 classes:
35
36   # inputs has shape [batch, 224, 224, 3]
37   with slim.arg_scope(resnet_v2.resnet_arg_scope()):
38      net, end_points = resnet_v2.resnet_v2_101(inputs, 1000, is_training=False)
39
40ResNet-101 for semantic segmentation into 21 classes:
41
42   # inputs has shape [batch, 513, 513, 3]
43   with slim.arg_scope(resnet_v2.resnet_arg_scope()):
44      net, end_points = resnet_v2.resnet_v2_101(inputs,
45                                                21,
46                                                is_training=False,
47                                                global_pool=False,
48                                                output_stride=16)
49"""
50
51from __future__ import absolute_import
52from __future__ import division
53from __future__ import print_function
54
55from tensorflow.contrib import layers as layers_lib
56from tensorflow.contrib.framework.python.ops import add_arg_scope
57from tensorflow.contrib.framework.python.ops import arg_scope
58from tensorflow.contrib.layers.python.layers import layers
59from tensorflow.contrib.layers.python.layers import utils
60from tensorflow.contrib.slim.python.slim.nets import resnet_utils
61from tensorflow.python.ops import math_ops
62from tensorflow.python.ops import nn_ops
63from tensorflow.python.ops import variable_scope
64
65resnet_arg_scope = resnet_utils.resnet_arg_scope
66
67
68@add_arg_scope
69def bottleneck(inputs,
70               depth,
71               depth_bottleneck,
72               stride,
73               rate=1,
74               outputs_collections=None,
75               scope=None):
76  """Bottleneck residual unit variant with BN before convolutions.
77
78  This is the full preactivation residual unit variant proposed in [2]. See
79  Fig. 1(b) of [2] for its definition. Note that we use here the bottleneck
80  variant which has an extra bottleneck layer.
81
82  When putting together two consecutive ResNet blocks that use this unit, one
83  should use stride = 2 in the last unit of the first block.
84
85  Args:
86    inputs: A tensor of size [batch, height, width, channels].
87    depth: The depth of the ResNet unit output.
88    depth_bottleneck: The depth of the bottleneck layers.
89    stride: The ResNet unit's stride. Determines the amount of downsampling of
90      the units output compared to its input.
91    rate: An integer, rate for atrous convolution.
92    outputs_collections: Collection to add the ResNet unit output.
93    scope: Optional variable_scope.
94
95  Returns:
96    The ResNet unit's output.
97  """
98  with variable_scope.variable_scope(scope, 'bottleneck_v2', [inputs]) as sc:
99    depth_in = utils.last_dimension(inputs.get_shape(), min_rank=4)
100    preact = layers.batch_norm(
101        inputs, activation_fn=nn_ops.relu, scope='preact')
102    if depth == depth_in:
103      shortcut = resnet_utils.subsample(inputs, stride, 'shortcut')
104    else:
105      shortcut = layers_lib.conv2d(
106          preact,
107          depth, [1, 1],
108          stride=stride,
109          normalizer_fn=None,
110          activation_fn=None,
111          scope='shortcut')
112
113    residual = layers_lib.conv2d(
114        preact, depth_bottleneck, [1, 1], stride=1, scope='conv1')
115    residual = resnet_utils.conv2d_same(
116        residual, depth_bottleneck, 3, stride, rate=rate, scope='conv2')
117    residual = layers_lib.conv2d(
118        residual,
119        depth, [1, 1],
120        stride=1,
121        normalizer_fn=None,
122        activation_fn=None,
123        scope='conv3')
124
125    output = shortcut + residual
126
127    return utils.collect_named_outputs(outputs_collections, sc.name, output)
128
129
130def resnet_v2(inputs,
131              blocks,
132              num_classes=None,
133              is_training=True,
134              global_pool=True,
135              output_stride=None,
136              include_root_block=True,
137              reuse=None,
138              scope=None):
139  """Generator for v2 (preactivation) ResNet models.
140
141  This function generates a family of ResNet v2 models. See the resnet_v2_*()
142  methods for specific model instantiations, obtained by selecting different
143  block instantiations that produce ResNets of various depths.
144
145  Training for image classification on Imagenet is usually done with [224, 224]
146  inputs, resulting in [7, 7] feature maps at the output of the last ResNet
147  block for the ResNets defined in [1] that have nominal stride equal to 32.
148  However, for dense prediction tasks we advise that one uses inputs with
149  spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In
150  this case the feature maps at the ResNet output will have spatial shape
151  [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1]
152  and corners exactly aligned with the input image corners, which greatly
153  facilitates alignment of the features to the image. Using as input [225, 225]
154  images results in [8, 8] feature maps at the output of the last ResNet block.
155
156  For dense prediction tasks, the ResNet needs to run in fully-convolutional
157  (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all
158  have nominal stride equal to 32 and a good choice in FCN mode is to use
159  output_stride=16 in order to increase the density of the computed features at
160  small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915.
161
162  Args:
163    inputs: A tensor of size [batch, height_in, width_in, channels].
164    blocks: A list of length equal to the number of ResNet blocks. Each element
165      is a resnet_utils.Block object describing the units in the block.
166    num_classes: Number of predicted classes for classification tasks. If None
167      we return the features before the logit layer.
168    is_training: whether batch_norm layers are in training mode.
169    global_pool: If True, we perform global average pooling before computing the
170      logits. Set to True for image classification, False for dense prediction.
171    output_stride: If None, then the output will be computed at the nominal
172      network stride. If output_stride is not None, it specifies the requested
173      ratio of input to output spatial resolution.
174    include_root_block: If True, include the initial convolution followed by
175      max-pooling, if False excludes it. If excluded, `inputs` should be the
176      results of an activation-less convolution.
177    reuse: whether or not the network and its variables should be reused. To be
178      able to reuse 'scope' must be given.
179    scope: Optional variable_scope.
180
181
182  Returns:
183    net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
184      If global_pool is False, then height_out and width_out are reduced by a
185      factor of output_stride compared to the respective height_in and width_in,
186      else both height_out and width_out equal one. If num_classes is None, then
187      net is the output of the last ResNet block, potentially after global
188      average pooling. If num_classes is not None, net contains the pre-softmax
189      activations.
190    end_points: A dictionary from components of the network to the corresponding
191      activation.
192
193  Raises:
194    ValueError: If the target output_stride is not valid.
195  """
196  with variable_scope.variable_scope(
197      scope, 'resnet_v2', [inputs], reuse=reuse) as sc:
198    end_points_collection = sc.original_name_scope + '_end_points'
199    with arg_scope(
200        [layers_lib.conv2d, bottleneck, resnet_utils.stack_blocks_dense],
201        outputs_collections=end_points_collection):
202      with arg_scope([layers.batch_norm], is_training=is_training):
203        net = inputs
204        if include_root_block:
205          if output_stride is not None:
206            if output_stride % 4 != 0:
207              raise ValueError('The output_stride needs to be a multiple of 4.')
208            output_stride /= 4
209          # We do not include batch normalization or activation functions in
210          # conv1 because the first ResNet unit will perform these. Cf.
211          # Appendix of [2].
212          with arg_scope(
213              [layers_lib.conv2d], activation_fn=None, normalizer_fn=None):
214            net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1')
215          net = layers.max_pool2d(net, [3, 3], stride=2, scope='pool1')
216        net = resnet_utils.stack_blocks_dense(net, blocks, output_stride)
217        # This is needed because the pre-activation variant does not have batch
218        # normalization or activation functions in the residual unit output. See
219        # Appendix of [2].
220        net = layers.batch_norm(
221            net, activation_fn=nn_ops.relu, scope='postnorm')
222        if global_pool:
223          # Global average pooling.
224          net = math_ops.reduce_mean(net, [1, 2], name='pool5', keepdims=True)
225        if num_classes is not None:
226          net = layers_lib.conv2d(
227              net,
228              num_classes, [1, 1],
229              activation_fn=None,
230              normalizer_fn=None,
231              scope='logits')
232        # Convert end_points_collection into a dictionary of end_points.
233        end_points = utils.convert_collection_to_dict(end_points_collection)
234        if num_classes is not None:
235          end_points['predictions'] = layers.softmax(net, scope='predictions')
236        return net, end_points
237resnet_v2.default_image_size = 224
238
239
240def resnet_v2_block(scope, base_depth, num_units, stride):
241  """Helper function for creating a resnet_v2 bottleneck block.
242
243  Args:
244    scope: The scope of the block.
245    base_depth: The depth of the bottleneck layer for each unit.
246    num_units: The number of units in the block.
247    stride: The stride of the block, implemented as a stride in the last unit.
248      All other units have stride=1.
249
250  Returns:
251    A resnet_v2 bottleneck block.
252  """
253  return resnet_utils.Block(scope, bottleneck, [{
254      'depth': base_depth * 4,
255      'depth_bottleneck': base_depth,
256      'stride': 1
257  }] * (num_units - 1) + [{
258      'depth': base_depth * 4,
259      'depth_bottleneck': base_depth,
260      'stride': stride
261  }])
262
263
264def resnet_v2_50(inputs,
265                 num_classes=None,
266                 is_training=True,
267                 global_pool=True,
268                 output_stride=None,
269                 reuse=None,
270                 scope='resnet_v2_50'):
271  """ResNet-50 model of [1]. See resnet_v2() for arg and return description."""
272  blocks = [
273      resnet_v2_block('block1', base_depth=64, num_units=3, stride=2),
274      resnet_v2_block('block2', base_depth=128, num_units=4, stride=2),
275      resnet_v2_block('block3', base_depth=256, num_units=6, stride=2),
276      resnet_v2_block('block4', base_depth=512, num_units=3, stride=1),
277  ]
278  return resnet_v2(
279      inputs,
280      blocks,
281      num_classes,
282      is_training,
283      global_pool,
284      output_stride,
285      include_root_block=True,
286      reuse=reuse,
287      scope=scope)
288
289
290def resnet_v2_101(inputs,
291                  num_classes=None,
292                  is_training=True,
293                  global_pool=True,
294                  output_stride=None,
295                  reuse=None,
296                  scope='resnet_v2_101'):
297  """ResNet-101 model of [1]. See resnet_v2() for arg and return description."""
298  blocks = [
299      resnet_v2_block('block1', base_depth=64, num_units=3, stride=2),
300      resnet_v2_block('block2', base_depth=128, num_units=4, stride=2),
301      resnet_v2_block('block3', base_depth=256, num_units=23, stride=2),
302      resnet_v2_block('block4', base_depth=512, num_units=3, stride=1),
303  ]
304  return resnet_v2(
305      inputs,
306      blocks,
307      num_classes,
308      is_training,
309      global_pool,
310      output_stride,
311      include_root_block=True,
312      reuse=reuse,
313      scope=scope)
314
315
316def resnet_v2_152(inputs,
317                  num_classes=None,
318                  is_training=True,
319                  global_pool=True,
320                  output_stride=None,
321                  reuse=None,
322                  scope='resnet_v2_152'):
323  """ResNet-152 model of [1]. See resnet_v2() for arg and return description."""
324  blocks = [
325      resnet_v2_block('block1', base_depth=64, num_units=3, stride=2),
326      resnet_v2_block('block2', base_depth=128, num_units=8, stride=2),
327      resnet_v2_block('block3', base_depth=256, num_units=36, stride=2),
328      resnet_v2_block('block4', base_depth=512, num_units=3, stride=1),
329  ]
330  return resnet_v2(
331      inputs,
332      blocks,
333      num_classes,
334      is_training,
335      global_pool,
336      output_stride,
337      include_root_block=True,
338      reuse=reuse,
339      scope=scope)
340
341
342def resnet_v2_200(inputs,
343                  num_classes=None,
344                  is_training=True,
345                  global_pool=True,
346                  output_stride=None,
347                  reuse=None,
348                  scope='resnet_v2_200'):
349  """ResNet-200 model of [2]. See resnet_v2() for arg and return description."""
350  blocks = [
351      resnet_v2_block('block1', base_depth=64, num_units=3, stride=2),
352      resnet_v2_block('block2', base_depth=128, num_units=24, stride=2),
353      resnet_v2_block('block3', base_depth=256, num_units=36, stride=2),
354      resnet_v2_block('block4', base_depth=512, num_units=3, stride=1),
355  ]
356  return resnet_v2(
357      inputs,
358      blocks,
359      num_classes,
360      is_training,
361      global_pool,
362      output_stride,
363      include_root_block=True,
364      reuse=reuse,
365      scope=scope)
366