1 // expand.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Expand a PDT to an FST.
20 
21 #ifndef FST_EXTENSIONS_PDT_EXPAND_H__
22 #define FST_EXTENSIONS_PDT_EXPAND_H__
23 
24 #include <vector>
25 using std::vector;
26 
27 #include <fst/extensions/pdt/pdt.h>
28 #include <fst/extensions/pdt/paren.h>
29 #include <fst/extensions/pdt/shortest-path.h>
30 #include <fst/extensions/pdt/reverse.h>
31 #include <fst/cache.h>
32 #include <fst/mutable-fst.h>
33 #include <fst/queue.h>
34 #include <fst/state-table.h>
35 #include <fst/test-properties.h>
36 
37 namespace fst {
38 
39 template <class Arc>
40 struct ExpandFstOptions : public CacheOptions {
41   bool keep_parentheses;
42   PdtStack<typename Arc::StateId, typename Arc::Label> *stack;
43   PdtStateTable<typename Arc::StateId, typename Arc::StateId> *state_table;
44 
45   ExpandFstOptions(
46       const CacheOptions &opts = CacheOptions(),
47       bool kp = false,
48       PdtStack<typename Arc::StateId, typename Arc::Label> *s = 0,
49       PdtStateTable<typename Arc::StateId, typename Arc::StateId> *st = 0)
CacheOptionsExpandFstOptions50       : CacheOptions(opts), keep_parentheses(kp), stack(s), state_table(st) {}
51 };
52 
53 // Properties for an expanded PDT.
ExpandProperties(uint64 inprops)54 inline uint64 ExpandProperties(uint64 inprops) {
55   return inprops & (kAcceptor | kAcyclic | kInitialAcyclic | kUnweighted);
56 }
57 
58 
59 // Implementation class for ExpandFst
60 template <class A>
61 class ExpandFstImpl
62     : public CacheImpl<A> {
63  public:
64   using FstImpl<A>::SetType;
65   using FstImpl<A>::SetProperties;
66   using FstImpl<A>::Properties;
67   using FstImpl<A>::SetInputSymbols;
68   using FstImpl<A>::SetOutputSymbols;
69 
70   using CacheBaseImpl< CacheState<A> >::PushArc;
71   using CacheBaseImpl< CacheState<A> >::HasArcs;
72   using CacheBaseImpl< CacheState<A> >::HasFinal;
73   using CacheBaseImpl< CacheState<A> >::HasStart;
74   using CacheBaseImpl< CacheState<A> >::SetArcs;
75   using CacheBaseImpl< CacheState<A> >::SetFinal;
76   using CacheBaseImpl< CacheState<A> >::SetStart;
77 
78   typedef A Arc;
79   typedef typename A::Label Label;
80   typedef typename A::Weight Weight;
81   typedef typename A::StateId StateId;
82   typedef StateId StackId;
83   typedef PdtStateTuple<StateId, StackId> StateTuple;
84 
ExpandFstImpl(const Fst<A> & fst,const vector<pair<typename Arc::Label,typename Arc::Label>> & parens,const ExpandFstOptions<A> & opts)85   ExpandFstImpl(const Fst<A> &fst,
86                 const vector<pair<typename Arc::Label,
87                                   typename Arc::Label> > &parens,
88                 const ExpandFstOptions<A> &opts)
89       : CacheImpl<A>(opts), fst_(fst.Copy()),
90         stack_(opts.stack ? opts.stack: new PdtStack<StateId, Label>(parens)),
91         state_table_(opts.state_table ? opts.state_table :
92                      new PdtStateTable<StateId, StackId>()),
93         own_stack_(opts.stack == 0), own_state_table_(opts.state_table == 0),
94         keep_parentheses_(opts.keep_parentheses) {
95     SetType("expand");
96 
97     uint64 props = fst.Properties(kFstProperties, false);
98     SetProperties(ExpandProperties(props), kCopyProperties);
99 
100     SetInputSymbols(fst.InputSymbols());
101     SetOutputSymbols(fst.OutputSymbols());
102   }
103 
ExpandFstImpl(const ExpandFstImpl & impl)104   ExpandFstImpl(const ExpandFstImpl &impl)
105       : CacheImpl<A>(impl),
106         fst_(impl.fst_->Copy(true)),
107         stack_(new PdtStack<StateId, Label>(*impl.stack_)),
108         state_table_(new PdtStateTable<StateId, StackId>()),
109         own_stack_(true), own_state_table_(true),
110         keep_parentheses_(impl.keep_parentheses_) {
111     SetType("expand");
112     SetProperties(impl.Properties(), kCopyProperties);
113     SetInputSymbols(impl.InputSymbols());
114     SetOutputSymbols(impl.OutputSymbols());
115   }
116 
~ExpandFstImpl()117   ~ExpandFstImpl() {
118     delete fst_;
119     if (own_stack_)
120       delete stack_;
121     if (own_state_table_)
122       delete state_table_;
123   }
124 
Start()125   StateId Start() {
126     if (!HasStart()) {
127       StateId s = fst_->Start();
128       if (s == kNoStateId)
129         return kNoStateId;
130       StateTuple tuple(s, 0);
131       StateId start = state_table_->FindState(tuple);
132       SetStart(start);
133     }
134     return CacheImpl<A>::Start();
135   }
136 
Final(StateId s)137   Weight Final(StateId s) {
138     if (!HasFinal(s)) {
139       const StateTuple &tuple = state_table_->Tuple(s);
140       Weight w = fst_->Final(tuple.state_id);
141       if (w != Weight::Zero() && tuple.stack_id == 0)
142         SetFinal(s, w);
143       else
144         SetFinal(s, Weight::Zero());
145     }
146     return CacheImpl<A>::Final(s);
147   }
148 
NumArcs(StateId s)149   size_t NumArcs(StateId s) {
150     if (!HasArcs(s)) {
151       ExpandState(s);
152     }
153     return CacheImpl<A>::NumArcs(s);
154   }
155 
NumInputEpsilons(StateId s)156   size_t NumInputEpsilons(StateId s) {
157     if (!HasArcs(s))
158       ExpandState(s);
159     return CacheImpl<A>::NumInputEpsilons(s);
160   }
161 
NumOutputEpsilons(StateId s)162   size_t NumOutputEpsilons(StateId s) {
163     if (!HasArcs(s))
164       ExpandState(s);
165     return CacheImpl<A>::NumOutputEpsilons(s);
166   }
167 
InitArcIterator(StateId s,ArcIteratorData<A> * data)168   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
169     if (!HasArcs(s))
170       ExpandState(s);
171     CacheImpl<A>::InitArcIterator(s, data);
172   }
173 
174   // Computes the outgoing transitions from a state, creating new destination
175   // states as needed.
ExpandState(StateId s)176   void ExpandState(StateId s) {
177     StateTuple tuple = state_table_->Tuple(s);
178     for (ArcIterator< Fst<A> > aiter(*fst_, tuple.state_id);
179          !aiter.Done(); aiter.Next()) {
180       Arc arc = aiter.Value();
181       StackId stack_id = stack_->Find(tuple.stack_id, arc.ilabel);
182       if (stack_id == -1) {
183         // Non-matching close parenthesis
184         continue;
185       } else if ((stack_id != tuple.stack_id) && !keep_parentheses_) {
186         // Stack push/pop
187         arc.ilabel = arc.olabel = 0;
188       }
189 
190       StateTuple ntuple(arc.nextstate, stack_id);
191       arc.nextstate = state_table_->FindState(ntuple);
192       PushArc(s, arc);
193     }
194     SetArcs(s);
195   }
196 
GetStack()197   const PdtStack<StackId, Label> &GetStack() const { return *stack_; }
198 
GetStateTable()199   const PdtStateTable<StateId, StackId> &GetStateTable() const {
200     return *state_table_;
201   }
202 
203  private:
204   const Fst<A> *fst_;
205 
206   PdtStack<StackId, Label> *stack_;
207   PdtStateTable<StateId, StackId> *state_table_;
208   bool own_stack_;
209   bool own_state_table_;
210   bool keep_parentheses_;
211 
212   void operator=(const ExpandFstImpl<A> &);  // disallow
213 };
214 
215 // Expands a pushdown transducer (PDT) encoded as an FST into an FST.
216 // This version is a delayed Fst.  In the PDT, some transitions are
217 // labeled with open or close parentheses. To be interpreted as a PDT,
218 // the parens must balance on a path. The open-close parenthesis label
219 // pairs are passed in 'parens'. The expansion enforces the
220 // parenthesis constraints. The PDT must be expandable as an FST.
221 //
222 // This class attaches interface to implementation and handles
223 // reference counting, delegating most methods to ImplToFst.
224 template <class A>
225 class ExpandFst : public ImplToFst< ExpandFstImpl<A> > {
226  public:
227   friend class ArcIterator< ExpandFst<A> >;
228   friend class StateIterator< ExpandFst<A> >;
229 
230   typedef A Arc;
231   typedef typename A::Label Label;
232   typedef typename A::Weight Weight;
233   typedef typename A::StateId StateId;
234   typedef StateId StackId;
235   typedef CacheState<A> State;
236   typedef ExpandFstImpl<A> Impl;
237 
ExpandFst(const Fst<A> & fst,const vector<pair<typename Arc::Label,typename Arc::Label>> & parens)238   ExpandFst(const Fst<A> &fst,
239             const vector<pair<typename Arc::Label,
240                               typename Arc::Label> > &parens)
241       : ImplToFst<Impl>(new Impl(fst, parens, ExpandFstOptions<A>())) {}
242 
ExpandFst(const Fst<A> & fst,const vector<pair<typename Arc::Label,typename Arc::Label>> & parens,const ExpandFstOptions<A> & opts)243   ExpandFst(const Fst<A> &fst,
244             const vector<pair<typename Arc::Label,
245                               typename Arc::Label> > &parens,
246             const ExpandFstOptions<A> &opts)
247       : ImplToFst<Impl>(new Impl(fst, parens, opts)) {}
248 
249   // See Fst<>::Copy() for doc.
250   ExpandFst(const ExpandFst<A> &fst, bool safe = false)
251       : ImplToFst<Impl>(fst, safe) {}
252 
253   // Get a copy of this ExpandFst. See Fst<>::Copy() for further doc.
254   virtual ExpandFst<A> *Copy(bool safe = false) const {
255     return new ExpandFst<A>(*this, safe);
256   }
257 
258   virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
259 
InitArcIterator(StateId s,ArcIteratorData<A> * data)260   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
261     GetImpl()->InitArcIterator(s, data);
262   }
263 
GetStack()264   const PdtStack<StackId, Label> &GetStack() const {
265     return GetImpl()->GetStack();
266   }
267 
GetStateTable()268   const PdtStateTable<StateId, StackId> &GetStateTable() const {
269     return GetImpl()->GetStateTable();
270   }
271 
272  private:
273   // Makes visible to friends.
GetImpl()274   Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
275 
276   void operator=(const ExpandFst<A> &fst);  // Disallow
277 };
278 
279 
280 // Specialization for ExpandFst.
281 template<class A>
282 class StateIterator< ExpandFst<A> >
283     : public CacheStateIterator< ExpandFst<A> > {
284  public:
StateIterator(const ExpandFst<A> & fst)285   explicit StateIterator(const ExpandFst<A> &fst)
286       : CacheStateIterator< ExpandFst<A> >(fst, fst.GetImpl()) {}
287 };
288 
289 
290 // Specialization for ExpandFst.
291 template <class A>
292 class ArcIterator< ExpandFst<A> >
293     : public CacheArcIterator< ExpandFst<A> > {
294  public:
295   typedef typename A::StateId StateId;
296 
ArcIterator(const ExpandFst<A> & fst,StateId s)297   ArcIterator(const ExpandFst<A> &fst, StateId s)
298       : CacheArcIterator< ExpandFst<A> >(fst.GetImpl(), s) {
299     if (!fst.GetImpl()->HasArcs(s))
300       fst.GetImpl()->ExpandState(s);
301   }
302 
303  private:
304   DISALLOW_COPY_AND_ASSIGN(ArcIterator);
305 };
306 
307 
308 template <class A> inline
InitStateIterator(StateIteratorData<A> * data)309 void ExpandFst<A>::InitStateIterator(StateIteratorData<A> *data) const
310 {
311   data->base = new StateIterator< ExpandFst<A> >(*this);
312 }
313 
314 //
315 // PrunedExpand Class
316 //
317 
318 // Prunes the delayed expansion of a pushdown transducer (PDT) encoded
319 // as an FST into an FST.  In the PDT, some transitions are labeled
320 // with open or close parentheses. To be interpreted as a PDT, the
321 // parens must balance on a path. The open-close parenthesis label
322 // pairs are passed in 'parens'. The expansion enforces the
323 // parenthesis constraints.
324 //
325 // The algorithm works by visiting the delayed ExpandFst using a
326 // shortest-stack first queue discipline and relies on the
327 // shortest-distance information computed using a reverse
328 // shortest-path call to perform the pruning.
329 //
330 // The algorithm maintains the same state ordering between the ExpandFst
331 // being visited 'efst_' and the result of pruning written into the
332 // MutableFst 'ofst_' to improve readability of the code.
333 //
334 template <class A>
335 class PrunedExpand {
336  public:
337   typedef A Arc;
338   typedef typename A::Label Label;
339   typedef typename A::StateId StateId;
340   typedef typename A::Weight Weight;
341   typedef StateId StackId;
342   typedef PdtStack<StackId, Label> Stack;
343   typedef PdtStateTable<StateId, StackId> StateTable;
344   typedef typename PdtBalanceData<Arc>::SetIterator SetIterator;
345 
346   // Constructor taking as input a PDT specified by 'ifst' and 'parens'.
347   // 'keep_parentheses' specifies whether parentheses are replaced by
348   // epsilons or not during the expansion. 'opts' is the cache options
349   // used to instantiate the underlying ExpandFst.
350   PrunedExpand(const Fst<A> &ifst,
351                const vector<pair<Label, Label> > &parens,
352                bool keep_parentheses = false,
353                const CacheOptions &opts = CacheOptions())
354       : ifst_(ifst.Copy()),
355         keep_parentheses_(keep_parentheses),
356         stack_(parens),
357         efst_(ifst, parens,
358               ExpandFstOptions<Arc>(opts, true, &stack_, &state_table_)),
359         queue_(state_table_, stack_, stack_length_, distance_, fdistance_) {
360     Reverse(*ifst_, parens, &rfst_);
361     VectorFst<Arc> path;
362     reverse_shortest_path_ = new SP(
363         rfst_, parens,
364         PdtShortestPathOptions<A, FifoQueue<StateId> >(true, false));
365     reverse_shortest_path_->ShortestPath(&path);
366     balance_data_ = reverse_shortest_path_->GetBalanceData()->Reverse(
367         rfst_.NumStates(), 10, -1);
368 
369     InitCloseParenMultimap(parens);
370   }
371 
~PrunedExpand()372   ~PrunedExpand() {
373     delete ifst_;
374     delete reverse_shortest_path_;
375     delete balance_data_;
376   }
377 
378   // Expands and prunes with weight threshold 'threshold' the input PDT.
379   // Writes the result in 'ofst'.
380   void Expand(MutableFst<A> *ofst, const Weight &threshold);
381 
382  private:
383   static const uint8 kEnqueued;
384   static const uint8 kExpanded;
385   static const uint8 kSourceState;
386 
387   // Comparison functor used by the queue:
388   // 1. states corresponding to shortest stack first,
389   // 2. among stacks of the same length, reverse lexicographic order is used,
390   // 3. among states with the same stack, shortest-first order is used.
391   class StackCompare {
392    public:
StackCompare(const StateTable & st,const Stack & s,const vector<StackId> & sl,const vector<Weight> & d,const vector<Weight> & fd)393     StackCompare(const StateTable &st,
394                  const Stack &s, const vector<StackId> &sl,
395                  const vector<Weight> &d, const vector<Weight> &fd)
396         : state_table_(st), stack_(s), stack_length_(sl),
397           distance_(d), fdistance_(fd) {}
398 
operator()399     bool operator()(StateId s1, StateId s2) const {
400       StackId si1 = state_table_.Tuple(s1).stack_id;
401       StackId si2 = state_table_.Tuple(s2).stack_id;
402       if (stack_length_[si1] < stack_length_[si2])
403         return true;
404       if  (stack_length_[si1] > stack_length_[si2])
405         return false;
406       // If stack id equal, use A*
407       if (si1 == si2) {
408         Weight w1 = (s1 < distance_.size()) && (s1 < fdistance_.size()) ?
409             Times(distance_[s1], fdistance_[s1]) : Weight::Zero();
410         Weight w2 = (s2 < distance_.size()) && (s2 < fdistance_.size()) ?
411             Times(distance_[s2], fdistance_[s2]) : Weight::Zero();
412         return less_(w1, w2);
413       }
414       // If lenghts are equal, use reverse lexico.
415       for (; si1 != si2; si1 = stack_.Pop(si1), si2 = stack_.Pop(si2)) {
416         if (stack_.Top(si1) < stack_.Top(si2)) return true;
417         if (stack_.Top(si1) > stack_.Top(si2)) return false;
418       }
419       return false;
420     }
421 
422    private:
423     const StateTable &state_table_;
424     const Stack &stack_;
425     const vector<StackId> &stack_length_;
426     const vector<Weight> &distance_;
427     const vector<Weight> &fdistance_;
428     NaturalLess<Weight> less_;
429   };
430 
431   class ShortestStackFirstQueue
432       : public ShortestFirstQueue<StateId, StackCompare> {
433    public:
ShortestStackFirstQueue(const PdtStateTable<StateId,StackId> & st,const Stack & s,const vector<StackId> & sl,const vector<Weight> & d,const vector<Weight> & fd)434     ShortestStackFirstQueue(
435         const PdtStateTable<StateId, StackId> &st,
436         const Stack &s,
437         const vector<StackId> &sl,
438         const vector<Weight> &d, const vector<Weight> &fd)
439         : ShortestFirstQueue<StateId, StackCompare>(
440             StackCompare(st, s, sl, d, fd)) {}
441   };
442 
443 
444   void InitCloseParenMultimap(const vector<pair<Label, Label> > &parens);
445   Weight DistanceToDest(StateId state, StateId source) const;
446   uint8 Flags(StateId s) const;
447   void SetFlags(StateId s, uint8 flags, uint8 mask);
448   Weight Distance(StateId s) const;
449   void SetDistance(StateId s, Weight w);
450   Weight FinalDistance(StateId s) const;
451   void SetFinalDistance(StateId s, Weight w);
452   StateId SourceState(StateId s) const;
453   void SetSourceState(StateId s, StateId p);
454   void AddStateAndEnqueue(StateId s);
455   void Relax(StateId s, const A &arc, Weight w);
456   bool PruneArc(StateId s, const A &arc);
457   void ProcStart();
458   void ProcFinal(StateId s);
459   bool ProcNonParen(StateId s, const A &arc, bool add_arc);
460   bool ProcOpenParen(StateId s, const A &arc, StackId si, StackId nsi);
461   bool ProcCloseParen(StateId s, const A &arc);
462   void ProcDestStates(StateId s, StackId si);
463 
464   Fst<A> *ifst_;                   // Input PDT
465   VectorFst<Arc> rfst_;            // Reversed PDT
466   bool keep_parentheses_;          // Keep parentheses in ofst?
467   StateTable state_table_;         // State table for efst_
468   Stack stack_;                    // Stack trie
469   ExpandFst<Arc> efst_;            // Expanded PDT
470   vector<StackId> stack_length_;   // Length of stack for given stack id
471   vector<Weight> distance_;        // Distance from initial state in efst_/ofst
472   vector<Weight> fdistance_;       // Distance to final states in efst_/ofst
473   ShortestStackFirstQueue queue_;  // Queue used to visit efst_
474   vector<uint8> flags_;            // Status flags for states in efst_/ofst
475   vector<StateId> sources_;        // PDT source state for each expanded state
476 
477   typedef PdtShortestPath<Arc, FifoQueue<StateId> > SP;
478   typedef typename SP::CloseParenMultimap ParenMultimap;
479   SP *reverse_shortest_path_;  // Shortest path for rfst_
480   PdtBalanceData<Arc> *balance_data_;   // Not owned by shortest_path_
481   ParenMultimap close_paren_multimap_;  // Maps open paren arcs to
482   // balancing close paren arcs.
483 
484   MutableFst<Arc> *ofst_;  // Output fst
485   Weight limit_;           // Weight limit
486 
487   typedef unordered_map<StateId, Weight> DestMap;
488   DestMap dest_map_;
489   StackId current_stack_id_;
490   // 'current_stack_id_' is the stack id of the states currently at the top
491   // of queue, i.e., the states currently being popped and processed.
492   // 'dest_map_' maps a state 's' in 'ifst_' that is the source
493   // of a close parentheses matching the top of 'current_stack_id_; to
494   // the shortest-distance from '(s, current_stack_id_)' to the final
495   // states in 'efst_'.
496   ssize_t current_paren_id_;  // Paren id at top of current stack
497   ssize_t cached_stack_id_;
498   StateId cached_source_;
499   slist<pair<StateId, Weight> > cached_dest_list_;
500   // 'cached_dest_list_' contains the set of pair of destination
501   // states and weight to final states for source state
502   // 'cached_source_' and paren id 'cached_paren_id': the set of
503   // source state of a close parenthesis with paren id
504   // 'cached_paren_id' balancing an incoming open parenthesis with
505   // paren id 'cached_paren_id' in state 'cached_source_'.
506 
507   NaturalLess<Weight> less_;
508 };
509 
510 template <class A> const uint8 PrunedExpand<A>::kEnqueued = 0x01;
511 template <class A> const uint8 PrunedExpand<A>::kExpanded = 0x02;
512 template <class A> const uint8 PrunedExpand<A>::kSourceState = 0x04;
513 
514 
515 // Initializes close paren multimap, mapping pairs (s,paren_id) to
516 // all the arcs out of s labeled with close parenthese for paren_id.
517 template <class A>
InitCloseParenMultimap(const vector<pair<Label,Label>> & parens)518 void PrunedExpand<A>::InitCloseParenMultimap(
519     const vector<pair<Label, Label> > &parens) {
520   unordered_map<Label, Label> paren_id_map;
521   for (Label i = 0; i < parens.size(); ++i) {
522     const pair<Label, Label>  &p = parens[i];
523     paren_id_map[p.first] = i;
524     paren_id_map[p.second] = i;
525   }
526 
527   for (StateIterator<Fst<Arc> > siter(*ifst_); !siter.Done(); siter.Next()) {
528     StateId s = siter.Value();
529     for (ArcIterator<Fst<Arc> > aiter(*ifst_, s);
530          !aiter.Done(); aiter.Next()) {
531       const Arc &arc = aiter.Value();
532       typename unordered_map<Label, Label>::const_iterator pit
533           = paren_id_map.find(arc.ilabel);
534       if (pit == paren_id_map.end()) continue;
535       if (arc.ilabel == parens[pit->second].second) {  // Close paren
536         ParenState<Arc> paren_state(pit->second, s);
537         close_paren_multimap_.insert(make_pair(paren_state, arc));
538       }
539     }
540   }
541 }
542 
543 
544 // Returns the weight of the shortest balanced path from 'source' to 'dest'
545 // in 'ifst_', 'dest' must be the source state of a close paren arc.
546 template <class A>
DistanceToDest(StateId source,StateId dest)547 typename A::Weight PrunedExpand<A>::DistanceToDest(StateId source,
548                                                    StateId dest) const {
549   typename SP::SearchState s(source + 1, dest + 1);
550   VLOG(2) << "D(" << source << ", " << dest << ") ="
551             << reverse_shortest_path_->GetShortestPathData().Distance(s);
552   return reverse_shortest_path_->GetShortestPathData().Distance(s);
553 }
554 
555 // Returns the flags for state 's' in 'ofst_'.
556 template <class A>
Flags(StateId s)557 uint8 PrunedExpand<A>::Flags(StateId s) const {
558   return s < flags_.size() ? flags_[s] : 0;
559 }
560 
561 // Modifies the flags for state 's' in 'ofst_'.
562 template <class A>
SetFlags(StateId s,uint8 flags,uint8 mask)563 void PrunedExpand<A>::SetFlags(StateId s, uint8 flags, uint8 mask) {
564   while (flags_.size() <= s) flags_.push_back(0);
565   flags_[s] &= ~mask;
566   flags_[s] |= flags & mask;
567 }
568 
569 
570 // Returns the shortest distance from the initial state to 's' in 'ofst_'.
571 template <class A>
Distance(StateId s)572 typename A::Weight PrunedExpand<A>::Distance(StateId s) const {
573   return s < distance_.size() ? distance_[s] : Weight::Zero();
574 }
575 
576 // Sets the shortest distance from the initial state to 's' in 'ofst_' to 'w'.
577 template <class A>
SetDistance(StateId s,Weight w)578 void PrunedExpand<A>::SetDistance(StateId s, Weight w) {
579   while (distance_.size() <= s ) distance_.push_back(Weight::Zero());
580   distance_[s] = w;
581 }
582 
583 
584 // Returns the shortest distance from 's' to the final states in 'ofst_'.
585 template <class A>
FinalDistance(StateId s)586 typename A::Weight PrunedExpand<A>::FinalDistance(StateId s) const {
587   return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
588 }
589 
590 // Sets the shortest distance from 's' to the final states in 'ofst_' to 'w'.
591 template <class A>
SetFinalDistance(StateId s,Weight w)592 void PrunedExpand<A>::SetFinalDistance(StateId s, Weight w) {
593   while (fdistance_.size() <= s) fdistance_.push_back(Weight::Zero());
594   fdistance_[s] = w;
595 }
596 
597 // Returns the PDT "source" state of state 's' in 'ofst_'.
598 template <class A>
SourceState(StateId s)599 typename A::StateId PrunedExpand<A>::SourceState(StateId s) const {
600   return s < sources_.size() ? sources_[s] : kNoStateId;
601 }
602 
603 // Sets the PDT "source" state of state 's' in 'ofst_' to state 'p' in 'ifst_'.
604 template <class A>
SetSourceState(StateId s,StateId p)605 void PrunedExpand<A>::SetSourceState(StateId s, StateId p) {
606   while (sources_.size() <= s) sources_.push_back(kNoStateId);
607   sources_[s] = p;
608 }
609 
610 // Adds state 's' of 'efst_' to 'ofst_' and inserts it in the queue,
611 // modifying the flags for 's' accordingly.
612 template <class A>
AddStateAndEnqueue(StateId s)613 void PrunedExpand<A>::AddStateAndEnqueue(StateId s) {
614   if (!(Flags(s) & (kEnqueued | kExpanded))) {
615     while (ofst_->NumStates() <= s) ofst_->AddState();
616     queue_.Enqueue(s);
617     SetFlags(s, kEnqueued, kEnqueued);
618   } else if (Flags(s) & kEnqueued) {
619     queue_.Update(s);
620   }
621   // TODO(allauzen): Check everything is fine when kExpanded?
622 }
623 
624 // Relaxes arc 'arc' out of state 's' in 'ofst_':
625 // * if the distance to 's' times the weight of 'arc' is smaller than
626 //   the currently stored distance for 'arc.nextstate',
627 //   updates 'Distance(arc.nextstate)' with new estimate;
628 // * if 'fd' is less than the currently stored distance from 'arc.nextstate'
629 //   to the final state, updates with new estimate.
630 template <class A>
Relax(StateId s,const A & arc,Weight fd)631 void PrunedExpand<A>::Relax(StateId s, const A &arc, Weight fd) {
632   Weight nd = Times(Distance(s), arc.weight);
633   if (less_(nd, Distance(arc.nextstate))) {
634     SetDistance(arc.nextstate, nd);
635     SetSourceState(arc.nextstate, SourceState(s));
636   }
637   if (less_(fd, FinalDistance(arc.nextstate)))
638     SetFinalDistance(arc.nextstate, fd);
639   VLOG(2) << "Relax: " << s << ", d[s] = " << Distance(s) << ", to "
640             << arc.nextstate << ", d[ns] = " << Distance(arc.nextstate)
641             << ", nd = " << nd;
642 }
643 
644 // Returns 'true' if the arc 'arc' out of state 's' in 'efst_' needs to
645 // be pruned.
646 template <class A>
PruneArc(StateId s,const A & arc)647 bool PrunedExpand<A>::PruneArc(StateId s, const A &arc) {
648   VLOG(2) << "Prune ?";
649   Weight fd = Weight::Zero();
650 
651   if ((cached_source_ != SourceState(s)) ||
652       (cached_stack_id_ != current_stack_id_)) {
653     cached_source_ = SourceState(s);
654     cached_stack_id_ = current_stack_id_;
655     cached_dest_list_.clear();
656     if (cached_source_ != ifst_->Start()) {
657       for (SetIterator set_iter =
658                balance_data_->Find(current_paren_id_, cached_source_);
659            !set_iter.Done(); set_iter.Next()) {
660         StateId dest = set_iter.Element();
661         typename DestMap::const_iterator iter = dest_map_.find(dest);
662         cached_dest_list_.push_front(*iter);
663       }
664     } else {
665       // TODO(allauzen): queue discipline should prevent this never
666       // from happening; replace by a check.
667       cached_dest_list_.push_front(
668           make_pair(rfst_.Start() -1, Weight::One()));
669     }
670   }
671 
672   for (typename slist<pair<StateId, Weight> >::const_iterator iter =
673            cached_dest_list_.begin();
674        iter != cached_dest_list_.end();
675        ++iter) {
676     fd = Plus(fd,
677               Times(DistanceToDest(state_table_.Tuple(arc.nextstate).state_id,
678                                    iter->first),
679                     iter->second));
680   }
681   Relax(s, arc, fd);
682   Weight w = Times(Distance(s), Times(arc.weight, fd));
683   return less_(limit_, w);
684 }
685 
686 // Adds start state of 'efst_' to 'ofst_', enqueues it and initializes
687 // the distance data structures.
688 template <class A>
ProcStart()689 void PrunedExpand<A>::ProcStart() {
690   StateId s = efst_.Start();
691   AddStateAndEnqueue(s);
692   ofst_->SetStart(s);
693   SetSourceState(s, ifst_->Start());
694 
695   current_stack_id_ = 0;
696   current_paren_id_ = -1;
697   stack_length_.push_back(0);
698   dest_map_[rfst_.Start() - 1] = Weight::One(); // not needed
699 
700   cached_source_ = ifst_->Start();
701   cached_stack_id_ = 0;
702   cached_dest_list_.push_front(
703           make_pair(rfst_.Start() -1, Weight::One()));
704 
705   PdtStateTuple<StateId, StackId> tuple(rfst_.Start() - 1, 0);
706   SetFinalDistance(state_table_.FindState(tuple), Weight::One());
707   SetDistance(s, Weight::One());
708   SetFinalDistance(s, DistanceToDest(ifst_->Start(), rfst_.Start() - 1));
709   VLOG(2) << DistanceToDest(ifst_->Start(), rfst_.Start() - 1);
710 }
711 
712 // Makes 's' final in 'ofst_' if shortest accepting path ending in 's'
713 // is below threshold.
714 template <class A>
ProcFinal(StateId s)715 void PrunedExpand<A>::ProcFinal(StateId s) {
716   Weight final = efst_.Final(s);
717   if ((final == Weight::Zero()) || less_(limit_, Times(Distance(s), final)))
718     return;
719   ofst_->SetFinal(s, final);
720 }
721 
722 // Returns true when arc (or meta-arc) 'arc' out of 's' in 'efst_' is
723 // below the threshold.  When 'add_arc' is true, 'arc' is added to 'ofst_'.
724 template <class A>
ProcNonParen(StateId s,const A & arc,bool add_arc)725 bool PrunedExpand<A>::ProcNonParen(StateId s, const A &arc, bool add_arc) {
726   VLOG(2) << "ProcNonParen: " << s << " to " << arc.nextstate
727           << ", " << arc.ilabel << ":" << arc.olabel << " / " << arc.weight
728           << ", add_arc = " << (add_arc ? "true" : "false");
729   if (PruneArc(s, arc)) return false;
730   if(add_arc) ofst_->AddArc(s, arc);
731   AddStateAndEnqueue(arc.nextstate);
732   return true;
733 }
734 
735 // Processes an open paren arc 'arc' out of state 's' in 'ofst_'.
736 // When 'arc' is labeled with an open paren,
737 // 1. considers each (shortest) balanced path starting in 's' by
738 //    taking 'arc' and ending by a close paren balancing the open
739 //    paren of 'arc' as a meta-arc, processes and prunes each meta-arc
740 //    as a non-paren arc, inserting its destination to the queue;
741 // 2. if at least one of these meta-arcs has not been pruned,
742 //    adds the destination of 'arc' to 'ofst_' as a new source state
743 //    for the stack id 'nsi' and inserts it in the queue.
744 template <class A>
ProcOpenParen(StateId s,const A & arc,StackId si,StackId nsi)745 bool PrunedExpand<A>::ProcOpenParen(StateId s, const A &arc, StackId si,
746                                     StackId nsi) {
747   // Update the stack lenght when needed: |nsi| = |si| + 1.
748   while (stack_length_.size() <= nsi) stack_length_.push_back(-1);
749   if (stack_length_[nsi] == -1)
750     stack_length_[nsi] = stack_length_[si] + 1;
751 
752   StateId ns = arc.nextstate;
753   VLOG(2) << "Open paren: " << s << "(" << state_table_.Tuple(s).state_id
754             << ") to " << ns << "(" << state_table_.Tuple(ns).state_id << ")";
755   bool proc_arc = false;
756   Weight fd = Weight::Zero();
757   ssize_t paren_id = stack_.ParenId(arc.ilabel);
758   slist<StateId> sources;
759   for (SetIterator set_iter =
760            balance_data_->Find(paren_id, state_table_.Tuple(ns).state_id);
761        !set_iter.Done(); set_iter.Next()) {
762     sources.push_front(set_iter.Element());
763   }
764   for (typename slist<StateId>::const_iterator sources_iter = sources.begin();
765        sources_iter != sources.end();
766        ++ sources_iter) {
767     StateId source = *sources_iter;
768     VLOG(2) << "Close paren source: " << source;
769     ParenState<Arc> paren_state(paren_id, source);
770     for (typename ParenMultimap::const_iterator iter =
771              close_paren_multimap_.find(paren_state);
772          iter != close_paren_multimap_.end() && paren_state == iter->first;
773          ++iter) {
774       Arc meta_arc = iter->second;
775       PdtStateTuple<StateId, StackId> tuple(meta_arc.nextstate, si);
776       meta_arc.nextstate =  state_table_.FindState(tuple);
777       VLOG(2) << state_table_.Tuple(ns).state_id << ", " << source;
778       VLOG(2) << "Meta arc weight = " << arc.weight << " Times "
779                 << DistanceToDest(state_table_.Tuple(ns).state_id, source)
780                 << " Times " << meta_arc.weight;
781       meta_arc.weight = Times(
782           arc.weight,
783           Times(DistanceToDest(state_table_.Tuple(ns).state_id, source),
784                 meta_arc.weight));
785       proc_arc |= ProcNonParen(s, meta_arc, false);
786       fd = Plus(fd, Times(
787           Times(
788               DistanceToDest(state_table_.Tuple(ns).state_id, source),
789               iter->second.weight),
790           FinalDistance(meta_arc.nextstate)));
791     }
792   }
793   if (proc_arc) {
794     VLOG(2) << "Proc open paren " << s << " to " << arc.nextstate;
795     ofst_->AddArc(
796       s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
797     AddStateAndEnqueue(arc.nextstate);
798     Weight nd = Times(Distance(s), arc.weight);
799     if(less_(nd, Distance(arc.nextstate)))
800       SetDistance(arc.nextstate, nd);
801     // FinalDistance not necessary for source state since pruning
802     // decided using the meta-arcs above.  But this is a problem with
803     // A*, hence:
804     if (less_(fd, FinalDistance(arc.nextstate)))
805       SetFinalDistance(arc.nextstate, fd);
806     SetFlags(arc.nextstate, kSourceState, kSourceState);
807   }
808   return proc_arc;
809 }
810 
811 // Checks that shortest path through close paren arc in 'efst_' is
812 // below threshold, if so adds it to 'ofst_'.
813 template <class A>
ProcCloseParen(StateId s,const A & arc)814 bool PrunedExpand<A>::ProcCloseParen(StateId s, const A &arc) {
815   Weight w = Times(Distance(s),
816                    Times(arc.weight, FinalDistance(arc.nextstate)));
817   if (less_(limit_, w))
818     return false;
819   ofst_->AddArc(
820       s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
821   return true;
822 }
823 
824 // When 's' in 'ofst_' is a source state for stack id 'si', identifies
825 // all the corresponding possible destination states, that is, all the
826 // states in 'ifst_' that have an outgoing close paren arc balancing
827 // the incoming open paren taken to get to 's', and for each such
828 // state 't', computes the shortest distance from (t, si) to the final
829 // states in 'ofst_'. Stores this information in 'dest_map_'.
830 template <class A>
ProcDestStates(StateId s,StackId si)831 void PrunedExpand<A>::ProcDestStates(StateId s, StackId si) {
832   if (!(Flags(s) & kSourceState)) return;
833   if (si != current_stack_id_) {
834     dest_map_.clear();
835     current_stack_id_ = si;
836     current_paren_id_ = stack_.Top(current_stack_id_);
837     VLOG(2) << "StackID " << si << " dequeued for first time";
838   }
839   // TODO(allauzen): clean up source state business; rename current function to
840   // ProcSourceState.
841   SetSourceState(s, state_table_.Tuple(s).state_id);
842 
843   ssize_t paren_id = stack_.Top(si);
844   for (SetIterator set_iter =
845            balance_data_->Find(paren_id, state_table_.Tuple(s).state_id);
846        !set_iter.Done(); set_iter.Next()) {
847     StateId dest_state = set_iter.Element();
848     if (dest_map_.find(dest_state) != dest_map_.end())
849       continue;
850     Weight dest_weight = Weight::Zero();
851     ParenState<Arc> paren_state(paren_id, dest_state);
852     for (typename ParenMultimap::const_iterator iter =
853              close_paren_multimap_.find(paren_state);
854          iter != close_paren_multimap_.end() && paren_state == iter->first;
855          ++iter) {
856       const Arc &arc = iter->second;
857       PdtStateTuple<StateId, StackId> tuple(arc.nextstate, stack_.Pop(si));
858       dest_weight = Plus(dest_weight,
859                          Times(arc.weight,
860                                FinalDistance(state_table_.FindState(tuple))));
861     }
862     dest_map_[dest_state] = dest_weight;
863     VLOG(2) << "State " << dest_state << " is a dest state for stack id "
864               << si << " with weight " << dest_weight;
865   }
866 }
867 
868 // Expands and prunes with weight threshold 'threshold' the input PDT.
869 // Writes the result in 'ofst'.
870 template <class A>
Expand(MutableFst<A> * ofst,const typename A::Weight & threshold)871 void PrunedExpand<A>::Expand(
872     MutableFst<A> *ofst, const typename A::Weight &threshold) {
873   ofst_ = ofst;
874   ofst_->DeleteStates();
875   ofst_->SetInputSymbols(ifst_->InputSymbols());
876   ofst_->SetOutputSymbols(ifst_->OutputSymbols());
877 
878   limit_ = Times(DistanceToDest(ifst_->Start(), rfst_.Start() - 1), threshold);
879   flags_.clear();
880 
881   ProcStart();
882 
883   while (!queue_.Empty()) {
884     StateId s = queue_.Head();
885     queue_.Dequeue();
886     SetFlags(s, kExpanded, kExpanded | kEnqueued);
887     VLOG(2) << s << " dequeued!";
888 
889     ProcFinal(s);
890     StackId stack_id = state_table_.Tuple(s).stack_id;
891     ProcDestStates(s, stack_id);
892 
893     for (ArcIterator<ExpandFst<Arc> > aiter(efst_, s);
894          !aiter.Done();
895          aiter.Next()) {
896       Arc arc = aiter.Value();
897       StackId nextstack_id = state_table_.Tuple(arc.nextstate).stack_id;
898       if (stack_id == nextstack_id)
899         ProcNonParen(s, arc, true);
900       else if (stack_id == stack_.Pop(nextstack_id))
901         ProcOpenParen(s, arc, stack_id, nextstack_id);
902       else
903         ProcCloseParen(s, arc);
904     }
905     VLOG(2) << "d[" << s << "] = " << Distance(s)
906             << ", fd[" << s << "] = " << FinalDistance(s);
907   }
908 }
909 
910 //
911 // Expand() Functions
912 //
913 
914 template <class Arc>
915 struct ExpandOptions {
916   bool connect;
917   bool keep_parentheses;
918   typename Arc::Weight weight_threshold;
919 
920   ExpandOptions(bool c  = true, bool k = false,
921                 typename Arc::Weight w = Arc::Weight::Zero())
connectExpandOptions922       : connect(c), keep_parentheses(k), weight_threshold(w) {}
923 };
924 
925 // Expands a pushdown transducer (PDT) encoded as an FST into an FST.
926 // This version writes the expanded PDT result to a MutableFst.
927 // In the PDT, some transitions are labeled with open or close
928 // parentheses. To be interpreted as a PDT, the parens must balance on
929 // a path. The open-close parenthesis label pairs are passed in
930 // 'parens'. The expansion enforces the parenthesis constraints. The
931 // PDT must be expandable as an FST.
932 template <class Arc>
Expand(const Fst<Arc> & ifst,const vector<pair<typename Arc::Label,typename Arc::Label>> & parens,MutableFst<Arc> * ofst,const ExpandOptions<Arc> & opts)933 void Expand(
934     const Fst<Arc> &ifst,
935     const vector<pair<typename Arc::Label, typename Arc::Label> > &parens,
936     MutableFst<Arc> *ofst,
937     const ExpandOptions<Arc> &opts) {
938   typedef typename Arc::Label Label;
939   typedef typename Arc::StateId StateId;
940   typedef typename Arc::Weight Weight;
941   typedef typename ExpandFst<Arc>::StackId StackId;
942 
943   ExpandFstOptions<Arc> eopts;
944   eopts.gc_limit = 0;
945   if (opts.weight_threshold == Weight::Zero()) {
946     eopts.keep_parentheses = opts.keep_parentheses;
947     *ofst = ExpandFst<Arc>(ifst, parens, eopts);
948   } else {
949     PrunedExpand<Arc> pruned_expand(ifst, parens, opts.keep_parentheses);
950     pruned_expand.Expand(ofst, opts.weight_threshold);
951   }
952 
953   if (opts.connect)
954     Connect(ofst);
955 }
956 
957 // Expands a pushdown transducer (PDT) encoded as an FST into an FST.
958 // This version writes the expanded PDT result to a MutableFst.
959 // In the PDT, some transitions are labeled with open or close
960 // parentheses. To be interpreted as a PDT, the parens must balance on
961 // a path. The open-close parenthesis label pairs are passed in
962 // 'parens'. The expansion enforces the parenthesis constraints. The
963 // PDT must be expandable as an FST.
964 template<class Arc>
965 void Expand(
966     const Fst<Arc> &ifst,
967     const vector<pair<typename Arc::Label, typename Arc::Label> > &parens,
968     MutableFst<Arc> *ofst,
969     bool connect = true, bool keep_parentheses = false) {
970   Expand(ifst, parens, ofst, ExpandOptions<Arc>(connect, keep_parentheses));
971 }
972 
973 }  // namespace fst
974 
975 #endif  // FST_EXTENSIONS_PDT_EXPAND_H__
976