1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Apply graph_transforms tool to MetaGraphDefs. 17 18@@meta_graph_transform 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25 26import re as _re 27 28from tensorflow.core.framework import graph_pb2 as _graph_pb2 29from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2 30from tensorflow.python.client import session as _session 31from tensorflow.python.framework import graph_util as _graph_util 32from tensorflow.python.framework import importer as _importer 33from tensorflow.python.framework import ops as _ops 34from tensorflow.python.platform import tf_logging as _logging 35from tensorflow.python.saved_model import constants as _saved_model_constants 36from tensorflow.python.training import saver as _saver_lib 37from tensorflow.python.util import compat as _compat 38from tensorflow.tools import graph_transforms as _graph_transforms 39 40 41_FREEZE_GRAPH_TRANSFORM = 'freeze_graph' 42_SPARSIFY_GATHER_TRANSFORM = 'sparsify_gather' 43 44 45def _op_name(tensor_name): 46 """Get the op name from a tensor name.""" 47 # control dependency inputs start with ^ 48 if tensor_name[0] == '^': 49 tensor_name = tensor_name[1:] 50 if ':' in tensor_name: 51 op_name, _ = tensor_name.split(':') 52 return op_name 53 return tensor_name 54 55 56def _get_shared_init_op(initializer_names): 57 """Obtain the shared init op name, if it exists. 58 59 Args: 60 initializer_names: Dictionary of the "infrastructural" nodes (initializers, 61 save and restore ops, etc.). The keys in this dictionary 62 indicate the collection where these nodes were obtained from. 63 64 Returns: 65 A string indicating the shared init op name or none if None if none exists. 66 """ 67 return_value = initializer_names.get(_saved_model_constants.MAIN_OP_KEY, None) 68 if not return_value: 69 return_value = initializer_names.get( 70 _saved_model_constants.LEGACY_INIT_OP_KEY, None) 71 return str(return_value[0]) if return_value else None 72 73 74def _gtt_transforms(graph_def, input_names, output_names, initializer_names, 75 transforms): 76 """Pass through gtt transforms, applying them to the graph_def. 77 78 Args: 79 graph_def: A GraphDef proto to be transformed. 80 input_names: Names of input nodes. 81 output_names: Names of output nodes. 82 initializer_names: Dictionary of the "infrastructural" nodes (initializers, 83 save and restore ops, etc.) that should be retained even if they are not 84 transitively reachable from output nodes. The keys in this dictionary 85 indicate the collection where these nodes were obtained from. 86 transforms: A list of strings naming the graph transforms to be applied in 87 order. 88 Returns: 89 The transformed GraphDef. 90 """ 91 if not transforms: 92 transformed_graph_def = _graph_pb2.GraphDef() 93 transformed_graph_def.CopyFrom(graph_def) 94 return transformed_graph_def 95 96 initializer_names_flat = sorted( 97 [k for l in initializer_names.values() for k in l]) 98 all_output_names = output_names + initializer_names_flat 99 return _graph_transforms.TransformGraph(graph_def, input_names, 100 all_output_names, transforms) 101 102 103def _freeze_transform(graph_def, output_names, initializer_names, saver_def, 104 checkpoint_path): 105 """Handle the freeze transform. 106 107 Determine which initializer nodes should be retained by the freeze transform. 108 Retain those nodes and return an updated dictionary containing them. 109 110 Args: 111 graph_def: A GraphDef proto to be transformed. 112 output_names: Names of output nodes. 113 initializer_names: Dictionary of the "infrastructural" nodes (initializers, 114 save and restore ops, etc.). The keys in this dictionary 115 indicate the collection where these nodes were obtained from. 116 saver_def: A SaverDef proto used for restoring a checkpoint during freezing, 117 if needed (default None). 118 checkpoint_path: A path to a checkpoint to restore during freezing, 119 if needed (default None). 120 121 Returns: 122 A tuple containing the GraphDef and a Dict of pruned initializer nodes. 123 """ 124 table_initializers = initializer_names.get(_ops.GraphKeys.TABLE_INITIALIZERS, 125 []) 126 shared_init_op = _get_shared_init_op(initializer_names) 127 128 graph_def = _freeze_graph_with_def_protos(graph_def, output_names, 129 table_initializers, shared_init_op, 130 saver_def, checkpoint_path) 131 pruned_initializer_names = {} 132 # Freeze graph prunes all initializers and shared init nodes that are not 133 # explicitly maintained. Create new initializer_names dictionary to reflect 134 # this. 135 if table_initializers: 136 pruned_initializer_names[_ops.GraphKeys.TABLE_INITIALIZERS] = ( 137 table_initializers) 138 if _saved_model_constants.LEGACY_INIT_OP_KEY in initializer_names: 139 pruned_initializer_names[_saved_model_constants.LEGACY_INIT_OP_KEY] = ( 140 initializer_names[_saved_model_constants.LEGACY_INIT_OP_KEY]) 141 if _saved_model_constants.MAIN_OP_KEY in initializer_names: 142 pruned_initializer_names[_saved_model_constants.MAIN_OP_KEY] = ( 143 initializer_names[_saved_model_constants.MAIN_OP_KEY]) 144 return (graph_def, pruned_initializer_names) 145 146 147def _clean_save_and_restore(graph_def, op, removed_op_names): 148 """Clean the specified save and restore op. 149 150 Updates the dtypes attribute of the save / restore op and the associated name 151 and shape tensors to remove entries for variables that have been removed. 152 153 Args: 154 graph_def: A GraphDef proto to be transformed. 155 op: The save or restore op to update. 156 removed_op_names: List of op names that have been removed. 157 """ 158 name = op.name + '/tensor_names' 159 shape = op.name + '/shape_and_slices' 160 name_op = _find_op(graph_def, name) 161 shape_op = _find_op(graph_def, shape) 162 name_op_value_tensor = name_op.attr['value'].tensor 163 shape_op_value_tensor = shape_op.attr['value'].tensor 164 names = [] 165 shapes = [] 166 dtypes = [] 167 for index, value in enumerate(name_op_value_tensor.string_val): 168 if not _is_removed(_compat.as_str(value), removed_op_names): 169 names.append(value) 170 shapes.append(shape_op_value_tensor.string_val[index]) 171 dtypes.append(op.attr['dtypes'].list.type[index]) 172 name_op_value_tensor.string_val[:] = names 173 name_op_value_tensor.tensor_shape.dim[0].size = len(names) 174 shape_op_value_tensor.string_val[:] = shapes 175 shape_op_value_tensor.tensor_shape.dim[0].size = len(shapes) 176 op.attr['dtypes'].list.type[:] = dtypes 177 178 if not name_op.attr['_output_shapes'].list.shape: 179 name_op.attr['_output_shapes'].list.shape.add() 180 name_op.attr['_output_shapes'].list.shape[0].dim.add() 181 name_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(names) 182 183 if not shape_op.attr['_output_shapes'].list.shape: 184 shape_op.attr['_output_shapes'].list.shape.add() 185 shape_op.attr['_output_shapes'].list.shape[0].dim.add() 186 shape_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(shapes) 187 188 189def _sparsify_gather_transform(graph_def, input_names, output_names, 190 initializer_names, checkpoint_path): 191 """Handle the sparsify gather transform. 192 193 Provides the transform the checkpoint and keeps track of the newly created 194 initializer nodes. 195 196 Args: 197 graph_def: A GraphDef proto to be transformed. 198 input_names: Names of input nodes. 199 output_names: Names of output nodes. 200 initializer_names: Dictionary of the "infrastructural" nodes (initializers, 201 save and restore ops, etc.). The keys in this dictionary 202 indicate the collection where these nodes were obtained from. 203 checkpoint_path: A path to a checkpoint. 204 205 Returns: 206 A tuple containing the GraphDef and a Dict of updated initializer nodes. 207 Raises: 208 ValueError: if the restore_op_name does not have the expected format. 209 """ 210 # Ensure that sparsify_shared_init_op is unique. 211 sparsify_shared_init_op = 'sparify_gather_init_op' 212 while _find_op(graph_def, sparsify_shared_init_op): 213 sparsify_shared_init_op += '_1' 214 215 input_flag = '' 216 if checkpoint_path: 217 input_flag = 'input_checkpoint="%s", ' % checkpoint_path 218 219 sparsify_cmd = [ 220 'sparsify_gather(%sgroup_init_node="%s")' % (input_flag, 221 sparsify_shared_init_op) 222 ] 223 224 starting_op_names = [node.name for node in graph_def.node] 225 226 graph_def = _gtt_transforms(graph_def, input_names, output_names, 227 initializer_names, sparsify_cmd) 228 ending_op_names = [node.name for node in graph_def.node] 229 removed_op_names = list(set(starting_op_names) - set(ending_op_names)) 230 removed_op_names.sort() 231 232 for op_index, op_name in enumerate(removed_op_names): 233 op_name_parts = op_name.rsplit('/', 1) 234 # Remove part to get the checkpoint names used by the saver. 235 if len(op_name_parts) == 2 and op_name_parts[1].startswith('part_'): 236 removed_op_names[op_index] = op_name_parts[0] 237 else: 238 removed_op_names[op_index] = op_name 239 240 # Obtain newly created table inits from gtt sparsify transform. 241 added_table_inits = [] 242 for index, node in enumerate(graph_def.node): 243 if node.name == sparsify_shared_init_op: 244 added_table_inits = [n.lstrip('^') for n in node.input] 245 246 table_initializers = initializer_names.get( 247 _ops.GraphKeys.TABLE_INITIALIZERS, []) 248 table_initializers.extend(added_table_inits) 249 initializer_names[_ops.GraphKeys.TABLE_INITIALIZERS] = table_initializers 250 251 del graph_def.node[index] 252 break 253 254 # Add inits to existing shared init op. 255 node = _find_op(graph_def, _get_shared_init_op(initializer_names)) 256 for init in added_table_inits: 257 node.input.append('^' + init) 258 259 # Update saver. 260 for node in graph_def.node: 261 if node.name.endswith('SaveV2'): 262 _clean_save_and_restore(graph_def, node, removed_op_names) 263 264 return (graph_def, initializer_names) 265 266 267def _do_transforms(graph_def, 268 input_names, 269 output_names, 270 initializer_names, 271 transforms, 272 saver_def=None, 273 checkpoint_path=None): 274 """Apply requested transforms to a GraphDef, including freezing. 275 276 Args: 277 graph_def: A GraphDef proto to be transformed. 278 input_names: Names of input nodes. 279 output_names: Names of output nodes. 280 initializer_names: Dictionary of the "infrastructural" nodes (initializers, 281 save and restore ops, etc.) that should be retained even if they are not 282 transitively reachable from output nodes. The keys in this dictionary 283 indicate the collection where these nodes were obtained from. 284 transforms: A list of strings naming the graph transforms to be applied in 285 order. These transform names are exactly those supported by the Graph 286 Transform Tool, with the addition of the 'freeze_graph' and 287 'sparsify_gather' transforms. 288 saver_def: A SaverDef proto used for restoring a checkpoint during freezing, 289 if needed (default None). 290 checkpoint_path: A path to a checkpoint to restore during freezing, 291 if needed (default None). 292 Returns: 293 A tuple containing the GraphDef and a Dict of updated initializer nodes. 294 """ 295 transformed_graph_def = _graph_pb2.GraphDef() 296 transformed_graph_def.CopyFrom(graph_def) 297 transformed_initializer_names = initializer_names.copy() 298 299 if not transforms: 300 return transformed_graph_def, transformed_initializer_names 301 302 current_gtt_transforms = [] 303 for t in transforms: 304 if t == _FREEZE_GRAPH_TRANSFORM: 305 transformed_graph_def = _gtt_transforms( 306 transformed_graph_def, input_names, output_names, 307 transformed_initializer_names, current_gtt_transforms) 308 output_node_names = [_op_name(x) for x in output_names] 309 transformed_graph_def, transformed_initializer_names = _freeze_transform( 310 transformed_graph_def, output_node_names, 311 transformed_initializer_names, saver_def, checkpoint_path) 312 current_gtt_transforms = [] 313 elif t == _SPARSIFY_GATHER_TRANSFORM: 314 transformed_graph_def = _gtt_transforms( 315 transformed_graph_def, input_names, output_names, 316 transformed_initializer_names, current_gtt_transforms) 317 transformed_graph_def, transformed_initializer_names = ( 318 _sparsify_gather_transform( 319 transformed_graph_def, input_names, output_names, 320 transformed_initializer_names, checkpoint_path)) 321 current_gtt_transforms = [] 322 else: 323 current_gtt_transforms.append(t) 324 325 transformed_graph_def = _gtt_transforms( 326 transformed_graph_def, input_names, output_names, 327 transformed_initializer_names, current_gtt_transforms) 328 return transformed_graph_def, transformed_initializer_names 329 330 331def _connect_to_shared_init_op(graph_def, shared_init_op_name, 332 nodes_to_connect): 333 """Creates a new shared init node that is connected to via control deps. 334 335 Args: 336 graph_def: The GraphDef proto to add the shared init node to. 337 shared_init_op_name: A string specifying the name of the shared init node to 338 create. 339 nodes_to_connect: A list of strings specifying the names of nodes to connect 340 to the shared node via control dependencies. 341 """ 342 if nodes_to_connect: 343 init_op = graph_def.node.add() 344 init_op.name = shared_init_op_name 345 init_op.op = 'NoOp' 346 init_op.input.extend(['^' + i for i in nodes_to_connect]) 347 348 349# forked and modified from freeze_graph.py 350def _freeze_graph_with_def_protos(input_graph_def, output_node_names, 351 initializer_names, shared_init_op_name, 352 input_saver_def, input_checkpoint): 353 """Converts all variables in a graph and checkpoint into constants. 354 355 During this process, we need to retain certain initializer nodes (e.g. table 356 initializer nodes). Instead of determining which dependencies 357 of the shared initializer node (e.g. group_deps) to keep, we 358 reconstruct the connections between the individual initializer nodes and 359 the shared node after freezing the graph. 360 361 Args: 362 input_graph_def: A GraphDef proto to be frozen. 363 output_node_names: Names of output nodes. 364 initializer_names: Names of initializer nodes to keep. 365 shared_init_op_name: The name of the shared initializer node to connect the 366 nodes in initializer names to. 367 input_saver_def: A SaverDef proto used for restoring a checkpoint. 368 input_checkpoint: A path to a checkpoint to restore. 369 370 Returns: 371 A frozen GraphDef. 372 """ 373 374 with _ops.Graph().as_default(): 375 _ = _importer.import_graph_def(input_graph_def, name='') 376 377 with _session.Session() as sess: 378 saver = _saver_lib.Saver(saver_def=input_saver_def) 379 saver.restore(sess, input_checkpoint) 380 output_graph_def = _graph_util.convert_variables_to_constants( 381 sess, input_graph_def, output_node_names + initializer_names) 382 _connect_to_shared_init_op(output_graph_def, shared_init_op_name, 383 initializer_names) 384 return output_graph_def 385 386 387def _find_all_mandatory_retain_ops(base_meta_graph_def): 388 """Identify all infrastructural Ops, to ensure that they are retained. 389 390 We need to retain infrastructural Ops (init and saver stuff), in addition 391 to the desired outputs. 392 393 For now we retain *all* save and restore ops, variable initializers, 394 table initializers, and main init ops. 395 This means that strip_unused_nodes will not remove unused variables. 396 397 Args: 398 base_meta_graph_def: a GraphDef proto in which to identify nodes to retain. 399 400 Returns: 401 A dictionary corresponding to the nodes associated with each collection 402 that are to be retained. 403 """ 404 # TODO(b/63447631): implement variable stripping. 405 406 initializer_names = {} 407 408 # Primary SaverDef and SAVERS collection 409 saver_defs = [] 410 if base_meta_graph_def.HasField('saver_def'): 411 saver_defs.append(base_meta_graph_def.saver_def) 412 saver_defs.extend(_get_all_protos_from_collection( 413 base_meta_graph_def, _ops.GraphKeys.SAVERS)) 414 for saver_def in saver_defs: 415 savers = initializer_names.get(_ops.GraphKeys.SAVERS, []) 416 savers.extend([ 417 saver_def.filename_tensor_name, saver_def.save_tensor_name, 418 saver_def.restore_op_name 419 ]) 420 initializer_names[_ops.GraphKeys.SAVERS] = savers 421 422 # Variable initializers 423 variable_collections = [ 424 _ops.GraphKeys.GLOBAL_VARIABLES, 425 _ops.GraphKeys.TRAINABLE_VARIABLES, 426 _ops.GraphKeys.MOVING_AVERAGE_VARIABLES, 427 _ops.GraphKeys.LOCAL_VARIABLES, 428 _ops.GraphKeys.MODEL_VARIABLES] 429 for var_coll in variable_collections: 430 variables = _get_all_protos_from_collection(base_meta_graph_def, var_coll) 431 var_init_names = [v.initializer_name for v in variables] 432 if var_init_names: 433 # Sanity check to ensure we don't overwrite dictionary entries. 434 assert var_coll not in initializer_names 435 initializer_names[var_coll] = var_init_names 436 437 # Table initializers 438 op_names = _get_all_node_names_from_collection( 439 base_meta_graph_def, _ops.GraphKeys.TABLE_INITIALIZERS) 440 if op_names: 441 # Sanity check to ensure we don't overwrite dictionary entries. 442 assert _ops.GraphKeys.TABLE_INITIALIZERS not in initializer_names 443 table_initializers = [t for t in op_names] 444 initializer_names[_ops.GraphKeys.TABLE_INITIALIZERS] = table_initializers 445 446 # Various init ops 447 various_init_op_collections = [_saved_model_constants.LEGACY_INIT_OP_KEY, 448 _saved_model_constants.MAIN_OP_KEY, 449 _ops.GraphKeys.INIT_OP, 450 _ops.GraphKeys.LOCAL_INIT_OP, 451 _ops.GraphKeys.READY_OP, 452 _ops.GraphKeys.READY_FOR_LOCAL_INIT_OP] 453 for op_coll in various_init_op_collections: 454 op_name = _get_single_node_name_from_collection( 455 base_meta_graph_def, op_coll) 456 if op_name: 457 # Sanity check to ensure we don't overwrite dictionary entries. 458 assert op_coll not in initializer_names 459 initializer_names[op_coll] = [op_name] 460 return initializer_names 461 462 463def _add_pruned_collection(base_meta_graph_def, meta_graph_def, 464 collection_name, removed_op_names): 465 """Copy collection to the transformed MetaGraphDef, omitting removed items.""" 466 467 base_collection = base_meta_graph_def.collection_def[collection_name] 468 collection = meta_graph_def.collection_def[collection_name] 469 470 if base_collection.HasField('any_list'): 471 for any_value in base_collection.any_list.value: 472 # just search the serialized proto as a string 473 if not _is_removed_mentioned(any_value.value, removed_op_names): 474 copied_any = collection.any_list.value.add() 475 copied_any.CopyFrom(any_value) 476 elif base_collection.HasField('bytes_list'): 477 collection.bytes_list.value[:] = [ 478 s for s in base_collection.bytes_list.value 479 if not _is_removed_mentioned(s, removed_op_names)] 480 _logging.info( 481 'In collection %s, nodes excluded are: %s', collection_name, 482 sorted([ 483 s for s in base_collection.bytes_list.value 484 if _is_removed_mentioned(s, removed_op_names) 485 ])) 486 elif base_collection.HasField('node_list'): 487 collection.node_list.value[:] = [ 488 s for s in base_collection.node_list.value 489 if not _is_removed(s, removed_op_names)] 490 else: 491 collection.CopyFrom(base_collection) 492 493 494def _add_pruned_saver(base_meta_graph_def, meta_graph_def, removed_op_names): 495 """Copy the Saver into the transformed MetaGraphDef, if valid. 496 497 Currently this copies the Saver as is, after verifying that none of the 498 referenced Save & Restore ops were removed. A future version will modify 499 the Save and Restore ops themselves as needed to account for removed 500 Variables. 501 502 Args: 503 base_meta_graph_def: The untransformed MetaGraphDef. 504 meta_graph_def: The transformed MetaGraphDef being built. 505 removed_op_names: An iterable of names of ops that were removed. 506 """ 507 508 # Note this does surgery on meta_graph_def.graph_def too, so that should have 509 # been copied already. 510 if base_meta_graph_def.HasField('saver_def'): 511 filename_tensor_name = base_meta_graph_def.saver_def.filename_tensor_name 512 save_tensor_name = base_meta_graph_def.saver_def.save_tensor_name 513 restore_op_name = base_meta_graph_def.saver_def.restore_op_name 514 515 _check_tensor_not_removed(filename_tensor_name, removed_op_names) 516 _check_tensor_not_removed(save_tensor_name, removed_op_names) 517 _check_tensor_not_removed(restore_op_name, removed_op_names) 518 519 # TODO(b/63447631): Once we strip unused variables, remove references to 520 # them from save and restore ops. Retain those ops only if they also refer 521 # to retained Variables. See if we can use _clean_save_and_restore() for 522 # this. 523 524 # saver_name, restore_all = restore_op_name.rsplit('/', 1) 525 # if restore_all != 'restore_all': 526 # raise ValueError( 527 # 'SaverDef restore_op_name did not have expected form */restore_all') 528 529 # save_tensor_names_op_name = '{}/SaveV2/tensor_names'.format(saver_name) 530 # restore_tensor_names_op_name = ( 531 # '{}/RestoreV2/tensor_names'.format(saver_name)) 532 533 # save_tensor_names_op = _find_op(meta_graph_def.graph_def, 534 # save_tensor_names_op_name) 535 # save_tensor_names_value_tensor = save_tensor_names_op.attr['value'].tensor 536 # save_tensor_names_value_tensor.string_val[:] = [ 537 # s for s in save_tensor_names_value_tensor.string_val 538 # if not _is_removed(s, removed_op_names)] 539 540 # restore_tensor_names_op = _find_op( 541 # meta_graph_def.graph_def, restore_tensor_names_op_name) 542 # restore_tensor_names_value_tensor = ( 543 # restore_tensor_names_op.attr['value'].tensor) 544 # restore_tensor_names_value_tensor.string_val[:] = [ 545 # s for s in restore_tensor_names_value_tensor.string_val 546 # if not _is_removed(s, removed_op_names)] 547 548 # if (save_tensor_names_value_tensor.string_val 549 # or restore_tensor_names_value_tensor.string_val): 550 meta_graph_def.saver_def.CopyFrom(base_meta_graph_def.saver_def) 551 552 553def _find_op(graph_def, op_name): 554 """Fetch a node from a GraphDef proto by name.""" 555 for node_def in graph_def.node: 556 if node_def.name == op_name: 557 return node_def 558 return None 559 560 561def _add_pruned_signature(base_meta_graph_def, meta_graph_def, 562 signature_name, removed_op_names): 563 """Copy the named signature into the transformed MetaGraphDef, if valid. 564 565 If any input or output mentioned in the signature was removed by the graph 566 transform, the signature is silently omitted from the transformed 567 MetaGraphDef. 568 569 Args: 570 base_meta_graph_def: The untransformed MetaGraphDef. 571 meta_graph_def: The transformed MetaGraphDef being built. 572 signature_name: The name of the signature to copy. 573 removed_op_names: An iterable of names of ops that were removed. 574 """ 575 try: 576 base_signature = base_meta_graph_def.signature_def[signature_name] 577 for key in base_signature.inputs: 578 _check_tensor_not_removed(base_signature.inputs[key].name, 579 removed_op_names) 580 for key in base_signature.outputs: 581 _check_tensor_not_removed(base_signature.outputs[key].name, 582 removed_op_names) 583 meta_graph_def.signature_def[signature_name].CopyFrom(base_signature) 584 except ValueError: 585 # exclude any signature that mentions a removed node 586 pass 587 588 589def _get_single_node_name_from_collection(meta_graph_def, collection_key): 590 """Obtain a node name that is the single element of a collection.""" 591 if collection_key not in meta_graph_def.collection_def: 592 return None 593 collection = meta_graph_def.collection_def[collection_key] 594 if not collection.node_list.value: 595 raise ValueError( 596 'Collection {} is present but type is not node_list.'.format( 597 collection_key)) 598 if len(collection.node_list.value) != 1: 599 raise ValueError( 600 'Collection {} is has {} elements; expected exactly one.'.format( 601 collection_key, collection.bytes_list)) 602 return collection.node_list.value[0] 603 604 605def _get_all_node_names_from_collection(meta_graph_def, collection_key): 606 """Obtain node names from a collection.""" 607 if collection_key not in meta_graph_def.collection_def: 608 return None 609 collection = meta_graph_def.collection_def[collection_key] 610 if not collection.node_list.value: 611 raise ValueError( 612 'Collection {} is present but type is not node_list.'.format( 613 collection_key)) 614 return collection.node_list.value 615 616 617def _get_all_protos_from_collection(meta_graph_def, collection_key): 618 """Obtain node names from a collection.""" 619 if collection_key not in meta_graph_def.collection_def: 620 return [] 621 collection = meta_graph_def.collection_def[collection_key] 622 if not collection.bytes_list.value: 623 raise ValueError( 624 'Collection {} is present but type is not bytes_list.'.format( 625 collection_key)) 626 proto_type = _ops.get_collection_proto_type(collection_key) 627 result = [] 628 for value in collection.bytes_list.value: 629 proto = proto_type() 630 proto.ParseFromString(value) 631 result.append(proto) 632 return result 633 634 635def _is_removed(tensor_name, removed_op_names): 636 """Determine whether the named tensor is an output of a removed op.""" 637 for removed_op_name in removed_op_names: 638 if tensor_name.split(':')[0] == removed_op_name: 639 return True 640 return False 641 642 643def _is_removed_mentioned(s, removed_op_names): 644 """Determine whether any removed op is mentioned in the given object. 645 646 This relies on the string representation of the object. This is used for 647 proto messages that may mention ops by name in nested fields. The string 648 representation of the proto includes those field values, so this string 649 search approach is sufficient. 650 651 Args: 652 s: an object to search for removed op names. 653 removed_op_names: An iterable of names of ops that were removed. 654 655 Returns: 656 True if any removed op is mentioned in the given object, False otherwise. 657 """ 658 # A common approach taken by some of the transforms in gtt is to add new nodes 659 # that have the same prefix as the node they are removing. For example, if 660 # the original node name was /foo, they may remove that node and add in 661 # /foo/bar. This regex ensures that we handle these two nodes 662 # as separate entities. It matches on nodes having names in the form of 663 # '/foo/bar_x' as well as nodes having names in the form of 'foo.' 664 s_names = _re.findall(r'((?:[\/]?[a-zA-Z0-9\_]*)*)', _compat.as_str_any(s)) 665 for removed_op_name in removed_op_names: 666 for s_name in s_names: 667 if s_name.endswith(removed_op_name): 668 return True 669 return False 670 671 672def _check_tensor_not_removed(tensor_name, removed_op_names): 673 """Verify that the named tensor was not removed. 674 675 Args: 676 tensor_name: the name of a tensor to check. 677 removed_op_names: An iterable of names of ops that were removed. 678 679 Raises: 680 ValueError: if the tensor was removed. 681 """ 682 if not tensor_name: 683 raise ValueError('Tensor name should not be empty') 684 if _is_removed(tensor_name, removed_op_names): 685 raise ValueError( 686 'Expected Tensor, but it was removed: {}'.format(tensor_name)) 687 688 689def _add_new_inits_to_collection(meta_graph_def, updated_initializer_names): 690 """Add new inits to collection. 691 692 Args: 693 meta_graph_def: The MetaGraphDef protocol buffer to update. 694 updated_initializer_names: Dictionary of the updated "infrastructural" nodes 695 (initializers, save and restore ops, etc.). The keys in this dictionary 696 indicate the collection where these nodes were obtained from. 697 698 Raises: 699 ValueError: if the tensor was removed. 700 """ 701 # TODO(dzats): Extend this to support all collections. 702 if _ops.GraphKeys.TABLE_INITIALIZERS in updated_initializer_names: 703 orig_table_inits = _get_all_node_names_from_collection( 704 meta_graph_def, _ops.GraphKeys.TABLE_INITIALIZERS) 705 orig_table_inits = orig_table_inits if orig_table_inits else [] 706 updated_table_inits = updated_initializer_names[ 707 _ops.GraphKeys.TABLE_INITIALIZERS] 708 new_table_inits = list(set(updated_table_inits) - set(orig_table_inits)) 709 new_table_inits.sort() 710 meta_graph_def.collection_def[ 711 _ops.GraphKeys.TABLE_INITIALIZERS].node_list.value.extend( 712 new_table_inits) 713 714 715def meta_graph_transform( 716 base_meta_graph_def, input_names, output_names, transforms, tags, 717 checkpoint_path=None): 718 """Apply the Graph Transform tool to a MetaGraphDef. 719 720 Args: 721 base_meta_graph_def: A MetaGraphDef protocol buffer to transform. 722 input_names: Names of input nodes. 723 output_names: Names of output nodes. 724 transforms: A list of strings naming the graph transforms to be applied in 725 order. These transform names are exactly those supported by the Graph 726 Transform Tool, with the addition of the 'freeze_graph' and 727 'sparsify_gather' transforms. 728 tags: A list of tags with which to annotate the transformed MetaGraphDef. 729 checkpoint_path: A path to a checkpoint to restore during freezing, 730 if needed (default None). 731 732 Returns: 733 A new transformed MetaGraphDef protocol buffer. 734 """ 735 meta_graph_def = _meta_graph_pb2.MetaGraphDef() 736 737 initializer_names = _find_all_mandatory_retain_ops(base_meta_graph_def) 738 739 transformed_graph_def, updated_initializer_names = _do_transforms( 740 base_meta_graph_def.graph_def, input_names, output_names, 741 initializer_names, transforms, base_meta_graph_def.saver_def, 742 checkpoint_path) 743 744 meta_graph_def.graph_def.CopyFrom(transformed_graph_def) 745 meta_graph_def.meta_info_def.CopyFrom(base_meta_graph_def.meta_info_def) 746 meta_graph_def.meta_info_def.ClearField('tags') 747 for tag in tags: 748 meta_graph_def.meta_info_def.tags.append(tag) 749 750 base_op_names = [_compat.as_str(node.name) 751 for node in base_meta_graph_def.graph_def.node] 752 retained_op_names = [_compat.as_str(node.name) 753 for node in meta_graph_def.graph_def.node] 754 removed_op_names = set(base_op_names) - set(retained_op_names) 755 _logging.info('Node names in base graph: %s', sorted(base_op_names)) 756 _logging.info('Node names retained: %s', sorted(retained_op_names)) 757 _logging.info('Node names removed: %s', sorted(removed_op_names)) 758 759 # Copy saver, excluding any pruned nodes if graph was not frozen. 760 # TODO(b/63447631): Revisit this once the problem is addressed. Currently 761 # _add_pruned_saver assumes that the save and restore nodes have not been 762 # removed but freeze_graph (correctly) removes them. 763 if _FREEZE_GRAPH_TRANSFORM not in transforms: 764 _add_pruned_saver(base_meta_graph_def, meta_graph_def, removed_op_names) 765 766 # Copy collections, excluding any pruned nodes 767 for collection_name in base_meta_graph_def.collection_def: 768 _add_pruned_collection( 769 base_meta_graph_def, meta_graph_def, collection_name, 770 removed_op_names) 771 772 # Append newly added initializers to collection. 773 _add_new_inits_to_collection(meta_graph_def, updated_initializer_names) 774 775 # Copy signature_defs, excluding any pruned nodes 776 for signature_name in base_meta_graph_def.signature_def: 777 _add_pruned_signature( 778 base_meta_graph_def, meta_graph_def, signature_name, 779 removed_op_names) 780 781 return meta_graph_def 782