1 // expand.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Expand a PDT to an FST.
20
21 #ifndef FST_EXTENSIONS_PDT_EXPAND_H__
22 #define FST_EXTENSIONS_PDT_EXPAND_H__
23
24 #include <vector>
25 using std::vector;
26
27 #include <fst/extensions/pdt/pdt.h>
28 #include <fst/extensions/pdt/paren.h>
29 #include <fst/extensions/pdt/shortest-path.h>
30 #include <fst/extensions/pdt/reverse.h>
31 #include <fst/cache.h>
32 #include <fst/mutable-fst.h>
33 #include <fst/queue.h>
34 #include <fst/state-table.h>
35 #include <fst/test-properties.h>
36
37 namespace fst {
38
39 template <class Arc>
40 struct ExpandFstOptions : public CacheOptions {
41 bool keep_parentheses;
42 PdtStack<typename Arc::StateId, typename Arc::Label> *stack;
43 PdtStateTable<typename Arc::StateId, typename Arc::StateId> *state_table;
44
45 ExpandFstOptions(
46 const CacheOptions &opts = CacheOptions(),
47 bool kp = false,
48 PdtStack<typename Arc::StateId, typename Arc::Label> *s = 0,
49 PdtStateTable<typename Arc::StateId, typename Arc::StateId> *st = 0)
CacheOptionsExpandFstOptions50 : CacheOptions(opts), keep_parentheses(kp), stack(s), state_table(st) {}
51 };
52
53 // Properties for an expanded PDT.
ExpandProperties(uint64 inprops)54 inline uint64 ExpandProperties(uint64 inprops) {
55 return inprops & (kAcceptor | kAcyclic | kInitialAcyclic | kUnweighted);
56 }
57
58
59 // Implementation class for ExpandFst
60 template <class A>
61 class ExpandFstImpl
62 : public CacheImpl<A> {
63 public:
64 using FstImpl<A>::SetType;
65 using FstImpl<A>::SetProperties;
66 using FstImpl<A>::Properties;
67 using FstImpl<A>::SetInputSymbols;
68 using FstImpl<A>::SetOutputSymbols;
69
70 using CacheBaseImpl< CacheState<A> >::PushArc;
71 using CacheBaseImpl< CacheState<A> >::HasArcs;
72 using CacheBaseImpl< CacheState<A> >::HasFinal;
73 using CacheBaseImpl< CacheState<A> >::HasStart;
74 using CacheBaseImpl< CacheState<A> >::SetArcs;
75 using CacheBaseImpl< CacheState<A> >::SetFinal;
76 using CacheBaseImpl< CacheState<A> >::SetStart;
77
78 typedef A Arc;
79 typedef typename A::Label Label;
80 typedef typename A::Weight Weight;
81 typedef typename A::StateId StateId;
82 typedef StateId StackId;
83 typedef PdtStateTuple<StateId, StackId> StateTuple;
84
ExpandFstImpl(const Fst<A> & fst,const vector<pair<typename Arc::Label,typename Arc::Label>> & parens,const ExpandFstOptions<A> & opts)85 ExpandFstImpl(const Fst<A> &fst,
86 const vector<pair<typename Arc::Label,
87 typename Arc::Label> > &parens,
88 const ExpandFstOptions<A> &opts)
89 : CacheImpl<A>(opts), fst_(fst.Copy()),
90 stack_(opts.stack ? opts.stack: new PdtStack<StateId, Label>(parens)),
91 state_table_(opts.state_table ? opts.state_table :
92 new PdtStateTable<StateId, StackId>()),
93 own_stack_(opts.stack == 0), own_state_table_(opts.state_table == 0),
94 keep_parentheses_(opts.keep_parentheses) {
95 SetType("expand");
96
97 uint64 props = fst.Properties(kFstProperties, false);
98 SetProperties(ExpandProperties(props), kCopyProperties);
99
100 SetInputSymbols(fst.InputSymbols());
101 SetOutputSymbols(fst.OutputSymbols());
102 }
103
ExpandFstImpl(const ExpandFstImpl & impl)104 ExpandFstImpl(const ExpandFstImpl &impl)
105 : CacheImpl<A>(impl),
106 fst_(impl.fst_->Copy(true)),
107 stack_(new PdtStack<StateId, Label>(*impl.stack_)),
108 state_table_(new PdtStateTable<StateId, StackId>()),
109 own_stack_(true), own_state_table_(true),
110 keep_parentheses_(impl.keep_parentheses_) {
111 SetType("expand");
112 SetProperties(impl.Properties(), kCopyProperties);
113 SetInputSymbols(impl.InputSymbols());
114 SetOutputSymbols(impl.OutputSymbols());
115 }
116
~ExpandFstImpl()117 ~ExpandFstImpl() {
118 delete fst_;
119 if (own_stack_)
120 delete stack_;
121 if (own_state_table_)
122 delete state_table_;
123 }
124
Start()125 StateId Start() {
126 if (!HasStart()) {
127 StateId s = fst_->Start();
128 if (s == kNoStateId)
129 return kNoStateId;
130 StateTuple tuple(s, 0);
131 StateId start = state_table_->FindState(tuple);
132 SetStart(start);
133 }
134 return CacheImpl<A>::Start();
135 }
136
Final(StateId s)137 Weight Final(StateId s) {
138 if (!HasFinal(s)) {
139 const StateTuple &tuple = state_table_->Tuple(s);
140 Weight w = fst_->Final(tuple.state_id);
141 if (w != Weight::Zero() && tuple.stack_id == 0)
142 SetFinal(s, w);
143 else
144 SetFinal(s, Weight::Zero());
145 }
146 return CacheImpl<A>::Final(s);
147 }
148
NumArcs(StateId s)149 size_t NumArcs(StateId s) {
150 if (!HasArcs(s)) {
151 ExpandState(s);
152 }
153 return CacheImpl<A>::NumArcs(s);
154 }
155
NumInputEpsilons(StateId s)156 size_t NumInputEpsilons(StateId s) {
157 if (!HasArcs(s))
158 ExpandState(s);
159 return CacheImpl<A>::NumInputEpsilons(s);
160 }
161
NumOutputEpsilons(StateId s)162 size_t NumOutputEpsilons(StateId s) {
163 if (!HasArcs(s))
164 ExpandState(s);
165 return CacheImpl<A>::NumOutputEpsilons(s);
166 }
167
InitArcIterator(StateId s,ArcIteratorData<A> * data)168 void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
169 if (!HasArcs(s))
170 ExpandState(s);
171 CacheImpl<A>::InitArcIterator(s, data);
172 }
173
174 // Computes the outgoing transitions from a state, creating new destination
175 // states as needed.
ExpandState(StateId s)176 void ExpandState(StateId s) {
177 StateTuple tuple = state_table_->Tuple(s);
178 for (ArcIterator< Fst<A> > aiter(*fst_, tuple.state_id);
179 !aiter.Done(); aiter.Next()) {
180 Arc arc = aiter.Value();
181 StackId stack_id = stack_->Find(tuple.stack_id, arc.ilabel);
182 if (stack_id == -1) {
183 // Non-matching close parenthesis
184 continue;
185 } else if ((stack_id != tuple.stack_id) && !keep_parentheses_) {
186 // Stack push/pop
187 arc.ilabel = arc.olabel = 0;
188 }
189
190 StateTuple ntuple(arc.nextstate, stack_id);
191 arc.nextstate = state_table_->FindState(ntuple);
192 PushArc(s, arc);
193 }
194 SetArcs(s);
195 }
196
GetStack()197 const PdtStack<StackId, Label> &GetStack() const { return *stack_; }
198
GetStateTable()199 const PdtStateTable<StateId, StackId> &GetStateTable() const {
200 return *state_table_;
201 }
202
203 private:
204 const Fst<A> *fst_;
205
206 PdtStack<StackId, Label> *stack_;
207 PdtStateTable<StateId, StackId> *state_table_;
208 bool own_stack_;
209 bool own_state_table_;
210 bool keep_parentheses_;
211
212 void operator=(const ExpandFstImpl<A> &); // disallow
213 };
214
215 // Expands a pushdown transducer (PDT) encoded as an FST into an FST.
216 // This version is a delayed Fst. In the PDT, some transitions are
217 // labeled with open or close parentheses. To be interpreted as a PDT,
218 // the parens must balance on a path. The open-close parenthesis label
219 // pairs are passed in 'parens'. The expansion enforces the
220 // parenthesis constraints. The PDT must be expandable as an FST.
221 //
222 // This class attaches interface to implementation and handles
223 // reference counting, delegating most methods to ImplToFst.
224 template <class A>
225 class ExpandFst : public ImplToFst< ExpandFstImpl<A> > {
226 public:
227 friend class ArcIterator< ExpandFst<A> >;
228 friend class StateIterator< ExpandFst<A> >;
229
230 typedef A Arc;
231 typedef typename A::Label Label;
232 typedef typename A::Weight Weight;
233 typedef typename A::StateId StateId;
234 typedef StateId StackId;
235 typedef CacheState<A> State;
236 typedef ExpandFstImpl<A> Impl;
237
ExpandFst(const Fst<A> & fst,const vector<pair<typename Arc::Label,typename Arc::Label>> & parens)238 ExpandFst(const Fst<A> &fst,
239 const vector<pair<typename Arc::Label,
240 typename Arc::Label> > &parens)
241 : ImplToFst<Impl>(new Impl(fst, parens, ExpandFstOptions<A>())) {}
242
ExpandFst(const Fst<A> & fst,const vector<pair<typename Arc::Label,typename Arc::Label>> & parens,const ExpandFstOptions<A> & opts)243 ExpandFst(const Fst<A> &fst,
244 const vector<pair<typename Arc::Label,
245 typename Arc::Label> > &parens,
246 const ExpandFstOptions<A> &opts)
247 : ImplToFst<Impl>(new Impl(fst, parens, opts)) {}
248
249 // See Fst<>::Copy() for doc.
250 ExpandFst(const ExpandFst<A> &fst, bool safe = false)
251 : ImplToFst<Impl>(fst, safe) {}
252
253 // Get a copy of this ExpandFst. See Fst<>::Copy() for further doc.
254 virtual ExpandFst<A> *Copy(bool safe = false) const {
255 return new ExpandFst<A>(*this, safe);
256 }
257
258 virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
259
InitArcIterator(StateId s,ArcIteratorData<A> * data)260 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
261 GetImpl()->InitArcIterator(s, data);
262 }
263
GetStack()264 const PdtStack<StackId, Label> &GetStack() const {
265 return GetImpl()->GetStack();
266 }
267
GetStateTable()268 const PdtStateTable<StateId, StackId> &GetStateTable() const {
269 return GetImpl()->GetStateTable();
270 }
271
272 private:
273 // Makes visible to friends.
GetImpl()274 Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
275
276 void operator=(const ExpandFst<A> &fst); // Disallow
277 };
278
279
280 // Specialization for ExpandFst.
281 template<class A>
282 class StateIterator< ExpandFst<A> >
283 : public CacheStateIterator< ExpandFst<A> > {
284 public:
StateIterator(const ExpandFst<A> & fst)285 explicit StateIterator(const ExpandFst<A> &fst)
286 : CacheStateIterator< ExpandFst<A> >(fst, fst.GetImpl()) {}
287 };
288
289
290 // Specialization for ExpandFst.
291 template <class A>
292 class ArcIterator< ExpandFst<A> >
293 : public CacheArcIterator< ExpandFst<A> > {
294 public:
295 typedef typename A::StateId StateId;
296
ArcIterator(const ExpandFst<A> & fst,StateId s)297 ArcIterator(const ExpandFst<A> &fst, StateId s)
298 : CacheArcIterator< ExpandFst<A> >(fst.GetImpl(), s) {
299 if (!fst.GetImpl()->HasArcs(s))
300 fst.GetImpl()->ExpandState(s);
301 }
302
303 private:
304 DISALLOW_COPY_AND_ASSIGN(ArcIterator);
305 };
306
307
308 template <class A> inline
InitStateIterator(StateIteratorData<A> * data)309 void ExpandFst<A>::InitStateIterator(StateIteratorData<A> *data) const
310 {
311 data->base = new StateIterator< ExpandFst<A> >(*this);
312 }
313
314 //
315 // PrunedExpand Class
316 //
317
318 // Prunes the delayed expansion of a pushdown transducer (PDT) encoded
319 // as an FST into an FST. In the PDT, some transitions are labeled
320 // with open or close parentheses. To be interpreted as a PDT, the
321 // parens must balance on a path. The open-close parenthesis label
322 // pairs are passed in 'parens'. The expansion enforces the
323 // parenthesis constraints.
324 //
325 // The algorithm works by visiting the delayed ExpandFst using a
326 // shortest-stack first queue discipline and relies on the
327 // shortest-distance information computed using a reverse
328 // shortest-path call to perform the pruning.
329 //
330 // The algorithm maintains the same state ordering between the ExpandFst
331 // being visited 'efst_' and the result of pruning written into the
332 // MutableFst 'ofst_' to improve readability of the code.
333 //
334 template <class A>
335 class PrunedExpand {
336 public:
337 typedef A Arc;
338 typedef typename A::Label Label;
339 typedef typename A::StateId StateId;
340 typedef typename A::Weight Weight;
341 typedef StateId StackId;
342 typedef PdtStack<StackId, Label> Stack;
343 typedef PdtStateTable<StateId, StackId> StateTable;
344 typedef typename PdtBalanceData<Arc>::SetIterator SetIterator;
345
346 // Constructor taking as input a PDT specified by 'ifst' and 'parens'.
347 // 'keep_parentheses' specifies whether parentheses are replaced by
348 // epsilons or not during the expansion. 'opts' is the cache options
349 // used to instantiate the underlying ExpandFst.
350 PrunedExpand(const Fst<A> &ifst,
351 const vector<pair<Label, Label> > &parens,
352 bool keep_parentheses = false,
353 const CacheOptions &opts = CacheOptions())
354 : ifst_(ifst.Copy()),
355 keep_parentheses_(keep_parentheses),
356 stack_(parens),
357 efst_(ifst, parens,
358 ExpandFstOptions<Arc>(opts, true, &stack_, &state_table_)),
359 queue_(state_table_, stack_, stack_length_, distance_, fdistance_) {
360 Reverse(*ifst_, parens, &rfst_);
361 VectorFst<Arc> path;
362 reverse_shortest_path_ = new SP(
363 rfst_, parens,
364 PdtShortestPathOptions<A, FifoQueue<StateId> >(true, false));
365 reverse_shortest_path_->ShortestPath(&path);
366 balance_data_ = reverse_shortest_path_->GetBalanceData()->Reverse(
367 rfst_.NumStates(), 10, -1);
368
369 InitCloseParenMultimap(parens);
370 }
371
~PrunedExpand()372 ~PrunedExpand() {
373 delete ifst_;
374 delete reverse_shortest_path_;
375 delete balance_data_;
376 }
377
378 // Expands and prunes with weight threshold 'threshold' the input PDT.
379 // Writes the result in 'ofst'.
380 void Expand(MutableFst<A> *ofst, const Weight &threshold);
381
382 private:
383 static const uint8 kEnqueued;
384 static const uint8 kExpanded;
385 static const uint8 kSourceState;
386
387 // Comparison functor used by the queue:
388 // 1. states corresponding to shortest stack first,
389 // 2. among stacks of the same length, reverse lexicographic order is used,
390 // 3. among states with the same stack, shortest-first order is used.
391 class StackCompare {
392 public:
StackCompare(const StateTable & st,const Stack & s,const vector<StackId> & sl,const vector<Weight> & d,const vector<Weight> & fd)393 StackCompare(const StateTable &st,
394 const Stack &s, const vector<StackId> &sl,
395 const vector<Weight> &d, const vector<Weight> &fd)
396 : state_table_(st), stack_(s), stack_length_(sl),
397 distance_(d), fdistance_(fd) {}
398
operator()399 bool operator()(StateId s1, StateId s2) const {
400 StackId si1 = state_table_.Tuple(s1).stack_id;
401 StackId si2 = state_table_.Tuple(s2).stack_id;
402 if (stack_length_[si1] < stack_length_[si2])
403 return true;
404 if (stack_length_[si1] > stack_length_[si2])
405 return false;
406 // If stack id equal, use A*
407 if (si1 == si2) {
408 Weight w1 = (s1 < distance_.size()) && (s1 < fdistance_.size()) ?
409 Times(distance_[s1], fdistance_[s1]) : Weight::Zero();
410 Weight w2 = (s2 < distance_.size()) && (s2 < fdistance_.size()) ?
411 Times(distance_[s2], fdistance_[s2]) : Weight::Zero();
412 return less_(w1, w2);
413 }
414 // If lenghts are equal, use reverse lexico.
415 for (; si1 != si2; si1 = stack_.Pop(si1), si2 = stack_.Pop(si2)) {
416 if (stack_.Top(si1) < stack_.Top(si2)) return true;
417 if (stack_.Top(si1) > stack_.Top(si2)) return false;
418 }
419 return false;
420 }
421
422 private:
423 const StateTable &state_table_;
424 const Stack &stack_;
425 const vector<StackId> &stack_length_;
426 const vector<Weight> &distance_;
427 const vector<Weight> &fdistance_;
428 NaturalLess<Weight> less_;
429 };
430
431 class ShortestStackFirstQueue
432 : public ShortestFirstQueue<StateId, StackCompare> {
433 public:
ShortestStackFirstQueue(const PdtStateTable<StateId,StackId> & st,const Stack & s,const vector<StackId> & sl,const vector<Weight> & d,const vector<Weight> & fd)434 ShortestStackFirstQueue(
435 const PdtStateTable<StateId, StackId> &st,
436 const Stack &s,
437 const vector<StackId> &sl,
438 const vector<Weight> &d, const vector<Weight> &fd)
439 : ShortestFirstQueue<StateId, StackCompare>(
440 StackCompare(st, s, sl, d, fd)) {}
441 };
442
443
444 void InitCloseParenMultimap(const vector<pair<Label, Label> > &parens);
445 Weight DistanceToDest(StateId state, StateId source) const;
446 uint8 Flags(StateId s) const;
447 void SetFlags(StateId s, uint8 flags, uint8 mask);
448 Weight Distance(StateId s) const;
449 void SetDistance(StateId s, Weight w);
450 Weight FinalDistance(StateId s) const;
451 void SetFinalDistance(StateId s, Weight w);
452 StateId SourceState(StateId s) const;
453 void SetSourceState(StateId s, StateId p);
454 void AddStateAndEnqueue(StateId s);
455 void Relax(StateId s, const A &arc, Weight w);
456 bool PruneArc(StateId s, const A &arc);
457 void ProcStart();
458 void ProcFinal(StateId s);
459 bool ProcNonParen(StateId s, const A &arc, bool add_arc);
460 bool ProcOpenParen(StateId s, const A &arc, StackId si, StackId nsi);
461 bool ProcCloseParen(StateId s, const A &arc);
462 void ProcDestStates(StateId s, StackId si);
463
464 Fst<A> *ifst_; // Input PDT
465 VectorFst<Arc> rfst_; // Reversed PDT
466 bool keep_parentheses_; // Keep parentheses in ofst?
467 StateTable state_table_; // State table for efst_
468 Stack stack_; // Stack trie
469 ExpandFst<Arc> efst_; // Expanded PDT
470 vector<StackId> stack_length_; // Length of stack for given stack id
471 vector<Weight> distance_; // Distance from initial state in efst_/ofst
472 vector<Weight> fdistance_; // Distance to final states in efst_/ofst
473 ShortestStackFirstQueue queue_; // Queue used to visit efst_
474 vector<uint8> flags_; // Status flags for states in efst_/ofst
475 vector<StateId> sources_; // PDT source state for each expanded state
476
477 typedef PdtShortestPath<Arc, FifoQueue<StateId> > SP;
478 typedef typename SP::CloseParenMultimap ParenMultimap;
479 SP *reverse_shortest_path_; // Shortest path for rfst_
480 PdtBalanceData<Arc> *balance_data_; // Not owned by shortest_path_
481 ParenMultimap close_paren_multimap_; // Maps open paren arcs to
482 // balancing close paren arcs.
483
484 MutableFst<Arc> *ofst_; // Output fst
485 Weight limit_; // Weight limit
486
487 typedef unordered_map<StateId, Weight> DestMap;
488 DestMap dest_map_;
489 StackId current_stack_id_;
490 // 'current_stack_id_' is the stack id of the states currently at the top
491 // of queue, i.e., the states currently being popped and processed.
492 // 'dest_map_' maps a state 's' in 'ifst_' that is the source
493 // of a close parentheses matching the top of 'current_stack_id_; to
494 // the shortest-distance from '(s, current_stack_id_)' to the final
495 // states in 'efst_'.
496 ssize_t current_paren_id_; // Paren id at top of current stack
497 ssize_t cached_stack_id_;
498 StateId cached_source_;
499 slist<pair<StateId, Weight> > cached_dest_list_;
500 // 'cached_dest_list_' contains the set of pair of destination
501 // states and weight to final states for source state
502 // 'cached_source_' and paren id 'cached_paren_id': the set of
503 // source state of a close parenthesis with paren id
504 // 'cached_paren_id' balancing an incoming open parenthesis with
505 // paren id 'cached_paren_id' in state 'cached_source_'.
506
507 NaturalLess<Weight> less_;
508 };
509
510 template <class A> const uint8 PrunedExpand<A>::kEnqueued = 0x01;
511 template <class A> const uint8 PrunedExpand<A>::kExpanded = 0x02;
512 template <class A> const uint8 PrunedExpand<A>::kSourceState = 0x04;
513
514
515 // Initializes close paren multimap, mapping pairs (s,paren_id) to
516 // all the arcs out of s labeled with close parenthese for paren_id.
517 template <class A>
InitCloseParenMultimap(const vector<pair<Label,Label>> & parens)518 void PrunedExpand<A>::InitCloseParenMultimap(
519 const vector<pair<Label, Label> > &parens) {
520 unordered_map<Label, Label> paren_id_map;
521 for (Label i = 0; i < parens.size(); ++i) {
522 const pair<Label, Label> &p = parens[i];
523 paren_id_map[p.first] = i;
524 paren_id_map[p.second] = i;
525 }
526
527 for (StateIterator<Fst<Arc> > siter(*ifst_); !siter.Done(); siter.Next()) {
528 StateId s = siter.Value();
529 for (ArcIterator<Fst<Arc> > aiter(*ifst_, s);
530 !aiter.Done(); aiter.Next()) {
531 const Arc &arc = aiter.Value();
532 typename unordered_map<Label, Label>::const_iterator pit
533 = paren_id_map.find(arc.ilabel);
534 if (pit == paren_id_map.end()) continue;
535 if (arc.ilabel == parens[pit->second].second) { // Close paren
536 ParenState<Arc> paren_state(pit->second, s);
537 close_paren_multimap_.insert(make_pair(paren_state, arc));
538 }
539 }
540 }
541 }
542
543
544 // Returns the weight of the shortest balanced path from 'source' to 'dest'
545 // in 'ifst_', 'dest' must be the source state of a close paren arc.
546 template <class A>
DistanceToDest(StateId source,StateId dest)547 typename A::Weight PrunedExpand<A>::DistanceToDest(StateId source,
548 StateId dest) const {
549 typename SP::SearchState s(source + 1, dest + 1);
550 VLOG(2) << "D(" << source << ", " << dest << ") ="
551 << reverse_shortest_path_->GetShortestPathData().Distance(s);
552 return reverse_shortest_path_->GetShortestPathData().Distance(s);
553 }
554
555 // Returns the flags for state 's' in 'ofst_'.
556 template <class A>
Flags(StateId s)557 uint8 PrunedExpand<A>::Flags(StateId s) const {
558 return s < flags_.size() ? flags_[s] : 0;
559 }
560
561 // Modifies the flags for state 's' in 'ofst_'.
562 template <class A>
SetFlags(StateId s,uint8 flags,uint8 mask)563 void PrunedExpand<A>::SetFlags(StateId s, uint8 flags, uint8 mask) {
564 while (flags_.size() <= s) flags_.push_back(0);
565 flags_[s] &= ~mask;
566 flags_[s] |= flags & mask;
567 }
568
569
570 // Returns the shortest distance from the initial state to 's' in 'ofst_'.
571 template <class A>
Distance(StateId s)572 typename A::Weight PrunedExpand<A>::Distance(StateId s) const {
573 return s < distance_.size() ? distance_[s] : Weight::Zero();
574 }
575
576 // Sets the shortest distance from the initial state to 's' in 'ofst_' to 'w'.
577 template <class A>
SetDistance(StateId s,Weight w)578 void PrunedExpand<A>::SetDistance(StateId s, Weight w) {
579 while (distance_.size() <= s ) distance_.push_back(Weight::Zero());
580 distance_[s] = w;
581 }
582
583
584 // Returns the shortest distance from 's' to the final states in 'ofst_'.
585 template <class A>
FinalDistance(StateId s)586 typename A::Weight PrunedExpand<A>::FinalDistance(StateId s) const {
587 return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
588 }
589
590 // Sets the shortest distance from 's' to the final states in 'ofst_' to 'w'.
591 template <class A>
SetFinalDistance(StateId s,Weight w)592 void PrunedExpand<A>::SetFinalDistance(StateId s, Weight w) {
593 while (fdistance_.size() <= s) fdistance_.push_back(Weight::Zero());
594 fdistance_[s] = w;
595 }
596
597 // Returns the PDT "source" state of state 's' in 'ofst_'.
598 template <class A>
SourceState(StateId s)599 typename A::StateId PrunedExpand<A>::SourceState(StateId s) const {
600 return s < sources_.size() ? sources_[s] : kNoStateId;
601 }
602
603 // Sets the PDT "source" state of state 's' in 'ofst_' to state 'p' in 'ifst_'.
604 template <class A>
SetSourceState(StateId s,StateId p)605 void PrunedExpand<A>::SetSourceState(StateId s, StateId p) {
606 while (sources_.size() <= s) sources_.push_back(kNoStateId);
607 sources_[s] = p;
608 }
609
610 // Adds state 's' of 'efst_' to 'ofst_' and inserts it in the queue,
611 // modifying the flags for 's' accordingly.
612 template <class A>
AddStateAndEnqueue(StateId s)613 void PrunedExpand<A>::AddStateAndEnqueue(StateId s) {
614 if (!(Flags(s) & (kEnqueued | kExpanded))) {
615 while (ofst_->NumStates() <= s) ofst_->AddState();
616 queue_.Enqueue(s);
617 SetFlags(s, kEnqueued, kEnqueued);
618 } else if (Flags(s) & kEnqueued) {
619 queue_.Update(s);
620 }
621 // TODO(allauzen): Check everything is fine when kExpanded?
622 }
623
624 // Relaxes arc 'arc' out of state 's' in 'ofst_':
625 // * if the distance to 's' times the weight of 'arc' is smaller than
626 // the currently stored distance for 'arc.nextstate',
627 // updates 'Distance(arc.nextstate)' with new estimate;
628 // * if 'fd' is less than the currently stored distance from 'arc.nextstate'
629 // to the final state, updates with new estimate.
630 template <class A>
Relax(StateId s,const A & arc,Weight fd)631 void PrunedExpand<A>::Relax(StateId s, const A &arc, Weight fd) {
632 Weight nd = Times(Distance(s), arc.weight);
633 if (less_(nd, Distance(arc.nextstate))) {
634 SetDistance(arc.nextstate, nd);
635 SetSourceState(arc.nextstate, SourceState(s));
636 }
637 if (less_(fd, FinalDistance(arc.nextstate)))
638 SetFinalDistance(arc.nextstate, fd);
639 VLOG(2) << "Relax: " << s << ", d[s] = " << Distance(s) << ", to "
640 << arc.nextstate << ", d[ns] = " << Distance(arc.nextstate)
641 << ", nd = " << nd;
642 }
643
644 // Returns 'true' if the arc 'arc' out of state 's' in 'efst_' needs to
645 // be pruned.
646 template <class A>
PruneArc(StateId s,const A & arc)647 bool PrunedExpand<A>::PruneArc(StateId s, const A &arc) {
648 VLOG(2) << "Prune ?";
649 Weight fd = Weight::Zero();
650
651 if ((cached_source_ != SourceState(s)) ||
652 (cached_stack_id_ != current_stack_id_)) {
653 cached_source_ = SourceState(s);
654 cached_stack_id_ = current_stack_id_;
655 cached_dest_list_.clear();
656 if (cached_source_ != ifst_->Start()) {
657 for (SetIterator set_iter =
658 balance_data_->Find(current_paren_id_, cached_source_);
659 !set_iter.Done(); set_iter.Next()) {
660 StateId dest = set_iter.Element();
661 typename DestMap::const_iterator iter = dest_map_.find(dest);
662 cached_dest_list_.push_front(*iter);
663 }
664 } else {
665 // TODO(allauzen): queue discipline should prevent this never
666 // from happening; replace by a check.
667 cached_dest_list_.push_front(
668 make_pair(rfst_.Start() -1, Weight::One()));
669 }
670 }
671
672 for (typename slist<pair<StateId, Weight> >::const_iterator iter =
673 cached_dest_list_.begin();
674 iter != cached_dest_list_.end();
675 ++iter) {
676 fd = Plus(fd,
677 Times(DistanceToDest(state_table_.Tuple(arc.nextstate).state_id,
678 iter->first),
679 iter->second));
680 }
681 Relax(s, arc, fd);
682 Weight w = Times(Distance(s), Times(arc.weight, fd));
683 return less_(limit_, w);
684 }
685
686 // Adds start state of 'efst_' to 'ofst_', enqueues it and initializes
687 // the distance data structures.
688 template <class A>
ProcStart()689 void PrunedExpand<A>::ProcStart() {
690 StateId s = efst_.Start();
691 AddStateAndEnqueue(s);
692 ofst_->SetStart(s);
693 SetSourceState(s, ifst_->Start());
694
695 current_stack_id_ = 0;
696 current_paren_id_ = -1;
697 stack_length_.push_back(0);
698 dest_map_[rfst_.Start() - 1] = Weight::One(); // not needed
699
700 cached_source_ = ifst_->Start();
701 cached_stack_id_ = 0;
702 cached_dest_list_.push_front(
703 make_pair(rfst_.Start() -1, Weight::One()));
704
705 PdtStateTuple<StateId, StackId> tuple(rfst_.Start() - 1, 0);
706 SetFinalDistance(state_table_.FindState(tuple), Weight::One());
707 SetDistance(s, Weight::One());
708 SetFinalDistance(s, DistanceToDest(ifst_->Start(), rfst_.Start() - 1));
709 VLOG(2) << DistanceToDest(ifst_->Start(), rfst_.Start() - 1);
710 }
711
712 // Makes 's' final in 'ofst_' if shortest accepting path ending in 's'
713 // is below threshold.
714 template <class A>
ProcFinal(StateId s)715 void PrunedExpand<A>::ProcFinal(StateId s) {
716 Weight final = efst_.Final(s);
717 if ((final == Weight::Zero()) || less_(limit_, Times(Distance(s), final)))
718 return;
719 ofst_->SetFinal(s, final);
720 }
721
722 // Returns true when arc (or meta-arc) 'arc' out of 's' in 'efst_' is
723 // below the threshold. When 'add_arc' is true, 'arc' is added to 'ofst_'.
724 template <class A>
ProcNonParen(StateId s,const A & arc,bool add_arc)725 bool PrunedExpand<A>::ProcNonParen(StateId s, const A &arc, bool add_arc) {
726 VLOG(2) << "ProcNonParen: " << s << " to " << arc.nextstate
727 << ", " << arc.ilabel << ":" << arc.olabel << " / " << arc.weight
728 << ", add_arc = " << (add_arc ? "true" : "false");
729 if (PruneArc(s, arc)) return false;
730 if(add_arc) ofst_->AddArc(s, arc);
731 AddStateAndEnqueue(arc.nextstate);
732 return true;
733 }
734
735 // Processes an open paren arc 'arc' out of state 's' in 'ofst_'.
736 // When 'arc' is labeled with an open paren,
737 // 1. considers each (shortest) balanced path starting in 's' by
738 // taking 'arc' and ending by a close paren balancing the open
739 // paren of 'arc' as a meta-arc, processes and prunes each meta-arc
740 // as a non-paren arc, inserting its destination to the queue;
741 // 2. if at least one of these meta-arcs has not been pruned,
742 // adds the destination of 'arc' to 'ofst_' as a new source state
743 // for the stack id 'nsi' and inserts it in the queue.
744 template <class A>
ProcOpenParen(StateId s,const A & arc,StackId si,StackId nsi)745 bool PrunedExpand<A>::ProcOpenParen(StateId s, const A &arc, StackId si,
746 StackId nsi) {
747 // Update the stack lenght when needed: |nsi| = |si| + 1.
748 while (stack_length_.size() <= nsi) stack_length_.push_back(-1);
749 if (stack_length_[nsi] == -1)
750 stack_length_[nsi] = stack_length_[si] + 1;
751
752 StateId ns = arc.nextstate;
753 VLOG(2) << "Open paren: " << s << "(" << state_table_.Tuple(s).state_id
754 << ") to " << ns << "(" << state_table_.Tuple(ns).state_id << ")";
755 bool proc_arc = false;
756 Weight fd = Weight::Zero();
757 ssize_t paren_id = stack_.ParenId(arc.ilabel);
758 slist<StateId> sources;
759 for (SetIterator set_iter =
760 balance_data_->Find(paren_id, state_table_.Tuple(ns).state_id);
761 !set_iter.Done(); set_iter.Next()) {
762 sources.push_front(set_iter.Element());
763 }
764 for (typename slist<StateId>::const_iterator sources_iter = sources.begin();
765 sources_iter != sources.end();
766 ++ sources_iter) {
767 StateId source = *sources_iter;
768 VLOG(2) << "Close paren source: " << source;
769 ParenState<Arc> paren_state(paren_id, source);
770 for (typename ParenMultimap::const_iterator iter =
771 close_paren_multimap_.find(paren_state);
772 iter != close_paren_multimap_.end() && paren_state == iter->first;
773 ++iter) {
774 Arc meta_arc = iter->second;
775 PdtStateTuple<StateId, StackId> tuple(meta_arc.nextstate, si);
776 meta_arc.nextstate = state_table_.FindState(tuple);
777 VLOG(2) << state_table_.Tuple(ns).state_id << ", " << source;
778 VLOG(2) << "Meta arc weight = " << arc.weight << " Times "
779 << DistanceToDest(state_table_.Tuple(ns).state_id, source)
780 << " Times " << meta_arc.weight;
781 meta_arc.weight = Times(
782 arc.weight,
783 Times(DistanceToDest(state_table_.Tuple(ns).state_id, source),
784 meta_arc.weight));
785 proc_arc |= ProcNonParen(s, meta_arc, false);
786 fd = Plus(fd, Times(
787 Times(
788 DistanceToDest(state_table_.Tuple(ns).state_id, source),
789 iter->second.weight),
790 FinalDistance(meta_arc.nextstate)));
791 }
792 }
793 if (proc_arc) {
794 VLOG(2) << "Proc open paren " << s << " to " << arc.nextstate;
795 ofst_->AddArc(
796 s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
797 AddStateAndEnqueue(arc.nextstate);
798 Weight nd = Times(Distance(s), arc.weight);
799 if(less_(nd, Distance(arc.nextstate)))
800 SetDistance(arc.nextstate, nd);
801 // FinalDistance not necessary for source state since pruning
802 // decided using the meta-arcs above. But this is a problem with
803 // A*, hence:
804 if (less_(fd, FinalDistance(arc.nextstate)))
805 SetFinalDistance(arc.nextstate, fd);
806 SetFlags(arc.nextstate, kSourceState, kSourceState);
807 }
808 return proc_arc;
809 }
810
811 // Checks that shortest path through close paren arc in 'efst_' is
812 // below threshold, if so adds it to 'ofst_'.
813 template <class A>
ProcCloseParen(StateId s,const A & arc)814 bool PrunedExpand<A>::ProcCloseParen(StateId s, const A &arc) {
815 Weight w = Times(Distance(s),
816 Times(arc.weight, FinalDistance(arc.nextstate)));
817 if (less_(limit_, w))
818 return false;
819 ofst_->AddArc(
820 s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
821 return true;
822 }
823
824 // When 's' in 'ofst_' is a source state for stack id 'si', identifies
825 // all the corresponding possible destination states, that is, all the
826 // states in 'ifst_' that have an outgoing close paren arc balancing
827 // the incoming open paren taken to get to 's', and for each such
828 // state 't', computes the shortest distance from (t, si) to the final
829 // states in 'ofst_'. Stores this information in 'dest_map_'.
830 template <class A>
ProcDestStates(StateId s,StackId si)831 void PrunedExpand<A>::ProcDestStates(StateId s, StackId si) {
832 if (!(Flags(s) & kSourceState)) return;
833 if (si != current_stack_id_) {
834 dest_map_.clear();
835 current_stack_id_ = si;
836 current_paren_id_ = stack_.Top(current_stack_id_);
837 VLOG(2) << "StackID " << si << " dequeued for first time";
838 }
839 // TODO(allauzen): clean up source state business; rename current function to
840 // ProcSourceState.
841 SetSourceState(s, state_table_.Tuple(s).state_id);
842
843 ssize_t paren_id = stack_.Top(si);
844 for (SetIterator set_iter =
845 balance_data_->Find(paren_id, state_table_.Tuple(s).state_id);
846 !set_iter.Done(); set_iter.Next()) {
847 StateId dest_state = set_iter.Element();
848 if (dest_map_.find(dest_state) != dest_map_.end())
849 continue;
850 Weight dest_weight = Weight::Zero();
851 ParenState<Arc> paren_state(paren_id, dest_state);
852 for (typename ParenMultimap::const_iterator iter =
853 close_paren_multimap_.find(paren_state);
854 iter != close_paren_multimap_.end() && paren_state == iter->first;
855 ++iter) {
856 const Arc &arc = iter->second;
857 PdtStateTuple<StateId, StackId> tuple(arc.nextstate, stack_.Pop(si));
858 dest_weight = Plus(dest_weight,
859 Times(arc.weight,
860 FinalDistance(state_table_.FindState(tuple))));
861 }
862 dest_map_[dest_state] = dest_weight;
863 VLOG(2) << "State " << dest_state << " is a dest state for stack id "
864 << si << " with weight " << dest_weight;
865 }
866 }
867
868 // Expands and prunes with weight threshold 'threshold' the input PDT.
869 // Writes the result in 'ofst'.
870 template <class A>
Expand(MutableFst<A> * ofst,const typename A::Weight & threshold)871 void PrunedExpand<A>::Expand(
872 MutableFst<A> *ofst, const typename A::Weight &threshold) {
873 ofst_ = ofst;
874 ofst_->DeleteStates();
875 ofst_->SetInputSymbols(ifst_->InputSymbols());
876 ofst_->SetOutputSymbols(ifst_->OutputSymbols());
877
878 limit_ = Times(DistanceToDest(ifst_->Start(), rfst_.Start() - 1), threshold);
879 flags_.clear();
880
881 ProcStart();
882
883 while (!queue_.Empty()) {
884 StateId s = queue_.Head();
885 queue_.Dequeue();
886 SetFlags(s, kExpanded, kExpanded | kEnqueued);
887 VLOG(2) << s << " dequeued!";
888
889 ProcFinal(s);
890 StackId stack_id = state_table_.Tuple(s).stack_id;
891 ProcDestStates(s, stack_id);
892
893 for (ArcIterator<ExpandFst<Arc> > aiter(efst_, s);
894 !aiter.Done();
895 aiter.Next()) {
896 Arc arc = aiter.Value();
897 StackId nextstack_id = state_table_.Tuple(arc.nextstate).stack_id;
898 if (stack_id == nextstack_id)
899 ProcNonParen(s, arc, true);
900 else if (stack_id == stack_.Pop(nextstack_id))
901 ProcOpenParen(s, arc, stack_id, nextstack_id);
902 else
903 ProcCloseParen(s, arc);
904 }
905 VLOG(2) << "d[" << s << "] = " << Distance(s)
906 << ", fd[" << s << "] = " << FinalDistance(s);
907 }
908 }
909
910 //
911 // Expand() Functions
912 //
913
914 template <class Arc>
915 struct ExpandOptions {
916 bool connect;
917 bool keep_parentheses;
918 typename Arc::Weight weight_threshold;
919
920 ExpandOptions(bool c = true, bool k = false,
921 typename Arc::Weight w = Arc::Weight::Zero())
connectExpandOptions922 : connect(c), keep_parentheses(k), weight_threshold(w) {}
923 };
924
925 // Expands a pushdown transducer (PDT) encoded as an FST into an FST.
926 // This version writes the expanded PDT result to a MutableFst.
927 // In the PDT, some transitions are labeled with open or close
928 // parentheses. To be interpreted as a PDT, the parens must balance on
929 // a path. The open-close parenthesis label pairs are passed in
930 // 'parens'. The expansion enforces the parenthesis constraints. The
931 // PDT must be expandable as an FST.
932 template <class Arc>
Expand(const Fst<Arc> & ifst,const vector<pair<typename Arc::Label,typename Arc::Label>> & parens,MutableFst<Arc> * ofst,const ExpandOptions<Arc> & opts)933 void Expand(
934 const Fst<Arc> &ifst,
935 const vector<pair<typename Arc::Label, typename Arc::Label> > &parens,
936 MutableFst<Arc> *ofst,
937 const ExpandOptions<Arc> &opts) {
938 typedef typename Arc::Label Label;
939 typedef typename Arc::StateId StateId;
940 typedef typename Arc::Weight Weight;
941 typedef typename ExpandFst<Arc>::StackId StackId;
942
943 ExpandFstOptions<Arc> eopts;
944 eopts.gc_limit = 0;
945 if (opts.weight_threshold == Weight::Zero()) {
946 eopts.keep_parentheses = opts.keep_parentheses;
947 *ofst = ExpandFst<Arc>(ifst, parens, eopts);
948 } else {
949 PrunedExpand<Arc> pruned_expand(ifst, parens, opts.keep_parentheses);
950 pruned_expand.Expand(ofst, opts.weight_threshold);
951 }
952
953 if (opts.connect)
954 Connect(ofst);
955 }
956
957 // Expands a pushdown transducer (PDT) encoded as an FST into an FST.
958 // This version writes the expanded PDT result to a MutableFst.
959 // In the PDT, some transitions are labeled with open or close
960 // parentheses. To be interpreted as a PDT, the parens must balance on
961 // a path. The open-close parenthesis label pairs are passed in
962 // 'parens'. The expansion enforces the parenthesis constraints. The
963 // PDT must be expandable as an FST.
964 template<class Arc>
965 void Expand(
966 const Fst<Arc> &ifst,
967 const vector<pair<typename Arc::Label, typename Arc::Label> > &parens,
968 MutableFst<Arc> *ofst,
969 bool connect = true, bool keep_parentheses = false) {
970 Expand(ifst, parens, ofst, ExpandOptions<Arc>(connect, keep_parentheses));
971 }
972
973 } // namespace fst
974
975 #endif // FST_EXTENSIONS_PDT_EXPAND_H__
976