1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains the core layer classes for model pruning and its functional aliases.
16"""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import tensor_shape
24from tensorflow.python.keras.engine import input_spec
25from tensorflow.python.layers import base
26from tensorflow.python.layers import utils
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import init_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops import nn
31from tensorflow.python.ops import standard_ops
32
33MASK_COLLECTION = 'masks'
34THRESHOLD_COLLECTION = 'thresholds'
35MASKED_WEIGHT_COLLECTION = 'masked_weights'
36WEIGHT_COLLECTION = 'kernel'
37# The 'weights' part of the name is needed for the quantization library
38# to recognize that the kernel should be quantized.
39MASKED_WEIGHT_NAME = 'weights/masked_weight'
40
41
42class _MaskedConv(base.Layer):
43  """Abstract nD convolution layer (private, used as implementation base).
44
45  This layer creates a convolution kernel that is convolved
46  (actually cross-correlated) with the layer input to produce a tensor of
47  outputs. The weight tensor of this layer is masked.
48  If `use_bias` is True (and a `bias_initializer` is provided),
49  a bias vector is created and added to the outputs. Finally, if
50  `activation` is not `None`, it is applied to the outputs as well.
51
52  Arguments:
53    rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution.
54    filters: Integer, the dimensionality of the output space (i.e. the number
55      of filters in the convolution).
56    kernel_size: An integer or tuple/list of n integers, specifying the
57      length of the convolution window.
58    strides: An integer or tuple/list of n integers,
59      specifying the stride length of the convolution.
60      Specifying any stride value != 1 is incompatible with specifying
61      any `dilation_rate` value != 1.
62    padding: One of `"valid"` or `"same"` (case-insensitive).
63    data_format: A string, one of `channels_last` (default) or `channels_first`.
64      The ordering of the dimensions in the inputs.
65      `channels_last` corresponds to inputs with shape
66      `(batch, ..., channels)` while `channels_first` corresponds to
67      inputs with shape `(batch, channels, ...)`.
68    dilation_rate: An integer or tuple/list of n integers, specifying
69      the dilation rate to use for dilated convolution.
70      Currently, specifying any `dilation_rate` value != 1 is
71      incompatible with specifying any `strides` value != 1.
72    activation: Activation function. Set it to None to maintain a
73      linear activation.
74    use_bias: Boolean, whether the layer uses a bias.
75    kernel_initializer: An initializer for the convolution kernel.
76    bias_initializer: An initializer for the bias vector. If None, the default
77      initializer will be used.
78    kernel_regularizer: Optional regularizer for the convolution kernel.
79    bias_regularizer: Optional regularizer for the bias vector.
80    activity_regularizer: Regularizer function for the output.
81    trainable: Boolean, if `True` also add variables to the graph collection
82      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
83    name: A string, the name of the layer.
84  """
85
86  def __init__(self,
87               rank,
88               filters,
89               kernel_size,
90               strides=1,
91               padding='valid',
92               data_format='channels_last',
93               dilation_rate=1,
94               activation=None,
95               use_bias=True,
96               kernel_initializer=None,
97               bias_initializer=init_ops.zeros_initializer(),
98               kernel_regularizer=None,
99               bias_regularizer=None,
100               activity_regularizer=None,
101               trainable=True,
102               name=None,
103               **kwargs):
104    super(_MaskedConv, self).__init__(
105        trainable=trainable,
106        name=name,
107        activity_regularizer=activity_regularizer,
108        **kwargs)
109    self.rank = rank
110    self.filters = filters
111    self.kernel_size = utils.normalize_tuple(kernel_size, rank, 'kernel_size')
112    self.strides = utils.normalize_tuple(strides, rank, 'strides')
113    self.padding = utils.normalize_padding(padding)
114    self.data_format = utils.normalize_data_format(data_format)
115    self.dilation_rate = utils.normalize_tuple(dilation_rate, rank,
116                                               'dilation_rate')
117    self.activation = activation
118    self.use_bias = use_bias
119    self.kernel_initializer = kernel_initializer
120    self.bias_initializer = bias_initializer
121    self.kernel_regularizer = kernel_regularizer
122    self.bias_regularizer = bias_regularizer
123    self.input_spec = input_spec.InputSpec(ndim=self.rank + 2)
124
125  def build(self, input_shape):
126    input_shape = tensor_shape.TensorShape(input_shape)
127    channel_axis = 1 if self.data_format == 'channels_first' else -1
128    if tensor_shape.dimension_value(input_shape[channel_axis]) is None:
129      raise ValueError('The channel dimension of the inputs '
130                       'should be defined. Found `None`.')
131    input_dim = tensor_shape.dimension_value(input_shape[channel_axis])
132    kernel_shape = self.kernel_size + (input_dim, self.filters)
133    self.mask = self.add_variable(
134        name='mask',
135        shape=kernel_shape,
136        initializer=init_ops.ones_initializer(),
137        trainable=False,
138        dtype=self.dtype)
139
140    self.kernel = self.add_variable(
141        name='kernel',
142        shape=kernel_shape,
143        initializer=self.kernel_initializer,
144        regularizer=self.kernel_regularizer,
145        trainable=True,
146        dtype=self.dtype)
147
148    self.threshold = self.add_variable(
149        name='threshold',
150        shape=[],
151        initializer=init_ops.zeros_initializer(),
152        trainable=False,
153        dtype=self.dtype)
154
155    # Add masked_weights in the weights namescope so as to make it easier
156    # for the quantization library to add quant ops.
157    self.masked_kernel = math_ops.multiply(self.mask, self.kernel,
158                                           MASKED_WEIGHT_NAME)
159
160    ops.add_to_collection(MASK_COLLECTION, self.mask)
161    ops.add_to_collection(MASKED_WEIGHT_COLLECTION, self.masked_kernel)
162    ops.add_to_collection(THRESHOLD_COLLECTION, self.threshold)
163    ops.add_to_collection(WEIGHT_COLLECTION, self.kernel)
164
165    if self.use_bias:
166      self.bias = self.add_variable(
167          name='bias',
168          shape=(self.filters,),
169          initializer=self.bias_initializer,
170          regularizer=self.bias_regularizer,
171          trainable=True,
172          dtype=self.dtype)
173    else:
174      self.bias = None
175    self.input_spec = input_spec.InputSpec(
176        ndim=self.rank + 2, axes={channel_axis: input_dim})
177    self.built = True
178
179  def call(self, inputs):
180    outputs = nn.convolution(
181        input=inputs,
182        filter=self.masked_kernel,
183        dilation_rate=self.dilation_rate,
184        strides=self.strides,
185        padding=self.padding.upper(),
186        data_format=utils.convert_data_format(self.data_format, self.rank + 2))
187
188    if self.bias is not None:
189      if self.data_format == 'channels_first':
190        if self.rank == 1:
191          # nn.bias_add does not accept a 1D input tensor.
192          bias = array_ops.reshape(self.bias, (1, self.filters, 1))
193          outputs += bias
194        if self.rank == 2:
195          outputs = nn.bias_add(outputs, self.bias, data_format='NCHW')
196        if self.rank == 3:
197          # As of Mar 2017, direct addition is significantly slower than
198          # bias_add when computing gradients. To use bias_add, we collapse Z
199          # and Y into a single dimension to obtain a 4D input tensor.
200          outputs_shape = outputs.shape.as_list()
201          outputs_4d = array_ops.reshape(outputs, [
202              outputs_shape[0], outputs_shape[1],
203              outputs_shape[2] * outputs_shape[3], outputs_shape[4]
204          ])
205          outputs_4d = nn.bias_add(outputs_4d, self.bias, data_format='NCHW')
206          outputs = array_ops.reshape(outputs_4d, outputs_shape)
207      else:
208        outputs = nn.bias_add(outputs, self.bias, data_format='NHWC')
209
210    if self.activation is not None:
211      return self.activation(outputs)
212    return outputs
213
214  def compute_output_shape(self, input_shape):
215    input_shape = tensor_shape.TensorShape(input_shape).as_list()
216    if self.data_format == 'channels_last':
217      space = input_shape[1:-1]
218      new_space = []
219      for i in range(len(space)):
220        new_dim = utils.conv_output_length(
221            space[i],
222            self.kernel_size[i],
223            padding=self.padding,
224            stride=self.strides[i],
225            dilation=self.dilation_rate[i])
226        new_space.append(new_dim)
227      return tensor_shape.TensorShape([input_shape[0]] + new_space +
228                                      [self.filters])
229    else:
230      space = input_shape[2:]
231      new_space = []
232      for i in range(len(space)):
233        new_dim = utils.conv_output_length(
234            space[i],
235            self.kernel_size[i],
236            padding=self.padding,
237            stride=self.strides[i],
238            dilation=self.dilation_rate[i])
239        new_space.append(new_dim)
240      return tensor_shape.TensorShape([input_shape[0], self.filters] +
241                                      new_space)
242
243
244class MaskedConv2D(_MaskedConv):
245  """2D convolution layer (e.g. spatial convolution over images).
246
247  This layer creates a convolution kernel that is convolved
248  (actually cross-correlated) with the layer input to produce a tensor of
249  outputs. If `use_bias` is True (and a `bias_initializer` is provided),
250  a bias vector is created and added to the outputs. Finally, if
251  `activation` is not `None`, it is applied to the outputs as well.
252
253  Arguments:
254    filters: Integer, the dimensionality of the output space (i.e. the number
255      of filters in the convolution).
256    kernel_size: An integer or tuple/list of 2 integers, specifying the
257      height and width of the 2D convolution window.
258      Can be a single integer to specify the same value for
259      all spatial dimensions.
260    strides: An integer or tuple/list of 2 integers,
261      specifying the strides of the convolution along the height and width.
262      Can be a single integer to specify the same value for
263      all spatial dimensions.
264      Specifying any stride value != 1 is incompatible with specifying
265      any `dilation_rate` value != 1.
266    padding: One of `"valid"` or `"same"` (case-insensitive).
267    data_format: A string, one of `channels_last` (default) or `channels_first`.
268      The ordering of the dimensions in the inputs.
269      `channels_last` corresponds to inputs with shape
270      `(batch, height, width, channels)` while `channels_first` corresponds to
271      inputs with shape `(batch, channels, height, width)`.
272
273    dilation_rate: An integer or tuple/list of 2 integers, specifying
274      the dilation rate to use for dilated convolution.
275      Can be a single integer to specify the same value for
276      all spatial dimensions.
277      Currently, specifying any `dilation_rate` value != 1 is
278      incompatible with specifying any stride value != 1.
279    activation: Activation function. Set it to None to maintain a
280      linear activation.
281    use_bias: Boolean, whether the layer uses a bias.
282    kernel_initializer: An initializer for the convolution kernel.
283    bias_initializer: An initializer for the bias vector. If None, the default
284      initializer will be used.
285    kernel_regularizer: Optional regularizer for the convolution kernel.
286    bias_regularizer: Optional regularizer for the bias vector.
287    activity_regularizer: Regularizer function for the output.
288    trainable: Boolean, if `True` also add variables to the graph collection
289      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
290    name: A string, the name of the layer.
291  """
292
293  def __init__(self,
294               filters,
295               kernel_size,
296               strides=(1, 1),
297               padding='valid',
298               data_format='channels_last',
299               dilation_rate=(1, 1),
300               activation=None,
301               use_bias=True,
302               kernel_initializer=None,
303               bias_initializer=init_ops.zeros_initializer(),
304               kernel_regularizer=None,
305               bias_regularizer=None,
306               activity_regularizer=None,
307               trainable=True,
308               name=None,
309               **kwargs):
310    super(MaskedConv2D, self).__init__(
311        rank=2,
312        filters=filters,
313        kernel_size=kernel_size,
314        strides=strides,
315        padding=padding,
316        data_format=data_format,
317        dilation_rate=dilation_rate,
318        activation=activation,
319        use_bias=use_bias,
320        kernel_initializer=kernel_initializer,
321        bias_initializer=bias_initializer,
322        kernel_regularizer=kernel_regularizer,
323        bias_regularizer=bias_regularizer,
324        activity_regularizer=activity_regularizer,
325        trainable=trainable,
326        name=name,
327        **kwargs)
328
329
330class MaskedFullyConnected(base.Layer):
331  """Fully-connected layer class with masked weights.
332
333  This layer implements the operation:
334  `outputs = activation(inputs.kernel + bias)`
335  Where `activation` is the activation function passed as the `activation`
336  argument (if not `None`), `kernel` is a weights matrix created by the layer,
337  and `bias` is a bias vector created by the layer
338  (only if `use_bias` is `True`).
339
340  Note: if the input to the layer has a rank greater than 2, then it is
341  flattened prior to the initial matrix multiply by `kernel`.
342
343  Arguments:
344    units: Integer or Long, dimensionality of the output space.
345    activation: Activation function (callable). Set it to None to maintain a
346      linear activation.
347    use_bias: Boolean, whether the layer uses a bias.
348    kernel_initializer: Initializer function for the weight matrix.
349    bias_initializer: Initializer function for the bias.
350    kernel_regularizer: Regularizer function for the weight matrix.
351    bias_regularizer: Regularizer function for the bias.
352    activity_regularizer: Regularizer function for the output.
353    trainable: Boolean, if `True` also add variables to the graph collection
354      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
355    name: String, the name of the layer. Layers with the same name will
356      share weights, but to avoid mistakes we require reuse=True in such cases.
357    reuse: Boolean, whether to reuse the weights of a previous layer
358      by the same name.
359
360  Properties:
361    units: Python integer, dimensionality of the output space.
362    activation: Activation function (callable).
363    use_bias: Boolean, whether the layer uses a bias.
364    kernel_initializer: Initializer instance (or name) for the weight matrix.
365    bias_initializer: Initializer instance (or name) for the bias.
366    kernel_regularizer: Regularizer instance for the weight matrix (callable)
367    bias_regularizer: Regularizer instance for the bias (callable).
368    activity_regularizer: Regularizer instance for the output (callable)
369    kernel: Weight matrix (TensorFlow variable or tensor).
370    bias: Bias vector, if applicable (TensorFlow variable or tensor).
371  """
372
373  def __init__(self,
374               units,
375               activation=None,
376               use_bias=True,
377               kernel_initializer=None,
378               bias_initializer=init_ops.zeros_initializer(),
379               kernel_regularizer=None,
380               bias_regularizer=None,
381               activity_regularizer=None,
382               trainable=True,
383               name=None,
384               **kwargs):
385    super(MaskedFullyConnected, self).__init__(
386        trainable=trainable,
387        name=name,
388        activity_regularizer=activity_regularizer,
389        **kwargs)
390    self.units = units
391    self.activation = activation
392    self.use_bias = use_bias
393    self.kernel_initializer = kernel_initializer
394    self.bias_initializer = bias_initializer
395    self.kernel_regularizer = kernel_regularizer
396    self.bias_regularizer = bias_regularizer
397    self.input_spec = input_spec.InputSpec(min_ndim=2)
398
399  def build(self, input_shape):
400    input_shape = tensor_shape.TensorShape(input_shape)
401    if tensor_shape.dimension_value(input_shape[-1]) is None:
402      raise ValueError('The last dimension of the inputs to `Dense` '
403                       'should be defined. Found `None`.')
404    self.input_spec = input_spec.InputSpec(
405        min_ndim=2, axes={-1: tensor_shape.dimension_value(input_shape[-1])})
406
407    self.kernel = self.add_variable(
408        'kernel',
409        shape=[tensor_shape.dimension_value(input_shape[-1]), self.units],
410        initializer=self.kernel_initializer,
411        regularizer=self.kernel_regularizer,
412        dtype=self.dtype,
413        trainable=True)
414
415    self.mask = self.add_variable(
416        name='mask',
417        shape=[tensor_shape.dimension_value(input_shape[-1]), self.units],
418        initializer=init_ops.ones_initializer(),
419        trainable=False,
420        dtype=self.dtype)
421
422    self.threshold = self.add_variable(
423        name='threshold',
424        shape=[],
425        initializer=init_ops.zeros_initializer(),
426        trainable=False,
427        dtype=self.dtype)
428
429    # Add masked_weights in the weights namescope so as to make it easier
430    # for the quantization library to add quant ops.
431    self.masked_kernel = math_ops.multiply(self.mask, self.kernel,
432                                           MASKED_WEIGHT_NAME)
433
434    ops.add_to_collection(MASK_COLLECTION, self.mask)
435    ops.add_to_collection(MASKED_WEIGHT_COLLECTION, self.masked_kernel)
436    ops.add_to_collection(THRESHOLD_COLLECTION, self.threshold)
437    ops.add_to_collection(WEIGHT_COLLECTION, self.kernel)
438
439    if self.use_bias:
440      self.bias = self.add_variable(
441          'bias',
442          shape=[
443              self.units,
444          ],
445          initializer=self.bias_initializer,
446          regularizer=self.bias_regularizer,
447          dtype=self.dtype,
448          trainable=True)
449    else:
450      self.bias = None
451    self.built = True
452
453  def call(self, inputs):
454    inputs = ops.convert_to_tensor(inputs, dtype=self.dtype)
455    shape = inputs.get_shape().as_list()
456    output_shape = shape[:-1] + [self.units]
457    if len(output_shape) > 2:
458      # Broadcasting is required for the inputs.
459      outputs = standard_ops.tensordot(inputs, self.masked_kernel,
460                                       [[len(shape) - 1], [0]])
461      # Reshape the output back to the original ndim of the input.
462      outputs.set_shape(output_shape)
463    else:
464      outputs = standard_ops.matmul(inputs, self.masked_kernel)
465    if self.use_bias:
466      outputs = nn.bias_add(outputs, self.bias)
467    if self.activation is not None:
468      return self.activation(outputs)  # pylint: disable=not-callable
469    return outputs
470
471  def compute_output_shape(self, input_shape):
472    input_shape = tensor_shape.TensorShape(input_shape)
473    input_shape = input_shape.with_rank_at_least(2)
474    if tensor_shape.dimension_value(input_shape[-1]) is None:
475      raise ValueError(
476          'The innermost dimension of input_shape must be defined, but saw: %s'
477          % input_shape)
478    return input_shape[:-1].concatenate(self.units)
479