1 // shortest-path.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Functions to find shortest paths in a PDT.
20 
21 #ifndef FST_EXTENSIONS_PDT_SHORTEST_PATH_H__
22 #define FST_EXTENSIONS_PDT_SHORTEST_PATH_H__
23 
24 #include <fst/shortest-path.h>
25 #include <fst/extensions/pdt/paren.h>
26 #include <fst/extensions/pdt/pdt.h>
27 
28 #include <tr1/unordered_map>
29 using std::tr1::unordered_map;
30 using std::tr1::unordered_multimap;
31 #include <tr1/unordered_set>
32 using std::tr1::unordered_set;
33 using std::tr1::unordered_multiset;
34 #include <stack>
35 #include <vector>
36 using std::vector;
37 
38 namespace fst {
39 
40 template <class Arc, class Queue>
41 struct PdtShortestPathOptions {
42   bool keep_parentheses;
43   bool path_gc;
44 
45   PdtShortestPathOptions(bool kp = false, bool gc = true)
keep_parenthesesPdtShortestPathOptions46       : keep_parentheses(kp), path_gc(gc) {}
47 };
48 
49 
50 // Class to store PDT shortest path results. Stores shortest path
51 // tree info 'Distance()', Parent(), and ArcParent() information keyed
52 // on two types:
53 // (1) By SearchState: This is a usual node in a shortest path tree but:
54 //    (a) is w.r.t a PDT search state - a pair of a PDT state and
55 //        a 'start' state, which is either the PDT start state or
56 //        the destination state of an open parenthesis.
57 //    (b) the Distance() is from this 'start' state to the search state.
58 //    (c) Parent().state is kNoLabel for the 'start' state.
59 //
60 // (2) By ParenSpec: This connects shortest path trees depending on the
61 // the parenthesis taken. Given the parenthesis spec:
62 //    (a) the Distance() is from the Parent() 'start' state to the
63 //     parenthesis destination state.
64 //    (b) the ArcParent() is the parenthesis arc.
65 template <class Arc>
66 class PdtShortestPathData {
67  public:
68   static const uint8 kFinal;
69 
70   typedef typename Arc::StateId StateId;
71   typedef typename Arc::Weight Weight;
72   typedef typename Arc::Label Label;
73 
74   struct SearchState {
SearchStateSearchState75     SearchState() : state(kNoStateId), start(kNoStateId) {}
76 
SearchStateSearchState77     SearchState(StateId s, StateId t) : state(s), start(t) {}
78 
79     bool operator==(const SearchState &s) const {
80       if (&s == this)
81         return true;
82       return s.state == this->state && s.start == this->start;
83     }
84 
85     StateId state;  // PDT state
86     StateId start;  // PDT paren 'source' state
87   };
88 
89 
90   // Specifies paren id, source and dest 'start' states of a paren.
91   // These are the 'start' states of the respective sub-graphs.
92   struct ParenSpec {
ParenSpecParenSpec93     ParenSpec()
94         : paren_id(kNoLabel), src_start(kNoStateId), dest_start(kNoStateId) {}
95 
ParenSpecParenSpec96     ParenSpec(Label id, StateId s, StateId d)
97         : paren_id(id), src_start(s), dest_start(d) {}
98 
99     Label paren_id;        // Id of parenthesis
100     StateId src_start;     // sub-graph 'start' state for paren source.
101     StateId dest_start;    // sub-graph 'start' state for paren dest.
102 
103     bool operator==(const ParenSpec &x) const {
104       if (&x == this)
105         return true;
106       return x.paren_id == this->paren_id &&
107           x.src_start == this->src_start &&
108           x.dest_start == this->dest_start;
109     }
110   };
111 
112   struct SearchData {
SearchDataSearchData113     SearchData() : distance(Weight::Zero()),
114                    parent(kNoStateId, kNoStateId),
115                    paren_id(kNoLabel),
116                    flags(0) {}
117 
118     Weight distance;     // Distance to this state from PDT 'start' state
119     SearchState parent;  // Parent state in shortest path tree
120     int16 paren_id;      // If parent arc has paren, paren ID, o.w. kNoLabel
121     uint8 flags;         // First byte reserved for PdtShortestPathData use
122   };
123 
PdtShortestPathData(bool gc)124   PdtShortestPathData(bool gc)
125       : state_(kNoStateId, kNoStateId),
126         paren_(kNoLabel, kNoStateId, kNoStateId),
127         gc_(gc),
128         nstates_(0),
129         ngc_(0),
130         finished_(false) {}
131 
~PdtShortestPathData()132   ~PdtShortestPathData() {
133     VLOG(1) << "opm size: " << paren_map_.size();
134     VLOG(1) << "# of search states: " << nstates_;
135     if (gc_)
136       VLOG(1) << "# of GC'd search states: " << ngc_;
137   }
138 
Clear()139   void Clear() {
140     search_map_.clear();
141     search_multimap_.clear();
142     paren_map_.clear();
143     state_ = SearchState(kNoStateId, kNoStateId);
144     nstates_ = 0;
145     ngc_ = 0;
146   }
147 
Distance(SearchState s)148   Weight Distance(SearchState s) const {
149     SearchData *data = GetSearchData(s);
150     return data->distance;
151   }
152 
Distance(const ParenSpec & paren)153   Weight Distance(const ParenSpec &paren) const {
154     SearchData *data = GetSearchData(paren);
155     return data->distance;
156   }
157 
Parent(SearchState s)158   SearchState Parent(SearchState s) const {
159     SearchData *data = GetSearchData(s);
160     return data->parent;
161   }
162 
Parent(const ParenSpec & paren)163   SearchState Parent(const ParenSpec &paren) const {
164     SearchData *data = GetSearchData(paren);
165     return data->parent;
166   }
167 
ParenId(SearchState s)168   Label ParenId(SearchState s) const {
169     SearchData *data = GetSearchData(s);
170     return data->paren_id;
171   }
172 
Flags(SearchState s)173   uint8 Flags(SearchState s) const {
174     SearchData *data = GetSearchData(s);
175     return data->flags;
176   }
177 
SetDistance(SearchState s,Weight w)178   void SetDistance(SearchState s, Weight w) {
179     SearchData *data = GetSearchData(s);
180     data->distance = w;
181   }
182 
SetDistance(const ParenSpec & paren,Weight w)183   void SetDistance(const ParenSpec &paren, Weight w) {
184     SearchData *data = GetSearchData(paren);
185     data->distance = w;
186   }
187 
SetParent(SearchState s,SearchState p)188   void SetParent(SearchState s, SearchState p) {
189     SearchData *data = GetSearchData(s);
190     data->parent = p;
191   }
192 
SetParent(const ParenSpec & paren,SearchState p)193   void SetParent(const ParenSpec &paren, SearchState p) {
194     SearchData *data = GetSearchData(paren);
195     data->parent = p;
196   }
197 
SetParenId(SearchState s,Label p)198   void SetParenId(SearchState s, Label p) {
199     if (p >= 32768)
200       FSTERROR() << "PdtShortestPathData: Paren ID does not fits in an int16";
201     SearchData *data = GetSearchData(s);
202     data->paren_id = p;
203   }
204 
SetFlags(SearchState s,uint8 f,uint8 mask)205   void SetFlags(SearchState s, uint8 f, uint8 mask) {
206     SearchData *data = GetSearchData(s);
207     data->flags &= ~mask;
208     data->flags |= f & mask;
209   }
210 
211   void GC(StateId s);
212 
Finish()213   void Finish() { finished_ = true; }
214 
215  private:
216   static const Arc kNoArc;
217   static const size_t kPrime0;
218   static const size_t kPrime1;
219   static const uint8 kInited;
220   static const uint8 kMarked;
221 
222   // Hash for search state
223   struct SearchStateHash {
operatorSearchStateHash224     size_t operator()(const SearchState &s) const {
225       return s.state + s.start * kPrime0;
226     }
227   };
228 
229   // Hash for paren map
230   struct ParenHash {
operatorParenHash231     size_t operator()(const ParenSpec &paren) const {
232       return paren.paren_id + paren.src_start * kPrime0 +
233           paren.dest_start * kPrime1;
234     }
235   };
236 
237   typedef unordered_map<SearchState, SearchData, SearchStateHash> SearchMap;
238 
239   typedef unordered_multimap<StateId, StateId> SearchMultimap;
240 
241   // Hash map from paren spec to open paren data
242   typedef unordered_map<ParenSpec, SearchData, ParenHash> ParenMap;
243 
GetSearchData(SearchState s)244   SearchData *GetSearchData(SearchState s) const {
245     if (s == state_)
246       return state_data_;
247     if (finished_) {
248       typename SearchMap::iterator it = search_map_.find(s);
249       if (it == search_map_.end())
250         return &null_search_data_;
251       state_ = s;
252       return state_data_ = &(it->second);
253     } else {
254       state_ = s;
255       state_data_ = &search_map_[s];
256       if (!(state_data_->flags & kInited)) {
257         ++nstates_;
258         if (gc_)
259           search_multimap_.insert(make_pair(s.start, s.state));
260         state_data_->flags = kInited;
261       }
262       return state_data_;
263     }
264   }
265 
GetSearchData(ParenSpec paren)266   SearchData *GetSearchData(ParenSpec paren) const {
267     if (paren == paren_)
268       return paren_data_;
269     if (finished_) {
270       typename ParenMap::iterator it = paren_map_.find(paren);
271       if (it == paren_map_.end())
272         return &null_search_data_;
273       paren_ = paren;
274       return state_data_ = &(it->second);
275     } else {
276       paren_ = paren;
277       return paren_data_ = &paren_map_[paren];
278     }
279   }
280 
281   mutable SearchMap search_map_;            // Maps from search state to data
282   mutable SearchMultimap search_multimap_;  // Maps from 'start' to subgraph
283   mutable ParenMap paren_map_;              // Maps paren spec to search data
284   mutable SearchState state_;               // Last state accessed
285   mutable SearchData *state_data_;          // Last state data accessed
286   mutable ParenSpec paren_;                 // Last paren spec accessed
287   mutable SearchData *paren_data_;          // Last paren data accessed
288   bool gc_;                                 // Allow GC?
289   mutable size_t nstates_;                  // Total number of search states
290   size_t ngc_;                              // Number of GC'd search states
291   mutable SearchData null_search_data_;     // Null search data
292   bool finished_;                           // Read-only access when true
293 
294   DISALLOW_COPY_AND_ASSIGN(PdtShortestPathData);
295 };
296 
297 // Deletes inaccessible search data from a given 'start' (open paren dest)
298 // state. Assumes 'final' (close paren source or PDT final) states have
299 // been flagged 'kFinal'.
300 template<class Arc>
GC(StateId start)301 void  PdtShortestPathData<Arc>::GC(StateId start) {
302   if (!gc_)
303     return;
304   vector<StateId> final;
305   for (typename SearchMultimap::iterator mmit = search_multimap_.find(start);
306        mmit != search_multimap_.end() && mmit->first == start;
307        ++mmit) {
308     SearchState s(mmit->second, start);
309     const SearchData &data = search_map_[s];
310     if (data.flags & kFinal)
311       final.push_back(s.state);
312   }
313 
314   // Mark phase
315   for (size_t i = 0; i < final.size(); ++i) {
316     SearchState s(final[i], start);
317     while (s.state != kNoLabel) {
318       SearchData *sdata = &search_map_[s];
319       if (sdata->flags & kMarked)
320         break;
321       sdata->flags |= kMarked;
322       SearchState p = sdata->parent;
323       if (p.start != start && p.start != kNoLabel) {  // entering sub-subgraph
324         ParenSpec paren(sdata->paren_id, s.start, p.start);
325         SearchData *pdata = &paren_map_[paren];
326         s = pdata->parent;
327       } else {
328         s = p;
329       }
330     }
331   }
332 
333   // Sweep phase
334   typename SearchMultimap::iterator mmit = search_multimap_.find(start);
335   while (mmit != search_multimap_.end() && mmit->first == start) {
336     SearchState s(mmit->second, start);
337     typename SearchMap::iterator mit = search_map_.find(s);
338     const SearchData &data = mit->second;
339     if (!(data.flags & kMarked)) {
340       search_map_.erase(mit);
341       ++ngc_;
342     }
343     search_multimap_.erase(mmit++);
344   }
345 }
346 
347 template<class Arc> const Arc PdtShortestPathData<Arc>::kNoArc
348     = Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId);
349 
350 template<class Arc> const size_t PdtShortestPathData<Arc>::kPrime0 = 7853;
351 
352 template<class Arc> const size_t PdtShortestPathData<Arc>::kPrime1 = 7867;
353 
354 template<class Arc> const uint8 PdtShortestPathData<Arc>::kInited = 0x01;
355 
356 template<class Arc> const uint8 PdtShortestPathData<Arc>::kFinal =  0x02;
357 
358 template<class Arc> const uint8 PdtShortestPathData<Arc>::kMarked = 0x04;
359 
360 
361 // This computes the single source shortest (balanced) path (SSSP)
362 // through a weighted PDT that has a bounded stack (i.e. is expandable
363 // as an FST). It is a generalization of the classic SSSP graph
364 // algorithm that removes a state s from a queue (defined by a
365 // user-provided queue type) and relaxes the destination states of
366 // transitions leaving s. In this PDT version, states that have
367 // entering open parentheses are treated as source states for a
368 // sub-graph SSSP problem with the shortest path up to the open
369 // parenthesis being first saved. When a close parenthesis is then
370 // encountered any balancing open parenthesis is examined for this
371 // saved information and multiplied back. In this way, each sub-graph
372 // is entered only once rather than repeatedly.  If every state in the
373 // input PDT has the property that there is a unique 'start' state for
374 // it with entering open parentheses, then this algorithm is quite
375 // straight-forward. In general, this will not be the case, so the
376 // algorithm (implicitly) creates a new graph where each state is a
377 // pair of an original state and a possible parenthesis 'start' state
378 // for that state.
379 template<class Arc, class Queue>
380 class PdtShortestPath {
381  public:
382   typedef typename Arc::StateId StateId;
383   typedef typename Arc::Weight Weight;
384   typedef typename Arc::Label Label;
385 
386   typedef PdtShortestPathData<Arc> SpData;
387   typedef typename SpData::SearchState SearchState;
388   typedef typename SpData::ParenSpec ParenSpec;
389 
390   typedef typename PdtBalanceData<Arc>::SetIterator CloseSourceIterator;
391 
PdtShortestPath(const Fst<Arc> & ifst,const vector<pair<Label,Label>> & parens,const PdtShortestPathOptions<Arc,Queue> & opts)392   PdtShortestPath(const Fst<Arc> &ifst,
393                   const vector<pair<Label, Label> > &parens,
394                   const PdtShortestPathOptions<Arc, Queue> &opts)
395       : kFinal(SpData::kFinal),
396         ifst_(ifst.Copy()),
397         parens_(parens),
398         keep_parens_(opts.keep_parentheses),
399         start_(ifst.Start()),
400         sp_data_(opts.path_gc),
401         error_(false) {
402 
403     if ((Weight::Properties() & (kPath | kRightSemiring))
404         != (kPath | kRightSemiring)) {
405       FSTERROR() << "PdtShortestPath: Weight needs to have the path"
406                  << " property and be right distributive: " << Weight::Type();
407       error_ = true;
408     }
409 
410     for (Label i = 0; i < parens.size(); ++i) {
411       const pair<Label, Label>  &p = parens[i];
412       paren_id_map_[p.first] = i;
413       paren_id_map_[p.second] = i;
414     }
415   };
416 
~PdtShortestPath()417   ~PdtShortestPath() {
418     VLOG(1) << "# of input states: " << CountStates(*ifst_);
419     VLOG(1) << "# of enqueued: " << nenqueued_;
420     VLOG(1) << "cpmm size: " << close_paren_multimap_.size();
421     delete ifst_;
422   }
423 
ShortestPath(MutableFst<Arc> * ofst)424   void ShortestPath(MutableFst<Arc> *ofst) {
425     Init(ofst);
426     GetDistance(start_);
427     GetPath();
428     sp_data_.Finish();
429     if (error_) ofst->SetProperties(kError, kError);
430   }
431 
GetShortestPathData()432   const PdtShortestPathData<Arc> &GetShortestPathData() const {
433     return sp_data_;
434   }
435 
GetBalanceData()436   PdtBalanceData<Arc> *GetBalanceData() { return &balance_data_; }
437 
438  private:
439   static const Arc kNoArc;
440   static const uint8 kEnqueued;
441   static const uint8 kExpanded;
442   static const uint8 kFinished;
443   const uint8 kFinal;
444 
445  public:
446   // Hash multimap from close paren label to an paren arc.
447   typedef unordered_multimap<ParenState<Arc>, Arc,
448                         typename ParenState<Arc>::Hash> CloseParenMultimap;
449 
GetCloseParenMultimap()450   const CloseParenMultimap &GetCloseParenMultimap() const {
451     return close_paren_multimap_;
452   }
453 
454  private:
455   void Init(MutableFst<Arc> *ofst);
456   void GetDistance(StateId start);
457   void ProcFinal(SearchState s);
458   void ProcArcs(SearchState s);
459   void ProcOpenParen(Label paren_id, SearchState s, Arc arc, Weight w);
460   void ProcCloseParen(Label paren_id, SearchState s, const Arc &arc, Weight w);
461   void ProcNonParen(SearchState s, const Arc &arc, Weight w);
462   void Relax(SearchState s, SearchState t, Arc arc, Weight w, Label paren_id);
463   void Enqueue(SearchState d);
464   void GetPath();
465   Arc GetPathArc(SearchState s, SearchState p, Label paren_id, bool open);
466 
467   Fst<Arc> *ifst_;
468   MutableFst<Arc> *ofst_;
469   const vector<pair<Label, Label> > &parens_;
470   bool keep_parens_;
471   Queue *state_queue_;                   // current state queue
472   StateId start_;
473   Weight f_distance_;
474   SearchState f_parent_;
475   SpData sp_data_;
476   unordered_map<Label, Label> paren_id_map_;
477   CloseParenMultimap close_paren_multimap_;
478   PdtBalanceData<Arc> balance_data_;
479   ssize_t nenqueued_;
480   bool error_;
481 
482   DISALLOW_COPY_AND_ASSIGN(PdtShortestPath);
483 };
484 
485 template<class Arc, class Queue>
Init(MutableFst<Arc> * ofst)486 void PdtShortestPath<Arc, Queue>::Init(MutableFst<Arc> *ofst) {
487   ofst_ = ofst;
488   ofst->DeleteStates();
489   ofst->SetInputSymbols(ifst_->InputSymbols());
490   ofst->SetOutputSymbols(ifst_->OutputSymbols());
491 
492   if (ifst_->Start() == kNoStateId)
493     return;
494 
495   f_distance_ = Weight::Zero();
496   f_parent_ = SearchState(kNoStateId, kNoStateId);
497 
498   sp_data_.Clear();
499   close_paren_multimap_.clear();
500   balance_data_.Clear();
501   nenqueued_ = 0;
502 
503   // Find open parens per destination state and close parens per source state.
504   for (StateIterator<Fst<Arc> > siter(*ifst_); !siter.Done(); siter.Next()) {
505     StateId s = siter.Value();
506     for (ArcIterator<Fst<Arc> > aiter(*ifst_, s);
507          !aiter.Done(); aiter.Next()) {
508       const Arc &arc = aiter.Value();
509       typename unordered_map<Label, Label>::const_iterator pit
510           = paren_id_map_.find(arc.ilabel);
511       if (pit != paren_id_map_.end()) {               // Is a paren?
512         Label paren_id = pit->second;
513         if (arc.ilabel == parens_[paren_id].first) {  // Open paren
514           balance_data_.OpenInsert(paren_id, arc.nextstate);
515         } else {                                      // Close paren
516           ParenState<Arc> paren_state(paren_id, s);
517           close_paren_multimap_.insert(make_pair(paren_state, arc));
518         }
519       }
520     }
521   }
522 }
523 
524 // Computes the shortest distance stored in a recursive way. Each
525 // sub-graph (i.e. different paren 'start' state) begins with weight One().
526 template<class Arc, class Queue>
GetDistance(StateId start)527 void PdtShortestPath<Arc, Queue>::GetDistance(StateId start) {
528   if (start == kNoStateId)
529     return;
530 
531   Queue state_queue;
532   state_queue_ = &state_queue;
533   SearchState q(start, start);
534   Enqueue(q);
535   sp_data_.SetDistance(q, Weight::One());
536 
537   while (!state_queue_->Empty()) {
538     StateId state = state_queue_->Head();
539     state_queue_->Dequeue();
540     SearchState s(state, start);
541     sp_data_.SetFlags(s, 0, kEnqueued);
542     ProcFinal(s);
543     ProcArcs(s);
544     sp_data_.SetFlags(s, kExpanded, kExpanded);
545   }
546   sp_data_.SetFlags(q, kFinished, kFinished);
547   balance_data_.FinishInsert(start);
548   sp_data_.GC(start);
549 }
550 
551 // Updates best complete path.
552 template<class Arc, class Queue>
ProcFinal(SearchState s)553 void PdtShortestPath<Arc, Queue>::ProcFinal(SearchState s) {
554   if (ifst_->Final(s.state) != Weight::Zero() && s.start == start_) {
555     Weight w = Times(sp_data_.Distance(s),
556                      ifst_->Final(s.state));
557     if (f_distance_ != Plus(f_distance_, w)) {
558       if (f_parent_.state != kNoStateId)
559         sp_data_.SetFlags(f_parent_, 0, kFinal);
560       sp_data_.SetFlags(s, kFinal, kFinal);
561 
562       f_distance_ = Plus(f_distance_, w);
563       f_parent_ = s;
564     }
565   }
566 }
567 
568 // Processes all arcs leaving the state s.
569 template<class Arc, class Queue>
ProcArcs(SearchState s)570 void PdtShortestPath<Arc, Queue>::ProcArcs(SearchState s) {
571   for (ArcIterator< Fst<Arc> > aiter(*ifst_, s.state);
572        !aiter.Done();
573        aiter.Next()) {
574     Arc arc = aiter.Value();
575     Weight w = Times(sp_data_.Distance(s), arc.weight);
576 
577     typename unordered_map<Label, Label>::const_iterator pit
578         = paren_id_map_.find(arc.ilabel);
579     if (pit != paren_id_map_.end()) {  // Is a paren?
580       Label paren_id = pit->second;
581       if (arc.ilabel == parens_[paren_id].first)
582         ProcOpenParen(paren_id, s, arc, w);
583       else
584         ProcCloseParen(paren_id, s, arc, w);
585     } else {
586       ProcNonParen(s, arc, w);
587     }
588   }
589 }
590 
591 // Saves the shortest path info for reaching this parenthesis
592 // and starts a new SSSP in the sub-graph pointed to by the parenthesis
593 // if previously unvisited. Otherwise it finds any previously encountered
594 // closing parentheses and relaxes them using the recursively stored
595 // shortest distance to them.
596 template<class Arc, class Queue> inline
ProcOpenParen(Label paren_id,SearchState s,Arc arc,Weight w)597 void PdtShortestPath<Arc, Queue>::ProcOpenParen(
598     Label paren_id, SearchState s, Arc arc, Weight w) {
599 
600   SearchState d(arc.nextstate, arc.nextstate);
601   ParenSpec paren(paren_id, s.start, d.start);
602   Weight pdist = sp_data_.Distance(paren);
603   if (pdist != Plus(pdist, w)) {
604     sp_data_.SetDistance(paren, w);
605     sp_data_.SetParent(paren, s);
606     Weight dist = sp_data_.Distance(d);
607     if (dist == Weight::Zero()) {
608       Queue *state_queue = state_queue_;
609       GetDistance(d.start);
610       state_queue_ = state_queue;
611     } else if (!(sp_data_.Flags(d) & kFinished)) {
612       FSTERROR() << "PdtShortestPath: open parenthesis recursion: not bounded stack";
613       error_ = true;
614     }
615 
616     for (CloseSourceIterator set_iter =
617              balance_data_.Find(paren_id, arc.nextstate);
618          !set_iter.Done(); set_iter.Next()) {
619       SearchState cpstate(set_iter.Element(), d.start);
620       ParenState<Arc> paren_state(paren_id, cpstate.state);
621       for (typename CloseParenMultimap::const_iterator cpit =
622                close_paren_multimap_.find(paren_state);
623            cpit != close_paren_multimap_.end() && paren_state == cpit->first;
624            ++cpit) {
625         const Arc &cparc = cpit->second;
626         Weight cpw = Times(w, Times(sp_data_.Distance(cpstate),
627                                     cparc.weight));
628         Relax(cpstate, s, cparc, cpw, paren_id);
629       }
630     }
631   }
632 }
633 
634 // Saves the correspondence between each closing parenthesis and its
635 // balancing open parenthesis info. Relaxes any close parenthesis
636 // destination state that has a balancing previously encountered open
637 // parenthesis.
638 template<class Arc, class Queue> inline
ProcCloseParen(Label paren_id,SearchState s,const Arc & arc,Weight w)639 void PdtShortestPath<Arc, Queue>::ProcCloseParen(
640     Label paren_id, SearchState s, const Arc &arc, Weight w) {
641   ParenState<Arc> paren_state(paren_id, s.start);
642   if (!(sp_data_.Flags(s) & kExpanded)) {
643     balance_data_.CloseInsert(paren_id, s.start, s.state);
644     sp_data_.SetFlags(s, kFinal, kFinal);
645   }
646 }
647 
648 // For non-parentheses, classical relaxation.
649 template<class Arc, class Queue> inline
ProcNonParen(SearchState s,const Arc & arc,Weight w)650 void PdtShortestPath<Arc, Queue>::ProcNonParen(
651     SearchState s, const Arc &arc, Weight w) {
652   Relax(s, s, arc, w, kNoLabel);
653 }
654 
655 // Classical relaxation on the search graph for 'arc' from state 's'.
656 // State 't' is in the same sub-graph as the nextstate should be (i.e.
657 // has the same paren 'start'.
658 template<class Arc, class Queue> inline
Relax(SearchState s,SearchState t,Arc arc,Weight w,Label paren_id)659 void PdtShortestPath<Arc, Queue>::Relax(
660     SearchState s, SearchState t, Arc arc, Weight w, Label paren_id) {
661   SearchState d(arc.nextstate, t.start);
662   Weight dist = sp_data_.Distance(d);
663   if (dist != Plus(dist, w)) {
664     sp_data_.SetParent(d, s);
665     sp_data_.SetParenId(d, paren_id);
666     sp_data_.SetDistance(d, Plus(dist, w));
667     Enqueue(d);
668   }
669 }
670 
671 template<class Arc, class Queue> inline
Enqueue(SearchState s)672 void PdtShortestPath<Arc, Queue>::Enqueue(SearchState s) {
673   if (!(sp_data_.Flags(s) & kEnqueued)) {
674     state_queue_->Enqueue(s.state);
675     sp_data_.SetFlags(s, kEnqueued, kEnqueued);
676     ++nenqueued_;
677   } else {
678     state_queue_->Update(s.state);
679   }
680 }
681 
682 // Follows parent pointers to find the shortest path. Uses a stack
683 // since the shortest distance is stored recursively.
684 template<class Arc, class Queue>
GetPath()685 void PdtShortestPath<Arc, Queue>::GetPath() {
686   SearchState s = f_parent_, d = SearchState(kNoStateId, kNoStateId);
687   StateId s_p = kNoStateId, d_p = kNoStateId;
688   Arc arc(kNoArc);
689   Label paren_id = kNoLabel;
690   stack<ParenSpec> paren_stack;
691   while (s.state != kNoStateId) {
692     d_p = s_p;
693     s_p = ofst_->AddState();
694     if (d.state == kNoStateId) {
695       ofst_->SetFinal(s_p, ifst_->Final(f_parent_.state));
696     } else {
697       if (paren_id != kNoLabel) {                     // paren?
698         if (arc.ilabel == parens_[paren_id].first) {  // open paren
699           paren_stack.pop();
700         } else {                                      // close paren
701           ParenSpec paren(paren_id, d.start, s.start);
702           paren_stack.push(paren);
703         }
704         if (!keep_parens_)
705           arc.ilabel = arc.olabel = 0;
706       }
707       arc.nextstate = d_p;
708       ofst_->AddArc(s_p, arc);
709     }
710     d = s;
711     s = sp_data_.Parent(d);
712     paren_id = sp_data_.ParenId(d);
713     if (s.state != kNoStateId) {
714       arc = GetPathArc(s, d, paren_id, false);
715     } else if (!paren_stack.empty()) {
716       ParenSpec paren = paren_stack.top();
717       s = sp_data_.Parent(paren);
718       paren_id = paren.paren_id;
719       arc = GetPathArc(s, d, paren_id, true);
720     }
721   }
722   ofst_->SetStart(s_p);
723   ofst_->SetProperties(
724       ShortestPathProperties(ofst_->Properties(kFstProperties, false)),
725       kFstProperties);
726 }
727 
728 
729 // Finds transition with least weight between two states with label matching
730 // paren_id and open/close paren type or a non-paren if kNoLabel.
731 template<class Arc, class Queue>
GetPathArc(SearchState s,SearchState d,Label paren_id,bool open_paren)732 Arc PdtShortestPath<Arc, Queue>::GetPathArc(
733     SearchState s, SearchState d, Label paren_id, bool open_paren) {
734   Arc path_arc = kNoArc;
735   for (ArcIterator< Fst<Arc> > aiter(*ifst_, s.state);
736        !aiter.Done();
737        aiter.Next()) {
738     const Arc &arc = aiter.Value();
739     if (arc.nextstate != d.state)
740       continue;
741     Label arc_paren_id = kNoLabel;
742     typename unordered_map<Label, Label>::const_iterator pit
743         = paren_id_map_.find(arc.ilabel);
744     if (pit != paren_id_map_.end()) {
745       arc_paren_id = pit->second;
746       bool arc_open_paren = arc.ilabel == parens_[arc_paren_id].first;
747       if (arc_open_paren != open_paren)
748         continue;
749     }
750     if (arc_paren_id != paren_id)
751       continue;
752     if (arc.weight == Plus(arc.weight, path_arc.weight))
753       path_arc = arc;
754   }
755   if (path_arc.nextstate == kNoStateId) {
756     FSTERROR() << "PdtShortestPath::GetPathArc failed to find arc";
757     error_ = true;
758   }
759   return path_arc;
760 }
761 
762 template<class Arc, class Queue>
763 const Arc PdtShortestPath<Arc, Queue>::kNoArc
764     = Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId);
765 
766 template<class Arc, class Queue>
767 const uint8 PdtShortestPath<Arc, Queue>::kEnqueued = 0x10;
768 
769 template<class Arc, class Queue>
770 const uint8 PdtShortestPath<Arc, Queue>::kExpanded = 0x20;
771 
772 template<class Arc, class Queue>
773 const uint8 PdtShortestPath<Arc, Queue>::kFinished = 0x40;
774 
775 template<class Arc, class Queue>
ShortestPath(const Fst<Arc> & ifst,const vector<pair<typename Arc::Label,typename Arc::Label>> & parens,MutableFst<Arc> * ofst,const PdtShortestPathOptions<Arc,Queue> & opts)776 void ShortestPath(const Fst<Arc> &ifst,
777                   const vector<pair<typename Arc::Label,
778                                     typename Arc::Label> > &parens,
779                   MutableFst<Arc> *ofst,
780                   const PdtShortestPathOptions<Arc, Queue> &opts) {
781   PdtShortestPath<Arc, Queue> psp(ifst, parens, opts);
782   psp.ShortestPath(ofst);
783 }
784 
785 template<class Arc>
ShortestPath(const Fst<Arc> & ifst,const vector<pair<typename Arc::Label,typename Arc::Label>> & parens,MutableFst<Arc> * ofst)786 void ShortestPath(const Fst<Arc> &ifst,
787                   const vector<pair<typename Arc::Label,
788                                     typename Arc::Label> > &parens,
789                   MutableFst<Arc> *ofst) {
790   typedef FifoQueue<typename Arc::StateId> Queue;
791   PdtShortestPathOptions<Arc, Queue> opts;
792   PdtShortestPath<Arc, Queue> psp(ifst, parens, opts);
793   psp.ShortestPath(ofst);
794 }
795 
796 }  // namespace fst
797 
798 #endif  // FST_EXTENSIONS_PDT_SHORTEST_PATH_H__
799