1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Utlity to convert FunctionDef to GraphDef and Graph."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.framework import graph_pb2
22from tensorflow.core.framework import types_pb2
23from tensorflow.core.framework import versions_pb2
24from tensorflow.python.eager import context
25from tensorflow.python.framework import importer
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import versions
28from tensorflow.python.framework.func_graph import FuncGraph
29
30
31def function_def_to_graph(fdef, input_shapes=None):
32  """Converts a FunctionDef to a FuncGraph (sub-class Graph).
33
34  The returned FuncGraph's `name`, `inputs` and `outputs` fields will be set.
35  The input tensors are represented as placeholders.
36
37  Note: `FuncGraph.inputs` and `FuncGraph.captures` are not set and may be set
38  by the caller.
39
40  Args:
41    fdef: FunctionDef.
42    input_shapes: Optional. A list of TensorShape objects of the shapes of
43      function inputs. If specified, its length must match length of
44      `fdef.signature.input_arg`. If a shape is None, the corresponding input
45      placeholder will have unknown shape.
46
47  Returns:
48    A FuncGraph.
49  """
50  func_graph = FuncGraph(fdef.signature.name)
51  graph_def, nested_to_flat_tensor_name = function_def_to_graph_def(
52      fdef, input_shapes)
53
54  with func_graph.as_default():
55    # Add all function nodes to the graph.
56    importer.import_graph_def(graph_def, name="")
57
58    # Initialize fields specific to FuncGraph.
59
60    # inputs
61    input_tensor_names = [
62        nested_to_flat_tensor_name[arg.name] for arg in fdef.signature.input_arg
63    ]
64    func_graph.inputs = [
65        func_graph.get_tensor_by_name(name) for name in input_tensor_names
66    ]
67
68    # outputs
69    output_tensor_names = [
70        nested_to_flat_tensor_name[fdef.ret[arg.name]]
71        for arg in fdef.signature.output_arg
72    ]
73    func_graph.outputs = [
74        func_graph.get_tensor_by_name(name) for name in output_tensor_names
75    ]
76    func_graph.control_outputs = [
77        func_graph.get_operation_by_name(fdef.control_ret[ret_name])
78        for ret_name in fdef.signature.control_output
79    ]
80
81  return func_graph
82
83
84def _is_function(fname):
85  """Checks for a function definition with `fname` in the current context."""
86  if context.executing_eagerly():
87    return context.context().has_function(fname)
88  else:
89    return ops.get_default_graph()._is_function(fname)  # pylint: disable=protected-access
90
91
92def function_def_to_graph_def(fdef, input_shapes=None):
93  """Convert a FunctionDef to a GraphDef.
94
95  Steps:
96  1. Creates placeholder nodes corresponding to inputs in
97     `FunctionDef.signature.input_arg`.
98  2. Adds NodeDefs in `FunctionDef.node_def` to `GraphDef.node`.
99  3. Renames inputs of all nodes to use the convention of GraphDef instead of
100     FunctionDef. See comment on `FunctionDef.node_def` on how the tensor naming
101     in FunctionDefs is different from GraphDefs.
102
103  Args:
104    fdef: FunctionDef.
105    input_shapes: Optional. A list of TensorShape objects of the shapes of
106      function inputs. If specified, its length must match length of
107      `fdef.signature.input_arg`. If a shape is None, the corresponding input
108      placeholder will have unknown shape.
109
110  Returns:
111    A tuple of (GraphDef, dict<string, string>). The dict contains a mapping
112    from nested tensor names (in FunctionDef) to flattened names (in GraphDef).
113
114  Raises:
115    ValueError: If the length of input_shapes does not match the number of
116      input_args or if the FunctionDef is invalid.
117  """
118  graph_def = graph_pb2.GraphDef()
119  graph_def.versions.CopyFrom(
120      versions_pb2.VersionDef(
121          producer=versions.GRAPH_DEF_VERSION,
122          min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER))
123
124  # Copy *all* functions from outer graph to `graph_def` so that both direct
125  # and indirect references are safely handled.
126  ops.get_default_graph()._copy_functions_to_graph_def(graph_def, 0)  # pylint: disable=protected-access
127
128  if input_shapes and len(input_shapes) != len(fdef.signature.input_arg):
129    raise ValueError("Length of input_shapes must match the number of " +
130                     "input_args. len(input_shapes): {} len(input_arg): {}".
131                     format(len(input_shapes), len(fdef.signature.input_arg)))
132
133  # 1. Create placeholders for input nodes.
134  for i, arg_def in enumerate(fdef.signature.input_arg):
135    node_def = graph_def.node.add()
136    node_def.name = arg_def.name
137    node_def.op = "Placeholder"
138    node_def.attr["dtype"].type = arg_def.type
139    if input_shapes and input_shapes[i] is not None:
140      node_def.attr["shape"].shape.CopyFrom(input_shapes[i].as_proto())
141
142  # 2. Copy all body NodeDefs to the GraphDef.
143  graph_def.node.extend(fdef.node_def)
144
145  # 3. Perform the renaming.
146
147  # Build the tensor name mapping then flatten the tensor names.
148  # See comment on `FunctionDef.node_def` on how the tensor naming in
149  # FunctionDefs is different from GraphDefs.
150  nested_to_flat_tensor_name = {}
151
152  for arg_def in fdef.signature.input_arg:
153    nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name)
154    control_name = "^" + arg_def.name
155    nested_to_flat_tensor_name[control_name] = control_name
156
157  for node_def in fdef.node_def:
158    op_def = ops.get_default_graph()._get_op_def(node_def.op)  # pylint: disable=protected-access
159
160    for attr in op_def.attr:
161      if attr.type == "func":
162        fname = node_def.attr[attr.name].func.name
163        if not _is_function(fname):
164          raise ValueError("%s function not found." % fname)
165      elif attr.type == "list(func)":
166        for fn in node_def.attr[attr.name].list.func:
167          fname = fn.name
168          if not _is_function(fname):
169            raise ValueError("%s function not found." % fname)
170
171    # Iterate over output_args in op_def to build the map.
172    # Index of the output tensor in the flattened list of *all* output
173    # tensors of the op.
174    flattened_index = 0
175    for arg_def in op_def.output_arg:
176      num_args = _get_num_args(arg_def, node_def)
177      for i in range(num_args):
178        # Map tensor names from "node_name:output_arg_name:index" to
179        # "node_name:flattened_index".
180        nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i)
181        flat_name = "{}:{}".format(node_def.name, flattened_index)
182        nested_to_flat_tensor_name[nested_name] = flat_name
183        flattened_index += 1
184    control_name = "^" + node_def.name
185    nested_to_flat_tensor_name[control_name] = control_name
186
187  # Update inputs of all nodes in graph.
188  for node_def in graph_def.node:
189    for i in range(len(node_def.input)):
190      node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]]
191
192  return graph_def, nested_to_flat_tensor_name
193
194
195# Based on implementation in core/framework/node_def_util.cc::ComputeArgRange.
196def _get_num_args(arg_def, node_def):
197  if arg_def.number_attr:
198    return node_def.attr[arg_def.number_attr].i
199  elif arg_def.type_list_attr:
200    return len(node_def.attr[arg_def.type_list_attr].list.type)
201  elif arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID:
202    return 1
203  else:
204    raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def)))
205