1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helpers to convert variables to constants in TensorFlow 2.0.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.framework import graph_pb2 22from tensorflow.core.protobuf import config_pb2 23from tensorflow.core.protobuf import meta_graph_pb2 24from tensorflow.python.eager import function 25from tensorflow.python.framework import func_graph 26from tensorflow.python.framework import importer 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_util 29from tensorflow.python.grappler import tf_optimizer 30from tensorflow.python.platform import tf_logging as logging 31from tensorflow.python.training.saver import export_meta_graph 32 33 34def _run_inline_graph_optimization(func): 35 """Apply function inline optimization to the graph. 36 37 Returns the GraphDef after Grappler's function inlining optimization is 38 applied. This optimization does not work on models with control flow. 39 40 Args: 41 func: ConcreteFunction. 42 43 Returns: 44 GraphDef 45 """ 46 meta_graph = export_meta_graph( 47 graph_def=func.graph.as_graph_def(), graph=func.graph) 48 49 # Add a collection 'train_op' so that Grappler knows the outputs. 50 fetch_collection = meta_graph_pb2.CollectionDef() 51 for array in func.inputs + func.outputs: 52 fetch_collection.node_list.value.append(array.name) 53 meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) 54 55 # Initialize RewriterConfig with everything disabled except function inlining. 56 config = config_pb2.ConfigProto() 57 rewrite_options = config.graph_options.rewrite_options 58 rewrite_options.optimizers.append("function") 59 return tf_optimizer.OptimizeGraph(config, meta_graph) 60 61 62def _get_tensors_from_graph(graph, tensors): 63 """Gets the Tensors in `graph` with the name of the tensors in `tensors`. 64 65 Args: 66 graph: TensorFlow Graph. 67 tensors: List of Tensors. 68 69 Returns: 70 List of Tensors. 71 """ 72 new_tensors = [] 73 for orig_tensor in tensors: 74 new_tensor = graph.get_tensor_by_name(orig_tensor.name) 75 if new_tensor.shape.rank is None: 76 new_tensor.set_shape(orig_tensor.shape) 77 new_tensors.append(new_tensor) 78 return new_tensors 79 80 81def _construct_concrete_function(input_func, graph_def): 82 """Creates a ConcreteFunction from the input function and frozen graph. 83 84 Args: 85 input_func: ConcreteFunction. 86 graph_def: TensorFlow GraphDef. 87 88 Returns: 89 ConcreteFunction containing the graph_def. 90 """ 91 output_graph = func_graph.FuncGraph(input_func.graph.name) 92 with output_graph.as_default(): 93 importer.import_graph_def(graph_def, name="") 94 output_graph.inputs = _get_tensors_from_graph(output_graph, 95 input_func.inputs) 96 output_graph.outputs = _get_tensors_from_graph(output_graph, 97 input_func.outputs) 98 99 output_graph.structured_outputs = input_func.graph.structured_outputs 100 output_graph.structured_input_signature = ( 101 input_func.graph.structured_input_signature) 102 103 # Create the ConcreteFunction and add it to the global context. 104 output_func = function.ConcreteFunction(output_graph) 105 output_func.add_to_graph() 106 107 # Inject the captured inputs into the ConcreteFunction. 108 output_func._captured_inputs = input_func.captured_inputs # pylint: disable=protected-access 109 output_func.graph.variables = input_func.graph.variables 110 111 output_func._arg_keywords = input_func._arg_keywords # pylint: disable=protected-access 112 output_func._num_position_args = input_func._num_positional_args # pylint: disable=protected-access 113 114 # Register the gradients in the current root context. 115 with ops.init_scope(): 116 output_func._register_gradient() # pylint: disable=protected-access 117 return output_func 118 119 120def convert_variables_to_constants_v2(func): 121 """Replaces all the variables in a graph with constants of the same values. 122 123 TensorFlow 2.0 function for converting all Variable ops into Const ops holding 124 the same values. This makes it possible to describe the network fully with a 125 single GraphDef file, and allows the removal of a lot of ops related to 126 loading and saving the variables. This function runs Grappler's function 127 inlining optimization in order to return a single subgraph. 128 129 The current implementation only works for graphs that do not contain any 130 control flow or embedding related ops. 131 132 Args: 133 func: ConcreteFunction. 134 135 Returns: 136 ConcreteFunction containing a simplified version of the original. 137 """ 138 # TODO(nupurgarg): Replace ResourceGather with Gather. 139 # TODO(nupurgarg): Change attr for Variables in control flow and functions. 140 graph_def = _run_inline_graph_optimization(func) 141 142 # Identify the ReadVariableOps. 143 get_name = lambda name: name.split(":")[0] 144 map_name_to_node = {get_name(node.name): node for node in graph_def.node} 145 146 # TODO(b/125838789): Use `func.graph.captures`. 147 # Get mapping from input name to variable value. 148 tensor_data = {} 149 input_tensors = func.inputs[-len(func.captured_inputs):] 150 for var in func.graph.variables: 151 index = func.captured_inputs.index(var.handle) 152 tensor = input_tensors[index] 153 tensor_data[get_name(tensor.name)] = var.numpy() 154 155 resource_identities = {} 156 resource_placeholders = {} 157 for node in graph_def.node: 158 if node.op == "ReadVariableOp": 159 # Get name of Placeholder op associated with ReadVariableOp. There can be 160 # an Identity in between the ReadVariableOp and Placeholder. Store the 161 # Identity ops with the associated dtypes. 162 input_name = get_name(node.input[0]) 163 while map_name_to_node[input_name].op == "Identity": 164 resource_identities[input_name] = node.attr["dtype"] 165 input_name = get_name(map_name_to_node[input_name].input[0]) 166 if map_name_to_node[input_name].op != "Placeholder": 167 raise ValueError("Cannot find the Placeholder op that is an input " 168 "to the ReadVariableOp.") 169 # Build a map of Placeholder ops that are inputs to ReadVariableOps to the 170 # variable's dtype and data. 171 resource_placeholders[input_name] = { 172 "dtype": node.attr["dtype"], 173 "data": tensor_data[input_name], 174 } 175 176 # Reconstruct the graph with constants in place of variables. 177 output_graph_def = graph_pb2.GraphDef() 178 how_many_converted = 0 179 180 for input_node in graph_def.node: 181 output_node = output_graph_def.node.add() 182 # Convert Placeholder ops that are inputs to ReadVariableOps into Const ops. 183 if input_node.name in resource_placeholders: 184 dtype = resource_placeholders[input_node.name]["dtype"] 185 data = resource_placeholders[input_node.name]["data"] 186 187 output_node.op = "Const" 188 output_node.name = input_node.name 189 output_node.attr["dtype"].CopyFrom(dtype) 190 output_node.attr["value"].tensor.CopyFrom( 191 tensor_util.make_tensor_proto( 192 data, dtype=dtype.type, shape=data.shape)) 193 how_many_converted += 1 194 # Change the dtype for Identity ops that are inputs to ReadVariableOps. 195 elif input_node.name in resource_identities: 196 output_node.CopyFrom(input_node) 197 output_node.attr["T"].CopyFrom(resource_identities[input_node.name]) 198 # Convert ReadVariableOps into Identity ops. 199 elif input_node.op == "ReadVariableOp": 200 output_node.op = "Identity" 201 output_node.name = input_node.name 202 output_node.input.extend([input_node.input[0]]) 203 output_node.attr["T"].CopyFrom(input_node.attr["dtype"]) 204 if "_class" in input_node.attr: 205 output_node.attr["_class"].CopyFrom(input_node.attr["_class"]) 206 else: 207 output_node.CopyFrom(input_node) 208 209 logging.info("Converted %d variables to const ops.", how_many_converted) 210 # TODO(b/126613403): Use wrap_function.function_from_graph_def. 211 return _construct_concrete_function(func, output_graph_def) 212