1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions for adding pruning related ops to the graph.
16"""
17# pylint: disable=missing-docstring
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import numpy as np
23
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import init_ops
30from tensorflow.python.ops import nn_ops
31from tensorflow.python.ops import state_ops
32from tensorflow.python.ops import variable_scope
33
34
35def weight_mask_variable(var, scope):
36  """Create a mask for the weights.
37
38  This function adds a variable 'mask' to the graph.
39
40  Args:
41    var: the weight variable that needs to be masked
42    scope: The variable scope of the variable var
43
44  Returns:
45    the mask variable of the same size and shape as var, initialized to all 1s.
46  """
47  with variable_scope.variable_scope(scope):
48    mask = variable_scope.get_variable(
49        'mask',
50        var.get_shape(),
51        initializer=init_ops.ones_initializer(),
52        trainable=False,
53        dtype=var.dtype)
54  return mask
55
56
57def weight_threshold_variable(var, scope):
58  """Create a scalar threshold for the weights.
59
60  This function adds a variable
61  'threshold' to the graph.
62
63  Args:
64    var: The weight variable that needs to be masked
65    scope: The variable scope of the variable var
66
67  Returns:
68    A scalar threshold variable initialized to 0.
69  """
70  with variable_scope.variable_scope(scope):
71    threshold = variable_scope.get_variable(
72        'threshold', [],
73        initializer=init_ops.zeros_initializer(),
74        trainable=False,
75        dtype=var.dtype)
76    return threshold
77
78
79def kronecker_product(mat1, mat2):
80  """Computes the Kronecker product of two matrices mat1 and mat2.
81
82  Args:
83    mat1: A matrix of size m x n
84    mat2: A matrix of size p x q
85  Returns:
86    Kronecker product of matrices mat1 and mat2 of size mp x nq
87  """
88
89  m1, n1 = mat1.get_shape().as_list()
90  mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1])
91  m2, n2 = mat2.get_shape().as_list()
92  mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2])
93  return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2])
94
95
96def expand_tensor(tensor, block_dims):
97  """Expands a 2D tensor by replicating the tensor values.
98
99  This is equivalent to the kronecker product of the tensor and a matrix of
100  ones of size block_dims.
101
102  Example:
103
104  tensor = [[1,2]
105            [3,4]]
106  block_dims = [2,2]
107
108  result = [[1 1 2 2]
109            [1 1 2 2]
110            [3 3 4 4]
111            [3 3 4 4]]
112
113  Args:
114    tensor: A 2D tensor that needs to be expanded.
115    block_dims: List of integers specifying the expansion factor.
116
117  Returns:
118    The expanded tensor
119
120  Raises:
121    ValueError: if tensor is not rank-2 or block_dims is does not have 2
122    elements.
123  """
124  if tensor.get_shape().ndims != 2:
125    raise ValueError('Input tensor must be rank 2')
126
127  if len(block_dims) != 2:
128    raise ValueError('block_dims must have 2 elements')
129
130  block_height, block_width = block_dims
131
132  def _tile_rows(tensor, multiple):
133    """Create a new tensor by tiling the tensor along rows."""
134    return array_ops.tile(tensor, [multiple, 1])
135
136  def _generate_indices(num_rows, block_dim):
137    indices = np.zeros(shape=[num_rows * block_dim, 1], dtype=np.int32)
138    for k in range(block_dim):
139      for r in range(num_rows):
140        indices[k * num_rows + r] = r * block_dim + k
141    return indices
142
143  def _replicate_rows(tensor, multiple):
144    tensor_shape = tensor.shape.as_list()
145    expanded_shape = [tensor_shape[0] * multiple, tensor_shape[1]]
146    indices = constant_op.constant(_generate_indices(tensor_shape[0], multiple))
147    return array_ops.scatter_nd(indices, _tile_rows(tensor, multiple),
148                                expanded_shape)
149
150  expanded_tensor = tensor
151
152  # Expand rows by factor block_height.
153  if block_height > 1:
154    expanded_tensor = _replicate_rows(tensor, block_height)
155
156  # Transpose and expand by factor block_width. Transpose the result.
157  if block_width > 1:
158    expanded_tensor = array_ops.transpose(
159        _replicate_rows(array_ops.transpose(expanded_tensor), block_width))
160
161  return expanded_tensor
162
163
164def factorized_pool(input_tensor,
165                    window_shape,
166                    pooling_type,
167                    strides,
168                    padding,
169                    name=None):
170  """Performs m x n pooling through a combination of 1xm and 1xn pooling.
171
172  Args:
173    input_tensor: Input tensor. Must be rank 2
174    window_shape: Pooling window shape
175    pooling_type: Either 'MAX' or 'AVG'
176    strides: The stride of the pooling window
177    padding: 'SAME' or 'VALID'.
178    name: Name of the op
179
180  Returns:
181    A rank 2 tensor containing the pooled output
182
183  Raises:
184    ValueError: if the input tensor is not rank 2
185  """
186  if input_tensor.get_shape().ndims != 2:
187    raise ValueError('factorized_pool() accepts tensors of rank 2 only')
188
189  [height, width] = input_tensor.get_shape()
190  with ops.name_scope(name, 'factorized_pool'):
191    input_tensor_aligned = array_ops.reshape(
192        input_tensor, [1, 1, height, width],
193        name=input_tensor.op.name + '_aligned')
194
195    height_pooling = nn_ops.pool(
196        input_tensor_aligned,
197        window_shape=[1, window_shape[0]],
198        pooling_type=pooling_type,
199        strides=[1, strides[0]],
200        padding=padding)
201    swap_height_width = array_ops.transpose(height_pooling, perm=[0, 1, 3, 2])
202
203    width_pooling = nn_ops.pool(
204        swap_height_width,
205        window_shape=[1, window_shape[1]],
206        pooling_type=pooling_type,
207        strides=[1, strides[1]],
208        padding=padding)
209
210  return array_ops.squeeze(
211      array_ops.transpose(width_pooling, perm=[0, 1, 3, 2]), axis=[0, 1])
212
213
214def determine_partitioned_axis(partitioned_variable):
215  partitioned_axis = 0
216  concatenated_variable_shape = partitioned_variable.get_shape()
217  for partition in partitioned_variable:
218    partition_shape = partition.get_shape()
219    maybe_partitioned_axis = np.less(partition_shape,
220                                     concatenated_variable_shape)
221    # Sanity check: make sure number of partitioned axis == 1
222    if np.count_nonzero(maybe_partitioned_axis) != 1:
223      raise ValueError('Number of partitioned axes %s not equal to 1' %
224                       np.count_nonzero(maybe_partitioned_axis))
225    partitioned_axis = np.where(maybe_partitioned_axis)[0][0]
226  return partitioned_axis
227
228
229def variable_assign(var, new_value):
230  return state_ops.assign(var, new_value, name=var.op.name + '_assign')
231
232
233def partitioned_variable_assign(partitioned_var, new_value):
234  """Assign op for partitioned variables.
235
236  Args:
237    partitioned_var: A partitioned tensorflow variable
238    new_value: Value to be assigned to the variable var
239
240  Returns:
241    A tensorflow op that groups the assign ops for each of the variable slices
242  """
243  # Determine which axis was used to partition the variable. Currently
244  # tensorflow allows partitioning variable only along 1 axis.
245  axis = 0 if len(partitioned_var) == 1 else determine_partitioned_axis(
246      partitioned_var)
247
248  partition_sizes = np.array(
249      [partition.get_shape()[axis] for partition in partitioned_var])
250  new_partitioned_values = array_ops.split(
251      new_value,
252      ops.convert_to_tensor(partition_sizes, dtype=dtypes.int32),
253      axis=axis)
254  op_list = []
255  for partition in partitioned_var:
256    op_list.append(
257        variable_assign(partition, new_partitioned_values[len(op_list)]))
258  return control_flow_ops.group(
259      *op_list, name=partitioned_var.name + '_group_assign')
260