1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/utils/traversal.h"
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "tensorflow/core/framework/node_def.pb.h"
20 #include "tensorflow/core/grappler/graph_topology_view.h"
21 
22 namespace tensorflow {
23 namespace grappler {
24 
25 namespace {
26 
27 struct DfsStackElem {
DfsStackElemtensorflow::grappler::__anon201039010111::DfsStackElem28   DfsStackElem(int node, bool children_visited, int src)
29       : node(node), children_visited(children_visited), src(src) {}
DfsStackElemtensorflow::grappler::__anon201039010111::DfsStackElem30   explicit DfsStackElem(int node) : DfsStackElem(node, false, -1) {}
31 
32   // Index of the node in the graph ∊ [0, num_nodes).
33   int node;
34   // `True` if visited all the input/output nodes (pushed all input/output nodes
35   // to the stack).
36   bool children_visited;
37   // Index of the node in the graph, from which we entered the `node`.
38   int src;
39 };
40 
41 enum class NodeState { kNotVisited, kVisiting, kDone };
42 
43 }  // namespace
44 
DfsTraversal(const GraphTopologyView & graph_view,const absl::Span<const NodeDef * const> from,const TraversalDirection direction,const DfsPredicates & predicates,const DfsCallbacks & callbacks)45 void DfsTraversal(const GraphTopologyView& graph_view,
46                   const absl::Span<const NodeDef* const> from,
47                   const TraversalDirection direction,
48                   const DfsPredicates& predicates,
49                   const DfsCallbacks& callbacks) {
50   std::vector<DfsStackElem> stack;
51   stack.reserve(from.size());
52 
53   for (const NodeDef* node : from) {
54     const absl::optional<int> node_idx = graph_view.GetNodeIndex(*node);
55     DCHECK(node_idx.has_value()) << "Illegal start node: " << node->name();
56     if (node_idx.has_value()) {
57       stack.emplace_back(node_idx.value());
58     }
59   }
60 
61   absl::flat_hash_map<int, NodeState> node_state;
62   while (!stack.empty()) {
63     DfsStackElem w = stack.back();
64     stack.pop_back();
65 
66     NodeState& state = node_state[w.node];
67     if (state == NodeState::kDone) continue;
68 
69     // Skip nodes that we should not enter.
70     if (predicates.enter && !predicates.enter(graph_view.GetNode(w.node))) {
71       state = NodeState::kDone;
72       continue;
73     }
74 
75     // We've processed all the children of this node.
76     if (w.children_visited) {
77       state = NodeState::kDone;
78       if (callbacks.post_order) {
79         callbacks.post_order(graph_view.GetNode(w.node));
80       }
81       continue;
82     }
83 
84     // Loop detected.
85     if (state == NodeState::kVisiting) {
86       if (callbacks.on_back_edge) {
87         callbacks.on_back_edge(graph_view.GetNode(w.src),
88                                graph_view.GetNode(w.node));
89       }
90       continue;
91     }
92 
93     state = NodeState::kVisiting;
94     if (callbacks.pre_order) {
95       callbacks.pre_order(graph_view.GetNode(w.node));
96     }
97 
98     // Enqueue the node again with the children_visited flag set to true.
99     stack.emplace_back(w.node, true, w.src);
100 
101     // Check if we can continue traversal from the current node.
102     if (predicates.advance && !predicates.advance(graph_view.GetNode(w.node))) {
103       continue;
104     }
105 
106     // Now enqueue the fanin/fanout nodes.
107     if (direction == TraversalDirection::kFollowInputs) {
108       for (const int fanin : graph_view.GetFanin(w.node)) {
109         stack.emplace_back(fanin, false, w.node);
110       }
111     } else {
112       for (const int fanout : graph_view.GetFanout(w.node)) {
113         stack.emplace_back(fanout, false, w.node);
114       }
115     }
116   }
117 }
118 
DfsTraversal(const GraphTopologyView & graph_view,const absl::Span<const NodeDef * const> from,TraversalDirection direction,const DfsCallbacks & callbacks)119 void DfsTraversal(const GraphTopologyView& graph_view,
120                   const absl::Span<const NodeDef* const> from,
121                   TraversalDirection direction, const DfsCallbacks& callbacks) {
122   DfsTraversal(graph_view, from, direction, {}, callbacks);
123 }
124 
125 }  // namespace grappler
126 }  // namespace tensorflow
127