1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helpers to convert variables to constants in TensorFlow 2.0.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import numpy as np 23 24from tensorflow.core.framework import attr_value_pb2 25from tensorflow.core.framework import graph_pb2 26from tensorflow.core.framework import tensor_shape_pb2 27from tensorflow.core.framework import variable_pb2 28from tensorflow.core.protobuf import config_pb2 29from tensorflow.core.protobuf import meta_graph_pb2 30from tensorflow.core.protobuf import rewriter_config_pb2 31from tensorflow.python.eager import context 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import graph_util 34from tensorflow.python.framework import ops 35from tensorflow.python.framework import tensor_util 36from tensorflow.python.grappler import tf_optimizer 37from tensorflow.python.ops import array_ops 38from tensorflow.python.training.saver import export_meta_graph 39from tensorflow.python.util import lazy_loader 40from tensorflow.python.util import object_identity 41 42# Lazy load the single eager module to avoid introducing new dependencies for 43# graph_util:convert_variables_to_constants (eg in 44# tensorflow/contrib/session_bundle:session_bundle_py_test). 45wrap_function = lazy_loader.LazyLoader( 46 "wrap_function", globals(), 47 "tensorflow.python.eager.wrap_function") 48 49_CONDITIONAL_OPS = set(["If", "StatelessIf"]) 50_LOOP_OPS = set(["While", "StatelessWhile"]) 51_CONTROL_FLOW_OPS = _CONDITIONAL_OPS.union(_LOOP_OPS) 52 53 54class _TensorData( 55 collections.namedtuple("_TensorData", ["numpy", "dtype", "index"])): 56 """Data about a tensor that was converted to a constant.""" 57 __slots__ = () 58 59 @property 60 def dtype_attr(self): 61 return attr_value_pb2.AttrValue(type=self.dtype) 62 63 64class _EndPoint(collections.namedtuple("_EndPoint", ["convertible", "index"])): 65 """An endpoint in a graph.""" 66 __slots__ = () 67 68 def __str__(self): 69 return "{}[{}]".format(self.convertible, self.index) 70 71 72class _Edge(collections.namedtuple("_Edge", ["source", "destination"])): 73 """A directed graph edge.""" 74 __slots__ = () 75 76 def __str__(self): 77 return "{} -> {}".format(self.source, self.destination) 78 79 80class _Convertible(object): 81 """An entity that can have variables converted to constants.""" 82 83 def __init__(self, enclosing_graph): 84 self._enclosing_graph = enclosing_graph 85 self._outgoing_edges = [] 86 self._converted_self = None 87 88 def converted_self(self): 89 """A copy of this Convertible to be modified during conversion. 90 91 Returns: 92 Implementations should return the copied instance, which in turn should 93 be contained in converted_enclosing_graph(). This instance is the one that 94 will be modified during conversion. Its main use will be in the 95 implementations of convert_variable_to_constant(). 96 """ 97 raise NotImplementedError() 98 99 def convert_variable_to_constant(self, incoming_edge, tensor_data): 100 """Converts a variable in this Convertible and its dependencies. 101 102 This method should make sure that a converted copy of itself is present in 103 the converted graph, and that all Convertibles depending on this one also go 104 through the same process. 105 106 Args: 107 incoming_edge: The graph edge into this Convertible that is being 108 converted to a constant. 109 tensor_data: The tensor representing the constant. 110 """ 111 raise NotImplementedError() 112 113 def create_edges(self): 114 """Calls add_outgoing_edge for all edges known to this Convertible. 115 116 This is used to build the graph dependencies, so that conversion of 117 variables to constants can be properly propagated through the graph. Usually 118 this method will call add_outgoing_edge() to all the Convertible inputs. 119 """ 120 raise NotImplementedError() 121 122 def add_outgoing_edge(self, edge): 123 """Adds an outgoing edge to the Convertible's list of edges. 124 125 Args: 126 edge: The outgoing edge (its source should be 'self'). 127 """ 128 self._outgoing_edges.append(edge) 129 130 @property 131 def converted_enclosing_graph(self): 132 """The graph being converted.""" 133 return self._enclosing_graph.converted_self() 134 135 @property 136 def outgoing_edges(self): 137 """The list of edges starting at this Convertible.""" 138 return self._outgoing_edges 139 140 141class _Function(_Convertible): 142 """A library function Convertible. 143 144 Edges into functions are edges from node _inputs_ into function _inputs_: 145 Functions get their input from their callers, not from node outputs, and the 146 callers in turn get those values as inputs. 147 """ 148 149 def __init__(self, function, enclosing_graph): 150 super(_Function, self).__init__(enclosing_graph) 151 self._function = function 152 self._nodes = { 153 n.name: 154 _Node.new(node=n, function=self, enclosing_graph=enclosing_graph) 155 for n in function.node_def 156 } 157 158 def __str__(self): 159 return self.function.signature.name 160 161 @property 162 def function(self): 163 return self._function 164 165 @property 166 def nodes(self): 167 return self._nodes 168 169 def converted_self(self): 170 """The Function copy to be converted. 171 172 The copy will be renamed according to the graph's converted_function_name 173 map, to ensure the name does not match anything currently in TensorFlow's 174 function cache. 175 176 Returns: 177 The function instance to be converted. 178 """ 179 if self._converted_self is None: 180 old_name = self.function.signature.name 181 new_name = self._enclosing_graph.converted_function_names[old_name] 182 self.converted_enclosing_graph.rename_function(old_name, new_name) 183 self._converted_self = self.converted_enclosing_graph.functions[new_name] 184 return self._converted_self 185 186 def convert_variable_to_constant(self, incoming_edge, tensor_data): 187 """Converts one function argument into a constant. 188 189 Args: 190 incoming_edge: The edge into the argument to be converted. 191 tensor_data: The constant value. 192 """ 193 function = self.converted_self().function 194 index = incoming_edge.destination.index 195 function.signature.input_arg[index].type = tensor_data.dtype 196 197 for edge in self.outgoing_edges: 198 if edge.source.index == index: 199 edge.destination.convertible.convert_variable_to_constant( 200 edge, tensor_data) 201 202 def create_edges(self): 203 for n in self._nodes.values(): 204 n.create_edges() 205 206 207class _Node(_Convertible): 208 """A Convertible NodeDef.""" 209 210 def __init__(self, node, function, enclosing_graph): 211 super(_Node, self).__init__(enclosing_graph) 212 self._node = node 213 self._function = function 214 215 def __str__(self): 216 return self._node.name 217 218 @staticmethod 219 def new(node, function, enclosing_graph): 220 """Creates a new _Node base on its operation type.""" 221 if node.op in ["VariableV2", "VarHandleOp", "Placeholder"]: 222 return _VarHandle(node, function, enclosing_graph) 223 elif node.op == "Case": 224 return _Case(node, function, enclosing_graph) 225 elif node.op == "Merge": 226 return _Merge(node, function, enclosing_graph) 227 elif node.op == "PartitionedCall": 228 return _PartitionedCall(node, function, enclosing_graph) 229 elif node.op == "ReadVariableOp": 230 return _ReadVariable(node, function, enclosing_graph) 231 elif node.op == "ResourceGather": 232 return _ResourceGather(node, function, enclosing_graph) 233 elif node.op == "ResourceGatherNd": 234 return _ResourceGatherNd(node, function, enclosing_graph) 235 elif node.op in ["If", "StatelessIf"]: 236 return _If(node, function, enclosing_graph) 237 elif node.op in ["While", "StatelessWhile"]: 238 return _While(node, function, enclosing_graph) 239 elif node.op in [ 240 "Enter", "Exit", "Identity", "NextIteration", "Switch", "_SwitchN"]: 241 return _Intermediate(node, function, enclosing_graph) 242 else: 243 return _Node(node, function, enclosing_graph) 244 245 @property 246 def node(self): 247 return self._node 248 249 @property 250 def container(self): 251 """The node container (either a graph or a function).""" 252 if self._function is not None: 253 return self._function.function 254 return self._enclosing_graph.graph_def 255 256 def converted_self(self): 257 """The NodeDef to be converted. 258 259 Returns: 260 The NodeDef to be converted, which can come from either a graph for a 261 function. Derived classes should call this (via 'super') to make sure the 262 node is retrieved from the right place. 263 """ 264 if self._converted_self is None: 265 source = self._function or self._enclosing_graph 266 self._converted_self = source.converted_self().nodes[self._node.name] 267 return self._converted_self 268 269 def convert_variable_to_constant(self, incoming_edge, tensor_data): 270 pass 271 272 def create_edges(self): 273 for index, name in enumerate(self._node.input): 274 # Discard edges from control inputs. 275 if name[0] == "^": 276 continue 277 source = self.resolve_input(name) 278 source.convertible.add_outgoing_edge( 279 _Edge(source, _EndPoint(self, index))) 280 281 def resolve_input(self, input_name): 282 """Resolves an input into its _EndPoint. 283 284 A NodeDef's input name can refer to either global NodeDefs (in the 285 GraphDef's node list), a NodeDef in a function's node list, or a Function 286 (in the GraphDef's function library). The name can also carry semantic 287 information, depending on whether it starts with "^". This method handles 288 all that logic in order to find the object to which the input name refers 289 to. 290 291 Args: 292 input_name: The input name to resolve. 293 294 Returns: 295 The object referred to by 'input_name'. 296 """ 297 298 # The logic below oversimplifies the semantics, but is good enough for the 299 # purposes of converting to constants. The introduction of new types of 300 # operations may change this, forcing the code to be more generic. 301 # 302 # In particular, we are assuming that the lack of an index suffix means 303 # ":0", when it could mean "all the outputs of a node." This works now 304 # because converting to constants relies very little on output types, and 305 # when it does it specializes its treatment in dedicated classes. 306 name_elts = input_name.split(":") 307 source_name = name_elts[0] 308 if source_name[0] == "^": 309 source_name = source_name[1:] 310 source_index = 0 311 if len(name_elts) > 1 and name_elts[-1].isnumeric(): 312 source_index = int(name_elts[-1]) 313 314 if self._function is None: 315 return _EndPoint(self._enclosing_graph.nodes[source_name], source_index) 316 317 if source_index != 0 or source_name in self._function.nodes: 318 return _EndPoint(self._function.nodes[source_name], source_index) 319 320 inputs = [i.name for i in self._function.function.signature.input_arg] 321 return _EndPoint(self._function, inputs.index(source_name)) 322 323 def update_dtype(self, attr_name, index, dtype): 324 """Changes the type of a given input. 325 326 Args: 327 attr_name: The NodeDef attribute containing the type to change. 328 index: The index of the input type to change. 329 dtype: The type to change to. 330 """ 331 attr = self._node.attr[attr_name] 332 num_types = 0 333 # Check for various 'oneof' possibilities, and update the type if 334 # index in range. 335 if attr.HasField("list"): 336 types = attr.list.type 337 num_types = len(types) 338 if num_types > index: 339 types[index] = dtype 340 return 341 elif attr.HasField("type"): 342 num_types = 1 343 if index == 0: 344 attr.type = dtype 345 return 346 raise ValueError( 347 "Index %d out of range for node(%s).attr(%s), which has %d elements." % 348 (index, self._node.name, attr_name, num_types)) 349 350 351class _Intermediate(_Node): 352 """Specialization of _Node to intermediate ops.""" 353 354 def convert_variable_to_constant(self, incoming_edge, tensor_data): 355 node = self.converted_self() 356 node.update_dtype("T", incoming_edge.destination.index, tensor_data.dtype) 357 if "_output_shapes" in node.node.attr: 358 del node.node.attr["_output_shapes"] 359 for edge in self.outgoing_edges: 360 edge.destination.convertible.convert_variable_to_constant( 361 edge, tensor_data) 362 363 364class _Merge(_Node): 365 """Specialization of _Node to Merge ops.""" 366 367 def convert_variable_to_constant(self, incoming_edge, tensor_data): 368 # The Merge operation has a single type for all its inputs, the number of 369 # which is reflected in the "N" attribute. For the time being, we assume 370 # that unilaterally changing all of them at once is ok. 371 super(_Merge, self).convert_variable_to_constant( 372 _Edge(incoming_edge.source, 373 _Edge(incoming_edge.destination.convertible, 0)), tensor_data) 374 375 376class _VarHandle(_Node): 377 """Specialization of _Node to VarHandleOp.""" 378 379 def convert_variable_to_constant(self, incoming_edge, tensor_data): 380 tensor_proto = tensor_util.make_tensor_proto(tensor_data.numpy, 381 tensor_data.dtype, 382 tensor_data.numpy.shape) 383 384 node = self.converted_self().node 385 node.Clear() 386 node.name = self._node.name 387 node.op = "Const" 388 node.attr["dtype"].CopyFrom(tensor_data.dtype_attr) 389 node.attr["value"].tensor.CopyFrom(tensor_proto) 390 391 for edge in self.outgoing_edges: 392 edge.destination.convertible.convert_variable_to_constant( 393 edge, tensor_data) 394 395 396class _ResourceGather(_Node): 397 """Specialization of _Node to ResourceGather.""" 398 399 def convert_variable_to_constant(self, incoming_edge, tensor_data): 400 # We currently skip the conversion if this is inside a function. 401 if self._function is not None: 402 return 403 if self._node.attr["batch_dims"].i != 0: 404 raise ValueError("batch_dims != 0 is not supported by freeze_graph.") 405 axis_node_name = self._node.name + "/axis" 406 axis_dtype = self._node.attr["Tindices"] 407 axis_data = np.array(self._node.attr["batch_dims"].i) 408 output_axis_node = self.converted_self().container.node.add() 409 output_axis_node.name = axis_node_name 410 output_axis_node.op = "Const" 411 output_axis_node.attr["dtype"].CopyFrom(axis_dtype) 412 tensor = tensor_util.make_tensor_proto( 413 axis_data, dtype=axis_dtype.type, shape=axis_data.shape) 414 output_axis_node.attr["value"].tensor.CopyFrom(tensor) 415 416 output_node = self.converted_self().node 417 output_node.Clear() 418 output_node.name = self._node.name 419 output_node.op = "GatherV2" 420 output_node.input.extend( 421 [self._node.input[0], self._node.input[1], axis_node_name]) 422 output_node.attr["Tparams"].CopyFrom(self._node.attr["dtype"]) 423 output_node.attr["Tindices"].CopyFrom(self._node.attr["Tindices"]) 424 output_node.attr["Taxis"].CopyFrom(axis_dtype) 425 if "_class" in self._node.attr: 426 output_node.attr["_class"].CopyFrom(self._node.attr["_class"]) 427 428 429class _ResourceGatherNd(_Node): 430 """Specialization of _Node to ResourceGatherNd.""" 431 432 def convert_variable_to_constant(self, incoming_edge, tensor_data): 433 output_node = self.converted_self().node 434 output_node.Clear() 435 output_node.name = self._node.name 436 output_node.op = "GatherNd" 437 output_node.input.extend([self._node.input[0], self._node.input[1]]) 438 output_node.attr["Tparams"].CopyFrom(self._node.attr["dtype"]) 439 output_node.attr["Tindices"].CopyFrom(self._node.attr["Tindices"]) 440 if "_class" in self._node.attr: 441 output_node.attr["_class"].CopyFrom(self._node.attr["_class"]) 442 443 444class _ReadVariable(_Node): 445 """Specialization of _Node to ReadVariableOp.""" 446 447 def convert_variable_to_constant(self, incoming_edge, tensor_data): 448 node = self.converted_self().node 449 node.Clear() 450 node.name = self._node.name 451 node.op = "Identity" 452 453 node.input.append(self._node.input[0]) 454 node.attr["T"].CopyFrom(self._node.attr["dtype"]) 455 if "_class" in self._node.attr: 456 node.attr["_class"].CopyFrom(self._node.attr["_class"]) 457 458 # If the ReadVariableOp is part of a function, then every node having the 459 # ReadVariableOp one as its input will refer to it using a ":value" 460 # syntax. We need to change that to ":output". 461 if self._function is not None: 462 for edge in self.outgoing_edges: 463 index = edge.destination.index 464 dest = edge.destination.convertible.converted_self() 465 if isinstance(dest, _Node): 466 input_name_parts = dest.node.input[index].split(":") 467 if len(input_name_parts) > 1 and input_name_parts[1] == "value": 468 input_name_parts[1] = "output" 469 dest.node.input[index] = ":".join(input_name_parts) 470 471 472class _FunctionCaller(_Node): 473 """A base class for Convertibles that reference functions.""" 474 475 def __init__(self, node, function, enclosing_graph, first_function_input, 476 type_attribute, function_attributes): 477 """Initializes a _FunctionCaller. 478 479 Args: 480 node: As in _Node. 481 function: As in _Node. 482 enclosing_graph: As in _Node. 483 first_function_input: The index of the first NodeDef input that is tied to 484 the function inputs. It is assumed that the rest of the NodeDef inputs 485 map one to one to function inputs. 486 type_attribute: The name of the NodeDef attribute that defines the input 487 types. It is assumed that the types listed here map one-to-one with the 488 function inputs (that is, they do _not_ specify types for inputs that 489 are not passed to functions). 490 function_attributes: The names of the NodeDef attributes containing 491 references to functions. 492 """ 493 super(_FunctionCaller, self).__init__(node, function, enclosing_graph) 494 self._first_function_input = first_function_input 495 self._type_attribute = type_attribute 496 self._function_attributes = function_attributes 497 498 def converted_self(self): 499 if self._converted_self is None: 500 node = super(_FunctionCaller, self).converted_self().node 501 converted_names = self._enclosing_graph.converted_function_names 502 for attr_name in self._function_attributes: 503 attr = node.attr[attr_name] 504 if attr.HasField("func"): 505 attr.func.name = converted_names[attr.func.name] 506 elif attr.HasField("list"): 507 for func in attr.list.func: 508 func.name = converted_names[func.name] 509 return self._converted_self 510 511 def convert_variable_to_constant(self, incoming_edge, tensor_data): 512 node = self.converted_self() 513 index = incoming_edge.destination.index 514 if index >= self._first_function_input: 515 node.update_dtype(self._type_attribute, 516 index - self._first_function_input, tensor_data.dtype) 517 518 # The loop below is reasonable but not correct in general: 519 # The outgoing edges going into the functions are correct, because the 520 # inputs map to the function inputs. But the edges going into other nodes do 521 # not take into account the logic of the body function, which may do 522 # arbitrary things to the node's output: 523 # 524 # while x < 0: 525 # return y 526 # 527 # In this case, the node's ":0" output may map to its ":1 input". For the 528 # time being, then, we only process edges into functions. 529 for edge in self.outgoing_edges: 530 dest = edge.destination.convertible 531 if edge.source.index == index and isinstance(dest, _Function): 532 dest.convert_variable_to_constant(edge, tensor_data) 533 534 def create_edges(self): 535 """Creates edges related to a function caller. 536 537 Edges from a function caller to its called functions are always edges from 538 _inputs_ to _inputs_: a FunctionDef input is given by the caller, based on 539 its own inputs. 540 """ 541 super(_FunctionCaller, self).create_edges() 542 for attr_name in self._function_attributes: 543 attr = self._node.attr[attr_name] 544 if attr.HasField("func"): 545 function = self._enclosing_graph.functions[attr.func.name] 546 for index in range(len(self._node.input) - self._first_function_input): 547 self.add_outgoing_edge( 548 _Edge( 549 _EndPoint(self, index + self._first_function_input), 550 _EndPoint(function, index))) 551 elif attr.HasField("list"): 552 for func in attr.list.func: 553 function = self._enclosing_graph.functions[func.name] 554 for index in range( 555 len(self._node.input) - self._first_function_input): 556 self.add_outgoing_edge( 557 _Edge( 558 _EndPoint(self, index + self._first_function_input), 559 _EndPoint(function, index))) 560 561 562class _If(_FunctionCaller): 563 """Specialization of _Node to If-like operations.""" 564 565 def __init__(self, node, function, enclosing_graph): 566 super(_If, self).__init__( 567 node, 568 function, 569 enclosing_graph, 570 first_function_input=1, 571 type_attribute="Tin", 572 function_attributes=["then_branch", "else_branch"]) 573 574 575class _Case(_FunctionCaller): 576 """Specialization of _Node to Case-like operations.""" 577 578 def __init__(self, node, function, enclosing_graph): 579 super(_Case, self).__init__( 580 node, 581 function, 582 enclosing_graph, 583 first_function_input=1, 584 type_attribute="Tin", 585 function_attributes=["branches"]) 586 587 588class _PartitionedCall(_FunctionCaller): 589 """Specialization of _Node to PartitionedCall-like operations.""" 590 591 def __init__(self, node, function, enclosing_graph): 592 super(_PartitionedCall, self).__init__( 593 node, 594 function, 595 enclosing_graph, 596 first_function_input=0, 597 type_attribute="Tin", 598 function_attributes=["f"]) 599 600 601class _While(_FunctionCaller): 602 """Specialization of _Node to While-like operations.""" 603 604 def __init__(self, node, function, enclosing_graph): 605 super(_While, self).__init__( 606 node, 607 function, 608 enclosing_graph, 609 first_function_input=0, 610 type_attribute="T", 611 function_attributes=["body", "cond"]) 612 613 def convert_variable_to_constant(self, incoming_edge, tensor_data): 614 super(_While, self).convert_variable_to_constant(incoming_edge, tensor_data) 615 node = self.converted_self() 616 if node.node.attr["output_shapes"].list.shape: 617 node.node.attr["output_shapes"].list.shape[ 618 incoming_edge.destination.index].CopyFrom( 619 tensor_shape_pb2.TensorShapeProto(dim=[ 620 tensor_shape_pb2.TensorShapeProto.Dim(size=dim) 621 for dim in tensor_data.numpy.shape 622 ])) 623 624 # The while's body inputs and outputs have the same type, so here we can go 625 # ahead and change that function's output type. 626 body_name = self._node.attr["body"].func.name 627 body = self._enclosing_graph.functions[body_name].converted_self().function 628 body.signature.output_arg[ 629 incoming_edge.destination.index].type = tensor_data.dtype 630 631 632class _GraphDef(_Convertible): 633 """A convertible GraphDef.""" 634 635 def __init__(self, graph_def): 636 super(_GraphDef, self).__init__(enclosing_graph=None) 637 self._graph_def = graph_def 638 self._nodes = { 639 n.name: _Node.new(node=n, function=None, enclosing_graph=self) 640 for n in graph_def.node 641 } 642 self._functions = { 643 f.signature.name: _Function(f, enclosing_graph=self) 644 for f in graph_def.library.function 645 } 646 self.create_edges() 647 self._converted_function_names = None 648 649 @property 650 def graph_def(self): 651 return self._graph_def 652 653 @property 654 def nodes(self): 655 return self._nodes 656 657 @property 658 def functions(self): 659 return self._functions 660 661 @property 662 def converted_function_names(self): 663 """Map from original to new function names. 664 665 In order to avoid conflicts (two functions with the same name, one converted 666 and one not), we need to change the name of every converted function to 667 something that is hopefully unique. 668 669 Returns: 670 Map from original to new suggested function names. 671 """ 672 if self._converted_function_names is None: 673 parsed_names = [] # List of (id, base_name, original_name) 674 for name in self.functions: 675 elements = name.rsplit("_", 1) 676 if len(elements) == 2 and elements[1].isnumeric(): 677 parsed_names.append((int(elements[1]), elements[0], name)) 678 else: 679 parsed_names.append((-1, name, name)) 680 self._converted_function_names = { 681 name: "{}_frozen_{}".format(base_name, ops.uid()) 682 for (_, base_name, name) in sorted(parsed_names) 683 } 684 685 return self._converted_function_names 686 687 def rename_function(self, old_name, new_name): 688 func = self.functions.pop(old_name) 689 func.function.signature.name = new_name 690 self.functions[new_name] = func 691 692 def converted_self(self): 693 if self._converted_self is None: 694 copied_graph = graph_pb2.GraphDef() 695 copied_graph.CopyFrom(self._graph_def) 696 self._converted_self = _GraphDef(copied_graph) 697 return self._converted_self 698 699 def create_edges(self): 700 for n in self._nodes.values(): 701 n.create_edges() 702 for f in self._functions.values(): 703 f.create_edges() 704 705 706class _ConverterData(object): 707 """Container for constant conversion supporting data. 708 709 The data includes the graph being converted, and the pre-converted 710 tensors. This class will be specialized for ConcreteFunction and Session-based 711 conversions, as the means to obtain that data is different for each case. 712 """ 713 714 def __init__(self, 715 graph_def, 716 variable_names_allowlist=None, 717 variable_names_denylist=None): 718 self._graph_def = graph_def 719 self._tensor_data = {} 720 self._build_node_defs_list() 721 self._variable_names_allowlist = variable_names_allowlist 722 self._variable_names_denylist = variable_names_denylist 723 724 @property 725 def graph_def(self): 726 """The graph to be converted.""" 727 return self._graph_def 728 729 @property 730 def node_defs(self): 731 """All the node defs in the graph to be converted. 732 733 Returns: 734 A map from node name to the NodeDef for all NodeDefs in the graph, as well 735 as all control flow NodeDefs in the functions. 736 """ 737 return self._node_defs 738 739 @property 740 def tensor_data(self): 741 """A map from tensor name to its converted _TensorData.""" 742 return self._tensor_data 743 744 def _should_convert(self, name): 745 """Checks whether to convert the given variable name to a constant.""" 746 return (self._variable_names_allowlist is None or 747 name in self._variable_names_allowlist) and ( 748 self._variable_names_denylist is None or 749 name not in self._variable_names_denylist) 750 751 def _build_node_defs_list(self): 752 """Builds the list of NodeDefs in the GraphDef. 753 754 This list consists of all NodeDefs in the main graph as well as all control 755 flow NodeDefs in the functions. 756 757 The remaining NodeDefs in the functions are not included because the op 758 names 759 are not unique and the variables are handled differently than the main 760 graph. 761 The control flow ops need to be extracted because they are need their 762 attributes to be updated similar to the control flow ops in the main graph. 763 """ 764 self._node_defs = {node.name: node for node in self._graph_def.node} 765 766 if self._graph_def.library: 767 for func in self._graph_def.library.function: 768 self._node_defs.update({ 769 node.name: node 770 for node in func.node_def 771 if node.op in _CONTROL_FLOW_OPS 772 }) 773 774 775class _FunctionConverterData(_ConverterData): 776 """Container for ConcreteFunction-based conversion data.""" 777 778 def __init__(self, 779 func, 780 lower_control_flow, 781 aggressive_inlining, 782 variable_names_allowlist=None, 783 variable_names_denylist=None): 784 """Creates the conversion data for the given function. 785 786 Args: 787 func: ConcreteFunction. 788 lower_control_flow: Boolean indicating whether or not to lower control 789 flow ops such as If and While. 790 aggressive_inlining: Boolean indicating whether or not to do aggressive 791 function inlining (might be unsafe if function has stateful ops, not 792 properly connected to control outputs). 793 variable_names_allowlist: The set of variable names to convert (by 794 default, all variables are converted). 795 variable_names_denylist: The set of variable names to omit converting to 796 constants. 797 """ 798 799 self._func = func 800 # Inline the graph in order to remove functions when possible. 801 graph_def = _run_inline_graph_optimization(func, lower_control_flow, 802 aggressive_inlining) 803 super(_FunctionConverterData, self).__init__( 804 graph_def, 805 variable_names_allowlist=variable_names_allowlist, 806 variable_names_denylist=variable_names_denylist) 807 self._build_tensor_data() 808 809 def _build_tensor_data(self): 810 """Caches the tensor data for all Placeholders in the given function.""" 811 map_index_to_variable = {} 812 for var in self._func.graph.variables: 813 for idx, captured_input in enumerate(self._func.captured_inputs): 814 if var.handle is captured_input: # pylint: disable=protected-access 815 map_index_to_variable[idx] = var 816 break 817 818 # Iterates through all captures which are represented as Placeholders. 819 for idx, (val_tensor, name_tensor) in enumerate(self._func.graph.captures): 820 tensor_name = name_tensor.name.split(":")[0] 821 if not self._should_convert(tensor_name): 822 continue 823 if idx in map_index_to_variable: 824 data = map_index_to_variable[idx].numpy() 825 else: 826 data = np.array(val_tensor.numpy()) 827 self._tensor_data[tensor_name] = _TensorData( 828 numpy=data, 829 dtype=dtypes.as_dtype(data.dtype).as_datatype_enum, 830 index=idx) 831 832 # Get data for VariableV2 ops (reference variables) that cannot be lifted. 833 for node in self.node_defs.values(): 834 if node.op == "VariableV2": 835 if not self._should_convert(node.name): 836 continue 837 if node.name not in self.tensor_data: 838 with self._func.graph.as_default(): 839 identity_node = array_ops.identity( 840 self._func.graph.as_graph_element(node.name + ":0")) 841 pruned_graph = self._func.prune([], [identity_node.name])()[0] 842 self._tensor_data[node.name] = _TensorData( 843 numpy=pruned_graph.numpy(), 844 dtype=node.attr["dtype"].type, 845 index=None) 846 847 848class _SessionConverterData(_ConverterData): 849 """Container for Session-based conversion data.""" 850 851 def __init__(self, 852 session, 853 graph_def, 854 output_node_names, 855 variable_names_allowlist=None, 856 variable_names_denylist=None): 857 graph_def = graph_util.extract_sub_graph(graph_def, output_node_names) 858 super(_SessionConverterData, self).__init__( 859 graph_def, 860 variable_names_allowlist=variable_names_allowlist, 861 variable_names_denylist=variable_names_denylist) 862 863 nodes_to_convert = [] 864 tensor_names_to_convert = [] 865 for node in self.graph_def.node: 866 if node.op in ["Variable", "VariableV2", "VarHandleOp"]: 867 tensor_name = node.name 868 if not self._should_convert(tensor_name): 869 continue 870 if node.op == "VarHandleOp": 871 tensor_name = tensor_name + "/Read/ReadVariableOp" 872 nodes_to_convert.append(node) 873 tensor_names_to_convert.append(tensor_name + ":0") 874 875 if tensor_names_to_convert: 876 converted_tensors = session.run(tensor_names_to_convert) 877 for node, tensor_value in zip(nodes_to_convert, converted_tensors): 878 self._tensor_data[node.name] = _TensorData( 879 numpy=tensor_value, dtype=node.attr["dtype"].type, index=None) 880 881 882def disable_lower_using_switch_merge(graph_def): 883 """Set '_lower_using_switch_merge' attributes to False. 884 885 Sets the attribute to False in the NodeDefs in the main graph and the NodeDefs 886 in each function's graph. 887 888 Args: 889 graph_def: GraphDef proto. 890 891 Returns: 892 GraphDef 893 """ 894 output_graph_def = graph_pb2.GraphDef() 895 output_graph_def.CopyFrom(graph_def) 896 897 def disable_control_flow_lowering(node): 898 if node.op in _CONTROL_FLOW_OPS: 899 node.attr["_lower_using_switch_merge"].b = False 900 901 for node in output_graph_def.node: 902 disable_control_flow_lowering(node) 903 904 if output_graph_def.library: 905 for func in output_graph_def.library.function: 906 for node in func.node_def: 907 disable_control_flow_lowering(node) 908 return output_graph_def 909 910 911def _run_inline_graph_optimization(func, lower_control_flow, 912 aggressive_inlining): 913 """Apply function inline optimization to the graph. 914 915 Returns the GraphDef after Grappler's function inlining optimization is 916 applied. This optimization does not work on models with control flow. 917 918 Args: 919 func: ConcreteFunction. 920 lower_control_flow: Boolean indicating whether or not to lower control flow 921 ops such as If and While. (default True) 922 aggressive_inlining: Boolean indicating whether or not to do aggressive 923 function inlining (might be unsafe if function has stateful ops not 924 properly connected to control outputs). 925 926 Returns: 927 GraphDef 928 """ 929 graph_def = func.graph.as_graph_def() 930 if not lower_control_flow: 931 graph_def = disable_lower_using_switch_merge(graph_def) 932 933 # In some cases, a secondary implementation of the function (e.g. for GPU) is 934 # written to the "api_implements" attribute. (e.g. `tf.keras.layers.LSTM` in 935 # TF2 produces a CuDNN-based RNN for GPU). 936 # This function suppose to inline all functions calls, but "api_implements" 937 # prevents this from happening. Removing the attribute solves the problem. 938 # To learn more about "api_implements", see: 939 # tensorflow/core/grappler/optimizers/implementation_selector.h 940 for function in graph_def.library.function: 941 if "api_implements" in function.attr: 942 del function.attr["api_implements"] 943 944 meta_graph = export_meta_graph(graph_def=graph_def, graph=func.graph) 945 946 # Clear the initializer_name for the variables collections, since they are not 947 # needed after saved to saved_model. 948 for name in [ 949 "variables", "model_variables", "trainable_variables", "local_variables" 950 ]: 951 raw_list = [] 952 for raw in meta_graph.collection_def["variables"].bytes_list.value: 953 variable = variable_pb2.VariableDef() 954 variable.ParseFromString(raw) 955 variable.ClearField("initializer_name") 956 raw_list.append(variable.SerializeToString()) 957 meta_graph.collection_def[name].bytes_list.value[:] = raw_list 958 959 # Add a collection 'train_op' so that Grappler knows the outputs. 960 fetch_collection = meta_graph_pb2.CollectionDef() 961 for array in func.inputs + func.outputs: 962 fetch_collection.node_list.value.append(array.name) 963 meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) 964 965 # Initialize RewriterConfig with everything disabled except function inlining. 966 config = config_pb2.ConfigProto() 967 rewrite_options = config.graph_options.rewrite_options 968 rewrite_options.min_graph_nodes = -1 # do not skip small graphs 969 rewrite_options.optimizers.append("function") 970 if aggressive_inlining: 971 rewrite_options.function_optimization =\ 972 rewriter_config_pb2.RewriterConfig.AGGRESSIVE 973 return tf_optimizer.OptimizeGraph(config, meta_graph) 974 975 976def _construct_concrete_function(func, output_graph_def, 977 converted_input_indices): 978 """Constructs a concrete function from the `output_graph_def`. 979 980 Args: 981 func: ConcreteFunction 982 output_graph_def: GraphDef proto. 983 converted_input_indices: Set of integers of input indices that were 984 converted to constants. 985 986 Returns: 987 ConcreteFunction. 988 """ 989 # Create a ConcreteFunction from the new GraphDef. 990 input_tensors = func.graph.internal_captures 991 converted_inputs = object_identity.ObjectIdentitySet( 992 [input_tensors[index] for index in converted_input_indices]) 993 not_converted_inputs = [ 994 tensor for tensor in func.inputs if tensor not in converted_inputs 995 ] 996 not_converted_inputs_map = { 997 tensor.name: tensor for tensor in not_converted_inputs 998 } 999 1000 new_input_names = [tensor.name for tensor in not_converted_inputs] 1001 new_output_names = [tensor.name for tensor in func.outputs] 1002 1003 # Remove old functions to use updated functions from graph def. 1004 for f in output_graph_def.library.function: 1005 if context.context().has_function(f.signature.name): 1006 context.context().remove_function(f.signature.name) 1007 1008 new_func = wrap_function.function_from_graph_def(output_graph_def, 1009 new_input_names, 1010 new_output_names) 1011 1012 # Manually propagate shape for input tensors where the shape is not correctly 1013 # propagated. Scalars shapes are lost when wrapping the function. 1014 for input_tensor in new_func.inputs: 1015 input_tensor.set_shape(not_converted_inputs_map[input_tensor.name].shape) 1016 return new_func 1017 1018 1019def _replace_variables_by_constants(converter_data): 1020 """Replaces variables by constants on a given graph. 1021 1022 Given a _ConverterData instance with converted variables in its tensor_data 1023 field, create a new graph where the respective variables are replaced with the 1024 converted constants. 1025 1026 Args: 1027 converter_data: A pre-populated _ConverterData instance. 1028 1029 Returns: 1030 The converted graph. 1031 """ 1032 input_graph = _GraphDef(converter_data.graph_def) 1033 1034 for tensor_name, tensor_data in converter_data.tensor_data.items(): 1035 input_graph.nodes[tensor_name].convert_variable_to_constant( 1036 None, tensor_data) 1037 1038 converted_graph = input_graph.converted_self().graph_def 1039 1040 converted_input_indices = { 1041 t.index 1042 for t in converter_data.tensor_data.values() 1043 if t.index is not None 1044 } 1045 1046 return converted_graph, converted_input_indices 1047 1048 1049def convert_variables_to_constants_v2(func, 1050 lower_control_flow=True, 1051 aggressive_inlining=False): 1052 """Replaces all the variables in a graph with constants of the same values. 1053 1054 TensorFlow 2.0 function for converting all Variable ops into Const ops holding 1055 the same values. This makes it possible to describe the network fully with a 1056 single GraphDef file, and allows the removal of a lot of ops related to 1057 loading and saving the variables. This function runs Grappler's function 1058 inlining optimization in order to return a single subgraph. 1059 1060 The current implementation only works for graphs that do not contain any 1061 control flow or embedding related ops. 1062 1063 Args: 1064 func: ConcreteFunction. 1065 lower_control_flow: Boolean indicating whether or not to lower control flow 1066 ops such as If and While. (default True) 1067 aggressive_inlining: Boolean indicating whether or not to do aggressive 1068 function inlining (might be unsafe if function has stateful ops, not 1069 properly connected to control outputs). (default False) 1070 1071 Returns: 1072 ConcreteFunction containing a simplified version of the original. 1073 """ 1074 1075 converter_data = _FunctionConverterData( 1076 func=func, 1077 lower_control_flow=lower_control_flow, 1078 aggressive_inlining=aggressive_inlining) 1079 1080 output_graph_def, converted_input_indices = _replace_variables_by_constants( 1081 converter_data=converter_data) 1082 1083 return _construct_concrete_function(func, output_graph_def, 1084 converted_input_indices) 1085 1086 1087def convert_variables_to_constants_v2_as_graph(func, 1088 lower_control_flow=True, 1089 aggressive_inlining=False): 1090 """Replaces all the variables in a graph with constants of the same values. 1091 1092 This function works as same as convert_variables_to_constants_v2, but it 1093 returns the intermediate `GraphDef` as well. This `GraphDef` contains all the 1094 debug information after all the transformations in the frozen phase. 1095 1096 Args: 1097 func: ConcreteFunction. 1098 lower_control_flow: Boolean indicating whether or not to lower control flow 1099 ops such as If and While. (default True) 1100 aggressive_inlining: Boolean indicating whether or not to do aggressive 1101 function inlining (might be unsafe if function has stateful ops, not 1102 properly connected to control outputs). 1103 1104 Returns: 1105 ConcreteFunction containing a simplified version of the original, and also 1106 the intermediate GraphDef containing the node debug information for the 1107 transformations in the frozen phase. 1108 """ 1109 converter_data = _FunctionConverterData( 1110 func=func, 1111 lower_control_flow=lower_control_flow, 1112 aggressive_inlining=aggressive_inlining) 1113 1114 output_graph_def, converted_input_indices = _replace_variables_by_constants( 1115 converter_data=converter_data) 1116 1117 frozen_func = _construct_concrete_function(func, output_graph_def, 1118 converted_input_indices) 1119 return frozen_func, output_graph_def 1120 1121 1122def convert_variables_to_constants_from_session_graph( 1123 session, 1124 graph_def, 1125 output_node_names, 1126 variable_names_allowlist=None, 1127 variable_names_denylist=None): 1128 """Replaces all the variables in a graph with constants of the same values. 1129 1130 This function works similarly to convert_variables_to_constants_v2, but it 1131 retrieves the constant values from a Session instead of from a 1132 ConcreteFunction. This is useful when converting graphs generated from 1133 TensorFlow V1, where ConcreteFunctions are not available. This also differs 1134 from graph_util.convert_variables_to_constants in that it supports resource 1135 variables when V2 control flow constructions are present. 1136 1137 Args: 1138 session: Active TensorFlow session containing the variables. 1139 graph_def: A GraphDef to convert. 1140 output_node_names: List of name strings for the result nodes of the graph. 1141 variable_names_allowlist: The set of variable names to convert (by default, 1142 all variables are converted). 1143 variable_names_denylist: The set of variable names to omit converting to 1144 constants. 1145 1146 Returns: 1147 An optimized GraphDef. 1148 """ 1149 # TODO(b/176982859): Find a more satisfying way to update shape information 1150 # than clearing it, or migrate users to a workflow that does not require 1151 # freezing. 1152 for function in graph_def.library.function: 1153 if "_input_shapes" in function.attr: 1154 for input_arg, shape_attribute in zip( 1155 function.signature.input_arg, 1156 function.attr["_input_shapes"].list.shape): 1157 if dtypes.as_dtype(input_arg.type) == dtypes.resource: 1158 shape_attribute.unknown_rank = True 1159 graph_def, _ = _replace_variables_by_constants( 1160 converter_data=_SessionConverterData( 1161 session=session, 1162 graph_def=graph_def, 1163 output_node_names=output_node_names, 1164 variable_names_allowlist=variable_names_allowlist, 1165 variable_names_denylist=variable_names_denylist)) 1166 return graph_def 1167