1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Utility to convert a Graph to a FunctionDef.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import re 22 23from tensorflow.core.framework import function_pb2 24from tensorflow.core.framework import op_def_pb2 25from tensorflow.python.framework import errors_impl 26from tensorflow.python.framework import op_def_registry 27 28 29def _make_argname_from_tensor_name(name): 30 return re.sub(":0$", "", name).replace(":", "_o") 31 32 33def _tensor_to_argdef(t, name=None, used_names=None): 34 """Convert tensor t to an argdef, with a specified name or a unique name.""" 35 arg = op_def_pb2.OpDef.ArgDef() 36 if name is None: 37 arg.name = _make_argname_from_tensor_name(t.name) 38 if used_names is not None: 39 if arg.name in used_names: 40 i = 0 41 while True: 42 new_name = "%s_U%d" % (arg.name, i) 43 if new_name not in used_names: 44 arg.name = new_name 45 break 46 i += 1 47 used_names.add(arg.name) 48 else: 49 arg.name = name 50 arg.type = t.dtype.as_datatype_enum 51 return arg 52 53 54def _is_in_placeholders(op, func_arg_placeholders): 55 """Checks whether any output of this op is in func_arg_placeholders.""" 56 return op.values() and any(x.name in func_arg_placeholders 57 for x in op.values()) 58 59 60def _get_node_def(op): 61 return op.node_def # pylint: disable=protected-access 62 63 64def _get_op_def(op): 65 return op.op_def or op_def_registry.get_registered_ops()[op.type] 66 67 68def _create_input_dict(function_graph, 69 func_arg_placeholders, 70 initial_value=None): 71 """Create a mapping from graph tensor names to function tensor names.""" 72 if initial_value is None: 73 input_dict = {} 74 else: 75 input_dict = dict(initial_value) 76 for op in function_graph.get_operations(): 77 if _is_in_placeholders(op, func_arg_placeholders): 78 input_dict[op.name] = op.name 79 else: 80 op_def = _get_op_def(op) 81 attrs = _get_node_def(op).attr 82 o = 0 83 for arg_def in op_def.output_arg: 84 if arg_def.number_attr: 85 num = attrs[arg_def.number_attr].i 86 elif arg_def.type_list_attr: 87 num = len(attrs[arg_def.type_list_attr].list.type) 88 else: 89 num = 1 90 for i in range(num): 91 result = "%s:%s:%d" % (op.name, arg_def.name, i) 92 input_dict[op.values()[o].name] = result 93 if o == 0: 94 input_dict[op.name] = result 95 o += 1 96 return input_dict 97 98 99def _add_op_node(op, func, input_dict): 100 """Converts an op to a function def node and add it to `func`.""" 101 # Add an entry in func.node_def 102 103 # Note that extend() makes a copy in this case, see: 104 # https://developers.google.com/protocol-buffers/docs/reference/python-generated#repeated-message-fields 105 func.node_def.extend([_get_node_def(op)]) 106 node_def = func.node_def[-1] 107 for i in range(len(node_def.input)): 108 if not node_def.input[i].startswith("^"): 109 assert node_def.input[i] in input_dict, ("%s missing from %s" % 110 (node_def.input[i], 111 input_dict.items())) 112 node_def.input[i] = input_dict[node_def.input[i]] 113 # The function is stateful if any of its operations are stateful. 114 # NOTE(mrry): The "Const" node typically does not have an `OpDef` associated 115 # with it, so we assume any nodes without an `OpDef` are stateless. 116 # TODO(skyewm): Remove the `is not None` test after we transition to the C 117 # API. 118 if op.op_def is not None and op.op_def.is_stateful: 119 func.signature.is_stateful = True 120 121 122def graph_to_function_def(graph, operations, inputs, outputs, out_names=None): 123 """Returns `graph` as a `FunctionDef` protocol buffer. 124 125 This method creates a [`FunctionDef`]( 126 https://www.tensorflow.org/code/tensorflow/core/framework/function.proto) 127 protocol buffer that contains all the ops in `operations`. The 128 operations become the body of the function. 129 130 The arguments `inputs` and `outputs` will be listed as the inputs 131 and outputs tensors of the function. They must be lists of 132 tensors present in the graph. The lists can optionally be empty. 133 134 Args: 135 graph: Graph. 136 operations: the operations to put in the function. Must be a subset of 137 the operations in the graph. 138 inputs: List of tensors. Inputs to the function. 139 outputs: List of tensors. Outputs of the function. 140 out_names: Optional list of string names for the outputs. 141 142 Returns: 143 A FunctionDef protocol buffer. 144 145 Raises: 146 ValueError: if out_names is specified and the wrong length. 147 """ 148 func = function_pb2.FunctionDef() 149 func.signature.name = "_" 150 used_names = set() 151 func.signature.input_arg.extend( 152 [_tensor_to_argdef(i, used_names=used_names) for i in inputs]) 153 # Initializes the input map with all placeholder input tensors. 154 initial_dict = {} 155 for o, m in zip(inputs, func.signature.input_arg): 156 initial_dict[o.name] = m.name 157 if out_names is None: 158 used_names = set() 159 func.signature.output_arg.extend( 160 [_tensor_to_argdef(o, used_names=used_names) for o in outputs]) 161 elif len(outputs) != len(out_names): 162 raise errors_impl.InvalidArgumentError( 163 None, None, 164 "output names must be either empty or equal in size to outputs. " 165 "output names size = %d outputs size = %d" % 166 (len(out_names), len(outputs))) 167 elif len(out_names) != len(set(out_names)): 168 raise ValueError( 169 "Must not have duplicates in out_names: %s" % ", ".join(out_names)) 170 else: 171 func.signature.output_arg.extend( 172 [_tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)]) 173 func_arg_placeholders = set([i.name for i in inputs]) 174 input_dict = _create_input_dict(graph, func_arg_placeholders, 175 initial_value=initial_dict) 176 177 for op in operations: 178 if _is_in_placeholders(op, func_arg_placeholders): 179 continue 180 _add_op_node(op, func, input_dict) 181 182 if out_names is None: 183 for index, o in enumerate(outputs): 184 k = func.signature.output_arg[index].name 185 func.ret[k] = input_dict[o.name] 186 else: 187 for o, n in zip(outputs, out_names): 188 func.ret[n] = input_dict[o.name] 189 190 return func 191