1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15r"""Converts checkpoint variables into Const ops in a standalone GraphDef file.
16
17This script is designed to take a GraphDef proto, a SaverDef proto, and a set of
18variable values stored in a checkpoint file, and output a GraphDef with all of
19the variable ops converted into const ops containing the values of the
20variables.
21
22It's useful to do this when we need to load a single file in C++, especially in
23environments like mobile or embedded where we may not have access to the
24RestoreTensor ops and file loading calls that they rely on.
25
26An example of command-line usage is:
27bazel build tensorflow/python/tools:freeze_graph && \
28bazel-bin/tensorflow/python/tools/freeze_graph \
29--input_graph=some_graph_def.pb \
30--input_checkpoint=model.ckpt-8361242 \
31--output_graph=/tmp/frozen_graph.pb --output_node_names=softmax
32
33You can also look at freeze_graph_test.py for an example of how to use it.
34
35"""
36from __future__ import absolute_import
37from __future__ import division
38from __future__ import print_function
39
40import argparse
41import re
42import sys
43
44from google.protobuf import text_format
45
46from tensorflow.core.framework import graph_pb2
47from tensorflow.core.protobuf import saver_pb2
48from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
49from tensorflow.python import pywrap_tensorflow
50from tensorflow.python.client import session
51from tensorflow.python.framework import graph_util
52from tensorflow.python.framework import importer
53from tensorflow.python.platform import app
54from tensorflow.python.platform import gfile
55from tensorflow.python.saved_model import loader
56from tensorflow.python.saved_model import tag_constants
57from tensorflow.python.tools import saved_model_utils
58from tensorflow.python.training import checkpoint_management
59from tensorflow.python.training import saver as saver_lib
60
61
62def _has_no_variables(sess):
63  """Determines if the graph has any variables.
64
65  Args:
66    sess: TensorFlow Session.
67
68  Returns:
69    Bool.
70  """
71  for op in sess.graph.get_operations():
72    if op.type.startswith("Variable") or op.type.endswith("VariableOp"):
73      return False
74  return True
75
76
77def freeze_graph_with_def_protos(input_graph_def,
78                                 input_saver_def,
79                                 input_checkpoint,
80                                 output_node_names,
81                                 restore_op_name,
82                                 filename_tensor_name,
83                                 output_graph,
84                                 clear_devices,
85                                 initializer_nodes,
86                                 variable_names_whitelist="",
87                                 variable_names_blacklist="",
88                                 input_meta_graph_def=None,
89                                 input_saved_model_dir=None,
90                                 saved_model_tags=None,
91                                 checkpoint_version=saver_pb2.SaverDef.V2):
92  """Converts all variables in a graph and checkpoint into constants.
93
94  Args:
95    input_graph_def: A `GraphDef`.
96    input_saver_def: A `SaverDef` (optional).
97    input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
98      priority.  Typically the result of `Saver.save()` or that of
99      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
100      V1/V2.
101    output_node_names: The name(s) of the output nodes, comma separated.
102    restore_op_name: Unused.
103    filename_tensor_name: Unused.
104    output_graph: String where to write the frozen `GraphDef`.
105    clear_devices: A Bool whether to remove device specifications.
106    initializer_nodes: Comma separated string of initializer nodes to run before
107                       freezing.
108    variable_names_whitelist: The set of variable names to convert (optional, by
109                              default, all variables are converted).
110    variable_names_blacklist: The set of variable names to omit converting
111                              to constants (optional).
112    input_meta_graph_def: A `MetaGraphDef` (optional),
113    input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file
114                           and variables (optional).
115    saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
116                      load, in string format (optional).
117    checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
118                        or saver_pb2.SaverDef.V2)
119
120  Returns:
121    Location of the output_graph_def.
122  """
123  del restore_op_name, filename_tensor_name  # Unused by updated loading code.
124
125  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
126  if (not input_saved_model_dir and
127      not checkpoint_management.checkpoint_exists(input_checkpoint)):
128    print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
129    return -1
130
131  if not output_node_names:
132    print("You need to supply the name of a node to --output_node_names.")
133    return -1
134
135  # Remove all the explicit device specifications for this node. This helps to
136  # make the graph more portable.
137  if clear_devices:
138    if input_meta_graph_def:
139      for node in input_meta_graph_def.graph_def.node:
140        node.device = ""
141    elif input_graph_def:
142      for node in input_graph_def.node:
143        node.device = ""
144
145  if input_graph_def:
146    _ = importer.import_graph_def(input_graph_def, name="")
147  with session.Session() as sess:
148    if input_saver_def:
149      saver = saver_lib.Saver(
150          saver_def=input_saver_def, write_version=checkpoint_version)
151      saver.restore(sess, input_checkpoint)
152    elif input_meta_graph_def:
153      restorer = saver_lib.import_meta_graph(
154          input_meta_graph_def, clear_devices=True)
155      restorer.restore(sess, input_checkpoint)
156      if initializer_nodes:
157        sess.run(initializer_nodes.replace(" ", "").split(","))
158    elif input_saved_model_dir:
159      if saved_model_tags is None:
160        saved_model_tags = []
161      loader.load(sess, saved_model_tags, input_saved_model_dir)
162    else:
163      var_list = {}
164      reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
165      var_to_shape_map = reader.get_variable_to_shape_map()
166
167      # List of all partition variables. Because the condition is heuristic
168      # based, the list could include false positives.
169      all_parition_variable_names = [
170          tensor.name.split(":")[0]
171          for op in sess.graph.get_operations()
172          for tensor in op.values()
173          if re.search(r"/part_\d+/", tensor.name)
174      ]
175      has_partition_var = False
176
177      for key in var_to_shape_map:
178        try:
179          tensor = sess.graph.get_tensor_by_name(key + ":0")
180          if any(key in name for name in all_parition_variable_names):
181            has_partition_var = True
182        except KeyError:
183          # This tensor doesn't exist in the graph (for example it's
184          # 'global_step' or a similar housekeeping element) so skip it.
185          continue
186        var_list[key] = tensor
187
188      try:
189        saver = saver_lib.Saver(
190            var_list=var_list, write_version=checkpoint_version)
191      except TypeError as e:
192        # `var_list` is required to be a map of variable names to Variable
193        # tensors. Partition variables are Identity tensors that cannot be
194        # handled by Saver.
195        if has_partition_var:
196          print("Models containing partition variables cannot be converted "
197                "from checkpoint files. Please pass in a SavedModel using "
198                "the flag --input_saved_model_dir.")
199          return -1
200        # Models that have been frozen previously do not contain Variables.
201        elif _has_no_variables(sess):
202          print("No variables were found in this model. It is likely the model "
203                "was frozen previously. You cannot freeze a graph twice.")
204          return 0
205        else:
206          raise e
207
208      saver.restore(sess, input_checkpoint)
209      if initializer_nodes:
210        sess.run(initializer_nodes.replace(" ", "").split(","))
211
212    variable_names_whitelist = (
213        variable_names_whitelist.replace(" ", "").split(",")
214        if variable_names_whitelist else None)
215    variable_names_blacklist = (
216        variable_names_blacklist.replace(" ", "").split(",")
217        if variable_names_blacklist else None)
218
219    if input_meta_graph_def:
220      output_graph_def = graph_util.convert_variables_to_constants(
221          sess,
222          input_meta_graph_def.graph_def,
223          output_node_names.replace(" ", "").split(","),
224          variable_names_whitelist=variable_names_whitelist,
225          variable_names_blacklist=variable_names_blacklist)
226    else:
227      output_graph_def = graph_util.convert_variables_to_constants(
228          sess,
229          input_graph_def,
230          output_node_names.replace(" ", "").split(","),
231          variable_names_whitelist=variable_names_whitelist,
232          variable_names_blacklist=variable_names_blacklist)
233
234  # Write GraphDef to file if output path has been given.
235  if output_graph:
236    with gfile.GFile(output_graph, "wb") as f:
237      f.write(output_graph_def.SerializeToString())
238
239  return output_graph_def
240
241
242def _parse_input_graph_proto(input_graph, input_binary):
243  """Parses input tensorflow graph into GraphDef proto."""
244  if not gfile.Exists(input_graph):
245    print("Input graph file '" + input_graph + "' does not exist!")
246    return -1
247  input_graph_def = graph_pb2.GraphDef()
248  mode = "rb" if input_binary else "r"
249  with gfile.GFile(input_graph, mode) as f:
250    if input_binary:
251      input_graph_def.ParseFromString(f.read())
252    else:
253      text_format.Merge(f.read(), input_graph_def)
254  return input_graph_def
255
256
257def _parse_input_meta_graph_proto(input_graph, input_binary):
258  """Parses input tensorflow graph into MetaGraphDef proto."""
259  if not gfile.Exists(input_graph):
260    print("Input meta graph file '" + input_graph + "' does not exist!")
261    return -1
262  input_meta_graph_def = MetaGraphDef()
263  mode = "rb" if input_binary else "r"
264  with gfile.GFile(input_graph, mode) as f:
265    if input_binary:
266      input_meta_graph_def.ParseFromString(f.read())
267    else:
268      text_format.Merge(f.read(), input_meta_graph_def)
269  print("Loaded meta graph file '" + input_graph)
270  return input_meta_graph_def
271
272
273def _parse_input_saver_proto(input_saver, input_binary):
274  """Parses input tensorflow Saver into SaverDef proto."""
275  if not gfile.Exists(input_saver):
276    print("Input saver file '" + input_saver + "' does not exist!")
277    return -1
278  mode = "rb" if input_binary else "r"
279  with gfile.GFile(input_saver, mode) as f:
280    saver_def = saver_pb2.SaverDef()
281    if input_binary:
282      saver_def.ParseFromString(f.read())
283    else:
284      text_format.Merge(f.read(), saver_def)
285  return saver_def
286
287
288def freeze_graph(input_graph,
289                 input_saver,
290                 input_binary,
291                 input_checkpoint,
292                 output_node_names,
293                 restore_op_name,
294                 filename_tensor_name,
295                 output_graph,
296                 clear_devices,
297                 initializer_nodes,
298                 variable_names_whitelist="",
299                 variable_names_blacklist="",
300                 input_meta_graph=None,
301                 input_saved_model_dir=None,
302                 saved_model_tags=tag_constants.SERVING,
303                 checkpoint_version=saver_pb2.SaverDef.V2):
304  """Converts all variables in a graph and checkpoint into constants.
305
306  Args:
307    input_graph: A `GraphDef` file to load.
308    input_saver: A TensorFlow Saver file.
309    input_binary: A Bool. True means input_graph is .pb, False indicates .pbtxt.
310    input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
311      priority.  Typically the result of `Saver.save()` or that of
312      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
313      V1/V2.
314    output_node_names: The name(s) of the output nodes, comma separated.
315    restore_op_name: Unused.
316    filename_tensor_name: Unused.
317    output_graph: String where to write the frozen `GraphDef`.
318    clear_devices: A Bool whether to remove device specifications.
319    initializer_nodes: Comma separated list of initializer nodes to run before
320                       freezing.
321    variable_names_whitelist: The set of variable names to convert (optional, by
322                              default, all variables are converted),
323    variable_names_blacklist: The set of variable names to omit converting
324                              to constants (optional).
325    input_meta_graph: A `MetaGraphDef` file to load (optional).
326    input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and
327                           variables (optional).
328    saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
329                      load, in string format.
330    checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
331                        or saver_pb2.SaverDef.V2).
332  Returns:
333    String that is the location of frozen GraphDef.
334  """
335  input_graph_def = None
336  if input_saved_model_dir:
337    input_graph_def = saved_model_utils.get_meta_graph_def(
338        input_saved_model_dir, saved_model_tags).graph_def
339  elif input_graph:
340    input_graph_def = _parse_input_graph_proto(input_graph, input_binary)
341  input_meta_graph_def = None
342  if input_meta_graph:
343    input_meta_graph_def = _parse_input_meta_graph_proto(
344        input_meta_graph, input_binary)
345  input_saver_def = None
346  if input_saver:
347    input_saver_def = _parse_input_saver_proto(input_saver, input_binary)
348  freeze_graph_with_def_protos(
349      input_graph_def,
350      input_saver_def,
351      input_checkpoint,
352      output_node_names,
353      restore_op_name,
354      filename_tensor_name,
355      output_graph,
356      clear_devices,
357      initializer_nodes,
358      variable_names_whitelist,
359      variable_names_blacklist,
360      input_meta_graph_def,
361      input_saved_model_dir,
362      saved_model_tags.replace(" ", "").split(","),
363      checkpoint_version=checkpoint_version)
364
365
366def main(unused_args, flags):
367  if flags.checkpoint_version == 1:
368    checkpoint_version = saver_pb2.SaverDef.V1
369  elif flags.checkpoint_version == 2:
370    checkpoint_version = saver_pb2.SaverDef.V2
371  else:
372    print("Invalid checkpoint version (must be '1' or '2'): %d" %
373          flags.checkpoint_version)
374    return -1
375  freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary,
376               flags.input_checkpoint, flags.output_node_names,
377               flags.restore_op_name, flags.filename_tensor_name,
378               flags.output_graph, flags.clear_devices, flags.initializer_nodes,
379               flags.variable_names_whitelist, flags.variable_names_blacklist,
380               flags.input_meta_graph, flags.input_saved_model_dir,
381               flags.saved_model_tags, checkpoint_version)
382
383def run_main():
384  parser = argparse.ArgumentParser()
385  parser.register("type", "bool", lambda v: v.lower() == "true")
386  parser.add_argument(
387      "--input_graph",
388      type=str,
389      default="",
390      help="TensorFlow \'GraphDef\' file to load.")
391  parser.add_argument(
392      "--input_saver",
393      type=str,
394      default="",
395      help="TensorFlow saver file to load.")
396  parser.add_argument(
397      "--input_checkpoint",
398      type=str,
399      default="",
400      help="TensorFlow variables file to load.")
401  parser.add_argument(
402      "--checkpoint_version",
403      type=int,
404      default=2,
405      help="Tensorflow variable file format")
406  parser.add_argument(
407      "--output_graph",
408      type=str,
409      default="",
410      help="Output \'GraphDef\' file name.")
411  parser.add_argument(
412      "--input_binary",
413      nargs="?",
414      const=True,
415      type="bool",
416      default=False,
417      help="Whether the input files are in binary format.")
418  parser.add_argument(
419      "--output_node_names",
420      type=str,
421      default="",
422      help="The name of the output nodes, comma separated.")
423  parser.add_argument(
424      "--restore_op_name",
425      type=str,
426      default="save/restore_all",
427      help="""\
428      The name of the master restore operator. Deprecated, unused by updated \
429      loading code.
430      """)
431  parser.add_argument(
432      "--filename_tensor_name",
433      type=str,
434      default="save/Const:0",
435      help="""\
436      The name of the tensor holding the save path. Deprecated, unused by \
437      updated loading code.
438      """)
439  parser.add_argument(
440      "--clear_devices",
441      nargs="?",
442      const=True,
443      type="bool",
444      default=True,
445      help="Whether to remove device specifications.")
446  parser.add_argument(
447      "--initializer_nodes",
448      type=str,
449      default="",
450      help="Comma separated list of initializer nodes to run before freezing.")
451  parser.add_argument(
452      "--variable_names_whitelist",
453      type=str,
454      default="",
455      help="""\
456      Comma separated list of variables to convert to constants. If specified, \
457      only those variables will be converted to constants.\
458      """)
459  parser.add_argument(
460      "--variable_names_blacklist",
461      type=str,
462      default="",
463      help="""\
464      Comma separated list of variables to skip converting to constants.\
465      """)
466  parser.add_argument(
467      "--input_meta_graph",
468      type=str,
469      default="",
470      help="TensorFlow \'MetaGraphDef\' file to load.")
471  parser.add_argument(
472      "--input_saved_model_dir",
473      type=str,
474      default="",
475      help="Path to the dir with TensorFlow \'SavedModel\' file and variables.")
476  parser.add_argument(
477      "--saved_model_tags",
478      type=str,
479      default="serve",
480      help="""\
481      Group of tag(s) of the MetaGraphDef to load, in string format,\
482      separated by \',\'. For tag-set contains multiple tags, all tags \
483      must be passed in.\
484      """)
485  flags, unparsed = parser.parse_known_args()
486
487  my_main = lambda unused_args: main(unused_args, flags)
488  app.run(main=my_main, argv=[sys.argv[0]] + unparsed)
489
490if __name__ == '__main__':
491  run_main()
492