1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""MetaGraph and related functions."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22from distutils import version as distutils_version  # pylint: disable=g-bad-import-order
23import os.path
24import re
25
26import six
27from google.protobuf.any_pb2 import Any
28from google.protobuf import text_format
29
30from tensorflow.core.framework import attr_value_pb2
31from tensorflow.core.framework import graph_pb2
32from tensorflow.core.framework import op_def_pb2
33from tensorflow.core.protobuf import meta_graph_pb2
34from tensorflow.core.protobuf import saver_pb2
35from tensorflow.python.client import pywrap_tf_session as c_api
36from tensorflow.python.eager import context
37from tensorflow.python.framework import error_interpolation
38from tensorflow.python.framework import graph_io
39from tensorflow.python.framework import importer
40from tensorflow.python.framework import op_def_registry
41from tensorflow.python.framework import ops
42from tensorflow.python.framework import versions
43from tensorflow.python.lib.io import file_io
44from tensorflow.python.platform import tf_logging as logging
45from tensorflow.python.util import compat
46
47
48# Prefix to be added to unbound input names so they are easily identifiable.
49_UNBOUND_INPUT_PREFIX = "$unbound_inputs_"
50
51# List of collections that didn't register proto functions, as a result in
52# a previously exported meta_graph the items are of a different data type.
53_COMPAT_COLLECTION_LIST = [ops.GraphKeys.LOCAL_VARIABLES,
54                           ops.GraphKeys.MODEL_VARIABLES,
55                           ops.GraphKeys.METRIC_VARIABLES]
56
57
58def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False):
59  """Create a `NodeDef` proto with export_scope stripped.
60
61  Args:
62    from_node_def: A `node_def_pb2.NodeDef` protocol buffer.
63    export_scope: A `string` representing the name scope to remove.
64    unbound_inputs: An array of unbound input names if they exist.
65    clear_devices: Boolean which controls whether to clear device information
66      from node_def. Default false.
67
68  Returns:
69    A `node_def_pb2.NodeDef` protocol buffer.
70  """
71  node_def = copy.deepcopy(from_node_def)
72  for i, v in enumerate(node_def.input):
73    if (export_scope and
74        not node_def.input[i].lstrip("^").startswith(export_scope)):
75      # Adds "$unbound_inputs_" prefix to the unbound name so they are easily
76      # identifiable.
77      node_def.input[i] = re.sub(r"([\^]|^)(.*)",
78                                 r"\1" + _UNBOUND_INPUT_PREFIX + r"\2",
79                                 compat.as_str(v))
80      unbound_inputs.append(node_def.input[i])
81    else:
82      node_def.input[i] = ops.strip_name_scope(v, export_scope)
83  node_def.name = compat.as_bytes(
84      ops.strip_name_scope(from_node_def.name, export_scope))
85  for k, v in six.iteritems(from_node_def.attr):
86    if k == "_class":
87      new_s = [compat.as_bytes(
88          ops.strip_name_scope(s, export_scope)) for s in v.list.s
89               if not export_scope or
90               compat.as_str(s).split("@")[1].startswith(export_scope)]
91      node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(
92          list=attr_value_pb2.AttrValue.ListValue(s=new_s)))
93    elif node_def.op in ("Enter", "RefEnter") and k == "frame_name":
94      if not export_scope or compat.as_str(v.s).startswith(export_scope):
95        new_s = compat.as_bytes(ops.strip_name_scope(v.s, export_scope))
96      node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(s=new_s))
97    else:
98      node_def.attr[k].CopyFrom(v)
99
100  if clear_devices:
101    node_def.device = ""
102
103  return node_def
104
105
106def _read_file(filename):
107  """Reads a file containing `GraphDef` and returns the protocol buffer.
108
109  Args:
110    filename: `graph_def` filename including the path.
111
112  Returns:
113    A `GraphDef` protocol buffer.
114
115  Raises:
116    IOError: If the file doesn't exist, or cannot be successfully parsed.
117  """
118  graph_def = graph_pb2.GraphDef()
119  if not file_io.file_exists(filename):
120    raise IOError("File %s does not exist." % filename)
121  # First try to read it as a binary file.
122  file_content = file_io.FileIO(filename, "rb").read()
123  try:
124    graph_def.ParseFromString(file_content)
125    return graph_def
126  except Exception:  # pylint: disable=broad-except
127    pass
128
129  # Next try to read it as a text file.
130  try:
131    text_format.Merge(file_content, graph_def)
132  except text_format.ParseError as e:
133    raise IOError("Cannot parse file %s: %s." % (filename, str(e)))
134
135  return graph_def
136
137
138def ops_used_by_graph_def(graph_def):
139  """Collect the list of ops used by a graph.
140
141  Does not validate that the ops are all registered.
142
143  Args:
144    graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.
145
146  Returns:
147    A list of strings, each naming an op used by the graph.
148  """
149  # Map function names to definitions
150  name_to_function = {}
151  for fun in graph_def.library.function:
152    name_to_function[fun.signature.name] = fun
153
154  # Collect the list of op names.  Since functions can reference functions, we
155  # need a recursive traversal.
156  used_ops = set()  # Includes both primitive ops and functions
157  functions_to_process = []  # A subset of used_ops
158
159  def mark_op_as_used(op):
160    if op not in used_ops and op in name_to_function:
161      functions_to_process.append(name_to_function[op])
162    used_ops.add(op)
163
164  def process_node(node):
165    mark_op_as_used(node.op)
166    if node.op in ["PartitionedCall", "StatefulPartitionedCall"]:
167      mark_op_as_used(node.attr["f"].func.name)
168
169  for node in graph_def.node:
170    process_node(node)
171  while functions_to_process:
172    fun = functions_to_process.pop()
173    for node in fun.node_def:
174      process_node(node)
175
176  return [op for op in used_ops if op not in name_to_function]
177
178
179def stripped_op_list_for_graph(graph_def):
180  """Collect the stripped OpDefs for ops used by a graph.
181
182  This function computes the `stripped_op_list` field of `MetaGraphDef` and
183  similar protos.  The result can be communicated from the producer to the
184  consumer, which can then use the C++ function
185  `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility.
186
187  Args:
188    graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.
189
190  Returns:
191    An `OpList` of ops used by the graph.
192  """
193  # This is similar to StrippedOpListForGraph in C++, but unlike its
194  # C++ counterpart, this version does not require all ops to be registered.
195  # This is done to support Prelu fusion in tfjs.
196  used_ops = ops_used_by_graph_def(graph_def)
197  op_defs = []
198  for op in sorted(used_ops):
199    op_def = op_def_registry.get(op)
200    if op_def is not None:
201      op_defs.append(op_def)
202  return op_def_pb2.OpList(op=op_defs)
203
204
205def _get_kind_name(item):
206  """Returns the kind name in CollectionDef.
207
208  Args:
209    item: A data item.
210
211  Returns:
212    The string representation of the kind in CollectionDef.
213  """
214  if isinstance(item, (six.string_types, six.binary_type)):
215    kind = "bytes_list"
216  elif isinstance(item, six.integer_types):
217    kind = "int64_list"
218  elif isinstance(item, float):
219    kind = "float_list"
220  elif isinstance(item, Any):
221    kind = "any_list"
222  else:
223    kind = "node_list"
224  return kind
225
226
227SAVE_AND_RESTORE_OPS = ["SaveV2",
228                        "Save", "SaveSlice",
229                        "LegacySave", "LegacySaveSlice",
230                        "RestoreV2",
231                        "Restore", "RestoreSlice",
232                        "LegacyRestore", "LegacyRestoreSlice"]
233
234
235def _op_name(tensor_name):
236  """Extract the Op name from a Tensor name.
237
238  The Op name is everything before a colon, if present,
239  not including any ^ prefix denoting a control dependency.
240
241  Args:
242    tensor_name: the full name of a Tensor in the graph.
243  Returns:
244    The name of the Op of which the given Tensor is an output.
245  Raises:
246    ValueError: if tensor_name is None or empty.
247  """
248  if not tensor_name:
249    raise ValueError("Tensor name cannot be empty or None.")
250
251  # Control dependency inputs start with ^.
252  if tensor_name.startswith("^"):
253    tensor_name = tensor_name[1:]
254  if ":" in tensor_name:
255    op_name, _ = tensor_name.split(":")
256    return op_name
257  return tensor_name
258
259
260def _get_scope(node_name):
261  """Extract the scope name from a node name.
262
263  The scope name is everything before the final slash,
264  not including any ^ prefix denoting a control dependency.
265
266  Args:
267    node_name: the full name of an Op or a Tensor in the graph.
268  Returns:
269    The deepest named scope containing the node.
270  Raises:
271    ValueError: if tensor_name is None or empty
272  """
273  if not node_name:
274    raise ValueError("Node name cannot be empty or None.")
275
276  # Control dependency inputs start with ^.
277  if node_name.startswith("^"):
278    node_name = node_name[1:]
279  if "/" in node_name:
280    scope, _ = node_name.rsplit("/", 1)
281    return scope
282
283  return ""
284
285
286def _find_extraneous_saver_nodes(graph_def, saver_def):
287  """Identifies any nodes in the graph_def related to unused Savers.
288
289  This approach assumes that each Saver is cleanly isolated in its own name
290  scope, so we need only identify the scopes associated with extraneous Savers
291  and return all the nodes in those scopes.
292
293  Args:
294    graph_def: a GraphDef proto to evaluate.
295    saver_def: a SaverDef proto referencing Save/Restore ops to be retained.
296  Returns:
297    An iterable of node names that may be safely omitted.
298  """
299  # TODO(soergel): confirm that the assumption of scope isolation is valid.
300  # If not, we need to walk up the graph from any restore_all nodes, and walk
301  # down the graph from any Save/Restore nodes.  I drafted that approach too,
302  # but it seems unnecessarily complex given the name scope solution.
303
304  # load the graph DAG in minimal form, without initializing a full Graph object
305  nodes = {
306      node_def.name: (set(_op_name(x) for x in node_def.input), node_def.op)
307      for node_def in graph_def.node
308  }
309
310  retain_scope_save = None
311  retain_scope_restore = None
312  # It's possible to have no saver if the graph has no Variables
313  if saver_def is not None:
314    save_op_name = _op_name(saver_def.save_tensor_name)
315    restore_op_name = _op_name(saver_def.restore_op_name)
316
317    # The save and restore scopes should always be the same, but if they differ
318    # for some reason, we retain them both to be safe.
319    retain_scope_restore = _get_scope(restore_op_name) + "/"
320    retain_scope_save = _get_scope(save_op_name) + "/"
321
322  all_saver_node_names = set(
323      name for name, (_, op) in nodes.items() if op in SAVE_AND_RESTORE_OPS)
324
325  all_saver_scopes = (
326      set(_get_scope(x) for x in all_saver_node_names) - all_saver_node_names)
327  all_saver_scopes = set(x + "/" for x in all_saver_scopes)
328
329  extraneous_scopes = all_saver_scopes - set([retain_scope_save,
330                                              retain_scope_restore])
331
332  extraneous_node_names = set()
333  for name, _ in nodes.items():
334    for extraneous_scope in extraneous_scopes:
335      if name.startswith(extraneous_scope):
336        extraneous_node_names.add(name)
337        break
338
339  return extraneous_node_names
340
341
342def _should_include_node(node_or_node_name, export_scope, exclude_nodes):
343  """Returns `True` if a node should be included.
344
345  Args:
346    node_or_node_name: A node or `string` node name.
347    export_scope: `string`. Name scope under which to extract the subgraph. The
348      scope name will be stripped from the node definitions for easy import
349      later into new name scopes.
350    exclude_nodes: An iterable of nodes or `string` node names to omit from the
351      export, or None.  Note no sanity-checking is done, so this list must be
352      carefully constructed to avoid producing an invalid graph.
353
354  Returns:
355    `True` if the node should be included.
356  """
357  if not isinstance(node_or_node_name, six.string_types):
358    try:
359      node_name = node_or_node_name.name
360    except AttributeError:
361      # Keep the object that we don't know how to process.
362      return True
363  else:
364    node_name = node_or_node_name
365
366  if exclude_nodes and (node_or_node_name in exclude_nodes
367                        or node_name in exclude_nodes):
368    return False
369
370  return (node_name.startswith(_UNBOUND_INPUT_PREFIX) or
371          (not export_scope or node_name.startswith(export_scope)))
372
373
374def add_collection_def(meta_graph_def, key, graph=None,
375                       export_scope=None, exclude_nodes=None,
376                       override_contents=None):
377  """Adds a collection to MetaGraphDef protocol buffer.
378
379  Args:
380    meta_graph_def: MetaGraphDef protocol buffer.
381    key: One of the GraphKeys or user-defined string.
382    graph: The `Graph` from which to get collections.
383    export_scope: Optional `string`. Name scope to remove.
384    exclude_nodes: An iterable of nodes or `string` node names to omit from the
385      collection, or None.
386    override_contents: An iterable of values to place in the collection,
387      ignoring the current values (if set).
388  """
389  if graph and not isinstance(graph, ops.Graph):
390    raise TypeError("graph must be of type Graph, not %s", type(graph))
391
392  if not isinstance(key, six.string_types) and not isinstance(key, bytes):
393    logging.warning("Only collections with string type keys will be "
394                    "serialized. This key has %s", type(key))
395    return
396
397  # Sets graph to default graph if it's not passed in.
398  graph = graph or ops.get_default_graph()
399
400  if override_contents:
401    collection_list = override_contents
402  else:
403    collection_list = graph.get_collection(key)
404
405  # Remove nodes that should not be exported from the collection list.
406  collection_list = [x for x in collection_list if
407                     _should_include_node(x, export_scope, exclude_nodes)]
408  if not collection_list:
409    return
410
411  try:
412    col_def = meta_graph_def.collection_def[key]
413    to_proto = ops.get_to_proto_function(key)
414    proto_type = ops.get_collection_proto_type(key)
415    if to_proto:
416      kind = "bytes_list"
417      for x in collection_list:
418        # Additional type check to make sure the returned proto is indeed
419        # what we expect.
420        proto = to_proto(x, export_scope=export_scope)
421        if proto:
422          assert isinstance(proto, proto_type)
423          getattr(col_def, kind).value.append(proto.SerializeToString())
424    else:
425      kind = _get_kind_name(collection_list[0])
426      if kind == "node_list":
427        for x in collection_list:
428          if not export_scope or x.name.startswith(export_scope):
429            getattr(col_def, kind).value.append(
430                ops.strip_name_scope(x.name, export_scope))
431      elif kind == "bytes_list":
432        # NOTE(opensource): This force conversion is to work around the fact
433        # that Python3 distinguishes between bytes and strings.
434        getattr(col_def, kind).value.extend(
435            [compat.as_bytes(x) for x in collection_list])
436      else:
437        getattr(col_def, kind).value.extend([x for x in collection_list])
438  except Exception as e:  # pylint: disable=broad-except
439    logging.warning("Issue encountered when serializing %s.\n"
440                    "Type is unsupported, or the types of the items don't "
441                    "match field type in CollectionDef. Note this is a warning "
442                    "and probably safe to ignore.\n%s", key, str(e))
443    if key in meta_graph_def.collection_def:
444      del meta_graph_def.collection_def[key]
445    return
446
447
448def _is_default_attr_value(op_def, attr_name, attr_value):
449  """Checks if given attribute matches the default value in the op def."""
450  for attr_def in op_def.attr:
451    if attr_def.name == attr_name:
452      if not attr_def.HasField("default_value"):
453        return False
454      # c_api.EqualAttrValueWrapper returns an empty string
455      # if both arguments represent an equivalent AttrValue instance.
456      return not c_api.EqualAttrValueWrapper(
457          attr_value.SerializeToString(),
458          attr_def.default_value.SerializeToString())
459  return False
460
461
462def strip_graph_default_valued_attrs(meta_graph_def):
463  """Strips default valued attributes for node defs in given MetaGraphDef.
464
465  This method also sets `meta_info_def.stripped_default_attrs` in the given
466  `MetaGraphDef` proto to True.
467
468  Args:
469    meta_graph_def: `MetaGraphDef` protocol buffer
470
471  Returns:
472    None.
473  """
474  # Map function op names to their function definitions.
475  op_name_to_function = {}
476  for function_def in meta_graph_def.graph_def.library.function:
477    op_name_to_function[function_def.signature.name] = function_def
478
479  def _strip_node_default_valued_attrs(node_def):
480    """Removes default valued attributes from a single node def."""
481    if node_def.op in op_name_to_function:
482      return
483
484    op_def = op_def_registry.get(node_def.op)
485    if op_def is None:
486      return
487
488    attrs_to_strip = set()
489    for attr_name, attr_value in node_def.attr.items():
490      if _is_default_attr_value(op_def, attr_name, attr_value):
491        attrs_to_strip.add(attr_name)
492
493    for attr in attrs_to_strip:
494      del node_def.attr[attr]
495
496  # Process all NodeDef instances in graph_def.
497  for node_def in meta_graph_def.graph_def.node:
498    _strip_node_default_valued_attrs(node_def)
499
500  # Process all NodeDef instances in graph_def.library.function.
501  for function_def in meta_graph_def.graph_def.library.function:
502    for function_node_def in function_def.node_def:
503      _strip_node_default_valued_attrs(function_node_def)
504
505  # Tell consumers of this graph that default valued attrs have been stripped.
506  meta_graph_def.meta_info_def.stripped_default_attrs = True
507
508
509def create_meta_graph_def(meta_info_def=None,
510                          graph_def=None,
511                          saver_def=None,
512                          collection_list=None,
513                          graph=None,
514                          export_scope=None,
515                          exclude_nodes=None,
516                          clear_extraneous_savers=False,
517                          strip_default_attrs=False):
518  # pylint: disable=line-too-long
519  """Construct and returns a `MetaGraphDef` protocol buffer.
520
521  Args:
522    meta_info_def: `MetaInfoDef` protocol buffer.
523    graph_def: `GraphDef` protocol buffer.
524    saver_def: `SaverDef` protocol buffer.
525    collection_list: List of string keys to collect.
526    graph: The `Graph` to create `MetaGraphDef` out of.
527    export_scope: Optional `string`. Name scope to remove.
528    exclude_nodes: An iterable of nodes or `string` node names to omit from all
529      collection, or None.
530    clear_extraneous_savers: Remove any preexisting SaverDefs from the SAVERS
531        collection.  Note this method does not alter the graph, so any
532        extraneous Save/Restore ops should have been removed already, as needed.
533    strip_default_attrs: Boolean. If `True`, default-valued attributes will be
534        removed from the NodeDefs. For a detailed guide, see
535        [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
536
537  Returns:
538    MetaGraphDef protocol buffer.
539
540  Raises:
541    TypeError: If the arguments are not of the correct proto buffer type.
542  """
543  # pylint: enable=line-too-long
544  # Type check.
545  if graph and not isinstance(graph, ops.Graph):
546    raise TypeError("graph must be of type Graph, not %s", type(graph))
547  if meta_info_def and not isinstance(meta_info_def,
548                                      meta_graph_pb2.MetaGraphDef.MetaInfoDef):
549    raise TypeError("meta_info_def must be of type MetaInfoDef, not %s",
550                    type(meta_info_def))
551  if graph_def and not isinstance(graph_def, graph_pb2.GraphDef):
552    raise TypeError("graph_def must be of type GraphDef, not %s",
553                    type(graph_def))
554  if saver_def and not isinstance(saver_def, saver_pb2.SaverDef):
555    raise TypeError("saver_def must be of type SaverDef, not %s",
556                    type(saver_def))
557
558  # Sets graph to default graph if it's not passed in.
559  graph = graph or ops.get_default_graph()
560
561  # Creates a MetaGraphDef proto.
562  meta_graph_def = meta_graph_pb2.MetaGraphDef()
563  # Adds meta_info_def.
564  if not meta_info_def:
565    meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()
566
567  # Set the tf version strings to the current tf build.
568  meta_info_def.tensorflow_version = versions.__version__
569  meta_info_def.tensorflow_git_version = versions.__git_version__
570  meta_graph_def.meta_info_def.MergeFrom(meta_info_def)
571
572  # Adds graph_def or the default.
573  if not graph_def:
574    meta_graph_def.graph_def.MergeFrom(graph.as_graph_def(add_shapes=True))
575  else:
576    meta_graph_def.graph_def.MergeFrom(graph_def)
577
578  # Fills in meta_info_def.stripped_op_list using the ops from graph_def.
579  # pylint: disable=g-explicit-length-test
580  if len(meta_graph_def.meta_info_def.stripped_op_list.op) == 0:
581    meta_graph_def.meta_info_def.stripped_op_list.MergeFrom(
582        stripped_op_list_for_graph(meta_graph_def.graph_def))
583  # pylint: enable=g-explicit-length-test
584
585  # Strip default valued attributes in graph_def.
586  if strip_default_attrs:
587    strip_graph_default_valued_attrs(meta_graph_def)
588
589  # Adds saver_def.
590  if saver_def:
591    meta_graph_def.saver_def.MergeFrom(saver_def)
592
593  # Adds collection_list.
594  if collection_list is not None:
595    clist = collection_list
596  else:
597    clist = graph.get_all_collection_keys()
598
599  for ctype in clist:
600    if clear_extraneous_savers and ctype == ops.GraphKeys.SAVERS:
601      # Avoid importing Saver here
602      from_proto = ops.get_from_proto_function(ctype)
603      add_collection_def(meta_graph_def, ctype,
604                         graph=graph,
605                         export_scope=export_scope,
606                         exclude_nodes=exclude_nodes,
607                         override_contents=[from_proto(saver_def)])
608    else:
609      add_collection_def(meta_graph_def, ctype,
610                         graph=graph,
611                         export_scope=export_scope,
612                         exclude_nodes=exclude_nodes)
613  return meta_graph_def
614
615
616def read_meta_graph_file(filename):
617  """Reads a file containing `MetaGraphDef` and returns the protocol buffer.
618
619  Args:
620    filename: `meta_graph_def` filename including the path.
621
622  Returns:
623    A `MetaGraphDef` protocol buffer.
624
625  Raises:
626    IOError: If the file doesn't exist, or cannot be successfully parsed.
627  """
628  meta_graph_def = meta_graph_pb2.MetaGraphDef()
629  if not file_io.file_exists(filename):
630    raise IOError("File %s does not exist." % filename)
631  # First try to read it as a binary file.
632  file_content = file_io.FileIO(filename, "rb").read()
633  try:
634    meta_graph_def.ParseFromString(file_content)
635    return meta_graph_def
636  except Exception:  # pylint: disable=broad-except
637    pass
638
639  # Next try to read it as a text file.
640  try:
641    text_format.Merge(file_content.decode("utf-8"), meta_graph_def)
642  except text_format.ParseError as e:
643    raise IOError("Cannot parse file %s: %s." % (filename, str(e)))
644
645  return meta_graph_def
646
647
648def import_scoped_meta_graph(meta_graph_or_file,
649                             clear_devices=False,
650                             graph=None,
651                             import_scope=None,
652                             input_map=None,
653                             unbound_inputs_col_name="unbound_inputs",
654                             restore_collections_predicate=(lambda key: True)):
655  """Recreates a `Graph` saved in a `MetaGraphDef` proto.
656
657  This function takes a `MetaGraphDef` protocol buffer as input. If
658  the argument is a file containing a `MetaGraphDef` protocol buffer ,
659  it constructs a protocol buffer from the file content. The function
660  then adds all the nodes from the `graph_def` field to the
661  current graph, recreates the desired collections, and returns a dictionary of
662  all the Variables imported into the name scope.
663
664  In combination with `export_scoped_meta_graph()`, this function can be used to
665
666  * Serialize a graph along with other Python objects such as `QueueRunner`,
667    `Variable` into a `MetaGraphDef`.
668
669  * Restart training from a saved graph and checkpoints.
670
671  * Run inference from a saved graph and checkpoints.
672
673  Args:
674    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
675      the path) containing a `MetaGraphDef`.
676    clear_devices: Boolean which controls whether to clear device information
677      from graph_def. Default false.
678    graph: The `Graph` to import into. If `None`, use the default graph.
679    import_scope: Optional `string`. Name scope into which to import the
680      subgraph. If `None`, the graph is imported to the root name scope.
681    input_map: A dictionary mapping input names (as strings) in `graph_def` to
682      `Tensor` objects. The values of the named input tensors in the imported
683      graph will be re-mapped to the respective `Tensor` values.
684    unbound_inputs_col_name: Collection name for looking up unbound inputs.
685    restore_collections_predicate: a predicate on collection names. A collection
686      named c (i.e whose key is c) will be restored iff
687      1) `restore_collections_predicate(c)` is True, and
688      2) `c != unbound_inputs_col_name`.
689
690  Returns:
691    A dictionary of all the `Variables` imported into the name scope.
692
693  Raises:
694    ValueError: If the graph_def contains unbound inputs.
695  """
696  return import_scoped_meta_graph_with_return_elements(
697      meta_graph_or_file, clear_devices, graph, import_scope, input_map,
698      unbound_inputs_col_name, restore_collections_predicate)[0]
699
700
701def import_scoped_meta_graph_with_return_elements(
702    meta_graph_or_file,
703    clear_devices=False,
704    graph=None,
705    import_scope=None,
706    input_map=None,
707    unbound_inputs_col_name="unbound_inputs",
708    restore_collections_predicate=(lambda key: True),
709    return_elements=None):
710  """Imports graph from `MetaGraphDef` and returns vars and return elements.
711
712  This function takes a `MetaGraphDef` protocol buffer as input. If
713  the argument is a file containing a `MetaGraphDef` protocol buffer ,
714  it constructs a protocol buffer from the file content. The function
715  then adds all the nodes from the `graph_def` field to the
716  current graph, recreates the desired collections, and returns a dictionary of
717  all the Variables imported into the name scope.
718
719  In combination with `export_scoped_meta_graph()`, this function can be used to
720
721  * Serialize a graph along with other Python objects such as `QueueRunner`,
722    `Variable` into a `MetaGraphDef`.
723
724  * Restart training from a saved graph and checkpoints.
725
726  * Run inference from a saved graph and checkpoints.
727
728  Args:
729    meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
730      the path) containing a `MetaGraphDef`.
731    clear_devices: Boolean which controls whether to clear device information
732      from graph_def. Default false.
733    graph: The `Graph` to import into. If `None`, use the default graph.
734    import_scope: Optional `string`. Name scope into which to import the
735      subgraph. If `None`, the graph is imported to the root name scope.
736    input_map: A dictionary mapping input names (as strings) in `graph_def` to
737      `Tensor` objects. The values of the named input tensors in the imported
738      graph will be re-mapped to the respective `Tensor` values.
739    unbound_inputs_col_name: Collection name for looking up unbound inputs.
740    restore_collections_predicate: a predicate on collection names. A collection
741      named c (i.e whose key is c) will be restored iff
742      1) `restore_collections_predicate(c)` is True, and
743      2) `c != unbound_inputs_col_name`.
744    return_elements:  A list of strings containing operation names in the
745      `MetaGraphDef` that will be returned as `Operation` objects; and/or
746      tensor names in `MetaGraphDef` that will be returned as `Tensor` objects.
747
748  Returns:
749    A tuple of (
750      dictionary of all the `Variables` imported into the name scope,
751      list of `Operation` or `Tensor` objects from the `return_elements` list).
752
753  Raises:
754    ValueError: If the graph_def contains unbound inputs.
755
756  """
757  if context.executing_eagerly():
758    raise ValueError("Exporting/importing meta graphs is not supported when "
759                     "eager execution is enabled.")
760  if isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
761    meta_graph_def = meta_graph_or_file
762  else:
763    meta_graph_def = read_meta_graph_file(meta_graph_or_file)
764
765  if unbound_inputs_col_name:
766    for key, col_def in meta_graph_def.collection_def.items():
767      if key == unbound_inputs_col_name:
768        kind = col_def.WhichOneof("kind")
769        field = getattr(col_def, kind)
770        if field.value and (
771            not input_map or
772            sorted([compat.as_str(v) for v in field.value]) !=
773            sorted(input_map)):
774          raise ValueError("Graph contains unbound inputs: %s. Must "
775                           "provide these inputs through input_map." % ",".join(
776                               compat.as_str(v)
777                               for v in field.value
778                               if not input_map or v not in input_map))
779        break
780
781  # Sets graph to default graph if it's not passed in.
782  graph = graph or ops.get_default_graph()
783
784  # Gathers the list of nodes we are interested in.
785  with graph.as_default():
786    producer_op_list = None
787    if meta_graph_def.meta_info_def.HasField("stripped_op_list"):
788      producer_op_list = meta_graph_def.meta_info_def.stripped_op_list
789    input_graph_def = meta_graph_def.graph_def
790    # Remove all the explicit device specifications for this node. This helps to
791    # make the graph more portable.
792    if clear_devices:
793      for node in input_graph_def.node:
794        node.device = ""
795
796    scope_to_prepend_to_names = graph.unique_name(
797        import_scope or "", mark_as_used=False)
798
799    imported_return_elements = importer.import_graph_def(
800        input_graph_def,
801        name=(import_scope or scope_to_prepend_to_names),
802        input_map=input_map,
803        producer_op_list=producer_op_list,
804        return_elements=return_elements)
805
806    # TensorFlow versions before 1.9 (not inclusive) exported SavedModels
807    # without a VariableDef.trainable field set.
808    tf_version = meta_graph_def.meta_info_def.tensorflow_version
809    if not tf_version:
810      variables_have_trainable = True
811    else:
812      variables_have_trainable = (
813          distutils_version.LooseVersion(tf_version)
814          >= distutils_version.LooseVersion("1.9"))
815
816    # Sort collections so we see TRAINABLE_VARIABLES first and can default these
817    # variables to trainable if the value is not set in their VariableDef.
818    sorted_collections = []
819    if ops.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def:
820      sorted_collections.append(
821          (ops.GraphKeys.TRAINABLE_VARIABLES,
822           meta_graph_def.collection_def[ops.GraphKeys.TRAINABLE_VARIABLES]))
823    for key, value in sorted(meta_graph_def.collection_def.items()):
824      if key != ops.GraphKeys.TRAINABLE_VARIABLES:
825        sorted_collections.append((key, value))
826
827    # Restores all the other collections.
828    variable_objects = {}
829    for key, col_def in sorted_collections:
830      # Don't add unbound_inputs to the new graph.
831      if key == unbound_inputs_col_name:
832        continue
833      if not restore_collections_predicate(key):
834        continue
835
836      kind = col_def.WhichOneof("kind")
837      if kind is None:
838        logging.error("Cannot identify data type for collection %s. Skipping.",
839                      key)
840        continue
841      from_proto = ops.get_from_proto_function(key)
842
843      # Temporary change to allow the TFMA evaluator to read metric variables
844      # saved as a bytes list.
845      # TODO(kathywu): Remove this hack once cl/248406059 has been submitted.
846      if key == ops.GraphKeys.METRIC_VARIABLES:
847        # Metric variables will use the same proto functions as GLOBAL_VARIABLES
848        from_proto = ops.get_from_proto_function(ops.GraphKeys.GLOBAL_VARIABLES)
849      if from_proto and kind == "bytes_list":
850        proto_type = ops.get_collection_proto_type(key)
851        if key in ops.GraphKeys._VARIABLE_COLLECTIONS:  # pylint: disable=protected-access
852          for value in col_def.bytes_list.value:
853            variable = variable_objects.get(value, None)
854            if variable is None:
855              proto = proto_type()
856              proto.ParseFromString(value)
857              if not variables_have_trainable:
858                # If the VariableDef proto does not contain a "trainable"
859                # property because it was exported before that property was
860                # added, we default it to whether the variable is in the
861                # TRAINABLE_VARIABLES collection. We've sorted
862                # TRAINABLE_VARIABLES to be first, so trainable variables will
863                # be created from that collection.
864                proto.trainable = (key == ops.GraphKeys.TRAINABLE_VARIABLES)
865              variable = from_proto(
866                  proto, import_scope=scope_to_prepend_to_names)
867              variable_objects[value] = variable
868            graph.add_to_collection(key, variable)
869        else:
870          for value in col_def.bytes_list.value:
871            proto = proto_type()
872            proto.ParseFromString(value)
873            graph.add_to_collection(
874                key, from_proto(
875                    proto, import_scope=scope_to_prepend_to_names))
876      else:
877        field = getattr(col_def, kind)
878        if key in _COMPAT_COLLECTION_LIST:
879          logging.warning(
880              "The saved meta_graph is possibly from an older release:\n"
881              "'%s' collection should be of type 'byte_list', but instead "
882              "is of type '%s'.", key, kind)
883        if kind == "node_list":
884          for value in field.value:
885            col_op = graph.as_graph_element(
886                ops.prepend_name_scope(value, scope_to_prepend_to_names))
887            graph.add_to_collection(key, col_op)
888        elif kind == "int64_list":
889          # NOTE(opensource): This force conversion is to work around the fact
890          # that Python2 distinguishes between int and long, while Python3 has
891          # only int.
892          for value in field.value:
893            graph.add_to_collection(key, int(value))
894        else:
895          for value in field.value:
896            graph.add_to_collection(
897                key, ops.prepend_name_scope(value, scope_to_prepend_to_names))
898
899    var_list = {}
900    variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
901                                     scope=scope_to_prepend_to_names)
902    for v in variables:
903      var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v
904
905  return var_list, imported_return_elements
906
907
908def export_scoped_meta_graph(filename=None,
909                             graph_def=None,
910                             graph=None,
911                             export_scope=None,
912                             as_text=False,
913                             unbound_inputs_col_name="unbound_inputs",
914                             clear_devices=False,
915                             saver_def=None,
916                             clear_extraneous_savers=False,
917                             strip_default_attrs=False,
918                             save_debug_info=False,
919                             **kwargs):
920  """Returns `MetaGraphDef` proto. Optionally writes it to filename.
921
922  This function exports the graph, saver, and collection objects into
923  `MetaGraphDef` protocol buffer with the intention of it being imported
924  at a later time or location to restart training, run inference, or be
925  a subgraph.
926
927  Args:
928    filename: Optional filename including the path for writing the
929      generated `MetaGraphDef` protocol buffer.
930    graph_def: `GraphDef` protocol buffer.
931    graph: The `Graph` to export. If `None`, use the default graph.
932    export_scope: Optional `string`. Name scope under which to extract
933      the subgraph. The scope name will be stripped from the node definitions
934      for easy import later into new name scopes. If `None`, the whole graph
935      is exported.
936    as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
937    unbound_inputs_col_name: Optional `string`. If provided, a string collection
938      with the given name will be added to the returned `MetaGraphDef`,
939      containing the names of tensors that must be remapped when importing the
940      `MetaGraphDef`.
941    clear_devices: Boolean which controls whether to clear device information
942      before exporting the graph.
943    saver_def: `SaverDef` protocol buffer.
944    clear_extraneous_savers: Remove any Saver-related information from the
945        graph (both Save/Restore ops and SaverDefs) that are not associated
946        with the provided SaverDef.
947    strip_default_attrs: Set to true if default valued attributes must be
948      removed while exporting the GraphDef.
949    save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
950      which in the same directory of filename and with `_debug` added before the
951      file extension.
952    **kwargs: Optional keyed arguments, including meta_info_def and
953        collection_list.
954
955  Returns:
956    A `MetaGraphDef` proto and dictionary of `Variables` in the exported
957    name scope.
958
959  Raises:
960    ValueError: When the `GraphDef` is larger than 2GB.
961    ValueError: When executing in Eager mode and either `graph_def` or `graph`
962      is undefined.
963  """
964  if context.executing_eagerly() and not (graph_def is not None and
965                                          graph is not None):
966    raise ValueError("Exporting/importing meta graphs is not supported when "
967                     "Eager Execution is enabled.")
968  graph = graph or ops.get_default_graph()
969
970  exclude_nodes = None
971  unbound_inputs = []
972  if export_scope or clear_extraneous_savers or clear_devices:
973    if graph_def:
974      new_graph_def = graph_pb2.GraphDef()
975      new_graph_def.versions.CopyFrom(graph_def.versions)
976      new_graph_def.library.CopyFrom(graph_def.library)
977
978      if clear_extraneous_savers:
979        exclude_nodes = _find_extraneous_saver_nodes(graph_def, saver_def)
980
981      for node_def in graph_def.node:
982        if _should_include_node(node_def.name, export_scope, exclude_nodes):
983          new_node_def = _node_def(node_def, export_scope, unbound_inputs,
984                                   clear_devices=clear_devices)
985          new_graph_def.node.extend([new_node_def])
986      graph_def = new_graph_def
987    else:
988      # Only do this complicated work if we want to remove a name scope.
989      graph_def = graph_pb2.GraphDef()
990      # pylint: disable=protected-access
991      graph_def.versions.CopyFrom(graph.graph_def_versions)
992      bytesize = 0
993
994      if clear_extraneous_savers:
995        exclude_nodes = _find_extraneous_saver_nodes(graph.as_graph_def(),
996                                                     saver_def)
997
998      for key in sorted(graph._nodes_by_id):
999        if _should_include_node(graph._nodes_by_id[key].name,
1000                                export_scope,
1001                                exclude_nodes):
1002          value = graph._nodes_by_id[key]
1003          # pylint: enable=protected-access
1004          node_def = _node_def(value.node_def, export_scope, unbound_inputs,
1005                               clear_devices=clear_devices)
1006          graph_def.node.extend([node_def])
1007          if value.outputs:
1008            assert "_output_shapes" not in graph_def.node[-1].attr
1009            graph_def.node[-1].attr["_output_shapes"].list.shape.extend([
1010                output.get_shape().as_proto() for output in value.outputs])
1011          bytesize += value.node_def.ByteSize()
1012          if bytesize >= (1 << 31) or bytesize < 0:
1013            raise ValueError("GraphDef cannot be larger than 2GB.")
1014
1015      graph._copy_functions_to_graph_def(graph_def, bytesize)  # pylint: disable=protected-access
1016
1017    # It's possible that not all the inputs are in the export_scope.
1018    # If we would like such information included in the exported meta_graph,
1019    # add them to a special unbound_inputs collection.
1020    if unbound_inputs_col_name:
1021      # Clears the unbound_inputs collections.
1022      graph.clear_collection(unbound_inputs_col_name)
1023      for k in unbound_inputs:
1024        graph.add_to_collection(unbound_inputs_col_name, k)
1025
1026  var_list = {}
1027  variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
1028                                   scope=export_scope)
1029  for v in variables:
1030    if _should_include_node(v, export_scope, exclude_nodes):
1031      var_list[ops.strip_name_scope(v.name, export_scope)] = v
1032
1033  scoped_meta_graph_def = create_meta_graph_def(
1034      graph_def=graph_def,
1035      graph=graph,
1036      export_scope=export_scope,
1037      exclude_nodes=exclude_nodes,
1038      clear_extraneous_savers=clear_extraneous_savers,
1039      saver_def=saver_def,
1040      strip_default_attrs=strip_default_attrs,
1041      **kwargs)
1042
1043  if filename:
1044    graph_io.write_graph(
1045        scoped_meta_graph_def,
1046        os.path.dirname(filename),
1047        os.path.basename(filename),
1048        as_text=as_text)
1049    if save_debug_info:
1050      name, _ = os.path.splitext(filename)
1051      debug_filename = "{name}{ext}".format(name=name, ext=".debug")
1052
1053      # Gets the operation from the graph by the name. Excludes variable nodes,
1054      # so only the nodes in the frozen models are included.
1055      # TODO(liufengdb): fix this for functions.
1056      ops_to_export = []
1057      for node in scoped_meta_graph_def.graph_def.node:
1058        scoped_op_name = ops.prepend_name_scope(node.name, export_scope)
1059        ops_to_export.append(("", graph.get_operation_by_name(scoped_op_name)))
1060
1061      graph_debug_info = error_interpolation.create_graph_debug_info_def(
1062          ops_to_export)
1063
1064      graph_io.write_graph(
1065          graph_debug_info,
1066          os.path.dirname(debug_filename),
1067          os.path.basename(debug_filename),
1068          as_text=as_text)
1069
1070  return scoped_meta_graph_def, var_list
1071
1072
1073def copy_scoped_meta_graph(from_scope, to_scope,
1074                           from_graph=None, to_graph=None):
1075  """Copies a sub-meta_graph from one scope to another.
1076
1077  Args:
1078    from_scope: `String` name scope containing the subgraph to be copied.
1079    to_scope: `String` name scope under which the copied subgraph will reside.
1080    from_graph: Optional `Graph` from which to copy the subgraph. If `None`, the
1081      default graph is use.
1082    to_graph: Optional `Graph` to which to copy the subgraph. If `None`, the
1083      default graph is used.
1084
1085  Returns:
1086    A dictionary of `Variables` that has been copied into `to_scope`.
1087
1088  Raises:
1089    ValueError: If `from_scope` and `to_scope` are the same while
1090      `from_graph` and `to_graph` are also the same.
1091  """
1092  from_graph = from_graph or ops.get_default_graph()
1093  to_graph = to_graph or ops.get_default_graph()
1094
1095  if from_graph == to_graph and from_scope == to_scope:
1096    raise ValueError("'from_scope' and 'to_scope' need to be different "
1097                     "when performing copy in the same graph.")
1098
1099  orig_meta_graph, var_list = export_scoped_meta_graph(
1100      export_scope=from_scope, graph=from_graph)
1101  var_list = import_scoped_meta_graph(orig_meta_graph,
1102                                      graph=to_graph,
1103                                      import_scope=to_scope)
1104  return var_list
1105