1 // lookahead-matcher.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Classes to add lookahead to FST matchers, useful e.g. for improving
20 // composition efficiency with certain inputs.
21
22 #ifndef FST_LIB_LOOKAHEAD_MATCHER_H__
23 #define FST_LIB_LOOKAHEAD_MATCHER_H__
24
25 #include <fst/add-on.h>
26 #include <fst/const-fst.h>
27 #include <fst/fst.h>
28 #include <fst/label-reachable.h>
29 #include <fst/matcher.h>
30
31
32 DECLARE_string(save_relabel_ipairs);
33 DECLARE_string(save_relabel_opairs);
34
35 namespace fst {
36
37 // LOOKAHEAD MATCHERS - these have the interface of Matchers (see
38 // matcher.h) and these additional methods:
39 //
40 // template <class F>
41 // class LookAheadMatcher {
42 // public:
43 // typedef F FST;
44 // typedef F::Arc Arc;
45 // typedef typename Arc::StateId StateId;
46 // typedef typename Arc::Label Label;
47 // typedef typename Arc::Weight Weight;
48 //
49 // // Required constructors.
50 // LookAheadMatcher(const F &fst, MatchType match_type);
51 // // If safe=true, the copy is thread-safe (except the lookahead Fst is
52 // // preserved). See Fst<>::Cop() for further doc.
53 // LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false);
54 //
55 // Below are methods for looking ahead for a match to a label and
56 // more generally, to a rational set. Each returns false if there is
57 // definitely not a match and returns true if there possibly is a
58 // match.
59
60 // // LABEL LOOKAHEAD: Can 'label' be read from the current matcher state
61 // // after possibly following epsilon transitions?
62 // bool LookAheadLabel(Label label) const;
63 //
64 // // RATIONAL LOOKAHEAD: The next methods allow looking ahead for an
65 // // arbitrary rational set of strings, specified by an FST and a state
66 // // from which to begin the matching. If the lookahead FST is a
67 // // transducer, this looks on the side different from the matcher
68 // // 'match_type' (cf. composition).
69 //
70 // // Are there paths P from 's' in the lookahead FST that can be read from
71 // // the cur. matcher state?
72 // bool LookAheadFst(const Fst<Arc>& fst, StateId s);
73 //
74 // // Gives an estimate of the combined weight of the paths P in the
75 // // lookahead and matcher FSTs for the last call to LookAheadFst.
76 // // A trivial implementation returns Weight::One(). Non-trivial
77 // // implementations are useful for weight-pushing in composition.
78 // Weight LookAheadWeight() const;
79 //
80 // // Is there is a single non-epsilon arc found in the lookahead FST
81 // // that begins P (after possibly following any epsilons) in the last
82 // // call LookAheadFst? If so, return true and copy it to '*arc', o.w.
83 // // return false. A trivial implementation returns false. Non-trivial
84 // // implementations are useful for label-pushing in composition.
85 // bool LookAheadPrefix(Arc *arc);
86 //
87 // // Optionally pre-specifies the lookahead FST that will be passed
88 // // to LookAheadFst() for possible precomputation. If copy is true,
89 // // then 'fst' is a copy of the FST used in the previous call to
90 // // this method (useful to avoid unnecessary updates).
91 // void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false);
92 //
93 // };
94
95 //
96 // LOOK-AHEAD FLAGS (see also kMatcherFlags in matcher.h):
97 //
98 // Matcher is a lookahead matcher when 'match_type' is MATCH_INPUT.
99 const uint32 kInputLookAheadMatcher = 0x00000010;
100
101 // Matcher is a lookahead matcher when 'match_type' is MATCH_OUTPUT.
102 const uint32 kOutputLookAheadMatcher = 0x00000020;
103
104 // A non-trivial implementation of LookAheadWeight() method defined and
105 // should be used?
106 const uint32 kLookAheadWeight = 0x00000040;
107
108 // A non-trivial implementation of LookAheadPrefix() method defined and
109 // should be used?
110 const uint32 kLookAheadPrefix = 0x00000080;
111
112 // Look-ahead of matcher FST non-epsilon arcs?
113 const uint32 kLookAheadNonEpsilons = 0x00000100;
114
115 // Look-ahead of matcher FST epsilon arcs?
116 const uint32 kLookAheadEpsilons = 0x00000200;
117
118 // Ignore epsilon paths for the lookahead prefix? Note this gives
119 // correct results in composition only with an appropriate composition
120 // filter since it depends on the filter blocking the ignored paths.
121 const uint32 kLookAheadNonEpsilonPrefix = 0x00000400;
122
123 // For LabelLookAheadMatcher, save relabeling data to file
124 const uint32 kLookAheadKeepRelabelData = 0x00000800;
125
126 // Flags used for lookahead matchers.
127 const uint32 kLookAheadFlags = 0x00000ff0;
128
129 // LookAhead Matcher interface, templated on the Arc definition; used
130 // for lookahead matcher specializations that are returned by the
131 // InitMatcher() Fst method.
132 template <class A>
133 class LookAheadMatcherBase : public MatcherBase<A> {
134 public:
135 typedef A Arc;
136 typedef typename A::StateId StateId;
137 typedef typename A::Label Label;
138 typedef typename A::Weight Weight;
139
LookAheadMatcherBase()140 LookAheadMatcherBase()
141 : weight_(Weight::One()),
142 prefix_arc_(kNoLabel, kNoLabel, Weight::One(), kNoStateId) {}
143
~LookAheadMatcherBase()144 virtual ~LookAheadMatcherBase() {}
145
LookAheadLabel(Label label)146 bool LookAheadLabel(Label label) const { return LookAheadLabel_(label); }
147
LookAheadFst(const Fst<Arc> & fst,StateId s)148 bool LookAheadFst(const Fst<Arc> &fst, StateId s) {
149 return LookAheadFst_(fst, s);
150 }
151
LookAheadWeight()152 Weight LookAheadWeight() const { return weight_; }
153
LookAheadPrefix(Arc * arc)154 bool LookAheadPrefix(Arc *arc) const {
155 if (prefix_arc_.nextstate != kNoStateId) {
156 *arc = prefix_arc_;
157 return true;
158 } else {
159 return false;
160 }
161 }
162
163 virtual void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) = 0;
164
165 protected:
SetLookAheadWeight(const Weight & w)166 void SetLookAheadWeight(const Weight &w) { weight_ = w; }
167
SetLookAheadPrefix(const Arc & arc)168 void SetLookAheadPrefix(const Arc &arc) { prefix_arc_ = arc; }
169
ClearLookAheadPrefix()170 void ClearLookAheadPrefix() { prefix_arc_.nextstate = kNoStateId; }
171
172 private:
173 virtual bool LookAheadLabel_(Label label) const = 0;
174 virtual bool LookAheadFst_(const Fst<Arc> &fst,
175 StateId s) = 0; // This must set l.a. weight and
176 // prefix if non-trivial.
177 Weight weight_; // Look-ahead weight
178 Arc prefix_arc_; // Look-ahead prefix arc
179 };
180
181
182 // Don't really lookahead, just declare future looks good regardless.
183 template <class M>
184 class TrivialLookAheadMatcher
185 : public LookAheadMatcherBase<typename M::FST::Arc> {
186 public:
187 typedef typename M::FST FST;
188 typedef typename M::Arc Arc;
189 typedef typename Arc::StateId StateId;
190 typedef typename Arc::Label Label;
191 typedef typename Arc::Weight Weight;
192
TrivialLookAheadMatcher(const FST & fst,MatchType match_type)193 TrivialLookAheadMatcher(const FST &fst, MatchType match_type)
194 : matcher_(fst, match_type) {}
195
196 TrivialLookAheadMatcher(const TrivialLookAheadMatcher<M> &lmatcher,
197 bool safe = false)
198 : matcher_(lmatcher.matcher_, safe) {}
199
200 // General matcher methods
201 TrivialLookAheadMatcher<M> *Copy(bool safe = false) const {
202 return new TrivialLookAheadMatcher<M>(*this, safe);
203 }
204
Type(bool test)205 MatchType Type(bool test) const { return matcher_.Type(test); }
SetState(StateId s)206 void SetState(StateId s) { return matcher_.SetState(s); }
Find(Label label)207 bool Find(Label label) { return matcher_.Find(label); }
Done()208 bool Done() const { return matcher_.Done(); }
Value()209 const Arc& Value() const { return matcher_.Value(); }
Next()210 void Next() { matcher_.Next(); }
GetFst()211 virtual const FST &GetFst() const { return matcher_.GetFst(); }
Properties(uint64 props)212 uint64 Properties(uint64 props) const { return matcher_.Properties(props); }
Flags()213 uint32 Flags() const {
214 return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher;
215 }
216
217 // Look-ahead methods.
LookAheadLabel(Label label)218 bool LookAheadLabel(Label label) const { return true; }
LookAheadFst(const Fst<Arc> & fst,StateId s)219 bool LookAheadFst(const Fst<Arc> &fst, StateId s) {return true; }
LookAheadWeight()220 Weight LookAheadWeight() const { return Weight::One(); }
LookAheadPrefix(Arc * arc)221 bool LookAheadPrefix(Arc *arc) const { return false; }
222 void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {}
223
224 private:
225 // This allows base class virtual access to non-virtual derived-
226 // class members of the same name. It makes the derived class more
227 // efficient to use but unsafe to further derive.
SetState_(StateId s)228 virtual void SetState_(StateId s) { SetState(s); }
Find_(Label label)229 virtual bool Find_(Label label) { return Find(label); }
Done_()230 virtual bool Done_() const { return Done(); }
Value_()231 virtual const Arc& Value_() const { return Value(); }
Next_()232 virtual void Next_() { Next(); }
233
LookAheadLabel_(Label l)234 bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); }
235
LookAheadFst_(const Fst<Arc> & fst,StateId s)236 bool LookAheadFst_(const Fst<Arc> &fst, StateId s) {
237 return LookAheadFst(fst, s);
238 }
239
LookAheadWeight_()240 Weight LookAheadWeight_() const { return LookAheadWeight(); }
LookAheadPrefix_(Arc * arc)241 bool LookAheadPrefix_(Arc *arc) const { return LookAheadPrefix(arc); }
242
243 M matcher_;
244 };
245
246 // Look-ahead of one transition. Template argument F accepts flags to
247 // control behavior.
248 template <class M, uint32 F = kLookAheadNonEpsilons | kLookAheadEpsilons |
249 kLookAheadWeight | kLookAheadPrefix>
250 class ArcLookAheadMatcher
251 : public LookAheadMatcherBase<typename M::FST::Arc> {
252 public:
253 typedef typename M::FST FST;
254 typedef typename M::Arc Arc;
255 typedef typename Arc::StateId StateId;
256 typedef typename Arc::Label Label;
257 typedef typename Arc::Weight Weight;
258 typedef NullAddOn MatcherData;
259
260 using LookAheadMatcherBase<Arc>::LookAheadWeight;
261 using LookAheadMatcherBase<Arc>::SetLookAheadPrefix;
262 using LookAheadMatcherBase<Arc>::SetLookAheadWeight;
263 using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix;
264
265 ArcLookAheadMatcher(const FST &fst, MatchType match_type,
266 MatcherData *data = 0)
matcher_(fst,match_type)267 : matcher_(fst, match_type),
268 fst_(matcher_.GetFst()),
269 lfst_(0),
270 s_(kNoStateId) {}
271
272 ArcLookAheadMatcher(const ArcLookAheadMatcher<M, F> &lmatcher,
273 bool safe = false)
274 : matcher_(lmatcher.matcher_, safe),
275 fst_(matcher_.GetFst()),
276 lfst_(lmatcher.lfst_),
277 s_(kNoStateId) {}
278
279 // General matcher methods
280 ArcLookAheadMatcher<M, F> *Copy(bool safe = false) const {
281 return new ArcLookAheadMatcher<M, F>(*this, safe);
282 }
283
Type(bool test)284 MatchType Type(bool test) const { return matcher_.Type(test); }
285
SetState(StateId s)286 void SetState(StateId s) {
287 s_ = s;
288 matcher_.SetState(s);
289 }
290
Find(Label label)291 bool Find(Label label) { return matcher_.Find(label); }
Done()292 bool Done() const { return matcher_.Done(); }
Value()293 const Arc& Value() const { return matcher_.Value(); }
Next()294 void Next() { matcher_.Next(); }
GetFst()295 const FST &GetFst() const { return fst_; }
Properties(uint64 props)296 uint64 Properties(uint64 props) const { return matcher_.Properties(props); }
Flags()297 uint32 Flags() const {
298 return matcher_.Flags() | kInputLookAheadMatcher |
299 kOutputLookAheadMatcher | F;
300 }
301
302 // Writable matcher methods
GetData()303 MatcherData *GetData() const { return 0; }
304
305 // Look-ahead methods.
LookAheadLabel(Label label)306 bool LookAheadLabel(Label label) const { return matcher_.Find(label); }
307
308 // Checks if there is a matching (possibly super-final) transition
309 // at (s_, s).
310 bool LookAheadFst(const Fst<Arc> &fst, StateId s);
311
312 void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {
313 lfst_ = &fst;
314 }
315
316 private:
317 // This allows base class virtual access to non-virtual derived-
318 // class members of the same name. It makes the derived class more
319 // efficient to use but unsafe to further derive.
SetState_(StateId s)320 virtual void SetState_(StateId s) { SetState(s); }
Find_(Label label)321 virtual bool Find_(Label label) { return Find(label); }
Done_()322 virtual bool Done_() const { return Done(); }
Value_()323 virtual const Arc& Value_() const { return Value(); }
Next_()324 virtual void Next_() { Next(); }
325
LookAheadLabel_(Label l)326 bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); }
LookAheadFst_(const Fst<Arc> & fst,StateId s)327 bool LookAheadFst_(const Fst<Arc> &fst, StateId s) {
328 return LookAheadFst(fst, s);
329 }
330
331 mutable M matcher_;
332 const FST &fst_; // Matcher FST
333 const Fst<Arc> *lfst_; // Look-ahead FST
334 StateId s_; // Matcher state
335 };
336
337 template <class M, uint32 F>
LookAheadFst(const Fst<Arc> & fst,StateId s)338 bool ArcLookAheadMatcher<M, F>::LookAheadFst(const Fst<Arc> &fst, StateId s) {
339 if (&fst != lfst_)
340 InitLookAheadFst(fst);
341
342 bool ret = false;
343 ssize_t nprefix = 0;
344 if (F & kLookAheadWeight)
345 SetLookAheadWeight(Weight::Zero());
346 if (F & kLookAheadPrefix)
347 ClearLookAheadPrefix();
348 if (fst_.Final(s_) != Weight::Zero() &&
349 lfst_->Final(s) != Weight::Zero()) {
350 if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
351 return true;
352 ++nprefix;
353 if (F & kLookAheadWeight)
354 SetLookAheadWeight(Plus(LookAheadWeight(),
355 Times(fst_.Final(s_), lfst_->Final(s))));
356 ret = true;
357 }
358 if (matcher_.Find(kNoLabel)) {
359 if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
360 return true;
361 ++nprefix;
362 if (F & kLookAheadWeight)
363 for (; !matcher_.Done(); matcher_.Next())
364 SetLookAheadWeight(Plus(LookAheadWeight(), matcher_.Value().weight));
365 ret = true;
366 }
367 for (ArcIterator< Fst<Arc> > aiter(*lfst_, s);
368 !aiter.Done();
369 aiter.Next()) {
370 const Arc &arc = aiter.Value();
371 Label label = kNoLabel;
372 switch (matcher_.Type(false)) {
373 case MATCH_INPUT:
374 label = arc.olabel;
375 break;
376 case MATCH_OUTPUT:
377 label = arc.ilabel;
378 break;
379 default:
380 FSTERROR() << "ArcLookAheadMatcher::LookAheadFst: bad match type";
381 return true;
382 }
383 if (label == 0) {
384 if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
385 return true;
386 if (!(F & kLookAheadNonEpsilonPrefix))
387 ++nprefix;
388 if (F & kLookAheadWeight)
389 SetLookAheadWeight(Plus(LookAheadWeight(), arc.weight));
390 ret = true;
391 } else if (matcher_.Find(label)) {
392 if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
393 return true;
394 for (; !matcher_.Done(); matcher_.Next()) {
395 ++nprefix;
396 if (F & kLookAheadWeight)
397 SetLookAheadWeight(Plus(LookAheadWeight(),
398 Times(arc.weight,
399 matcher_.Value().weight)));
400 if ((F & kLookAheadPrefix) && nprefix == 1)
401 SetLookAheadPrefix(arc);
402 }
403 ret = true;
404 }
405 }
406 if (F & kLookAheadPrefix) {
407 if (nprefix == 1)
408 SetLookAheadWeight(Weight::One()); // Avoids double counting.
409 else
410 ClearLookAheadPrefix();
411 }
412 return ret;
413 }
414
415
416 // Template argument F accepts flags to control behavior.
417 // It must include precisely one of KInputLookAheadMatcher or
418 // KOutputLookAheadMatcher.
419 template <class M, uint32 F = kLookAheadEpsilons | kLookAheadWeight |
420 kLookAheadPrefix | kLookAheadNonEpsilonPrefix |
421 kLookAheadKeepRelabelData,
422 class S = DefaultAccumulator<typename M::Arc> >
423 class LabelLookAheadMatcher
424 : public LookAheadMatcherBase<typename M::FST::Arc> {
425 public:
426 typedef typename M::FST FST;
427 typedef typename M::Arc Arc;
428 typedef typename Arc::StateId StateId;
429 typedef typename Arc::Label Label;
430 typedef typename Arc::Weight Weight;
431 typedef LabelReachableData<Label> MatcherData;
432
433 using LookAheadMatcherBase<Arc>::LookAheadWeight;
434 using LookAheadMatcherBase<Arc>::SetLookAheadPrefix;
435 using LookAheadMatcherBase<Arc>::SetLookAheadWeight;
436 using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix;
437
438 LabelLookAheadMatcher(const FST &fst, MatchType match_type,
439 MatcherData *data = 0, S *s = 0)
matcher_(fst,match_type)440 : matcher_(fst, match_type),
441 lfst_(0),
442 label_reachable_(0),
443 s_(kNoStateId),
444 error_(false) {
445 if (!(F & (kInputLookAheadMatcher | kOutputLookAheadMatcher))) {
446 FSTERROR() << "LabelLookaheadMatcher: bad matcher flags: " << F;
447 error_ = true;
448 }
449 bool reach_input = match_type == MATCH_INPUT;
450 if (data) {
451 if (reach_input == data->ReachInput())
452 label_reachable_ = new LabelReachable<Arc, S>(data, s);
453 } else if ((reach_input && (F & kInputLookAheadMatcher)) ||
454 (!reach_input && (F & kOutputLookAheadMatcher))) {
455 label_reachable_ = new LabelReachable<Arc, S>(
456 fst, reach_input, s, F & kLookAheadKeepRelabelData);
457 }
458 }
459
460 LabelLookAheadMatcher(const LabelLookAheadMatcher<M, F, S> &lmatcher,
461 bool safe = false)
462 : matcher_(lmatcher.matcher_, safe),
463 lfst_(lmatcher.lfst_),
464 label_reachable_(
465 lmatcher.label_reachable_ ?
466 new LabelReachable<Arc, S>(*lmatcher.label_reachable_) : 0),
467 s_(kNoStateId),
468 error_(lmatcher.error_) {}
469
~LabelLookAheadMatcher()470 ~LabelLookAheadMatcher() {
471 delete label_reachable_;
472 }
473
474 // General matcher methods
475 LabelLookAheadMatcher<M, F, S> *Copy(bool safe = false) const {
476 return new LabelLookAheadMatcher<M, F, S>(*this, safe);
477 }
478
Type(bool test)479 MatchType Type(bool test) const { return matcher_.Type(test); }
480
SetState(StateId s)481 void SetState(StateId s) {
482 if (s_ == s)
483 return;
484 s_ = s;
485 match_set_state_ = false;
486 reach_set_state_ = false;
487 }
488
Find(Label label)489 bool Find(Label label) {
490 if (!match_set_state_) {
491 matcher_.SetState(s_);
492 match_set_state_ = true;
493 }
494 return matcher_.Find(label);
495 }
496
Done()497 bool Done() const { return matcher_.Done(); }
Value()498 const Arc& Value() const { return matcher_.Value(); }
Next()499 void Next() { matcher_.Next(); }
GetFst()500 const FST &GetFst() const { return matcher_.GetFst(); }
501
Properties(uint64 inprops)502 uint64 Properties(uint64 inprops) const {
503 uint64 outprops = matcher_.Properties(inprops);
504 if (error_ || (label_reachable_ && label_reachable_->Error()))
505 outprops |= kError;
506 return outprops;
507 }
508
Flags()509 uint32 Flags() const {
510 if (label_reachable_ && label_reachable_->GetData()->ReachInput())
511 return matcher_.Flags() | F | kInputLookAheadMatcher;
512 else if (label_reachable_ && !label_reachable_->GetData()->ReachInput())
513 return matcher_.Flags() | F | kOutputLookAheadMatcher;
514 else
515 return matcher_.Flags();
516 }
517
518 // Writable matcher methods
GetData()519 MatcherData *GetData() const {
520 return label_reachable_ ? label_reachable_->GetData() : 0;
521 };
522
523 // Look-ahead methods.
LookAheadLabel(Label label)524 bool LookAheadLabel(Label label) const {
525 if (label == 0)
526 return true;
527
528 if (label_reachable_) {
529 if (!reach_set_state_) {
530 label_reachable_->SetState(s_);
531 reach_set_state_ = true;
532 }
533 return label_reachable_->Reach(label);
534 } else {
535 return true;
536 }
537 }
538
539 // Checks if there is a matching (possibly super-final) transition
540 // at (s_, s).
541 template <class L>
542 bool LookAheadFst(const L &fst, StateId s);
543
544 void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {
545 lfst_ = &fst;
546 if (label_reachable_)
547 label_reachable_->ReachInit(fst, copy);
548 }
549
550 template <class L>
551 void InitLookAheadFst(const L& fst, bool copy = false) {
552 lfst_ = static_cast<const Fst<Arc> *>(&fst);
553 if (label_reachable_)
554 label_reachable_->ReachInit(fst, copy);
555 }
556
557 private:
558 // This allows base class virtual access to non-virtual derived-
559 // class members of the same name. It makes the derived class more
560 // efficient to use but unsafe to further derive.
SetState_(StateId s)561 virtual void SetState_(StateId s) { SetState(s); }
Find_(Label label)562 virtual bool Find_(Label label) { return Find(label); }
Done_()563 virtual bool Done_() const { return Done(); }
Value_()564 virtual const Arc& Value_() const { return Value(); }
Next_()565 virtual void Next_() { Next(); }
566
LookAheadLabel_(Label l)567 bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); }
LookAheadFst_(const Fst<Arc> & fst,StateId s)568 bool LookAheadFst_(const Fst<Arc> &fst, StateId s) {
569 return LookAheadFst(fst, s);
570 }
571
572 mutable M matcher_;
573 const Fst<Arc> *lfst_; // Look-ahead FST
574 LabelReachable<Arc, S> *label_reachable_; // Label reachability info
575 StateId s_; // Matcher state
576 bool match_set_state_; // matcher_.SetState called?
577 mutable bool reach_set_state_; // reachable_.SetState called?
578 bool error_;
579 };
580
581 template <class M, uint32 F, class S>
582 template <class L> inline
LookAheadFst(const L & fst,StateId s)583 bool LabelLookAheadMatcher<M, F, S>::LookAheadFst(const L &fst, StateId s) {
584 if (static_cast<const Fst<Arc> *>(&fst) != lfst_)
585 InitLookAheadFst(fst);
586
587 SetLookAheadWeight(Weight::One());
588 ClearLookAheadPrefix();
589
590 if (!label_reachable_)
591 return true;
592
593 label_reachable_->SetState(s_, s);
594 reach_set_state_ = true;
595
596 bool compute_weight = F & kLookAheadWeight;
597 bool compute_prefix = F & kLookAheadPrefix;
598
599 bool reach_input = Type(false) == MATCH_OUTPUT;
600 ArcIterator<L> aiter(fst, s);
601 bool reach_arc = label_reachable_->Reach(&aiter, 0,
602 internal::NumArcs(*lfst_, s),
603 reach_input, compute_weight);
604 Weight lfinal = internal::Final(*lfst_, s);
605 bool reach_final = lfinal != Weight::Zero() && label_reachable_->ReachFinal();
606 if (reach_arc) {
607 ssize_t begin = label_reachable_->ReachBegin();
608 ssize_t end = label_reachable_->ReachEnd();
609 if (compute_prefix && end - begin == 1 && !reach_final) {
610 aiter.Seek(begin);
611 SetLookAheadPrefix(aiter.Value());
612 compute_weight = false;
613 } else if (compute_weight) {
614 SetLookAheadWeight(label_reachable_->ReachWeight());
615 }
616 }
617 if (reach_final && compute_weight)
618 SetLookAheadWeight(reach_arc ?
619 Plus(LookAheadWeight(), lfinal) : lfinal);
620
621 return reach_arc || reach_final;
622 }
623
624
625 // Label-lookahead relabeling class.
626 template <class A>
627 class LabelLookAheadRelabeler {
628 public:
629 typedef typename A::Label Label;
630 typedef LabelReachableData<Label> MatcherData;
631 typedef AddOnPair<MatcherData, MatcherData> D;
632
633 // Relabels matcher Fst - initialization function object.
634 template <typename I>
635 LabelLookAheadRelabeler(I **impl);
636
637 // Relabels arbitrary Fst. Class L should be a label-lookahead Fst.
638 template <class L>
Relabel(MutableFst<A> * fst,const L & mfst,bool relabel_input)639 static void Relabel(MutableFst<A> *fst, const L &mfst,
640 bool relabel_input) {
641 typename L::Impl *impl = mfst.GetImpl();
642 D *data = impl->GetAddOn();
643 LabelReachable<A> reachable(data->First() ?
644 data->First() : data->Second());
645 reachable.Relabel(fst, relabel_input);
646 }
647
648 // Returns relabeling pairs (cf. relabel.h::Relabel()).
649 // Class L should be a label-lookahead Fst.
650 // If 'avoid_collisions' is true, extra pairs are added to
651 // ensure no collisions when relabeling automata that have
652 // labels unseen here.
653 template <class L>
654 static void RelabelPairs(const L &mfst, vector<pair<Label, Label> > *pairs,
655 bool avoid_collisions = false) {
656 typename L::Impl *impl = mfst.GetImpl();
657 D *data = impl->GetAddOn();
658 LabelReachable<A> reachable(data->First() ?
659 data->First() : data->Second());
660 reachable.RelabelPairs(pairs, avoid_collisions);
661 }
662 };
663
664 template <class A>
665 template <typename I> inline
LabelLookAheadRelabeler(I ** impl)666 LabelLookAheadRelabeler<A>::LabelLookAheadRelabeler(I **impl) {
667 Fst<A> &fst = (*impl)->GetFst();
668 D *data = (*impl)->GetAddOn();
669 const string name = (*impl)->Type();
670 bool is_mutable = fst.Properties(kMutable, false);
671 MutableFst<A> *mfst = 0;
672 if (is_mutable) {
673 mfst = static_cast<MutableFst<A> *>(&fst);
674 } else {
675 mfst = new VectorFst<A>(fst);
676 data->IncrRefCount();
677 delete *impl;
678 }
679 if (data->First()) { // reach_input
680 LabelReachable<A> reachable(data->First());
681 reachable.Relabel(mfst, true);
682 if (!FLAGS_save_relabel_ipairs.empty()) {
683 vector<pair<Label, Label> > pairs;
684 reachable.RelabelPairs(&pairs, true);
685 WriteLabelPairs(FLAGS_save_relabel_ipairs, pairs);
686 }
687 } else {
688 LabelReachable<A> reachable(data->Second());
689 reachable.Relabel(mfst, false);
690 if (!FLAGS_save_relabel_opairs.empty()) {
691 vector<pair<Label, Label> > pairs;
692 reachable.RelabelPairs(&pairs, true);
693 WriteLabelPairs(FLAGS_save_relabel_opairs, pairs);
694 }
695 }
696 if (!is_mutable) {
697 *impl = new I(*mfst, name);
698 (*impl)->SetAddOn(data);
699 delete mfst;
700 data->DecrRefCount();
701 }
702 }
703
704
705 // Generic lookahead matcher, templated on the FST definition
706 // - a wrapper around pointer to specific one.
707 template <class F>
708 class LookAheadMatcher {
709 public:
710 typedef F FST;
711 typedef typename F::Arc Arc;
712 typedef typename Arc::StateId StateId;
713 typedef typename Arc::Label Label;
714 typedef typename Arc::Weight Weight;
715 typedef LookAheadMatcherBase<Arc> LBase;
716
LookAheadMatcher(const F & fst,MatchType match_type)717 LookAheadMatcher(const F &fst, MatchType match_type) {
718 base_ = fst.InitMatcher(match_type);
719 if (!base_)
720 base_ = new SortedMatcher<F>(fst, match_type);
721 lookahead_ = false;
722 }
723
724 LookAheadMatcher(const LookAheadMatcher<F> &matcher, bool safe = false) {
725 base_ = matcher.base_->Copy(safe);
726 lookahead_ = matcher.lookahead_;
727 }
728
~LookAheadMatcher()729 ~LookAheadMatcher() { delete base_; }
730
731 // General matcher methods
732 LookAheadMatcher<F> *Copy(bool safe = false) const {
733 return new LookAheadMatcher<F>(*this, safe);
734 }
735
Type(bool test)736 MatchType Type(bool test) const { return base_->Type(test); }
SetState(StateId s)737 void SetState(StateId s) { base_->SetState(s); }
Find(Label label)738 bool Find(Label label) { return base_->Find(label); }
Done()739 bool Done() const { return base_->Done(); }
Value()740 const Arc& Value() const { return base_->Value(); }
Next()741 void Next() { base_->Next(); }
GetFst()742 const F &GetFst() const { return static_cast<const F &>(base_->GetFst()); }
743
Properties(uint64 props)744 uint64 Properties(uint64 props) const { return base_->Properties(props); }
745
Flags()746 uint32 Flags() const { return base_->Flags(); }
747
748 // Look-ahead methods
LookAheadLabel(Label label)749 bool LookAheadLabel(Label label) const {
750 if (LookAheadCheck()) {
751 LBase *lbase = static_cast<LBase *>(base_);
752 return lbase->LookAheadLabel(label);
753 } else {
754 return true;
755 }
756 }
757
LookAheadFst(const Fst<Arc> & fst,StateId s)758 bool LookAheadFst(const Fst<Arc> &fst, StateId s) {
759 if (LookAheadCheck()) {
760 LBase *lbase = static_cast<LBase *>(base_);
761 return lbase->LookAheadFst(fst, s);
762 } else {
763 return true;
764 }
765 }
766
LookAheadWeight()767 Weight LookAheadWeight() const {
768 if (LookAheadCheck()) {
769 LBase *lbase = static_cast<LBase *>(base_);
770 return lbase->LookAheadWeight();
771 } else {
772 return Weight::One();
773 }
774 }
775
LookAheadPrefix(Arc * arc)776 bool LookAheadPrefix(Arc *arc) const {
777 if (LookAheadCheck()) {
778 LBase *lbase = static_cast<LBase *>(base_);
779 return lbase->LookAheadPrefix(arc);
780 } else {
781 return false;
782 }
783 }
784
785 void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {
786 if (LookAheadCheck()) {
787 LBase *lbase = static_cast<LBase *>(base_);
788 lbase->InitLookAheadFst(fst, copy);
789 }
790 }
791
792 private:
LookAheadCheck()793 bool LookAheadCheck() const {
794 if (!lookahead_) {
795 lookahead_ = base_->Flags() &
796 (kInputLookAheadMatcher | kOutputLookAheadMatcher);
797 if (!lookahead_) {
798 FSTERROR() << "LookAheadMatcher: No look-ahead matcher defined";
799 }
800 }
801 return lookahead_;
802 }
803
804 MatcherBase<Arc> *base_;
805 mutable bool lookahead_;
806
807 void operator=(const LookAheadMatcher<Arc> &); // disallow
808 };
809
810 } // namespace fst
811
812 #endif // FST_LIB_LOOKAHEAD_MATCHER_H__
813