1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Logic to update a TensorFlow model graph with quantization operations.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import re 22from tensorflow.contrib.quantize.python import common 23from tensorflow.contrib.quantize.python import graph_matcher 24from tensorflow.contrib.quantize.python import input_to_ops 25from tensorflow.contrib.quantize.python import quant_ops 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import control_flow_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.platform import tf_logging as logging 30 31# Quantizable operation types that are supported by the quantization rewrite. 32_QUANTIZABLE_TYPES = {'Conv2D', 'MatMul', 'DepthwiseConv2dNative'} 33 34# Activations that are supported by the quantization rewrite. 35_ACTIVATION_TYPES = {'Relu', 'Relu6', 'Identity'} 36 37_RELU_TYPES = {'Relu', 'Relu6'} 38 39_QUANTIZATION_OP = {'FakeQuantWithMinMaxVars'} 40_VALID_SRC_OP = {'Add', 'Mul'} 41_INTERMEDIATE_OP = {'Add', 'Mul'} 42_PASS_THROUGH_OP = {'Reshape', 'Identity', 'BatchToSpaceND', 'SpaceToBatchND'} 43_VALID_ACTIVATION_OP = {'Relu', 'Relu6'} 44 45 46def Quantize(graph, 47 is_training, 48 weight_bits=8, 49 activation_bits=8, 50 symmetric=False, 51 ema_decay=0.999, 52 quant_delay=None, 53 vars_collection=ops.GraphKeys.GLOBAL_VARIABLES, 54 scope=None): 55 """Updates graph with quantization operations. 56 57 Currently we quantize the following tensors: 58 * Conv/MatMul: Quantize the weights if it matches. 59 * Activation: Quantize the output if it matches. 60 * Bypass/Post-activation Bypass: Quantize both input and output 61 if it matches. 62 63 Args: 64 graph: Graph to modify. 65 is_training: Whether quantizing training graph or eval graph. 66 weight_bits: Number of bits to use for quantizing weights. 67 activation_bits: Number of bits to use for quantizing activations. 68 symmetric: (Optional) If true, use symmetric quantization limits instead of 69 training the minimum and maximum of each quantization range separately. 70 ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update 71 quantization intervals for quantizing activations (see here about EMA: 72 https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average). 73 quant_delay: (Optional, default None) Int, count of global steps for which 74 to delay quantization. This helps weights stabilize at the start of 75 training. 76 vars_collection: (Optional) Collection where to store the variables for 77 quantization interval ends. 78 scope: The scope to be transformed. If it's not None, only the ops which 79 are in this scope will be transformed. 80 Raises: 81 ValueError: When quantization fails. 82 """ 83 if scope and not scope.endswith('/'): 84 scope += '/' 85 86 input_to_ops_map = input_to_ops.InputToOps(graph) 87 quantized_ops = set() 88 for layer_match in _FindLayersToQuantize(graph): 89 # Quantize the weights. 90 context = _GetContextFromOp(layer_match.layer_op) 91 92 # If `scope` is given, only quantize it if the consumer of weights 93 # (the layer op) is in the right scope. 94 if layer_match.weight_tensor is not None: 95 _InsertQuantOp( 96 context, 97 'weights_quant', 98 layer_match.weight_tensor.op, 99 input_to_ops_map.ConsumerOperations(layer_match.weight_tensor.op), 100 is_training, 101 moving_avg=False, 102 ema_decay=ema_decay, 103 quant_delay=quant_delay, 104 narrow_range=True, 105 vars_collection=vars_collection, 106 bits=weight_bits, 107 symmetric=symmetric, 108 consumer_scope=scope) 109 110 # Quantize the activations. 111 if layer_match.activation_op is not None: 112 consumer_ops = input_to_ops_map.ConsumerOperations( 113 layer_match.activation_op) 114 add_context = context 115 if layer_match.bypass_op: 116 pattern_match_result = re.search(r'^(.*)/([^/]+)', context) 117 if pattern_match_result is not None: 118 add_context = pattern_match_result.group(1) 119 else: 120 add_context = '' 121 # If `scope` is given, only quantize it if the producer of weights 122 # (usually it's the layer op) is in the right scope. 123 _InsertQuantOp( 124 add_context, 125 'act_quant', 126 layer_match.activation_op, 127 consumer_ops, 128 is_training, 129 moving_avg=True, 130 ema_decay=ema_decay, 131 quant_delay=quant_delay, 132 vars_collection=vars_collection, 133 bits=activation_bits, 134 symmetric=symmetric, 135 init_min=0.0, 136 producer_scope=scope) 137 quantized_ops.add(layer_match.activation_op) 138 139 # Quantize the inputs and output to the bypass (if it exists). The input to 140 # the bypass is the bias add, and the output is the activation. 141 if layer_match.bypass_op is not None: 142 # If `scope` is given, only quantize it if the both the producer and the 143 # consumer are in the right scope. 144 _InsertQuantOp( 145 context, 146 'conv_quant', 147 layer_match.bias_add_op, 148 input_to_ops_map.ConsumerOperations(layer_match.bias_add_op), 149 is_training, 150 moving_avg=True, 151 ema_decay=ema_decay, 152 quant_delay=quant_delay, 153 vars_collection=vars_collection, 154 bits=activation_bits, 155 symmetric=symmetric, 156 producer_scope=scope, 157 consumer_scope=scope) 158 quantized_ops.add(layer_match.bias_add_op) 159 # Make sure the op following this isn't an activation. In which case, we 160 # shouldn't quantize it, since the activation will be Fused into the 161 # Add at inference time. 162 consumers = input_to_ops_map.ConsumerOperations(layer_match.bypass_op) 163 if any(consumer.type in _ACTIVATION_TYPES for consumer in consumers): 164 logging.info('Skipping %s, because its followed by an activation.', 165 layer_match.bypass_op.name) 166 else: 167 _InsertQuantOp( 168 add_context, 169 'add_quant', 170 layer_match.bypass_op, 171 input_to_ops_map.ConsumerOperations(layer_match.bypass_op), 172 is_training, 173 moving_avg=True, 174 ema_decay=ema_decay, 175 quant_delay=quant_delay, 176 vars_collection=vars_collection, 177 bits=activation_bits, 178 symmetric=symmetric, 179 producer_scope=scope, 180 consumer_scope=scope) 181 quantized_ops.add(layer_match.bypass_op) 182 183 # Quantize bypass ops that occur after the activation. 184 if layer_match.post_activation_bypass_op is not None: 185 pattern_match_result = re.search( 186 r'^(.*)/([^/]+)', layer_match.post_activation_bypass_op.name) 187 if pattern_match_result is not None: 188 post_activation_bypass_context = pattern_match_result.group(1) 189 else: 190 post_activation_bypass_context = '' 191 # If `scope` is given, only quantize it if the producer is in the right 192 # scope. 193 # Make sure the op following this isn't an activation. In which case, we 194 # shouldn't quantize it, since the activation will be Fused into the 195 # Add at inference time. 196 consumers = input_to_ops_map.ConsumerOperations( 197 layer_match.post_activation_bypass_op) 198 if any(consumer.type in _RELU_TYPES for consumer in consumers): 199 logging.info('Skipping %s, because its followed by an activation.', 200 layer_match.post_activation_bypass_op.name) 201 else: 202 _InsertQuantOp( 203 post_activation_bypass_context, 204 'post_activation_bypass_quant', 205 layer_match.post_activation_bypass_op, 206 consumers, 207 is_training, 208 moving_avg=True, 209 ema_decay=ema_decay, 210 quant_delay=quant_delay, 211 vars_collection=vars_collection, 212 bits=activation_bits, 213 symmetric=symmetric, 214 producer_scope=scope) 215 quantized_ops.add(layer_match.post_activation_bypass_op) 216 217 _QuantizeActivationLayers( 218 quantized_ops, 219 graph, 220 is_training, 221 activation_bits, 222 ema_decay, 223 quant_delay, 224 vars_collection, 225 scope=scope) 226 227 228def _QuantizeActivationLayers(quantized_ops, 229 graph, 230 is_training, 231 activation_bits=8, 232 ema_decay=0.999, 233 quant_delay=None, 234 vars_collection=ops.GraphKeys.GLOBAL_VARIABLES, 235 scope=None): 236 """Quantize intermediate activation tensors after addition and multiplication. 237 238 Args: 239 quantized_ops: Set of previously quantized activation ops. 240 graph: Graph to modify. 241 is_training: Whether quantizing training graph or eval graph. 242 activation_bits: Number of bits to use for quantizing activations. 243 ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update 244 quantization intervals for quantizing activations (see here about EMA: 245 https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average). 246 quant_delay: (Optional, default None) Int, count of global steps for which 247 to delay quantization. This helps weights stabilize at the start of 248 training. 249 vars_collection: (Optional) Collection where to store the variables for 250 quantization interval ends. 251 scope: The scope to be transformed. If it's not None, only the ops which are 252 in this scope will be transformed. 253 254 Raises: 255 ValueError: When quantization fails. 256 """ 257 input_to_ops_map = input_to_ops.InputToOps(graph) 258 for op in (op for op in graph.get_operations()): 259 if _CheckIfQuantizableOp(op, quantized_ops): 260 logging.info('Inserting fake quant op activation_%s_quant after %s', 261 op.type, op.name) 262 consumers = input_to_ops_map.ConsumerOperations(op) 263 _InsertQuantOp( 264 op.name, 265 'activation_' + op.type + '_quant', 266 op, 267 consumers, 268 is_training, 269 moving_avg=True, 270 ema_decay=ema_decay, 271 quant_delay=quant_delay, 272 vars_collection=vars_collection, 273 bits=activation_bits, 274 producer_scope=scope) 275 276 277def _CheckIfQuantizableOp(src_op, quantized_ops): 278 """Check if the output of an op should be quantized. 279 280 Args: 281 src_op: op to be checked 282 quantized_ops: Set of previously quantized activation ops. 283 284 Returns: 285 Boolean specifying if output should be quantized or not. 286 """ 287 src_op_name = set([src_op.type]) 288 if src_op in quantized_ops: 289 return False 290 if not src_op_name.intersection(_VALID_SRC_OP): 291 return False 292 293 # If src op is an add or a mul and the output is immediately 294 # followed by an activation skip 295 if len(src_op.outputs) == 1 and len(src_op.outputs[0].consumers()) == 1: 296 op_consumers = src_op.outputs[0].consumers() 297 if set([op_consumers[0].type]).intersection(_VALID_ACTIVATION_OP): 298 logging.info('Skipping quant after %s', src_op.name) 299 return False 300 # Is an Add or a Mul 301 input_ops = src_op.inputs 302 303 for op in input_ops: 304 curr_op = op.op 305 curr_op_type = set([curr_op.type]) 306 while curr_op_type.intersection(_PASS_THROUGH_OP): 307 # Walk back through pass through ops 308 curr_op = curr_op.inputs[0].op 309 curr_op_type = set([curr_op.type]) 310 # Now at a valid or quantizable op, need to check if 311 # atleast one of the inputs to a valid op is connected 312 # to a quantizable op via pass through ops 313 314 if (curr_op_type.intersection(_QUANTIZATION_OP) or 315 curr_op.name.find('delayed_quant/Merge') > 0): 316 return True 317 318 if curr_op_type.intersection(_INTERMEDIATE_OP): 319 # Check if atleast one input to intermediate_op are quantizable 320 for input_op in curr_op.inputs: 321 if _CheckIfQuantizableOp(input_op.op, quantized_ops): 322 return True 323 return False 324 325 326def _FindLayersToQuantize(graph): 327 """Matches layers in graph to quantize. 328 329 The following patterns get matched. Nodes surrounded by [] will be 330 optionally matched: 331 332 weight|folded_weight 333 / 334 conv|fc 335 | 336 [batch_to_space_nd] 337 | 338 [post_conv_correction] 339 | 340 [biasadd|folded_bias] 341 | 342 [bypass] 343 | 344 activation 345 | 346 [post_activation_bypass] 347 348 Match replacements: 349 If weight|folded_weight is found, FakeQuant is added afterwards. 350 If bypass is found, FakeQuant is added before and after. 351 If activation is found, FakeQuant is added afterwards. 352 If post_activation_bypass is found, FakeQuant is added afterwards. 353 354 Args: 355 graph: Graph to perform match on. 356 357 Returns: 358 list of _LayerMatches. 359 """ 360 input_pattern = graph_matcher.OpTypePattern('*') 361 weight_var_pattern = graph_matcher.OpTypePattern('Variable|VariableV2') 362 weight_partition_identity_pattern = graph_matcher.OpTypePattern( 363 'Identity', inputs=[weight_var_pattern]) 364 weight_partition_concat_pattern = graph_matcher.OpTypePattern( 365 'ConcatV2', inputs=[weight_partition_identity_pattern, '*', '*']) 366 weight_identity_pattern = graph_matcher.OpTypePattern( 367 'Identity', 368 inputs=[ 369 graph_matcher.OneofPattern([ 370 weight_partition_identity_pattern, 371 weight_partition_concat_pattern, 372 weight_var_pattern, 373 ]) 374 ]) 375 weight_resource_var_pattern = graph_matcher.OpTypePattern('ReadVariableOp') 376 folded_weight_pattern = graph_matcher.OpTypePattern('Mul') 377 378 # The weights inputs to the layer operation can either be from the Variable or 379 # the folded weight (Mul). 380 layer_pattern = graph_matcher.OpTypePattern( 381 '|'.join(_QUANTIZABLE_TYPES), 382 inputs=[ 383 input_pattern, 384 graph_matcher.OneofPattern([ 385 weight_identity_pattern, weight_resource_var_pattern, 386 folded_weight_pattern 387 ]) 388 ], 389 ordered_inputs=False) 390 391 # For atrous convolutions a BatchToSpaceND will occur after the depthwise 392 # convolution. 393 batch_to_space_pattern = graph_matcher.OpTypePattern( 394 'BatchToSpaceND', 395 inputs=[ 396 layer_pattern, 397 graph_matcher.OpTypePattern('*'), 398 graph_matcher.OpTypePattern('*') 399 ]) 400 401 layer_output_pattern = graph_matcher.OneofPattern( 402 [batch_to_space_pattern, layer_pattern]) 403 404 # For separable convolutions, we are looking for a conv, followed by a conv 405 # with no activations between the two. 406 sep_conv_pattern = graph_matcher.OpTypePattern( 407 '|'.join(_QUANTIZABLE_TYPES), 408 inputs=[ 409 graph_matcher.OneofPattern([layer_output_pattern]), 410 graph_matcher.OpTypePattern('*') 411 ], 412 ordered_inputs=False) 413 folded_bias_mul_pattern = graph_matcher.OpTypePattern( 414 'Mul', 415 inputs=[graph_matcher.OpTypePattern('*'), layer_output_pattern], 416 ordered_inputs=False) 417 post_layer_op_correction_pattern = graph_matcher.OpTypePattern( 418 'Add', 419 inputs=[folded_bias_mul_pattern, 420 graph_matcher.OpTypePattern('*')], 421 ordered_inputs=False) 422 folded_bias_add_pattern = graph_matcher.OpTypePattern( 423 'Add', 424 inputs=[ 425 post_layer_op_correction_pattern, 426 graph_matcher.OpTypePattern('*') 427 ], 428 ordered_inputs=False) 429 430 # batch_norms with forced updates have an Identity operation at the end. 431 # TODO(suharshs): Find a way to easily skip extra Identity operations. The 432 # current issue is that doing so can often match patterns across many layers 433 # incorrectly. 434 batch_norm_identity = graph_matcher.OpTypePattern( 435 'Identity', inputs=[folded_bias_add_pattern]) 436 437 bias_add_pattern = graph_matcher.OpTypePattern( 438 'Add|BiasAdd', inputs=[layer_output_pattern, '*'], ordered_inputs=False) 439 440 # The bias can come from the bias add or the folded bias add. 441 bypass_pattern = graph_matcher.OpTypePattern( 442 'Add', 443 inputs=[ 444 graph_matcher.OneofPattern( 445 [bias_add_pattern, folded_bias_add_pattern, batch_norm_identity]), 446 '*' 447 ], 448 ordered_inputs=False) 449 450 # The input to the activation can come from bias add, fold bias add, the 451 # bypasses. 452 # TODO(suharshs): We should ideally skip Identity operations instead of 453 # treating them as activations. 454 activation_pattern = graph_matcher.OpTypePattern( 455 '|'.join(_ACTIVATION_TYPES) + '|Identity', 456 inputs=[ 457 graph_matcher.OneofPattern([ 458 bias_add_pattern, 459 folded_bias_add_pattern, 460 batch_norm_identity, 461 bypass_pattern, 462 layer_pattern, 463 ]) 464 ]) 465 466 post_activation_bypass_pattern = graph_matcher.OpTypePattern( 467 'Add', inputs=['*', activation_pattern], ordered_inputs=False) 468 469 # The order of the following matching blocks is very important. Since matches 470 # aren't guaranteed to be disjoint, we structure matches from largest to 471 # smallest to guarantee that the largest match always wins. Additionally, we 472 # ensure that we don't match layers multiple times. 473 474 layer_matches = [] 475 # We use matched_layer_set to ensure that layers aren't matched multiple 476 # times. 477 matched_layer_set = set() 478 479 # First, we match layers that have a post activation bypass. We do this first 480 # to ensure we don't match only the first part of this layer, missing the 481 # post activation bypass node. 482 post_activation_bypass_layer_matcher = graph_matcher.GraphMatcher( 483 post_activation_bypass_pattern) 484 for match_result in post_activation_bypass_layer_matcher.match_graph(graph): 485 layer_op = match_result.get_op(layer_pattern) 486 weight_tensor = match_result.get_tensor(weight_identity_pattern) 487 if weight_tensor is None: 488 weight_tensor = match_result.get_tensor(weight_resource_var_pattern) 489 if weight_tensor is None: 490 weight_tensor = match_result.get_tensor(folded_weight_pattern) 491 activation_op = match_result.get_op(activation_pattern) 492 bias_add_op = match_result.get_op(bias_add_pattern) 493 if bias_add_op is None: 494 bias_add_op = match_result.get_op(folded_bias_add_pattern) 495 bypass_op = match_result.get_op(bypass_pattern) 496 post_activation_bypass_op = match_result.get_op( 497 post_activation_bypass_pattern) 498 if layer_op not in matched_layer_set: 499 matched_layer_set.add(layer_op) 500 layer_matches.append( 501 _LayerMatch(layer_op, weight_tensor, activation_op, bypass_op, 502 post_activation_bypass_op, bias_add_op)) 503 504 # Now, we match the basic layer ending at an activation. We may get duplicate 505 # matches from above, but we don't add them to layer_matches. 506 layer_matcher = graph_matcher.GraphMatcher(activation_pattern) 507 for match_result in layer_matcher.match_graph(graph): 508 layer_op = match_result.get_op(layer_pattern) 509 weight_tensor = match_result.get_tensor(weight_identity_pattern) 510 if weight_tensor is None: 511 weight_tensor = match_result.get_tensor(weight_resource_var_pattern) 512 if weight_tensor is None: 513 weight_tensor = match_result.get_tensor(folded_weight_pattern) 514 activation_op = match_result.get_op(activation_pattern) 515 bias_add_op = match_result.get_op(bias_add_pattern) 516 if bias_add_op is None: 517 bias_add_op = match_result.get_op(folded_bias_add_pattern) 518 bypass_op = match_result.get_op(bypass_pattern) 519 if layer_op not in matched_layer_set: 520 if not _IsSkipLayer(activation_op): 521 matched_layer_set.add(layer_op) 522 layer_matches.append( 523 _LayerMatch(layer_op, weight_tensor, activation_op, bypass_op, None, 524 bias_add_op)) 525 526 # Match the final layer, where there may not be an activation and instead 527 # the output of the final BiasAdd must be quantized. So we treat the BiasAdd 528 # as the 'activation_op' in the _LayerMatch, to ensure that it's output is 529 # quantized. 530 final_layer_matcher = graph_matcher.GraphMatcher( 531 graph_matcher.OneofPattern([bias_add_pattern, folded_bias_add_pattern])) 532 for match_result in final_layer_matcher.match_graph(graph): 533 layer_op = match_result.get_op(layer_pattern) 534 weight_tensor = match_result.get_tensor(weight_identity_pattern) 535 if weight_tensor is None: 536 weight_tensor = match_result.get_tensor(weight_resource_var_pattern) 537 if weight_tensor is None: 538 weight_tensor = match_result.get_tensor(folded_weight_pattern) 539 activation_op = match_result.get_op(bias_add_pattern) 540 if activation_op is None: 541 activation_op = match_result.get_op(folded_bias_add_pattern) 542 if layer_op not in matched_layer_set: 543 matched_layer_set.add(layer_op) 544 layer_matches.append( 545 _LayerMatch(layer_op, weight_tensor, activation_op, None, None, None)) 546 547 # Look for separable convolutions here 548 sep_conv_matcher = graph_matcher.GraphMatcher(sep_conv_pattern) 549 for match_result in sep_conv_matcher.match_graph(graph): 550 layer_op = match_result.get_op(layer_pattern) 551 weight_tensor = match_result.get_tensor(weight_identity_pattern) 552 if weight_tensor is None: 553 weight_tensor = match_result.get_tensor(weight_resource_var_pattern) 554 activation_op = match_result.get_op(layer_pattern) 555 if layer_op not in matched_layer_set: 556 matched_layer_set.add(layer_op) 557 layer_matches.append( 558 _LayerMatch(layer_op, weight_tensor, activation_op, None, None, None)) 559 560 return layer_matches 561 562 563def _IsSkipLayer(activation_op): 564 """Skip quantizing conv->identity->Batch norm layers. 565 566 Args: 567 activation_op: Activation op detected by layer matching pattern 568 569 Returns: 570 skip_layer: boolean, true when conv->identity->batch norm is detected. 571 """ 572 573 # Exclude quantization of conv->identity->BN, 574 # After folding, this part corresponds to estimation of mean and variance 575 # and should not be quantized. 576 skip_layer = False 577 if activation_op.type == 'Identity' and len(activation_op.outputs) == 1: 578 if len(activation_op.outputs[0].consumers()) == 1: 579 consumer = activation_op.outputs[0].consumers()[0] 580 if consumer.type == 'FusedBatchNorm': 581 skip_layer = True 582 logging.info( 583 'Skipping quantizing %s, because it is the output of a conv/fc ' 584 'followed by a identity, feeding a fused batch norm.', 585 activation_op.name) 586 return skip_layer 587 588 589class _LayerMatch(object): 590 """Contains all information related to a matched Layer.""" 591 592 def __init__(self, layer_op, weight_tensor, activation_op, bypass_op, 593 post_activation_bypass_op, bias_add_op): 594 self._layer_op = layer_op 595 self._weight_tensor = weight_tensor 596 self._activation_op = activation_op 597 self._bypass_op = bypass_op 598 self._post_activation_bypass_op = post_activation_bypass_op 599 self._bias_add_op = bias_add_op 600 601 @property 602 def layer_op(self): 603 return self._layer_op 604 605 @property 606 def weight_tensor(self): 607 return self._weight_tensor 608 609 @property 610 def activation_op(self): 611 return self._activation_op 612 613 @property 614 def bypass_op(self): 615 return self._bypass_op 616 617 @property 618 def post_activation_bypass_op(self): 619 return self._post_activation_bypass_op 620 621 @property 622 def bias_add_op(self): 623 return self._bias_add_op 624 625 626def _FollowedByFakeQuant(tensor): 627 """Returns True if the tensor is followed by a FakeQuant.""" 628 fake_quant_ops = set([ 629 'FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxArgs', 630 'FakeQuantWithMinMaxVarsPerChannel' 631 ]) 632 pass_through_ops = set(['Reshape', 'Identity']) 633 consumers = tensor.consumers() 634 while consumers: 635 c = consumers.pop() 636 if c.type in fake_quant_ops: 637 return True 638 elif c.type in pass_through_ops: 639 for output in c.outputs: 640 consumers.extend(output.consumers()) 641 return False 642 643 644def _InsertQuantOp(context, 645 name, 646 producer, 647 consumers, 648 is_training, 649 moving_avg=True, 650 init_min=-6.0, 651 init_max=6.0, 652 bits=8, 653 symmetric=False, 654 ema_decay=0.999, 655 quant_delay=None, 656 vars_collection=ops.GraphKeys.GLOBAL_VARIABLES, 657 narrow_range=False, 658 producer_scope=None, 659 consumer_scope=None): 660 """Inserts a quant op between a producer op and (multiple) consumer ops. 661 662 Args: 663 context: Context where producer and consumer operations are nested. 664 name: Name for the new quantization op within the context. 665 producer: Producer operation of the pairs where quantization will be 666 inserted. 667 consumers: Consumer operations of the pairs. 668 is_training: Whether quantizing training graph or eval graph. 669 moving_avg: Specifies whether to use exponential moving average or just 670 the last value seen. 671 init_min: Starting minimum value for the new quantization op. 672 init_max: Starting maximum value for the new quantization op. 673 bits: Number of bits to use for quantization, must be between 2 and 8. 674 symmetric: (Optional) If true, use symmetric quantization limits instead of 675 training the minimum and maximum of each quantization range separately. 676 ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update 677 quantization intervals for quantizing activations (see here about EMA: 678 https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average). 679 quant_delay: (Optional, default None) Int, count of global steps for which 680 to delay quantization. This helps weights stabilize at the start of 681 training. 682 vars_collection: (Optional) Collection where to store the variables for 683 quantization interval ends. 684 narrow_range: Whether to use the narrow quantization range 685 [1; 2^bits - 1] or wide range [0; 2^bits - 1]. 686 producer_scope: The restriction of producer scope. If not None, the new op 687 will be inserted only when the producer is in this scope. 688 consumer_scope: The restriction of consumer scope. If not None, the new op 689 will be inserted only when all the consumers are in this scope. 690 Raises: 691 ValueError: When producer operation is not directly connected to the 692 consumer operation. 693 """ 694 if producer_scope and not producer.name.startswith(producer_scope): 695 logging.info( 696 '_InsertQuantOp ignores context="%s" name="%s" ' 697 'because producer "%s" is not in scope "%s"', 698 context, name, producer.name, producer_scope) 699 return 700 701 if consumer_scope: 702 consumers_in_scope = [] 703 for consumer in consumers: 704 if consumer.name.startswith(consumer_scope): 705 consumers_in_scope.append(consumer) 706 else: 707 logging.info( 708 '_InsertQuantOp context="%s" name="%s" ignores ' 709 'consumer "%s" because it is not in scope "%s"', 710 context, name, consumer.name, consumer_scope) 711 return 712 consumers = consumers_in_scope 713 714 name_prefix = _AddContextToName(context, name) 715 # This is needed on TPU where name_scope == 'TPUReplicate/loop', and 716 # name_prefix starts with 'TPUReplicate/loop/'; without dropping it 717 # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which 718 # breaks things later. 719 name_scope = ops.get_name_scope() 720 if name_scope: 721 name_prefix = common.DropStringPrefix(name_prefix, name_scope + '/') 722 723 inputs = producer.outputs[0] 724 # Prevent ops from being quantized multiple times. Bypass ops can sometimes 725 # overlap between multiple matches, so we need to ensure that we don't 726 # add duplicate FakeQuant operations. 727 if _FollowedByFakeQuant(inputs): 728 return 729 730 if moving_avg: 731 quant = ( 732 quant_ops.MovingAvgQuantize( 733 inputs, 734 init_min=init_min, 735 init_max=init_max, 736 ema_decay=ema_decay, 737 is_training=is_training, 738 num_bits=bits, 739 symmetric=symmetric, 740 narrow_range=narrow_range, 741 vars_collection=vars_collection, 742 name_prefix=name_prefix)) 743 else: 744 quant = ( 745 quant_ops.LastValueQuantize( 746 inputs, 747 init_min=init_min, 748 init_max=init_max, 749 is_training=is_training, 750 num_bits=bits, 751 symmetric=symmetric, 752 narrow_range=narrow_range, 753 vars_collection=vars_collection, 754 name_prefix=name_prefix)) 755 756 if quant_delay and quant_delay > 0: 757 activate_quant = math_ops.greater_equal( 758 common.CreateOrGetQuantizationStep(), 759 quant_delay, 760 name=name_prefix + '/activate_quant') 761 quant = control_flow_ops.cond( 762 activate_quant, 763 lambda: quant, 764 lambda: inputs, 765 name=name_prefix + '/delayed_quant') 766 767 if consumers: 768 tensors_modified_count = common.RerouteTensor( 769 quant, inputs, can_modify=consumers) 770 # Some operations can have multiple output tensors going to the same 771 # consumer. Since consumers is a set, we need to ensure that 772 # tensors_modified_count is greater than or equal to the length of the set 773 # of consumers. 774 if tensors_modified_count < len(consumers): 775 raise ValueError('No inputs quantized for ops: [%s]' % ', '.join( 776 [consumer.name for consumer in consumers])) 777 778 779def _GetContextFromOp(op): 780 """Gets the root context name from the op name.""" 781 context_re = re.search(r'^(.*)/([^/]+)', op.name) 782 if context_re: 783 return context_re.group(1) 784 return '' 785 786 787def _AddContextToName(context, name): 788 """Adds the context to the name if it exists.""" 789 if not context: 790 return name 791 return context + '/' + name 792