1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_TRAVERSAL_H_
17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_TRAVERSAL_H_
18 
19 #include <functional>
20 
21 #include "tensorflow/core/grappler/graph_topology_view.h"
22 
23 namespace tensorflow {
24 namespace grappler {
25 
26 enum class TraversalDirection { kFollowInputs, kFollowOutputs };
27 
28 // Encapsulate DFS callbacks that will be called during the graph traversal.
29 //
30 // If non-empty, the `pre_order` and `post_order` functors will be called on
31 // each reachable node (including the `from` nodes) in pre and post order. If
32 // loops are found, the `on_back_edge` functor will be called on the
33 // corresponding back edges. Moreover, the pre and post order will assume that
34 // these back edges will be cut.
35 struct DfsCallbacks {
36   DfsCallbacks() = default;
DfsCallbacksDfsCallbacks37   DfsCallbacks(std::function<void(const NodeDef*)> pre,
38                std::function<void(const NodeDef*)> post,
39                std::function<void(const NodeDef*, const NodeDef*)> back_edge)
40       : pre_order(std::move(pre)),
41         post_order(std::move(post)),
42         on_back_edge(std::move(back_edge)) {}
43 
PreOrderDfsCallbacks44   static DfsCallbacks PreOrder(std::function<void(const NodeDef*)> pre) {
45     return DfsCallbacks(std::move(pre), nullptr, nullptr);
46   }
47 
PostOrderDfsCallbacks48   static DfsCallbacks PostOrder(std::function<void(const NodeDef*)> post) {
49     return DfsCallbacks(nullptr, std::move(post), nullptr);
50   }
51 
52   std::function<void(const NodeDef*)> pre_order;
53   std::function<void(const NodeDef*)> post_order;
54   std::function<void(const NodeDef*, const NodeDef*)> on_back_edge;
55 };
56 
57 // Encapsulate DFS predicates for traversing the graph.
58 //
59 // The `enter` predicate decides if traversal should enter the node, and the
60 // `advance` predicate decides if the traversal should follow inputs/outputs
61 // from the node.
62 //
63 // If predicates are empty (default initialized), it's assumed that we can enter
64 // into any node and advance from any node respectively.
65 struct DfsPredicates {
66   DfsPredicates() = default;
DfsPredicatesDfsPredicates67   DfsPredicates(std::function<bool(const NodeDef*)> enter,
68                 std::function<bool(const NodeDef*)> advance)
69       : enter(std::move(enter)), advance(std::move(advance)) {}
70 
EnterDfsPredicates71   static DfsPredicates Enter(std::function<bool(const NodeDef*)> enter) {
72     return DfsPredicates(std::move(enter), nullptr);
73   }
74 
AdvanceDfsPredicates75   static DfsPredicates Advance(std::function<bool(const NodeDef*)> advance) {
76     return DfsPredicates(nullptr, std::move(advance));
77   }
78 
79   std::function<bool(const NodeDef*)> enter;
80   std::function<bool(const NodeDef*)> advance;
81 };
82 
83 // Traverse the graph in DFS order in the given direction, starting from the
84 // list of nodes specified in the `from` argument. Use `predicates` to decide if
85 // traversal should enter/advance to/from the graph node. These predicates also
86 // applied to the `from` nodes. Call corresponding callbacks for each visited
87 // node.
88 void DfsTraversal(const GraphTopologyView& graph_view,
89                   absl::Span<const NodeDef* const> from,
90                   TraversalDirection direction, const DfsPredicates& predicates,
91                   const DfsCallbacks& callbacks);
92 
93 // Traverse the graph in DFS order in the given direction, starting from the
94 // list of nodes specified in the `from` argument. Call corresponding callbacks
95 // for each visited node.
96 void DfsTraversal(const GraphTopologyView& graph_view,
97                   absl::Span<const NodeDef* const> from,
98                   TraversalDirection direction, const DfsCallbacks& callbacks);
99 
100 }  // namespace grappler
101 }  // namespace tensorflow
102 
103 #endif  // TENSORFLOW_CORE_GRAPPLER_UTILS_TRAVERSAL_H_
104