1 // paren.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // Common classes for PDT parentheses
19
20 // \file
21
22 #ifndef FST_EXTENSIONS_PDT_PAREN_H_
23 #define FST_EXTENSIONS_PDT_PAREN_H_
24
25 #include <algorithm>
26 #include <tr1/unordered_map>
27 using std::tr1::unordered_map;
28 using std::tr1::unordered_multimap;
29 #include <tr1/unordered_set>
30 using std::tr1::unordered_set;
31 using std::tr1::unordered_multiset;
32 #include <set>
33
34 #include <fst/extensions/pdt/pdt.h>
35 #include <fst/extensions/pdt/collection.h>
36 #include <fst/fst.h>
37 #include <fst/dfs-visit.h>
38
39
40 namespace fst {
41
42 //
43 // ParenState: Pair of an open (close) parenthesis and
44 // its destination (source) state.
45 //
46
47 template <class A>
48 class ParenState {
49 public:
50 typedef typename A::Label Label;
51 typedef typename A::StateId StateId;
52
53 struct Hash {
operatorHash54 size_t operator()(const ParenState<A> &p) const {
55 return p.paren_id + p.state_id * kPrime;
56 }
57 };
58
59 Label paren_id; // ID of open (close) paren
60 StateId state_id; // destination (source) state of open (close) paren
61
ParenState()62 ParenState() : paren_id(kNoLabel), state_id(kNoStateId) {}
63
ParenState(Label p,StateId s)64 ParenState(Label p, StateId s) : paren_id(p), state_id(s) {}
65
66 bool operator==(const ParenState<A> &p) const {
67 if (&p == this)
68 return true;
69 return p.paren_id == this->paren_id && p.state_id == this->state_id;
70 }
71
72 bool operator!=(const ParenState<A> &p) const { return !(p == *this); }
73
74 bool operator<(const ParenState<A> &p) const {
75 return paren_id < this->paren.id ||
76 (p.paren_id == this->paren.id && p.state_id < this->state_id);
77 }
78
79 private:
80 static const size_t kPrime;
81 };
82
83 template <class A>
84 const size_t ParenState<A>::kPrime = 7853;
85
86
87 // Creates an FST-style iterator from STL map and iterator.
88 template <class M>
89 class MapIterator {
90 public:
91 typedef typename M::const_iterator StlIterator;
92 typedef typename M::value_type PairType;
93 typedef typename PairType::second_type ValueType;
94
MapIterator(const M & m,StlIterator iter)95 MapIterator(const M &m, StlIterator iter)
96 : map_(m), begin_(iter), iter_(iter) {}
97
Done()98 bool Done() const {
99 return iter_ == map_.end() || iter_->first != begin_->first;
100 }
101
Value()102 ValueType Value() const { return iter_->second; }
Next()103 void Next() { ++iter_; }
Reset()104 void Reset() { iter_ = begin_; }
105
106 private:
107 const M &map_;
108 StlIterator begin_;
109 StlIterator iter_;
110 };
111
112 //
113 // PdtParenReachable: Provides various parenthesis reachability information
114 // on a PDT.
115 //
116
117 template <class A>
118 class PdtParenReachable {
119 public:
120 typedef typename A::StateId StateId;
121 typedef typename A::Label Label;
122 public:
123 // Maps from state ID to reachable paren IDs from (to) that state.
124 typedef unordered_multimap<StateId, Label> ParenMultiMap;
125
126 // Maps from paren ID and state ID to reachable state set ID
127 typedef unordered_map<ParenState<A>, ssize_t,
128 typename ParenState<A>::Hash> StateSetMap;
129
130 // Maps from paren ID and state ID to arcs exiting that state with that
131 // Label.
132 typedef unordered_multimap<ParenState<A>, A,
133 typename ParenState<A>::Hash> ParenArcMultiMap;
134
135 typedef MapIterator<ParenMultiMap> ParenIterator;
136
137 typedef MapIterator<ParenArcMultiMap> ParenArcIterator;
138
139 typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator;
140
141 // Computes close (open) parenthesis reachabilty information for
142 // a PDT with bounded stack.
PdtParenReachable(const Fst<A> & fst,const vector<pair<Label,Label>> & parens,bool close)143 PdtParenReachable(const Fst<A> &fst,
144 const vector<pair<Label, Label> > &parens, bool close)
145 : fst_(fst),
146 parens_(parens),
147 close_(close),
148 error_(false) {
149 for (Label i = 0; i < parens.size(); ++i) {
150 const pair<Label, Label> &p = parens[i];
151 paren_id_map_[p.first] = i;
152 paren_id_map_[p.second] = i;
153 }
154
155 if (close_) {
156 StateId start = fst.Start();
157 if (start == kNoStateId)
158 return;
159 if (!DFSearch(start)) {
160 FSTERROR() << "PdtReachable: Underlying cyclicity not supported";
161 error_ = true;
162 }
163 } else {
164 FSTERROR() << "PdtParenReachable: open paren info not implemented";
165 error_ = true;
166 }
167 }
168
Error()169 bool const Error() { return error_; }
170
171 // Given a state ID, returns an iterator over paren IDs
172 // for close (open) parens reachable from that state along balanced
173 // paths.
FindParens(StateId s)174 ParenIterator FindParens(StateId s) const {
175 return ParenIterator(paren_multimap_, paren_multimap_.find(s));
176 }
177
178 // Given a paren ID and a state ID s, returns an iterator over
179 // states that can be reached along balanced paths from (to) s that
180 // have have close (open) parentheses matching the paren ID exiting
181 // (entering) those states.
FindStates(Label paren_id,StateId s)182 SetIterator FindStates(Label paren_id, StateId s) const {
183 ParenState<A> paren_state(paren_id, s);
184 typename StateSetMap::const_iterator id_it = set_map_.find(paren_state);
185 if (id_it == set_map_.end()) {
186 return state_sets_.FindSet(-1);
187 } else {
188 return state_sets_.FindSet(id_it->second);
189 }
190 }
191
192 // Given a paren Id and a state ID s, return an iterator over
193 // arcs that exit (enter) s and are labeled with a close (open)
194 // parenthesis matching the paren ID.
FindParenArcs(Label paren_id,StateId s)195 ParenArcIterator FindParenArcs(Label paren_id, StateId s) const {
196 ParenState<A> paren_state(paren_id, s);
197 return ParenArcIterator(paren_arc_multimap_,
198 paren_arc_multimap_.find(paren_state));
199 }
200
201 private:
202 // DFS that gathers paren and state set information.
203 // Bool returns false when cycle detected.
204 bool DFSearch(StateId s);
205
206 // Unions state sets together gathered by the DFS.
207 void ComputeStateSet(StateId s);
208
209 // Gather state set(s) from state 'nexts'.
210 void UpdateStateSet(StateId nexts, set<Label> *paren_set,
211 vector< set<StateId> > *state_sets) const;
212
213 const Fst<A> &fst_;
214 const vector<pair<Label, Label> > &parens_; // Paren ID -> Labels
215 bool close_; // Close/open paren info?
216 unordered_map<Label, Label> paren_id_map_; // Paren labels -> ID
217 ParenMultiMap paren_multimap_; // Paren reachability
218 ParenArcMultiMap paren_arc_multimap_; // Paren Arcs
219 vector<char> state_color_; // DFS state
220 mutable Collection<ssize_t, StateId> state_sets_; // Reachable states -> ID
221 StateSetMap set_map_; // ID -> Reachable states
222 bool error_;
223 DISALLOW_COPY_AND_ASSIGN(PdtParenReachable);
224 };
225
226 // DFS that gathers paren and state set information.
227 template <class A>
DFSearch(StateId s)228 bool PdtParenReachable<A>::DFSearch(StateId s) {
229 if (s >= state_color_.size())
230 state_color_.resize(s + 1, kDfsWhite);
231
232 if (state_color_[s] == kDfsBlack)
233 return true;
234
235 if (state_color_[s] == kDfsGrey)
236 return false;
237
238 state_color_[s] = kDfsGrey;
239
240 for (ArcIterator<Fst<A> > aiter(fst_, s);
241 !aiter.Done();
242 aiter.Next()) {
243 const A &arc = aiter.Value();
244
245 typename unordered_map<Label, Label>::const_iterator pit
246 = paren_id_map_.find(arc.ilabel);
247 if (pit != paren_id_map_.end()) { // paren?
248 Label paren_id = pit->second;
249 if (arc.ilabel == parens_[paren_id].first) { // open paren
250 if (!DFSearch(arc.nextstate))
251 return false;
252 for (SetIterator set_iter = FindStates(paren_id, arc.nextstate);
253 !set_iter.Done(); set_iter.Next()) {
254 for (ParenArcIterator paren_arc_iter =
255 FindParenArcs(paren_id, set_iter.Element());
256 !paren_arc_iter.Done();
257 paren_arc_iter.Next()) {
258 const A &cparc = paren_arc_iter.Value();
259 if (!DFSearch(cparc.nextstate))
260 return false;
261 }
262 }
263 }
264 } else { // non-paren
265 if(!DFSearch(arc.nextstate))
266 return false;
267 }
268 }
269 ComputeStateSet(s);
270 state_color_[s] = kDfsBlack;
271 return true;
272 }
273
274 // Unions state sets together gathered by the DFS.
275 template <class A>
ComputeStateSet(StateId s)276 void PdtParenReachable<A>::ComputeStateSet(StateId s) {
277 set<Label> paren_set;
278 vector< set<StateId> > state_sets(parens_.size());
279 for (ArcIterator< Fst<A> > aiter(fst_, s);
280 !aiter.Done();
281 aiter.Next()) {
282 const A &arc = aiter.Value();
283
284 typename unordered_map<Label, Label>::const_iterator pit
285 = paren_id_map_.find(arc.ilabel);
286 if (pit != paren_id_map_.end()) { // paren?
287 Label paren_id = pit->second;
288 if (arc.ilabel == parens_[paren_id].first) { // open paren
289 for (SetIterator set_iter =
290 FindStates(paren_id, arc.nextstate);
291 !set_iter.Done(); set_iter.Next()) {
292 for (ParenArcIterator paren_arc_iter =
293 FindParenArcs(paren_id, set_iter.Element());
294 !paren_arc_iter.Done();
295 paren_arc_iter.Next()) {
296 const A &cparc = paren_arc_iter.Value();
297 UpdateStateSet(cparc.nextstate, &paren_set, &state_sets);
298 }
299 }
300 } else { // close paren
301 paren_set.insert(paren_id);
302 state_sets[paren_id].insert(s);
303 ParenState<A> paren_state(paren_id, s);
304 paren_arc_multimap_.insert(make_pair(paren_state, arc));
305 }
306 } else { // non-paren
307 UpdateStateSet(arc.nextstate, &paren_set, &state_sets);
308 }
309 }
310
311 vector<StateId> state_set;
312 for (typename set<Label>::iterator paren_iter = paren_set.begin();
313 paren_iter != paren_set.end(); ++paren_iter) {
314 state_set.clear();
315 Label paren_id = *paren_iter;
316 paren_multimap_.insert(make_pair(s, paren_id));
317 for (typename set<StateId>::iterator state_iter
318 = state_sets[paren_id].begin();
319 state_iter != state_sets[paren_id].end();
320 ++state_iter) {
321 state_set.push_back(*state_iter);
322 }
323 ParenState<A> paren_state(paren_id, s);
324 set_map_[paren_state] = state_sets_.FindId(state_set);
325 }
326 }
327
328 // Gather state set(s) from state 'nexts'.
329 template <class A>
UpdateStateSet(StateId nexts,set<Label> * paren_set,vector<set<StateId>> * state_sets)330 void PdtParenReachable<A>::UpdateStateSet(
331 StateId nexts, set<Label> *paren_set,
332 vector< set<StateId> > *state_sets) const {
333 for(ParenIterator paren_iter = FindParens(nexts);
334 !paren_iter.Done(); paren_iter.Next()) {
335 Label paren_id = paren_iter.Value();
336 paren_set->insert(paren_id);
337 for (SetIterator set_iter = FindStates(paren_id, nexts);
338 !set_iter.Done(); set_iter.Next()) {
339 (*state_sets)[paren_id].insert(set_iter.Element());
340 }
341 }
342 }
343
344
345 // Store balancing parenthesis data for a PDT. Allows on-the-fly
346 // construction (e.g. in PdtShortestPath) unlike PdtParenReachable above.
347 template <class A>
348 class PdtBalanceData {
349 public:
350 typedef typename A::StateId StateId;
351 typedef typename A::Label Label;
352
353 // Hash set for open parens
354 typedef unordered_set<ParenState<A>, typename ParenState<A>::Hash> OpenParenSet;
355
356 // Maps from open paren destination state to parenthesis ID.
357 typedef unordered_multimap<StateId, Label> OpenParenMap;
358
359 // Maps from open paren state to source states of matching close parens
360 typedef unordered_multimap<ParenState<A>, StateId,
361 typename ParenState<A>::Hash> CloseParenMap;
362
363 // Maps from open paren state to close source set ID
364 typedef unordered_map<ParenState<A>, ssize_t,
365 typename ParenState<A>::Hash> CloseSourceMap;
366
367 typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator;
368
PdtBalanceData()369 PdtBalanceData() {}
370
Clear()371 void Clear() {
372 open_paren_map_.clear();
373 close_paren_map_.clear();
374 }
375
376 // Adds an open parenthesis with destination state 'open_dest'.
OpenInsert(Label paren_id,StateId open_dest)377 void OpenInsert(Label paren_id, StateId open_dest) {
378 ParenState<A> key(paren_id, open_dest);
379 if (!open_paren_set_.count(key)) {
380 open_paren_set_.insert(key);
381 open_paren_map_.insert(make_pair(open_dest, paren_id));
382 }
383 }
384
385 // Adds a matching closing parenthesis with source state
386 // 'close_source' that balances an open_parenthesis with destination
387 // state 'open_dest' if OpenInsert() previously called
388 // (o.w. CloseInsert() does nothing).
CloseInsert(Label paren_id,StateId open_dest,StateId close_source)389 void CloseInsert(Label paren_id, StateId open_dest, StateId close_source) {
390 ParenState<A> key(paren_id, open_dest);
391 if (open_paren_set_.count(key))
392 close_paren_map_.insert(make_pair(key, close_source));
393 }
394
395 // Find close paren source states matching an open parenthesis.
396 // Methods that follow, iterate through those matching states.
397 // Should be called only after FinishInsert(open_dest).
Find(Label paren_id,StateId open_dest)398 SetIterator Find(Label paren_id, StateId open_dest) {
399 ParenState<A> close_key(paren_id, open_dest);
400 typename CloseSourceMap::const_iterator id_it =
401 close_source_map_.find(close_key);
402 if (id_it == close_source_map_.end()) {
403 return close_source_sets_.FindSet(-1);
404 } else {
405 return close_source_sets_.FindSet(id_it->second);
406 }
407 }
408
409 // Call when all open and close parenthesis insertions wrt open
410 // parentheses entering 'open_dest' are finished. Must be called
411 // before Find(open_dest). Stores close paren source state sets
412 // efficiently.
FinishInsert(StateId open_dest)413 void FinishInsert(StateId open_dest) {
414 vector<StateId> close_sources;
415 for (typename OpenParenMap::iterator oit = open_paren_map_.find(open_dest);
416 oit != open_paren_map_.end() && oit->first == open_dest;) {
417 Label paren_id = oit->second;
418 close_sources.clear();
419 ParenState<A> okey(paren_id, open_dest);
420 open_paren_set_.erase(open_paren_set_.find(okey));
421 for (typename CloseParenMap::iterator cit = close_paren_map_.find(okey);
422 cit != close_paren_map_.end() && cit->first == okey;) {
423 close_sources.push_back(cit->second);
424 close_paren_map_.erase(cit++);
425 }
426 sort(close_sources.begin(), close_sources.end());
427 typename vector<StateId>::iterator unique_end =
428 unique(close_sources.begin(), close_sources.end());
429 close_sources.resize(unique_end - close_sources.begin());
430
431 if (!close_sources.empty())
432 close_source_map_[okey] = close_source_sets_.FindId(close_sources);
433 open_paren_map_.erase(oit++);
434 }
435 }
436
437 // Return a new balance data object representing the reversed balance
438 // information.
439 PdtBalanceData<A> *Reverse(StateId num_states,
440 StateId num_split,
441 StateId state_id_shift) const;
442
443 private:
444 OpenParenSet open_paren_set_; // open par. at dest?
445
446 OpenParenMap open_paren_map_; // open parens per state
447 ParenState<A> open_dest_; // cur open dest. state
448 typename OpenParenMap::const_iterator open_iter_; // cur open parens/state
449
450 CloseParenMap close_paren_map_; // close states/open
451 // paren and state
452
453 CloseSourceMap close_source_map_; // paren, state to set ID
454 mutable Collection<ssize_t, StateId> close_source_sets_;
455 };
456
457 // Return a new balance data object representing the reversed balance
458 // information.
459 template <class A>
Reverse(StateId num_states,StateId num_split,StateId state_id_shift)460 PdtBalanceData<A> *PdtBalanceData<A>::Reverse(
461 StateId num_states,
462 StateId num_split,
463 StateId state_id_shift) const {
464 PdtBalanceData<A> *bd = new PdtBalanceData<A>;
465 unordered_set<StateId> close_sources;
466 StateId split_size = num_states / num_split;
467
468 for (StateId i = 0; i < num_states; i+= split_size) {
469 close_sources.clear();
470
471 for (typename CloseSourceMap::const_iterator
472 sit = close_source_map_.begin();
473 sit != close_source_map_.end();
474 ++sit) {
475 ParenState<A> okey = sit->first;
476 StateId open_dest = okey.state_id;
477 Label paren_id = okey.paren_id;
478 for (SetIterator set_iter = close_source_sets_.FindSet(sit->second);
479 !set_iter.Done(); set_iter.Next()) {
480 StateId close_source = set_iter.Element();
481 if ((close_source < i) || (close_source >= i + split_size))
482 continue;
483 close_sources.insert(close_source + state_id_shift);
484 bd->OpenInsert(paren_id, close_source + state_id_shift);
485 bd->CloseInsert(paren_id, close_source + state_id_shift,
486 open_dest + state_id_shift);
487 }
488 }
489
490 for (typename unordered_set<StateId>::const_iterator it
491 = close_sources.begin();
492 it != close_sources.end();
493 ++it) {
494 bd->FinishInsert(*it);
495 }
496
497 }
498 return bd;
499 }
500
501
502 } // namespace fst
503
504 #endif // FST_EXTENSIONS_PDT_PAREN_H_
505