1 // compose.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Compose a PDT and an FST.
20
21 #ifndef FST_EXTENSIONS_PDT_COMPOSE_H__
22 #define FST_EXTENSIONS_PDT_COMPOSE_H__
23
24 #include <list>
25
26 #include <fst/extensions/pdt/pdt.h>
27 #include <fst/compose.h>
28
29 namespace fst {
30
31 // Return paren arcs for Find(kNoLabel).
32 const uint32 kParenList = 0x00000001;
33
34 // Return a kNolabel loop for Find(paren).
35 const uint32 kParenLoop = 0x00000002;
36
37 // This class is a matcher that treats parens as multi-epsilon labels.
38 // It is most efficient if the parens are in a range non-overlapping with
39 // the non-paren labels.
40 template <class F>
41 class ParenMatcher {
42 public:
43 typedef SortedMatcher<F> M;
44 typedef typename M::FST FST;
45 typedef typename M::Arc Arc;
46 typedef typename Arc::StateId StateId;
47 typedef typename Arc::Label Label;
48 typedef typename Arc::Weight Weight;
49
50 ParenMatcher(const FST &fst, MatchType match_type,
51 uint32 flags = (kParenLoop | kParenList))
matcher_(fst,match_type)52 : matcher_(fst, match_type),
53 match_type_(match_type),
54 flags_(flags) {
55 if (match_type == MATCH_INPUT) {
56 loop_.ilabel = kNoLabel;
57 loop_.olabel = 0;
58 } else {
59 loop_.ilabel = 0;
60 loop_.olabel = kNoLabel;
61 }
62 loop_.weight = Weight::One();
63 loop_.nextstate = kNoStateId;
64 }
65
66 ParenMatcher(const ParenMatcher<F> &matcher, bool safe = false)
67 : matcher_(matcher.matcher_, safe),
68 match_type_(matcher.match_type_),
69 flags_(matcher.flags_),
70 open_parens_(matcher.open_parens_),
71 close_parens_(matcher.close_parens_),
72 loop_(matcher.loop_) {
73 loop_.nextstate = kNoStateId;
74 }
75
76 ParenMatcher<F> *Copy(bool safe = false) const {
77 return new ParenMatcher<F>(*this, safe);
78 }
79
Type(bool test)80 MatchType Type(bool test) const { return matcher_.Type(test); }
81
SetState(StateId s)82 void SetState(StateId s) {
83 matcher_.SetState(s);
84 loop_.nextstate = s;
85 }
86
87 bool Find(Label match_label);
88
Done()89 bool Done() const {
90 return done_;
91 }
92
Value()93 const Arc& Value() const {
94 return paren_loop_ ? loop_ : matcher_.Value();
95 }
96
97 void Next();
98
GetFst()99 const FST &GetFst() const { return matcher_.GetFst(); }
100
Properties(uint64 props)101 uint64 Properties(uint64 props) const { return matcher_.Properties(props); }
102
Flags()103 uint32 Flags() const { return matcher_.Flags(); }
104
AddOpenParen(Label label)105 void AddOpenParen(Label label) {
106 if (label == 0) {
107 FSTERROR() << "ParenMatcher: Bad open paren label: 0";
108 } else {
109 open_parens_.Insert(label);
110 }
111 }
112
AddCloseParen(Label label)113 void AddCloseParen(Label label) {
114 if (label == 0) {
115 FSTERROR() << "ParenMatcher: Bad close paren label: 0";
116 } else {
117 close_parens_.Insert(label);
118 }
119 }
120
RemoveOpenParen(Label label)121 void RemoveOpenParen(Label label) {
122 if (label == 0) {
123 FSTERROR() << "ParenMatcher: Bad open paren label: 0";
124 } else {
125 open_parens_.Erase(label);
126 }
127 }
128
RemoveCloseParen(Label label)129 void RemoveCloseParen(Label label) {
130 if (label == 0) {
131 FSTERROR() << "ParenMatcher: Bad close paren label: 0";
132 } else {
133 close_parens_.Erase(label);
134 }
135 }
136
ClearOpenParens()137 void ClearOpenParens() {
138 open_parens_.Clear();
139 }
140
ClearCloseParens()141 void ClearCloseParens() {
142 close_parens_.Clear();
143 }
144
IsOpenParen(Label label)145 bool IsOpenParen(Label label) const {
146 return open_parens_.Member(label);
147 }
148
IsCloseParen(Label label)149 bool IsCloseParen(Label label) const {
150 return close_parens_.Member(label);
151 }
152
153 private:
154 // Advances matcher to next open paren if it exists, returning true.
155 // O.w. returns false.
156 bool NextOpenParen();
157
158 // Advances matcher to next open paren if it exists, returning true.
159 // O.w. returns false.
160 bool NextCloseParen();
161
162 M matcher_;
163 MatchType match_type_; // Type of match to perform
164 uint32 flags_;
165
166 // open paren label set
167 CompactSet<Label, kNoLabel> open_parens_;
168
169 // close paren label set
170 CompactSet<Label, kNoLabel> close_parens_;
171
172
173 bool open_paren_list_; // Matching open paren list
174 bool close_paren_list_; // Matching close paren list
175 bool paren_loop_; // Current arc is the implicit paren loop
176 mutable Arc loop_; // For non-consuming symbols
177 bool done_; // Matching done
178
179 void operator=(const ParenMatcher<F> &); // Disallow
180 };
181
182 template <class M> inline
Find(Label match_label)183 bool ParenMatcher<M>::Find(Label match_label) {
184 open_paren_list_ = false;
185 close_paren_list_ = false;
186 paren_loop_ = false;
187 done_ = false;
188
189 // Returns all parenthesis arcs
190 if (match_label == kNoLabel && (flags_ & kParenList)) {
191 if (open_parens_.LowerBound() != kNoLabel) {
192 matcher_.LowerBound(open_parens_.LowerBound());
193 open_paren_list_ = NextOpenParen();
194 if (open_paren_list_) return true;
195 }
196 if (close_parens_.LowerBound() != kNoLabel) {
197 matcher_.LowerBound(close_parens_.LowerBound());
198 close_paren_list_ = NextCloseParen();
199 if (close_paren_list_) return true;
200 }
201 }
202
203 // Returns 'implicit' paren loop
204 if (match_label > 0 && (flags_ & kParenLoop) &&
205 (IsOpenParen(match_label) || IsCloseParen(match_label))) {
206 paren_loop_ = true;
207 return true;
208 }
209
210 // Returns all other labels
211 if (matcher_.Find(match_label))
212 return true;
213
214 done_ = true;
215 return false;
216 }
217
218 template <class F> inline
Next()219 void ParenMatcher<F>::Next() {
220 if (paren_loop_) {
221 paren_loop_ = false;
222 done_ = true;
223 } else if (open_paren_list_) {
224 matcher_.Next();
225 open_paren_list_ = NextOpenParen();
226 if (open_paren_list_) return;
227
228 if (close_parens_.LowerBound() != kNoLabel) {
229 matcher_.LowerBound(close_parens_.LowerBound());
230 close_paren_list_ = NextCloseParen();
231 if (close_paren_list_) return;
232 }
233 done_ = !matcher_.Find(kNoLabel);
234 } else if (close_paren_list_) {
235 matcher_.Next();
236 close_paren_list_ = NextCloseParen();
237 if (close_paren_list_) return;
238 done_ = !matcher_.Find(kNoLabel);
239 } else {
240 matcher_.Next();
241 done_ = matcher_.Done();
242 }
243 }
244
245 // Advances matcher to next open paren if it exists, returning true.
246 // O.w. returns false.
247 template <class F> inline
NextOpenParen()248 bool ParenMatcher<F>::NextOpenParen() {
249 for (; !matcher_.Done(); matcher_.Next()) {
250 Label label = match_type_ == MATCH_INPUT ?
251 matcher_.Value().ilabel : matcher_.Value().olabel;
252 if (label > open_parens_.UpperBound())
253 return false;
254 if (IsOpenParen(label))
255 return true;
256 }
257 return false;
258 }
259
260 // Advances matcher to next close paren if it exists, returning true.
261 // O.w. returns false.
262 template <class F> inline
NextCloseParen()263 bool ParenMatcher<F>::NextCloseParen() {
264 for (; !matcher_.Done(); matcher_.Next()) {
265 Label label = match_type_ == MATCH_INPUT ?
266 matcher_.Value().ilabel : matcher_.Value().olabel;
267 if (label > close_parens_.UpperBound())
268 return false;
269 if (IsCloseParen(label))
270 return true;
271 }
272 return false;
273 }
274
275
276 template <class F>
277 class ParenFilter {
278 public:
279 typedef typename F::FST1 FST1;
280 typedef typename F::FST2 FST2;
281 typedef typename F::Arc Arc;
282 typedef typename Arc::StateId StateId;
283 typedef typename Arc::Label Label;
284 typedef typename Arc::Weight Weight;
285 typedef typename F::Matcher1 Matcher1;
286 typedef typename F::Matcher2 Matcher2;
287 typedef typename F::FilterState FilterState1;
288 typedef StateId StackId;
289 typedef PdtStack<StackId, Label> ParenStack;
290 typedef IntegerFilterState<StackId> FilterState2;
291 typedef PairFilterState<FilterState1, FilterState2> FilterState;
292 typedef ParenFilter<F> Filter;
293
294 ParenFilter(const FST1 &fst1, const FST2 &fst2,
295 Matcher1 *matcher1 = 0, Matcher2 *matcher2 = 0,
296 const vector<pair<Label, Label> > *parens = 0,
297 bool expand = false, bool keep_parens = true)
filter_(fst1,fst2,matcher1,matcher2)298 : filter_(fst1, fst2, matcher1, matcher2),
299 parens_(parens ? *parens : vector<pair<Label, Label> >()),
300 expand_(expand),
301 keep_parens_(keep_parens),
302 f_(FilterState::NoState()),
303 stack_(parens_),
304 paren_id_(-1) {
305 if (parens) {
306 for (size_t i = 0; i < parens->size(); ++i) {
307 const pair<Label, Label> &p = (*parens)[i];
308 parens_.push_back(p);
309 GetMatcher1()->AddOpenParen(p.first);
310 GetMatcher2()->AddOpenParen(p.first);
311 if (!expand_) {
312 GetMatcher1()->AddCloseParen(p.second);
313 GetMatcher2()->AddCloseParen(p.second);
314 }
315 }
316 }
317 }
318
319 ParenFilter(const Filter &filter, bool safe = false)
320 : filter_(filter.filter_, safe),
321 parens_(filter.parens_),
322 expand_(filter.expand_),
323 keep_parens_(filter.keep_parens_),
324 f_(FilterState::NoState()),
325 stack_(filter.parens_),
326 paren_id_(-1) { }
327
Start()328 FilterState Start() const {
329 return FilterState(filter_.Start(), FilterState2(0));
330 }
331
SetState(StateId s1,StateId s2,const FilterState & f)332 void SetState(StateId s1, StateId s2, const FilterState &f) {
333 f_ = f;
334 filter_.SetState(s1, s2, f_.GetState1());
335 if (!expand_)
336 return;
337
338 ssize_t paren_id = stack_.Top(f.GetState2().GetState());
339 if (paren_id != paren_id_) {
340 if (paren_id_ != -1) {
341 GetMatcher1()->RemoveCloseParen(parens_[paren_id_].second);
342 GetMatcher2()->RemoveCloseParen(parens_[paren_id_].second);
343 }
344 paren_id_ = paren_id;
345 if (paren_id_ != -1) {
346 GetMatcher1()->AddCloseParen(parens_[paren_id_].second);
347 GetMatcher2()->AddCloseParen(parens_[paren_id_].second);
348 }
349 }
350 }
351
FilterArc(Arc * arc1,Arc * arc2)352 FilterState FilterArc(Arc *arc1, Arc *arc2) const {
353 FilterState1 f1 = filter_.FilterArc(arc1, arc2);
354 const FilterState2 &f2 = f_.GetState2();
355 if (f1 == FilterState1::NoState())
356 return FilterState::NoState();
357
358 if (arc1->olabel == kNoLabel && arc2->ilabel) { // arc2 parentheses
359 if (keep_parens_) {
360 arc1->ilabel = arc2->ilabel;
361 } else if (arc2->ilabel) {
362 arc2->olabel = arc1->ilabel;
363 }
364 return FilterParen(arc2->ilabel, f1, f2);
365 } else if (arc2->ilabel == kNoLabel && arc1->olabel) { // arc1 parentheses
366 if (keep_parens_) {
367 arc2->olabel = arc1->olabel;
368 } else {
369 arc1->ilabel = arc2->olabel;
370 }
371 return FilterParen(arc1->olabel, f1, f2);
372 } else {
373 return FilterState(f1, f2);
374 }
375 }
376
FilterFinal(Weight * w1,Weight * w2)377 void FilterFinal(Weight *w1, Weight *w2) const {
378 if (f_.GetState2().GetState() != 0)
379 *w1 = Weight::Zero();
380 filter_.FilterFinal(w1, w2);
381 }
382
383 // Return resp matchers. Ownership stays with filter.
GetMatcher1()384 Matcher1 *GetMatcher1() { return filter_.GetMatcher1(); }
GetMatcher2()385 Matcher2 *GetMatcher2() { return filter_.GetMatcher2(); }
386
Properties(uint64 iprops)387 uint64 Properties(uint64 iprops) const {
388 uint64 oprops = filter_.Properties(iprops);
389 return oprops & kILabelInvariantProperties & kOLabelInvariantProperties;
390 }
391
392 private:
FilterParen(Label label,const FilterState1 & f1,const FilterState2 & f2)393 const FilterState FilterParen(Label label, const FilterState1 &f1,
394 const FilterState2 &f2) const {
395 if (!expand_)
396 return FilterState(f1, f2);
397
398 StackId stack_id = stack_.Find(f2.GetState(), label);
399 if (stack_id < 0) {
400 return FilterState::NoState();
401 } else {
402 return FilterState(f1, FilterState2(stack_id));
403 }
404 }
405
406 F filter_;
407 vector<pair<Label, Label> > parens_;
408 bool expand_; // Expands to FST
409 bool keep_parens_; // Retains parentheses in output
410 FilterState f_; // Current filter state
411 mutable ParenStack stack_;
412 ssize_t paren_id_;
413 };
414
415 // Class to setup composition options for PDT composition.
416 // Default is for the PDT as the first composition argument.
417 template <class Arc, bool left_pdt = true>
418 class PdtComposeFstOptions : public
419 ComposeFstOptions<Arc,
420 ParenMatcher< Fst<Arc> >,
421 ParenFilter<AltSequenceComposeFilter<
422 ParenMatcher< Fst<Arc> > > > > {
423 public:
424 typedef typename Arc::Label Label;
425 typedef ParenMatcher< Fst<Arc> > PdtMatcher;
426 typedef ParenFilter<AltSequenceComposeFilter<PdtMatcher> > PdtFilter;
427 typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions;
428 using COptions::matcher1;
429 using COptions::matcher2;
430 using COptions::filter;
431
432 PdtComposeFstOptions(const Fst<Arc> &ifst1,
433 const vector<pair<Label, Label> > &parens,
434 const Fst<Arc> &ifst2, bool expand = false,
435 bool keep_parens = true) {
436 matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenList);
437 matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenLoop);
438
439 filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens,
440 expand, keep_parens);
441 }
442 };
443
444 // Class to setup composition options for PDT with FST composition.
445 // Specialization is for the FST as the first composition argument.
446 template <class Arc>
447 class PdtComposeFstOptions<Arc, false> : public
448 ComposeFstOptions<Arc,
449 ParenMatcher< Fst<Arc> >,
450 ParenFilter<SequenceComposeFilter<
451 ParenMatcher< Fst<Arc> > > > > {
452 public:
453 typedef typename Arc::Label Label;
454 typedef ParenMatcher< Fst<Arc> > PdtMatcher;
455 typedef ParenFilter<SequenceComposeFilter<PdtMatcher> > PdtFilter;
456 typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions;
457 using COptions::matcher1;
458 using COptions::matcher2;
459 using COptions::filter;
460
461 PdtComposeFstOptions(const Fst<Arc> &ifst1,
462 const Fst<Arc> &ifst2,
463 const vector<pair<Label, Label> > &parens,
464 bool expand = false, bool keep_parens = true) {
465 matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenLoop);
466 matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenList);
467
468 filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens,
469 expand, keep_parens);
470 }
471 };
472
473 enum PdtComposeFilter {
474 PAREN_FILTER, // Bar-Hillel construction; keeps parentheses
475 EXPAND_FILTER, // Bar-Hillel + expansion; removes parentheses
476 EXPAND_PAREN_FILTER, // Bar-Hillel + expansion; keeps parentheses
477 };
478
479 struct PdtComposeOptions {
480 bool connect; // Connect output
481 PdtComposeFilter filter_type; // Which pre-defined filter to use
482
483 explicit PdtComposeOptions(bool c, PdtComposeFilter ft = PAREN_FILTER)
connectPdtComposeOptions484 : connect(c), filter_type(ft) {}
PdtComposeOptionsPdtComposeOptions485 PdtComposeOptions() : connect(true), filter_type(PAREN_FILTER) {}
486 };
487
488 // Composes pushdown transducer (PDT) encoded as an FST (1st arg) and
489 // an FST (2nd arg) with the result also a PDT encoded as an Fst. (3rd arg).
490 // In the PDTs, some transitions are labeled with open or close
491 // parentheses. To be interpreted as a PDT, the parens must balance on
492 // a path (see PdtExpand()). The open-close parenthesis label pairs
493 // are passed in 'parens'.
494 template <class Arc>
495 void Compose(const Fst<Arc> &ifst1,
496 const vector<pair<typename Arc::Label,
497 typename Arc::Label> > &parens,
498 const Fst<Arc> &ifst2,
499 MutableFst<Arc> *ofst,
500 const PdtComposeOptions &opts = PdtComposeOptions()) {
501 bool expand = opts.filter_type != PAREN_FILTER;
502 bool keep_parens = opts.filter_type != EXPAND_FILTER;
503 PdtComposeFstOptions<Arc, true> copts(ifst1, parens, ifst2,
504 expand, keep_parens);
505 copts.gc_limit = 0;
506 *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
507 if (opts.connect)
508 Connect(ofst);
509 }
510
511 // Composes an FST (1st arg) and pushdown transducer (PDT) encoded as
512 // an FST (2nd arg) with the result also a PDT encoded as an Fst (3rd arg).
513 // In the PDTs, some transitions are labeled with open or close
514 // parentheses. To be interpreted as a PDT, the parens must balance on
515 // a path (see ExpandFst()). The open-close parenthesis label pairs
516 // are passed in 'parens'.
517 template <class Arc>
518 void Compose(const Fst<Arc> &ifst1,
519 const Fst<Arc> &ifst2,
520 const vector<pair<typename Arc::Label,
521 typename Arc::Label> > &parens,
522 MutableFst<Arc> *ofst,
523 const PdtComposeOptions &opts = PdtComposeOptions()) {
524 bool expand = opts.filter_type != PAREN_FILTER;
525 bool keep_parens = opts.filter_type != EXPAND_FILTER;
526 PdtComposeFstOptions<Arc, false> copts(ifst1, ifst2, parens,
527 expand, keep_parens);
528 copts.gc_limit = 0;
529 *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
530 if (opts.connect)
531 Connect(ofst);
532 }
533
534 } // namespace fst
535
536 #endif // FST_EXTENSIONS_PDT_COMPOSE_H__
537