1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helper functions to add support for magnitude-based model pruning.
16
17  # Adds variables and ops to the graph to enable
18  # elementwise masking of weights
19  apply_mask(weights)
20
21  # Returns a list containing the sparsity of each of the weight tensors
22  get_weight_sparsity()
23
24  # Returns a list of all the masked weight tensorflow variables
25  get_masked_weights()
26
27  # Returns a list of all the mask tensorflow variables
28  get_masks()
29
30  # Returns a list of all the thresholds
31  get_thresholds()
32
33  # Returns a list of all the weight tensors that have been masked
34  get_weights()
35
36  The Pruning class uses a tf.hparams object to set up the
37  parameters for a model pruning. Here's a typical usage:
38
39  # Parse pruning hyperparameters
40  pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)
41
42  # Create a pruning object using the pruning_hparams
43  p = pruning.Pruning(pruning_hparams)
44
45  # Add mask update ops to the graph
46  mask_update_op = p.conditional_mask_update_op()
47
48  # Add the summaries
49  p.add_pruning_summaries()
50
51  # Run the op
52  session.run(mask_update_op)
53
54  # An object of the pruning also accepts externally defined sparsity:
55  sparsity = tf.Variable(0.5, name = "ConstantSparsity")
56  p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
57"""
58# pylint: disable=missing-docstring
59from __future__ import absolute_import
60from __future__ import division
61from __future__ import print_function
62
63from tensorflow.contrib.model_pruning.python import pruning_utils
64from tensorflow.contrib.model_pruning.python.layers import core_layers as core
65from tensorflow.contrib.training.python.training import hparam
66from tensorflow.python.framework import dtypes
67from tensorflow.python.framework import ops
68from tensorflow.python.ops import array_ops
69from tensorflow.python.ops import control_flow_ops
70from tensorflow.python.ops import init_ops
71from tensorflow.python.ops import math_ops
72from tensorflow.python.ops import nn_impl
73from tensorflow.python.ops import nn_ops
74from tensorflow.python.ops import state_ops
75from tensorflow.python.ops import variable_scope
76from tensorflow.python.ops import variables
77from tensorflow.python.platform import tf_logging as logging
78from tensorflow.python.summary import summary
79from tensorflow.python.training import training_util
80
81_MASK_COLLECTION = core.MASK_COLLECTION
82_THRESHOLD_COLLECTION = core.THRESHOLD_COLLECTION
83_MASKED_WEIGHT_COLLECTION = core.MASKED_WEIGHT_COLLECTION
84_WEIGHT_COLLECTION = core.WEIGHT_COLLECTION
85_MASKED_WEIGHT_NAME = core.MASKED_WEIGHT_NAME
86
87
88def apply_mask(x, scope=''):
89  """Apply mask to a given weight tensor.
90
91  Args:
92    x: Input weight tensor
93    scope: The current variable scope. Defaults to "".
94  Returns:
95    Tensor representing masked_weights
96  """
97
98  mask = pruning_utils.weight_mask_variable(x, scope)
99  threshold = pruning_utils.weight_threshold_variable(x, scope)
100  # Add masked_weights in the weights namescope so as to make it easier
101  # for the quantization library to add quant ops.
102  masked_weights = math_ops.multiply(mask, x, _MASKED_WEIGHT_NAME)
103
104  # Make sure the mask for a given variable are not added multiple times to the
105  # collection. This is particularly important when applying mask to RNN's
106  # weight variables
107  if mask not in ops.get_collection_ref(_MASK_COLLECTION):
108    ops.add_to_collection(_THRESHOLD_COLLECTION, threshold)
109    ops.add_to_collection(_MASK_COLLECTION, mask)
110    ops.add_to_collection(_MASKED_WEIGHT_COLLECTION, masked_weights)
111    ops.add_to_collection(_WEIGHT_COLLECTION, x)
112  return masked_weights
113
114
115def get_masked_weights():
116  return ops.get_collection(_MASKED_WEIGHT_COLLECTION)
117
118
119def get_masks():
120  return ops.get_collection(_MASK_COLLECTION)
121
122
123def get_thresholds():
124  return ops.get_collection(_THRESHOLD_COLLECTION)
125
126
127def get_weights():
128  return ops.get_collection(_WEIGHT_COLLECTION)
129
130
131def get_weight_sparsity():
132  """Get sparsity of the weights.
133
134  Args:
135    None
136
137  Returns:
138    A list containing the sparsity of each of the weight tensors
139  """
140  masks = get_masks()
141  return [nn_impl.zero_fraction(mask) for mask in masks]
142
143
144def get_pruning_hparams():
145  """Get a tf.HParams object with the default values for the hyperparameters.
146
147    name: string
148      name of the pruning specification. Used for adding summaries and ops under
149      a common tensorflow name_scope
150    begin_pruning_step: integer
151      the global step at which to begin pruning
152    end_pruning_step: integer
153      the global step at which to terminate pruning. Defaults to -1 implying
154      that pruning continues till the training stops
155    weight_sparsity_map: list of strings
156       comma separed list of weight variable name:target sparsity pairs.
157       For layers/weights not in this list, sparsity as specified by the
158       target_sparsity hyperparameter is used.
159       Eg. [conv1:0.9,conv2/kernel:0.8]
160    threshold_decay: float
161      the decay factor to use for exponential decay of the thresholds
162    pruning_frequency: integer
163      How often should the masks be updated? (in # of global_steps)
164    nbins: integer
165      number of bins to use for histogram computation
166    block_height: integer
167      number of rows in a block (defaults to 1)
168    block_width: integer
169      number of cols in a block (defaults to 1)
170    block_pooling_function: string
171      Whether to perform average (AVG) or max (MAX) pooling in the block
172      (default: AVG)
173    initial_sparsity: float
174      initial sparsity value
175    target_sparsity: float
176      target sparsity value
177    sparsity_function_begin_step: integer
178      the global step at this which the gradual sparsity function begins to
179      take effect
180    sparsity_function_end_step: integer
181      the global step used as the end point for the gradual sparsity function
182    sparsity_function_exponent: float
183      exponent = 1 is linearly varying sparsity between initial and final.
184      exponent > 1 varies more slowly towards the end than the beginning
185    use_tpu: False
186      Indicates whether to use TPU
187
188    We use the following sparsity function:
189
190    num_steps = (sparsity_function_end_step -
191                 sparsity_function_begin_step)/pruning_frequency
192    sparsity(step) = (initial_sparsity - target_sparsity)*
193                     [1-step/(num_steps -1)]**exponent + target_sparsity
194
195  Args:
196    None
197
198  Returns:
199    tf.HParams object initialized to default values
200
201  """
202  return hparam.HParams(
203      name='model_pruning',
204      begin_pruning_step=0,
205      end_pruning_step=-1,
206      weight_sparsity_map=[''],
207      threshold_decay=0.0,
208      pruning_frequency=10,
209      nbins=256,
210      block_height=1,
211      block_width=1,
212      block_pooling_function='AVG',
213      initial_sparsity=0.0,
214      target_sparsity=0.5,
215      sparsity_function_begin_step=0,
216      sparsity_function_end_step=100,
217      sparsity_function_exponent=3.0,
218      use_tpu=False)
219
220
221class Pruning(object):
222
223  def __init__(self, spec=None, global_step=None, sparsity=None):
224    """Set up the specification for model pruning.
225
226    If a spec is provided, the sparsity is set up based on the sparsity_function
227    in the spec. The effect of sparsity_function is overridden if the sparsity
228    variable is passed to the constructor. This enables setting up arbitrary
229    sparsity profiles externally and passing it to this pruning functions.
230
231    Args:
232      spec: Pruning spec as defined in pruning.proto
233      global_step: A tensorflow variable that is used while setting up the
234        sparsity function
235      sparsity: A tensorflow scalar variable storing the sparsity
236    """
237    # Pruning specification
238    self._spec = spec if spec else get_pruning_hparams()
239
240    # Sanity check for pruning hparams
241    self._validate_spec()
242
243    # A tensorflow variable that tracks the sparsity function.
244    # If not provided as input, the graph must already contain the global_step
245    # variable before calling this constructor.
246    self._global_step = self._setup_global_step(global_step)
247
248    # Stores the tensorflow sparsity variable.
249    # Built using self._setup_sparsity() or provided externally
250    self._sparsity = (sparsity
251                      if sparsity is not None else self._setup_sparsity())
252
253    # List of tensorflow assignments ops for new masks and thresholds
254    self._assign_ops = []
255
256    # Tensorflow variable keeping track of the last global step when the masks
257    # were updated
258    self._last_update_step = self._setup_last_update_step()
259
260    # Block dimensions
261    self._block_dim = [self._spec.block_height, self._spec.block_width]
262
263    # Block pooling function
264    self._block_pooling_function = self._spec.block_pooling_function
265
266    # Mapping of weight names and target sparsity
267    self._weight_sparsity_map = self._get_weight_sparsity_map()
268
269  def _validate_spec(self):
270    spec = self._spec
271    if spec.begin_pruning_step < 0:
272      raise ValueError('Illegal value for begin_pruning_step')
273
274    if spec.begin_pruning_step >= spec.end_pruning_step:
275      if spec.end_pruning_step != -1:
276        raise ValueError(
277            'Pruning must begin before it can end. begin_step=%d, end_step=%d.'
278            'Set end_pruning_step to -1 if pruning is required till training'
279            'stops' % (spec.begin_pruning_step, spec.end_pruning_step))
280
281    if spec.sparsity_function_begin_step < 0:
282      raise ValueError('Illegal value for sparsity_function_begin_step')
283
284    if spec.sparsity_function_begin_step >= spec.sparsity_function_end_step:
285      raise ValueError(
286          'Sparsity function requires begin_step < end_step')
287
288    if not 0.0 <= spec.threshold_decay < 1.0:
289      raise ValueError('threshold_decay must be in range [0,1)')
290
291    if not 0.0 <= spec.initial_sparsity < 1.0:
292      raise ValueError('initial_sparsity must be in range [0,1)')
293
294    if not 0.0 <= spec.target_sparsity < 1.0:
295      raise ValueError('target_sparsity must be in range [0,1)')
296
297  def _setup_global_step(self, global_step):
298    graph_global_step = global_step
299    if graph_global_step is None:
300      graph_global_step = training_util.get_global_step()
301
302    return math_ops.cast(graph_global_step, dtypes.int32)
303
304  def _setup_sparsity(self):
305    begin_step = self._spec.sparsity_function_begin_step
306    end_step = self._spec.sparsity_function_end_step
307    initial_sparsity = self._spec.initial_sparsity
308    target_sparsity = self._spec.target_sparsity
309    exponent = self._spec.sparsity_function_exponent
310
311    with ops.name_scope(self._spec.name):
312      p = math_ops.minimum(
313          1.0,
314          math_ops.maximum(
315              0.0,
316              math_ops.div(
317                  math_ops.cast(self._global_step - begin_step, dtypes.float32),
318                  end_step - begin_step)))
319      sparsity = math_ops.add(
320          math_ops.multiply(initial_sparsity - target_sparsity,
321                            math_ops.pow(1 - p, exponent)),
322          target_sparsity,
323          name='sparsity')
324
325    return sparsity
326
327  def _setup_last_update_step(self):
328    with variable_scope.variable_scope(
329        self._spec.name, use_resource=self._spec.use_tpu) as scope:
330      try:
331        last_update_step = variable_scope.get_variable(
332            'last_mask_update_step', [],
333            initializer=init_ops.zeros_initializer(),
334            trainable=False,
335            dtype=dtypes.int32)
336      except ValueError:
337        scope.reuse_variables()
338        last_update_step = variable_scope.get_variable(
339            'last_mask_update_step', dtype=dtypes.int32)
340    return last_update_step
341
342  def _get_weight_sparsity_map(self):
343    """Return the map of weight_name:sparsity parsed from the hparams."""
344    weight_sparsity_map = {}
345    val_list = self._spec.weight_sparsity_map
346    filtered_val_list = [l for l in val_list if l]
347    for val in filtered_val_list:
348      weight_name, sparsity = val.split(':')
349      if float(sparsity) >= 1.0:
350        raise ValueError('Weight sparsity can not exceed 1.0')
351      weight_sparsity_map[weight_name] = float(sparsity)
352
353    return weight_sparsity_map
354
355  def _get_sparsity(self, weight_name):
356    """Return target sparsity for the given layer/weight name."""
357    target_sparsity = [
358        sparsity for name, sparsity in self._weight_sparsity_map.items()
359        if weight_name.find(name) != -1
360    ]
361    if not target_sparsity:
362      return self._sparsity
363
364    if len(target_sparsity) > 1:
365      raise ValueError(
366          'Multiple matches in weight_sparsity_map for weight %s' % weight_name)
367    # TODO(suyoggupta): This will work when initial_sparsity = 0. Generalize
368    # to handle other cases as well.
369    return math_ops.mul(
370        self._sparsity,
371        math_ops.div(target_sparsity[0], self._spec.target_sparsity))
372
373  def _update_mask(self, weights, threshold):
374    """Updates the mask for a given weight tensor.
375
376    This functions first computes the cdf of the weight tensor, and estimates
377    the threshold value such that 'desired_sparsity' fraction of weights
378    have magnitude less than the threshold.
379
380    Args:
381      weights: The weight tensor that needs to be masked.
382      threshold: The current threshold value. The function will compute a new
383        threshold and return the exponential moving average using the current
384        value of threshold
385
386    Returns:
387      new_threshold: The new value of the threshold based on weights, and
388        sparsity at the current global_step
389      new_mask: A numpy array of the same size and shape as weights containing
390        0 or 1 to indicate which of the values in weights falls below
391        the threshold
392
393    Raises:
394      ValueError: if sparsity is not defined
395    """
396    if self._sparsity is None:
397      raise ValueError('Sparsity variable undefined')
398
399    sparsity = self._get_sparsity(weights.op.name)
400    with ops.name_scope(weights.op.name + '_pruning_ops'):
401      abs_weights = math_ops.abs(weights)
402      k = math_ops.cast(
403          math_ops.round(
404              math_ops.cast(array_ops.size(abs_weights), dtypes.float32) *
405              (1 - sparsity)), dtypes.int32)
406      # Sort the entire array
407      values, _ = nn_ops.top_k(
408          array_ops.reshape(abs_weights, [-1]), k=array_ops.size(abs_weights))
409      # Grab the (k-1) th value
410      current_threshold = array_ops.gather(values, k - 1)
411      smoothed_threshold = math_ops.add_n([
412          math_ops.multiply(current_threshold, 1 - self._spec.threshold_decay),
413          math_ops.multiply(threshold, self._spec.threshold_decay)
414      ])
415
416      new_mask = math_ops.cast(
417          math_ops.greater_equal(abs_weights, smoothed_threshold),
418          dtypes.float32)
419
420    return smoothed_threshold, new_mask
421
422  def _maybe_update_block_mask(self, weights, threshold):
423    """Performs block-granular masking of the weights.
424
425    Block pruning occurs only if the block_height or block_width is > 1 and
426    if the weight tensor, when squeezed, has ndims = 2. Otherwise, elementwise
427    pruning occurs.
428    Args:
429      weights: The weight tensor that needs to be masked.
430      threshold: The current threshold value. The function will compute a new
431        threshold and return the exponential moving average using the current
432        value of threshold
433
434    Returns:
435      new_threshold: The new value of the threshold based on weights, and
436        sparsity at the current global_step
437      new_mask: A numpy array of the same size and shape as weights containing
438        0 or 1 to indicate which of the values in weights falls below
439        the threshold
440
441    Raises:
442      ValueError: if block pooling function is not AVG or MAX
443    """
444    squeezed_weights = array_ops.squeeze(weights)
445    if squeezed_weights.get_shape().ndims != 2 or self._block_dim == [1, 1]:
446      return self._update_mask(weights, threshold)
447
448    if self._block_pooling_function not in ['AVG', 'MAX']:
449      raise ValueError('Unknown pooling function for block sparsity: %s' %
450                       self._block_pooling_function)
451
452    with ops.name_scope(weights.op.name + '_pruning_ops'):
453      abs_weights = math_ops.abs(squeezed_weights)
454
455      pool_window = [self._block_dim[0], self._block_dim[1]]
456      pool_fn = pruning_utils.factorized_pool
457      squeeze_axis = None
458      if not self._spec.use_tpu:
459        pool_fn = nn_ops.pool
460        abs_weights = array_ops.reshape(
461            abs_weights,
462            [1, abs_weights.get_shape()[0],
463             abs_weights.get_shape()[1], 1])
464        squeeze_axis = [0, 3]
465
466      pooled_weights = pool_fn(
467          abs_weights,
468          window_shape=pool_window,
469          pooling_type=self._block_pooling_function,
470          strides=pool_window,
471          padding='SAME',
472          name=weights.op.name + '_pooled')
473
474      if pooled_weights.get_shape().ndims != 2:
475        pooled_weights = array_ops.squeeze(pooled_weights, axis=squeeze_axis)
476
477      smoothed_threshold, new_mask = self._update_mask(pooled_weights,
478                                                       threshold)
479
480      updated_mask = pruning_utils.expand_tensor(new_mask, self._block_dim)
481      sliced_mask = array_ops.slice(
482          updated_mask, [0, 0],
483          [squeezed_weights.get_shape()[0],
484           squeezed_weights.get_shape()[1]])
485
486    return smoothed_threshold, array_ops.reshape(sliced_mask,
487                                                 array_ops.shape(weights))
488
489  def _get_mask_assign_ops(self):
490    # Make sure the assignment ops have not already been added to the list
491    if self._assign_ops:
492      raise ValueError(
493          'Assign op list not empty. _get_mask_assign_ops() called twice?')
494
495    masks = get_masks()
496    weights = get_weights()
497    thresholds = get_thresholds()
498
499    if len(masks) != len(thresholds):
500      raise ValueError(
501          'Number of masks %s and number of thresholds %s mismatch' %
502          (len(masks), len(thresholds)))
503
504    for index, mask in enumerate(masks):
505      threshold = thresholds[index]
506      weight = weights[index]
507      is_partitioned = isinstance(weight, variables.PartitionedVariable)
508      if is_partitioned:
509        weight = weight.as_tensor()
510
511      new_threshold, new_mask = self._maybe_update_block_mask(weight, threshold)
512      self._assign_ops.append(
513          pruning_utils.variable_assign(threshold, new_threshold))
514
515      self._assign_ops.append(
516          pruning_utils.partitioned_variable_assign(mask, new_mask)
517          if is_partitioned else pruning_utils.variable_assign(mask, new_mask))
518
519  def mask_update_op(self):
520    with ops.name_scope(self._spec.name):
521      if not self._assign_ops:
522        self._get_mask_assign_ops()
523      with ops.control_dependencies([
524          state_ops.assign(
525              self._last_update_step,
526              self._global_step,
527              name='last_mask_update_step_assign')
528      ]):
529        with ops.control_dependencies(self._assign_ops):
530          logging.info('Updating masks.')
531          return control_flow_ops.no_op('mask_update')
532
533  def conditional_mask_update_op(self):
534
535    def maybe_update_masks():
536      with ops.name_scope(self._spec.name):
537        is_step_within_pruning_range = math_ops.logical_and(
538            math_ops.greater_equal(self._global_step,
539                                   self._spec.begin_pruning_step),
540            # If end_pruning_step is negative, keep pruning forever!
541            math_ops.logical_or(
542                math_ops.less_equal(self._global_step,
543                                    self._spec.end_pruning_step),
544                math_ops.less(self._spec.end_pruning_step, 0)))
545        is_pruning_step = math_ops.less_equal(
546            math_ops.add(self._last_update_step, self._spec.pruning_frequency),
547            self._global_step)
548        return math_ops.logical_and(is_step_within_pruning_range,
549                                    is_pruning_step)
550
551    def mask_update_op():
552      return self.mask_update_op()
553
554    def no_update_op():
555      return control_flow_ops.no_op()
556
557    return control_flow_ops.cond(maybe_update_masks(), mask_update_op,
558                                 no_update_op)
559
560  def add_pruning_summaries(self):
561    """Adds summaries of weight sparsities and thresholds."""
562    with ops.name_scope(self._spec.name + '_summaries'):
563      summary.scalar('sparsity', self._sparsity)
564      summary.scalar('last_mask_update_step', self._last_update_step)
565      masks = get_masks()
566      thresholds = get_thresholds()
567      for mask, threshold in zip(masks, thresholds):
568        summary.scalar(mask.op.name + '/sparsity', nn_impl.zero_fraction(mask))
569        summary.scalar(threshold.op.name + '/threshold', threshold)
570
571  def print_hparams(self):
572    logging.info(self._spec.to_json())
573