1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15r"""Converts checkpoint variables into Const ops in a standalone GraphDef file. 16 17This script is designed to take a GraphDef proto, a SaverDef proto, and a set of 18variable values stored in a checkpoint file, and output a GraphDef with all of 19the variable ops converted into const ops containing the values of the 20variables. 21 22It's useful to do this when we need to load a single file in C++, especially in 23environments like mobile or embedded where we may not have access to the 24RestoreTensor ops and file loading calls that they rely on. 25 26An example of command-line usage is: 27bazel build tensorflow/python/tools:freeze_graph && \ 28bazel-bin/tensorflow/python/tools/freeze_graph \ 29--input_graph=some_graph_def.pb \ 30--input_checkpoint=model.ckpt-8361242 \ 31--output_graph=/tmp/frozen_graph.pb --output_node_names=softmax 32 33You can also look at freeze_graph_test.py for an example of how to use it. 34 35""" 36from __future__ import absolute_import 37from __future__ import division 38from __future__ import print_function 39 40import argparse 41import re 42import sys 43 44from google.protobuf import text_format 45 46from tensorflow.core.framework import graph_pb2 47from tensorflow.core.protobuf import saver_pb2 48from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef 49from tensorflow.python import pywrap_tensorflow 50from tensorflow.python.client import session 51from tensorflow.python.framework import graph_util 52from tensorflow.python.framework import importer 53from tensorflow.python.platform import app 54from tensorflow.python.platform import gfile 55from tensorflow.python.saved_model import loader 56from tensorflow.python.saved_model import tag_constants 57from tensorflow.python.tools import saved_model_utils 58from tensorflow.python.training import checkpoint_management 59from tensorflow.python.training import saver as saver_lib 60 61 62def _has_no_variables(sess): 63 """Determines if the graph has any variables. 64 65 Args: 66 sess: TensorFlow Session. 67 68 Returns: 69 Bool. 70 """ 71 for op in sess.graph.get_operations(): 72 if op.type.startswith("Variable") or op.type.endswith("VariableOp"): 73 return False 74 return True 75 76 77def freeze_graph_with_def_protos(input_graph_def, 78 input_saver_def, 79 input_checkpoint, 80 output_node_names, 81 restore_op_name, 82 filename_tensor_name, 83 output_graph, 84 clear_devices, 85 initializer_nodes, 86 variable_names_whitelist="", 87 variable_names_blacklist="", 88 input_meta_graph_def=None, 89 input_saved_model_dir=None, 90 saved_model_tags=None, 91 checkpoint_version=saver_pb2.SaverDef.V2): 92 """Converts all variables in a graph and checkpoint into constants. 93 94 Args: 95 input_graph_def: A `GraphDef`. 96 input_saver_def: A `SaverDef` (optional). 97 input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking 98 priority. Typically the result of `Saver.save()` or that of 99 `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or 100 V1/V2. 101 output_node_names: The name(s) of the output nodes, comma separated. 102 restore_op_name: Unused. 103 filename_tensor_name: Unused. 104 output_graph: String where to write the frozen `GraphDef`. 105 clear_devices: A Bool whether to remove device specifications. 106 initializer_nodes: Comma separated string of initializer nodes to run before 107 freezing. 108 variable_names_whitelist: The set of variable names to convert (optional, by 109 default, all variables are converted). 110 variable_names_blacklist: The set of variable names to omit converting 111 to constants (optional). 112 input_meta_graph_def: A `MetaGraphDef` (optional), 113 input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file 114 and variables (optional). 115 saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to 116 load, in string format (optional). 117 checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1 118 or saver_pb2.SaverDef.V2) 119 120 Returns: 121 Location of the output_graph_def. 122 """ 123 del restore_op_name, filename_tensor_name # Unused by updated loading code. 124 125 # 'input_checkpoint' may be a prefix if we're using Saver V2 format 126 if (not input_saved_model_dir and 127 not checkpoint_management.checkpoint_exists(input_checkpoint)): 128 print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") 129 return -1 130 131 if not output_node_names: 132 print("You need to supply the name of a node to --output_node_names.") 133 return -1 134 135 # Remove all the explicit device specifications for this node. This helps to 136 # make the graph more portable. 137 if clear_devices: 138 if input_meta_graph_def: 139 for node in input_meta_graph_def.graph_def.node: 140 node.device = "" 141 elif input_graph_def: 142 for node in input_graph_def.node: 143 node.device = "" 144 145 if input_graph_def: 146 _ = importer.import_graph_def(input_graph_def, name="") 147 with session.Session() as sess: 148 if input_saver_def: 149 saver = saver_lib.Saver( 150 saver_def=input_saver_def, write_version=checkpoint_version) 151 saver.restore(sess, input_checkpoint) 152 elif input_meta_graph_def: 153 restorer = saver_lib.import_meta_graph( 154 input_meta_graph_def, clear_devices=True) 155 restorer.restore(sess, input_checkpoint) 156 if initializer_nodes: 157 sess.run(initializer_nodes.replace(" ", "").split(",")) 158 elif input_saved_model_dir: 159 if saved_model_tags is None: 160 saved_model_tags = [] 161 loader.load(sess, saved_model_tags, input_saved_model_dir) 162 else: 163 var_list = {} 164 reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) 165 var_to_shape_map = reader.get_variable_to_shape_map() 166 167 # List of all partition variables. Because the condition is heuristic 168 # based, the list could include false positives. 169 all_parition_variable_names = [ 170 tensor.name.split(":")[0] 171 for op in sess.graph.get_operations() 172 for tensor in op.values() 173 if re.search(r"/part_\d+/", tensor.name) 174 ] 175 has_partition_var = False 176 177 for key in var_to_shape_map: 178 try: 179 tensor = sess.graph.get_tensor_by_name(key + ":0") 180 if any(key in name for name in all_parition_variable_names): 181 has_partition_var = True 182 except KeyError: 183 # This tensor doesn't exist in the graph (for example it's 184 # 'global_step' or a similar housekeeping element) so skip it. 185 continue 186 var_list[key] = tensor 187 188 try: 189 saver = saver_lib.Saver( 190 var_list=var_list, write_version=checkpoint_version) 191 except TypeError as e: 192 # `var_list` is required to be a map of variable names to Variable 193 # tensors. Partition variables are Identity tensors that cannot be 194 # handled by Saver. 195 if has_partition_var: 196 print("Models containing partition variables cannot be converted " 197 "from checkpoint files. Please pass in a SavedModel using " 198 "the flag --input_saved_model_dir.") 199 return -1 200 # Models that have been frozen previously do not contain Variables. 201 elif _has_no_variables(sess): 202 print("No variables were found in this model. It is likely the model " 203 "was frozen previously. You cannot freeze a graph twice.") 204 return 0 205 else: 206 raise e 207 208 saver.restore(sess, input_checkpoint) 209 if initializer_nodes: 210 sess.run(initializer_nodes.replace(" ", "").split(",")) 211 212 variable_names_whitelist = ( 213 variable_names_whitelist.replace(" ", "").split(",") 214 if variable_names_whitelist else None) 215 variable_names_blacklist = ( 216 variable_names_blacklist.replace(" ", "").split(",") 217 if variable_names_blacklist else None) 218 219 if input_meta_graph_def: 220 output_graph_def = graph_util.convert_variables_to_constants( 221 sess, 222 input_meta_graph_def.graph_def, 223 output_node_names.replace(" ", "").split(","), 224 variable_names_whitelist=variable_names_whitelist, 225 variable_names_blacklist=variable_names_blacklist) 226 else: 227 output_graph_def = graph_util.convert_variables_to_constants( 228 sess, 229 input_graph_def, 230 output_node_names.replace(" ", "").split(","), 231 variable_names_whitelist=variable_names_whitelist, 232 variable_names_blacklist=variable_names_blacklist) 233 234 # Write GraphDef to file if output path has been given. 235 if output_graph: 236 with gfile.GFile(output_graph, "wb") as f: 237 f.write(output_graph_def.SerializeToString()) 238 239 return output_graph_def 240 241 242def _parse_input_graph_proto(input_graph, input_binary): 243 """Parses input tensorflow graph into GraphDef proto.""" 244 if not gfile.Exists(input_graph): 245 print("Input graph file '" + input_graph + "' does not exist!") 246 return -1 247 input_graph_def = graph_pb2.GraphDef() 248 mode = "rb" if input_binary else "r" 249 with gfile.GFile(input_graph, mode) as f: 250 if input_binary: 251 input_graph_def.ParseFromString(f.read()) 252 else: 253 text_format.Merge(f.read(), input_graph_def) 254 return input_graph_def 255 256 257def _parse_input_meta_graph_proto(input_graph, input_binary): 258 """Parses input tensorflow graph into MetaGraphDef proto.""" 259 if not gfile.Exists(input_graph): 260 print("Input meta graph file '" + input_graph + "' does not exist!") 261 return -1 262 input_meta_graph_def = MetaGraphDef() 263 mode = "rb" if input_binary else "r" 264 with gfile.GFile(input_graph, mode) as f: 265 if input_binary: 266 input_meta_graph_def.ParseFromString(f.read()) 267 else: 268 text_format.Merge(f.read(), input_meta_graph_def) 269 print("Loaded meta graph file '" + input_graph) 270 return input_meta_graph_def 271 272 273def _parse_input_saver_proto(input_saver, input_binary): 274 """Parses input tensorflow Saver into SaverDef proto.""" 275 if not gfile.Exists(input_saver): 276 print("Input saver file '" + input_saver + "' does not exist!") 277 return -1 278 mode = "rb" if input_binary else "r" 279 with gfile.GFile(input_saver, mode) as f: 280 saver_def = saver_pb2.SaverDef() 281 if input_binary: 282 saver_def.ParseFromString(f.read()) 283 else: 284 text_format.Merge(f.read(), saver_def) 285 return saver_def 286 287 288def freeze_graph(input_graph, 289 input_saver, 290 input_binary, 291 input_checkpoint, 292 output_node_names, 293 restore_op_name, 294 filename_tensor_name, 295 output_graph, 296 clear_devices, 297 initializer_nodes, 298 variable_names_whitelist="", 299 variable_names_blacklist="", 300 input_meta_graph=None, 301 input_saved_model_dir=None, 302 saved_model_tags=tag_constants.SERVING, 303 checkpoint_version=saver_pb2.SaverDef.V2): 304 """Converts all variables in a graph and checkpoint into constants. 305 306 Args: 307 input_graph: A `GraphDef` file to load. 308 input_saver: A TensorFlow Saver file. 309 input_binary: A Bool. True means input_graph is .pb, False indicates .pbtxt. 310 input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking 311 priority. Typically the result of `Saver.save()` or that of 312 `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or 313 V1/V2. 314 output_node_names: The name(s) of the output nodes, comma separated. 315 restore_op_name: Unused. 316 filename_tensor_name: Unused. 317 output_graph: String where to write the frozen `GraphDef`. 318 clear_devices: A Bool whether to remove device specifications. 319 initializer_nodes: Comma separated list of initializer nodes to run before 320 freezing. 321 variable_names_whitelist: The set of variable names to convert (optional, by 322 default, all variables are converted), 323 variable_names_blacklist: The set of variable names to omit converting 324 to constants (optional). 325 input_meta_graph: A `MetaGraphDef` file to load (optional). 326 input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and 327 variables (optional). 328 saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to 329 load, in string format. 330 checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1 331 or saver_pb2.SaverDef.V2). 332 Returns: 333 String that is the location of frozen GraphDef. 334 """ 335 input_graph_def = None 336 if input_saved_model_dir: 337 input_graph_def = saved_model_utils.get_meta_graph_def( 338 input_saved_model_dir, saved_model_tags).graph_def 339 elif input_graph: 340 input_graph_def = _parse_input_graph_proto(input_graph, input_binary) 341 input_meta_graph_def = None 342 if input_meta_graph: 343 input_meta_graph_def = _parse_input_meta_graph_proto( 344 input_meta_graph, input_binary) 345 input_saver_def = None 346 if input_saver: 347 input_saver_def = _parse_input_saver_proto(input_saver, input_binary) 348 freeze_graph_with_def_protos( 349 input_graph_def, 350 input_saver_def, 351 input_checkpoint, 352 output_node_names, 353 restore_op_name, 354 filename_tensor_name, 355 output_graph, 356 clear_devices, 357 initializer_nodes, 358 variable_names_whitelist, 359 variable_names_blacklist, 360 input_meta_graph_def, 361 input_saved_model_dir, 362 saved_model_tags.replace(" ", "").split(","), 363 checkpoint_version=checkpoint_version) 364 365 366def main(unused_args, flags): 367 if flags.checkpoint_version == 1: 368 checkpoint_version = saver_pb2.SaverDef.V1 369 elif flags.checkpoint_version == 2: 370 checkpoint_version = saver_pb2.SaverDef.V2 371 else: 372 print("Invalid checkpoint version (must be '1' or '2'): %d" % 373 flags.checkpoint_version) 374 return -1 375 freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary, 376 flags.input_checkpoint, flags.output_node_names, 377 flags.restore_op_name, flags.filename_tensor_name, 378 flags.output_graph, flags.clear_devices, flags.initializer_nodes, 379 flags.variable_names_whitelist, flags.variable_names_blacklist, 380 flags.input_meta_graph, flags.input_saved_model_dir, 381 flags.saved_model_tags, checkpoint_version) 382 383def run_main(): 384 parser = argparse.ArgumentParser() 385 parser.register("type", "bool", lambda v: v.lower() == "true") 386 parser.add_argument( 387 "--input_graph", 388 type=str, 389 default="", 390 help="TensorFlow \'GraphDef\' file to load.") 391 parser.add_argument( 392 "--input_saver", 393 type=str, 394 default="", 395 help="TensorFlow saver file to load.") 396 parser.add_argument( 397 "--input_checkpoint", 398 type=str, 399 default="", 400 help="TensorFlow variables file to load.") 401 parser.add_argument( 402 "--checkpoint_version", 403 type=int, 404 default=2, 405 help="Tensorflow variable file format") 406 parser.add_argument( 407 "--output_graph", 408 type=str, 409 default="", 410 help="Output \'GraphDef\' file name.") 411 parser.add_argument( 412 "--input_binary", 413 nargs="?", 414 const=True, 415 type="bool", 416 default=False, 417 help="Whether the input files are in binary format.") 418 parser.add_argument( 419 "--output_node_names", 420 type=str, 421 default="", 422 help="The name of the output nodes, comma separated.") 423 parser.add_argument( 424 "--restore_op_name", 425 type=str, 426 default="save/restore_all", 427 help="""\ 428 The name of the master restore operator. Deprecated, unused by updated \ 429 loading code. 430 """) 431 parser.add_argument( 432 "--filename_tensor_name", 433 type=str, 434 default="save/Const:0", 435 help="""\ 436 The name of the tensor holding the save path. Deprecated, unused by \ 437 updated loading code. 438 """) 439 parser.add_argument( 440 "--clear_devices", 441 nargs="?", 442 const=True, 443 type="bool", 444 default=True, 445 help="Whether to remove device specifications.") 446 parser.add_argument( 447 "--initializer_nodes", 448 type=str, 449 default="", 450 help="Comma separated list of initializer nodes to run before freezing.") 451 parser.add_argument( 452 "--variable_names_whitelist", 453 type=str, 454 default="", 455 help="""\ 456 Comma separated list of variables to convert to constants. If specified, \ 457 only those variables will be converted to constants.\ 458 """) 459 parser.add_argument( 460 "--variable_names_blacklist", 461 type=str, 462 default="", 463 help="""\ 464 Comma separated list of variables to skip converting to constants.\ 465 """) 466 parser.add_argument( 467 "--input_meta_graph", 468 type=str, 469 default="", 470 help="TensorFlow \'MetaGraphDef\' file to load.") 471 parser.add_argument( 472 "--input_saved_model_dir", 473 type=str, 474 default="", 475 help="Path to the dir with TensorFlow \'SavedModel\' file and variables.") 476 parser.add_argument( 477 "--saved_model_tags", 478 type=str, 479 default="serve", 480 help="""\ 481 Group of tag(s) of the MetaGraphDef to load, in string format,\ 482 separated by \',\'. For tag-set contains multiple tags, all tags \ 483 must be passed in.\ 484 """) 485 flags, unparsed = parser.parse_known_args() 486 487 my_main = lambda unused_args: main(unused_args, flags) 488 app.run(main=my_main, argv=[sys.argv[0]] + unparsed) 489 490if __name__ == '__main__': 491 run_main() 492