1 // lookahead-matcher.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Classes to add lookahead to FST matchers, useful e.g. for improving
20 // composition efficiency with certain inputs.
21 
22 #ifndef FST_LIB_LOOKAHEAD_MATCHER_H__
23 #define FST_LIB_LOOKAHEAD_MATCHER_H__
24 
25 #include <fst/add-on.h>
26 #include <fst/const-fst.h>
27 #include <fst/fst.h>
28 #include <fst/label-reachable.h>
29 #include <fst/matcher.h>
30 
31 
32 DECLARE_string(save_relabel_ipairs);
33 DECLARE_string(save_relabel_opairs);
34 
35 namespace fst {
36 
37 // LOOKAHEAD MATCHERS - these have the interface of Matchers (see
38 // matcher.h) and these additional methods:
39 //
40 // template <class F>
41 // class LookAheadMatcher {
42 //  public:
43 //   typedef F FST;
44 //   typedef F::Arc Arc;
45 //   typedef typename Arc::StateId StateId;
46 //   typedef typename Arc::Label Label;
47 //   typedef typename Arc::Weight Weight;
48 //
49 //  // Required constructors.
50 //  LookAheadMatcher(const F &fst, MatchType match_type);
51 //   // If safe=true, the copy is thread-safe (except the lookahead Fst is
52 //   // preserved). See Fst<>::Cop() for further doc.
53 //  LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false);
54 //
55 //  Below are methods for looking ahead for a match to a label and
56 //  more generally, to a rational set. Each returns false if there is
57 //  definitely not a match and returns true if there possibly is a
58 //  match.
59 
60 //  // LABEL LOOKAHEAD: Can 'label' be read from the current matcher state
61 //  // after possibly following epsilon transitions?
62 //  bool LookAheadLabel(Label label) const;
63 //
64 //  // RATIONAL LOOKAHEAD: The next methods allow looking ahead for an
65 //  // arbitrary rational set of strings, specified by an FST and a state
66 //  // from which to begin the matching. If the lookahead FST is a
67 //  // transducer, this looks on the side different from the matcher
68 //  // 'match_type' (cf. composition).
69 //
70 //  // Are there paths P from 's' in the lookahead FST that can be read from
71 //  // the cur. matcher state?
72 //  bool LookAheadFst(const Fst<Arc>& fst, StateId s);
73 //
74 //  // Gives an estimate of the combined weight of the paths P in the
75 //  // lookahead and matcher FSTs for the last call to LookAheadFst.
76 //  // A trivial implementation returns Weight::One(). Non-trivial
77 //  // implementations are useful for weight-pushing in composition.
78 //  Weight LookAheadWeight() const;
79 //
80 //  // Is there is a single non-epsilon arc found in the lookahead FST
81 //  // that begins P (after possibly following any epsilons) in the last
82 //  // call LookAheadFst? If so, return true and copy it to '*arc', o.w.
83 //  // return false. A trivial implementation returns false. Non-trivial
84 //  // implementations are useful for label-pushing in composition.
85 //  bool LookAheadPrefix(Arc *arc);
86 //
87 //  // Optionally pre-specifies the lookahead FST that will be passed
88 //  // to LookAheadFst() for possible precomputation. If copy is true,
89 //  // then 'fst' is a copy of the FST used in the previous call to
90 //  // this method (useful to avoid unnecessary updates).
91 //  void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false);
92 //
93 // };
94 
95 //
96 // LOOK-AHEAD FLAGS (see also kMatcherFlags in matcher.h):
97 //
98 // Matcher is a lookahead matcher when 'match_type' is MATCH_INPUT.
99 const uint32 kInputLookAheadMatcher =     0x00000010;
100 
101 // Matcher is a lookahead matcher when 'match_type' is MATCH_OUTPUT.
102 const uint32 kOutputLookAheadMatcher =    0x00000020;
103 
104 // A non-trivial implementation of LookAheadWeight() method defined and
105 // should be used?
106 const uint32 kLookAheadWeight =           0x00000040;
107 
108 // A non-trivial implementation of LookAheadPrefix() method defined and
109 // should be used?
110 const uint32 kLookAheadPrefix =           0x00000080;
111 
112 // Look-ahead of matcher FST non-epsilon arcs?
113 const uint32 kLookAheadNonEpsilons =      0x00000100;
114 
115 // Look-ahead of matcher FST epsilon arcs?
116 const uint32 kLookAheadEpsilons =         0x00000200;
117 
118 // Ignore epsilon paths for the lookahead prefix? Note this gives
119 // correct results in composition only with an appropriate composition
120 // filter since it depends on the filter blocking the ignored paths.
121 const uint32 kLookAheadNonEpsilonPrefix = 0x00000400;
122 
123 // For LabelLookAheadMatcher, save relabeling data to file
124 const uint32 kLookAheadKeepRelabelData =  0x00000800;
125 
126 // Flags used for lookahead matchers.
127 const uint32 kLookAheadFlags =            0x00000ff0;
128 
129 // LookAhead Matcher interface, templated on the Arc definition; used
130 // for lookahead matcher specializations that are returned by the
131 // InitMatcher() Fst method.
132 template <class A>
133 class LookAheadMatcherBase : public MatcherBase<A> {
134  public:
135   typedef A Arc;
136   typedef typename A::StateId StateId;
137   typedef typename A::Label Label;
138   typedef typename A::Weight Weight;
139 
LookAheadMatcherBase()140   LookAheadMatcherBase()
141   : weight_(Weight::One()),
142     prefix_arc_(kNoLabel, kNoLabel, Weight::One(), kNoStateId) {}
143 
~LookAheadMatcherBase()144   virtual ~LookAheadMatcherBase() {}
145 
LookAheadLabel(Label label)146   bool LookAheadLabel(Label label) const { return LookAheadLabel_(label); }
147 
LookAheadFst(const Fst<Arc> & fst,StateId s)148   bool LookAheadFst(const Fst<Arc> &fst, StateId s) {
149     return LookAheadFst_(fst, s);
150   }
151 
LookAheadWeight()152   Weight LookAheadWeight() const { return weight_; }
153 
LookAheadPrefix(Arc * arc)154   bool LookAheadPrefix(Arc *arc) const {
155     if (prefix_arc_.nextstate != kNoStateId) {
156       *arc = prefix_arc_;
157       return true;
158     } else {
159       return false;
160     }
161   }
162 
163   virtual void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) = 0;
164 
165  protected:
SetLookAheadWeight(const Weight & w)166   void SetLookAheadWeight(const Weight &w) { weight_ = w; }
167 
SetLookAheadPrefix(const Arc & arc)168   void SetLookAheadPrefix(const Arc &arc) { prefix_arc_ = arc; }
169 
ClearLookAheadPrefix()170   void ClearLookAheadPrefix() { prefix_arc_.nextstate = kNoStateId; }
171 
172  private:
173   virtual bool LookAheadLabel_(Label label) const = 0;
174   virtual bool LookAheadFst_(const Fst<Arc> &fst,
175                              StateId s) = 0;  // This must set l.a. weight and
176                                               // prefix if non-trivial.
177   Weight weight_;                             // Look-ahead weight
178   Arc prefix_arc_;                            // Look-ahead prefix arc
179 };
180 
181 
182 // Don't really lookahead, just declare future looks good regardless.
183 template <class M>
184 class TrivialLookAheadMatcher
185     : public LookAheadMatcherBase<typename M::FST::Arc> {
186  public:
187   typedef typename M::FST FST;
188   typedef typename M::Arc Arc;
189   typedef typename Arc::StateId StateId;
190   typedef typename Arc::Label Label;
191   typedef typename Arc::Weight Weight;
192 
TrivialLookAheadMatcher(const FST & fst,MatchType match_type)193   TrivialLookAheadMatcher(const FST &fst, MatchType match_type)
194       : matcher_(fst, match_type) {}
195 
196   TrivialLookAheadMatcher(const TrivialLookAheadMatcher<M> &lmatcher,
197                           bool safe = false)
198       : matcher_(lmatcher.matcher_, safe) {}
199 
200   // General matcher methods
201   TrivialLookAheadMatcher<M> *Copy(bool safe = false) const {
202     return new TrivialLookAheadMatcher<M>(*this, safe);
203   }
204 
Type(bool test)205   MatchType Type(bool test) const { return matcher_.Type(test); }
SetState(StateId s)206   void SetState(StateId s) { return matcher_.SetState(s); }
Find(Label label)207   bool Find(Label label) { return matcher_.Find(label); }
Done()208   bool Done() const { return matcher_.Done(); }
Value()209   const Arc& Value() const { return matcher_.Value(); }
Next()210   void Next() { matcher_.Next(); }
GetFst()211   virtual const FST &GetFst() const { return matcher_.GetFst(); }
Properties(uint64 props)212   uint64 Properties(uint64 props) const { return matcher_.Properties(props); }
Flags()213   uint32 Flags() const {
214     return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher;
215   }
216 
217   // Look-ahead methods.
LookAheadLabel(Label label)218   bool LookAheadLabel(Label label) const { return true;  }
LookAheadFst(const Fst<Arc> & fst,StateId s)219   bool LookAheadFst(const Fst<Arc> &fst, StateId s) {return true; }
LookAheadWeight()220   Weight LookAheadWeight() const { return Weight::One(); }
LookAheadPrefix(Arc * arc)221   bool LookAheadPrefix(Arc *arc) const { return false; }
222   void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {}
223 
224  private:
225   // This allows base class virtual access to non-virtual derived-
226   // class members of the same name. It makes the derived class more
227   // efficient to use but unsafe to further derive.
SetState_(StateId s)228   virtual void SetState_(StateId s) { SetState(s); }
Find_(Label label)229   virtual bool Find_(Label label) { return Find(label); }
Done_()230   virtual bool Done_() const { return Done(); }
Value_()231   virtual const Arc& Value_() const { return Value(); }
Next_()232   virtual void Next_() { Next(); }
233 
LookAheadLabel_(Label l)234   bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); }
235 
LookAheadFst_(const Fst<Arc> & fst,StateId s)236   bool LookAheadFst_(const Fst<Arc> &fst, StateId s) {
237     return LookAheadFst(fst, s);
238   }
239 
LookAheadWeight_()240   Weight LookAheadWeight_() const { return LookAheadWeight(); }
LookAheadPrefix_(Arc * arc)241   bool LookAheadPrefix_(Arc *arc) const { return LookAheadPrefix(arc); }
242 
243   M matcher_;
244 };
245 
246 // Look-ahead of one transition. Template argument F accepts flags to
247 // control behavior.
248 template <class M, uint32 F = kLookAheadNonEpsilons | kLookAheadEpsilons |
249           kLookAheadWeight | kLookAheadPrefix>
250 class ArcLookAheadMatcher
251     : public LookAheadMatcherBase<typename M::FST::Arc> {
252  public:
253   typedef typename M::FST FST;
254   typedef typename M::Arc Arc;
255   typedef typename Arc::StateId StateId;
256   typedef typename Arc::Label Label;
257   typedef typename Arc::Weight Weight;
258   typedef NullAddOn MatcherData;
259 
260   using LookAheadMatcherBase<Arc>::LookAheadWeight;
261   using LookAheadMatcherBase<Arc>::SetLookAheadPrefix;
262   using LookAheadMatcherBase<Arc>::SetLookAheadWeight;
263   using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix;
264 
265   ArcLookAheadMatcher(const FST &fst, MatchType match_type,
266                       MatcherData *data = 0)
matcher_(fst,match_type)267       : matcher_(fst, match_type),
268         fst_(matcher_.GetFst()),
269         lfst_(0),
270         s_(kNoStateId) {}
271 
272   ArcLookAheadMatcher(const ArcLookAheadMatcher<M, F> &lmatcher,
273                       bool safe = false)
274       : matcher_(lmatcher.matcher_, safe),
275         fst_(matcher_.GetFst()),
276         lfst_(lmatcher.lfst_),
277         s_(kNoStateId) {}
278 
279   // General matcher methods
280   ArcLookAheadMatcher<M, F> *Copy(bool safe = false) const {
281     return new ArcLookAheadMatcher<M, F>(*this, safe);
282   }
283 
Type(bool test)284   MatchType Type(bool test) const { return matcher_.Type(test); }
285 
SetState(StateId s)286   void SetState(StateId s) {
287     s_ = s;
288     matcher_.SetState(s);
289   }
290 
Find(Label label)291   bool Find(Label label) { return matcher_.Find(label); }
Done()292   bool Done() const { return matcher_.Done(); }
Value()293   const Arc& Value() const { return matcher_.Value(); }
Next()294   void Next() { matcher_.Next(); }
GetFst()295   const FST &GetFst() const { return fst_; }
Properties(uint64 props)296   uint64 Properties(uint64 props) const { return matcher_.Properties(props); }
Flags()297   uint32 Flags() const {
298     return matcher_.Flags() | kInputLookAheadMatcher |
299         kOutputLookAheadMatcher | F;
300   }
301 
302   // Writable matcher methods
GetData()303   MatcherData *GetData() const { return 0; }
304 
305   // Look-ahead methods.
LookAheadLabel(Label label)306   bool LookAheadLabel(Label label) const { return matcher_.Find(label); }
307 
308   // Checks if there is a matching (possibly super-final) transition
309   // at (s_, s).
310   bool LookAheadFst(const Fst<Arc> &fst, StateId s);
311 
312   void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {
313     lfst_ = &fst;
314   }
315 
316  private:
317   // This allows base class virtual access to non-virtual derived-
318   // class members of the same name. It makes the derived class more
319   // efficient to use but unsafe to further derive.
SetState_(StateId s)320   virtual void SetState_(StateId s) { SetState(s); }
Find_(Label label)321   virtual bool Find_(Label label) { return Find(label); }
Done_()322   virtual bool Done_() const { return Done(); }
Value_()323   virtual const Arc& Value_() const { return Value(); }
Next_()324   virtual void Next_() { Next(); }
325 
LookAheadLabel_(Label l)326   bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); }
LookAheadFst_(const Fst<Arc> & fst,StateId s)327   bool LookAheadFst_(const Fst<Arc> &fst, StateId s) {
328     return LookAheadFst(fst, s);
329   }
330 
331   mutable M matcher_;
332   const FST &fst_;         // Matcher FST
333   const Fst<Arc> *lfst_;   // Look-ahead FST
334   StateId s_;              // Matcher state
335 };
336 
337 template <class M, uint32 F>
LookAheadFst(const Fst<Arc> & fst,StateId s)338 bool ArcLookAheadMatcher<M, F>::LookAheadFst(const Fst<Arc> &fst, StateId s) {
339   if (&fst != lfst_)
340     InitLookAheadFst(fst);
341 
342   bool ret = false;
343   ssize_t nprefix = 0;
344   if (F & kLookAheadWeight)
345     SetLookAheadWeight(Weight::Zero());
346   if (F & kLookAheadPrefix)
347     ClearLookAheadPrefix();
348   if (fst_.Final(s_) != Weight::Zero() &&
349       lfst_->Final(s) != Weight::Zero()) {
350     if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
351       return true;
352     ++nprefix;
353     if (F & kLookAheadWeight)
354       SetLookAheadWeight(Plus(LookAheadWeight(),
355                               Times(fst_.Final(s_), lfst_->Final(s))));
356     ret = true;
357   }
358   if (matcher_.Find(kNoLabel)) {
359     if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
360       return true;
361     ++nprefix;
362     if (F & kLookAheadWeight)
363       for (; !matcher_.Done(); matcher_.Next())
364         SetLookAheadWeight(Plus(LookAheadWeight(), matcher_.Value().weight));
365     ret = true;
366   }
367   for (ArcIterator< Fst<Arc> > aiter(*lfst_, s);
368        !aiter.Done();
369        aiter.Next()) {
370     const Arc &arc = aiter.Value();
371     Label label = kNoLabel;
372     switch (matcher_.Type(false)) {
373       case MATCH_INPUT:
374         label = arc.olabel;
375         break;
376       case MATCH_OUTPUT:
377         label = arc.ilabel;
378         break;
379       default:
380         FSTERROR() << "ArcLookAheadMatcher::LookAheadFst: bad match type";
381         return true;
382     }
383     if (label == 0) {
384       if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
385         return true;
386       if (!(F & kLookAheadNonEpsilonPrefix))
387         ++nprefix;
388       if (F & kLookAheadWeight)
389         SetLookAheadWeight(Plus(LookAheadWeight(), arc.weight));
390       ret = true;
391     } else if (matcher_.Find(label)) {
392       if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
393         return true;
394       for (; !matcher_.Done(); matcher_.Next()) {
395         ++nprefix;
396         if (F & kLookAheadWeight)
397           SetLookAheadWeight(Plus(LookAheadWeight(),
398                                   Times(arc.weight,
399                                         matcher_.Value().weight)));
400         if ((F & kLookAheadPrefix) && nprefix == 1)
401           SetLookAheadPrefix(arc);
402       }
403       ret = true;
404     }
405   }
406   if (F & kLookAheadPrefix) {
407     if (nprefix == 1)
408       SetLookAheadWeight(Weight::One());  // Avoids double counting.
409     else
410       ClearLookAheadPrefix();
411   }
412   return ret;
413 }
414 
415 
416 // Template argument F accepts flags to control behavior.
417 // It must include precisely one of KInputLookAheadMatcher or
418 // KOutputLookAheadMatcher.
419 template <class M, uint32 F = kLookAheadEpsilons | kLookAheadWeight |
420           kLookAheadPrefix | kLookAheadNonEpsilonPrefix |
421           kLookAheadKeepRelabelData,
422           class S = DefaultAccumulator<typename M::Arc> >
423 class LabelLookAheadMatcher
424     : public LookAheadMatcherBase<typename M::FST::Arc> {
425  public:
426   typedef typename M::FST FST;
427   typedef typename M::Arc Arc;
428   typedef typename Arc::StateId StateId;
429   typedef typename Arc::Label Label;
430   typedef typename Arc::Weight Weight;
431   typedef LabelReachableData<Label> MatcherData;
432 
433   using LookAheadMatcherBase<Arc>::LookAheadWeight;
434   using LookAheadMatcherBase<Arc>::SetLookAheadPrefix;
435   using LookAheadMatcherBase<Arc>::SetLookAheadWeight;
436   using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix;
437 
438   LabelLookAheadMatcher(const FST &fst, MatchType match_type,
439                         MatcherData *data = 0, S *s = 0)
matcher_(fst,match_type)440       : matcher_(fst, match_type),
441         lfst_(0),
442         label_reachable_(0),
443         s_(kNoStateId),
444         error_(false) {
445     if (!(F & (kInputLookAheadMatcher | kOutputLookAheadMatcher))) {
446       FSTERROR() << "LabelLookaheadMatcher: bad matcher flags: " << F;
447       error_ = true;
448     }
449     bool reach_input = match_type == MATCH_INPUT;
450     if (data) {
451       if (reach_input == data->ReachInput())
452         label_reachable_ = new LabelReachable<Arc, S>(data, s);
453     } else if ((reach_input && (F & kInputLookAheadMatcher)) ||
454                (!reach_input && (F & kOutputLookAheadMatcher))) {
455       label_reachable_ = new LabelReachable<Arc, S>(
456           fst, reach_input, s, F & kLookAheadKeepRelabelData);
457     }
458   }
459 
460   LabelLookAheadMatcher(const LabelLookAheadMatcher<M, F, S> &lmatcher,
461                         bool safe = false)
462       : matcher_(lmatcher.matcher_, safe),
463         lfst_(lmatcher.lfst_),
464         label_reachable_(
465             lmatcher.label_reachable_ ?
466             new LabelReachable<Arc, S>(*lmatcher.label_reachable_) : 0),
467         s_(kNoStateId),
468         error_(lmatcher.error_) {}
469 
~LabelLookAheadMatcher()470   ~LabelLookAheadMatcher() {
471     delete label_reachable_;
472   }
473 
474   // General matcher methods
475   LabelLookAheadMatcher<M, F, S> *Copy(bool safe = false) const {
476     return new LabelLookAheadMatcher<M, F, S>(*this, safe);
477   }
478 
Type(bool test)479   MatchType Type(bool test) const { return matcher_.Type(test); }
480 
SetState(StateId s)481   void SetState(StateId s) {
482     if (s_ == s)
483       return;
484     s_ = s;
485     match_set_state_ = false;
486     reach_set_state_ = false;
487   }
488 
Find(Label label)489   bool Find(Label label) {
490     if (!match_set_state_) {
491       matcher_.SetState(s_);
492       match_set_state_ = true;
493     }
494     return matcher_.Find(label);
495   }
496 
Done()497   bool Done() const { return matcher_.Done(); }
Value()498   const Arc& Value() const { return matcher_.Value(); }
Next()499   void Next() { matcher_.Next(); }
GetFst()500   const FST &GetFst() const { return matcher_.GetFst(); }
501 
Properties(uint64 inprops)502   uint64 Properties(uint64 inprops) const {
503     uint64 outprops = matcher_.Properties(inprops);
504     if (error_ || (label_reachable_ && label_reachable_->Error()))
505       outprops |= kError;
506     return outprops;
507   }
508 
Flags()509   uint32 Flags() const {
510     if (label_reachable_ && label_reachable_->GetData()->ReachInput())
511       return matcher_.Flags() | F | kInputLookAheadMatcher;
512     else if (label_reachable_ && !label_reachable_->GetData()->ReachInput())
513       return matcher_.Flags() | F | kOutputLookAheadMatcher;
514     else
515       return matcher_.Flags();
516   }
517 
518   // Writable matcher methods
GetData()519   MatcherData *GetData() const {
520     return label_reachable_ ? label_reachable_->GetData() : 0;
521   };
522 
523   // Look-ahead methods.
LookAheadLabel(Label label)524   bool LookAheadLabel(Label label) const {
525     if (label == 0)
526       return true;
527 
528     if (label_reachable_) {
529       if (!reach_set_state_) {
530         label_reachable_->SetState(s_);
531         reach_set_state_ = true;
532       }
533       return label_reachable_->Reach(label);
534     } else {
535       return true;
536     }
537   }
538 
539   // Checks if there is a matching (possibly super-final) transition
540   // at (s_, s).
541   template <class L>
542   bool LookAheadFst(const L &fst, StateId s);
543 
544   void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {
545     lfst_ = &fst;
546     if (label_reachable_)
547       label_reachable_->ReachInit(fst, copy);
548   }
549 
550   template <class L>
551   void InitLookAheadFst(const L& fst, bool copy = false) {
552     lfst_ = static_cast<const Fst<Arc> *>(&fst);
553     if (label_reachable_)
554       label_reachable_->ReachInit(fst, copy);
555   }
556 
557  private:
558   // This allows base class virtual access to non-virtual derived-
559   // class members of the same name. It makes the derived class more
560   // efficient to use but unsafe to further derive.
SetState_(StateId s)561   virtual void SetState_(StateId s) { SetState(s); }
Find_(Label label)562   virtual bool Find_(Label label) { return Find(label); }
Done_()563   virtual bool Done_() const { return Done(); }
Value_()564   virtual const Arc& Value_() const { return Value(); }
Next_()565   virtual void Next_() { Next(); }
566 
LookAheadLabel_(Label l)567   bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); }
LookAheadFst_(const Fst<Arc> & fst,StateId s)568   bool LookAheadFst_(const Fst<Arc> &fst, StateId s) {
569     return LookAheadFst(fst, s);
570   }
571 
572   mutable M matcher_;
573   const Fst<Arc> *lfst_;                     // Look-ahead FST
574   LabelReachable<Arc, S> *label_reachable_;  // Label reachability info
575   StateId s_;                                // Matcher state
576   bool match_set_state_;                     // matcher_.SetState called?
577   mutable bool reach_set_state_;             // reachable_.SetState called?
578   bool error_;
579 };
580 
581 template <class M, uint32 F, class S>
582 template <class L> inline
LookAheadFst(const L & fst,StateId s)583 bool LabelLookAheadMatcher<M, F, S>::LookAheadFst(const L &fst, StateId s) {
584   if (static_cast<const Fst<Arc> *>(&fst) != lfst_)
585     InitLookAheadFst(fst);
586 
587   SetLookAheadWeight(Weight::One());
588   ClearLookAheadPrefix();
589 
590   if (!label_reachable_)
591     return true;
592 
593   label_reachable_->SetState(s_, s);
594   reach_set_state_ = true;
595 
596   bool compute_weight = F & kLookAheadWeight;
597   bool compute_prefix = F & kLookAheadPrefix;
598 
599   bool reach_input = Type(false) == MATCH_OUTPUT;
600   ArcIterator<L> aiter(fst, s);
601   bool reach_arc = label_reachable_->Reach(&aiter, 0,
602                                            internal::NumArcs(*lfst_, s),
603                                            reach_input, compute_weight);
604   Weight lfinal = internal::Final(*lfst_, s);
605   bool reach_final = lfinal != Weight::Zero() && label_reachable_->ReachFinal();
606   if (reach_arc) {
607     ssize_t begin = label_reachable_->ReachBegin();
608     ssize_t end = label_reachable_->ReachEnd();
609     if (compute_prefix && end - begin == 1 && !reach_final) {
610       aiter.Seek(begin);
611       SetLookAheadPrefix(aiter.Value());
612       compute_weight = false;
613     } else if (compute_weight) {
614       SetLookAheadWeight(label_reachable_->ReachWeight());
615     }
616   }
617   if (reach_final && compute_weight)
618     SetLookAheadWeight(reach_arc ?
619                        Plus(LookAheadWeight(), lfinal) : lfinal);
620 
621   return reach_arc || reach_final;
622 }
623 
624 
625 // Label-lookahead relabeling class.
626 template <class A>
627 class LabelLookAheadRelabeler {
628  public:
629   typedef typename A::Label Label;
630   typedef LabelReachableData<Label> MatcherData;
631   typedef AddOnPair<MatcherData, MatcherData> D;
632 
633   // Relabels matcher Fst - initialization function object.
634   template <typename I>
635   LabelLookAheadRelabeler(I **impl);
636 
637   // Relabels arbitrary Fst. Class L should be a label-lookahead Fst.
638   template <class L>
Relabel(MutableFst<A> * fst,const L & mfst,bool relabel_input)639   static void Relabel(MutableFst<A> *fst, const L &mfst,
640                       bool relabel_input) {
641     typename L::Impl *impl = mfst.GetImpl();
642     D *data = impl->GetAddOn();
643     LabelReachable<A> reachable(data->First() ?
644                                   data->First() : data->Second());
645     reachable.Relabel(fst, relabel_input);
646   }
647 
648   // Returns relabeling pairs (cf. relabel.h::Relabel()).
649   // Class L should be a label-lookahead Fst.
650   // If 'avoid_collisions' is true, extra pairs are added to
651   // ensure no collisions when relabeling automata that have
652   // labels unseen here.
653   template <class L>
654   static void RelabelPairs(const L &mfst, vector<pair<Label, Label> > *pairs,
655                            bool avoid_collisions = false) {
656     typename L::Impl *impl = mfst.GetImpl();
657     D *data = impl->GetAddOn();
658     LabelReachable<A> reachable(data->First() ?
659                                   data->First() : data->Second());
660     reachable.RelabelPairs(pairs, avoid_collisions);
661   }
662 };
663 
664 template <class A>
665 template <typename I> inline
LabelLookAheadRelabeler(I ** impl)666 LabelLookAheadRelabeler<A>::LabelLookAheadRelabeler(I **impl) {
667   Fst<A> &fst = (*impl)->GetFst();
668   D *data = (*impl)->GetAddOn();
669   const string name = (*impl)->Type();
670   bool is_mutable = fst.Properties(kMutable, false);
671   MutableFst<A> *mfst = 0;
672   if (is_mutable) {
673     mfst = static_cast<MutableFst<A> *>(&fst);
674   } else {
675     mfst = new VectorFst<A>(fst);
676     data->IncrRefCount();
677     delete *impl;
678   }
679   if (data->First()) {  // reach_input
680     LabelReachable<A> reachable(data->First());
681     reachable.Relabel(mfst, true);
682     if (!FLAGS_save_relabel_ipairs.empty()) {
683       vector<pair<Label, Label> > pairs;
684       reachable.RelabelPairs(&pairs, true);
685       WriteLabelPairs(FLAGS_save_relabel_ipairs, pairs);
686     }
687   } else {
688     LabelReachable<A> reachable(data->Second());
689     reachable.Relabel(mfst, false);
690     if (!FLAGS_save_relabel_opairs.empty()) {
691       vector<pair<Label, Label> > pairs;
692       reachable.RelabelPairs(&pairs, true);
693       WriteLabelPairs(FLAGS_save_relabel_opairs, pairs);
694     }
695   }
696   if (!is_mutable) {
697     *impl = new I(*mfst, name);
698     (*impl)->SetAddOn(data);
699     delete mfst;
700     data->DecrRefCount();
701   }
702 }
703 
704 
705 // Generic lookahead matcher, templated on the FST definition
706 // - a wrapper around pointer to specific one.
707 template <class F>
708 class LookAheadMatcher {
709  public:
710   typedef F FST;
711   typedef typename F::Arc Arc;
712   typedef typename Arc::StateId StateId;
713   typedef typename Arc::Label Label;
714   typedef typename Arc::Weight Weight;
715   typedef LookAheadMatcherBase<Arc> LBase;
716 
LookAheadMatcher(const F & fst,MatchType match_type)717   LookAheadMatcher(const F &fst, MatchType match_type) {
718     base_ = fst.InitMatcher(match_type);
719     if (!base_)
720       base_ = new SortedMatcher<F>(fst, match_type);
721     lookahead_ = false;
722   }
723 
724   LookAheadMatcher(const LookAheadMatcher<F> &matcher, bool safe = false) {
725     base_ = matcher.base_->Copy(safe);
726     lookahead_ = matcher.lookahead_;
727   }
728 
~LookAheadMatcher()729   ~LookAheadMatcher() { delete base_; }
730 
731   // General matcher methods
732   LookAheadMatcher<F> *Copy(bool safe = false) const {
733       return new LookAheadMatcher<F>(*this, safe);
734   }
735 
Type(bool test)736   MatchType Type(bool test) const { return base_->Type(test); }
SetState(StateId s)737   void SetState(StateId s) { base_->SetState(s); }
Find(Label label)738   bool Find(Label label) { return base_->Find(label); }
Done()739   bool Done() const { return base_->Done(); }
Value()740   const Arc& Value() const { return base_->Value(); }
Next()741   void Next() { base_->Next(); }
GetFst()742   const F &GetFst() const { return static_cast<const F &>(base_->GetFst()); }
743 
Properties(uint64 props)744   uint64 Properties(uint64 props) const { return base_->Properties(props); }
745 
Flags()746   uint32 Flags() const { return base_->Flags(); }
747 
748   // Look-ahead methods
LookAheadLabel(Label label)749   bool LookAheadLabel(Label label) const {
750     if (LookAheadCheck()) {
751       LBase *lbase = static_cast<LBase *>(base_);
752       return lbase->LookAheadLabel(label);
753     } else {
754       return true;
755     }
756   }
757 
LookAheadFst(const Fst<Arc> & fst,StateId s)758   bool LookAheadFst(const Fst<Arc> &fst, StateId s) {
759     if (LookAheadCheck()) {
760       LBase *lbase = static_cast<LBase *>(base_);
761       return lbase->LookAheadFst(fst, s);
762     } else {
763       return true;
764     }
765   }
766 
LookAheadWeight()767   Weight LookAheadWeight() const {
768     if (LookAheadCheck()) {
769       LBase *lbase = static_cast<LBase *>(base_);
770       return lbase->LookAheadWeight();
771     } else {
772       return Weight::One();
773     }
774   }
775 
LookAheadPrefix(Arc * arc)776   bool LookAheadPrefix(Arc *arc) const {
777     if (LookAheadCheck()) {
778       LBase *lbase = static_cast<LBase *>(base_);
779       return lbase->LookAheadPrefix(arc);
780     } else {
781       return false;
782     }
783   }
784 
785   void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {
786     if (LookAheadCheck()) {
787       LBase *lbase = static_cast<LBase *>(base_);
788       lbase->InitLookAheadFst(fst, copy);
789     }
790   }
791 
792  private:
LookAheadCheck()793   bool LookAheadCheck() const {
794     if (!lookahead_) {
795       lookahead_ = base_->Flags() &
796           (kInputLookAheadMatcher | kOutputLookAheadMatcher);
797       if (!lookahead_) {
798         FSTERROR() << "LookAheadMatcher: No look-ahead matcher defined";
799       }
800     }
801     return lookahead_;
802   }
803 
804   MatcherBase<Arc> *base_;
805   mutable bool lookahead_;
806 
807   void operator=(const LookAheadMatcher<Arc> &);  // disallow
808 };
809 
810 }  // namespace fst
811 
812 #endif  // FST_LIB_LOOKAHEAD_MATCHER_H__
813