1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/algorithm.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <vector>
21 
22 #include "tensorflow/core/platform/logging.h"
23 
24 namespace tensorflow {
25 namespace {
26 template <typename T>
DFSFromHelper(const Graph & g,gtl::ArraySlice<T> start,const std::function<void (T)> & enter,const std::function<void (T)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)27 void DFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
28                    const std::function<void(T)>& enter,
29                    const std::function<void(T)>& leave,
30                    const NodeComparator& stable_comparator,
31                    const EdgeFilter& edge_filter) {
32   // Stack of work to do.
33   struct Work {
34     T node;
35     bool leave;  // Are we entering or leaving n?
36   };
37   std::vector<Work> stack(start.size());
38   for (int i = 0; i < start.size(); ++i) {
39     stack[i] = Work{start[i], false};
40   }
41 
42   std::vector<bool> visited(g.num_node_ids(), false);
43   while (!stack.empty()) {
44     Work w = stack.back();
45     stack.pop_back();
46 
47     T n = w.node;
48     if (w.leave) {
49       leave(n);
50       continue;
51     }
52 
53     if (visited[n->id()]) continue;
54     visited[n->id()] = true;
55     if (enter) enter(n);
56 
57     // Arrange to call leave(n) when all done with descendants.
58     if (leave) stack.push_back(Work{n, true});
59 
60     auto add_work = [&visited, &stack](Node* out) {
61       if (!visited[out->id()]) {
62         // Note; we must not mark as visited until we actually process it.
63         stack.push_back(Work{out, false});
64       }
65     };
66 
67     if (stable_comparator) {
68       std::vector<Node*> nodes_sorted;
69       for (const Edge* out_edge : n->out_edges()) {
70         if (!edge_filter || edge_filter(*out_edge)) {
71           nodes_sorted.emplace_back(out_edge->dst());
72         }
73       }
74       std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
75       for (Node* out : nodes_sorted) {
76         add_work(out);
77       }
78     } else {
79       for (const Edge* out_edge : n->out_edges()) {
80         if (!edge_filter || edge_filter(*out_edge)) {
81           add_work(out_edge->dst());
82         }
83       }
84     }
85   }
86 }
87 }  // namespace
88 
DFS(const Graph & g,const std::function<void (Node *)> & enter,const std::function<void (Node *)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)89 void DFS(const Graph& g, const std::function<void(Node*)>& enter,
90          const std::function<void(Node*)>& leave,
91          const NodeComparator& stable_comparator,
92          const EdgeFilter& edge_filter) {
93   DFSFromHelper(g, {g.source_node()}, enter, leave, stable_comparator,
94                 edge_filter);
95 }
96 
DFSFrom(const Graph & g,gtl::ArraySlice<Node * > start,const std::function<void (Node *)> & enter,const std::function<void (Node *)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)97 void DFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
98              const std::function<void(Node*)>& enter,
99              const std::function<void(Node*)>& leave,
100              const NodeComparator& stable_comparator,
101              const EdgeFilter& edge_filter) {
102   DFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter);
103 }
104 
DFSFrom(const Graph & g,gtl::ArraySlice<const Node * > start,const std::function<void (const Node *)> & enter,const std::function<void (const Node *)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)105 void DFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
106              const std::function<void(const Node*)>& enter,
107              const std::function<void(const Node*)>& leave,
108              const NodeComparator& stable_comparator,
109              const EdgeFilter& edge_filter) {
110   DFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter);
111 }
112 
ReverseDFS(const Graph & g,const std::function<void (Node *)> & enter,const std::function<void (Node *)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)113 void ReverseDFS(const Graph& g, const std::function<void(Node*)>& enter,
114                 const std::function<void(Node*)>& leave,
115                 const NodeComparator& stable_comparator,
116                 const EdgeFilter& edge_filter) {
117   ReverseDFSFrom(g, {g.sink_node()}, enter, leave, stable_comparator,
118                  edge_filter);
119 }
120 
121 namespace {
122 
123 template <typename T>
ReverseDFSFromHelper(const Graph & g,gtl::ArraySlice<T> start,const std::function<void (T)> & enter,const std::function<void (T)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)124 void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
125                           const std::function<void(T)>& enter,
126                           const std::function<void(T)>& leave,
127                           const NodeComparator& stable_comparator,
128                           const EdgeFilter& edge_filter) {
129   // Stack of work to do.
130   struct Work {
131     T node;
132     bool leave;  // Are we entering or leaving n?
133   };
134   std::vector<Work> stack(start.size());
135   for (int i = 0; i < start.size(); ++i) {
136     stack[i] = Work{start[i], false};
137   }
138 
139   std::vector<bool> visited(g.num_node_ids(), false);
140   while (!stack.empty()) {
141     Work w = stack.back();
142     stack.pop_back();
143 
144     T n = w.node;
145     if (w.leave) {
146       leave(n);
147       continue;
148     }
149 
150     if (visited[n->id()]) continue;
151     visited[n->id()] = true;
152     if (enter) enter(n);
153 
154     // Arrange to call leave(n) when all done with descendants.
155     if (leave) stack.push_back(Work{n, true});
156 
157     auto add_work = [&visited, &stack](T out) {
158       if (!visited[out->id()]) {
159         // Note; we must not mark as visited until we actually process it.
160         stack.push_back(Work{out, false});
161       }
162     };
163 
164     if (stable_comparator) {
165       std::vector<T> nodes_sorted;
166       for (const Edge* in_edge : n->in_edges()) {
167         if (!edge_filter || edge_filter(*in_edge)) {
168           nodes_sorted.emplace_back(in_edge->src());
169         }
170       }
171       std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
172       for (T in : nodes_sorted) {
173         add_work(in);
174       }
175     } else {
176       for (const Edge* in_edge : n->in_edges()) {
177         if (!edge_filter || edge_filter(*in_edge)) {
178           add_work(in_edge->src());
179         }
180       }
181     }
182   }
183 }
184 
185 }  // namespace
186 
ReverseDFSFrom(const Graph & g,gtl::ArraySlice<const Node * > start,const std::function<void (const Node *)> & enter,const std::function<void (const Node *)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)187 void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
188                     const std::function<void(const Node*)>& enter,
189                     const std::function<void(const Node*)>& leave,
190                     const NodeComparator& stable_comparator,
191                     const EdgeFilter& edge_filter) {
192   ReverseDFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter);
193 }
194 
ReverseDFSFrom(const Graph & g,gtl::ArraySlice<Node * > start,const std::function<void (Node *)> & enter,const std::function<void (Node *)> & leave,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)195 void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
196                     const std::function<void(Node*)>& enter,
197                     const std::function<void(Node*)>& leave,
198                     const NodeComparator& stable_comparator,
199                     const EdgeFilter& edge_filter) {
200   ReverseDFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter);
201 }
202 
GetPostOrder(const Graph & g,std::vector<Node * > * order,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)203 void GetPostOrder(const Graph& g, std::vector<Node*>* order,
204                   const NodeComparator& stable_comparator,
205                   const EdgeFilter& edge_filter) {
206   order->clear();
207   DFS(g, nullptr, [order](Node* n) { order->push_back(n); }, stable_comparator,
208       edge_filter);
209 }
210 
GetReversePostOrder(const Graph & g,std::vector<Node * > * order,const NodeComparator & stable_comparator,const EdgeFilter & edge_filter)211 void GetReversePostOrder(const Graph& g, std::vector<Node*>* order,
212                          const NodeComparator& stable_comparator,
213                          const EdgeFilter& edge_filter) {
214   GetPostOrder(g, order, stable_comparator, edge_filter);
215   std::reverse(order->begin(), order->end());
216 }
217 
PruneForReverseReachability(Graph * g,std::unordered_set<const Node * > start)218 bool PruneForReverseReachability(Graph* g,
219                                  std::unordered_set<const Node*> start) {
220   // Compute set of nodes that we need to traverse in order to reach
221   // the nodes in "start" by performing a breadth-first search from those
222   // nodes, and accumulating the visited nodes.
223   std::vector<bool> visited(g->num_node_ids());
224   for (auto node : start) {
225     visited[node->id()] = true;
226   }
227   std::deque<const Node*> queue(start.begin(), start.end());
228   while (!queue.empty()) {
229     const Node* n = queue.front();
230     queue.pop_front();
231     for (const Node* in : n->in_nodes()) {
232       if (!visited[in->id()]) {
233         visited[in->id()] = true;
234         queue.push_back(in);
235         VLOG(2) << "Reverse reach : " << n->name() << " from " << in->name();
236       }
237     }
238   }
239 
240   // Make a pass over the graph to remove nodes not in "visited".
241   bool any_removed = false;
242   for (int i = 0; i < visited.size(); ++i) {
243     if (!visited[i]) {
244       Node* n = g->FindNodeId(i);
245       if (n != nullptr && !n->IsSource() && !n->IsSink()) {
246         g->RemoveNode(n);
247         any_removed = true;
248       }
249     }
250   }
251   return any_removed;
252 }
253 
FixupSourceAndSinkEdges(Graph * g)254 bool FixupSourceAndSinkEdges(Graph* g) {
255   // Connect all nodes with no incoming edges to source.
256   // Connect all nodes with no outgoing edges to sink.
257   bool changed = false;
258   for (Node* n : g->nodes()) {
259     if (!n->IsSource() && n->in_edges().empty()) {
260       g->AddControlEdge(g->source_node(), n,
261                         true /* skip test for duplicates */);
262       changed = true;
263     }
264     if (!n->IsSink() && n->out_edges().empty()) {
265       g->AddControlEdge(n, g->sink_node(), true /* skip test for duplicates */);
266       changed = true;
267     }
268   }
269   return changed;
270 }
271 
272 }  // namespace tensorflow
273