1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Utility to convert FunctionDef to GraphDef and Graph.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import itertools 22 23 24from tensorflow.core.framework import function_pb2 25from tensorflow.core.framework import graph_pb2 26from tensorflow.core.framework import tensor_shape_pb2 27from tensorflow.core.framework import types_pb2 28from tensorflow.core.framework import versions_pb2 29from tensorflow.python.eager import context 30from tensorflow.python.framework import cpp_shape_inference_pb2 31from tensorflow.python.framework import importer 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import versions 34from tensorflow.python.framework.func_graph import FuncGraph 35from tensorflow.python.ops import resource_variable_ops 36 37 38def function_def_to_graph(fdef, input_shapes=None): 39 """Converts a FunctionDef to a FuncGraph (sub-class Graph). 40 41 The returned FuncGraph's `name`, `inputs` and `outputs` fields will be set. 42 The input tensors are represented as placeholders. 43 44 Note: `FuncGraph.inputs` and `FuncGraph.captures` are not set and may be set 45 by the caller. 46 47 Args: 48 fdef: FunctionDef. 49 input_shapes: Optional. A list of TensorShape objects of the shapes of 50 function inputs. Defaults to the function's "_input_shapes" attribute. If 51 specified, its length must match length of `fdef.signature.input_arg`. If 52 a shape is None, the corresponding input placeholder will have unknown 53 shape. 54 55 Returns: 56 A FuncGraph. 57 """ 58 func_graph = FuncGraph(fdef.signature.name) 59 if input_shapes is None: 60 input_shapes_attr = fdef.attr.get("_input_shapes", None) 61 if input_shapes_attr is not None: 62 input_shapes = input_shapes_attr.list.shape 63 graph_def, nested_to_flat_tensor_name = function_def_to_graph_def( 64 fdef, input_shapes) 65 66 with func_graph.as_default(): 67 # Add all function nodes to the graph. 68 importer.import_graph_def_for_function(graph_def, name="") 69 70 # Initialize fields specific to FuncGraph. 71 72 # inputs 73 input_tensor_names = [ 74 nested_to_flat_tensor_name[arg.name] for arg in fdef.signature.input_arg 75 ] 76 func_graph.inputs = [ 77 func_graph.get_tensor_by_name(name) for name in input_tensor_names 78 ] 79 80 # outputs 81 output_tensor_names = [ 82 nested_to_flat_tensor_name[fdef.ret[arg.name]] 83 for arg in fdef.signature.output_arg 84 ] 85 func_graph.outputs = [ 86 func_graph.get_tensor_by_name(name) for name in output_tensor_names 87 ] 88 func_graph.control_outputs = [ 89 func_graph.get_operation_by_name(fdef.control_ret[ret_name]) 90 for ret_name in fdef.signature.control_output 91 ] 92 93 _set_handle_data(func_graph, fdef) 94 95 for node in graph_def.node: 96 output_shapes = node.attr.get("_output_shapes", None) 97 if output_shapes is not None: 98 op = func_graph.get_operation_by_name(node.name) 99 # _output_shapes for functions can sometimes be too long because the 100 # output-intermediates-for-gradients version of the function was 101 # substituted before saving. We'll accept that here. (See b/133666530). 102 for output_index, shape in enumerate( 103 output_shapes.list.shape[:len(op.outputs)]): 104 op.outputs[output_index].set_shape(shape) 105 output_names = {} 106 for ret_arg_def, tensor_name in zip( 107 fdef.signature.output_arg, output_tensor_names): 108 output_names[ops.tensor_id( 109 func_graph.get_tensor_by_name(tensor_name))] = ( 110 ret_arg_def.name) 111 func_graph._output_names = output_names # pylint: disable=protected-access 112 return func_graph 113 114 115def is_function(fname): 116 """Checks for a function definition with `fname` in the current context.""" 117 if context.executing_eagerly(): 118 return context.context().has_function(fname) 119 else: 120 graph = ops.get_default_graph() 121 while graph is not None: 122 if graph._is_function(fname): # pylint: disable=protected-access 123 return True 124 if hasattr(graph, "outer_graph"): 125 graph = graph.outer_graph 126 else: 127 return False 128 129 130def function_def_to_graph_def(fdef, input_shapes=None): 131 """Convert a FunctionDef to a GraphDef. 132 133 Steps: 134 1. Creates placeholder nodes corresponding to inputs in 135 `FunctionDef.signature.input_arg`. 136 2. Adds NodeDefs in `FunctionDef.node_def` to `GraphDef.node`. 137 3. Renames inputs of all nodes to use the convention of GraphDef instead of 138 FunctionDef. See comment on `FunctionDef.node_def` on how the tensor naming 139 in FunctionDefs is different from GraphDefs. 140 141 Args: 142 fdef: FunctionDef. 143 input_shapes: Optional. A list of TensorShape objects of the shapes of 144 function inputs. If specified, its length must match length of 145 `fdef.signature.input_arg`. If a shape is None, the corresponding input 146 placeholder will have unknown shape. 147 148 Returns: 149 A tuple of (GraphDef, dict<string, string>). The dict contains a mapping 150 from nested tensor names (in FunctionDef) to flattened names (in GraphDef). 151 152 Raises: 153 ValueError: If the length of input_shapes does not match the number of 154 input_args or if the FunctionDef is invalid. 155 """ 156 graph_def = graph_pb2.GraphDef() 157 graph_def.versions.CopyFrom( 158 versions_pb2.VersionDef( 159 producer=versions.GRAPH_DEF_VERSION, 160 min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER)) 161 162 default_graph = ops.get_default_graph() 163 164 copied_functions = set() 165 166 if input_shapes and len(input_shapes) != len(fdef.signature.input_arg): 167 raise ValueError("Length of input_shapes must match the number of " + 168 "input_args. len(input_shapes): {} len(input_arg): {}". 169 format(len(input_shapes), len(fdef.signature.input_arg))) 170 171 # 1. Create placeholders for input nodes. 172 for i, arg_def in enumerate(fdef.signature.input_arg): 173 node_def = graph_def.node.add() 174 node_def.name = arg_def.name 175 node_def.op = "Placeholder" 176 node_def.attr["dtype"].type = arg_def.type 177 if input_shapes and input_shapes[i] is not None: 178 input_shape = input_shapes[i] 179 if not isinstance(input_shape, tensor_shape_pb2.TensorShapeProto): 180 input_shape = input_shape.as_proto() 181 node_def.attr["shape"].shape.CopyFrom(input_shape) 182 arg_attrs = fdef.arg_attr[i].attr 183 for k in arg_attrs: 184 # Only copy internal attributes. Normal attributes for nodes cannot be 185 # applied to these Placeholder nodes. 186 if k == "_output_shapes": 187 node_def.attr["shape"].shape.CopyFrom(arg_attrs[k].list.shape[0]) 188 elif k.startswith("_"): 189 node_def.attr[k].CopyFrom(arg_attrs[k]) 190 191 # 2. Copy all body NodeDefs to the GraphDef. 192 graph_def.node.extend(fdef.node_def) 193 194 # 3. Perform the renaming. 195 196 # Build the tensor name mapping then flatten the tensor names. 197 # See comment on `FunctionDef.node_def` on how the tensor naming in 198 # FunctionDefs is different from GraphDefs. 199 nested_to_flat_tensor_name = {} 200 201 for arg_def in fdef.signature.input_arg: 202 nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name) 203 control_name = "^" + arg_def.name 204 nested_to_flat_tensor_name[control_name] = control_name 205 206 for node_def in fdef.node_def: 207 graph = default_graph 208 while True: 209 f = graph._functions.get(node_def.op, None) # pylint: disable=protected-access 210 if f is not None or not hasattr(graph, "outer_graph"): 211 break 212 graph = graph.outer_graph 213 214 if f is not None: 215 op_def = f.definition.signature 216 if node_def.op not in copied_functions: 217 # Since this function is referenced as an op type, we have no choice but 218 # to copy it into the GraphDef if we want downstream tools to process 219 # it. 220 graph_def.library.function.add().CopyFrom(f.definition) 221 copied_functions.add(node_def.op) 222 if f.grad_func_name: 223 grad_def = function_pb2.GradientDef() 224 grad_def.function_name = f.name 225 grad_def.gradient_func = f.grad_func_name 226 graph_def.library.gradient.extend([grad_def]) 227 else: 228 op_def = default_graph._get_op_def(node_def.op) # pylint: disable=protected-access 229 230 for attr in op_def.attr: 231 if attr.type == "func": 232 fname = node_def.attr[attr.name].func.name 233 if not is_function(fname): 234 raise ValueError("%s function not found." % fname) 235 elif attr.type == "list(func)": 236 for fn in node_def.attr[attr.name].list.func: 237 fname = fn.name 238 if not is_function(fname): 239 raise ValueError("%s function not found." % fname) 240 241 # Iterate over output_args in op_def to build the map. 242 # Index of the output tensor in the flattened list of *all* output 243 # tensors of the op. 244 flattened_index = 0 245 for arg_def in op_def.output_arg: 246 num_args = _get_num_args(arg_def, node_def) 247 for i in range(num_args): 248 # Map tensor names from "node_name:output_arg_name:index" to 249 # "node_name:flattened_index". 250 nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i) 251 flat_name = "{}:{}".format(node_def.name, flattened_index) 252 nested_to_flat_tensor_name[nested_name] = flat_name 253 flattened_index += 1 254 control_name = "^" + node_def.name 255 nested_to_flat_tensor_name[control_name] = control_name 256 257 # Update inputs of all nodes in graph. 258 for node_def in graph_def.node: 259 for i in range(len(node_def.input)): 260 node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]] 261 262 return graph_def, nested_to_flat_tensor_name 263 264 265# Based on implementation in core/framework/node_def_util.cc::ComputeArgRange. 266def _get_num_args(arg_def, node_def): 267 if arg_def.number_attr: 268 return node_def.attr[arg_def.number_attr].i 269 elif arg_def.type_list_attr: 270 return len(node_def.attr[arg_def.type_list_attr].list.type) 271 elif arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID: 272 return 1 273 else: 274 raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def))) 275 276 277def _set_handle_data(func_graph, fdef): 278 """Adds handle data for resource type inputs and outputs.""" 279 for tensor, arg_def in itertools.chain( 280 zip(func_graph.inputs, fdef.signature.input_arg), 281 zip(func_graph.outputs, fdef.signature.output_arg)): 282 if arg_def.handle_data: 283 shape_and_dtype = arg_def.handle_data[0] 284 handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData() 285 handle_data.is_set = True 286 handle_data.shape_and_type.append( 287 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType( 288 shape=shape_and_dtype.shape, dtype=shape_and_dtype.dtype)) 289 resource_variable_ops._set_handle_shapes_and_types( # pylint: disable=protected-access 290 tensor, handle_data, True) 291