1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/utils/traversal.h"
17
18 #include "absl/container/flat_hash_map.h"
19 #include "tensorflow/core/framework/node_def.pb.h"
20 #include "tensorflow/core/grappler/graph_topology_view.h"
21
22 namespace tensorflow {
23 namespace grappler {
24
25 namespace {
26
27 struct DfsStackElem {
DfsStackElemtensorflow::grappler::__anon201039010111::DfsStackElem28 DfsStackElem(int node, bool children_visited, int src)
29 : node(node), children_visited(children_visited), src(src) {}
DfsStackElemtensorflow::grappler::__anon201039010111::DfsStackElem30 explicit DfsStackElem(int node) : DfsStackElem(node, false, -1) {}
31
32 // Index of the node in the graph ∊ [0, num_nodes).
33 int node;
34 // `True` if visited all the input/output nodes (pushed all input/output nodes
35 // to the stack).
36 bool children_visited;
37 // Index of the node in the graph, from which we entered the `node`.
38 int src;
39 };
40
41 enum class NodeState { kNotVisited, kVisiting, kDone };
42
43 } // namespace
44
DfsTraversal(const GraphTopologyView & graph_view,const absl::Span<const NodeDef * const> from,const TraversalDirection direction,const DfsPredicates & predicates,const DfsCallbacks & callbacks)45 void DfsTraversal(const GraphTopologyView& graph_view,
46 const absl::Span<const NodeDef* const> from,
47 const TraversalDirection direction,
48 const DfsPredicates& predicates,
49 const DfsCallbacks& callbacks) {
50 std::vector<DfsStackElem> stack;
51 stack.reserve(from.size());
52
53 for (const NodeDef* node : from) {
54 const absl::optional<int> node_idx = graph_view.GetNodeIndex(*node);
55 DCHECK(node_idx.has_value()) << "Illegal start node: " << node->name();
56 if (node_idx.has_value()) {
57 stack.emplace_back(node_idx.value());
58 }
59 }
60
61 absl::flat_hash_map<int, NodeState> node_state;
62 while (!stack.empty()) {
63 DfsStackElem w = stack.back();
64 stack.pop_back();
65
66 NodeState& state = node_state[w.node];
67 if (state == NodeState::kDone) continue;
68
69 // Skip nodes that we should not enter.
70 if (predicates.enter && !predicates.enter(graph_view.GetNode(w.node))) {
71 state = NodeState::kDone;
72 continue;
73 }
74
75 // We've processed all the children of this node.
76 if (w.children_visited) {
77 state = NodeState::kDone;
78 if (callbacks.post_order) {
79 callbacks.post_order(graph_view.GetNode(w.node));
80 }
81 continue;
82 }
83
84 // Loop detected.
85 if (state == NodeState::kVisiting) {
86 if (callbacks.on_back_edge) {
87 callbacks.on_back_edge(graph_view.GetNode(w.src),
88 graph_view.GetNode(w.node));
89 }
90 continue;
91 }
92
93 state = NodeState::kVisiting;
94 if (callbacks.pre_order) {
95 callbacks.pre_order(graph_view.GetNode(w.node));
96 }
97
98 // Enqueue the node again with the children_visited flag set to true.
99 stack.emplace_back(w.node, true, w.src);
100
101 // Check if we can continue traversal from the current node.
102 if (predicates.advance && !predicates.advance(graph_view.GetNode(w.node))) {
103 continue;
104 }
105
106 // Now enqueue the fanin/fanout nodes.
107 if (direction == TraversalDirection::kFollowInputs) {
108 for (const int fanin : graph_view.GetFanin(w.node)) {
109 stack.emplace_back(fanin, false, w.node);
110 }
111 } else {
112 for (const int fanout : graph_view.GetFanout(w.node)) {
113 stack.emplace_back(fanout, false, w.node);
114 }
115 }
116 }
117 }
118
DfsTraversal(const GraphTopologyView & graph_view,const absl::Span<const NodeDef * const> from,TraversalDirection direction,const DfsCallbacks & callbacks)119 void DfsTraversal(const GraphTopologyView& graph_view,
120 const absl::Span<const NodeDef* const> from,
121 TraversalDirection direction, const DfsCallbacks& callbacks) {
122 DfsTraversal(graph_view, from, direction, {}, callbacks);
123 }
124
125 } // namespace grappler
126 } // namespace tensorflow
127