1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helpers to manipulate a tensor graph in python. 16""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21import copy 22import six 23 24# pylint: disable=unused-import 25from tensorflow.core.framework import graph_pb2 26from tensorflow.core.framework import node_def_pb2 27from tensorflow.python.framework import ops 28from tensorflow.python.framework.graph_util_impl import _assert_nodes_are_present 29from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes 30from tensorflow.python.framework.graph_util_impl import _extract_graph_summary 31from tensorflow.python.framework.graph_util_impl import _node_name 32 33 34__all__ = ["fuse_op", "get_placeholders"] 35 36 37def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, 38 output_quantized, op_name, op_type): 39 """Fuse subgraph between input_nodes and output_nodes into a single custom op. 40 41 Args: 42 graph_def: A graph_pb2.GraphDef proto. 43 input_nodes: input nodes to the subgraph to be fused. 44 output_nodes: output nodes to the subgraph to be fused. 45 output_dtypes: A list of output datatypes for the custom op 46 output_quantized: A boolean flag that indicates if output is quantized 47 op_name: fused op name. 48 op_type: fused op type. 49 Returns: 50 The GraphDef of the new graph. 51 52 Raises: 53 TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto. 54 """ 55 56 if not isinstance(graph_def, graph_pb2.GraphDef): 57 raise TypeError("graph_def must be a graph_pb2.GraphDef proto.") 58 59 if isinstance(input_nodes, six.string_types): 60 raise TypeError("input_nodes must be a list.") 61 62 if isinstance(output_nodes, six.string_types): 63 raise TypeError("output_nodes must be a list.") 64 65 name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( 66 graph_def) 67 _assert_nodes_are_present(name_to_node, input_nodes + output_nodes) 68 69 # Nodes upto and including input_nodes 70 reachable_by_input = _bfs_for_reachable_nodes(input_nodes, name_to_input_name) 71 # Nodes upto and including output_nodes 72 reachable_by_output = _bfs_for_reachable_nodes(output_nodes, 73 name_to_input_name) 74 75 # Set of nodes in the list input_nodes 76 input_nodes_set = set(input_nodes) 77 78 # Set of nodes in the list output_nodes 79 output_nodes_set = set(output_nodes) 80 81 nodes_post_output = [] 82 for node in graph_def.node: 83 n = _node_name(node.name) 84 if n in reachable_by_output: 85 if n not in reachable_by_input and n not in output_nodes_set: 86 # n is between input and output, i.e., part of the fused op 87 next_to_visit = [n] 88 visited = set() 89 while next_to_visit: 90 cur_node = next_to_visit[0] 91 visited.add(cur_node) 92 del next_to_visit[0] 93 if cur_node in reachable_by_input and cur_node not in input_nodes_set: 94 raise TypeError("Node %s uses input %s not in input_nodes." % 95 (n, cur_node)) 96 if cur_node not in input_nodes_set: 97 next_to_visit += [ 98 input_node for input_node in name_to_input_name[cur_node] 99 if input_node not in visited 100 ] 101 elif n not in reachable_by_input: 102 nodes_post_output.append(n) 103 104 # Add all nodes upto the input nodes 105 out = graph_pb2.GraphDef() 106 reachable_by_input_sorted = sorted( 107 list(reachable_by_input), key=lambda n: name_to_seq_num[n]) 108 for node in reachable_by_input_sorted: 109 out.node.extend([copy.deepcopy(name_to_node[node])]) 110 111 # Add the custom op 112 new_node = node_def_pb2.NodeDef() 113 for node in input_nodes: 114 new_node.input.append(node) 115 new_node.attr["_output_types"].list.type[:] = output_dtypes 116 new_node.attr["_output_quantized"].b = output_quantized 117 new_node.op = op_type 118 new_node.name = op_name 119 out.node.extend([new_node]) 120 121 # Add the nodes in the output of the custom op 122 for index, n in enumerate(output_nodes): 123 assert len(name_to_node[n].input) == 1 124 new_node = copy.deepcopy(name_to_node[n]) 125 del new_node.input[:] 126 new_node.input.append(op_name + (":" + str(index) if index != 0 else "")) 127 out.node.extend([new_node]) 128 129 # Add the nodes post output_nodes 130 for n in nodes_post_output: 131 out.node.extend([copy.deepcopy(name_to_node[n])]) 132 133 out.library.CopyFrom(graph_def.library) 134 out.versions.CopyFrom(graph_def.versions) 135 return out 136 137 138def get_placeholders(graph): 139 """Get placeholders of a graph. 140 141 For example: 142 143 ```python 144 a = tf.placeholder(dtype=tf.float32, shape=[2, 2], name='a') 145 a = tf.placeholder(dtype=tf.int32, shape=[3, 2], name='b') 146 147 tf.contrib.framework.get_placeholders(tf.get_default_graph()) 148 # Returns: 149 # [<tf.Tensor 'a:0' shape=(2, 2) dtype=float32>, 150 # <tf.Tensor 'b:0' shape=(3, 2) dtype=int32>] 151 ``` 152 153 Args: 154 graph: A tf.Graph. 155 Returns: 156 A list contains all placeholders of given graph. 157 158 Raises: 159 TypeError: If `graph` is not a tensorflow graph. 160 """ 161 162 if not isinstance(graph, ops.Graph): 163 raise TypeError("Input graph needs to be a Graph: %s" % graph) 164 165 # For each placeholder() call, there is a corresponding 166 # operation of type 'Placeholder' registered to the graph. 167 # The return value (a Tensor) of placeholder() is the 168 # first output of this operation in fact. 169 operations = graph.get_operations() 170 result = [i.outputs[0] for i in operations if i.type == "Placeholder"] 171 return result 172