1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helpers to manipulate a tensor graph in python.
16"""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21import copy
22import six
23
24# pylint: disable=unused-import
25from tensorflow.core.framework import graph_pb2
26from tensorflow.core.framework import node_def_pb2
27from tensorflow.python.framework import ops
28from tensorflow.python.framework.graph_util_impl import _assert_nodes_are_present
29from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes
30from tensorflow.python.framework.graph_util_impl import _extract_graph_summary
31from tensorflow.python.framework.graph_util_impl import _node_name
32
33
34__all__ = ["fuse_op", "get_placeholders"]
35
36
37def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes,
38            output_quantized, op_name, op_type):
39  """Fuse subgraph between input_nodes and output_nodes into a single custom op.
40
41  Args:
42    graph_def: A graph_pb2.GraphDef proto.
43    input_nodes: input nodes to the subgraph to be fused.
44    output_nodes: output nodes to the subgraph to be fused.
45    output_dtypes: A list of output datatypes for the custom op
46    output_quantized: A boolean flag that indicates if output is quantized
47    op_name: fused op name.
48    op_type: fused op type.
49  Returns:
50    The GraphDef of the new graph.
51
52  Raises:
53    TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto.
54  """
55
56  if not isinstance(graph_def, graph_pb2.GraphDef):
57    raise TypeError("graph_def must be a graph_pb2.GraphDef proto.")
58
59  if isinstance(input_nodes, six.string_types):
60    raise TypeError("input_nodes must be a list.")
61
62  if isinstance(output_nodes, six.string_types):
63    raise TypeError("output_nodes must be a list.")
64
65  name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
66      graph_def)
67  _assert_nodes_are_present(name_to_node, input_nodes + output_nodes)
68
69  # Nodes upto and including input_nodes
70  reachable_by_input = _bfs_for_reachable_nodes(input_nodes, name_to_input_name)
71  # Nodes upto and including output_nodes
72  reachable_by_output = _bfs_for_reachable_nodes(output_nodes,
73                                                 name_to_input_name)
74
75  # Set of nodes in the list input_nodes
76  input_nodes_set = set(input_nodes)
77
78  # Set of nodes in the list output_nodes
79  output_nodes_set = set(output_nodes)
80
81  nodes_post_output = []
82  for node in graph_def.node:
83    n = _node_name(node.name)
84    if n in reachable_by_output:
85      if n not in reachable_by_input and n not in output_nodes_set:
86        # n is between input and output, i.e., part of the fused op
87        next_to_visit = [n]
88        visited = set()
89        while next_to_visit:
90          cur_node = next_to_visit[0]
91          visited.add(cur_node)
92          del next_to_visit[0]
93          if cur_node in reachable_by_input and cur_node not in input_nodes_set:
94            raise TypeError("Node %s uses input %s not in input_nodes." %
95                            (n, cur_node))
96          if cur_node not in input_nodes_set:
97            next_to_visit += [
98                input_node for input_node in name_to_input_name[cur_node]
99                if input_node not in visited
100            ]
101    elif n not in reachable_by_input:
102      nodes_post_output.append(n)
103
104  # Add all nodes upto the input nodes
105  out = graph_pb2.GraphDef()
106  reachable_by_input_sorted = sorted(
107      list(reachable_by_input), key=lambda n: name_to_seq_num[n])
108  for node in reachable_by_input_sorted:
109    out.node.extend([copy.deepcopy(name_to_node[node])])
110
111  # Add the custom op
112  new_node = node_def_pb2.NodeDef()
113  for node in input_nodes:
114    new_node.input.append(node)
115  new_node.attr["_output_types"].list.type[:] = output_dtypes
116  new_node.attr["_output_quantized"].b = output_quantized
117  new_node.op = op_type
118  new_node.name = op_name
119  out.node.extend([new_node])
120
121  # Add the nodes in the output of the custom op
122  for index, n in enumerate(output_nodes):
123    assert len(name_to_node[n].input) == 1
124    new_node = copy.deepcopy(name_to_node[n])
125    del new_node.input[:]
126    new_node.input.append(op_name + (":" + str(index) if index != 0 else ""))
127    out.node.extend([new_node])
128
129  # Add the nodes post output_nodes
130  for n in nodes_post_output:
131    out.node.extend([copy.deepcopy(name_to_node[n])])
132
133  out.library.CopyFrom(graph_def.library)
134  out.versions.CopyFrom(graph_def.versions)
135  return out
136
137
138def get_placeholders(graph):
139  """Get placeholders of a graph.
140
141  For example:
142
143  ```python
144  a = tf.placeholder(dtype=tf.float32, shape=[2, 2], name='a')
145  a = tf.placeholder(dtype=tf.int32, shape=[3, 2], name='b')
146
147  tf.contrib.framework.get_placeholders(tf.get_default_graph())
148  # Returns:
149  #  [<tf.Tensor 'a:0' shape=(2, 2) dtype=float32>,
150  #   <tf.Tensor 'b:0' shape=(3, 2) dtype=int32>]
151  ```
152
153  Args:
154    graph: A tf.Graph.
155  Returns:
156    A list contains all placeholders of given graph.
157
158  Raises:
159    TypeError: If `graph` is not a tensorflow graph.
160  """
161
162  if not isinstance(graph, ops.Graph):
163    raise TypeError("Input graph needs to be a Graph: %s" % graph)
164
165  # For each placeholder() call, there is a corresponding
166  # operation of type 'Placeholder' registered to the graph.
167  # The return value (a Tensor) of placeholder() is the
168  # first output of this operation in fact.
169  operations = graph.get_operations()
170  result = [i.outputs[0] for i in operations if i.type == "Placeholder"]
171  return result
172