1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Utility to convert FunctionDef to GraphDef and Graph."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import itertools
22
23
24from tensorflow.core.framework import function_pb2
25from tensorflow.core.framework import graph_pb2
26from tensorflow.core.framework import tensor_shape_pb2
27from tensorflow.core.framework import types_pb2
28from tensorflow.core.framework import versions_pb2
29from tensorflow.python.eager import context
30from tensorflow.python.framework import cpp_shape_inference_pb2
31from tensorflow.python.framework import importer
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import versions
34from tensorflow.python.framework.func_graph import FuncGraph
35from tensorflow.python.ops import resource_variable_ops
36
37
38def function_def_to_graph(fdef, input_shapes=None):
39  """Converts a FunctionDef to a FuncGraph (sub-class Graph).
40
41  The returned FuncGraph's `name`, `inputs` and `outputs` fields will be set.
42  The input tensors are represented as placeholders.
43
44  Note: `FuncGraph.inputs` and `FuncGraph.captures` are not set and may be set
45  by the caller.
46
47  Args:
48    fdef: FunctionDef.
49    input_shapes: Optional. A list of TensorShape objects of the shapes of
50      function inputs. Defaults to the function's "_input_shapes" attribute. If
51      specified, its length must match length of `fdef.signature.input_arg`. If
52      a shape is None, the corresponding input placeholder will have unknown
53      shape.
54
55  Returns:
56    A FuncGraph.
57  """
58  func_graph = FuncGraph(fdef.signature.name)
59  if input_shapes is None:
60    input_shapes_attr = fdef.attr.get("_input_shapes", None)
61    if input_shapes_attr is not None:
62      input_shapes = input_shapes_attr.list.shape
63  graph_def, nested_to_flat_tensor_name = function_def_to_graph_def(
64      fdef, input_shapes)
65
66  with func_graph.as_default():
67    # Add all function nodes to the graph.
68    importer.import_graph_def_for_function(graph_def, name="")
69
70    # Initialize fields specific to FuncGraph.
71
72    # inputs
73    input_tensor_names = [
74        nested_to_flat_tensor_name[arg.name] for arg in fdef.signature.input_arg
75    ]
76    func_graph.inputs = [
77        func_graph.get_tensor_by_name(name) for name in input_tensor_names
78    ]
79
80    # outputs
81    output_tensor_names = [
82        nested_to_flat_tensor_name[fdef.ret[arg.name]]
83        for arg in fdef.signature.output_arg
84    ]
85    func_graph.outputs = [
86        func_graph.get_tensor_by_name(name) for name in output_tensor_names
87    ]
88    func_graph.control_outputs = [
89        func_graph.get_operation_by_name(fdef.control_ret[ret_name])
90        for ret_name in fdef.signature.control_output
91    ]
92
93    _set_handle_data(func_graph, fdef)
94
95    for node in graph_def.node:
96      output_shapes = node.attr.get("_output_shapes", None)
97      if output_shapes is not None:
98        op = func_graph.get_operation_by_name(node.name)
99        # _output_shapes for functions can sometimes be too long because the
100        # output-intermediates-for-gradients version of the function was
101        # substituted before saving. We'll accept that here. (See b/133666530).
102        for output_index, shape in enumerate(
103            output_shapes.list.shape[:len(op.outputs)]):
104          op.outputs[output_index].set_shape(shape)
105    output_names = {}
106    for ret_arg_def, tensor_name in zip(
107        fdef.signature.output_arg, output_tensor_names):
108      output_names[ops.tensor_id(
109          func_graph.get_tensor_by_name(tensor_name))] = (
110              ret_arg_def.name)
111    func_graph._output_names = output_names  # pylint: disable=protected-access
112  return func_graph
113
114
115def is_function(fname):
116  """Checks for a function definition with `fname` in the current context."""
117  if context.executing_eagerly():
118    return context.context().has_function(fname)
119  else:
120    graph = ops.get_default_graph()
121    while graph is not None:
122      if graph._is_function(fname):  # pylint: disable=protected-access
123        return True
124      if hasattr(graph, "outer_graph"):
125        graph = graph.outer_graph
126      else:
127        return False
128
129
130def function_def_to_graph_def(fdef, input_shapes=None):
131  """Convert a FunctionDef to a GraphDef.
132
133  Steps:
134  1. Creates placeholder nodes corresponding to inputs in
135     `FunctionDef.signature.input_arg`.
136  2. Adds NodeDefs in `FunctionDef.node_def` to `GraphDef.node`.
137  3. Renames inputs of all nodes to use the convention of GraphDef instead of
138     FunctionDef. See comment on `FunctionDef.node_def` on how the tensor naming
139     in FunctionDefs is different from GraphDefs.
140
141  Args:
142    fdef: FunctionDef.
143    input_shapes: Optional. A list of TensorShape objects of the shapes of
144      function inputs. If specified, its length must match length of
145      `fdef.signature.input_arg`. If a shape is None, the corresponding input
146      placeholder will have unknown shape.
147
148  Returns:
149    A tuple of (GraphDef, dict<string, string>). The dict contains a mapping
150    from nested tensor names (in FunctionDef) to flattened names (in GraphDef).
151
152  Raises:
153    ValueError: If the length of input_shapes does not match the number of
154      input_args or if the FunctionDef is invalid.
155  """
156  graph_def = graph_pb2.GraphDef()
157  graph_def.versions.CopyFrom(
158      versions_pb2.VersionDef(
159          producer=versions.GRAPH_DEF_VERSION,
160          min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER))
161
162  default_graph = ops.get_default_graph()
163
164  copied_functions = set()
165
166  if input_shapes and len(input_shapes) != len(fdef.signature.input_arg):
167    raise ValueError("Length of input_shapes must match the number of " +
168                     "input_args. len(input_shapes): {} len(input_arg): {}".
169                     format(len(input_shapes), len(fdef.signature.input_arg)))
170
171  # 1. Create placeholders for input nodes.
172  for i, arg_def in enumerate(fdef.signature.input_arg):
173    node_def = graph_def.node.add()
174    node_def.name = arg_def.name
175    node_def.op = "Placeholder"
176    node_def.attr["dtype"].type = arg_def.type
177    if input_shapes and input_shapes[i] is not None:
178      input_shape = input_shapes[i]
179      if not isinstance(input_shape, tensor_shape_pb2.TensorShapeProto):
180        input_shape = input_shape.as_proto()
181      node_def.attr["shape"].shape.CopyFrom(input_shape)
182    arg_attrs = fdef.arg_attr[i].attr
183    for k in arg_attrs:
184      # Only copy internal attributes. Normal attributes for nodes cannot be
185      # applied to these Placeholder nodes.
186      if k == "_output_shapes":
187        node_def.attr["shape"].shape.CopyFrom(arg_attrs[k].list.shape[0])
188      elif k.startswith("_"):
189        node_def.attr[k].CopyFrom(arg_attrs[k])
190
191  # 2. Copy all body NodeDefs to the GraphDef.
192  graph_def.node.extend(fdef.node_def)
193
194  # 3. Perform the renaming.
195
196  # Build the tensor name mapping then flatten the tensor names.
197  # See comment on `FunctionDef.node_def` on how the tensor naming in
198  # FunctionDefs is different from GraphDefs.
199  nested_to_flat_tensor_name = {}
200
201  for arg_def in fdef.signature.input_arg:
202    nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name)
203    control_name = "^" + arg_def.name
204    nested_to_flat_tensor_name[control_name] = control_name
205
206  for node_def in fdef.node_def:
207    graph = default_graph
208    while True:
209      f = graph._functions.get(node_def.op, None)  # pylint: disable=protected-access
210      if f is not None or not hasattr(graph, "outer_graph"):
211        break
212      graph = graph.outer_graph
213
214    if f is not None:
215      op_def = f.definition.signature
216      if node_def.op not in copied_functions:
217        # Since this function is referenced as an op type, we have no choice but
218        # to copy it into the GraphDef if we want downstream tools to process
219        # it.
220        graph_def.library.function.add().CopyFrom(f.definition)
221        copied_functions.add(node_def.op)
222        if f.grad_func_name:
223          grad_def = function_pb2.GradientDef()
224          grad_def.function_name = f.name
225          grad_def.gradient_func = f.grad_func_name
226          graph_def.library.gradient.extend([grad_def])
227    else:
228      op_def = default_graph._get_op_def(node_def.op)  # pylint: disable=protected-access
229
230    for attr in op_def.attr:
231      if attr.type == "func":
232        fname = node_def.attr[attr.name].func.name
233        if not is_function(fname):
234          raise ValueError("%s function not found." % fname)
235      elif attr.type == "list(func)":
236        for fn in node_def.attr[attr.name].list.func:
237          fname = fn.name
238          if not is_function(fname):
239            raise ValueError("%s function not found." % fname)
240
241    # Iterate over output_args in op_def to build the map.
242    # Index of the output tensor in the flattened list of *all* output
243    # tensors of the op.
244    flattened_index = 0
245    for arg_def in op_def.output_arg:
246      num_args = _get_num_args(arg_def, node_def)
247      for i in range(num_args):
248        # Map tensor names from "node_name:output_arg_name:index" to
249        # "node_name:flattened_index".
250        nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i)
251        flat_name = "{}:{}".format(node_def.name, flattened_index)
252        nested_to_flat_tensor_name[nested_name] = flat_name
253        flattened_index += 1
254    control_name = "^" + node_def.name
255    nested_to_flat_tensor_name[control_name] = control_name
256
257  # Update inputs of all nodes in graph.
258  for node_def in graph_def.node:
259    for i in range(len(node_def.input)):
260      node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]]
261
262  return graph_def, nested_to_flat_tensor_name
263
264
265# Based on implementation in core/framework/node_def_util.cc::ComputeArgRange.
266def _get_num_args(arg_def, node_def):
267  if arg_def.number_attr:
268    return node_def.attr[arg_def.number_attr].i
269  elif arg_def.type_list_attr:
270    return len(node_def.attr[arg_def.type_list_attr].list.type)
271  elif arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID:
272    return 1
273  else:
274    raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def)))
275
276
277def _set_handle_data(func_graph, fdef):
278  """Adds handle data for resource type inputs and outputs."""
279  for tensor, arg_def in itertools.chain(
280      zip(func_graph.inputs, fdef.signature.input_arg),
281      zip(func_graph.outputs, fdef.signature.output_arg)):
282    if arg_def.handle_data:
283      shape_and_dtype = arg_def.handle_data[0]
284      handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData()
285      handle_data.is_set = True
286      handle_data.shape_and_type.append(
287          cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType(
288              shape=shape_and_dtype.shape, dtype=shape_and_dtype.dtype))
289      resource_variable_ops._set_handle_shapes_and_types(  # pylint: disable=protected-access
290          tensor, handle_data, True)
291