1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for strip_pruning_vars."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import re
21
22from tensorflow.contrib.model_pruning.python import pruning
23from tensorflow.contrib.model_pruning.python import strip_pruning_vars_lib
24from tensorflow.contrib.model_pruning.python.layers import layers
25from tensorflow.contrib.model_pruning.python.layers import rnn_cells
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import graph_util
28from tensorflow.python.framework import importer
29from tensorflow.python.framework import ops
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import random_ops
32from tensorflow.python.ops import rnn
33from tensorflow.python.ops import rnn_cell as tf_rnn_cells
34from tensorflow.python.ops import state_ops
35from tensorflow.python.ops import variable_scope
36from tensorflow.python.ops import variables
37from tensorflow.python.platform import test
38from tensorflow.python.training import training_util
39
40
41def _get_number_pruning_vars(graph_def):
42  number_vars = 0
43  for node in graph_def.node:
44    if re.match(r"^.*(mask$)|(threshold$)", node.name):
45      number_vars += 1
46  return number_vars
47
48
49def _get_node_names(tensor_names):
50  return [
51      strip_pruning_vars_lib._node_name(tensor_name)
52      for tensor_name in tensor_names
53  ]
54
55
56class StripPruningVarsTest(test.TestCase):
57
58  def setUp(self):
59    param_list = [
60        "pruning_frequency=1", "begin_pruning_step=1", "end_pruning_step=10",
61        "nbins=2048", "threshold_decay=0.0"
62    ]
63    self.initial_graph = ops.Graph()
64    self.initial_graph_def = None
65    self.final_graph = ops.Graph()
66    self.final_graph_def = None
67    self.pruning_spec = ",".join(param_list)
68    with self.initial_graph.as_default():
69      self.sparsity = variables.Variable(0.5, name="sparsity")
70      self.global_step = training_util.get_or_create_global_step()
71      self.increment_global_step = state_ops.assign_add(self.global_step, 1)
72      self.mask_update_op = None
73
74  def _build_convolutional_model(self, number_of_layers):
75    # Create a graph with several conv2d layers
76    kernel_size = 3
77    base_depth = 4
78    depth_step = 7
79    height, width = 7, 9
80    with variable_scope.variable_scope("conv_model"):
81      input_tensor = array_ops.ones((8, height, width, base_depth))
82      top_layer = input_tensor
83      for ix in range(number_of_layers):
84        top_layer = layers.masked_conv2d(
85            top_layer,
86            base_depth + (ix + 1) * depth_step,
87            kernel_size,
88            scope="Conv_" + str(ix))
89
90    return top_layer
91
92  def _build_fully_connected_model(self, number_of_layers):
93    base_depth = 4
94    depth_step = 7
95
96    input_tensor = array_ops.ones((8, base_depth))
97
98    top_layer = input_tensor
99
100    with variable_scope.variable_scope("fc_model"):
101      for ix in range(number_of_layers):
102        top_layer = layers.masked_fully_connected(
103            top_layer, base_depth + (ix + 1) * depth_step)
104
105    return top_layer
106
107  def _build_lstm_model(self, number_of_layers):
108    batch_size = 8
109    dim = 10
110    inputs = variables.Variable(random_ops.random_normal([batch_size, dim]))
111
112    def lstm_cell():
113      return rnn_cells.MaskedBasicLSTMCell(
114          dim, forget_bias=0.0, state_is_tuple=True, reuse=False)
115
116    cell = tf_rnn_cells.MultiRNNCell(
117        [lstm_cell() for _ in range(number_of_layers)], state_is_tuple=True)
118
119    outputs = rnn.static_rnn(
120        cell, [inputs],
121        initial_state=cell.zero_state(batch_size, dtypes.float32))
122
123    return outputs
124
125  def _prune_model(self, session):
126    pruning_hparams = pruning.get_pruning_hparams().parse(self.pruning_spec)
127    p = pruning.Pruning(pruning_hparams, sparsity=self.sparsity)
128    self.mask_update_op = p.conditional_mask_update_op()
129
130    variables.global_variables_initializer().run()
131    for _ in range(20):
132      session.run(self.mask_update_op)
133      session.run(self.increment_global_step)
134
135  def _get_outputs(self, session, input_graph, tensors_list, graph_prefix=None):
136    outputs = []
137
138    for output_tensor in tensors_list:
139      if graph_prefix:
140        output_tensor = graph_prefix + "/" + output_tensor
141      outputs.append(
142          session.run(session.graph.get_tensor_by_name(output_tensor)))
143
144    return outputs
145
146  def _get_initial_outputs(self, output_tensor_names_list):
147    with self.session(graph=self.initial_graph) as sess1:
148      self._prune_model(sess1)
149      reference_outputs = self._get_outputs(sess1, self.initial_graph,
150                                            output_tensor_names_list)
151
152      self.initial_graph_def = graph_util.convert_variables_to_constants(
153          sess1, sess1.graph.as_graph_def(),
154          _get_node_names(output_tensor_names_list))
155    return reference_outputs
156
157  def _get_final_outputs(self, output_tensor_names_list):
158    self.final_graph_def = strip_pruning_vars_lib.strip_pruning_vars_fn(
159        self.initial_graph_def, _get_node_names(output_tensor_names_list))
160    _ = importer.import_graph_def(self.final_graph_def, name="final")
161
162    with self.test_session(self.final_graph) as sess2:
163      final_outputs = self._get_outputs(
164          sess2,
165          self.final_graph,
166          output_tensor_names_list,
167          graph_prefix="final")
168    return final_outputs
169
170  def _check_removal_of_pruning_vars(self, number_masked_layers):
171    self.assertEqual(
172        _get_number_pruning_vars(self.initial_graph_def), number_masked_layers)
173    self.assertEqual(_get_number_pruning_vars(self.final_graph_def), 0)
174
175  def _check_output_equivalence(self, initial_outputs, final_outputs):
176    for initial_output, final_output in zip(initial_outputs, final_outputs):
177      self.assertAllEqual(initial_output, final_output)
178
179  def testConvolutionalModel(self):
180    with self.initial_graph.as_default():
181      number_masked_conv_layers = 5
182      top_layer = self._build_convolutional_model(number_masked_conv_layers)
183      output_tensor_names = [top_layer.name]
184      initial_outputs = self._get_initial_outputs(output_tensor_names)
185
186    # Remove pruning-related nodes.
187    with self.final_graph.as_default():
188      final_outputs = self._get_final_outputs(output_tensor_names)
189
190    # Check that the final graph has no pruning-related vars
191    self._check_removal_of_pruning_vars(number_masked_conv_layers)
192
193    # Check that outputs remain the same after removal of pruning-related nodes
194    self._check_output_equivalence(initial_outputs, final_outputs)
195
196  def testFullyConnectedModel(self):
197    with self.initial_graph.as_default():
198      number_masked_fc_layers = 3
199      top_layer = self._build_fully_connected_model(number_masked_fc_layers)
200      output_tensor_names = [top_layer.name]
201      initial_outputs = self._get_initial_outputs(output_tensor_names)
202
203    # Remove pruning-related nodes.
204    with self.final_graph.as_default():
205      final_outputs = self._get_final_outputs(output_tensor_names)
206
207    # Check that the final graph has no pruning-related vars
208    self._check_removal_of_pruning_vars(number_masked_fc_layers)
209
210    # Check that outputs remain the same after removal of pruning-related nodes
211    self._check_output_equivalence(initial_outputs, final_outputs)
212
213  def testLSTMModel(self):
214    with self.initial_graph.as_default():
215      number_masked_lstm_layers = 2
216      outputs = self._build_lstm_model(number_masked_lstm_layers)
217      output_tensor_names = [outputs[0][0].name]
218      initial_outputs = self._get_initial_outputs(output_tensor_names)
219
220    # Remove pruning-related nodes.
221    with self.final_graph.as_default():
222      final_outputs = self._get_final_outputs(output_tensor_names)
223
224    # Check that the final graph has no pruning-related vars
225    self._check_removal_of_pruning_vars(number_masked_lstm_layers)
226
227    # Check that outputs remain the same after removal of pruning-related nodes
228    self._check_output_equivalence(initial_outputs, final_outputs)
229
230
231if __name__ == "__main__":
232  test.main()
233