1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Apply graph_transforms tool to MetaGraphDefs.
17
18@@meta_graph_transform
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25
26import re as _re
27
28from tensorflow.core.framework import graph_pb2 as _graph_pb2
29from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
30from tensorflow.python.client import session as _session
31from tensorflow.python.framework import graph_util as _graph_util
32from tensorflow.python.framework import importer as _importer
33from tensorflow.python.framework import ops as _ops
34from tensorflow.python.platform import tf_logging as _logging
35from tensorflow.python.saved_model import constants as _saved_model_constants
36from tensorflow.python.training import saver as _saver_lib
37from tensorflow.python.util import compat as _compat
38from tensorflow.tools import graph_transforms as _graph_transforms
39
40
41_FREEZE_GRAPH_TRANSFORM = 'freeze_graph'
42_SPARSIFY_GATHER_TRANSFORM = 'sparsify_gather'
43
44
45def _op_name(tensor_name):
46  """Get the op name from a tensor name."""
47  # control dependency inputs start with ^
48  if tensor_name[0] == '^':
49    tensor_name = tensor_name[1:]
50  if ':' in tensor_name:
51    op_name, _ = tensor_name.split(':')
52    return op_name
53  return tensor_name
54
55
56def _get_shared_init_op(initializer_names):
57  """Obtain the shared init op name, if it exists.
58
59  Args:
60   initializer_names: Dictionary of the "infrastructural" nodes (initializers,
61     save and restore ops, etc.). The keys in this dictionary
62     indicate the collection where these nodes were obtained from.
63
64  Returns:
65    A string indicating the shared init op name or none if None if none exists.
66  """
67  return_value = initializer_names.get(_saved_model_constants.MAIN_OP_KEY, None)
68  if not return_value:
69    return_value = initializer_names.get(
70        _saved_model_constants.LEGACY_INIT_OP_KEY, None)
71  return str(return_value[0]) if return_value else None
72
73
74def _gtt_transforms(graph_def, input_names, output_names, initializer_names,
75                    transforms):
76  """Pass through gtt transforms, applying them to the graph_def.
77
78  Args:
79    graph_def: A GraphDef proto to be transformed.
80    input_names: Names of input nodes.
81    output_names: Names of output nodes.
82    initializer_names: Dictionary of the "infrastructural" nodes (initializers,
83      save and restore ops, etc.) that should be retained even if they are not
84      transitively reachable from output nodes. The keys in this dictionary
85      indicate the collection where these nodes were obtained from.
86    transforms: A list of strings naming the graph transforms to be applied in
87      order.
88  Returns:
89    The transformed GraphDef.
90  """
91  if not transforms:
92    transformed_graph_def = _graph_pb2.GraphDef()
93    transformed_graph_def.CopyFrom(graph_def)
94    return transformed_graph_def
95
96  initializer_names_flat = sorted(
97      [k for l in initializer_names.values() for k in l])
98  all_output_names = output_names + initializer_names_flat
99  return _graph_transforms.TransformGraph(graph_def, input_names,
100                                          all_output_names, transforms)
101
102
103def _freeze_transform(graph_def, output_names, initializer_names, saver_def,
104                      checkpoint_path):
105  """Handle the freeze transform.
106
107  Determine which initializer nodes should be retained by the freeze transform.
108  Retain those nodes and return an updated dictionary containing them.
109
110  Args:
111    graph_def: A GraphDef proto to be transformed.
112    output_names: Names of output nodes.
113    initializer_names: Dictionary of the "infrastructural" nodes (initializers,
114      save and restore ops, etc.). The keys in this dictionary
115      indicate the collection where these nodes were obtained from.
116    saver_def: A SaverDef proto used for restoring a checkpoint during freezing,
117      if needed (default None).
118    checkpoint_path:  A path to a checkpoint to restore during freezing,
119      if needed (default None).
120
121  Returns:
122    A tuple containing the GraphDef and a Dict of pruned initializer nodes.
123  """
124  table_initializers = initializer_names.get(_ops.GraphKeys.TABLE_INITIALIZERS,
125                                             [])
126  shared_init_op = _get_shared_init_op(initializer_names)
127
128  graph_def = _freeze_graph_with_def_protos(graph_def, output_names,
129                                            table_initializers, shared_init_op,
130                                            saver_def, checkpoint_path)
131  pruned_initializer_names = {}
132  # Freeze graph prunes all initializers and shared init nodes that are not
133  # explicitly maintained. Create new initializer_names dictionary to reflect
134  # this.
135  if table_initializers:
136    pruned_initializer_names[_ops.GraphKeys.TABLE_INITIALIZERS] = (
137        table_initializers)
138    if _saved_model_constants.LEGACY_INIT_OP_KEY in initializer_names:
139      pruned_initializer_names[_saved_model_constants.LEGACY_INIT_OP_KEY] = (
140          initializer_names[_saved_model_constants.LEGACY_INIT_OP_KEY])
141    if _saved_model_constants.MAIN_OP_KEY in initializer_names:
142      pruned_initializer_names[_saved_model_constants.MAIN_OP_KEY] = (
143          initializer_names[_saved_model_constants.MAIN_OP_KEY])
144  return (graph_def, pruned_initializer_names)
145
146
147def _clean_save_and_restore(graph_def, op, removed_op_names):
148  """Clean the specified save and restore op.
149
150  Updates the dtypes attribute of the save / restore op and the associated name
151  and shape tensors to remove entries for variables that have been removed.
152
153  Args:
154    graph_def: A GraphDef proto to be transformed.
155    op: The save or restore op to update.
156    removed_op_names: List of op names that have been removed.
157  """
158  name = op.name + '/tensor_names'
159  shape = op.name + '/shape_and_slices'
160  name_op = _find_op(graph_def, name)
161  shape_op = _find_op(graph_def, shape)
162  name_op_value_tensor = name_op.attr['value'].tensor
163  shape_op_value_tensor = shape_op.attr['value'].tensor
164  names = []
165  shapes = []
166  dtypes = []
167  for index, value in enumerate(name_op_value_tensor.string_val):
168    if not _is_removed(_compat.as_str(value), removed_op_names):
169      names.append(value)
170      shapes.append(shape_op_value_tensor.string_val[index])
171      dtypes.append(op.attr['dtypes'].list.type[index])
172  name_op_value_tensor.string_val[:] = names
173  name_op_value_tensor.tensor_shape.dim[0].size = len(names)
174  shape_op_value_tensor.string_val[:] = shapes
175  shape_op_value_tensor.tensor_shape.dim[0].size = len(shapes)
176  op.attr['dtypes'].list.type[:] = dtypes
177
178  if not name_op.attr['_output_shapes'].list.shape:
179    name_op.attr['_output_shapes'].list.shape.add()
180    name_op.attr['_output_shapes'].list.shape[0].dim.add()
181  name_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(names)
182
183  if not shape_op.attr['_output_shapes'].list.shape:
184    shape_op.attr['_output_shapes'].list.shape.add()
185    shape_op.attr['_output_shapes'].list.shape[0].dim.add()
186  shape_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(shapes)
187
188
189def _sparsify_gather_transform(graph_def, input_names, output_names,
190                               initializer_names, checkpoint_path):
191  """Handle the sparsify gather transform.
192
193  Provides the transform the checkpoint and keeps track of the newly created
194  initializer nodes.
195
196  Args:
197    graph_def: A GraphDef proto to be transformed.
198    input_names: Names of input nodes.
199    output_names: Names of output nodes.
200    initializer_names: Dictionary of the "infrastructural" nodes (initializers,
201      save and restore ops, etc.). The keys in this dictionary
202      indicate the collection where these nodes were obtained from.
203    checkpoint_path:  A path to a checkpoint.
204
205  Returns:
206    A tuple containing the GraphDef and a Dict of updated initializer nodes.
207  Raises:
208    ValueError: if the restore_op_name does not have the expected format.
209  """
210  # Ensure that sparsify_shared_init_op is unique.
211  sparsify_shared_init_op = 'sparify_gather_init_op'
212  while _find_op(graph_def, sparsify_shared_init_op):
213    sparsify_shared_init_op += '_1'
214
215  input_flag = ''
216  if checkpoint_path:
217    input_flag = 'input_checkpoint="%s", ' % checkpoint_path
218
219  sparsify_cmd = [
220      'sparsify_gather(%sgroup_init_node="%s")' % (input_flag,
221                                                   sparsify_shared_init_op)
222  ]
223
224  starting_op_names = [node.name for node in graph_def.node]
225
226  graph_def = _gtt_transforms(graph_def, input_names, output_names,
227                              initializer_names, sparsify_cmd)
228  ending_op_names = [node.name for node in graph_def.node]
229  removed_op_names = list(set(starting_op_names) - set(ending_op_names))
230  removed_op_names.sort()
231
232  for op_index, op_name in enumerate(removed_op_names):
233    op_name_parts = op_name.rsplit('/', 1)
234    # Remove part to get the checkpoint names used by the saver.
235    if len(op_name_parts) == 2 and op_name_parts[1].startswith('part_'):
236      removed_op_names[op_index] = op_name_parts[0]
237    else:
238      removed_op_names[op_index] = op_name
239
240  # Obtain newly created table inits from gtt sparsify transform.
241  added_table_inits = []
242  for index, node in enumerate(graph_def.node):
243    if node.name == sparsify_shared_init_op:
244      added_table_inits = [n.lstrip('^') for n in node.input]
245
246      table_initializers = initializer_names.get(
247          _ops.GraphKeys.TABLE_INITIALIZERS, [])
248      table_initializers.extend(added_table_inits)
249      initializer_names[_ops.GraphKeys.TABLE_INITIALIZERS] = table_initializers
250
251      del graph_def.node[index]
252      break
253
254  # Add inits to existing shared init op.
255  node = _find_op(graph_def, _get_shared_init_op(initializer_names))
256  for init in added_table_inits:
257    node.input.append('^' + init)
258
259  # Update saver.
260  for node in graph_def.node:
261    if node.name.endswith('SaveV2'):
262      _clean_save_and_restore(graph_def, node, removed_op_names)
263
264  return (graph_def, initializer_names)
265
266
267def _do_transforms(graph_def,
268                   input_names,
269                   output_names,
270                   initializer_names,
271                   transforms,
272                   saver_def=None,
273                   checkpoint_path=None):
274  """Apply requested transforms to a GraphDef, including freezing.
275
276  Args:
277    graph_def: A GraphDef proto to be transformed.
278    input_names: Names of input nodes.
279    output_names: Names of output nodes.
280    initializer_names: Dictionary of the "infrastructural" nodes (initializers,
281      save and restore ops, etc.) that should be retained even if they are not
282      transitively reachable from output nodes. The keys in this dictionary
283      indicate the collection where these nodes were obtained from.
284    transforms: A list of strings naming the graph transforms to be applied in
285      order.  These transform names are exactly those supported by the Graph
286      Transform Tool, with the addition of the 'freeze_graph' and
287      'sparsify_gather' transforms.
288    saver_def: A SaverDef proto used for restoring a checkpoint during freezing,
289      if needed (default None).
290    checkpoint_path:  A path to a checkpoint to restore during freezing,
291      if needed (default None).
292  Returns:
293     A tuple containing the GraphDef and a Dict of updated initializer nodes.
294  """
295  transformed_graph_def = _graph_pb2.GraphDef()
296  transformed_graph_def.CopyFrom(graph_def)
297  transformed_initializer_names = initializer_names.copy()
298
299  if not transforms:
300    return transformed_graph_def, transformed_initializer_names
301
302  current_gtt_transforms = []
303  for t in transforms:
304    if t == _FREEZE_GRAPH_TRANSFORM:
305      transformed_graph_def = _gtt_transforms(
306          transformed_graph_def, input_names, output_names,
307          transformed_initializer_names, current_gtt_transforms)
308      output_node_names = [_op_name(x) for x in output_names]
309      transformed_graph_def, transformed_initializer_names = _freeze_transform(
310          transformed_graph_def, output_node_names,
311          transformed_initializer_names, saver_def, checkpoint_path)
312      current_gtt_transforms = []
313    elif t == _SPARSIFY_GATHER_TRANSFORM:
314      transformed_graph_def = _gtt_transforms(
315          transformed_graph_def, input_names, output_names,
316          transformed_initializer_names, current_gtt_transforms)
317      transformed_graph_def, transformed_initializer_names = (
318          _sparsify_gather_transform(
319              transformed_graph_def, input_names, output_names,
320              transformed_initializer_names, checkpoint_path))
321      current_gtt_transforms = []
322    else:
323      current_gtt_transforms.append(t)
324
325  transformed_graph_def = _gtt_transforms(
326      transformed_graph_def, input_names, output_names,
327      transformed_initializer_names, current_gtt_transforms)
328  return transformed_graph_def, transformed_initializer_names
329
330
331def _connect_to_shared_init_op(graph_def, shared_init_op_name,
332                               nodes_to_connect):
333  """Creates a new shared init node that is connected to via control deps.
334
335  Args:
336    graph_def: The GraphDef proto to add the shared init node to.
337    shared_init_op_name: A string specifying the name of the shared init node to
338      create.
339    nodes_to_connect: A list of strings specifying the names of nodes to connect
340      to the shared node via control dependencies.
341  """
342  if nodes_to_connect:
343    init_op = graph_def.node.add()
344    init_op.name = shared_init_op_name
345    init_op.op = 'NoOp'
346    init_op.input.extend(['^' + i for i in nodes_to_connect])
347
348
349# forked and modified from freeze_graph.py
350def _freeze_graph_with_def_protos(input_graph_def, output_node_names,
351                                  initializer_names, shared_init_op_name,
352                                  input_saver_def, input_checkpoint):
353  """Converts all variables in a graph and checkpoint into constants.
354
355  During this process, we need to retain certain initializer nodes (e.g. table
356  initializer nodes). Instead of determining which dependencies
357  of the shared initializer node (e.g. group_deps) to keep, we
358  reconstruct the connections between the individual initializer nodes and
359  the shared node after freezing the graph.
360
361  Args:
362    input_graph_def: A GraphDef proto to be frozen.
363    output_node_names: Names of output nodes.
364    initializer_names: Names of initializer nodes to keep.
365    shared_init_op_name: The name of the shared initializer node to connect the
366      nodes in initializer names to.
367    input_saver_def: A SaverDef proto used for restoring a checkpoint.
368    input_checkpoint: A path to a checkpoint to restore.
369
370  Returns:
371    A frozen GraphDef.
372  """
373
374  with _ops.Graph().as_default():
375    _ = _importer.import_graph_def(input_graph_def, name='')
376
377    with _session.Session() as sess:
378      saver = _saver_lib.Saver(saver_def=input_saver_def)
379      saver.restore(sess, input_checkpoint)
380      output_graph_def = _graph_util.convert_variables_to_constants(
381          sess, input_graph_def, output_node_names + initializer_names)
382      _connect_to_shared_init_op(output_graph_def, shared_init_op_name,
383                                 initializer_names)
384  return output_graph_def
385
386
387def _find_all_mandatory_retain_ops(base_meta_graph_def):
388  """Identify all infrastructural Ops, to ensure that they are retained.
389
390  We need to retain infrastructural Ops (init and saver stuff), in addition
391  to the desired outputs.
392
393  For now we retain *all* save and restore ops, variable initializers,
394  table initializers, and main init ops.
395  This means that strip_unused_nodes will not remove unused variables.
396
397  Args:
398    base_meta_graph_def: a GraphDef proto in which to identify nodes to retain.
399
400  Returns:
401    A dictionary corresponding to the nodes associated with each collection
402    that are to be retained.
403  """
404  # TODO(b/63447631): implement variable stripping.
405
406  initializer_names = {}
407
408  # Primary SaverDef and SAVERS collection
409  saver_defs = []
410  if base_meta_graph_def.HasField('saver_def'):
411    saver_defs.append(base_meta_graph_def.saver_def)
412  saver_defs.extend(_get_all_protos_from_collection(
413      base_meta_graph_def, _ops.GraphKeys.SAVERS))
414  for saver_def in saver_defs:
415    savers = initializer_names.get(_ops.GraphKeys.SAVERS, [])
416    savers.extend([
417        saver_def.filename_tensor_name, saver_def.save_tensor_name,
418        saver_def.restore_op_name
419    ])
420    initializer_names[_ops.GraphKeys.SAVERS] = savers
421
422  # Variable initializers
423  variable_collections = [
424      _ops.GraphKeys.GLOBAL_VARIABLES,
425      _ops.GraphKeys.TRAINABLE_VARIABLES,
426      _ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
427      _ops.GraphKeys.LOCAL_VARIABLES,
428      _ops.GraphKeys.MODEL_VARIABLES]
429  for var_coll in variable_collections:
430    variables = _get_all_protos_from_collection(base_meta_graph_def, var_coll)
431    var_init_names = [v.initializer_name for v in variables]
432    if var_init_names:
433      # Sanity check to ensure we don't overwrite dictionary entries.
434      assert var_coll not in initializer_names
435      initializer_names[var_coll] = var_init_names
436
437  # Table initializers
438  op_names = _get_all_node_names_from_collection(
439      base_meta_graph_def, _ops.GraphKeys.TABLE_INITIALIZERS)
440  if op_names:
441    # Sanity check to ensure we don't overwrite dictionary entries.
442    assert _ops.GraphKeys.TABLE_INITIALIZERS not in initializer_names
443    table_initializers = [t for t in op_names]
444    initializer_names[_ops.GraphKeys.TABLE_INITIALIZERS] = table_initializers
445
446  # Various init ops
447  various_init_op_collections = [_saved_model_constants.LEGACY_INIT_OP_KEY,
448                                 _saved_model_constants.MAIN_OP_KEY,
449                                 _ops.GraphKeys.INIT_OP,
450                                 _ops.GraphKeys.LOCAL_INIT_OP,
451                                 _ops.GraphKeys.READY_OP,
452                                 _ops.GraphKeys.READY_FOR_LOCAL_INIT_OP]
453  for op_coll in various_init_op_collections:
454    op_name = _get_single_node_name_from_collection(
455        base_meta_graph_def, op_coll)
456    if op_name:
457      # Sanity check to ensure we don't overwrite dictionary entries.
458      assert op_coll not in initializer_names
459      initializer_names[op_coll] = [op_name]
460  return initializer_names
461
462
463def _add_pruned_collection(base_meta_graph_def, meta_graph_def,
464                           collection_name, removed_op_names):
465  """Copy collection to the transformed MetaGraphDef, omitting removed items."""
466
467  base_collection = base_meta_graph_def.collection_def[collection_name]
468  collection = meta_graph_def.collection_def[collection_name]
469
470  if base_collection.HasField('any_list'):
471    for any_value in base_collection.any_list.value:
472      # just search the serialized proto as a string
473      if not _is_removed_mentioned(any_value.value, removed_op_names):
474        copied_any = collection.any_list.value.add()
475        copied_any.CopyFrom(any_value)
476  elif base_collection.HasField('bytes_list'):
477    collection.bytes_list.value[:] = [
478        s for s in base_collection.bytes_list.value
479        if not _is_removed_mentioned(s, removed_op_names)]
480    _logging.info(
481        'In collection %s, nodes excluded are: %s', collection_name,
482        sorted([
483            s for s in base_collection.bytes_list.value
484            if _is_removed_mentioned(s, removed_op_names)
485        ]))
486  elif base_collection.HasField('node_list'):
487    collection.node_list.value[:] = [
488        s for s in base_collection.node_list.value
489        if not _is_removed(s, removed_op_names)]
490  else:
491    collection.CopyFrom(base_collection)
492
493
494def _add_pruned_saver(base_meta_graph_def, meta_graph_def, removed_op_names):
495  """Copy the Saver into the transformed MetaGraphDef, if valid.
496
497  Currently this copies the Saver as is, after verifying that none of the
498  referenced Save & Restore ops were removed.  A future version will modify
499  the Save and Restore ops themselves as needed to account for removed
500  Variables.
501
502  Args:
503    base_meta_graph_def: The untransformed MetaGraphDef.
504    meta_graph_def: The transformed MetaGraphDef being built.
505    removed_op_names: An iterable of names of ops that were removed.
506  """
507
508  # Note this does surgery on meta_graph_def.graph_def too, so that should have
509  # been copied already.
510  if base_meta_graph_def.HasField('saver_def'):
511    filename_tensor_name = base_meta_graph_def.saver_def.filename_tensor_name
512    save_tensor_name = base_meta_graph_def.saver_def.save_tensor_name
513    restore_op_name = base_meta_graph_def.saver_def.restore_op_name
514
515    _check_tensor_not_removed(filename_tensor_name, removed_op_names)
516    _check_tensor_not_removed(save_tensor_name, removed_op_names)
517    _check_tensor_not_removed(restore_op_name, removed_op_names)
518
519    # TODO(b/63447631): Once we strip unused variables, remove references to
520    # them from save and restore ops.  Retain those ops only if they also refer
521    # to retained Variables. See if we can use _clean_save_and_restore() for
522    # this.
523
524    # saver_name, restore_all = restore_op_name.rsplit('/', 1)
525    # if restore_all != 'restore_all':
526    #   raise ValueError(
527    #       'SaverDef restore_op_name did not have expected form */restore_all')
528
529    # save_tensor_names_op_name = '{}/SaveV2/tensor_names'.format(saver_name)
530    # restore_tensor_names_op_name = (
531    #     '{}/RestoreV2/tensor_names'.format(saver_name))
532
533    # save_tensor_names_op = _find_op(meta_graph_def.graph_def,
534    #                                 save_tensor_names_op_name)
535    # save_tensor_names_value_tensor = save_tensor_names_op.attr['value'].tensor
536    # save_tensor_names_value_tensor.string_val[:] = [
537    #     s for s in save_tensor_names_value_tensor.string_val
538    #     if not _is_removed(s, removed_op_names)]
539
540    # restore_tensor_names_op = _find_op(
541    #     meta_graph_def.graph_def, restore_tensor_names_op_name)
542    # restore_tensor_names_value_tensor = (
543    #     restore_tensor_names_op.attr['value'].tensor)
544    # restore_tensor_names_value_tensor.string_val[:] = [
545    #     s for s in restore_tensor_names_value_tensor.string_val
546    #     if not _is_removed(s, removed_op_names)]
547
548    # if (save_tensor_names_value_tensor.string_val
549    #     or restore_tensor_names_value_tensor.string_val):
550    meta_graph_def.saver_def.CopyFrom(base_meta_graph_def.saver_def)
551
552
553def _find_op(graph_def, op_name):
554  """Fetch a node from a GraphDef proto by name."""
555  for node_def in graph_def.node:
556    if node_def.name == op_name:
557      return node_def
558  return None
559
560
561def _add_pruned_signature(base_meta_graph_def, meta_graph_def,
562                          signature_name, removed_op_names):
563  """Copy the named signature into the transformed MetaGraphDef, if valid.
564
565  If any input or output mentioned in the signature was removed by the graph
566  transform, the signature is silently omitted from the transformed
567  MetaGraphDef.
568
569  Args:
570    base_meta_graph_def: The untransformed MetaGraphDef.
571    meta_graph_def: The transformed MetaGraphDef being built.
572    signature_name: The name of the signature to copy.
573    removed_op_names: An iterable of names of ops that were removed.
574  """
575  try:
576    base_signature = base_meta_graph_def.signature_def[signature_name]
577    for key in base_signature.inputs:
578      _check_tensor_not_removed(base_signature.inputs[key].name,
579                                removed_op_names)
580    for key in base_signature.outputs:
581      _check_tensor_not_removed(base_signature.outputs[key].name,
582                                removed_op_names)
583    meta_graph_def.signature_def[signature_name].CopyFrom(base_signature)
584  except ValueError:
585    # exclude any signature that mentions a removed node
586    pass
587
588
589def _get_single_node_name_from_collection(meta_graph_def, collection_key):
590  """Obtain a node name that is the single element of a collection."""
591  if collection_key not in meta_graph_def.collection_def:
592    return None
593  collection = meta_graph_def.collection_def[collection_key]
594  if not collection.node_list.value:
595    raise ValueError(
596        'Collection {} is present but type is not node_list.'.format(
597            collection_key))
598  if len(collection.node_list.value) != 1:
599    raise ValueError(
600        'Collection {} is has {} elements; expected exactly one.'.format(
601            collection_key, collection.bytes_list))
602  return collection.node_list.value[0]
603
604
605def _get_all_node_names_from_collection(meta_graph_def, collection_key):
606  """Obtain node names from a collection."""
607  if collection_key not in meta_graph_def.collection_def:
608    return None
609  collection = meta_graph_def.collection_def[collection_key]
610  if not collection.node_list.value:
611    raise ValueError(
612        'Collection {} is present but type is not node_list.'.format(
613            collection_key))
614  return collection.node_list.value
615
616
617def _get_all_protos_from_collection(meta_graph_def, collection_key):
618  """Obtain node names from a collection."""
619  if collection_key not in meta_graph_def.collection_def:
620    return []
621  collection = meta_graph_def.collection_def[collection_key]
622  if not collection.bytes_list.value:
623    raise ValueError(
624        'Collection {} is present but type is not bytes_list.'.format(
625            collection_key))
626  proto_type = _ops.get_collection_proto_type(collection_key)
627  result = []
628  for value in collection.bytes_list.value:
629    proto = proto_type()
630    proto.ParseFromString(value)
631    result.append(proto)
632  return result
633
634
635def _is_removed(tensor_name, removed_op_names):
636  """Determine whether the named tensor is an output of a removed op."""
637  for removed_op_name in removed_op_names:
638    if tensor_name.split(':')[0] == removed_op_name:
639      return True
640  return False
641
642
643def _is_removed_mentioned(s, removed_op_names):
644  """Determine whether any removed op is mentioned in the given object.
645
646  This relies on the string representation of the object.  This is used for
647  proto messages that may mention ops by name in nested fields.  The string
648  representation of the proto includes those field values, so this string
649  search approach is sufficient.
650
651  Args:
652    s: an object to search for removed op names.
653    removed_op_names: An iterable of names of ops that were removed.
654
655  Returns:
656    True if any removed op is mentioned in the given object, False otherwise.
657  """
658  # A common approach taken by some of the transforms in gtt is to add new nodes
659  # that have the same prefix as the node they are removing. For example, if
660  # the original node name was /foo, they may remove that node and add in
661  # /foo/bar. This regex ensures that we handle these two nodes
662  # as separate entities.  It matches on nodes having names in the form of
663  # '/foo/bar_x' as well as nodes having names in the form of 'foo.'
664  s_names = _re.findall(r'((?:[\/]?[a-zA-Z0-9\_]*)*)', _compat.as_str_any(s))
665  for removed_op_name in removed_op_names:
666    for s_name in s_names:
667      if s_name.endswith(removed_op_name):
668        return True
669  return False
670
671
672def _check_tensor_not_removed(tensor_name, removed_op_names):
673  """Verify that the named tensor was not removed.
674
675  Args:
676    tensor_name: the name of a tensor to check.
677    removed_op_names: An iterable of names of ops that were removed.
678
679  Raises:
680    ValueError: if the tensor was removed.
681  """
682  if not tensor_name:
683    raise ValueError('Tensor name should not be empty')
684  if _is_removed(tensor_name, removed_op_names):
685    raise ValueError(
686        'Expected Tensor, but it was removed: {}'.format(tensor_name))
687
688
689def _add_new_inits_to_collection(meta_graph_def, updated_initializer_names):
690  """Add new inits to collection.
691
692  Args:
693    meta_graph_def: The MetaGraphDef protocol buffer to update.
694    updated_initializer_names: Dictionary of the updated "infrastructural" nodes
695      (initializers, save and restore ops, etc.). The keys in this dictionary
696      indicate the collection where these nodes were obtained from.
697
698  Raises:
699    ValueError: if the tensor was removed.
700  """
701  # TODO(dzats): Extend this to support all collections.
702  if _ops.GraphKeys.TABLE_INITIALIZERS in updated_initializer_names:
703    orig_table_inits = _get_all_node_names_from_collection(
704        meta_graph_def, _ops.GraphKeys.TABLE_INITIALIZERS)
705    orig_table_inits = orig_table_inits if orig_table_inits else []
706    updated_table_inits = updated_initializer_names[
707        _ops.GraphKeys.TABLE_INITIALIZERS]
708    new_table_inits = list(set(updated_table_inits) - set(orig_table_inits))
709    new_table_inits.sort()
710    meta_graph_def.collection_def[
711        _ops.GraphKeys.TABLE_INITIALIZERS].node_list.value.extend(
712            new_table_inits)
713
714
715def meta_graph_transform(
716    base_meta_graph_def, input_names, output_names, transforms, tags,
717    checkpoint_path=None):
718  """Apply the Graph Transform tool to a MetaGraphDef.
719
720  Args:
721    base_meta_graph_def: A MetaGraphDef protocol buffer to transform.
722    input_names: Names of input nodes.
723    output_names: Names of output nodes.
724    transforms: A list of strings naming the graph transforms to be applied in
725      order.  These transform names are exactly those supported by the Graph
726      Transform Tool, with the addition of the 'freeze_graph' and
727      'sparsify_gather' transforms.
728    tags: A list of tags with which to annotate the transformed MetaGraphDef.
729    checkpoint_path: A path to a checkpoint to restore during freezing,
730      if needed (default None).
731
732  Returns:
733    A new transformed MetaGraphDef protocol buffer.
734  """
735  meta_graph_def = _meta_graph_pb2.MetaGraphDef()
736
737  initializer_names = _find_all_mandatory_retain_ops(base_meta_graph_def)
738
739  transformed_graph_def, updated_initializer_names = _do_transforms(
740      base_meta_graph_def.graph_def, input_names, output_names,
741      initializer_names, transforms, base_meta_graph_def.saver_def,
742      checkpoint_path)
743
744  meta_graph_def.graph_def.CopyFrom(transformed_graph_def)
745  meta_graph_def.meta_info_def.CopyFrom(base_meta_graph_def.meta_info_def)
746  meta_graph_def.meta_info_def.ClearField('tags')
747  for tag in tags:
748    meta_graph_def.meta_info_def.tags.append(tag)
749
750  base_op_names = [_compat.as_str(node.name)
751                   for node in base_meta_graph_def.graph_def.node]
752  retained_op_names = [_compat.as_str(node.name)
753                       for node in meta_graph_def.graph_def.node]
754  removed_op_names = set(base_op_names) - set(retained_op_names)
755  _logging.info('Node names in base graph: %s', sorted(base_op_names))
756  _logging.info('Node names retained: %s', sorted(retained_op_names))
757  _logging.info('Node names removed: %s', sorted(removed_op_names))
758
759  # Copy saver, excluding any pruned nodes if graph was not frozen.
760  # TODO(b/63447631): Revisit this once the problem is addressed. Currently
761  # _add_pruned_saver assumes that the save and restore nodes have not been
762  # removed but freeze_graph (correctly) removes them.
763  if _FREEZE_GRAPH_TRANSFORM not in transforms:
764    _add_pruned_saver(base_meta_graph_def, meta_graph_def, removed_op_names)
765
766  # Copy collections, excluding any pruned nodes
767  for collection_name in base_meta_graph_def.collection_def:
768    _add_pruned_collection(
769        base_meta_graph_def, meta_graph_def, collection_name,
770        removed_op_names)
771
772  # Append newly added initializers to collection.
773  _add_new_inits_to_collection(meta_graph_def, updated_initializer_names)
774
775  # Copy signature_defs, excluding any pruned nodes
776  for signature_name in base_meta_graph_def.signature_def:
777    _add_pruned_signature(
778        base_meta_graph_def, meta_graph_def, signature_name,
779        removed_op_names)
780
781  return meta_graph_def
782