1 // matcher.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Classes to allow matching labels leaving FST states.
20
21 #ifndef FST_LIB_MATCHER_H__
22 #define FST_LIB_MATCHER_H__
23
24 #include <algorithm>
25 #include <set>
26
27 #include <fst/mutable-fst.h> // for all internal FST accessors
28
29
30 namespace fst {
31
32 // MATCHERS - these can find and iterate through requested labels at
33 // FST states. In the simplest form, these are just some associative
34 // map or search keyed on labels. More generally, they may
35 // implement matching special labels that represent sets of labels
36 // such as 'sigma' (all), 'rho' (rest), or 'phi' (fail).
37 // The Matcher interface is:
38 //
39 // template <class F>
40 // class Matcher {
41 // public:
42 // typedef F FST;
43 // typedef F::Arc Arc;
44 // typedef typename Arc::StateId StateId;
45 // typedef typename Arc::Label Label;
46 // typedef typename Arc::Weight Weight;
47 //
48 // // Required constructors.
49 // Matcher(const F &fst, MatchType type);
50 // // If safe=true, the copy is thread-safe. See Fst<>::Copy()
51 // // for further doc.
52 // Matcher(const Matcher &matcher, bool safe = false);
53 //
54 // // If safe=true, the copy is thread-safe. See Fst<>::Copy()
55 // // for further doc.
56 // Matcher<F> *Copy(bool safe = false) const;
57 //
58 // // Returns the match type that can be provided (depending on
59 // // compatibility of the input FST). It is either
60 // // the requested match type, MATCH_NONE, or MATCH_UNKNOWN.
61 // // If 'test' is false, a constant time test is performed, but
62 // // MATCH_UNKNOWN may be returned. If 'test' is true,
63 // // a definite answer is returned, but may involve more costly
64 // // computation (e.g., visiting the Fst).
65 // MatchType Type(bool test) const;
66 // // Specifies the current state.
67 // void SetState(StateId s);
68 //
69 // // This finds matches to a label at the current state.
70 // // Returns true if a match found. kNoLabel matches any
71 // // 'non-consuming' transitions, e.g., epsilon transitions,
72 // // which do not require a matching symbol.
73 // bool Find(Label label);
74 // // These iterate through any matches found:
75 // bool Done() const; // No more matches.
76 // const A& Value() const; // Current arc (when !Done)
77 // void Next(); // Advance to next arc (when !Done)
78 // // Initially and after SetState() the iterator methods
79 // // have undefined behavior until Find() is called.
80 //
81 // // Return matcher FST.
82 // const F& GetFst() const;
83 // // This specifies the known Fst properties as viewed from this
84 // // matcher. It takes as argument the input Fst's known properties.
85 // uint64 Properties(uint64 props) const;
86 // };
87
88 //
89 // MATCHER FLAGS (see also kLookAheadFlags in lookahead-matcher.h)
90 //
91 // Matcher prefers being used as the matching side in composition.
92 const uint32 kPreferMatch = 0x00000001;
93
94 // Matcher needs to be used as the matching side in composition.
95 const uint32 kRequireMatch = 0x00000002;
96
97 // Flags used for basic matchers (see also lookahead.h).
98 const uint32 kMatcherFlags = kPreferMatch | kRequireMatch;
99
100 // Matcher interface, templated on the Arc definition; used
101 // for matcher specializations that are returned by the
102 // InitMatcher Fst method.
103 template <class A>
104 class MatcherBase {
105 public:
106 typedef A Arc;
107 typedef typename A::StateId StateId;
108 typedef typename A::Label Label;
109 typedef typename A::Weight Weight;
110
~MatcherBase()111 virtual ~MatcherBase() {}
112
113 virtual MatcherBase<A> *Copy(bool safe = false) const = 0;
114 virtual MatchType Type(bool test) const = 0;
SetState(StateId s)115 void SetState(StateId s) { SetState_(s); }
Find(Label label)116 bool Find(Label label) { return Find_(label); }
Done()117 bool Done() const { return Done_(); }
Value()118 const A& Value() const { return Value_(); }
Next()119 void Next() { Next_(); }
120 virtual const Fst<A> &GetFst() const = 0;
121 virtual uint64 Properties(uint64 props) const = 0;
Flags()122 virtual uint32 Flags() const { return 0; }
123 private:
124 virtual void SetState_(StateId s) = 0;
125 virtual bool Find_(Label label) = 0;
126 virtual bool Done_() const = 0;
127 virtual const A& Value_() const = 0;
128 virtual void Next_() = 0;
129 };
130
131
132 // A matcher that expects sorted labels on the side to be matched.
133 // If match_type == MATCH_INPUT, epsilons match the implicit self loop
134 // Arc(kNoLabel, 0, Weight::One(), current_state) as well as any
135 // actual epsilon transitions. If match_type == MATCH_OUTPUT, then
136 // Arc(0, kNoLabel, Weight::One(), current_state) is instead matched.
137 template <class F>
138 class SortedMatcher : public MatcherBase<typename F::Arc> {
139 public:
140 typedef F FST;
141 typedef typename F::Arc Arc;
142 typedef typename Arc::StateId StateId;
143 typedef typename Arc::Label Label;
144 typedef typename Arc::Weight Weight;
145
146 // Labels >= binary_label will be searched for by binary search,
147 // o.w. linear search is used.
148 SortedMatcher(const F &fst, MatchType match_type,
149 Label binary_label = 1)
150 : fst_(fst.Copy()),
151 s_(kNoStateId),
152 aiter_(0),
153 match_type_(match_type),
154 binary_label_(binary_label),
155 match_label_(kNoLabel),
156 narcs_(0),
157 loop_(kNoLabel, 0, Weight::One(), kNoStateId),
158 error_(false) {
159 switch(match_type_) {
160 case MATCH_INPUT:
161 case MATCH_NONE:
162 break;
163 case MATCH_OUTPUT:
164 swap(loop_.ilabel, loop_.olabel);
165 break;
166 default:
167 FSTERROR() << "SortedMatcher: bad match type";
168 match_type_ = MATCH_NONE;
169 error_ = true;
170 }
171 }
172
173 SortedMatcher(const SortedMatcher<F> &matcher, bool safe = false)
174 : fst_(matcher.fst_->Copy(safe)),
175 s_(kNoStateId),
176 aiter_(0),
177 match_type_(matcher.match_type_),
178 binary_label_(matcher.binary_label_),
179 match_label_(kNoLabel),
180 narcs_(0),
181 loop_(matcher.loop_),
182 error_(matcher.error_) {}
183
~SortedMatcher()184 virtual ~SortedMatcher() {
185 if (aiter_)
186 delete aiter_;
187 delete fst_;
188 }
189
190 virtual SortedMatcher<F> *Copy(bool safe = false) const {
191 return new SortedMatcher<F>(*this, safe);
192 }
193
Type(bool test)194 virtual MatchType Type(bool test) const {
195 if (match_type_ == MATCH_NONE)
196 return match_type_;
197
198 uint64 true_prop = match_type_ == MATCH_INPUT ?
199 kILabelSorted : kOLabelSorted;
200 uint64 false_prop = match_type_ == MATCH_INPUT ?
201 kNotILabelSorted : kNotOLabelSorted;
202 uint64 props = fst_->Properties(true_prop | false_prop, test);
203
204 if (props & true_prop)
205 return match_type_;
206 else if (props & false_prop)
207 return MATCH_NONE;
208 else
209 return MATCH_UNKNOWN;
210 }
211
SetState(StateId s)212 void SetState(StateId s) {
213 if (s_ == s)
214 return;
215 s_ = s;
216 if (match_type_ == MATCH_NONE) {
217 FSTERROR() << "SortedMatcher: bad match type";
218 error_ = true;
219 }
220 if (aiter_)
221 delete aiter_;
222 aiter_ = new ArcIterator<F>(*fst_, s);
223 aiter_->SetFlags(kArcNoCache, kArcNoCache);
224 narcs_ = internal::NumArcs(*fst_, s);
225 loop_.nextstate = s;
226 }
227
Find(Label match_label)228 bool Find(Label match_label) {
229 exact_match_ = true;
230 if (error_) {
231 current_loop_ = false;
232 match_label_ = kNoLabel;
233 return false;
234 }
235 current_loop_ = match_label == 0;
236 match_label_ = match_label == kNoLabel ? 0 : match_label;
237 if (Search()) {
238 return true;
239 } else {
240 return current_loop_;
241 }
242 }
243
244 // Positions matcher to the first position where inserting
245 // match_label would maintain the sort order.
LowerBound(Label match_label)246 void LowerBound(Label match_label) {
247 exact_match_ = false;
248 current_loop_ = false;
249 if (error_) {
250 match_label_ = kNoLabel;
251 return;
252 }
253 match_label_ = match_label;
254 Search();
255 }
256
257 // After Find(), returns false if no more exact matches.
258 // After LowerBound(), returns false if no more arcs.
Done()259 bool Done() const {
260 if (current_loop_)
261 return false;
262 if (aiter_->Done())
263 return true;
264 if (!exact_match_)
265 return false;
266 aiter_->SetFlags(
267 match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue,
268 kArcValueFlags);
269 Label label = match_type_ == MATCH_INPUT ?
270 aiter_->Value().ilabel : aiter_->Value().olabel;
271 return label != match_label_;
272 }
273
Value()274 const Arc& Value() const {
275 if (current_loop_) {
276 return loop_;
277 }
278 aiter_->SetFlags(kArcValueFlags, kArcValueFlags);
279 return aiter_->Value();
280 }
281
Next()282 void Next() {
283 if (current_loop_)
284 current_loop_ = false;
285 else
286 aiter_->Next();
287 }
288
GetFst()289 virtual const F &GetFst() const { return *fst_; }
290
Properties(uint64 inprops)291 virtual uint64 Properties(uint64 inprops) const {
292 uint64 outprops = inprops;
293 if (error_) outprops |= kError;
294 return outprops;
295 }
296
Position()297 size_t Position() const { return aiter_ ? aiter_->Position() : 0; }
298
299 private:
SetState_(StateId s)300 virtual void SetState_(StateId s) { SetState(s); }
Find_(Label label)301 virtual bool Find_(Label label) { return Find(label); }
Done_()302 virtual bool Done_() const { return Done(); }
Value_()303 virtual const Arc& Value_() const { return Value(); }
Next_()304 virtual void Next_() { Next(); }
305
306 bool Search();
307
308 const F *fst_;
309 StateId s_; // Current state
310 ArcIterator<F> *aiter_; // Iterator for current state
311 MatchType match_type_; // Type of match to perform
312 Label binary_label_; // Least label for binary search
313 Label match_label_; // Current label to be matched
314 size_t narcs_; // Current state arc count
315 Arc loop_; // For non-consuming symbols
316 bool current_loop_; // Current arc is the implicit loop
317 bool exact_match_; // Exact match or lower bound?
318 bool error_; // Error encountered
319
320 void operator=(const SortedMatcher<F> &); // Disallow
321 };
322
323 // Returns true iff match to match_label_. Positions arc iterator at
324 // lower bound regardless.
325 template <class F> inline
Search()326 bool SortedMatcher<F>::Search() {
327 aiter_->SetFlags(
328 match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue,
329 kArcValueFlags);
330 if (match_label_ >= binary_label_) {
331 // Binary search for match.
332 size_t low = 0;
333 size_t high = narcs_;
334 while (low < high) {
335 size_t mid = (low + high) / 2;
336 aiter_->Seek(mid);
337 Label label = match_type_ == MATCH_INPUT ?
338 aiter_->Value().ilabel : aiter_->Value().olabel;
339 if (label > match_label_) {
340 high = mid;
341 } else if (label < match_label_) {
342 low = mid + 1;
343 } else {
344 // find first matching label (when non-determinism)
345 for (size_t i = mid; i > low; --i) {
346 aiter_->Seek(i - 1);
347 label = match_type_ == MATCH_INPUT ? aiter_->Value().ilabel :
348 aiter_->Value().olabel;
349 if (label != match_label_) {
350 aiter_->Seek(i);
351 return true;
352 }
353 }
354 return true;
355 }
356 }
357 aiter_->Seek(low);
358 return false;
359 } else {
360 // Linear search for match.
361 for (aiter_->Reset(); !aiter_->Done(); aiter_->Next()) {
362 Label label = match_type_ == MATCH_INPUT ?
363 aiter_->Value().ilabel : aiter_->Value().olabel;
364 if (label == match_label_) {
365 return true;
366 }
367 if (label > match_label_)
368 break;
369 }
370 return false;
371 }
372 }
373
374
375 // Specifies whether during matching we rewrite both the input and output sides.
376 enum MatcherRewriteMode {
377 MATCHER_REWRITE_AUTO = 0, // Rewrites both sides iff acceptor.
378 MATCHER_REWRITE_ALWAYS,
379 MATCHER_REWRITE_NEVER
380 };
381
382
383 // For any requested label that doesn't match at a state, this matcher
384 // considers all transitions that match the label 'rho_label' (rho =
385 // 'rest'). Each such rho transition found is returned with the
386 // rho_label rewritten as the requested label (both sides if an
387 // acceptor, or if 'rewrite_both' is true and both input and output
388 // labels of the found transition are 'rho_label'). If 'rho_label' is
389 // kNoLabel, this special matching is not done. RhoMatcher is
390 // templated itself on a matcher, which is used to perform the
391 // underlying matching. By default, the underlying matcher is
392 // constructed by RhoMatcher. The user can instead pass in this
393 // object; in that case, RhoMatcher takes its ownership.
394 template <class M>
395 class RhoMatcher : public MatcherBase<typename M::Arc> {
396 public:
397 typedef typename M::FST FST;
398 typedef typename M::Arc Arc;
399 typedef typename Arc::StateId StateId;
400 typedef typename Arc::Label Label;
401 typedef typename Arc::Weight Weight;
402
403 RhoMatcher(const FST &fst,
404 MatchType match_type,
405 Label rho_label = kNoLabel,
406 MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
407 M *matcher = 0)
408 : matcher_(matcher ? matcher : new M(fst, match_type)),
409 match_type_(match_type),
410 rho_label_(rho_label),
411 error_(false) {
412 if (match_type == MATCH_BOTH) {
413 FSTERROR() << "RhoMatcher: bad match type";
414 match_type_ = MATCH_NONE;
415 error_ = true;
416 }
417 if (rho_label == 0) {
418 FSTERROR() << "RhoMatcher: 0 cannot be used as rho_label";
419 rho_label_ = kNoLabel;
420 error_ = true;
421 }
422
423 if (rewrite_mode == MATCHER_REWRITE_AUTO)
424 rewrite_both_ = fst.Properties(kAcceptor, true);
425 else if (rewrite_mode == MATCHER_REWRITE_ALWAYS)
426 rewrite_both_ = true;
427 else
428 rewrite_both_ = false;
429 }
430
431 RhoMatcher(const RhoMatcher<M> &matcher, bool safe = false)
432 : matcher_(new M(*matcher.matcher_, safe)),
433 match_type_(matcher.match_type_),
434 rho_label_(matcher.rho_label_),
435 rewrite_both_(matcher.rewrite_both_),
436 error_(matcher.error_) {}
437
~RhoMatcher()438 virtual ~RhoMatcher() {
439 delete matcher_;
440 }
441
442 virtual RhoMatcher<M> *Copy(bool safe = false) const {
443 return new RhoMatcher<M>(*this, safe);
444 }
445
Type(bool test)446 virtual MatchType Type(bool test) const { return matcher_->Type(test); }
447
SetState(StateId s)448 void SetState(StateId s) {
449 matcher_->SetState(s);
450 has_rho_ = rho_label_ != kNoLabel;
451 }
452
Find(Label match_label)453 bool Find(Label match_label) {
454 if (match_label == rho_label_ && rho_label_ != kNoLabel) {
455 FSTERROR() << "RhoMatcher::Find: bad label (rho)";
456 error_ = true;
457 return false;
458 }
459 if (matcher_->Find(match_label)) {
460 rho_match_ = kNoLabel;
461 return true;
462 } else if (has_rho_ && match_label != 0 && match_label != kNoLabel &&
463 (has_rho_ = matcher_->Find(rho_label_))) {
464 rho_match_ = match_label;
465 return true;
466 } else {
467 return false;
468 }
469 }
470
Done()471 bool Done() const { return matcher_->Done(); }
472
Value()473 const Arc& Value() const {
474 if (rho_match_ == kNoLabel) {
475 return matcher_->Value();
476 } else {
477 rho_arc_ = matcher_->Value();
478 if (rewrite_both_) {
479 if (rho_arc_.ilabel == rho_label_)
480 rho_arc_.ilabel = rho_match_;
481 if (rho_arc_.olabel == rho_label_)
482 rho_arc_.olabel = rho_match_;
483 } else if (match_type_ == MATCH_INPUT) {
484 rho_arc_.ilabel = rho_match_;
485 } else {
486 rho_arc_.olabel = rho_match_;
487 }
488 return rho_arc_;
489 }
490 }
491
Next()492 void Next() { matcher_->Next(); }
493
GetFst()494 virtual const FST &GetFst() const { return matcher_->GetFst(); }
495
496 virtual uint64 Properties(uint64 props) const;
497
Flags()498 virtual uint32 Flags() const {
499 if (rho_label_ == kNoLabel || match_type_ == MATCH_NONE)
500 return matcher_->Flags();
501 return matcher_->Flags() | kRequireMatch;
502 }
503
504 private:
SetState_(StateId s)505 virtual void SetState_(StateId s) { SetState(s); }
Find_(Label label)506 virtual bool Find_(Label label) { return Find(label); }
Done_()507 virtual bool Done_() const { return Done(); }
Value_()508 virtual const Arc& Value_() const { return Value(); }
Next_()509 virtual void Next_() { Next(); }
510
511 M *matcher_;
512 MatchType match_type_; // Type of match requested
513 Label rho_label_; // Label that represents the rho transition
514 bool rewrite_both_; // Rewrite both sides when both are 'rho_label_'
515 bool has_rho_; // Are there possibly rhos at the current state?
516 Label rho_match_; // Current label that matches rho transition
517 mutable Arc rho_arc_; // Arc to return when rho match
518 bool error_; // Error encountered
519
520 void operator=(const RhoMatcher<M> &); // Disallow
521 };
522
523 template <class M> inline
Properties(uint64 inprops)524 uint64 RhoMatcher<M>::Properties(uint64 inprops) const {
525 uint64 outprops = matcher_->Properties(inprops);
526 if (error_) outprops |= kError;
527
528 if (match_type_ == MATCH_NONE) {
529 return outprops;
530 } else if (match_type_ == MATCH_INPUT) {
531 if (rewrite_both_) {
532 return outprops & ~(kODeterministic | kNonODeterministic | kString |
533 kILabelSorted | kNotILabelSorted |
534 kOLabelSorted | kNotOLabelSorted);
535 } else {
536 return outprops & ~(kODeterministic | kAcceptor | kString |
537 kILabelSorted | kNotILabelSorted);
538 }
539 } else if (match_type_ == MATCH_OUTPUT) {
540 if (rewrite_both_) {
541 return outprops & ~(kIDeterministic | kNonIDeterministic | kString |
542 kILabelSorted | kNotILabelSorted |
543 kOLabelSorted | kNotOLabelSorted);
544 } else {
545 return outprops & ~(kIDeterministic | kAcceptor | kString |
546 kOLabelSorted | kNotOLabelSorted);
547 }
548 } else {
549 // Shouldn't ever get here.
550 FSTERROR() << "RhoMatcher:: bad match type: " << match_type_;
551 return 0;
552 }
553 }
554
555
556 // For any requested label, this matcher considers all transitions
557 // that match the label 'sigma_label' (sigma = "any"), and this in
558 // additions to transitions with the requested label. Each such sigma
559 // transition found is returned with the sigma_label rewritten as the
560 // requested label (both sides if an acceptor, or if 'rewrite_both' is
561 // true and both input and output labels of the found transition are
562 // 'sigma_label'). If 'sigma_label' is kNoLabel, this special
563 // matching is not done. SigmaMatcher is templated itself on a
564 // matcher, which is used to perform the underlying matching. By
565 // default, the underlying matcher is constructed by SigmaMatcher.
566 // The user can instead pass in this object; in that case,
567 // SigmaMatcher takes its ownership.
568 template <class M>
569 class SigmaMatcher : public MatcherBase<typename M::Arc> {
570 public:
571 typedef typename M::FST FST;
572 typedef typename M::Arc Arc;
573 typedef typename Arc::StateId StateId;
574 typedef typename Arc::Label Label;
575 typedef typename Arc::Weight Weight;
576
577 SigmaMatcher(const FST &fst,
578 MatchType match_type,
579 Label sigma_label = kNoLabel,
580 MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
581 M *matcher = 0)
582 : matcher_(matcher ? matcher : new M(fst, match_type)),
583 match_type_(match_type),
584 sigma_label_(sigma_label),
585 error_(false) {
586 if (match_type == MATCH_BOTH) {
587 FSTERROR() << "SigmaMatcher: bad match type";
588 match_type_ = MATCH_NONE;
589 error_ = true;
590 }
591 if (sigma_label == 0) {
592 FSTERROR() << "SigmaMatcher: 0 cannot be used as sigma_label";
593 sigma_label_ = kNoLabel;
594 error_ = true;
595 }
596
597 if (rewrite_mode == MATCHER_REWRITE_AUTO)
598 rewrite_both_ = fst.Properties(kAcceptor, true);
599 else if (rewrite_mode == MATCHER_REWRITE_ALWAYS)
600 rewrite_both_ = true;
601 else
602 rewrite_both_ = false;
603 }
604
605 SigmaMatcher(const SigmaMatcher<M> &matcher, bool safe = false)
606 : matcher_(new M(*matcher.matcher_, safe)),
607 match_type_(matcher.match_type_),
608 sigma_label_(matcher.sigma_label_),
609 rewrite_both_(matcher.rewrite_both_),
610 error_(matcher.error_) {}
611
~SigmaMatcher()612 virtual ~SigmaMatcher() {
613 delete matcher_;
614 }
615
616 virtual SigmaMatcher<M> *Copy(bool safe = false) const {
617 return new SigmaMatcher<M>(*this, safe);
618 }
619
Type(bool test)620 virtual MatchType Type(bool test) const { return matcher_->Type(test); }
621
SetState(StateId s)622 void SetState(StateId s) {
623 matcher_->SetState(s);
624 has_sigma_ =
625 sigma_label_ != kNoLabel ? matcher_->Find(sigma_label_) : false;
626 }
627
Find(Label match_label)628 bool Find(Label match_label) {
629 match_label_ = match_label;
630 if (match_label == sigma_label_ && sigma_label_ != kNoLabel) {
631 FSTERROR() << "SigmaMatcher::Find: bad label (sigma)";
632 error_ = true;
633 return false;
634 }
635 if (matcher_->Find(match_label)) {
636 sigma_match_ = kNoLabel;
637 return true;
638 } else if (has_sigma_ && match_label != 0 && match_label != kNoLabel &&
639 matcher_->Find(sigma_label_)) {
640 sigma_match_ = match_label;
641 return true;
642 } else {
643 return false;
644 }
645 }
646
Done()647 bool Done() const {
648 return matcher_->Done();
649 }
650
Value()651 const Arc& Value() const {
652 if (sigma_match_ == kNoLabel) {
653 return matcher_->Value();
654 } else {
655 sigma_arc_ = matcher_->Value();
656 if (rewrite_both_) {
657 if (sigma_arc_.ilabel == sigma_label_)
658 sigma_arc_.ilabel = sigma_match_;
659 if (sigma_arc_.olabel == sigma_label_)
660 sigma_arc_.olabel = sigma_match_;
661 } else if (match_type_ == MATCH_INPUT) {
662 sigma_arc_.ilabel = sigma_match_;
663 } else {
664 sigma_arc_.olabel = sigma_match_;
665 }
666 return sigma_arc_;
667 }
668 }
669
Next()670 void Next() {
671 matcher_->Next();
672 if (matcher_->Done() && has_sigma_ && (sigma_match_ == kNoLabel) &&
673 (match_label_ > 0)) {
674 matcher_->Find(sigma_label_);
675 sigma_match_ = match_label_;
676 }
677 }
678
GetFst()679 virtual const FST &GetFst() const { return matcher_->GetFst(); }
680
681 virtual uint64 Properties(uint64 props) const;
682
Flags()683 virtual uint32 Flags() const {
684 if (sigma_label_ == kNoLabel || match_type_ == MATCH_NONE)
685 return matcher_->Flags();
686 // kRequireMatch temporarily disabled until issues
687 // in //speech/gaudi/annotation/util/denorm are resolved.
688 // return matcher_->Flags() | kRequireMatch;
689 return matcher_->Flags();
690 }
691
692 private:
SetState_(StateId s)693 virtual void SetState_(StateId s) { SetState(s); }
Find_(Label label)694 virtual bool Find_(Label label) { return Find(label); }
Done_()695 virtual bool Done_() const { return Done(); }
Value_()696 virtual const Arc& Value_() const { return Value(); }
Next_()697 virtual void Next_() { Next(); }
698
699 M *matcher_;
700 MatchType match_type_; // Type of match requested
701 Label sigma_label_; // Label that represents the sigma transition
702 bool rewrite_both_; // Rewrite both sides when both are 'sigma_label_'
703 bool has_sigma_; // Are there sigmas at the current state?
704 Label sigma_match_; // Current label that matches sigma transition
705 mutable Arc sigma_arc_; // Arc to return when sigma match
706 Label match_label_; // Label being matched
707 bool error_; // Error encountered
708
709 void operator=(const SigmaMatcher<M> &); // disallow
710 };
711
712 template <class M> inline
Properties(uint64 inprops)713 uint64 SigmaMatcher<M>::Properties(uint64 inprops) const {
714 uint64 outprops = matcher_->Properties(inprops);
715 if (error_) outprops |= kError;
716
717 if (match_type_ == MATCH_NONE) {
718 return outprops;
719 } else if (rewrite_both_) {
720 return outprops & ~(kIDeterministic | kNonIDeterministic |
721 kODeterministic | kNonODeterministic |
722 kILabelSorted | kNotILabelSorted |
723 kOLabelSorted | kNotOLabelSorted |
724 kString);
725 } else if (match_type_ == MATCH_INPUT) {
726 return outprops & ~(kIDeterministic | kNonIDeterministic |
727 kODeterministic | kNonODeterministic |
728 kILabelSorted | kNotILabelSorted |
729 kString | kAcceptor);
730 } else if (match_type_ == MATCH_OUTPUT) {
731 return outprops & ~(kIDeterministic | kNonIDeterministic |
732 kODeterministic | kNonODeterministic |
733 kOLabelSorted | kNotOLabelSorted |
734 kString | kAcceptor);
735 } else {
736 // Shouldn't ever get here.
737 FSTERROR() << "SigmaMatcher:: bad match type: " << match_type_;
738 return 0;
739 }
740 }
741
742
743 // For any requested label that doesn't match at a state, this matcher
744 // considers the *unique* transition that matches the label 'phi_label'
745 // (phi = 'fail'), and recursively looks for a match at its
746 // destination. When 'phi_loop' is true, if no match is found but a
747 // phi self-loop is found, then the phi transition found is returned
748 // with the phi_label rewritten as the requested label (both sides if
749 // an acceptor, or if 'rewrite_both' is true and both input and output
750 // labels of the found transition are 'phi_label'). If 'phi_label' is
751 // kNoLabel, this special matching is not done. PhiMatcher is
752 // templated itself on a matcher, which is used to perform the
753 // underlying matching. By default, the underlying matcher is
754 // constructed by PhiMatcher. The user can instead pass in this
755 // object; in that case, PhiMatcher takes its ownership.
756 // Warning: phi non-determinism not supported (for simplicity).
757 template <class M>
758 class PhiMatcher : public MatcherBase<typename M::Arc> {
759 public:
760 typedef typename M::FST FST;
761 typedef typename M::Arc Arc;
762 typedef typename Arc::StateId StateId;
763 typedef typename Arc::Label Label;
764 typedef typename Arc::Weight Weight;
765
766 PhiMatcher(const FST &fst,
767 MatchType match_type,
768 Label phi_label = kNoLabel,
769 bool phi_loop = true,
770 MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
771 M *matcher = 0)
772 : matcher_(matcher ? matcher : new M(fst, match_type)),
773 match_type_(match_type),
774 phi_label_(phi_label),
775 state_(kNoStateId),
776 phi_loop_(phi_loop),
777 error_(false) {
778 if (match_type == MATCH_BOTH) {
779 FSTERROR() << "PhiMatcher: bad match type";
780 match_type_ = MATCH_NONE;
781 error_ = true;
782 }
783
784 if (rewrite_mode == MATCHER_REWRITE_AUTO)
785 rewrite_both_ = fst.Properties(kAcceptor, true);
786 else if (rewrite_mode == MATCHER_REWRITE_ALWAYS)
787 rewrite_both_ = true;
788 else
789 rewrite_both_ = false;
790 }
791
792 PhiMatcher(const PhiMatcher<M> &matcher, bool safe = false)
793 : matcher_(new M(*matcher.matcher_, safe)),
794 match_type_(matcher.match_type_),
795 phi_label_(matcher.phi_label_),
796 rewrite_both_(matcher.rewrite_both_),
797 state_(kNoStateId),
798 phi_loop_(matcher.phi_loop_),
799 error_(matcher.error_) {}
800
~PhiMatcher()801 virtual ~PhiMatcher() {
802 delete matcher_;
803 }
804
805 virtual PhiMatcher<M> *Copy(bool safe = false) const {
806 return new PhiMatcher<M>(*this, safe);
807 }
808
Type(bool test)809 virtual MatchType Type(bool test) const { return matcher_->Type(test); }
810
SetState(StateId s)811 void SetState(StateId s) {
812 matcher_->SetState(s);
813 state_ = s;
814 has_phi_ = phi_label_ != kNoLabel;
815 }
816
817 bool Find(Label match_label);
818
Done()819 bool Done() const { return matcher_->Done(); }
820
Value()821 const Arc& Value() const {
822 if ((phi_match_ == kNoLabel) && (phi_weight_ == Weight::One())) {
823 return matcher_->Value();
824 } else if (phi_match_ == 0) { // Virtual epsilon loop
825 phi_arc_ = Arc(kNoLabel, 0, Weight::One(), state_);
826 if (match_type_ == MATCH_OUTPUT)
827 swap(phi_arc_.ilabel, phi_arc_.olabel);
828 return phi_arc_;
829 } else {
830 phi_arc_ = matcher_->Value();
831 phi_arc_.weight = Times(phi_weight_, phi_arc_.weight);
832 if (phi_match_ != kNoLabel) { // Phi loop match
833 if (rewrite_both_) {
834 if (phi_arc_.ilabel == phi_label_)
835 phi_arc_.ilabel = phi_match_;
836 if (phi_arc_.olabel == phi_label_)
837 phi_arc_.olabel = phi_match_;
838 } else if (match_type_ == MATCH_INPUT) {
839 phi_arc_.ilabel = phi_match_;
840 } else {
841 phi_arc_.olabel = phi_match_;
842 }
843 }
844 return phi_arc_;
845 }
846 }
847
Next()848 void Next() { matcher_->Next(); }
849
GetFst()850 virtual const FST &GetFst() const { return matcher_->GetFst(); }
851
852 virtual uint64 Properties(uint64 props) const;
853
Flags()854 virtual uint32 Flags() const {
855 if (phi_label_ == kNoLabel || match_type_ == MATCH_NONE)
856 return matcher_->Flags();
857 return matcher_->Flags() | kRequireMatch;
858 }
859
860 private:
SetState_(StateId s)861 virtual void SetState_(StateId s) { SetState(s); }
Find_(Label label)862 virtual bool Find_(Label label) { return Find(label); }
Done_()863 virtual bool Done_() const { return Done(); }
Value_()864 virtual const Arc& Value_() const { return Value(); }
Next_()865 virtual void Next_() { Next(); }
866
867 M *matcher_;
868 MatchType match_type_; // Type of match requested
869 Label phi_label_; // Label that represents the phi transition
870 bool rewrite_both_; // Rewrite both sides when both are 'phi_label_'
871 bool has_phi_; // Are there possibly phis at the current state?
872 Label phi_match_; // Current label that matches phi loop
873 mutable Arc phi_arc_; // Arc to return
874 StateId state_; // State where looking for matches
875 Weight phi_weight_; // Product of the weights of phi transitions taken
876 bool phi_loop_; // When true, phi self-loop are allowed and treated
877 // as rho (required for Aho-Corasick)
878 bool error_; // Error encountered
879
880 void operator=(const PhiMatcher<M> &); // disallow
881 };
882
883 template <class M> inline
Find(Label match_label)884 bool PhiMatcher<M>::Find(Label match_label) {
885 if (match_label == phi_label_ && phi_label_ != kNoLabel && phi_label_ != 0) {
886 FSTERROR() << "PhiMatcher::Find: bad label (phi): " << phi_label_;
887 error_ = true;
888 return false;
889 }
890 matcher_->SetState(state_);
891 phi_match_ = kNoLabel;
892 phi_weight_ = Weight::One();
893 if (phi_label_ == 0) { // When 'phi_label_ == 0',
894 if (match_label == kNoLabel) // there are no more true epsilon arcs,
895 return false;
896 if (match_label == 0) { // but virtual eps loop need to be returned
897 if (!matcher_->Find(kNoLabel)) {
898 return matcher_->Find(0);
899 } else {
900 phi_match_ = 0;
901 return true;
902 }
903 }
904 }
905 if (!has_phi_ || match_label == 0 || match_label == kNoLabel)
906 return matcher_->Find(match_label);
907 StateId state = state_;
908 while (!matcher_->Find(match_label)) {
909 // Look for phi transition (if phi_label_ == 0, we need to look
910 // for -1 to avoid getting the virtual self-loop)
911 if (!matcher_->Find(phi_label_ == 0 ? -1 : phi_label_))
912 return false;
913 if (phi_loop_ && matcher_->Value().nextstate == state) {
914 phi_match_ = match_label;
915 return true;
916 }
917 phi_weight_ = Times(phi_weight_, matcher_->Value().weight);
918 state = matcher_->Value().nextstate;
919 matcher_->Next();
920 if (!matcher_->Done()) {
921 FSTERROR() << "PhiMatcher: phi non-determinism not supported";
922 error_ = true;
923 }
924 matcher_->SetState(state);
925 }
926 return true;
927 }
928
929 template <class M> inline
Properties(uint64 inprops)930 uint64 PhiMatcher<M>::Properties(uint64 inprops) const {
931 uint64 outprops = matcher_->Properties(inprops);
932 if (error_) outprops |= kError;
933
934 if (match_type_ == MATCH_NONE) {
935 return outprops;
936 } else if (match_type_ == MATCH_INPUT) {
937 if (phi_label_ == 0) {
938 outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons;
939 outprops |= kNoEpsilons | kNoIEpsilons;
940 }
941 if (rewrite_both_) {
942 return outprops & ~(kODeterministic | kNonODeterministic | kString |
943 kILabelSorted | kNotILabelSorted |
944 kOLabelSorted | kNotOLabelSorted);
945 } else {
946 return outprops & ~(kODeterministic | kAcceptor | kString |
947 kILabelSorted | kNotILabelSorted |
948 kOLabelSorted | kNotOLabelSorted);
949 }
950 } else if (match_type_ == MATCH_OUTPUT) {
951 if (phi_label_ == 0) {
952 outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons;
953 outprops |= kNoEpsilons | kNoOEpsilons;
954 }
955 if (rewrite_both_) {
956 return outprops & ~(kIDeterministic | kNonIDeterministic | kString |
957 kILabelSorted | kNotILabelSorted |
958 kOLabelSorted | kNotOLabelSorted);
959 } else {
960 return outprops & ~(kIDeterministic | kAcceptor | kString |
961 kILabelSorted | kNotILabelSorted |
962 kOLabelSorted | kNotOLabelSorted);
963 }
964 } else {
965 // Shouldn't ever get here.
966 FSTERROR() << "PhiMatcher:: bad match type: " << match_type_;
967 return 0;
968 }
969 }
970
971
972 //
973 // MULTI-EPS MATCHER FLAGS
974 //
975
976 // Return multi-epsilon arcs for Find(kNoLabel).
977 const uint32 kMultiEpsList = 0x00000001;
978
979 // Return a kNolabel loop for Find(multi_eps).
980 const uint32 kMultiEpsLoop = 0x00000002;
981
982 // MultiEpsMatcher: allows treating multiple non-0 labels as
983 // non-consuming labels in addition to 0 that is always
984 // non-consuming. Precise behavior controlled by 'flags' argument. By
985 // default, the underlying matcher is constructed by
986 // MultiEpsMatcher. The user can instead pass in this object; in that
987 // case, MultiEpsMatcher takes its ownership iff 'own_matcher' is
988 // true.
989 template <class M>
990 class MultiEpsMatcher {
991 public:
992 typedef typename M::FST FST;
993 typedef typename M::Arc Arc;
994 typedef typename Arc::StateId StateId;
995 typedef typename Arc::Label Label;
996 typedef typename Arc::Weight Weight;
997
998 MultiEpsMatcher(const FST &fst, MatchType match_type,
999 uint32 flags = (kMultiEpsLoop | kMultiEpsList),
1000 M *matcher = 0, bool own_matcher = true)
1001 : matcher_(matcher ? matcher : new M(fst, match_type)),
1002 flags_(flags),
1003 own_matcher_(matcher ? own_matcher : true) {
1004 if (match_type == MATCH_INPUT) {
1005 loop_.ilabel = kNoLabel;
1006 loop_.olabel = 0;
1007 } else {
1008 loop_.ilabel = 0;
1009 loop_.olabel = kNoLabel;
1010 }
1011 loop_.weight = Weight::One();
1012 loop_.nextstate = kNoStateId;
1013 }
1014
1015 MultiEpsMatcher(const MultiEpsMatcher<M> &matcher, bool safe = false)
1016 : matcher_(new M(*matcher.matcher_, safe)),
1017 flags_(matcher.flags_),
1018 own_matcher_(true),
1019 multi_eps_labels_(matcher.multi_eps_labels_),
1020 loop_(matcher.loop_) {
1021 loop_.nextstate = kNoStateId;
1022 }
1023
~MultiEpsMatcher()1024 ~MultiEpsMatcher() {
1025 if (own_matcher_)
1026 delete matcher_;
1027 }
1028
1029 MultiEpsMatcher<M> *Copy(bool safe = false) const {
1030 return new MultiEpsMatcher<M>(*this, safe);
1031 }
1032
Type(bool test)1033 MatchType Type(bool test) const { return matcher_->Type(test); }
1034
SetState(StateId s)1035 void SetState(StateId s) {
1036 matcher_->SetState(s);
1037 loop_.nextstate = s;
1038 }
1039
1040 bool Find(Label match_label);
1041
Done()1042 bool Done() const {
1043 return done_;
1044 }
1045
Value()1046 const Arc& Value() const {
1047 return current_loop_ ? loop_ : matcher_->Value();
1048 }
1049
Next()1050 void Next() {
1051 if (!current_loop_) {
1052 matcher_->Next();
1053 done_ = matcher_->Done();
1054 if (done_ && multi_eps_iter_ != multi_eps_labels_.End()) {
1055 ++multi_eps_iter_;
1056 while ((multi_eps_iter_ != multi_eps_labels_.End()) &&
1057 !matcher_->Find(*multi_eps_iter_))
1058 ++multi_eps_iter_;
1059 if (multi_eps_iter_ != multi_eps_labels_.End())
1060 done_ = false;
1061 else
1062 done_ = !matcher_->Find(kNoLabel);
1063
1064 }
1065 } else {
1066 done_ = true;
1067 }
1068 }
1069
GetFst()1070 const FST &GetFst() const { return matcher_->GetFst(); }
1071
Properties(uint64 props)1072 uint64 Properties(uint64 props) const { return matcher_->Properties(props); }
1073
Flags()1074 uint32 Flags() const { return matcher_->Flags(); }
1075
AddMultiEpsLabel(Label label)1076 void AddMultiEpsLabel(Label label) {
1077 if (label == 0) {
1078 FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0";
1079 } else {
1080 multi_eps_labels_.Insert(label);
1081 }
1082 }
1083
RemoveMultiEpsLabel(Label label)1084 void RemoveMultiEpsLabel(Label label) {
1085 if (label == 0) {
1086 FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0";
1087 } else {
1088 multi_eps_labels_.Erase(label);
1089 }
1090 }
1091
ClearMultiEpsLabels()1092 void ClearMultiEpsLabels() {
1093 multi_eps_labels_.Clear();
1094 }
1095
1096 private:
1097 M *matcher_;
1098 uint32 flags_;
1099 bool own_matcher_; // Does this class delete the matcher?
1100
1101 // Multi-eps label set
1102 CompactSet<Label, kNoLabel> multi_eps_labels_;
1103 typename CompactSet<Label, kNoLabel>::const_iterator multi_eps_iter_;
1104
1105 bool current_loop_; // Current arc is the implicit loop
1106 mutable Arc loop_; // For non-consuming symbols
1107 bool done_; // Matching done
1108
1109 void operator=(const MultiEpsMatcher<M> &); // Disallow
1110 };
1111
1112 template <class M> inline
Find(Label match_label)1113 bool MultiEpsMatcher<M>::Find(Label match_label) {
1114 multi_eps_iter_ = multi_eps_labels_.End();
1115 current_loop_ = false;
1116 bool ret;
1117 if (match_label == 0) {
1118 ret = matcher_->Find(0);
1119 } else if (match_label == kNoLabel) {
1120 if (flags_ & kMultiEpsList) {
1121 // return all non-consuming arcs (incl. epsilon)
1122 multi_eps_iter_ = multi_eps_labels_.Begin();
1123 while ((multi_eps_iter_ != multi_eps_labels_.End()) &&
1124 !matcher_->Find(*multi_eps_iter_))
1125 ++multi_eps_iter_;
1126 if (multi_eps_iter_ != multi_eps_labels_.End())
1127 ret = true;
1128 else
1129 ret = matcher_->Find(kNoLabel);
1130 } else {
1131 // return all epsilon arcs
1132 ret = matcher_->Find(kNoLabel);
1133 }
1134 } else if ((flags_ & kMultiEpsLoop) &&
1135 multi_eps_labels_.Find(match_label) != multi_eps_labels_.End()) {
1136 // return 'implicit' loop
1137 current_loop_ = true;
1138 ret = true;
1139 } else {
1140 ret = matcher_->Find(match_label);
1141 }
1142 done_ = !ret;
1143 return ret;
1144 }
1145
1146
1147 // Generic matcher, templated on the FST definition
1148 // - a wrapper around pointer to specific one.
1149 // Here is a typical use: \code
1150 // Matcher<StdFst> matcher(fst, MATCH_INPUT);
1151 // matcher.SetState(state);
1152 // if (matcher.Find(label))
1153 // for (; !matcher.Done(); matcher.Next()) {
1154 // StdArc &arc = matcher.Value();
1155 // ...
1156 // } \endcode
1157 template <class F>
1158 class Matcher {
1159 public:
1160 typedef F FST;
1161 typedef typename F::Arc Arc;
1162 typedef typename Arc::StateId StateId;
1163 typedef typename Arc::Label Label;
1164 typedef typename Arc::Weight Weight;
1165
Matcher(const F & fst,MatchType match_type)1166 Matcher(const F &fst, MatchType match_type) {
1167 base_ = fst.InitMatcher(match_type);
1168 if (!base_)
1169 base_ = new SortedMatcher<F>(fst, match_type);
1170 }
1171
1172 Matcher(const Matcher<F> &matcher, bool safe = false) {
1173 base_ = matcher.base_->Copy(safe);
1174 }
1175
1176 // Takes ownership of the provided matcher
Matcher(MatcherBase<Arc> * base_matcher)1177 Matcher(MatcherBase<Arc>* base_matcher) { base_ = base_matcher; }
1178
~Matcher()1179 ~Matcher() { delete base_; }
1180
1181 Matcher<F> *Copy(bool safe = false) const {
1182 return new Matcher<F>(*this, safe);
1183 }
1184
Type(bool test)1185 MatchType Type(bool test) const { return base_->Type(test); }
SetState(StateId s)1186 void SetState(StateId s) { base_->SetState(s); }
Find(Label label)1187 bool Find(Label label) { return base_->Find(label); }
Done()1188 bool Done() const { return base_->Done(); }
Value()1189 const Arc& Value() const { return base_->Value(); }
Next()1190 void Next() { base_->Next(); }
GetFst()1191 const F &GetFst() const { return static_cast<const F &>(base_->GetFst()); }
Properties(uint64 props)1192 uint64 Properties(uint64 props) const { return base_->Properties(props); }
Flags()1193 uint32 Flags() const { return base_->Flags() & kMatcherFlags; }
1194
1195 private:
1196 MatcherBase<Arc> *base_;
1197
1198 void operator=(const Matcher<Arc> &); // disallow
1199 };
1200
1201 } // namespace fst
1202
1203
1204
1205 #endif // FST_LIB_MATCHER_H__
1206