1 // dfs-visit.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Depth-first search visitation. See visit.h for more general
20 // search queue disciplines.
21
22 #ifndef FST_LIB_DFS_VISIT_H__
23 #define FST_LIB_DFS_VISIT_H__
24
25 #include <stack>
26 #include <vector>
27 using std::vector;
28
29 #include <fst/arcfilter.h>
30 #include <fst/fst.h>
31
32
33 namespace fst {
34
35 // Visitor Interface - class determines actions taken during a Dfs.
36 // If any of the boolean member functions return false, the DFS is
37 // aborted by first calling FinishState() on all currently grey states
38 // and then calling FinishVisit().
39 //
40 // Note this is similar to the more general visitor interface in visit.h
41 // except that FinishState returns additional information appropriate only for
42 // a DFS and some methods names here are better suited to a DFS.
43 //
44 // template <class Arc>
45 // class Visitor {
46 // public:
47 // typedef typename Arc::StateId StateId;
48 //
49 // Visitor(T *return_data);
50 // // Invoked before DFS visit
51 // void InitVisit(const Fst<Arc> &fst);
52 // // Invoked when state discovered (2nd arg is DFS tree root)
53 // bool InitState(StateId s, StateId root);
54 // // Invoked when tree arc examined (to white/undiscovered state)
55 // bool TreeArc(StateId s, const Arc &a);
56 // // Invoked when back arc examined (to grey/unfinished state)
57 // bool BackArc(StateId s, const Arc &a);
58 // // Invoked when forward or cross arc examined (to black/finished state)
59 // bool ForwardOrCrossArc(StateId s, const Arc &a);
60 // // Invoked when state finished (PARENT is kNoStateID and ARC == NULL
61 // // when S is tree root)
62 // void FinishState(StateId s, StateId parent, const Arc *parent_arc);
63 // // Invoked after DFS visit
64 // void FinishVisit();
65 // };
66
67 // An Fst state's DFS status
68 const int kDfsWhite = 0; // Undiscovered
69 const int kDfsGrey = 1; // Discovered & unfinished
70 const int kDfsBlack = 2; // Finished
71
72 // An Fst state's DFS stack state
73 template <class Arc>
74 struct DfsState {
75 typedef typename Arc::StateId StateId;
76
DfsStateDfsState77 DfsState(const Fst<Arc> &fst, StateId s): state_id(s), arc_iter(fst, s) {}
78
79 StateId state_id; // Fst state ...
80 ArcIterator< Fst<Arc> > arc_iter; // and its corresponding arcs
81 };
82
83
84 // Performs depth-first visitation. Visitor class argument determines
85 // actions and contains any return data. ArcFilter determines arcs
86 // that are considered.
87 //
88 // Note this is similar to Visit() in visit.h called with a LIFO
89 // queue except this version has a Visitor class specialized and
90 // augmented for a DFS.
91 template <class Arc, class V, class ArcFilter>
DfsVisit(const Fst<Arc> & fst,V * visitor,ArcFilter filter)92 void DfsVisit(const Fst<Arc> &fst, V *visitor, ArcFilter filter) {
93 typedef typename Arc::StateId StateId;
94
95 visitor->InitVisit(fst);
96
97 StateId start = fst.Start();
98 if (start == kNoStateId) {
99 visitor->FinishVisit();
100 return;
101 }
102
103 vector<char> state_color; // Fst state DFS status
104 stack<DfsState<Arc> *> state_stack; // DFS execution stack
105
106 StateId nstates = start + 1; // # of known states in general case
107 bool expanded = false;
108 if (fst.Properties(kExpanded, false)) { // tests if expanded case, then
109 nstates = CountStates(fst); // uses ExpandedFst::NumStates().
110 expanded = true;
111 }
112
113 state_color.resize(nstates, kDfsWhite);
114 StateIterator< Fst<Arc> > siter(fst);
115
116 // Continue DFS while true
117 bool dfs = true;
118
119 // Iterate over trees in DFS forest.
120 for (StateId root = start; dfs && root < nstates;) {
121 state_color[root] = kDfsGrey;
122 state_stack.push(new DfsState<Arc>(fst, root));
123 dfs = visitor->InitState(root, root);
124 while (!state_stack.empty()) {
125 DfsState<Arc> *dfs_state = state_stack.top();
126 StateId s = dfs_state->state_id;
127 if (s >= state_color.size()) {
128 nstates = s + 1;
129 state_color.resize(nstates, kDfsWhite);
130 }
131 ArcIterator< Fst<Arc> > &aiter = dfs_state->arc_iter;
132 if (!dfs || aiter.Done()) {
133 state_color[s] = kDfsBlack;
134 delete dfs_state;
135 state_stack.pop();
136 if (!state_stack.empty()) {
137 DfsState<Arc> *parent_state = state_stack.top();
138 StateId p = parent_state->state_id;
139 ArcIterator< Fst<Arc> > &piter = parent_state->arc_iter;
140 visitor->FinishState(s, p, &piter.Value());
141 piter.Next();
142 } else {
143 visitor->FinishState(s, kNoStateId, 0);
144 }
145 continue;
146 }
147 const Arc &arc = aiter.Value();
148 if (arc.nextstate >= state_color.size()) {
149 nstates = arc.nextstate + 1;
150 state_color.resize(nstates, kDfsWhite);
151 }
152 if (!filter(arc)) {
153 aiter.Next();
154 continue;
155 }
156 int next_color = state_color[arc.nextstate];
157 switch (next_color) {
158 default:
159 case kDfsWhite:
160 dfs = visitor->TreeArc(s, arc);
161 if (!dfs) break;
162 state_color[arc.nextstate] = kDfsGrey;
163 state_stack.push(new DfsState<Arc>(fst, arc.nextstate));
164 dfs = visitor->InitState(arc.nextstate, root);
165 break;
166 case kDfsGrey:
167 dfs = visitor->BackArc(s, arc);
168 aiter.Next();
169 break;
170 case kDfsBlack:
171 dfs = visitor->ForwardOrCrossArc(s, arc);
172 aiter.Next();
173 break;
174 }
175 }
176
177 // Find next tree root
178 for (root = root == start ? 0 : root + 1;
179 root < nstates && state_color[root] != kDfsWhite;
180 ++root) {
181 }
182
183 // Check for a state beyond the largest known state
184 if (!expanded && root == nstates) {
185 for (; !siter.Done(); siter.Next()) {
186 if (siter.Value() == nstates) {
187 ++nstates;
188 state_color.push_back(kDfsWhite);
189 break;
190 }
191 }
192 }
193 }
194 visitor->FinishVisit();
195 }
196
197
198 template <class Arc, class V>
DfsVisit(const Fst<Arc> & fst,V * visitor)199 void DfsVisit(const Fst<Arc> &fst, V *visitor) {
200 DfsVisit(fst, visitor, AnyArcFilter<Arc>());
201 }
202
203 } // namespace fst
204
205 #endif // FST_LIB_DFS_VISIT_H__
206