1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Logic to update a TensorFlow model graph with quantization operations."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import re
22from tensorflow.contrib.quantize.python import common
23from tensorflow.contrib.quantize.python import graph_matcher
24from tensorflow.contrib.quantize.python import input_to_ops
25from tensorflow.contrib.quantize.python import quant_ops
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.platform import tf_logging as logging
30
31# Quantizable operation types that are supported by the quantization rewrite.
32_QUANTIZABLE_TYPES = {'Conv2D', 'MatMul', 'DepthwiseConv2dNative'}
33
34# Activations that are supported by the quantization rewrite.
35_ACTIVATION_TYPES = {'Relu', 'Relu6', 'Identity'}
36
37_RELU_TYPES = {'Relu', 'Relu6'}
38
39_QUANTIZATION_OP = {'FakeQuantWithMinMaxVars'}
40_VALID_SRC_OP = {'Add', 'Mul'}
41_INTERMEDIATE_OP = {'Add', 'Mul'}
42_PASS_THROUGH_OP = {'Reshape', 'Identity', 'BatchToSpaceND', 'SpaceToBatchND'}
43_VALID_ACTIVATION_OP = {'Relu', 'Relu6'}
44
45
46def Quantize(graph,
47             is_training,
48             weight_bits=8,
49             activation_bits=8,
50             symmetric=False,
51             ema_decay=0.999,
52             quant_delay=None,
53             vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
54             scope=None):
55  """Updates graph with quantization operations.
56
57  Currently we quantize the following tensors:
58  * Conv/MatMul: Quantize the weights if it matches.
59  * Activation: Quantize the output if it matches.
60  * Bypass/Post-activation Bypass: Quantize both input and output
61    if it matches.
62
63  Args:
64    graph: Graph to modify.
65    is_training: Whether quantizing training graph or eval graph.
66    weight_bits: Number of bits to use for quantizing weights.
67    activation_bits: Number of bits to use for quantizing activations.
68    symmetric: (Optional) If true, use symmetric quantization limits instead of
69      training the minimum and maximum of each quantization range separately.
70    ema_decay: (Optional) Float, EMA decay parameter.  EMA is used to update
71      quantization intervals for quantizing activations (see here about EMA:
72      https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
73    quant_delay: (Optional, default None) Int, count of global steps for which
74      to delay quantization.  This helps weights stabilize at the start of
75      training.
76    vars_collection: (Optional) Collection where to store the variables for
77      quantization interval ends.
78    scope: The scope to be transformed. If it's not None, only the ops which
79      are in this scope will be transformed.
80  Raises:
81    ValueError: When quantization fails.
82  """
83  if scope and not scope.endswith('/'):
84    scope += '/'
85
86  input_to_ops_map = input_to_ops.InputToOps(graph)
87  quantized_ops = set()
88  for layer_match in _FindLayersToQuantize(graph):
89    # Quantize the weights.
90    context = _GetContextFromOp(layer_match.layer_op)
91
92    # If `scope` is given, only quantize it if the consumer of weights
93    # (the layer op) is in the right scope.
94    if layer_match.weight_tensor is not None:
95      _InsertQuantOp(
96          context,
97          'weights_quant',
98          layer_match.weight_tensor.op,
99          input_to_ops_map.ConsumerOperations(layer_match.weight_tensor.op),
100          is_training,
101          moving_avg=False,
102          ema_decay=ema_decay,
103          quant_delay=quant_delay,
104          narrow_range=True,
105          vars_collection=vars_collection,
106          bits=weight_bits,
107          symmetric=symmetric,
108          consumer_scope=scope)
109
110    # Quantize the activations.
111    if layer_match.activation_op is not None:
112      consumer_ops = input_to_ops_map.ConsumerOperations(
113          layer_match.activation_op)
114      add_context = context
115      if layer_match.bypass_op:
116        pattern_match_result = re.search(r'^(.*)/([^/]+)', context)
117        if pattern_match_result is not None:
118          add_context = pattern_match_result.group(1)
119        else:
120          add_context = ''
121      # If `scope` is given, only quantize it if the producer of weights
122      # (usually it's the layer op) is in the right scope.
123      _InsertQuantOp(
124          add_context,
125          'act_quant',
126          layer_match.activation_op,
127          consumer_ops,
128          is_training,
129          moving_avg=True,
130          ema_decay=ema_decay,
131          quant_delay=quant_delay,
132          vars_collection=vars_collection,
133          bits=activation_bits,
134          symmetric=symmetric,
135          init_min=0.0,
136          producer_scope=scope)
137      quantized_ops.add(layer_match.activation_op)
138
139    # Quantize the inputs and output to the bypass (if it exists). The input to
140    # the bypass is the bias add, and the output is the activation.
141    if layer_match.bypass_op is not None:
142      # If `scope` is given, only quantize it if the both the producer and the
143      # consumer are in the right scope.
144      _InsertQuantOp(
145          context,
146          'conv_quant',
147          layer_match.bias_add_op,
148          input_to_ops_map.ConsumerOperations(layer_match.bias_add_op),
149          is_training,
150          moving_avg=True,
151          ema_decay=ema_decay,
152          quant_delay=quant_delay,
153          vars_collection=vars_collection,
154          bits=activation_bits,
155          symmetric=symmetric,
156          producer_scope=scope,
157          consumer_scope=scope)
158      quantized_ops.add(layer_match.bias_add_op)
159      # Make sure the op following this isn't an activation. In which case, we
160      # shouldn't quantize it, since the activation will be Fused into the
161      # Add at inference time.
162      consumers = input_to_ops_map.ConsumerOperations(layer_match.bypass_op)
163      if any(consumer.type in _ACTIVATION_TYPES for consumer in consumers):
164        logging.info('Skipping %s, because its followed by an activation.',
165                     layer_match.bypass_op.name)
166      else:
167        _InsertQuantOp(
168            add_context,
169            'add_quant',
170            layer_match.bypass_op,
171            input_to_ops_map.ConsumerOperations(layer_match.bypass_op),
172            is_training,
173            moving_avg=True,
174            ema_decay=ema_decay,
175            quant_delay=quant_delay,
176            vars_collection=vars_collection,
177            bits=activation_bits,
178            symmetric=symmetric,
179            producer_scope=scope,
180            consumer_scope=scope)
181        quantized_ops.add(layer_match.bypass_op)
182
183    # Quantize bypass ops that occur after the activation.
184    if layer_match.post_activation_bypass_op is not None:
185      pattern_match_result = re.search(
186          r'^(.*)/([^/]+)', layer_match.post_activation_bypass_op.name)
187      if pattern_match_result is not None:
188        post_activation_bypass_context = pattern_match_result.group(1)
189      else:
190        post_activation_bypass_context = ''
191      # If `scope` is given, only quantize it if the producer is in the right
192      # scope.
193      # Make sure the op following this isn't an activation. In which case, we
194      # shouldn't quantize it, since the activation will be Fused into the
195      # Add at inference time.
196      consumers = input_to_ops_map.ConsumerOperations(
197          layer_match.post_activation_bypass_op)
198      if any(consumer.type in _RELU_TYPES for consumer in consumers):
199        logging.info('Skipping %s, because its followed by an activation.',
200                     layer_match.post_activation_bypass_op.name)
201      else:
202        _InsertQuantOp(
203            post_activation_bypass_context,
204            'post_activation_bypass_quant',
205            layer_match.post_activation_bypass_op,
206            consumers,
207            is_training,
208            moving_avg=True,
209            ema_decay=ema_decay,
210            quant_delay=quant_delay,
211            vars_collection=vars_collection,
212            bits=activation_bits,
213            symmetric=symmetric,
214            producer_scope=scope)
215        quantized_ops.add(layer_match.post_activation_bypass_op)
216
217  _QuantizeActivationLayers(
218      quantized_ops,
219      graph,
220      is_training,
221      activation_bits,
222      ema_decay,
223      quant_delay,
224      vars_collection,
225      scope=scope)
226
227
228def _QuantizeActivationLayers(quantized_ops,
229                              graph,
230                              is_training,
231                              activation_bits=8,
232                              ema_decay=0.999,
233                              quant_delay=None,
234                              vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
235                              scope=None):
236  """Quantize intermediate activation tensors after addition and multiplication.
237
238  Args:
239    quantized_ops: Set of previously quantized activation ops.
240    graph: Graph to modify.
241    is_training: Whether quantizing training graph or eval graph.
242    activation_bits: Number of bits to use for quantizing activations.
243    ema_decay: (Optional) Float, EMA decay parameter.  EMA is used to update
244      quantization intervals for quantizing activations (see here about EMA:
245      https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
246    quant_delay: (Optional, default None) Int, count of global steps for which
247      to delay quantization.  This helps weights stabilize at the start of
248      training.
249    vars_collection: (Optional) Collection where to store the variables for
250      quantization interval ends.
251    scope: The scope to be transformed. If it's not None, only the ops which are
252      in this scope will be transformed.
253
254  Raises:
255    ValueError: When quantization fails.
256  """
257  input_to_ops_map = input_to_ops.InputToOps(graph)
258  for op in (op for op in graph.get_operations()):
259    if _CheckIfQuantizableOp(op, quantized_ops):
260      logging.info('Inserting fake quant op activation_%s_quant after %s',
261                   op.type, op.name)
262      consumers = input_to_ops_map.ConsumerOperations(op)
263      _InsertQuantOp(
264          op.name,
265          'activation_' + op.type + '_quant',
266          op,
267          consumers,
268          is_training,
269          moving_avg=True,
270          ema_decay=ema_decay,
271          quant_delay=quant_delay,
272          vars_collection=vars_collection,
273          bits=activation_bits,
274          producer_scope=scope)
275
276
277def _CheckIfQuantizableOp(src_op, quantized_ops):
278  """Check if the output of an op should be quantized.
279
280  Args:
281    src_op: op to be checked
282    quantized_ops: Set of previously quantized activation ops.
283
284  Returns:
285    Boolean specifying if output should be quantized or not.
286  """
287  src_op_name = set([src_op.type])
288  if src_op in quantized_ops:
289    return False
290  if not src_op_name.intersection(_VALID_SRC_OP):
291    return False
292
293  # If src op is an add or a mul and the output is immediately
294  # followed by an activation skip
295  if len(src_op.outputs) == 1 and len(src_op.outputs[0].consumers()) == 1:
296    op_consumers = src_op.outputs[0].consumers()
297    if set([op_consumers[0].type]).intersection(_VALID_ACTIVATION_OP):
298      logging.info('Skipping quant after %s', src_op.name)
299      return False
300  # Is an Add or a Mul
301  input_ops = src_op.inputs
302
303  for op in input_ops:
304    curr_op = op.op
305    curr_op_type = set([curr_op.type])
306    while curr_op_type.intersection(_PASS_THROUGH_OP):
307      # Walk back through pass through ops
308      curr_op = curr_op.inputs[0].op
309      curr_op_type = set([curr_op.type])
310      # Now at a valid or quantizable op, need to check if
311      # atleast one of the inputs to a valid op is connected
312      # to a quantizable op via pass through ops
313
314    if (curr_op_type.intersection(_QUANTIZATION_OP) or
315        curr_op.name.find('delayed_quant/Merge') > 0):
316      return True
317
318    if curr_op_type.intersection(_INTERMEDIATE_OP):
319      # Check if atleast one input to intermediate_op are quantizable
320      for input_op in curr_op.inputs:
321        if _CheckIfQuantizableOp(input_op.op, quantized_ops):
322          return True
323  return False
324
325
326def _FindLayersToQuantize(graph):
327  """Matches layers in graph to quantize.
328
329  The following patterns get matched. Nodes surrounded by [] will be
330  optionally matched:
331
332          weight|folded_weight
333                /
334         conv|fc
335            |
336      [batch_to_space_nd]
337            |
338    [post_conv_correction]
339            |
340     [biasadd|folded_bias]
341            |
342         [bypass]
343            |
344        activation
345            |
346   [post_activation_bypass]
347
348  Match replacements:
349    If weight|folded_weight is found, FakeQuant is added afterwards.
350    If bypass is found, FakeQuant is added before and after.
351    If activation is found, FakeQuant is added afterwards.
352    If post_activation_bypass is found, FakeQuant is added afterwards.
353
354  Args:
355    graph: Graph to perform match on.
356
357  Returns:
358    list of _LayerMatches.
359  """
360  input_pattern = graph_matcher.OpTypePattern('*')
361  weight_var_pattern = graph_matcher.OpTypePattern('Variable|VariableV2')
362  weight_partition_identity_pattern = graph_matcher.OpTypePattern(
363      'Identity', inputs=[weight_var_pattern])
364  weight_partition_concat_pattern = graph_matcher.OpTypePattern(
365      'ConcatV2', inputs=[weight_partition_identity_pattern, '*', '*'])
366  weight_identity_pattern = graph_matcher.OpTypePattern(
367      'Identity',
368      inputs=[
369          graph_matcher.OneofPattern([
370              weight_partition_identity_pattern,
371              weight_partition_concat_pattern,
372              weight_var_pattern,
373          ])
374      ])
375  weight_resource_var_pattern = graph_matcher.OpTypePattern('ReadVariableOp')
376  folded_weight_pattern = graph_matcher.OpTypePattern('Mul')
377
378  # The weights inputs to the layer operation can either be from the Variable or
379  # the folded weight (Mul).
380  layer_pattern = graph_matcher.OpTypePattern(
381      '|'.join(_QUANTIZABLE_TYPES),
382      inputs=[
383          input_pattern,
384          graph_matcher.OneofPattern([
385              weight_identity_pattern, weight_resource_var_pattern,
386              folded_weight_pattern
387          ])
388      ],
389      ordered_inputs=False)
390
391  # For atrous convolutions a BatchToSpaceND will occur after the depthwise
392  # convolution.
393  batch_to_space_pattern = graph_matcher.OpTypePattern(
394      'BatchToSpaceND',
395      inputs=[
396          layer_pattern,
397          graph_matcher.OpTypePattern('*'),
398          graph_matcher.OpTypePattern('*')
399      ])
400
401  layer_output_pattern = graph_matcher.OneofPattern(
402      [batch_to_space_pattern, layer_pattern])
403
404  # For separable convolutions, we are looking for a conv, followed by a conv
405  # with no activations between the two.
406  sep_conv_pattern = graph_matcher.OpTypePattern(
407      '|'.join(_QUANTIZABLE_TYPES),
408      inputs=[
409          graph_matcher.OneofPattern([layer_output_pattern]),
410          graph_matcher.OpTypePattern('*')
411      ],
412      ordered_inputs=False)
413  folded_bias_mul_pattern = graph_matcher.OpTypePattern(
414      'Mul',
415      inputs=[graph_matcher.OpTypePattern('*'), layer_output_pattern],
416      ordered_inputs=False)
417  post_layer_op_correction_pattern = graph_matcher.OpTypePattern(
418      'Add',
419      inputs=[folded_bias_mul_pattern,
420              graph_matcher.OpTypePattern('*')],
421      ordered_inputs=False)
422  folded_bias_add_pattern = graph_matcher.OpTypePattern(
423      'Add',
424      inputs=[
425          post_layer_op_correction_pattern,
426          graph_matcher.OpTypePattern('*')
427      ],
428      ordered_inputs=False)
429
430  # batch_norms with forced updates have an Identity operation at the end.
431  # TODO(suharshs): Find a way to easily skip extra Identity operations. The
432  # current issue is that doing so can often match patterns across many layers
433  # incorrectly.
434  batch_norm_identity = graph_matcher.OpTypePattern(
435      'Identity', inputs=[folded_bias_add_pattern])
436
437  bias_add_pattern = graph_matcher.OpTypePattern(
438      'Add|BiasAdd', inputs=[layer_output_pattern, '*'], ordered_inputs=False)
439
440  # The bias can come from the bias add or the folded bias add.
441  bypass_pattern = graph_matcher.OpTypePattern(
442      'Add',
443      inputs=[
444          graph_matcher.OneofPattern(
445              [bias_add_pattern, folded_bias_add_pattern, batch_norm_identity]),
446          '*'
447      ],
448      ordered_inputs=False)
449
450  # The input to the activation can come from bias add, fold bias add, the
451  # bypasses.
452  # TODO(suharshs): We should ideally skip Identity operations instead of
453  # treating them as activations.
454  activation_pattern = graph_matcher.OpTypePattern(
455      '|'.join(_ACTIVATION_TYPES) + '|Identity',
456      inputs=[
457          graph_matcher.OneofPattern([
458              bias_add_pattern,
459              folded_bias_add_pattern,
460              batch_norm_identity,
461              bypass_pattern,
462              layer_pattern,
463          ])
464      ])
465
466  post_activation_bypass_pattern = graph_matcher.OpTypePattern(
467      'Add', inputs=['*', activation_pattern], ordered_inputs=False)
468
469  # The order of the following matching blocks is very important. Since matches
470  # aren't guaranteed to be disjoint, we structure matches from largest to
471  # smallest to guarantee that the largest match always wins. Additionally, we
472  # ensure that we don't match layers multiple times.
473
474  layer_matches = []
475  # We use matched_layer_set to ensure that layers aren't matched multiple
476  # times.
477  matched_layer_set = set()
478
479  # First, we match layers that have a post activation bypass. We do this first
480  # to ensure we don't match only the first part of this layer, missing the
481  # post activation bypass node.
482  post_activation_bypass_layer_matcher = graph_matcher.GraphMatcher(
483      post_activation_bypass_pattern)
484  for match_result in post_activation_bypass_layer_matcher.match_graph(graph):
485    layer_op = match_result.get_op(layer_pattern)
486    weight_tensor = match_result.get_tensor(weight_identity_pattern)
487    if weight_tensor is None:
488      weight_tensor = match_result.get_tensor(weight_resource_var_pattern)
489    if weight_tensor is None:
490      weight_tensor = match_result.get_tensor(folded_weight_pattern)
491    activation_op = match_result.get_op(activation_pattern)
492    bias_add_op = match_result.get_op(bias_add_pattern)
493    if bias_add_op is None:
494      bias_add_op = match_result.get_op(folded_bias_add_pattern)
495    bypass_op = match_result.get_op(bypass_pattern)
496    post_activation_bypass_op = match_result.get_op(
497        post_activation_bypass_pattern)
498    if layer_op not in matched_layer_set:
499      matched_layer_set.add(layer_op)
500      layer_matches.append(
501          _LayerMatch(layer_op, weight_tensor, activation_op, bypass_op,
502                      post_activation_bypass_op, bias_add_op))
503
504  # Now, we match the basic layer ending at an activation. We may get duplicate
505  # matches from above, but we don't add them to layer_matches.
506  layer_matcher = graph_matcher.GraphMatcher(activation_pattern)
507  for match_result in layer_matcher.match_graph(graph):
508    layer_op = match_result.get_op(layer_pattern)
509    weight_tensor = match_result.get_tensor(weight_identity_pattern)
510    if weight_tensor is None:
511      weight_tensor = match_result.get_tensor(weight_resource_var_pattern)
512    if weight_tensor is None:
513      weight_tensor = match_result.get_tensor(folded_weight_pattern)
514    activation_op = match_result.get_op(activation_pattern)
515    bias_add_op = match_result.get_op(bias_add_pattern)
516    if bias_add_op is None:
517      bias_add_op = match_result.get_op(folded_bias_add_pattern)
518    bypass_op = match_result.get_op(bypass_pattern)
519    if layer_op not in matched_layer_set:
520      if not _IsSkipLayer(activation_op):
521        matched_layer_set.add(layer_op)
522        layer_matches.append(
523            _LayerMatch(layer_op, weight_tensor, activation_op, bypass_op, None,
524                        bias_add_op))
525
526  # Match the final layer, where there may not be an activation and instead
527  # the output of the final BiasAdd must be quantized. So we treat the BiasAdd
528  # as the 'activation_op' in the _LayerMatch, to ensure that it's output is
529  # quantized.
530  final_layer_matcher = graph_matcher.GraphMatcher(
531      graph_matcher.OneofPattern([bias_add_pattern, folded_bias_add_pattern]))
532  for match_result in final_layer_matcher.match_graph(graph):
533    layer_op = match_result.get_op(layer_pattern)
534    weight_tensor = match_result.get_tensor(weight_identity_pattern)
535    if weight_tensor is None:
536      weight_tensor = match_result.get_tensor(weight_resource_var_pattern)
537    if weight_tensor is None:
538      weight_tensor = match_result.get_tensor(folded_weight_pattern)
539    activation_op = match_result.get_op(bias_add_pattern)
540    if activation_op is None:
541      activation_op = match_result.get_op(folded_bias_add_pattern)
542    if layer_op not in matched_layer_set:
543      matched_layer_set.add(layer_op)
544      layer_matches.append(
545          _LayerMatch(layer_op, weight_tensor, activation_op, None, None, None))
546
547  # Look for separable convolutions here
548  sep_conv_matcher = graph_matcher.GraphMatcher(sep_conv_pattern)
549  for match_result in sep_conv_matcher.match_graph(graph):
550    layer_op = match_result.get_op(layer_pattern)
551    weight_tensor = match_result.get_tensor(weight_identity_pattern)
552    if weight_tensor is None:
553      weight_tensor = match_result.get_tensor(weight_resource_var_pattern)
554    activation_op = match_result.get_op(layer_pattern)
555    if layer_op not in matched_layer_set:
556      matched_layer_set.add(layer_op)
557      layer_matches.append(
558          _LayerMatch(layer_op, weight_tensor, activation_op, None, None, None))
559
560  return layer_matches
561
562
563def _IsSkipLayer(activation_op):
564  """Skip quantizing conv->identity->Batch norm layers.
565
566  Args:
567    activation_op: Activation op detected by layer matching pattern
568
569  Returns:
570    skip_layer: boolean, true when conv->identity->batch norm is detected.
571  """
572
573  # Exclude quantization of conv->identity->BN,
574  # After folding, this part corresponds to estimation of mean and variance
575  # and should not be quantized.
576  skip_layer = False
577  if activation_op.type == 'Identity' and len(activation_op.outputs) == 1:
578    if len(activation_op.outputs[0].consumers()) == 1:
579      consumer = activation_op.outputs[0].consumers()[0]
580      if consumer.type == 'FusedBatchNorm':
581        skip_layer = True
582        logging.info(
583            'Skipping quantizing %s, because it is the output of a conv/fc '
584            'followed by a identity, feeding a fused batch norm.',
585            activation_op.name)
586  return skip_layer
587
588
589class _LayerMatch(object):
590  """Contains all information related to a matched Layer."""
591
592  def __init__(self, layer_op, weight_tensor, activation_op, bypass_op,
593               post_activation_bypass_op, bias_add_op):
594    self._layer_op = layer_op
595    self._weight_tensor = weight_tensor
596    self._activation_op = activation_op
597    self._bypass_op = bypass_op
598    self._post_activation_bypass_op = post_activation_bypass_op
599    self._bias_add_op = bias_add_op
600
601  @property
602  def layer_op(self):
603    return self._layer_op
604
605  @property
606  def weight_tensor(self):
607    return self._weight_tensor
608
609  @property
610  def activation_op(self):
611    return self._activation_op
612
613  @property
614  def bypass_op(self):
615    return self._bypass_op
616
617  @property
618  def post_activation_bypass_op(self):
619    return self._post_activation_bypass_op
620
621  @property
622  def bias_add_op(self):
623    return self._bias_add_op
624
625
626def _FollowedByFakeQuant(tensor):
627  """Returns True if the tensor is followed by a FakeQuant."""
628  fake_quant_ops = set([
629      'FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxArgs',
630      'FakeQuantWithMinMaxVarsPerChannel'
631  ])
632  pass_through_ops = set(['Reshape', 'Identity'])
633  consumers = tensor.consumers()
634  while consumers:
635    c = consumers.pop()
636    if c.type in fake_quant_ops:
637      return True
638    elif c.type in pass_through_ops:
639      for output in c.outputs:
640        consumers.extend(output.consumers())
641  return False
642
643
644def _InsertQuantOp(context,
645                   name,
646                   producer,
647                   consumers,
648                   is_training,
649                   moving_avg=True,
650                   init_min=-6.0,
651                   init_max=6.0,
652                   bits=8,
653                   symmetric=False,
654                   ema_decay=0.999,
655                   quant_delay=None,
656                   vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
657                   narrow_range=False,
658                   producer_scope=None,
659                   consumer_scope=None):
660  """Inserts a quant op between a producer op and (multiple) consumer ops.
661
662  Args:
663    context: Context where producer and consumer operations are nested.
664    name: Name for the new quantization op within the context.
665    producer: Producer operation of the pairs where quantization will be
666      inserted.
667    consumers: Consumer operations of the pairs.
668    is_training: Whether quantizing training graph or eval graph.
669    moving_avg: Specifies whether to use exponential moving average or just
670      the last value seen.
671    init_min: Starting minimum value for the new quantization op.
672    init_max: Starting maximum value for the new quantization op.
673    bits: Number of bits to use for quantization, must be between 2 and 8.
674    symmetric: (Optional) If true, use symmetric quantization limits instead of
675      training the minimum and maximum of each quantization range separately.
676    ema_decay: (Optional) Float, EMA decay parameter.  EMA is used to update
677      quantization intervals for quantizing activations (see here about EMA:
678      https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
679    quant_delay: (Optional, default None) Int, count of global steps for which
680      to delay quantization.  This helps weights stabilize at the start of
681      training.
682    vars_collection: (Optional) Collection where to store the variables for
683      quantization interval ends.
684    narrow_range: Whether to use the narrow quantization range
685      [1; 2^bits - 1] or wide range [0; 2^bits - 1].
686    producer_scope: The restriction of producer scope. If not None, the new op
687      will be inserted only when the producer is in this scope.
688    consumer_scope: The restriction of consumer scope. If not None, the new op
689      will be inserted only when all the consumers are in this scope.
690  Raises:
691    ValueError: When producer operation is not directly connected to the
692      consumer operation.
693  """
694  if producer_scope and not producer.name.startswith(producer_scope):
695    logging.info(
696        '_InsertQuantOp ignores context="%s" name="%s" '
697        'because producer "%s" is not in scope "%s"',
698        context, name, producer.name, producer_scope)
699    return
700
701  if consumer_scope:
702    consumers_in_scope = []
703    for consumer in consumers:
704      if consumer.name.startswith(consumer_scope):
705        consumers_in_scope.append(consumer)
706      else:
707        logging.info(
708            '_InsertQuantOp context="%s" name="%s" ignores '
709            'consumer "%s" because it is not in scope "%s"',
710            context, name, consumer.name, consumer_scope)
711        return
712    consumers = consumers_in_scope
713
714  name_prefix = _AddContextToName(context, name)
715  # This is needed on TPU where name_scope == 'TPUReplicate/loop', and
716  # name_prefix starts with 'TPUReplicate/loop/'; without dropping it
717  # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which
718  # breaks things later.
719  name_scope = ops.get_name_scope()
720  if name_scope:
721    name_prefix = common.DropStringPrefix(name_prefix, name_scope + '/')
722
723  inputs = producer.outputs[0]
724  # Prevent ops from being quantized multiple times. Bypass ops can sometimes
725  # overlap between multiple matches, so we need to ensure that we don't
726  # add duplicate FakeQuant operations.
727  if _FollowedByFakeQuant(inputs):
728    return
729
730  if moving_avg:
731    quant = (
732        quant_ops.MovingAvgQuantize(
733            inputs,
734            init_min=init_min,
735            init_max=init_max,
736            ema_decay=ema_decay,
737            is_training=is_training,
738            num_bits=bits,
739            symmetric=symmetric,
740            narrow_range=narrow_range,
741            vars_collection=vars_collection,
742            name_prefix=name_prefix))
743  else:
744    quant = (
745        quant_ops.LastValueQuantize(
746            inputs,
747            init_min=init_min,
748            init_max=init_max,
749            is_training=is_training,
750            num_bits=bits,
751            symmetric=symmetric,
752            narrow_range=narrow_range,
753            vars_collection=vars_collection,
754            name_prefix=name_prefix))
755
756  if quant_delay and quant_delay > 0:
757    activate_quant = math_ops.greater_equal(
758        common.CreateOrGetQuantizationStep(),
759        quant_delay,
760        name=name_prefix + '/activate_quant')
761    quant = control_flow_ops.cond(
762        activate_quant,
763        lambda: quant,
764        lambda: inputs,
765        name=name_prefix + '/delayed_quant')
766
767  if consumers:
768    tensors_modified_count = common.RerouteTensor(
769        quant, inputs, can_modify=consumers)
770    # Some operations can have multiple output tensors going to the same
771    # consumer. Since consumers is a set, we need to ensure that
772    # tensors_modified_count is greater than or equal to the length of the set
773    # of consumers.
774    if tensors_modified_count < len(consumers):
775      raise ValueError('No inputs quantized for ops: [%s]' % ', '.join(
776          [consumer.name for consumer in consumers]))
777
778
779def _GetContextFromOp(op):
780  """Gets the root context name from the op name."""
781  context_re = re.search(r'^(.*)/([^/]+)', op.name)
782  if context_re:
783    return context_re.group(1)
784  return ''
785
786
787def _AddContextToName(context, name):
788  """Adds the context to the name if it exists."""
789  if not context:
790    return name
791  return context + '/' + name
792