1 // matcher.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Classes to allow matching labels leaving FST states.
20 
21 #ifndef FST_LIB_MATCHER_H__
22 #define FST_LIB_MATCHER_H__
23 
24 #include <algorithm>
25 #include <set>
26 
27 #include <fst/mutable-fst.h>  // for all internal FST accessors
28 
29 
30 namespace fst {
31 
32 // MATCHERS - these can find and iterate through requested labels at
33 // FST states. In the simplest form, these are just some associative
34 // map or search keyed on labels. More generally, they may
35 // implement matching special labels that represent sets of labels
36 // such as 'sigma' (all), 'rho' (rest), or 'phi' (fail).
37 // The Matcher interface is:
38 //
39 // template <class F>
40 // class Matcher {
41 //  public:
42 //   typedef F FST;
43 //   typedef F::Arc Arc;
44 //   typedef typename Arc::StateId StateId;
45 //   typedef typename Arc::Label Label;
46 //   typedef typename Arc::Weight Weight;
47 //
48 //   // Required constructors.
49 //   Matcher(const F &fst, MatchType type);
50 //   // If safe=true, the copy is thread-safe. See Fst<>::Copy()
51 //   // for further doc.
52 //   Matcher(const Matcher &matcher, bool safe = false);
53 //
54 //   // If safe=true, the copy is thread-safe. See Fst<>::Copy()
55 //   // for further doc.
56 //   Matcher<F> *Copy(bool safe = false) const;
57 //
58 //   // Returns the match type that can be provided (depending on
59 //   // compatibility of the input FST). It is either
60 //   // the requested match type, MATCH_NONE, or MATCH_UNKNOWN.
61 //   // If 'test' is false, a constant time test is performed, but
62 //   // MATCH_UNKNOWN may be returned. If 'test' is true,
63 //   // a definite answer is returned, but may involve more costly
64 //   // computation (e.g., visiting the Fst).
65 //   MatchType Type(bool test) const;
66 //   // Specifies the current state.
67 //   void SetState(StateId s);
68 //
69 //   // This finds matches to a label at the current state.
70 //   // Returns true if a match found. kNoLabel matches any
71 //   // 'non-consuming' transitions, e.g., epsilon transitions,
72 //   // which do not require a matching symbol.
73 //   bool Find(Label label);
74 //   // These iterate through any matches found:
75 //   bool Done() const;         // No more matches.
76 //   const A& Value() const;    // Current arc (when !Done)
77 //   void Next();               // Advance to next arc (when !Done)
78 //   // Initially and after SetState() the iterator methods
79 //   // have undefined behavior until Find() is called.
80 //
81 //   // Return matcher FST.
82 //   const F& GetFst() const;
83 //   // This specifies the known Fst properties as viewed from this
84 //   // matcher. It takes as argument the input Fst's known properties.
85 //   uint64 Properties(uint64 props) const;
86 // };
87 
88 //
89 // MATCHER FLAGS (see also kLookAheadFlags in lookahead-matcher.h)
90 //
91 // Matcher prefers being used as the matching side in composition.
92 const uint32 kPreferMatch  = 0x00000001;
93 
94 // Matcher needs to be used as the matching side in composition.
95 const uint32 kRequireMatch = 0x00000002;
96 
97 // Flags used for basic matchers (see also lookahead.h).
98 const uint32 kMatcherFlags = kPreferMatch | kRequireMatch;
99 
100 // Matcher interface, templated on the Arc definition; used
101 // for matcher specializations that are returned by the
102 // InitMatcher Fst method.
103 template <class A>
104 class MatcherBase {
105  public:
106   typedef A Arc;
107   typedef typename A::StateId StateId;
108   typedef typename A::Label Label;
109   typedef typename A::Weight Weight;
110 
~MatcherBase()111   virtual ~MatcherBase() {}
112 
113   virtual MatcherBase<A> *Copy(bool safe = false) const = 0;
114   virtual MatchType Type(bool test) const = 0;
SetState(StateId s)115   void SetState(StateId s) { SetState_(s); }
Find(Label label)116   bool Find(Label label) { return Find_(label); }
Done()117   bool Done() const { return Done_(); }
Value()118   const A& Value() const { return Value_(); }
Next()119   void Next() { Next_(); }
120   virtual const Fst<A> &GetFst() const = 0;
121   virtual uint64 Properties(uint64 props) const = 0;
Flags()122   virtual uint32 Flags() const { return 0; }
123  private:
124   virtual void SetState_(StateId s) = 0;
125   virtual bool Find_(Label label) = 0;
126   virtual bool Done_() const = 0;
127   virtual const A& Value_() const  = 0;
128   virtual void Next_()  = 0;
129 };
130 
131 
132 // A matcher that expects sorted labels on the side to be matched.
133 // If match_type == MATCH_INPUT, epsilons match the implicit self loop
134 // Arc(kNoLabel, 0, Weight::One(), current_state) as well as any
135 // actual epsilon transitions. If match_type == MATCH_OUTPUT, then
136 // Arc(0, kNoLabel, Weight::One(), current_state) is instead matched.
137 template <class F>
138 class SortedMatcher : public MatcherBase<typename F::Arc> {
139  public:
140   typedef F FST;
141   typedef typename F::Arc Arc;
142   typedef typename Arc::StateId StateId;
143   typedef typename Arc::Label Label;
144   typedef typename Arc::Weight Weight;
145 
146   // Labels >= binary_label will be searched for by binary search,
147   // o.w. linear search is used.
148   SortedMatcher(const F &fst, MatchType match_type,
149                 Label binary_label = 1)
150       : fst_(fst.Copy()),
151         s_(kNoStateId),
152         aiter_(0),
153         match_type_(match_type),
154         binary_label_(binary_label),
155         match_label_(kNoLabel),
156         narcs_(0),
157         loop_(kNoLabel, 0, Weight::One(), kNoStateId),
158         error_(false) {
159     switch(match_type_) {
160       case MATCH_INPUT:
161       case MATCH_NONE:
162         break;
163       case MATCH_OUTPUT:
164         swap(loop_.ilabel, loop_.olabel);
165         break;
166       default:
167         FSTERROR() << "SortedMatcher: bad match type";
168         match_type_ = MATCH_NONE;
169         error_ = true;
170     }
171   }
172 
173   SortedMatcher(const SortedMatcher<F> &matcher, bool safe = false)
174       : fst_(matcher.fst_->Copy(safe)),
175         s_(kNoStateId),
176         aiter_(0),
177         match_type_(matcher.match_type_),
178         binary_label_(matcher.binary_label_),
179         match_label_(kNoLabel),
180         narcs_(0),
181         loop_(matcher.loop_),
182         error_(matcher.error_) {}
183 
~SortedMatcher()184   virtual ~SortedMatcher() {
185     if (aiter_)
186       delete aiter_;
187     delete fst_;
188   }
189 
190   virtual SortedMatcher<F> *Copy(bool safe = false) const {
191     return new SortedMatcher<F>(*this, safe);
192   }
193 
Type(bool test)194   virtual MatchType Type(bool test) const {
195     if (match_type_ == MATCH_NONE)
196       return match_type_;
197 
198     uint64 true_prop =  match_type_ == MATCH_INPUT ?
199         kILabelSorted : kOLabelSorted;
200     uint64 false_prop = match_type_ == MATCH_INPUT ?
201         kNotILabelSorted : kNotOLabelSorted;
202     uint64 props = fst_->Properties(true_prop | false_prop, test);
203 
204     if (props & true_prop)
205       return match_type_;
206     else if (props & false_prop)
207       return MATCH_NONE;
208     else
209       return MATCH_UNKNOWN;
210   }
211 
SetState(StateId s)212   void SetState(StateId s) {
213     if (s_ == s)
214       return;
215     s_ = s;
216     if (match_type_ == MATCH_NONE) {
217       FSTERROR() << "SortedMatcher: bad match type";
218       error_ = true;
219     }
220     if (aiter_)
221       delete aiter_;
222     aiter_ = new ArcIterator<F>(*fst_, s);
223     aiter_->SetFlags(kArcNoCache, kArcNoCache);
224     narcs_ = internal::NumArcs(*fst_, s);
225     loop_.nextstate = s;
226   }
227 
Find(Label match_label)228   bool Find(Label match_label) {
229     exact_match_ = true;
230     if (error_) {
231       current_loop_ = false;
232       match_label_ = kNoLabel;
233       return false;
234     }
235     current_loop_ = match_label == 0;
236     match_label_ = match_label == kNoLabel ? 0 : match_label;
237     if (Search()) {
238       return true;
239     } else {
240       return current_loop_;
241     }
242   }
243 
244   // Positions matcher to the first position where inserting
245   // match_label would maintain the sort order.
LowerBound(Label match_label)246   void LowerBound(Label match_label) {
247     exact_match_ = false;
248     current_loop_ = false;
249     if (error_) {
250       match_label_ = kNoLabel;
251       return;
252     }
253     match_label_ = match_label;
254     Search();
255   }
256 
257   // After Find(), returns false if no more exact matches.
258   // After LowerBound(), returns false if no more arcs.
Done()259   bool Done() const {
260     if (current_loop_)
261       return false;
262     if (aiter_->Done())
263       return true;
264     if (!exact_match_)
265       return false;
266     aiter_->SetFlags(
267       match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue,
268       kArcValueFlags);
269     Label label = match_type_ == MATCH_INPUT ?
270         aiter_->Value().ilabel : aiter_->Value().olabel;
271     return label != match_label_;
272   }
273 
Value()274   const Arc& Value() const {
275     if (current_loop_) {
276       return loop_;
277     }
278     aiter_->SetFlags(kArcValueFlags, kArcValueFlags);
279     return aiter_->Value();
280   }
281 
Next()282   void Next() {
283     if (current_loop_)
284       current_loop_ = false;
285     else
286       aiter_->Next();
287   }
288 
GetFst()289   virtual const F &GetFst() const { return *fst_; }
290 
Properties(uint64 inprops)291   virtual uint64 Properties(uint64 inprops) const {
292     uint64 outprops = inprops;
293     if (error_) outprops |= kError;
294     return outprops;
295   }
296 
Position()297   size_t Position() const { return aiter_ ? aiter_->Position() : 0; }
298 
299  private:
SetState_(StateId s)300   virtual void SetState_(StateId s) { SetState(s); }
Find_(Label label)301   virtual bool Find_(Label label) { return Find(label); }
Done_()302   virtual bool Done_() const { return Done(); }
Value_()303   virtual const Arc& Value_() const { return Value(); }
Next_()304   virtual void Next_() { Next(); }
305 
306   bool Search();
307 
308   const F *fst_;
309   StateId s_;                     // Current state
310   ArcIterator<F> *aiter_;         // Iterator for current state
311   MatchType match_type_;          // Type of match to perform
312   Label binary_label_;            // Least label for binary search
313   Label match_label_;             // Current label to be matched
314   size_t narcs_;                  // Current state arc count
315   Arc loop_;                      // For non-consuming symbols
316   bool current_loop_;             // Current arc is the implicit loop
317   bool exact_match_;              // Exact match or lower bound?
318   bool error_;                    // Error encountered
319 
320   void operator=(const SortedMatcher<F> &);  // Disallow
321 };
322 
323 // Returns true iff match to match_label_. Positions arc iterator at
324 // lower bound regardless.
325 template <class F> inline
Search()326 bool SortedMatcher<F>::Search() {
327   aiter_->SetFlags(
328       match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue,
329       kArcValueFlags);
330   if (match_label_ >= binary_label_) {
331     // Binary search for match.
332     size_t low = 0;
333     size_t high = narcs_;
334     while (low < high) {
335       size_t mid = (low + high) / 2;
336       aiter_->Seek(mid);
337       Label label = match_type_ == MATCH_INPUT ?
338           aiter_->Value().ilabel : aiter_->Value().olabel;
339       if (label > match_label_) {
340         high = mid;
341       } else if (label < match_label_) {
342         low = mid + 1;
343       } else {
344         // find first matching label (when non-determinism)
345         for (size_t i = mid; i > low; --i) {
346           aiter_->Seek(i - 1);
347           label = match_type_ == MATCH_INPUT ? aiter_->Value().ilabel :
348               aiter_->Value().olabel;
349           if (label != match_label_) {
350             aiter_->Seek(i);
351             return true;
352           }
353         }
354         return true;
355       }
356     }
357     aiter_->Seek(low);
358     return false;
359   } else {
360     // Linear search for match.
361     for (aiter_->Reset(); !aiter_->Done(); aiter_->Next()) {
362       Label label = match_type_ == MATCH_INPUT ?
363           aiter_->Value().ilabel : aiter_->Value().olabel;
364       if (label == match_label_) {
365         return true;
366       }
367       if (label > match_label_)
368         break;
369     }
370     return false;
371   }
372 }
373 
374 
375 // Specifies whether during matching we rewrite both the input and output sides.
376 enum MatcherRewriteMode {
377   MATCHER_REWRITE_AUTO = 0,    // Rewrites both sides iff acceptor.
378   MATCHER_REWRITE_ALWAYS,
379   MATCHER_REWRITE_NEVER
380 };
381 
382 
383 // For any requested label that doesn't match at a state, this matcher
384 // considers all transitions that match the label 'rho_label' (rho =
385 // 'rest').  Each such rho transition found is returned with the
386 // rho_label rewritten as the requested label (both sides if an
387 // acceptor, or if 'rewrite_both' is true and both input and output
388 // labels of the found transition are 'rho_label').  If 'rho_label' is
389 // kNoLabel, this special matching is not done.  RhoMatcher is
390 // templated itself on a matcher, which is used to perform the
391 // underlying matching.  By default, the underlying matcher is
392 // constructed by RhoMatcher.  The user can instead pass in this
393 // object; in that case, RhoMatcher takes its ownership.
394 template <class M>
395 class RhoMatcher : public MatcherBase<typename M::Arc> {
396  public:
397   typedef typename M::FST FST;
398   typedef typename M::Arc Arc;
399   typedef typename Arc::StateId StateId;
400   typedef typename Arc::Label Label;
401   typedef typename Arc::Weight Weight;
402 
403   RhoMatcher(const FST &fst,
404              MatchType match_type,
405              Label rho_label = kNoLabel,
406              MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
407              M *matcher = 0)
408       : matcher_(matcher ? matcher : new M(fst, match_type)),
409         match_type_(match_type),
410         rho_label_(rho_label),
411         error_(false) {
412     if (match_type == MATCH_BOTH) {
413       FSTERROR() << "RhoMatcher: bad match type";
414       match_type_ = MATCH_NONE;
415       error_ = true;
416     }
417     if (rho_label == 0) {
418       FSTERROR() << "RhoMatcher: 0 cannot be used as rho_label";
419       rho_label_ = kNoLabel;
420       error_ = true;
421     }
422 
423     if (rewrite_mode == MATCHER_REWRITE_AUTO)
424       rewrite_both_ = fst.Properties(kAcceptor, true);
425     else if (rewrite_mode == MATCHER_REWRITE_ALWAYS)
426       rewrite_both_ = true;
427     else
428       rewrite_both_ = false;
429   }
430 
431   RhoMatcher(const RhoMatcher<M> &matcher, bool safe = false)
432       : matcher_(new M(*matcher.matcher_, safe)),
433         match_type_(matcher.match_type_),
434         rho_label_(matcher.rho_label_),
435         rewrite_both_(matcher.rewrite_both_),
436         error_(matcher.error_) {}
437 
~RhoMatcher()438   virtual ~RhoMatcher() {
439     delete matcher_;
440   }
441 
442   virtual RhoMatcher<M> *Copy(bool safe = false) const {
443     return new RhoMatcher<M>(*this, safe);
444   }
445 
Type(bool test)446   virtual MatchType Type(bool test) const { return matcher_->Type(test); }
447 
SetState(StateId s)448   void SetState(StateId s) {
449     matcher_->SetState(s);
450     has_rho_ = rho_label_ != kNoLabel;
451   }
452 
Find(Label match_label)453   bool Find(Label match_label) {
454     if (match_label == rho_label_ && rho_label_ != kNoLabel) {
455       FSTERROR() << "RhoMatcher::Find: bad label (rho)";
456       error_ = true;
457       return false;
458     }
459     if (matcher_->Find(match_label)) {
460       rho_match_ = kNoLabel;
461       return true;
462     } else if (has_rho_ && match_label != 0 && match_label != kNoLabel &&
463                (has_rho_ = matcher_->Find(rho_label_))) {
464       rho_match_ = match_label;
465       return true;
466     } else {
467       return false;
468     }
469   }
470 
Done()471   bool Done() const { return matcher_->Done(); }
472 
Value()473   const Arc& Value() const {
474     if (rho_match_ == kNoLabel) {
475       return matcher_->Value();
476     } else {
477       rho_arc_ = matcher_->Value();
478       if (rewrite_both_) {
479         if (rho_arc_.ilabel == rho_label_)
480           rho_arc_.ilabel = rho_match_;
481         if (rho_arc_.olabel == rho_label_)
482           rho_arc_.olabel = rho_match_;
483       } else if (match_type_ == MATCH_INPUT) {
484         rho_arc_.ilabel = rho_match_;
485       } else {
486         rho_arc_.olabel = rho_match_;
487       }
488       return rho_arc_;
489     }
490   }
491 
Next()492   void Next() { matcher_->Next(); }
493 
GetFst()494   virtual const FST &GetFst() const { return matcher_->GetFst(); }
495 
496   virtual uint64 Properties(uint64 props) const;
497 
Flags()498   virtual uint32 Flags() const {
499     if (rho_label_ == kNoLabel || match_type_ == MATCH_NONE)
500       return matcher_->Flags();
501     return matcher_->Flags() | kRequireMatch;
502   }
503 
504  private:
SetState_(StateId s)505   virtual void SetState_(StateId s) { SetState(s); }
Find_(Label label)506   virtual bool Find_(Label label) { return Find(label); }
Done_()507   virtual bool Done_() const { return Done(); }
Value_()508   virtual const Arc& Value_() const { return Value(); }
Next_()509   virtual void Next_() { Next(); }
510 
511   M *matcher_;
512   MatchType match_type_;  // Type of match requested
513   Label rho_label_;       // Label that represents the rho transition
514   bool rewrite_both_;     // Rewrite both sides when both are 'rho_label_'
515   bool has_rho_;          // Are there possibly rhos at the current state?
516   Label rho_match_;       // Current label that matches rho transition
517   mutable Arc rho_arc_;   // Arc to return when rho match
518   bool error_;            // Error encountered
519 
520   void operator=(const RhoMatcher<M> &);  // Disallow
521 };
522 
523 template <class M> inline
Properties(uint64 inprops)524 uint64 RhoMatcher<M>::Properties(uint64 inprops) const {
525   uint64 outprops = matcher_->Properties(inprops);
526   if (error_) outprops |= kError;
527 
528   if (match_type_ == MATCH_NONE) {
529     return outprops;
530   } else if (match_type_ == MATCH_INPUT) {
531     if (rewrite_both_) {
532       return outprops & ~(kODeterministic | kNonODeterministic | kString |
533                        kILabelSorted | kNotILabelSorted |
534                        kOLabelSorted | kNotOLabelSorted);
535     } else {
536       return outprops & ~(kODeterministic | kAcceptor | kString |
537                        kILabelSorted | kNotILabelSorted);
538     }
539   } else if (match_type_ == MATCH_OUTPUT) {
540     if (rewrite_both_) {
541       return outprops & ~(kIDeterministic | kNonIDeterministic | kString |
542                        kILabelSorted | kNotILabelSorted |
543                        kOLabelSorted | kNotOLabelSorted);
544     } else {
545       return outprops & ~(kIDeterministic | kAcceptor | kString |
546                        kOLabelSorted | kNotOLabelSorted);
547     }
548   } else {
549     // Shouldn't ever get here.
550     FSTERROR() << "RhoMatcher:: bad match type: " << match_type_;
551     return 0;
552   }
553 }
554 
555 
556 // For any requested label, this matcher considers all transitions
557 // that match the label 'sigma_label' (sigma = "any"), and this in
558 // additions to transitions with the requested label.  Each such sigma
559 // transition found is returned with the sigma_label rewritten as the
560 // requested label (both sides if an acceptor, or if 'rewrite_both' is
561 // true and both input and output labels of the found transition are
562 // 'sigma_label').  If 'sigma_label' is kNoLabel, this special
563 // matching is not done.  SigmaMatcher is templated itself on a
564 // matcher, which is used to perform the underlying matching.  By
565 // default, the underlying matcher is constructed by SigmaMatcher.
566 // The user can instead pass in this object; in that case,
567 // SigmaMatcher takes its ownership.
568 template <class M>
569 class SigmaMatcher : public MatcherBase<typename M::Arc> {
570  public:
571   typedef typename M::FST FST;
572   typedef typename M::Arc Arc;
573   typedef typename Arc::StateId StateId;
574   typedef typename Arc::Label Label;
575   typedef typename Arc::Weight Weight;
576 
577   SigmaMatcher(const FST &fst,
578                MatchType match_type,
579                Label sigma_label = kNoLabel,
580                MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
581                M *matcher = 0)
582       : matcher_(matcher ? matcher : new M(fst, match_type)),
583         match_type_(match_type),
584         sigma_label_(sigma_label),
585         error_(false) {
586     if (match_type == MATCH_BOTH) {
587       FSTERROR() << "SigmaMatcher: bad match type";
588       match_type_ = MATCH_NONE;
589       error_ = true;
590     }
591     if (sigma_label == 0) {
592       FSTERROR() << "SigmaMatcher: 0 cannot be used as sigma_label";
593       sigma_label_ = kNoLabel;
594       error_ = true;
595     }
596 
597     if (rewrite_mode == MATCHER_REWRITE_AUTO)
598       rewrite_both_ = fst.Properties(kAcceptor, true);
599     else if (rewrite_mode == MATCHER_REWRITE_ALWAYS)
600       rewrite_both_ = true;
601     else
602       rewrite_both_ = false;
603   }
604 
605   SigmaMatcher(const SigmaMatcher<M> &matcher, bool safe = false)
606       : matcher_(new M(*matcher.matcher_, safe)),
607         match_type_(matcher.match_type_),
608         sigma_label_(matcher.sigma_label_),
609         rewrite_both_(matcher.rewrite_both_),
610         error_(matcher.error_) {}
611 
~SigmaMatcher()612   virtual ~SigmaMatcher() {
613     delete matcher_;
614   }
615 
616   virtual SigmaMatcher<M> *Copy(bool safe = false) const {
617     return new SigmaMatcher<M>(*this, safe);
618   }
619 
Type(bool test)620   virtual MatchType Type(bool test) const { return matcher_->Type(test); }
621 
SetState(StateId s)622   void SetState(StateId s) {
623     matcher_->SetState(s);
624     has_sigma_ =
625         sigma_label_ != kNoLabel ? matcher_->Find(sigma_label_) : false;
626   }
627 
Find(Label match_label)628   bool Find(Label match_label) {
629     match_label_ = match_label;
630     if (match_label == sigma_label_ && sigma_label_ != kNoLabel) {
631       FSTERROR() << "SigmaMatcher::Find: bad label (sigma)";
632       error_ = true;
633       return false;
634     }
635     if (matcher_->Find(match_label)) {
636       sigma_match_ = kNoLabel;
637       return true;
638     } else if (has_sigma_ && match_label != 0 && match_label != kNoLabel &&
639                matcher_->Find(sigma_label_)) {
640       sigma_match_ = match_label;
641       return true;
642     } else {
643       return false;
644     }
645   }
646 
Done()647   bool Done() const {
648     return matcher_->Done();
649   }
650 
Value()651   const Arc& Value() const {
652     if (sigma_match_ == kNoLabel) {
653       return matcher_->Value();
654     } else {
655       sigma_arc_ = matcher_->Value();
656       if (rewrite_both_) {
657         if (sigma_arc_.ilabel == sigma_label_)
658           sigma_arc_.ilabel = sigma_match_;
659         if (sigma_arc_.olabel == sigma_label_)
660           sigma_arc_.olabel = sigma_match_;
661       } else if (match_type_ == MATCH_INPUT) {
662         sigma_arc_.ilabel = sigma_match_;
663       } else {
664         sigma_arc_.olabel = sigma_match_;
665       }
666       return sigma_arc_;
667     }
668   }
669 
Next()670   void Next() {
671     matcher_->Next();
672     if (matcher_->Done() && has_sigma_ && (sigma_match_ == kNoLabel) &&
673         (match_label_ > 0)) {
674       matcher_->Find(sigma_label_);
675       sigma_match_ = match_label_;
676     }
677   }
678 
GetFst()679   virtual const FST &GetFst() const { return matcher_->GetFst(); }
680 
681   virtual uint64 Properties(uint64 props) const;
682 
Flags()683   virtual uint32 Flags() const {
684     if (sigma_label_ == kNoLabel || match_type_ == MATCH_NONE)
685       return matcher_->Flags();
686     // kRequireMatch temporarily disabled until issues
687     // in //speech/gaudi/annotation/util/denorm are resolved.
688     // return matcher_->Flags() | kRequireMatch;
689     return matcher_->Flags();
690   }
691 
692 private:
SetState_(StateId s)693   virtual void SetState_(StateId s) { SetState(s); }
Find_(Label label)694   virtual bool Find_(Label label) { return Find(label); }
Done_()695   virtual bool Done_() const { return Done(); }
Value_()696   virtual const Arc& Value_() const { return Value(); }
Next_()697   virtual void Next_() { Next(); }
698 
699   M *matcher_;
700   MatchType match_type_;   // Type of match requested
701   Label sigma_label_;      // Label that represents the sigma transition
702   bool rewrite_both_;      // Rewrite both sides when both are 'sigma_label_'
703   bool has_sigma_;         // Are there sigmas at the current state?
704   Label sigma_match_;      // Current label that matches sigma transition
705   mutable Arc sigma_arc_;  // Arc to return when sigma match
706   Label match_label_;      // Label being matched
707   bool error_;             // Error encountered
708 
709   void operator=(const SigmaMatcher<M> &);  // disallow
710 };
711 
712 template <class M> inline
Properties(uint64 inprops)713 uint64 SigmaMatcher<M>::Properties(uint64 inprops) const {
714   uint64 outprops = matcher_->Properties(inprops);
715   if (error_) outprops |= kError;
716 
717   if (match_type_ == MATCH_NONE) {
718     return outprops;
719   } else if (rewrite_both_) {
720     return outprops & ~(kIDeterministic | kNonIDeterministic |
721                      kODeterministic | kNonODeterministic |
722                      kILabelSorted | kNotILabelSorted |
723                      kOLabelSorted | kNotOLabelSorted |
724                      kString);
725   } else if (match_type_ == MATCH_INPUT) {
726     return outprops & ~(kIDeterministic | kNonIDeterministic |
727                      kODeterministic | kNonODeterministic |
728                      kILabelSorted | kNotILabelSorted |
729                      kString | kAcceptor);
730   } else if (match_type_ == MATCH_OUTPUT) {
731     return outprops & ~(kIDeterministic | kNonIDeterministic |
732                      kODeterministic | kNonODeterministic |
733                      kOLabelSorted | kNotOLabelSorted |
734                      kString | kAcceptor);
735   } else {
736     // Shouldn't ever get here.
737     FSTERROR() << "SigmaMatcher:: bad match type: " << match_type_;
738     return 0;
739   }
740 }
741 
742 
743 // For any requested label that doesn't match at a state, this matcher
744 // considers the *unique* transition that matches the label 'phi_label'
745 // (phi = 'fail'), and recursively looks for a match at its
746 // destination.  When 'phi_loop' is true, if no match is found but a
747 // phi self-loop is found, then the phi transition found is returned
748 // with the phi_label rewritten as the requested label (both sides if
749 // an acceptor, or if 'rewrite_both' is true and both input and output
750 // labels of the found transition are 'phi_label').  If 'phi_label' is
751 // kNoLabel, this special matching is not done.  PhiMatcher is
752 // templated itself on a matcher, which is used to perform the
753 // underlying matching.  By default, the underlying matcher is
754 // constructed by PhiMatcher. The user can instead pass in this
755 // object; in that case, PhiMatcher takes its ownership.
756 // Warning: phi non-determinism not supported (for simplicity).
757 template <class M>
758 class PhiMatcher : public MatcherBase<typename M::Arc> {
759  public:
760   typedef typename M::FST FST;
761   typedef typename M::Arc Arc;
762   typedef typename Arc::StateId StateId;
763   typedef typename Arc::Label Label;
764   typedef typename Arc::Weight Weight;
765 
766   PhiMatcher(const FST &fst,
767              MatchType match_type,
768              Label phi_label = kNoLabel,
769              bool phi_loop = true,
770              MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
771              M *matcher = 0)
772       : matcher_(matcher ? matcher : new M(fst, match_type)),
773         match_type_(match_type),
774         phi_label_(phi_label),
775         state_(kNoStateId),
776         phi_loop_(phi_loop),
777         error_(false) {
778     if (match_type == MATCH_BOTH) {
779       FSTERROR() << "PhiMatcher: bad match type";
780       match_type_ = MATCH_NONE;
781       error_ = true;
782     }
783 
784     if (rewrite_mode == MATCHER_REWRITE_AUTO)
785       rewrite_both_ = fst.Properties(kAcceptor, true);
786     else if (rewrite_mode == MATCHER_REWRITE_ALWAYS)
787       rewrite_both_ = true;
788     else
789       rewrite_both_ = false;
790    }
791 
792   PhiMatcher(const PhiMatcher<M> &matcher, bool safe = false)
793       : matcher_(new M(*matcher.matcher_, safe)),
794         match_type_(matcher.match_type_),
795         phi_label_(matcher.phi_label_),
796         rewrite_both_(matcher.rewrite_both_),
797         state_(kNoStateId),
798         phi_loop_(matcher.phi_loop_),
799         error_(matcher.error_) {}
800 
~PhiMatcher()801   virtual ~PhiMatcher() {
802     delete matcher_;
803   }
804 
805   virtual PhiMatcher<M> *Copy(bool safe = false) const {
806     return new PhiMatcher<M>(*this, safe);
807   }
808 
Type(bool test)809   virtual MatchType Type(bool test) const { return matcher_->Type(test); }
810 
SetState(StateId s)811   void SetState(StateId s) {
812     matcher_->SetState(s);
813     state_ = s;
814     has_phi_ = phi_label_ != kNoLabel;
815   }
816 
817   bool Find(Label match_label);
818 
Done()819   bool Done() const { return matcher_->Done(); }
820 
Value()821   const Arc& Value() const {
822     if ((phi_match_ == kNoLabel) && (phi_weight_ == Weight::One())) {
823       return matcher_->Value();
824     } else if (phi_match_ == 0) {  // Virtual epsilon loop
825       phi_arc_ = Arc(kNoLabel, 0, Weight::One(), state_);
826       if (match_type_ == MATCH_OUTPUT)
827         swap(phi_arc_.ilabel, phi_arc_.olabel);
828       return phi_arc_;
829     } else {
830       phi_arc_ = matcher_->Value();
831       phi_arc_.weight = Times(phi_weight_, phi_arc_.weight);
832       if (phi_match_ != kNoLabel) {  // Phi loop match
833         if (rewrite_both_) {
834           if (phi_arc_.ilabel == phi_label_)
835             phi_arc_.ilabel = phi_match_;
836           if (phi_arc_.olabel == phi_label_)
837             phi_arc_.olabel = phi_match_;
838         } else if (match_type_ == MATCH_INPUT) {
839           phi_arc_.ilabel = phi_match_;
840         } else {
841           phi_arc_.olabel = phi_match_;
842         }
843       }
844       return phi_arc_;
845     }
846   }
847 
Next()848   void Next() { matcher_->Next(); }
849 
GetFst()850   virtual const FST &GetFst() const { return matcher_->GetFst(); }
851 
852   virtual uint64 Properties(uint64 props) const;
853 
Flags()854   virtual uint32 Flags() const {
855     if (phi_label_ == kNoLabel || match_type_ == MATCH_NONE)
856       return matcher_->Flags();
857     return matcher_->Flags() | kRequireMatch;
858   }
859 
860 private:
SetState_(StateId s)861   virtual void SetState_(StateId s) { SetState(s); }
Find_(Label label)862   virtual bool Find_(Label label) { return Find(label); }
Done_()863   virtual bool Done_() const { return Done(); }
Value_()864   virtual const Arc& Value_() const { return Value(); }
Next_()865   virtual void Next_() { Next(); }
866 
867   M *matcher_;
868   MatchType match_type_;  // Type of match requested
869   Label phi_label_;       // Label that represents the phi transition
870   bool rewrite_both_;     // Rewrite both sides when both are 'phi_label_'
871   bool has_phi_;          // Are there possibly phis at the current state?
872   Label phi_match_;       // Current label that matches phi loop
873   mutable Arc phi_arc_;   // Arc to return
874   StateId state_;         // State where looking for matches
875   Weight phi_weight_;     // Product of the weights of phi transitions taken
876   bool phi_loop_;         // When true, phi self-loop are allowed and treated
877                           // as rho (required for Aho-Corasick)
878   bool error_;             // Error encountered
879 
880   void operator=(const PhiMatcher<M> &);  // disallow
881 };
882 
883 template <class M> inline
Find(Label match_label)884 bool PhiMatcher<M>::Find(Label match_label) {
885   if (match_label == phi_label_ && phi_label_ != kNoLabel && phi_label_ != 0) {
886     FSTERROR() << "PhiMatcher::Find: bad label (phi): " << phi_label_;
887     error_ = true;
888     return false;
889   }
890   matcher_->SetState(state_);
891   phi_match_ = kNoLabel;
892   phi_weight_ = Weight::One();
893   if (phi_label_ == 0) {          // When 'phi_label_ == 0',
894     if (match_label == kNoLabel)  // there are no more true epsilon arcs,
895       return false;
896     if (match_label == 0) {       // but virtual eps loop need to be returned
897       if (!matcher_->Find(kNoLabel)) {
898         return matcher_->Find(0);
899       } else {
900         phi_match_ = 0;
901         return true;
902       }
903     }
904   }
905   if (!has_phi_ || match_label == 0 || match_label == kNoLabel)
906     return matcher_->Find(match_label);
907   StateId state = state_;
908   while (!matcher_->Find(match_label)) {
909     // Look for phi transition (if phi_label_ == 0, we need to look
910     // for -1 to avoid getting the virtual self-loop)
911     if (!matcher_->Find(phi_label_ == 0 ? -1 : phi_label_))
912       return false;
913     if (phi_loop_ && matcher_->Value().nextstate == state) {
914       phi_match_ = match_label;
915       return true;
916     }
917     phi_weight_ = Times(phi_weight_, matcher_->Value().weight);
918     state = matcher_->Value().nextstate;
919     matcher_->Next();
920     if (!matcher_->Done()) {
921       FSTERROR() << "PhiMatcher: phi non-determinism not supported";
922       error_ = true;
923     }
924     matcher_->SetState(state);
925   }
926   return true;
927 }
928 
929 template <class M> inline
Properties(uint64 inprops)930 uint64 PhiMatcher<M>::Properties(uint64 inprops) const {
931   uint64 outprops = matcher_->Properties(inprops);
932   if (error_) outprops |= kError;
933 
934   if (match_type_ == MATCH_NONE) {
935     return outprops;
936   } else if (match_type_ == MATCH_INPUT) {
937     if (phi_label_ == 0) {
938       outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons;
939       outprops |= kNoEpsilons | kNoIEpsilons;
940     }
941     if (rewrite_both_) {
942       return outprops & ~(kODeterministic | kNonODeterministic | kString |
943                        kILabelSorted | kNotILabelSorted |
944                        kOLabelSorted | kNotOLabelSorted);
945     } else {
946       return outprops & ~(kODeterministic | kAcceptor | kString |
947                        kILabelSorted | kNotILabelSorted |
948                        kOLabelSorted | kNotOLabelSorted);
949     }
950   } else if (match_type_ == MATCH_OUTPUT) {
951     if (phi_label_ == 0) {
952       outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons;
953       outprops |= kNoEpsilons | kNoOEpsilons;
954     }
955     if (rewrite_both_) {
956       return outprops & ~(kIDeterministic | kNonIDeterministic | kString |
957                        kILabelSorted | kNotILabelSorted |
958                        kOLabelSorted | kNotOLabelSorted);
959     } else {
960       return outprops & ~(kIDeterministic | kAcceptor | kString |
961                        kILabelSorted | kNotILabelSorted |
962                        kOLabelSorted | kNotOLabelSorted);
963     }
964   } else {
965     // Shouldn't ever get here.
966     FSTERROR() << "PhiMatcher:: bad match type: " << match_type_;
967     return 0;
968   }
969 }
970 
971 
972 //
973 // MULTI-EPS MATCHER FLAGS
974 //
975 
976 // Return multi-epsilon arcs for Find(kNoLabel).
977 const uint32 kMultiEpsList =  0x00000001;
978 
979 // Return a kNolabel loop for Find(multi_eps).
980 const uint32 kMultiEpsLoop =  0x00000002;
981 
982 // MultiEpsMatcher: allows treating multiple non-0 labels as
983 // non-consuming labels in addition to 0 that is always
984 // non-consuming. Precise behavior controlled by 'flags' argument. By
985 // default, the underlying matcher is constructed by
986 // MultiEpsMatcher. The user can instead pass in this object; in that
987 // case, MultiEpsMatcher takes its ownership iff 'own_matcher' is
988 // true.
989 template <class M>
990 class MultiEpsMatcher {
991  public:
992   typedef typename M::FST FST;
993   typedef typename M::Arc Arc;
994   typedef typename Arc::StateId StateId;
995   typedef typename Arc::Label Label;
996   typedef typename Arc::Weight Weight;
997 
998   MultiEpsMatcher(const FST &fst, MatchType match_type,
999                   uint32 flags = (kMultiEpsLoop | kMultiEpsList),
1000                   M *matcher = 0, bool own_matcher = true)
1001       : matcher_(matcher ? matcher : new M(fst, match_type)),
1002         flags_(flags),
1003         own_matcher_(matcher ?  own_matcher : true) {
1004     if (match_type == MATCH_INPUT) {
1005       loop_.ilabel = kNoLabel;
1006       loop_.olabel = 0;
1007     } else {
1008       loop_.ilabel = 0;
1009       loop_.olabel = kNoLabel;
1010     }
1011     loop_.weight = Weight::One();
1012     loop_.nextstate = kNoStateId;
1013   }
1014 
1015   MultiEpsMatcher(const MultiEpsMatcher<M> &matcher, bool safe = false)
1016       : matcher_(new M(*matcher.matcher_, safe)),
1017         flags_(matcher.flags_),
1018         own_matcher_(true),
1019         multi_eps_labels_(matcher.multi_eps_labels_),
1020         loop_(matcher.loop_) {
1021     loop_.nextstate = kNoStateId;
1022   }
1023 
~MultiEpsMatcher()1024   ~MultiEpsMatcher() {
1025     if (own_matcher_)
1026       delete matcher_;
1027   }
1028 
1029   MultiEpsMatcher<M> *Copy(bool safe = false) const {
1030     return new MultiEpsMatcher<M>(*this, safe);
1031   }
1032 
Type(bool test)1033   MatchType Type(bool test) const { return matcher_->Type(test); }
1034 
SetState(StateId s)1035   void SetState(StateId s) {
1036     matcher_->SetState(s);
1037     loop_.nextstate = s;
1038   }
1039 
1040   bool Find(Label match_label);
1041 
Done()1042   bool Done() const {
1043     return done_;
1044   }
1045 
Value()1046   const Arc& Value() const {
1047     return current_loop_ ? loop_ : matcher_->Value();
1048   }
1049 
Next()1050   void Next() {
1051     if (!current_loop_) {
1052       matcher_->Next();
1053       done_ = matcher_->Done();
1054       if (done_ && multi_eps_iter_ != multi_eps_labels_.End()) {
1055         ++multi_eps_iter_;
1056         while ((multi_eps_iter_ != multi_eps_labels_.End()) &&
1057                !matcher_->Find(*multi_eps_iter_))
1058           ++multi_eps_iter_;
1059         if (multi_eps_iter_ != multi_eps_labels_.End())
1060           done_ = false;
1061         else
1062           done_ = !matcher_->Find(kNoLabel);
1063 
1064       }
1065     } else {
1066       done_ = true;
1067     }
1068   }
1069 
GetFst()1070   const FST &GetFst() const { return matcher_->GetFst(); }
1071 
Properties(uint64 props)1072   uint64 Properties(uint64 props) const { return matcher_->Properties(props); }
1073 
Flags()1074   uint32 Flags() const { return matcher_->Flags(); }
1075 
AddMultiEpsLabel(Label label)1076   void AddMultiEpsLabel(Label label) {
1077     if (label == 0) {
1078       FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0";
1079     } else {
1080       multi_eps_labels_.Insert(label);
1081     }
1082   }
1083 
RemoveMultiEpsLabel(Label label)1084   void RemoveMultiEpsLabel(Label label) {
1085     if (label == 0) {
1086       FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0";
1087     } else {
1088       multi_eps_labels_.Erase(label);
1089     }
1090   }
1091 
ClearMultiEpsLabels()1092   void ClearMultiEpsLabels() {
1093     multi_eps_labels_.Clear();
1094   }
1095 
1096 private:
1097   M *matcher_;
1098   uint32 flags_;
1099   bool own_matcher_;             // Does this class delete the matcher?
1100 
1101   // Multi-eps label set
1102   CompactSet<Label, kNoLabel> multi_eps_labels_;
1103   typename CompactSet<Label, kNoLabel>::const_iterator multi_eps_iter_;
1104 
1105   bool current_loop_;            // Current arc is the implicit loop
1106   mutable Arc loop_;             // For non-consuming symbols
1107   bool done_;                    // Matching done
1108 
1109   void operator=(const MultiEpsMatcher<M> &);  // Disallow
1110 };
1111 
1112 template <class M> inline
Find(Label match_label)1113 bool MultiEpsMatcher<M>::Find(Label match_label) {
1114   multi_eps_iter_ = multi_eps_labels_.End();
1115   current_loop_ = false;
1116   bool ret;
1117   if (match_label == 0) {
1118     ret = matcher_->Find(0);
1119   } else if (match_label == kNoLabel) {
1120     if (flags_ & kMultiEpsList) {
1121       // return all non-consuming arcs (incl. epsilon)
1122       multi_eps_iter_ = multi_eps_labels_.Begin();
1123       while ((multi_eps_iter_ != multi_eps_labels_.End()) &&
1124              !matcher_->Find(*multi_eps_iter_))
1125         ++multi_eps_iter_;
1126       if (multi_eps_iter_ != multi_eps_labels_.End())
1127         ret = true;
1128       else
1129         ret = matcher_->Find(kNoLabel);
1130     } else {
1131       // return all epsilon arcs
1132       ret = matcher_->Find(kNoLabel);
1133     }
1134   } else if ((flags_ & kMultiEpsLoop) &&
1135              multi_eps_labels_.Find(match_label) != multi_eps_labels_.End()) {
1136     // return 'implicit' loop
1137     current_loop_ = true;
1138     ret = true;
1139   } else {
1140     ret = matcher_->Find(match_label);
1141   }
1142   done_ = !ret;
1143   return ret;
1144 }
1145 
1146 
1147 // Generic matcher, templated on the FST definition
1148 // - a wrapper around pointer to specific one.
1149 // Here is a typical use: \code
1150 //   Matcher<StdFst> matcher(fst, MATCH_INPUT);
1151 //   matcher.SetState(state);
1152 //   if (matcher.Find(label))
1153 //     for (; !matcher.Done(); matcher.Next()) {
1154 //       StdArc &arc = matcher.Value();
1155 //       ...
1156 //     } \endcode
1157 template <class F>
1158 class Matcher {
1159  public:
1160   typedef F FST;
1161   typedef typename F::Arc Arc;
1162   typedef typename Arc::StateId StateId;
1163   typedef typename Arc::Label Label;
1164   typedef typename Arc::Weight Weight;
1165 
Matcher(const F & fst,MatchType match_type)1166   Matcher(const F &fst, MatchType match_type) {
1167     base_ = fst.InitMatcher(match_type);
1168     if (!base_)
1169       base_ = new SortedMatcher<F>(fst, match_type);
1170   }
1171 
1172   Matcher(const Matcher<F> &matcher, bool safe = false) {
1173     base_ = matcher.base_->Copy(safe);
1174   }
1175 
1176   // Takes ownership of the provided matcher
Matcher(MatcherBase<Arc> * base_matcher)1177   Matcher(MatcherBase<Arc>* base_matcher) { base_ = base_matcher; }
1178 
~Matcher()1179   ~Matcher() { delete base_; }
1180 
1181   Matcher<F> *Copy(bool safe = false) const {
1182     return new Matcher<F>(*this, safe);
1183   }
1184 
Type(bool test)1185   MatchType Type(bool test) const { return base_->Type(test); }
SetState(StateId s)1186   void SetState(StateId s) { base_->SetState(s); }
Find(Label label)1187   bool Find(Label label) { return base_->Find(label); }
Done()1188   bool Done() const { return base_->Done(); }
Value()1189   const Arc& Value() const { return base_->Value(); }
Next()1190   void Next() { base_->Next(); }
GetFst()1191   const F &GetFst() const { return static_cast<const F &>(base_->GetFst()); }
Properties(uint64 props)1192   uint64 Properties(uint64 props) const { return base_->Properties(props); }
Flags()1193   uint32 Flags() const { return base_->Flags() & kMatcherFlags; }
1194 
1195  private:
1196   MatcherBase<Arc> *base_;
1197 
1198   void operator=(const Matcher<Arc> &);  // disallow
1199 };
1200 
1201 }  // namespace fst
1202 
1203 
1204 
1205 #endif  // FST_LIB_MATCHER_H__
1206