1 // compose.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Compose a PDT and an FST.
20 
21 #ifndef FST_EXTENSIONS_PDT_COMPOSE_H__
22 #define FST_EXTENSIONS_PDT_COMPOSE_H__
23 
24 #include <list>
25 
26 #include <fst/extensions/pdt/pdt.h>
27 #include <fst/compose.h>
28 
29 namespace fst {
30 
31 // Return paren arcs for Find(kNoLabel).
32 const uint32 kParenList =  0x00000001;
33 
34 // Return a kNolabel loop for Find(paren).
35 const uint32 kParenLoop =  0x00000002;
36 
37 // This class is a matcher that treats parens as multi-epsilon labels.
38 // It is most efficient if the parens are in a range non-overlapping with
39 // the non-paren labels.
40 template <class F>
41 class ParenMatcher {
42  public:
43   typedef SortedMatcher<F> M;
44   typedef typename M::FST FST;
45   typedef typename M::Arc Arc;
46   typedef typename Arc::StateId StateId;
47   typedef typename Arc::Label Label;
48   typedef typename Arc::Weight Weight;
49 
50   ParenMatcher(const FST &fst, MatchType match_type,
51                uint32 flags = (kParenLoop | kParenList))
matcher_(fst,match_type)52       : matcher_(fst, match_type),
53         match_type_(match_type),
54         flags_(flags) {
55     if (match_type == MATCH_INPUT) {
56       loop_.ilabel = kNoLabel;
57       loop_.olabel = 0;
58     } else {
59       loop_.ilabel = 0;
60       loop_.olabel = kNoLabel;
61     }
62     loop_.weight = Weight::One();
63     loop_.nextstate = kNoStateId;
64   }
65 
66   ParenMatcher(const ParenMatcher<F> &matcher, bool safe = false)
67       : matcher_(matcher.matcher_, safe),
68         match_type_(matcher.match_type_),
69         flags_(matcher.flags_),
70         open_parens_(matcher.open_parens_),
71         close_parens_(matcher.close_parens_),
72         loop_(matcher.loop_) {
73     loop_.nextstate = kNoStateId;
74   }
75 
76   ParenMatcher<F> *Copy(bool safe = false) const {
77     return new ParenMatcher<F>(*this, safe);
78   }
79 
Type(bool test)80   MatchType Type(bool test) const { return matcher_.Type(test); }
81 
SetState(StateId s)82   void SetState(StateId s) {
83     matcher_.SetState(s);
84     loop_.nextstate = s;
85   }
86 
87   bool Find(Label match_label);
88 
Done()89   bool Done() const {
90     return done_;
91   }
92 
Value()93   const Arc& Value() const {
94     return paren_loop_ ? loop_ : matcher_.Value();
95   }
96 
97   void Next();
98 
GetFst()99   const FST &GetFst() const { return matcher_.GetFst(); }
100 
Properties(uint64 props)101   uint64 Properties(uint64 props) const { return matcher_.Properties(props); }
102 
Flags()103   uint32 Flags() const { return matcher_.Flags(); }
104 
AddOpenParen(Label label)105   void AddOpenParen(Label label) {
106     if (label == 0) {
107       FSTERROR() << "ParenMatcher: Bad open paren label: 0";
108     } else {
109       open_parens_.Insert(label);
110     }
111   }
112 
AddCloseParen(Label label)113   void AddCloseParen(Label label) {
114     if (label == 0) {
115       FSTERROR() << "ParenMatcher: Bad close paren label: 0";
116     } else {
117       close_parens_.Insert(label);
118     }
119   }
120 
RemoveOpenParen(Label label)121   void RemoveOpenParen(Label label) {
122     if (label == 0) {
123       FSTERROR() << "ParenMatcher: Bad open paren label: 0";
124     } else {
125       open_parens_.Erase(label);
126     }
127   }
128 
RemoveCloseParen(Label label)129   void RemoveCloseParen(Label label) {
130     if (label == 0) {
131       FSTERROR() << "ParenMatcher: Bad close paren label: 0";
132     } else {
133       close_parens_.Erase(label);
134     }
135   }
136 
ClearOpenParens()137   void ClearOpenParens() {
138     open_parens_.Clear();
139   }
140 
ClearCloseParens()141   void ClearCloseParens() {
142     close_parens_.Clear();
143   }
144 
IsOpenParen(Label label)145   bool IsOpenParen(Label label) const {
146     return open_parens_.Member(label);
147   }
148 
IsCloseParen(Label label)149   bool IsCloseParen(Label label) const {
150     return close_parens_.Member(label);
151   }
152 
153  private:
154   // Advances matcher to next open paren if it exists, returning true.
155   // O.w. returns false.
156   bool NextOpenParen();
157 
158   // Advances matcher to next open paren if it exists, returning true.
159   // O.w. returns false.
160   bool NextCloseParen();
161 
162   M matcher_;
163   MatchType match_type_;          // Type of match to perform
164   uint32 flags_;
165 
166   // open paren label set
167   CompactSet<Label, kNoLabel> open_parens_;
168 
169   // close paren label set
170   CompactSet<Label, kNoLabel> close_parens_;
171 
172 
173   bool open_paren_list_;         // Matching open paren list
174   bool close_paren_list_;        // Matching close paren list
175   bool paren_loop_;              // Current arc is the implicit paren loop
176   mutable Arc loop_;             // For non-consuming symbols
177   bool done_;                    // Matching done
178 
179   void operator=(const ParenMatcher<F> &);  // Disallow
180 };
181 
182 template <class M> inline
Find(Label match_label)183 bool ParenMatcher<M>::Find(Label match_label) {
184   open_paren_list_ = false;
185   close_paren_list_ = false;
186   paren_loop_ = false;
187   done_ = false;
188 
189   // Returns all parenthesis arcs
190   if (match_label == kNoLabel && (flags_ & kParenList)) {
191     if (open_parens_.LowerBound() != kNoLabel) {
192       matcher_.LowerBound(open_parens_.LowerBound());
193       open_paren_list_ = NextOpenParen();
194       if (open_paren_list_) return true;
195     }
196     if (close_parens_.LowerBound() != kNoLabel) {
197       matcher_.LowerBound(close_parens_.LowerBound());
198       close_paren_list_ = NextCloseParen();
199       if (close_paren_list_) return true;
200     }
201   }
202 
203   // Returns 'implicit' paren loop
204   if (match_label > 0 && (flags_ & kParenLoop) &&
205       (IsOpenParen(match_label) || IsCloseParen(match_label))) {
206     paren_loop_ = true;
207     return true;
208   }
209 
210   // Returns all other labels
211   if (matcher_.Find(match_label))
212     return true;
213 
214   done_ = true;
215   return false;
216 }
217 
218 template <class F> inline
Next()219 void ParenMatcher<F>::Next() {
220   if (paren_loop_) {
221     paren_loop_ = false;
222     done_ = true;
223   } else if (open_paren_list_) {
224     matcher_.Next();
225     open_paren_list_ = NextOpenParen();
226     if (open_paren_list_) return;
227 
228     if (close_parens_.LowerBound() != kNoLabel) {
229       matcher_.LowerBound(close_parens_.LowerBound());
230       close_paren_list_ = NextCloseParen();
231       if (close_paren_list_) return;
232     }
233     done_ = !matcher_.Find(kNoLabel);
234   } else if (close_paren_list_) {
235     matcher_.Next();
236     close_paren_list_ = NextCloseParen();
237     if (close_paren_list_) return;
238     done_ = !matcher_.Find(kNoLabel);
239   } else {
240     matcher_.Next();
241     done_ = matcher_.Done();
242   }
243 }
244 
245 // Advances matcher to next open paren if it exists, returning true.
246 // O.w. returns false.
247 template <class F> inline
NextOpenParen()248 bool ParenMatcher<F>::NextOpenParen() {
249   for (; !matcher_.Done(); matcher_.Next()) {
250     Label label = match_type_ == MATCH_INPUT ?
251         matcher_.Value().ilabel : matcher_.Value().olabel;
252     if (label > open_parens_.UpperBound())
253       return false;
254     if (IsOpenParen(label))
255       return true;
256   }
257   return false;
258 }
259 
260 // Advances matcher to next close paren if it exists, returning true.
261 // O.w. returns false.
262 template <class F> inline
NextCloseParen()263 bool ParenMatcher<F>::NextCloseParen() {
264   for (; !matcher_.Done(); matcher_.Next()) {
265     Label label = match_type_ == MATCH_INPUT ?
266         matcher_.Value().ilabel : matcher_.Value().olabel;
267     if (label > close_parens_.UpperBound())
268       return false;
269     if (IsCloseParen(label))
270       return true;
271   }
272   return false;
273 }
274 
275 
276 template <class F>
277 class ParenFilter {
278  public:
279   typedef typename F::FST1 FST1;
280   typedef typename F::FST2 FST2;
281   typedef typename F::Arc Arc;
282   typedef typename Arc::StateId StateId;
283   typedef typename Arc::Label Label;
284   typedef typename Arc::Weight Weight;
285   typedef typename F::Matcher1 Matcher1;
286   typedef typename F::Matcher2 Matcher2;
287   typedef typename F::FilterState FilterState1;
288   typedef StateId StackId;
289   typedef PdtStack<StackId, Label> ParenStack;
290   typedef IntegerFilterState<StackId> FilterState2;
291   typedef PairFilterState<FilterState1, FilterState2> FilterState;
292   typedef ParenFilter<F> Filter;
293 
294   ParenFilter(const FST1 &fst1, const FST2 &fst2,
295               Matcher1 *matcher1 = 0,  Matcher2 *matcher2 = 0,
296               const vector<pair<Label, Label> > *parens = 0,
297               bool expand = false, bool keep_parens = true)
filter_(fst1,fst2,matcher1,matcher2)298       : filter_(fst1, fst2, matcher1, matcher2),
299         parens_(parens ? *parens : vector<pair<Label, Label> >()),
300         expand_(expand),
301         keep_parens_(keep_parens),
302         f_(FilterState::NoState()),
303         stack_(parens_),
304         paren_id_(-1) {
305     if (parens) {
306       for (size_t i = 0; i < parens->size(); ++i) {
307         const pair<Label, Label>  &p = (*parens)[i];
308         parens_.push_back(p);
309         GetMatcher1()->AddOpenParen(p.first);
310         GetMatcher2()->AddOpenParen(p.first);
311         if (!expand_) {
312           GetMatcher1()->AddCloseParen(p.second);
313           GetMatcher2()->AddCloseParen(p.second);
314         }
315       }
316     }
317   }
318 
319   ParenFilter(const Filter &filter, bool safe = false)
320       : filter_(filter.filter_, safe),
321         parens_(filter.parens_),
322         expand_(filter.expand_),
323         keep_parens_(filter.keep_parens_),
324         f_(FilterState::NoState()),
325         stack_(filter.parens_),
326         paren_id_(-1) { }
327 
Start()328   FilterState Start() const {
329     return FilterState(filter_.Start(), FilterState2(0));
330   }
331 
SetState(StateId s1,StateId s2,const FilterState & f)332   void SetState(StateId s1, StateId s2, const FilterState &f) {
333     f_ = f;
334     filter_.SetState(s1, s2, f_.GetState1());
335     if (!expand_)
336       return;
337 
338     ssize_t paren_id = stack_.Top(f.GetState2().GetState());
339     if (paren_id != paren_id_) {
340       if (paren_id_ != -1) {
341         GetMatcher1()->RemoveCloseParen(parens_[paren_id_].second);
342         GetMatcher2()->RemoveCloseParen(parens_[paren_id_].second);
343       }
344       paren_id_ = paren_id;
345       if (paren_id_ != -1) {
346         GetMatcher1()->AddCloseParen(parens_[paren_id_].second);
347         GetMatcher2()->AddCloseParen(parens_[paren_id_].second);
348       }
349     }
350   }
351 
FilterArc(Arc * arc1,Arc * arc2)352   FilterState FilterArc(Arc *arc1, Arc *arc2) const {
353     FilterState1 f1 = filter_.FilterArc(arc1, arc2);
354     const FilterState2 &f2 = f_.GetState2();
355     if (f1 == FilterState1::NoState())
356       return FilterState::NoState();
357 
358     if (arc1->olabel == kNoLabel && arc2->ilabel) {         // arc2 parentheses
359       if (keep_parens_) {
360         arc1->ilabel = arc2->ilabel;
361       } else if (arc2->ilabel) {
362         arc2->olabel = arc1->ilabel;
363       }
364       return FilterParen(arc2->ilabel, f1, f2);
365     } else if (arc2->ilabel == kNoLabel && arc1->olabel) {  // arc1 parentheses
366       if (keep_parens_) {
367         arc2->olabel = arc1->olabel;
368       } else {
369         arc1->ilabel = arc2->olabel;
370       }
371       return FilterParen(arc1->olabel, f1, f2);
372     } else {
373       return FilterState(f1, f2);
374     }
375   }
376 
FilterFinal(Weight * w1,Weight * w2)377   void FilterFinal(Weight *w1, Weight *w2) const {
378     if (f_.GetState2().GetState() != 0)
379       *w1 = Weight::Zero();
380     filter_.FilterFinal(w1, w2);
381   }
382 
383   // Return resp matchers. Ownership stays with filter.
GetMatcher1()384   Matcher1 *GetMatcher1() { return filter_.GetMatcher1(); }
GetMatcher2()385   Matcher2 *GetMatcher2() { return filter_.GetMatcher2(); }
386 
Properties(uint64 iprops)387   uint64 Properties(uint64 iprops) const {
388     uint64 oprops = filter_.Properties(iprops);
389     return oprops & kILabelInvariantProperties & kOLabelInvariantProperties;
390   }
391 
392  private:
FilterParen(Label label,const FilterState1 & f1,const FilterState2 & f2)393   const FilterState FilterParen(Label label, const FilterState1 &f1,
394                                 const FilterState2 &f2) const {
395     if (!expand_)
396       return FilterState(f1, f2);
397 
398     StackId stack_id = stack_.Find(f2.GetState(), label);
399     if (stack_id < 0) {
400       return FilterState::NoState();
401     } else {
402       return FilterState(f1, FilterState2(stack_id));
403     }
404   }
405 
406   F filter_;
407   vector<pair<Label, Label> > parens_;
408   bool expand_;                    // Expands to FST
409   bool keep_parens_;               // Retains parentheses in output
410   FilterState f_;                  // Current filter state
411   mutable ParenStack stack_;
412   ssize_t paren_id_;
413 };
414 
415 // Class to setup composition options for PDT composition.
416 // Default is for the PDT as the first composition argument.
417 template <class Arc, bool left_pdt = true>
418 class PdtComposeFstOptions : public
419 ComposeFstOptions<Arc,
420                   ParenMatcher< Fst<Arc> >,
421                   ParenFilter<AltSequenceComposeFilter<
422                                 ParenMatcher< Fst<Arc> > > > > {
423  public:
424   typedef typename Arc::Label Label;
425   typedef ParenMatcher< Fst<Arc> > PdtMatcher;
426   typedef ParenFilter<AltSequenceComposeFilter<PdtMatcher> > PdtFilter;
427   typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions;
428   using COptions::matcher1;
429   using COptions::matcher2;
430   using COptions::filter;
431 
432   PdtComposeFstOptions(const Fst<Arc> &ifst1,
433                     const vector<pair<Label, Label> > &parens,
434                        const Fst<Arc> &ifst2, bool expand = false,
435                        bool keep_parens = true) {
436     matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenList);
437     matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenLoop);
438 
439     filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens,
440                            expand, keep_parens);
441   }
442 };
443 
444 // Class to setup composition options for PDT with FST composition.
445 // Specialization is for the FST as the first composition argument.
446 template <class Arc>
447 class PdtComposeFstOptions<Arc, false> : public
448 ComposeFstOptions<Arc,
449                   ParenMatcher< Fst<Arc> >,
450                   ParenFilter<SequenceComposeFilter<
451                                 ParenMatcher< Fst<Arc> > > > > {
452  public:
453   typedef typename Arc::Label Label;
454   typedef ParenMatcher< Fst<Arc> > PdtMatcher;
455   typedef ParenFilter<SequenceComposeFilter<PdtMatcher> > PdtFilter;
456   typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions;
457   using COptions::matcher1;
458   using COptions::matcher2;
459   using COptions::filter;
460 
461   PdtComposeFstOptions(const Fst<Arc> &ifst1,
462                        const Fst<Arc> &ifst2,
463                        const vector<pair<Label, Label> > &parens,
464                        bool expand = false, bool keep_parens = true) {
465     matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenLoop);
466     matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenList);
467 
468     filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens,
469                            expand, keep_parens);
470   }
471 };
472 
473 enum PdtComposeFilter {
474   PAREN_FILTER,          // Bar-Hillel construction; keeps parentheses
475   EXPAND_FILTER,         // Bar-Hillel + expansion; removes parentheses
476   EXPAND_PAREN_FILTER,   // Bar-Hillel + expansion; keeps parentheses
477 };
478 
479 struct PdtComposeOptions {
480   bool connect;  // Connect output
481   PdtComposeFilter filter_type;  // Which pre-defined filter to use
482 
483   explicit PdtComposeOptions(bool c, PdtComposeFilter ft = PAREN_FILTER)
connectPdtComposeOptions484       : connect(c), filter_type(ft) {}
PdtComposeOptionsPdtComposeOptions485   PdtComposeOptions() : connect(true), filter_type(PAREN_FILTER) {}
486 };
487 
488 // Composes pushdown transducer (PDT) encoded as an FST (1st arg) and
489 // an FST (2nd arg) with the result also a PDT encoded as an Fst. (3rd arg).
490 // In the PDTs, some transitions are labeled with open or close
491 // parentheses. To be interpreted as a PDT, the parens must balance on
492 // a path (see PdtExpand()). The open-close parenthesis label pairs
493 // are passed in 'parens'.
494 template <class Arc>
495 void Compose(const Fst<Arc> &ifst1,
496              const vector<pair<typename Arc::Label,
497                                typename Arc::Label> > &parens,
498              const Fst<Arc> &ifst2,
499              MutableFst<Arc> *ofst,
500              const PdtComposeOptions &opts = PdtComposeOptions()) {
501   bool expand = opts.filter_type != PAREN_FILTER;
502   bool keep_parens = opts.filter_type != EXPAND_FILTER;
503   PdtComposeFstOptions<Arc, true> copts(ifst1, parens, ifst2,
504                                         expand, keep_parens);
505   copts.gc_limit = 0;
506   *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
507   if (opts.connect)
508     Connect(ofst);
509 }
510 
511 // Composes an FST (1st arg) and pushdown transducer (PDT) encoded as
512 // an FST (2nd arg) with the result also a PDT encoded as an Fst (3rd arg).
513 // In the PDTs, some transitions are labeled with open or close
514 // parentheses. To be interpreted as a PDT, the parens must balance on
515 // a path (see ExpandFst()). The open-close parenthesis label pairs
516 // are passed in 'parens'.
517 template <class Arc>
518 void Compose(const Fst<Arc> &ifst1,
519              const Fst<Arc> &ifst2,
520              const vector<pair<typename Arc::Label,
521                                typename Arc::Label> > &parens,
522              MutableFst<Arc> *ofst,
523              const PdtComposeOptions &opts = PdtComposeOptions()) {
524   bool expand = opts.filter_type != PAREN_FILTER;
525   bool keep_parens = opts.filter_type != EXPAND_FILTER;
526   PdtComposeFstOptions<Arc, false> copts(ifst1, ifst2, parens,
527                                          expand, keep_parens);
528   copts.gc_limit = 0;
529   *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
530   if (opts.connect)
531     Connect(ofst);
532 }
533 
534 }  // namespace fst
535 
536 #endif  // FST_EXTENSIONS_PDT_COMPOSE_H__
537