• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains building blocks for various versions of Residual Networks.
16
17Residual networks (ResNets) were proposed in:
18  Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
19  Deep Residual Learning for Image Recognition. arXiv:1512.03385, 2015
20
21More variants were introduced in:
22  Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
23  Identity Mappings in Deep Residual Networks. arXiv: 1603.05027, 2016
24
25We can obtain different ResNet variants by changing the network depth, width,
26and form of residual unit. This module implements the infrastructure for
27building them. Concrete ResNet units and full ResNet networks are implemented in
28the accompanying resnet_v1.py and resnet_v2.py modules.
29
30Compared to https://github.com/KaimingHe/deep-residual-networks, in the current
31implementation we subsample the output activations in the last residual unit of
32each block, instead of subsampling the input activations in the first residual
33unit of each block. The two implementations give identical results but our
34implementation is more memory efficient.
35"""
36
37from __future__ import absolute_import
38from __future__ import division
39from __future__ import print_function
40
41import collections
42
43from tensorflow.contrib import layers as layers_lib
44from tensorflow.contrib.framework.python.ops import add_arg_scope
45from tensorflow.contrib.framework.python.ops import arg_scope
46from tensorflow.contrib.layers.python.layers import initializers
47from tensorflow.contrib.layers.python.layers import layers
48from tensorflow.contrib.layers.python.layers import regularizers
49from tensorflow.contrib.layers.python.layers import utils
50from tensorflow.python.framework import ops
51from tensorflow.python.ops import array_ops
52from tensorflow.python.ops import nn_ops
53from tensorflow.python.ops import variable_scope
54
55
56class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])):
57  """A named tuple describing a ResNet block.
58
59  Its parts are:
60    scope: The scope of the `Block`.
61    unit_fn: The ResNet unit function which takes as input a `Tensor` and
62      returns another `Tensor` with the output of the ResNet unit.
63    args: A list of length equal to the number of units in the `Block`. The list
64      contains one (depth, depth_bottleneck, stride) tuple for each unit in the
65      block to serve as argument to unit_fn.
66  """
67
68
69def subsample(inputs, factor, scope=None):
70  """Subsamples the input along the spatial dimensions.
71
72  Args:
73    inputs: A `Tensor` of size [batch, height_in, width_in, channels].
74    factor: The subsampling factor.
75    scope: Optional variable_scope.
76
77  Returns:
78    output: A `Tensor` of size [batch, height_out, width_out, channels] with the
79      input, either intact (if factor == 1) or subsampled (if factor > 1).
80  """
81  if factor == 1:
82    return inputs
83  else:
84    return layers.max_pool2d(inputs, [1, 1], stride=factor, scope=scope)
85
86
87def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None):
88  """Strided 2-D convolution with 'SAME' padding.
89
90  When stride > 1, then we do explicit zero-padding, followed by conv2d with
91  'VALID' padding.
92
93  Note that
94
95     net = conv2d_same(inputs, num_outputs, 3, stride=stride)
96
97  is equivalent to
98
99     net = tf.contrib.layers.conv2d(inputs, num_outputs, 3, stride=1,
100     padding='SAME')
101     net = subsample(net, factor=stride)
102
103  whereas
104
105     net = tf.contrib.layers.conv2d(inputs, num_outputs, 3, stride=stride,
106     padding='SAME')
107
108  is different when the input's height or width is even, which is why we add the
109  current function. For more details, see ResnetUtilsTest.testConv2DSameEven().
110
111  Args:
112    inputs: A 4-D tensor of size [batch, height_in, width_in, channels].
113    num_outputs: An integer, the number of output filters.
114    kernel_size: An int with the kernel_size of the filters.
115    stride: An integer, the output stride.
116    rate: An integer, rate for atrous convolution.
117    scope: Scope.
118
119  Returns:
120    output: A 4-D tensor of size [batch, height_out, width_out, channels] with
121      the convolution output.
122  """
123  if stride == 1:
124    return layers_lib.conv2d(
125        inputs,
126        num_outputs,
127        kernel_size,
128        stride=1,
129        rate=rate,
130        padding='SAME',
131        scope=scope)
132  else:
133    kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
134    pad_total = kernel_size_effective - 1
135    pad_beg = pad_total // 2
136    pad_end = pad_total - pad_beg
137    inputs = array_ops.pad(
138        inputs, [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]])
139    return layers_lib.conv2d(
140        inputs,
141        num_outputs,
142        kernel_size,
143        stride=stride,
144        rate=rate,
145        padding='VALID',
146        scope=scope)
147
148
149@add_arg_scope
150def stack_blocks_dense(net,
151                       blocks,
152                       output_stride=None,
153                       outputs_collections=None):
154  """Stacks ResNet `Blocks` and controls output feature density.
155
156  First, this function creates scopes for the ResNet in the form of
157  'block_name/unit_1', 'block_name/unit_2', etc.
158
159  Second, this function allows the user to explicitly control the ResNet
160  output_stride, which is the ratio of the input to output spatial resolution.
161  This is useful for dense prediction tasks such as semantic segmentation or
162  object detection.
163
164  Most ResNets consist of 4 ResNet blocks and subsample the activations by a
165  factor of 2 when transitioning between consecutive ResNet blocks. This results
166  to a nominal ResNet output_stride equal to 8. If we set the output_stride to
167  half the nominal network stride (e.g., output_stride=4), then we compute
168  responses twice.
169
170  Control of the output feature density is implemented by atrous convolution.
171
172  Args:
173    net: A `Tensor` of size [batch, height, width, channels].
174    blocks: A list of length equal to the number of ResNet `Blocks`. Each
175      element is a ResNet `Block` object describing the units in the `Block`.
176    output_stride: If `None`, then the output will be computed at the nominal
177      network stride. If output_stride is not `None`, it specifies the requested
178      ratio of input to output spatial resolution, which needs to be equal to
179      the product of unit strides from the start up to some level of the ResNet.
180      For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1,
181      then valid values for the output_stride are 1, 2, 6, 24 or None (which
182      is equivalent to output_stride=24).
183    outputs_collections: Collection to add the ResNet block outputs.
184
185  Returns:
186    net: Output tensor with stride equal to the specified output_stride.
187
188  Raises:
189    ValueError: If the target output_stride is not valid.
190  """
191  # The current_stride variable keeps track of the effective stride of the
192  # activations. This allows us to invoke atrous convolution whenever applying
193  # the next residual unit would result in the activations having stride larger
194  # than the target output_stride.
195  current_stride = 1
196
197  # The atrous convolution rate parameter.
198  rate = 1
199
200  for block in blocks:
201    with variable_scope.variable_scope(block.scope, 'block', [net]) as sc:
202      for i, unit in enumerate(block.args):
203        if output_stride is not None and current_stride > output_stride:
204          raise ValueError('The target output_stride cannot be reached.')
205
206        with variable_scope.variable_scope('unit_%d' % (i + 1), values=[net]):
207          # If we have reached the target output_stride, then we need to employ
208          # atrous convolution with stride=1 and multiply the atrous rate by the
209          # current unit's stride for use in subsequent layers.
210          if output_stride is not None and current_stride == output_stride:
211            net = block.unit_fn(net, rate=rate, **dict(unit, stride=1))
212            rate *= unit.get('stride', 1)
213
214          else:
215            net = block.unit_fn(net, rate=1, **unit)
216            current_stride *= unit.get('stride', 1)
217      net = utils.collect_named_outputs(outputs_collections, sc.name, net)
218
219  if output_stride is not None and current_stride != output_stride:
220    raise ValueError('The target output_stride cannot be reached.')
221
222  return net
223
224
225def resnet_arg_scope(weight_decay=0.0001,
226                     batch_norm_decay=0.997,
227                     batch_norm_epsilon=1e-5,
228                     batch_norm_scale=True):
229  """Defines the default ResNet arg scope.
230
231  TODO(gpapan): The batch-normalization related default values above are
232    appropriate for use in conjunction with the reference ResNet models
233    released at https://github.com/KaimingHe/deep-residual-networks. When
234    training ResNets from scratch, they might need to be tuned.
235
236  Args:
237    weight_decay: The weight decay to use for regularizing the model.
238    batch_norm_decay: The moving average decay when estimating layer activation
239      statistics in batch normalization.
240    batch_norm_epsilon: Small constant to prevent division by zero when
241      normalizing activations by their variance in batch normalization.
242    batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the
243      activations in the batch normalization layer.
244
245  Returns:
246    An `arg_scope` to use for the resnet models.
247  """
248  batch_norm_params = {
249      'decay': batch_norm_decay,
250      'epsilon': batch_norm_epsilon,
251      'scale': batch_norm_scale,
252      'updates_collections': ops.GraphKeys.UPDATE_OPS,
253  }
254
255  with arg_scope(
256      [layers_lib.conv2d],
257      weights_regularizer=regularizers.l2_regularizer(weight_decay),
258      weights_initializer=initializers.variance_scaling_initializer(),
259      activation_fn=nn_ops.relu,
260      normalizer_fn=layers.batch_norm,
261      normalizer_params=batch_norm_params):
262    with arg_scope([layers.batch_norm], **batch_norm_params):
263      # The following implies padding='SAME' for pool1, which makes feature
264      # alignment easier for dense prediction tasks. This is also used in
265      # https://github.com/facebook/fb.resnet.torch. However the accompanying
266      # code of 'Deep Residual Learning for Image Recognition' uses
267      # padding='VALID' for pool1. You can switch to that choice by setting
268      # tf.contrib.framework.arg_scope([tf.contrib.layers.max_pool2d], padding='VALID').
269      with arg_scope([layers.max_pool2d], padding='SAME') as arg_sc:
270        return arg_sc
271