1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python support for quantization operations.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import ops 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import init_ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.ops import state_ops 26from tensorflow.python.ops import variable_scope 27from tensorflow.python.training import moving_averages 28 29 30def FixedQuantize(inputs, init_min=-6.0, init_max=6.0, scope=None): 31 """Adds a fake quantize layer with fixed quantization interval. 32 33 Args: 34 inputs: a tensor containing values to be quantized. 35 init_min: the lower end of quantization interval. 36 init_max: the upper end of quantization interval. 37 scope: Optional scope for name_scope. 38 Returns: 39 a tensor containing quantized values. 40 """ 41 with ops.name_scope(scope, 'FixedQuantize', values=[inputs]): 42 return array_ops.fake_quant_with_min_max_args( 43 inputs, min=init_min, max=init_max) 44 45 46def _ModelVariable(name, 47 shape=None, 48 initializer=None, 49 collections=None, 50 trainable=None): 51 collections = list(collections or []) 52 collections += [ops.GraphKeys.GLOBAL_VARIABLES] 53 return variable_scope.get_variable( 54 name, 55 shape=shape, 56 initializer=initializer, 57 collections=collections, 58 trainable=trainable) 59 60 61def LastValueQuantize(inputs, 62 per_channel=False, 63 init_min=-6.0, 64 init_max=6.0, 65 vars_collection=None, 66 name_prefix='LastValueQuant', 67 reuse=None, 68 is_training=True, 69 num_bits=8, 70 narrow_range=False, 71 symmetric=False): 72 """Adds a layer that collects quantization ranges as last input ranges. 73 74 LastValueQuantize creates variables called 'min' and 'max', representing the 75 interval used for quantization and clamping. 76 77 Args: 78 inputs: a tensor containing values to be quantized. 79 per_channel: (Optional) a boolean specifying whether to use different 80 quantization ranges per output channel. 81 init_min: a float scalar, the initial value for variable min. 82 init_max: a float scalar, the initial value for variable max. 83 vars_collection: (Optional) collection where to store variables for 84 quantization interval ends. 85 name_prefix: name_prefix for created nodes. 86 reuse: whether or not the layer and its variables should be reused. To be 87 able to reuse the layer scope must be given. 88 is_training: Whether the op is applied to a training or eval graph. 89 num_bits: Number of bits to use for quantization, must be between 2 and 8. 90 narrow_range: Whether to use the narrow quantization range 91 [1; 2^num_bits - 1] or wide range [0; 2^num_bits - 1]. 92 symmetric: If true, use symmetric quantization limits instead of training 93 the minimum and maximum of each quantization range separately. 94 Returns: 95 a tensor containing quantized values. 96 """ 97 with variable_scope.variable_scope( 98 None, default_name=name_prefix, values=[inputs], reuse=reuse) as scope: 99 scope.set_partitioner(None) 100 input_shape = inputs.get_shape() 101 input_dim = len(input_shape) 102 if per_channel: 103 # Only support quantizing 1-, 2- and 4-dimensional tensors. 104 assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in ' 105 ' scope: %s' % (input_shape, name_prefix)) 106 min_max_shape = [input_shape[-1]] 107 else: 108 min_max_shape = [] 109 110 vars_collections = [vars_collection] if vars_collection else [] 111 min_var = _ModelVariable( 112 'min', 113 shape=min_max_shape, 114 initializer=init_ops.constant_initializer(init_min), 115 collections=vars_collections, 116 trainable=False) 117 max_var = _ModelVariable( 118 'max', 119 shape=min_max_shape, 120 initializer=init_ops.constant_initializer(init_max), 121 collections=vars_collections, 122 trainable=False) 123 if not is_training: 124 return _FakeQuantWithMinMaxVars( 125 inputs, 126 min_var, 127 max_var, 128 per_channel=per_channel, 129 num_bits=num_bits, 130 narrow_range=narrow_range) 131 132 if per_channel: 133 if input_dim == 2: 134 reduce_dims = [0] 135 elif input_dim == 4: 136 reduce_dims = [0, 1, 2] 137 138 if per_channel: 139 if input_dim >= 2: 140 batch_min = math_ops.reduce_min( 141 inputs, axis=reduce_dims, name='BatchMin') 142 else: 143 batch_min = inputs 144 else: 145 batch_min = math_ops.reduce_min(inputs, name='BatchMin') 146 147 if per_channel: 148 if input_dim >= 2: 149 batch_max = math_ops.reduce_max( 150 inputs, axis=reduce_dims, name='BatchMax') 151 else: 152 batch_max = inputs 153 else: 154 batch_max = math_ops.reduce_max(inputs, name='BatchMax') 155 156 if symmetric: 157 if narrow_range: 158 min_max_ratio = -1 159 else: 160 # In two's complement notation, the negative range is slightly larger 161 # than the positive range. 162 min_max_ratio = -((1 << num_bits) - 2) / (1 << num_bits) 163 164 # TFLite requires that 0.0 if always in the [min; max] range. Because 165 # batch_min <= batch_max, it follows that range_min <= 0 <= range_max. 166 range_min = math_ops.minimum(batch_min, batch_max / min_max_ratio) 167 range_max = math_ops.maximum(batch_max, batch_min * min_max_ratio) 168 else: 169 # TFLite requires that 0.0 if always in the [min; max] range. 170 range_min = math_ops.minimum(batch_min, 0.0) 171 range_max = math_ops.maximum(batch_max, 0.0) 172 173 assign_min = state_ops.assign(min_var, range_min, name='AssignMinLast') 174 assign_max = state_ops.assign(max_var, range_max, name='AssignMaxLast') 175 176 return _FakeQuantWithMinMaxVars( 177 inputs, 178 assign_min, 179 assign_max, 180 per_channel=per_channel, 181 num_bits=num_bits, 182 narrow_range=narrow_range) 183 184 185def MovingAvgQuantize(inputs, 186 per_channel=False, 187 init_min=-6.0, 188 init_max=6.0, 189 ema_decay=0.999, 190 vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, 191 name_prefix='MovingAvgQuantize', 192 reuse=None, 193 is_training=True, 194 num_bits=8, 195 narrow_range=False, 196 symmetric=False): 197 """Adds a layer that collects quantization ranges as EMAs of input ranges. 198 199 MovingAvgQuantize creates variables called 'min' and 'max', representing the 200 interval used for quantization and clamping. 201 202 Args: 203 inputs: a tensor containing values to be quantized. 204 per_channel: (default False) a boolean specifying whether to use different 205 quantization ranges per output channel. 206 init_min: a float scalar, the initial value for variable min. 207 init_max: a float scalar, the initial value for variable max. 208 ema_decay: EMA decay parameter. 209 vars_collection: (Optional) collection where to store variables for 210 quantization interval ends. 211 name_prefix: name_prefix for created nodes. 212 reuse: whether or not the layer and its variables should be reused. To be 213 able to reuse the layer scope must be given. 214 is_training: Whether the op is applied to a training or eval graph. 215 num_bits: Number of bits to use for quantization, must be between 2 and 8. 216 narrow_range: Whether to use the narrow quantization range 217 [1; 2^num_bits - 1] or wide range [0; 2^num_bits - 1]. 218 symmetric: If true, use symmetric quantization limits instead of training 219 the minimum and maximum of each quantization range separately. 220 Returns: 221 a tensor containing quantized values. 222 """ 223 with variable_scope.variable_scope( 224 None, default_name=name_prefix, values=[inputs], reuse=reuse) as scope: 225 scope.set_partitioner(None) 226 input_shape = inputs.get_shape() 227 if per_channel: 228 input_dim = len(input_shape) 229 # Only support quantizing 1-, 2- and 4-dimensional tensors. 230 assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in ' 231 ' scope: %s' % (input_shape, name_prefix)) 232 min_max_shape = [input_shape[-1]] 233 else: 234 min_max_shape = [] 235 236 vars_collections = [vars_collection] if vars_collection else [] 237 min_var = _ModelVariable( 238 'min', 239 shape=min_max_shape, 240 initializer=init_ops.constant_initializer(init_min), 241 collections=vars_collections, 242 trainable=False) 243 max_var = _ModelVariable( 244 'max', 245 shape=min_max_shape, 246 initializer=init_ops.constant_initializer(init_max), 247 collections=vars_collections, 248 trainable=False) 249 if not is_training: 250 return _FakeQuantWithMinMaxVars( 251 inputs, 252 min_var, 253 max_var, 254 per_channel=per_channel, 255 num_bits=num_bits, 256 narrow_range=narrow_range) 257 if per_channel: 258 if input_dim == 2: 259 reduce_dims = [0] 260 elif input_dim == 4: 261 reduce_dims = [0, 1, 2] 262 263 if per_channel: 264 if input_dim >= 2: 265 batch_min = math_ops.reduce_min( 266 inputs, axis=reduce_dims, name='BatchMin') 267 else: 268 batch_min = inputs 269 else: 270 batch_min = math_ops.reduce_min(inputs, name='BatchMin') 271 272 if per_channel: 273 if input_dim >= 2: 274 batch_max = math_ops.reduce_max( 275 inputs, axis=reduce_dims, name='BatchMax') 276 else: 277 batch_max = inputs 278 else: 279 batch_max = math_ops.reduce_max(inputs, name='BatchMax') 280 281 if symmetric: 282 if narrow_range: 283 min_max_ratio = -1 284 else: 285 # In two's complement notation, the negative range is slightly larger 286 # than the positive range. 287 min_max_ratio = -((1 << num_bits) - 2) / (1 << num_bits) 288 289 # TFLite requires that 0.0 if always in the [min; max] range. Because 290 # batch_min <= batch_max, it follows that range_min <= 0 <= range_max. 291 range_min = math_ops.minimum(batch_min, batch_max / min_max_ratio) 292 range_max = math_ops.maximum(batch_max, batch_min * min_max_ratio) 293 else: 294 # TFLite requires that 0.0 if always in the [min; max] range. 295 range_min = math_ops.minimum(batch_min, 0.0) 296 range_max = math_ops.maximum(batch_max, 0.0) 297 298 assign_min = moving_averages.assign_moving_average( 299 min_var, range_min, ema_decay, name='AssignMinEma') 300 assign_max = moving_averages.assign_moving_average( 301 max_var, range_max, ema_decay, name='AssignMaxEma') 302 303 return _FakeQuantWithMinMaxVars( 304 inputs, 305 assign_min, 306 assign_max, 307 per_channel=per_channel, 308 num_bits=num_bits, 309 narrow_range=narrow_range) 310 311 312def _FakeQuantWithMinMaxVars(inputs, min_var, max_var, per_channel, num_bits, 313 narrow_range): 314 """Adds a fake quantization operation. 315 316 Depending on value of per_channel, this operation may do global quantization 317 or per channel quantization. min_var and max_var should have corresponding 318 shapes: [1] when per_channel == False and [d] when per_channel == True. 319 320 Args: 321 inputs: a tensor containing values to be quantized. 322 min_var: a variable containing quantization range lower end(s). 323 max_var: a variable containing quantization range upper end(s). 324 per_channel: a boolean specifying whether to use per-channel quantization. 325 num_bits: Number of bits to use for quantization, must be between 2 and 8. 326 narrow_range: Whether to use the narrow quantization range 327 [1; 2^num_bits - 1] or wide range [0; 2^num_bits - 1]. 328 Returns: 329 a tensor containing quantized values. 330 """ 331 332 if per_channel: 333 assert len(min_var.get_shape()) == 1 334 assert len(max_var.get_shape()) == 1 335 return array_ops.fake_quant_with_min_max_vars_per_channel( 336 inputs, min_var, max_var, num_bits=num_bits, narrow_range=narrow_range) 337 else: 338 assert min_var.get_shape() == [] # pylint: disable=g-explicit-bool-comparison 339 assert max_var.get_shape() == [] # pylint: disable=g-explicit-bool-comparison 340 return array_ops.fake_quant_with_min_max_vars( 341 inputs, min_var, max_var, num_bits=num_bits, narrow_range=narrow_range) 342