1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Utlity to convert FunctionDef to GraphDef and Graph.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.framework import graph_pb2 22from tensorflow.core.framework import types_pb2 23from tensorflow.core.framework import versions_pb2 24from tensorflow.python.eager import context 25from tensorflow.python.framework import importer 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import versions 28from tensorflow.python.framework.func_graph import FuncGraph 29 30 31def function_def_to_graph(fdef, input_shapes=None): 32 """Converts a FunctionDef to a FuncGraph (sub-class Graph). 33 34 The returned FuncGraph's `name`, `inputs` and `outputs` fields will be set. 35 The input tensors are represented as placeholders. 36 37 Note: `FuncGraph.inputs` and `FuncGraph.captures` are not set and may be set 38 by the caller. 39 40 Args: 41 fdef: FunctionDef. 42 input_shapes: Optional. A list of TensorShape objects of the shapes of 43 function inputs. If specified, its length must match length of 44 `fdef.signature.input_arg`. If a shape is None, the corresponding input 45 placeholder will have unknown shape. 46 47 Returns: 48 A FuncGraph. 49 """ 50 func_graph = FuncGraph(fdef.signature.name) 51 graph_def, nested_to_flat_tensor_name = function_def_to_graph_def( 52 fdef, input_shapes) 53 54 with func_graph.as_default(): 55 # Add all function nodes to the graph. 56 importer.import_graph_def(graph_def, name="") 57 58 # Initialize fields specific to FuncGraph. 59 60 # inputs 61 input_tensor_names = [ 62 nested_to_flat_tensor_name[arg.name] for arg in fdef.signature.input_arg 63 ] 64 func_graph.inputs = [ 65 func_graph.get_tensor_by_name(name) for name in input_tensor_names 66 ] 67 68 # outputs 69 output_tensor_names = [ 70 nested_to_flat_tensor_name[fdef.ret[arg.name]] 71 for arg in fdef.signature.output_arg 72 ] 73 func_graph.outputs = [ 74 func_graph.get_tensor_by_name(name) for name in output_tensor_names 75 ] 76 func_graph.control_outputs = [ 77 func_graph.get_operation_by_name(fdef.control_ret[ret_name]) 78 for ret_name in fdef.signature.control_output 79 ] 80 81 return func_graph 82 83 84def _is_function(fname): 85 """Checks for a function definition with `fname` in the current context.""" 86 if context.executing_eagerly(): 87 return context.context().has_function(fname) 88 else: 89 return ops.get_default_graph()._is_function(fname) # pylint: disable=protected-access 90 91 92def function_def_to_graph_def(fdef, input_shapes=None): 93 """Convert a FunctionDef to a GraphDef. 94 95 Steps: 96 1. Creates placeholder nodes corresponding to inputs in 97 `FunctionDef.signature.input_arg`. 98 2. Adds NodeDefs in `FunctionDef.node_def` to `GraphDef.node`. 99 3. Renames inputs of all nodes to use the convention of GraphDef instead of 100 FunctionDef. See comment on `FunctionDef.node_def` on how the tensor naming 101 in FunctionDefs is different from GraphDefs. 102 103 Args: 104 fdef: FunctionDef. 105 input_shapes: Optional. A list of TensorShape objects of the shapes of 106 function inputs. If specified, its length must match length of 107 `fdef.signature.input_arg`. If a shape is None, the corresponding input 108 placeholder will have unknown shape. 109 110 Returns: 111 A tuple of (GraphDef, dict<string, string>). The dict contains a mapping 112 from nested tensor names (in FunctionDef) to flattened names (in GraphDef). 113 114 Raises: 115 ValueError: If the length of input_shapes does not match the number of 116 input_args or if the FunctionDef is invalid. 117 """ 118 graph_def = graph_pb2.GraphDef() 119 graph_def.versions.CopyFrom( 120 versions_pb2.VersionDef( 121 producer=versions.GRAPH_DEF_VERSION, 122 min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER)) 123 124 # Copy *all* functions from outer graph to `graph_def` so that both direct 125 # and indirect references are safely handled. 126 ops.get_default_graph()._copy_functions_to_graph_def(graph_def, 0) # pylint: disable=protected-access 127 128 if input_shapes and len(input_shapes) != len(fdef.signature.input_arg): 129 raise ValueError("Length of input_shapes must match the number of " + 130 "input_args. len(input_shapes): {} len(input_arg): {}". 131 format(len(input_shapes), len(fdef.signature.input_arg))) 132 133 # 1. Create placeholders for input nodes. 134 for i, arg_def in enumerate(fdef.signature.input_arg): 135 node_def = graph_def.node.add() 136 node_def.name = arg_def.name 137 node_def.op = "Placeholder" 138 node_def.attr["dtype"].type = arg_def.type 139 if input_shapes and input_shapes[i] is not None: 140 node_def.attr["shape"].shape.CopyFrom(input_shapes[i].as_proto()) 141 142 # 2. Copy all body NodeDefs to the GraphDef. 143 graph_def.node.extend(fdef.node_def) 144 145 # 3. Perform the renaming. 146 147 # Build the tensor name mapping then flatten the tensor names. 148 # See comment on `FunctionDef.node_def` on how the tensor naming in 149 # FunctionDefs is different from GraphDefs. 150 nested_to_flat_tensor_name = {} 151 152 for arg_def in fdef.signature.input_arg: 153 nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name) 154 control_name = "^" + arg_def.name 155 nested_to_flat_tensor_name[control_name] = control_name 156 157 for node_def in fdef.node_def: 158 op_def = ops.get_default_graph()._get_op_def(node_def.op) # pylint: disable=protected-access 159 160 for attr in op_def.attr: 161 if attr.type == "func": 162 fname = node_def.attr[attr.name].func.name 163 if not _is_function(fname): 164 raise ValueError("%s function not found." % fname) 165 elif attr.type == "list(func)": 166 for fn in node_def.attr[attr.name].list.func: 167 fname = fn.name 168 if not _is_function(fname): 169 raise ValueError("%s function not found." % fname) 170 171 # Iterate over output_args in op_def to build the map. 172 # Index of the output tensor in the flattened list of *all* output 173 # tensors of the op. 174 flattened_index = 0 175 for arg_def in op_def.output_arg: 176 num_args = _get_num_args(arg_def, node_def) 177 for i in range(num_args): 178 # Map tensor names from "node_name:output_arg_name:index" to 179 # "node_name:flattened_index". 180 nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i) 181 flat_name = "{}:{}".format(node_def.name, flattened_index) 182 nested_to_flat_tensor_name[nested_name] = flat_name 183 flattened_index += 1 184 control_name = "^" + node_def.name 185 nested_to_flat_tensor_name[control_name] = control_name 186 187 # Update inputs of all nodes in graph. 188 for node_def in graph_def.node: 189 for i in range(len(node_def.input)): 190 node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]] 191 192 return graph_def, nested_to_flat_tensor_name 193 194 195# Based on implementation in core/framework/node_def_util.cc::ComputeArgRange. 196def _get_num_args(arg_def, node_def): 197 if arg_def.number_attr: 198 return node_def.attr[arg_def.number_attr].i 199 elif arg_def.type_list_attr: 200 return len(node_def.attr[arg_def.type_list_attr].list.type) 201 elif arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID: 202 return 1 203 else: 204 raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def))) 205