1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Logic to fold batch norm into preceding convolution or FC layers."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import re
22from tensorflow.contrib.quantize.python import common
23from tensorflow.contrib.quantize.python import graph_matcher
24from tensorflow.contrib.quantize.python import input_to_ops
25from tensorflow.core.framework import attr_value_pb2
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.layers import utils
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import nn
32from tensorflow.python.ops import nn_ops
33from tensorflow.python.ops import variable_scope
34from tensorflow.python.util import compat
35
36
37def FoldBatchNorms(graph, is_training, freeze_batch_norm_delay=None):
38  """Finds batch norm layers and folds them into preceding layers.
39
40  Folding only affects the following layers: Conv2D, fully connected, depthwise
41  convolution.
42
43  Args:
44    graph: Graph to walk and modify.
45    is_training: Bool, true if training.
46    freeze_batch_norm_delay: How many steps to wait before freezing moving mean
47      and variance and using them for batch normalization. This value is used
48      only when is_training is True.
49  Raises:
50    ValueError: When batch norm folding fails.
51  """
52  _FoldFusedBatchNorms(
53      graph, is_training, freeze_batch_norm_delay=freeze_batch_norm_delay)
54  _FoldUnfusedBatchNorms(
55      graph,
56      is_training=is_training,
57      freeze_batch_norm_delay=freeze_batch_norm_delay)
58
59
60def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
61  """Finds fused batch norm layers and folds them into preceding layers.
62
63  Folding only affects the following layers: Conv2D, fully connected, depthwise
64  convolution.
65
66  Args:
67    graph: Graph to walk and modify.
68    is_training: Bool, true if training.
69    freeze_batch_norm_delay: How many steps to wait before freezing moving mean
70      and variance and using them for batch normalization.
71
72  Raises:
73    ValueError: When batch norm folding fails.
74  """
75  for match in _FindFusedBatchNorms(graph):
76    scope, sep, _ = match.layer_op.name.rpartition('/')
77    # Make sure new ops are added to `graph` and put on the same device as
78    # `bn_op`. The '/' (i.e. `sep`) ensures that we reuse the existing scope
79    # named `scope`. Otherwise, TF creates a unique scope whose name starts with
80    # `scope`.
81    with graph.as_default(), graph.name_scope(scope + sep):
82      with graph.name_scope(scope + sep + 'BatchNorm_Fold' + sep):
83        # new weights = old weights * gamma / sqrt(variance + epsilon)
84        # new biases = -mean * gamma / sqrt(variance + epsilon) + beta
85        multiplier_tensor = match.gamma_tensor * math_ops.rsqrt(
86            match.variance_tensor + match.bn_op.get_attr('epsilon'))
87        bias_tensor = math_ops.subtract(
88            match.beta_tensor,
89            match.mean_tensor * multiplier_tensor,
90            name='bias')
91
92        correction_scale, correction_recip, correction_offset = None, None, None
93        if is_training:
94          correction_scale, correction_recip, correction_offset = (
95              _ComputeBatchNormCorrections(
96                  context='',
97                  match=match,
98                  freeze_batch_norm_delay=freeze_batch_norm_delay))
99        # The shape of depthwise weights is different, so we need to reshape the
100        # multiplier_tensor to ensure that the scaled_weight_tensor has the
101        # expected shape.
102        weights = match.weight_tensor
103        if match.layer_op.type == 'DepthwiseConv2dNative':
104          new_shape = [
105              match.weight_tensor.get_shape().as_list()[2],
106              match.weight_tensor.get_shape().as_list()[3]
107          ]
108          multiplier_tensor = array_ops.reshape(
109              multiplier_tensor, new_shape, name='scale_reshape')
110
111          if correction_scale is not None:
112            correction_scale = array_ops.reshape(
113                correction_scale, new_shape, name='correction_reshape')
114
115      if correction_scale is not None:
116        weights = math_ops.multiply(
117            correction_scale, weights, name='correction_mult')
118
119      scaled_weight_tensor = math_ops.multiply(
120          weights, multiplier_tensor, name='mul_fold')
121
122      new_layer_tensor = _CloneWithNewOperands(
123          match.layer_op, match.input_tensor, scaled_weight_tensor,
124          match.batch_to_space_op)
125
126      if correction_recip is not None:
127        new_layer_tensor = math_ops.multiply(
128            correction_recip, new_layer_tensor, name='post_conv_mul')
129        new_layer_tensor = math_ops.add(new_layer_tensor, (correction_offset),
130                                        'correction_add')
131
132      bias_add_tensor = math_ops.add(
133          new_layer_tensor, bias_tensor, name='add_fold')
134
135      nodes_modified_count = common.RerouteTensor(bias_add_tensor,
136                                                  match.output_tensor)
137      if nodes_modified_count == 0:
138        raise ValueError('Folding batch norms failed, %s had no outputs.' %
139                         match.output_tensor.name)
140
141
142def _FindFusedBatchNorms(graph):
143  """Finds all ops and tensors related to found FusedBatchNorms.
144
145  Args:
146    graph: Graph to inspect.
147
148  Returns:
149    _FusedBatchNormMatches.
150  """
151  input_pattern = graph_matcher.OpTypePattern('*')
152  # In practice, the weight pattern can match a Variable or a SpaceToBatchND
153  # operation that follows a variable for atrous convolutions.
154  weight_pattern = graph_matcher.OpTypePattern('*')
155  gamma_pattern = graph_matcher.OpTypePattern('*')
156  beta_pattern = graph_matcher.OpTypePattern('*')
157  mean_pattern = graph_matcher.OpTypePattern('*')
158  variance_pattern = graph_matcher.OpTypePattern('*')
159
160  moving_average_pattern = graph_matcher.OpTypePattern('*')
161  bn_decay_pattern = graph_matcher.OpTypePattern('*')
162  layer_pattern = graph_matcher.OpTypePattern(
163      'Conv2D|DepthwiseConv2dNative|MatMul',
164      inputs=[input_pattern, weight_pattern])
165  batch_to_space_pattern = graph_matcher.OpTypePattern(
166      'BatchToSpaceND',
167      inputs=[
168          layer_pattern,
169          graph_matcher.OpTypePattern('*'),
170          graph_matcher.OpTypePattern('*')
171      ])
172  # Identity between conv/matmul and bn
173  layer_pattern_with_identity = graph_matcher.OpTypePattern(
174      'Identity',
175      inputs=[
176          graph_matcher.OneofPattern([batch_to_space_pattern, layer_pattern])
177      ])
178  layer_output_pattern = graph_matcher.OneofPattern(
179      [layer_pattern_with_identity, layer_pattern, batch_to_space_pattern])
180
181  # MatMul has a Reshape between it and FusedBatchNorm.
182  matmul_reshape_pattern = graph_matcher.OpTypePattern(
183      'Reshape',
184      inputs=[layer_output_pattern,
185              graph_matcher.OpTypePattern('*')])
186
187  batch_norm_pattern = graph_matcher.OpTypePattern(
188      'FusedBatchNorm',
189      inputs=[
190          graph_matcher.OneofPattern(
191              [matmul_reshape_pattern, layer_output_pattern]), gamma_pattern,
192          beta_pattern, mean_pattern, variance_pattern
193      ])
194  matmul_bn_output_reshape_pattern = graph_matcher.OpTypePattern(
195      'Reshape', inputs=[batch_norm_pattern,
196                         graph_matcher.OpTypePattern('*')])
197
198  batch_norm_identity_pattern = graph_matcher.OpTypePattern(
199      'Identity', inputs=[batch_norm_pattern, matmul_bn_output_reshape_pattern])
200
201  bn_identity_matcher = graph_matcher.GraphMatcher(batch_norm_identity_pattern)
202
203  bn_matcher = graph_matcher.GraphMatcher(
204      graph_matcher.OneofPattern(
205          [matmul_bn_output_reshape_pattern, batch_norm_pattern]))
206
207  moving_average_sub_pattern = graph_matcher.OpTypePattern(
208      'Sub', inputs=[moving_average_pattern, batch_norm_pattern])
209  moving_average_mul_pattern = graph_matcher.OpTypePattern(
210      'Mul', inputs=[moving_average_sub_pattern, bn_decay_pattern])
211
212  moving_avg_mul_matcher = graph_matcher.GraphMatcher(
213      moving_average_mul_pattern)
214
215  def _GetLayerMatch(match_result):
216    """Populates a layer match object containing ops/tensors for folding BNs.
217
218    Args:
219      match_result: Matched result from graph matcher
220
221    Returns:
222      layer_op: Matching conv/fc op prior to batch norm
223      BatchNormMatch: _BatchNormMatch containing all required batch norm
224      parameters.
225    """
226    moving_mean_tensor = None
227    moving_variance_tensor = None
228    bn_decay_mean_tensor = None
229    bn_decay_var_tensor = None
230    batch_to_space_op = None
231    layer_op = match_result.get_op(layer_pattern)
232    layer_tensor = match_result.get_tensor(layer_pattern)
233    bn_id_op = match_result.get_op(batch_norm_identity_pattern)
234    bn_op = match_result.get_op(batch_norm_pattern)
235    if bn_id_op is None:
236      bn_id_op = bn_op
237
238    batch_epsilon = bn_op.get_attr('epsilon')
239
240    # In the MatMul case, the output of batch norm is reshaped back into a
241    # 2D tensor, so the output_tensor is the output of the Reshape op.
242    output_tensor = bn_op.outputs[0]
243    if layer_op.type == 'MatMul':
244      output_reshape_op = match_result.get_op(matmul_bn_output_reshape_pattern)
245      # If the matcher didn't match matmul_bn_output_reshape, there will be
246      # another match for this 'MatMul' later, so we can skip this one.
247      if output_reshape_op is None:
248        return None, None
249      output_tensor = output_reshape_op.outputs[0]
250
251    # Ensure that the output tensor has consumers, otherwise this is a dangling
252    # node and not a match.
253    if not output_tensor.consumers():
254      return None, None
255
256    batch_to_space_op = match_result.get_op(batch_to_space_pattern)
257    input_tensor = match_result.get_tensor(input_pattern)
258    weight_tensor = match_result.get_tensor(weight_pattern)
259    gamma_tensor = match_result.get_tensor(gamma_pattern)
260    beta_tensor = match_result.get_tensor(beta_pattern)
261    # FusedBatchNorm in training is different from that in inference. It takes
262    # empty 'mean' and empty 'variance', and produces the mean and the variance
263    # of the batch. Therefore, when is_training is true, mean_tensor and
264    # variance_tensor point to 1st and 2nd (0-based) output of bn_op,
265    # respectively; when is_training is false, they point to bn_op's inputs.
266    is_training = bn_op.get_attr('is_training')
267    if is_training:
268      # FusedBatchNormGrad doesn't compute gradients of the batch_mean and
269      # batch_variance outputs, so we need to substitute our own custom
270      # gradient.
271      # TODO(suharshs, raghuramank): Find a way to avoid needing this hack.
272      # pylint: disable=protected-access
273      bn_op._set_attr(
274          '_gradient_op_type',
275          attr_value_pb2.AttrValue(s=compat.as_bytes('FoldFusedBatchNormGrad')))
276      # pylint: enable=protected-access
277      mean_tensor = bn_op.outputs[1]
278      # The batch variance used during forward and backward prop is biased,
279      # i.e it is calculated as: V=sum(x(k)-mu)^2/N. For the moving average
280      # calculation, the variance is corrected by the term N/N-1 (Bessel's
281      # correction). The variance tensor read from FuseBatchNorm has Bessel's
282      # correction applied, so we undo it here.
283      scope, sep, _ = bn_op.name.rpartition('/')
284      g = ops.get_default_graph()
285      with g.as_default(), g.name_scope(scope + sep):
286        n = math_ops.cast(
287            array_ops.size(layer_tensor) / array_ops.size(mean_tensor),
288            dtypes.float32)
289        variance_tensor = math_ops.multiply(
290            bn_op.outputs[2], (n - 1) / n, name='Undo_Bessel_Correction')
291      # TODO(suharshs): Find a way to get rid of this inner match.
292      for mul_match_result in moving_avg_mul_matcher.match_graph(graph):
293        sub_op = mul_match_result.get_op(moving_average_sub_pattern)
294        if sub_op.inputs[1].name == bn_op.outputs[1].name:
295          # During training: Batch Mean is bn_op.outputs[1]
296          moving_mean_tensor = sub_op.inputs[0]
297          bn_decay_mean_tensor = mul_match_result.get_tensor(bn_decay_pattern)
298        if sub_op.inputs[1].name == bn_op.outputs[2].name:
299          # During training: Batch Var is bn_op.outputs[2]
300          moving_variance_tensor = sub_op.inputs[0]
301          bn_decay_var_tensor = mul_match_result.get_tensor(bn_decay_pattern)
302    else:
303      mean_tensor = match_result.get_tensor(mean_pattern)
304      variance_tensor = match_result.get_tensor(variance_pattern)
305
306    return layer_op, _BatchNormMatch(
307        layer_op=layer_op,
308        bn_op=bn_op,
309        output_tensor=output_tensor,
310        input_tensor=input_tensor,
311        weight_tensor=weight_tensor,
312        gamma_tensor=gamma_tensor,
313        beta_tensor=beta_tensor,
314        mean_tensor=mean_tensor,
315        variance_tensor=variance_tensor,
316        moving_mean_tensor=moving_mean_tensor,
317        moving_variance_tensor=moving_variance_tensor,
318        bn_decay_mean_tensor=bn_decay_mean_tensor,
319        bn_decay_var_tensor=bn_decay_var_tensor,
320        batch_epsilon=batch_epsilon,
321        batch_to_space_op=batch_to_space_op)
322
323  layer_matches = []
324  # We use matched_layer_set to ensure that layers aren't matched multiple
325  # times.
326  matched_layer_set = set()
327  for match_result in bn_identity_matcher.match_graph(graph):
328    layer_op, layer_match = _GetLayerMatch(match_result)
329    if layer_op is not None:
330      if layer_op not in matched_layer_set:
331        matched_layer_set.add(layer_op)
332        layer_matches.append(layer_match)
333
334  for match_result in bn_matcher.match_graph(graph):
335    layer_op, layer_match = _GetLayerMatch(match_result)
336    if layer_op is not None:
337      if layer_op not in matched_layer_set:
338        matched_layer_set.add(layer_op)
339        layer_matches.append(layer_match)
340
341  return layer_matches
342
343
344def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay):
345  """Computes batch norm correction params.
346
347     Before batch normalization is frozen:
348     We use batch statistics for batch norm.
349       correction_scale = sigma_b/sigma_mv
350       correction_recip = 1/correction_scale
351       correction_offset = 0
352
353     After batch normalization is frozen:
354      correction_scale = sigma_b/sigma_mv
355      correction_recip = 1
356      correction_offset =  gamma*(mu_b/sigma_b-mu_mv/sigma_mv).
357
358     Batch norm is frozen if global_step > bn_freeze_delay.
359     The corrections ensure that:
360     a) The weights are quantized after scaling by gamma/sigma_mv. This enables
361     smoother training as the scaling on the weights changes slowly, rather than
362     jump across mini-batches
363     b) Changing the values of the corrections allows for one to switch between
364     using batch statistics to using moving mean and average, without requiring
365     changes to batch_norm
366
367
368  Args:
369    context: The scope under which we look for batch norm params
370    match: Object containing required batch norm tensors for correction
371      computation.
372    freeze_batch_norm_delay: Delay in steps at which computation switches
373      from regular batch norm to frozen mean and variance.
374
375
376  Returns:
377    A tuple of correction_scale, correction_recip, correction_offset
378  """
379
380  g = ops.get_default_graph()
381  prefix = '' if not context else context
382  with g.name_scope(prefix + 'batch_norm_correction'):
383    recip_sigma_mv = math_ops.rsqrt(
384        match.moving_variance_tensor + match.batch_epsilon)
385    recip_sigma = math_ops.rsqrt(match.variance_tensor + match.batch_epsilon)
386    correction_scale = math_ops.divide(
387        recip_sigma_mv, recip_sigma, name='scale_compute')
388    correction_scale = array_ops.identity(
389        correction_scale, name='correction_scale')
390    correction_recip = math_ops.reciprocal(
391        correction_scale, name='reciprocal_compute')
392    correction_offset = math_ops.multiply(
393        match.gamma_tensor,
394        match.mean_tensor * recip_sigma -
395        match.moving_mean_tensor * recip_sigma_mv,
396        name='offset_compute')
397
398    if freeze_batch_norm_delay is not None:
399      use_mv_avg = math_ops.greater_equal(
400          common.CreateOrGetQuantizationStep(),
401          freeze_batch_norm_delay,
402          name='use_moving_average')
403    else:
404      use_mv_avg = False
405
406    bn_decay_zero = 0.0
407    bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers())
408    bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers())
409
410    bn_decay_mean_out = utils.smart_cond(
411        use_mv_avg,
412        lambda: bn_decay_zero,
413        lambda: match.bn_decay_mean_tensor,
414        name='freeze_moving_mean')
415
416    common.RerouteTensor(
417        bn_decay_mean_out,
418        match.bn_decay_mean_tensor,
419        can_modify=bn_decay_mean_consumers)
420
421    bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers())
422    bn_decay_var_out = utils.smart_cond(
423        use_mv_avg,
424        lambda: bn_decay_zero,
425        lambda: match.bn_decay_var_tensor,
426        name='freeze_moving_var')
427    common.RerouteTensor(
428        bn_decay_var_out,
429        match.bn_decay_var_tensor,
430        can_modify=bn_decay_var_consumers)
431
432    correction_recip = utils.smart_cond(
433        use_mv_avg,
434        lambda: array_ops.ones(correction_scale.shape),
435        lambda: correction_recip,
436        name='correction_recip')
437
438    correction_offset = utils.smart_cond(
439        use_mv_avg,
440        lambda: correction_offset,
441        lambda: array_ops.zeros(correction_offset.shape),
442        name='correction_offset')
443  return correction_scale, correction_recip, correction_offset
444
445
446def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor,
447                          batch_to_space_op):
448  """Clones layer_op with input_tensor and weight_tensor as new inputs."""
449  new_layer_name = layer_op.name.split('/')[-1] + '_Fold'
450  if layer_op.type == 'Conv2D':
451    return nn_ops.conv2d(
452        input_tensor,
453        weight_tensor,
454        strides=layer_op.get_attr('strides'),
455        padding=layer_op.get_attr('padding'),
456        use_cudnn_on_gpu=layer_op.get_attr('use_cudnn_on_gpu'),
457        data_format=layer_op.get_attr('data_format').decode(),
458        name=new_layer_name)
459  elif layer_op.type == 'MatMul':
460    return math_ops.matmul(
461        input_tensor,
462        weight_tensor,
463        transpose_a=layer_op.get_attr('transpose_a'),
464        transpose_b=layer_op.get_attr('transpose_b'),
465        name=new_layer_name)
466  elif layer_op.type == 'DepthwiseConv2dNative':
467    # We don't copy dilation rate because we reuse the input SpaceToBatch
468    # and create our own BatchToSpace operation below.
469    conv = nn.depthwise_conv2d(
470        input_tensor,
471        weight_tensor,
472        strides=layer_op.get_attr('strides'),
473        padding=layer_op.get_attr('padding'),
474        name=new_layer_name)
475    # Copy the batch to space operation if we have a atrous convolution.
476    if batch_to_space_op:
477      batch_to_space_op = layer_op.outputs[0].consumers()[0]
478      # TODO(suharshs): It's hard to make this name match with the unfused name.
479      # Restructure this code to not rely on scope at all.
480      new_batch_to_space_name = batch_to_space_op.name.split('/')[-1] + '_Fold'
481      conv = array_ops.batch_to_space_nd(
482          conv,
483          batch_to_space_op.inputs[1],
484          batch_to_space_op.inputs[2],
485          name=new_batch_to_space_name)
486    return conv
487  else:
488    raise ValueError('Cannot handle operation of type: %s' % layer_op.type)
489
490
491@ops.RegisterGradient('FoldFusedBatchNormGrad')
492def _FoldFusedBatchNormGrad(op, unused_grad_y, grad_mean, grad_var, unused_1,
493                            unused_2):
494  x = op.inputs[0]
495  n = math_ops.cast(
496      array_ops.size(x) / array_ops.size(grad_mean), dtypes.float32)
497  dmean_dx = grad_mean / n
498  dvar_dx = 2 * grad_var * (x - op.outputs[1]) / (n - 1)
499  return (dmean_dx + dvar_dx), None, None, None, None
500
501
502def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
503  """Finds unfused batch norm layers and folds them into preceding layers.
504
505  Folding only affects the following layers: Conv2D, fully connected, depthwise
506  convolution.
507
508  Args:
509    graph: Graph to walk and modify.
510    is_training: Bool, True if training.
511    freeze_batch_norm_delay: How many steps to wait before freezing moving mean
512      and variance and using them for batch normalization.
513
514  Raises:
515    ValueError: When batch norm folding fails.
516  """
517  input_to_ops_map = input_to_ops.InputToOps(graph)
518
519  for bn in common.BatchNormGroups(graph):
520    has_scaling = _HasScaling(graph, input_to_ops_map, bn)
521
522    if not _IsValidUnfusedBatchNorm(graph, bn):
523      continue
524
525    # The mangling code intimately depends on BatchNorm node's internals.
526    original_op, folded_op = _CreateFoldedOp(
527        graph,
528        bn,
529        has_scaling=has_scaling,
530        freeze_batch_norm_delay=freeze_batch_norm_delay,
531        is_training=is_training)
532
533    activation = common.GetEndpointActivationOp(graph, bn)
534    if activation:
535      nodes_modified_count = common.RerouteTensor(
536          folded_op.outputs[0], original_op.outputs[0], can_modify=[activation])
537      if nodes_modified_count != 1:
538        raise ValueError('Unexpected inputs to op: %s' % activation.name)
539      continue
540
541    # Treat consumer ops in bypass modules differently since they have Add
542    # operations instead of Relu* above.
543    # Changes to make sure that the correct scope is selected for the bypass add
544    # The rule here is that if the scope is of the form: str1/str2 for the
545    # batch norm,
546    # the bypass add is at scope str1. If bn is of scope just str1, then the
547    # bypass add is at scope ''.
548    # If there is no batch norm, then there is no bypass add.
549    add_bypass_ctx = ''
550    if bn:
551      try:
552        add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
553      except AttributeError:
554        add_bypass_ctx = ''
555
556    if add_bypass_ctx:
557      add_bypass_ctx = add_bypass_ctx + '/'
558
559    add_bypass = graph.get_operation_by_name(add_bypass_ctx + 'Add')
560    nodes_modified_count = common.RerouteTensor(
561        folded_op.outputs[0], original_op.outputs[0], can_modify=[add_bypass])
562    if nodes_modified_count != 1:
563      raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
564
565
566def _IsValidUnfusedBatchNorm(graph, context):
567  """Checks that the output of the unfused batch norm has consumers."""
568  add_shift = graph.get_operation_by_name(context +
569                                          'BatchNorm/batchnorm_1/add_1')
570  # Ensure that the output tensor of batch norm has consumers, otherwise this
571  # is a dangling node and not a match.
572  return bool(add_shift.outputs[0].consumers())
573
574
575def _FindMatchingTensor(graph, match_pattern, scope):
576  """Finds best match of ops matching match_pattern with scope.
577
578     Example: _FindMatchingTensor(graph,'/BatchNorm/moments/Squeeze',
579     'MobilenetV1/MobilenetV1/Conv2d_0/') returns:
580      Tensor('MobilenetV1/Conv2d_0/BatchNorm/moments/Squeeze')
581
582  Args:
583    graph: Graph to inspect.
584    match_pattern: Part of the name of the op that we need to match, should
585    be present in the op's name
586    scope: The scope of the op. All the elements of the scope need not be
587    present in the op's name.
588
589  Returns:
590    Tensor from graph that provides the best match to the match_pattern and
591    scope
592  """
593
594  oplist = graph.get_operations()
595  split_context = set(scope.split('/'))
596  match_dict = {}
597  for op in oplist:
598    if op.name.endswith(match_pattern):
599      split_name = op.name.split('/')
600      num_matches = len(set(split_name) & split_context)
601
602      if num_matches > 0 or not scope:
603        match_dict[op.name] = num_matches
604  # match_dict contains matching op names from graph with values being
605  # number of matches to scope. We pick the key with the most matches
606  if match_dict:
607    max_key = max(match_dict, key=match_dict.get)
608    return graph.get_tensor_by_name(max_key + ':0')
609  else:
610    return None
611
612
613def _GetBatchNormParams(graph, context, has_scaling):
614  """Extracts relevant tensors for folding batch norms.
615
616  Args:
617    graph: Graph to inspect.
618    context: The scope under which we look for batch norm params
619    has_scaling: Bool that specifies if scaling is done as part of batch norm.
620
621  Returns:
622    _BatchNormMatch containing all required batch norm parameters.
623  """
624  gamma_tensor = None
625  batch_mean_tensor = None
626  batch_variance_tensor = None
627  moving_mean_tensor = None
628  moving_variance_tensor = None
629  batch_epsilon = None
630  bn_decay_mean_tensor = None
631  bn_decay_var_tensor = None
632
633  # TODO(raghuramank) This code relies on string matching and needs to be
634  # updated if unfused batch norm continues to be widely used
635  # Matching variable names is brittle and relies on scoping
636  # conventions. Fused batch norm folding is more robust. Support for unfused
637  # batch norms will be deprecated as we move forward. Fused batch norms allow
638  # for faster training and should be used whenever possible.
639  # context contains part of the names of the tensors we are interested in:
640  # For MobilenetV1, the context has repetitions:
641  # MobilenetV1/MobilenetV1/Conv2d_3_depthwise
642  # when the moving_mean tensor has the name:
643  # MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_mean/read
644  # To pick the correct variable name, it is necessary to ignore the repeating
645  # header.
646
647  # For MobilenetV2, this problem does not exist:
648  # The context is: MobilenetV2/expanded_conv_3/depthwise
649  # and the names of the tensors start with a single MobilenetV2
650  # The moving mean for example, has the name:
651  # MobilenetV2/expanded_conv_3/depthwise/BatchNorm/moving_mean/read
652  # We identify the best match for an op by checking for
653  # 1. The suffix of the op is exactly matched
654  # 2. Maximum number of matches with the context.The matching
655  # score is given by the number of parts of context (split by /) that
656  # are present in the parts of the tensor name (again split by /).
657  # For example: scope= MobilenetV2/MobilenetV2/expanded_conv_3 and
658  # op.name =  MobilenetV2/expanded_conv_3/depthwise/BatchNorm/moving_mean/read
659  # will have 2 matches,scope with a different conv layer will have one match.
660
661  op_suffix_mean = 'BatchNorm/moments/Squeeze'
662  op_suffix_variance = 'BatchNorm/moments/Squeeze_1'
663  op_suffix_epsilon = 'BatchNorm/batchnorm_1/add/y'
664  op_suffix_bn_decay_mean = 'BatchNorm/AssignMovingAvg/decay'
665  op_suffix_bn_decay_var = 'BatchNorm/AssignMovingAvg_1/decay'
666
667  if variable_scope.get_variable_scope().use_resource:
668    op_suffix_gamma = 'BatchNorm/gamma/Read/ReadVariableOp'
669    op_suffix_moving_variance = (
670        'BatchNorm/moving_variance/Read/ReadVariableOp')
671    op_suffix_moving_mean = ('BatchNorm/moving_mean/Read/ReadVariableOp')
672  else:
673    op_suffix_gamma = 'BatchNorm/gamma'
674    op_suffix_moving_variance = 'BatchNorm/moving_variance/read'
675    op_suffix_moving_mean = 'BatchNorm/moving_mean/read'
676  # Parse through list of ops to find relevant ops
677
678  batch_mean_tensor = _FindMatchingTensor(graph, op_suffix_mean, context)
679  batch_variance_tensor = _FindMatchingTensor(graph, op_suffix_variance,
680                                              context)
681  moving_mean_tensor = _FindMatchingTensor(graph, op_suffix_moving_mean,
682                                           context)
683  moving_variance_tensor = _FindMatchingTensor(graph, op_suffix_moving_variance,
684                                               context)
685  batch_epsilon = _FindMatchingTensor(graph, op_suffix_epsilon, context)
686  bn_decay_mean_tensor = _FindMatchingTensor(graph, op_suffix_bn_decay_mean,
687                                             context)
688  bn_decay_var_tensor = _FindMatchingTensor(graph, op_suffix_bn_decay_var,
689                                            context)
690  if batch_mean_tensor is None and moving_mean_tensor is None:
691    ValueError('Error folding unfused batch norms')
692  if has_scaling:
693    gamma_tensor = _FindMatchingTensor(graph, op_suffix_gamma, context)
694
695  if not has_scaling:
696    gamma_tensor = array_ops.ones(moving_mean_tensor.shape)
697
698  return _BatchNormMatch(
699      layer_op=None,
700      bn_op=None,
701      output_tensor=None,
702      input_tensor=None,
703      weight_tensor=None,
704      gamma_tensor=gamma_tensor,
705      beta_tensor=None,
706      mean_tensor=batch_mean_tensor,
707      variance_tensor=batch_variance_tensor,
708      moving_mean_tensor=moving_mean_tensor,
709      moving_variance_tensor=moving_variance_tensor,
710      bn_decay_mean_tensor=bn_decay_mean_tensor,
711      bn_decay_var_tensor=bn_decay_var_tensor,
712      batch_epsilon=batch_epsilon,
713      batch_to_space_op=None)
714
715
716def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
717                    is_training):
718  """Folds in batch norm layer into preceding convolution or FC layer.
719
720  Creates 3 new nodes, connects their inputs and adds them to the graph:
721  mul is cloned into mul_fold, Conv2D or MatMul, or DepthwiseConv2d is cloned
722  into respective *_Fold, add is cloned into add_fold.
723
724  Args:
725    graph: Graph to modify.
726    context: String, batch norm context, i.e. node into which BatchNorm is
727      nested.
728    has_scaling: Whether the batch norm has scaling enabled.
729    freeze_batch_norm_delay: How many steps to wait before freezing moving mean
730      and variance and using them for batch normalization.
731    is_training: Bool, true if training.
732
733  Raises:
734    ValueError: When operation type is not supported, or input and output tensor
735      shapes mismatch for created operations: mul_fold, add_fold.
736
737  Returns:
738    A pair of Operations, the first is the original consumer node of the batch
739      norm (../BatchNorm/batchnorm_1/add_1), the second is the consumer node of
740      the folded graph (add_fold).
741  """
742  mul_scale_name = 'mul_1' if has_scaling else 'mul'
743  mul_scale = graph.get_operation_by_name(context + 'BatchNorm/batchnorm_1/' +
744                                          mul_scale_name)
745  op_below = mul_scale.inputs[0].op
746  # Skip over the BatchToSpace operation in the case of atrous convolutions.
747  batch_to_space_op = None
748  if op_below.type == 'BatchToSpaceND':
749    batch_to_space_op = op_below
750    op_below = op_below.inputs[0].op
751  weights = op_below.inputs[1]
752  match = _GetBatchNormParams(
753      graph=graph, context=context, has_scaling=has_scaling)
754  correction_scale, correction_recip, correction_offset = None, None, None
755  if is_training:
756    correction_scale, correction_recip, correction_offset = (
757        _ComputeBatchNormCorrections(
758            context=context,
759            match=match,
760            freeze_batch_norm_delay=freeze_batch_norm_delay))
761  # Special handling for weights of depthwise convolution.
762  if op_below.type == 'DepthwiseConv2dNative':
763    new_shape = [
764        weights.get_shape().as_list()[2],
765        weights.get_shape().as_list()[3]
766    ]
767    scale_name = 'mul' if has_scaling else 'Rsqrt'
768    scale = graph.get_operation_by_name(context + 'BatchNorm/batchnorm_1/' +
769                                        scale_name)
770    scale = array_ops.reshape(scale.outputs[0], new_shape,
771                              context + 'scale_reshape')
772
773    if correction_scale is not None:
774      correction_scale = array_ops.reshape(correction_scale, new_shape,
775                                           context + 'correction_reshape')
776      with ops.device(mul_scale.device):
777        weights = math_ops.multiply(correction_scale, weights,
778                                    context + 'correction_mult')
779
780    mul_fold = _CloneOp(mul_scale, context + 'mul_fold', [(0, weights),
781                                                          (1, scale)])
782  elif op_below.type in ['Conv2D', 'MatMul']:
783
784    if correction_scale is not None:
785      with ops.device(mul_scale.device):
786        weights = math_ops.multiply(correction_scale, weights,
787                                    context + 'correction_mult')
788    mul_fold = _CloneOp(mul_scale, context + 'mul_fold', [(0, weights)])
789  else:
790    raise ValueError('Cannot handle operation of type: %s' % op_below.type)
791  _AssertShapesMatch('mul_fold', mul_fold.inputs[0], mul_fold.outputs[0])
792
793  conv_or_fc_folded = _CloneOp(op_below, op_below.name + '_Fold',
794                               [(1, mul_fold.outputs[0])])
795
796  add_shift = graph.get_operation_by_name(context +
797                                          'BatchNorm/batchnorm_1/add_1')
798
799  corrected_output = conv_or_fc_folded.outputs[0]
800  # Copy the batch to space operation if we have a atrous convolution.
801  if batch_to_space_op:
802    corrected_output = array_ops.batch_to_space_nd(
803        corrected_output,
804        batch_to_space_op.inputs[1],
805        batch_to_space_op.inputs[2],
806        name=batch_to_space_op.name + '_Fold')
807  if correction_offset is not None:
808    with ops.device(conv_or_fc_folded.device):
809      corrected_output = math_ops.multiply(correction_recip, corrected_output,
810                                           context + 'post_conv_mul')
811      corrected_output = math_ops.add(corrected_output, (correction_offset),
812                                      context + 'correction_add')
813  add_fold = _CloneOp(add_shift, context + 'add_fold', [(0, corrected_output)])
814  _AssertShapesMatch('add_fold', add_fold.inputs[0], add_fold.outputs[0])
815  return add_shift, add_fold
816
817
818def _CloneOp(op, new_name, new_inputs):
819  """Clones a given op, replaces its name and some of its inputs.
820
821  Args:
822    op: Operation to modify.
823    new_name: String, a new name to set on cloned op.
824    new_inputs: A list of tuples (idx, tensor), each input with corresponding
825      index will be replaced by the given Tensor in the cloned op.
826
827  Returns:
828    Operation, the cloned op.
829
830  Raises:
831    TypeError: When Operation type is not supported.
832    ValueError: When input shapes are incompatible.
833  """
834  inputs = list(op.inputs)
835  for new_input in new_inputs:
836    inputs[new_input[0]] = new_input[1]
837  return _OP_CLONER.Clone(op, inputs, new_name)
838
839
840class _OpCloner(object):
841  """Helper class that clones tf.Operations based on their type."""
842
843  def __init__(self):
844    self.op_type_to_action = {
845        'Mul': self._CloneMul,
846        'Add': self._CloneAdd,
847        'Conv2D': self._CloneConv2d,
848        'DepthwiseConv2dNative': self._CloneDepthwiseConv2d,
849        'MatMul': self._CloneMatMul,
850    }
851
852  def _CloneMul(self, op, inputs, new_name):
853    del op  # Unused.
854    return math_ops.multiply(inputs[0], inputs[1], name=new_name).op
855
856  def _CloneAdd(self, op, inputs, new_name):
857    del op  # Unused.
858    return math_ops.add(inputs[0], inputs[1], name=new_name).op
859
860  def _CloneConv2d(self, op, inputs, new_name):
861    input_tensor = inputs[0]
862    weights = inputs[1]
863    self._AssertConvShapes(op.name, input_tensor, weights)
864    return nn_ops.conv2d(
865        input_tensor,
866        weights,
867        strides=op.get_attr('strides'),
868        padding=op.get_attr('padding'),
869        use_cudnn_on_gpu=op.get_attr('use_cudnn_on_gpu'),
870        data_format=op.get_attr('data_format').decode(),
871        name=new_name).op
872
873  def _CloneDepthwiseConv2d(self, op, inputs, new_name):
874    input_tensor = inputs[0]
875    weights = inputs[1]
876    self._AssertConvShapes(op.name, input_tensor, weights)
877    return nn.depthwise_conv2d(
878        input_tensor,
879        weights,
880        strides=op.get_attr('strides'),
881        padding=op.get_attr('padding'),
882        name=new_name).op
883
884  def _CloneMatMul(self, op, inputs, new_name):
885    weights = inputs[0]
886    input_tensor = inputs[1]
887    self._AssertFCShapes(op.name, weights, input_tensor)
888    return math_ops.matmul(
889        weights,
890        input_tensor,
891        transpose_a=op.get_attr('transpose_a'),
892        transpose_b=op.get_attr('transpose_b'),
893        name=new_name).op
894
895  def Clone(self, op, inputs, new_name):
896    try:
897      return self.op_type_to_action[op.type](op, inputs, new_name)
898    except KeyError:
899      raise TypeError('Unsupported operation type: %s' % op.type)
900
901  def _AssertConvShapes(self, op_name, input_tensor, weights):
902    """Makes sure that convolution inputs have compatible shapes.
903
904    Args:
905      op_name: Operation name, only used in error message.
906      input_tensor: Input that is convolved.
907      weights: Weights of the convolution filter.
908
909    Raises:
910      ValueError: When input shapes are incompatible.
911    """
912    input_shape = input_tensor.get_shape()
913    weights_shape = weights.get_shape()
914    if (len(input_shape) != 4 or len(weights_shape) != 4 or
915        input_shape[3] != weights_shape[2]):
916      raise ValueError('Incompatible shapes for op %s inputs: %s and %s' %
917                       (op_name, input_shape, weights_shape))
918
919  def _AssertFCShapes(self, op_name, weights, input_tensor):
920    """Makes sure that FC layer inputs have compatible shapes.
921
922    Args:
923      op_name: Operation name, only used in error message.
924      weights: Weights used in FC layer.
925      input_tensor: Input into FC layer.
926
927    Raises:
928      ValueError: When input shapes are incompatible.
929    """
930    weights_shape = weights.get_shape()
931    input_shape = input_tensor.get_shape()
932    if (len(weights_shape) != 2 or len(input_shape) != 2 or
933        weights_shape[1] != input_shape[0]):
934      raise ValueError('Incompatible shapes for op %s inputs: %s and %s' %
935                       (op_name, weights_shape, input_shape))
936
937_OP_CLONER = _OpCloner()
938
939
940def _AssertShapesMatch(op_name, in_tensor, out_tensor):
941  """Makes sure that shapes of input and output tensors are compatible.
942
943  Args:
944    op_name: String, operation name, only used in error message.
945    in_tensor: Tensor, input tensor.
946    out_tensor: Tensor, output tensor.
947
948  Raises:
949    ValueError: When input and output tensors have different shapes.
950  """
951  in_shape = in_tensor.get_shape()
952  out_shape = out_tensor.get_shape()
953
954  if not in_shape.is_compatible_with(out_shape):
955    raise ValueError('%s should not change tensor shape: input %s, '
956                     'output %s' % (op_name, in_shape, out_shape))
957
958
959def _HasScaling(graph, input_to_ops_map, bn):
960  r"""Checks if batch norm  has scaling enabled.
961
962  Difference between batch norm with scaling and without is that with scaling:
963
964  Rsqrt -> mul -> mul_1
965              \-> mul_2
966
967  where
968    mul multiplies gamma by inverse square root of EMA of batch variance,
969    mul_1 multiplies output of mul with output from the base operation
970      (convolution, FC or depthwise convolution),
971    mul_2 multiplies output of mul with EMA of batch mean,
972  and without scaling:
973
974  Rsqrt -> mul
975       \-> mul_1
976
977  where
978    mul multiplies the inverse square root of EMA of batch variance with output
979      from the base operation,
980    mul_1 multiplies inverse square root of EMA of batch variance with EMA
981      of batch mean.
982
983  Args:
984    graph: Graph to inspect.
985    input_to_ops_map: InputToOps object containing mapping from tensor's name
986      to ops that take it as input.
987    bn: Batch norm layer prefix string.
988
989  Returns:
990    A boolean indicating whether this batch norm layer has scaling enabled.
991  """
992  rsqrt_op = graph.get_operation_by_name(bn + 'BatchNorm/batchnorm_1/Rsqrt')
993  rsqrt_consumers = input_to_ops_map.ConsumerOperations(rsqrt_op)
994
995  return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1
996
997
998class _BatchNormMatch(object):
999  """Contains all information related to a found Fused/UnfusedBatchNorm."""
1000
1001  def __init__(self, layer_op, bn_op, output_tensor, input_tensor,
1002               weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
1003               variance_tensor, moving_mean_tensor, moving_variance_tensor,
1004               bn_decay_mean_tensor, bn_decay_var_tensor, batch_epsilon,
1005               batch_to_space_op):
1006    self._layer_op = layer_op
1007    self._bn_op = bn_op
1008    self._output_tensor = output_tensor
1009    self._input_tensor = input_tensor
1010    self._weight_tensor = weight_tensor
1011    self._gamma_tensor = gamma_tensor
1012    self._beta_tensor = beta_tensor
1013    self._mean_tensor = mean_tensor
1014    self._variance_tensor = variance_tensor
1015    self._moving_mean_tensor = moving_mean_tensor
1016    self._moving_variance_tensor = moving_variance_tensor
1017    self._bn_decay_mean_tensor = bn_decay_mean_tensor
1018    self._bn_decay_var_tensor = bn_decay_var_tensor
1019    self._batch_epsilon = batch_epsilon
1020    self._batch_to_space_op = batch_to_space_op
1021
1022  @property
1023  def layer_op(self):
1024    return self._layer_op
1025
1026  @property
1027  def bn_op(self):
1028    return self._bn_op
1029
1030  @property
1031  def output_tensor(self):
1032    return self._output_tensor
1033
1034  @property
1035  def input_tensor(self):
1036    return self._input_tensor
1037
1038  @property
1039  def weight_tensor(self):
1040    return self._weight_tensor
1041
1042  @property
1043  def gamma_tensor(self):
1044    return self._gamma_tensor
1045
1046  @property
1047  def beta_tensor(self):
1048    return self._beta_tensor
1049
1050  @property
1051  def mean_tensor(self):
1052    return self._mean_tensor
1053
1054  @property
1055  def variance_tensor(self):
1056    return self._variance_tensor
1057
1058  @property
1059  def moving_mean_tensor(self):
1060    return self._moving_mean_tensor
1061
1062  @property
1063  def moving_variance_tensor(self):
1064    return self._moving_variance_tensor
1065
1066  @property
1067  def batch_epsilon(self):
1068    return self._batch_epsilon
1069
1070  @property
1071  def bn_decay_mean_tensor(self):
1072    return self._bn_decay_mean_tensor
1073
1074  @property
1075  def bn_decay_var_tensor(self):
1076    return self._bn_decay_var_tensor
1077
1078  @property
1079  def batch_to_space_op(self):
1080    return self._batch_to_space_op
1081