1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Utility to convert a Graph to a FunctionDef."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import re
22
23from tensorflow.core.framework import function_pb2
24from tensorflow.core.framework import op_def_pb2
25from tensorflow.python.framework import errors_impl
26from tensorflow.python.framework import op_def_registry
27
28
29def _make_argname_from_tensor_name(name):
30  return re.sub(":0$", "", name).replace(":", "_o")
31
32
33def _tensor_to_argdef(t, name=None, used_names=None):
34  """Convert tensor t to an argdef, with a specified name or a unique name."""
35  arg = op_def_pb2.OpDef.ArgDef()
36  if name is None:
37    arg.name = _make_argname_from_tensor_name(t.name)
38    if used_names is not None:
39      if arg.name in used_names:
40        i = 0
41        while True:
42          new_name = "%s_U%d" % (arg.name, i)
43          if new_name not in used_names:
44            arg.name = new_name
45            break
46          i += 1
47      used_names.add(arg.name)
48  else:
49    arg.name = name
50  arg.type = t.dtype.as_datatype_enum
51  return arg
52
53
54def _is_in_placeholders(op, func_arg_placeholders):
55  """Checks whether any output of this op is in func_arg_placeholders."""
56  return op.values() and any(x.name in func_arg_placeholders
57                             for x in op.values())
58
59
60def _get_node_def(op):
61  return op.node_def  # pylint: disable=protected-access
62
63
64def _get_op_def(op):
65  return op.op_def or op_def_registry.get_registered_ops()[op.type]
66
67
68def _create_input_dict(function_graph,
69                       func_arg_placeholders,
70                       initial_value=None):
71  """Create a mapping from graph tensor names to function tensor names."""
72  if initial_value is None:
73    input_dict = {}
74  else:
75    input_dict = dict(initial_value)
76  for op in function_graph.get_operations():
77    if _is_in_placeholders(op, func_arg_placeholders):
78      input_dict[op.name] = op.name
79    else:
80      op_def = _get_op_def(op)
81      attrs = _get_node_def(op).attr
82      o = 0
83      for arg_def in op_def.output_arg:
84        if arg_def.number_attr:
85          num = attrs[arg_def.number_attr].i
86        elif arg_def.type_list_attr:
87          num = len(attrs[arg_def.type_list_attr].list.type)
88        else:
89          num = 1
90        for i in range(num):
91          result = "%s:%s:%d" % (op.name, arg_def.name, i)
92          input_dict[op.values()[o].name] = result
93          if o == 0:
94            input_dict[op.name] = result
95          o += 1
96  return input_dict
97
98
99def _add_op_node(op, func, input_dict):
100  """Converts an op to a function def node and add it to `func`."""
101  # Add an entry in func.node_def
102
103  # Note that extend() makes a copy in this case, see:
104  # https://developers.google.com/protocol-buffers/docs/reference/python-generated#repeated-message-fields
105  func.node_def.extend([_get_node_def(op)])
106  node_def = func.node_def[-1]
107  for i in range(len(node_def.input)):
108    if not node_def.input[i].startswith("^"):
109      assert node_def.input[i] in input_dict, ("%s missing from %s" %
110                                               (node_def.input[i],
111                                                input_dict.items()))
112      node_def.input[i] = input_dict[node_def.input[i]]
113  # The function is stateful if any of its operations are stateful.
114  # NOTE(mrry): The "Const" node typically does not have an `OpDef` associated
115  # with it, so we assume any nodes without an `OpDef` are stateless.
116  # TODO(skyewm): Remove the `is not None` test after we transition to the C
117  # API.
118  if op.op_def is not None and op.op_def.is_stateful:
119    func.signature.is_stateful = True
120
121
122def graph_to_function_def(graph, operations, inputs, outputs, out_names=None):
123  """Returns `graph` as a `FunctionDef` protocol buffer.
124
125  This method creates a [`FunctionDef`](
126  https://www.tensorflow.org/code/tensorflow/core/framework/function.proto)
127  protocol buffer that contains all the ops in `operations`.  The
128  operations become the body of the function.
129
130  The arguments `inputs` and `outputs` will be listed as the inputs
131  and outputs tensors of the function.  They must be lists of
132  tensors present in the graph.  The lists can optionally be empty.
133
134  Args:
135    graph: Graph.
136    operations: the operations to put in the function. Must be a subset of
137     the operations in the graph.
138    inputs: List of tensors. Inputs to the function.
139    outputs: List of tensors. Outputs of the function.
140    out_names: Optional list of string names for the outputs.
141
142  Returns:
143    A FunctionDef protocol buffer.
144
145  Raises:
146    ValueError: if out_names is specified and the wrong length.
147  """
148  func = function_pb2.FunctionDef()
149  func.signature.name = "_"
150  used_names = set()
151  func.signature.input_arg.extend(
152      [_tensor_to_argdef(i, used_names=used_names) for i in inputs])
153  # Initializes the input map with all placeholder input tensors.
154  initial_dict = {}
155  for o, m in zip(inputs, func.signature.input_arg):
156    initial_dict[o.name] = m.name
157  if out_names is None:
158    used_names = set()
159    func.signature.output_arg.extend(
160        [_tensor_to_argdef(o, used_names=used_names) for o in outputs])
161  elif len(outputs) != len(out_names):
162    raise errors_impl.InvalidArgumentError(
163        None, None,
164        "output names must be either empty or equal in size to outputs. "
165        "output names size = %d outputs size = %d" %
166        (len(out_names), len(outputs)))
167  elif len(out_names) != len(set(out_names)):
168    raise ValueError(
169        "Must not have duplicates in out_names: %s" % ", ".join(out_names))
170  else:
171    func.signature.output_arg.extend(
172        [_tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)])
173  func_arg_placeholders = set([i.name for i in inputs])
174  input_dict = _create_input_dict(graph, func_arg_placeholders,
175                                  initial_value=initial_dict)
176
177  for op in operations:
178    if _is_in_placeholders(op, func_arg_placeholders):
179      continue
180    _add_op_node(op, func, input_dict)
181
182  if out_names is None:
183    for index, o in enumerate(outputs):
184      k = func.signature.output_arg[index].name
185      func.ret[k] = input_dict[o.name]
186  else:
187    for o, n in zip(outputs, out_names):
188      func.ret[n] = input_dict[o.name]
189
190  return func
191