1 // map.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Class to map over/transform states e.g., sort transitions
20 // Consider using when operation does not change the number of states.
21
22 #ifndef FST_LIB_STATE_MAP_H__
23 #define FST_LIB_STATE_MAP_H__
24
25 #include <algorithm>
26 #include <tr1/unordered_map>
27 using std::tr1::unordered_map;
28 using std::tr1::unordered_multimap;
29 #include <string>
30 #include <utility>
31 using std::pair; using std::make_pair;
32
33 #include <fst/cache.h>
34 #include <fst/arc-map.h>
35 #include <fst/mutable-fst.h>
36
37
38 namespace fst {
39
40 // StateMapper Interface - class determinies how states are mapped.
41 // Useful for implementing operations that do not change the number of states.
42 //
43 // class StateMapper {
44 // public:
45 // typedef A FromArc;
46 // typedef B ToArc;
47 //
48 // // Typical constructor
49 // StateMapper(const Fst<A> &fst);
50 // // Required copy constructor that allows updating Fst argument;
51 // // pass only if relevant and changed.
52 // StateMapper(const StateMapper &mapper, const Fst<A> *fst = 0);
53 //
54 // // Specifies initial state of result
55 // B::StateId Start() const;
56 // // Specifies state's final weight in result
57 // B::Weight Final(B::StateId s) const;
58 //
59 // // These methods iterate through a state's arcs in result
60 // // Specifies state to iterate over
61 // void SetState(B::StateId s);
62 // // End of arcs?
63 // bool Done() const;
64 // // Current arc
65
66 // const B &Value() const;
67 // // Advance to next arc (when !Done)
68 // void Next();
69 //
70 // // Specifies input symbol table action the mapper requires (see above).
71 // MapSymbolsAction InputSymbolsAction() const;
72 // // Specifies output symbol table action the mapper requires (see above).
73 // MapSymbolsAction OutputSymbolsAction() const;
74 // // This specifies the known properties of an Fst mapped by this
75 // // mapper. It takes as argument the input Fst's known properties.
76 // uint64 Properties(uint64 props) const;
77 // };
78 //
79 // We include a various state map versions below. One dimension of
80 // variation is whether the mapping mutates its input, writes to a
81 // new result Fst, or is an on-the-fly Fst. Another dimension is how
82 // we pass the mapper. We allow passing the mapper by pointer
83 // for cases that we need to change the state of the user's mapper.
84 // We also include map versions that pass the mapper
85 // by value or const reference when this suffices.
86
87 // Maps an arc type A using a mapper function object C, passed
88 // by pointer. This version modifies its Fst input.
89 template<class A, class C>
StateMap(MutableFst<A> * fst,C * mapper)90 void StateMap(MutableFst<A> *fst, C* mapper) {
91 typedef typename A::StateId StateId;
92 typedef typename A::Weight Weight;
93
94 if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS)
95 fst->SetInputSymbols(0);
96
97 if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS)
98 fst->SetOutputSymbols(0);
99
100 if (fst->Start() == kNoStateId)
101 return;
102
103 uint64 props = fst->Properties(kFstProperties, false);
104
105 fst->SetStart(mapper->Start());
106
107 for (StateId s = 0; s < fst->NumStates(); ++s) {
108 mapper->SetState(s);
109 fst->DeleteArcs(s);
110 for (; !mapper->Done(); mapper->Next())
111 fst->AddArc(s, mapper->Value());
112 fst->SetFinal(s, mapper->Final(s));
113 }
114
115 fst->SetProperties(mapper->Properties(props), kFstProperties);
116 }
117
118 // Maps an arc type A using a mapper function object C, passed
119 // by value. This version modifies its Fst input.
120 template<class A, class C>
StateMap(MutableFst<A> * fst,C mapper)121 void StateMap(MutableFst<A> *fst, C mapper) {
122 StateMap(fst, &mapper);
123 }
124
125
126 // Maps an arc type A to an arc type B using mapper function
127 // object C, passed by pointer. This version writes the mapped
128 // input Fst to an output MutableFst.
129 template<class A, class B, class C>
StateMap(const Fst<A> & ifst,MutableFst<B> * ofst,C * mapper)130 void StateMap(const Fst<A> &ifst, MutableFst<B> *ofst, C* mapper) {
131 typedef typename A::StateId StateId;
132 typedef typename A::Weight Weight;
133
134 ofst->DeleteStates();
135
136 if (mapper->InputSymbolsAction() == MAP_COPY_SYMBOLS)
137 ofst->SetInputSymbols(ifst.InputSymbols());
138 else if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS)
139 ofst->SetInputSymbols(0);
140
141 if (mapper->OutputSymbolsAction() == MAP_COPY_SYMBOLS)
142 ofst->SetOutputSymbols(ifst.OutputSymbols());
143 else if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS)
144 ofst->SetOutputSymbols(0);
145
146 uint64 iprops = ifst.Properties(kCopyProperties, false);
147
148 if (ifst.Start() == kNoStateId) {
149 if (iprops & kError) ofst->SetProperties(kError, kError);
150 return;
151 }
152
153 // Add all states.
154 if (ifst.Properties(kExpanded, false))
155 ofst->ReserveStates(CountStates(ifst));
156 for (StateIterator< Fst<A> > siter(ifst); !siter.Done(); siter.Next())
157 ofst->AddState();
158
159 ofst->SetStart(mapper->Start());
160
161 for (StateIterator< Fst<A> > siter(ifst); !siter.Done(); siter.Next()) {
162 StateId s = siter.Value();
163 mapper->SetState(s);
164 for (; !mapper->Done(); mapper->Next())
165 ofst->AddArc(s, mapper->Value());
166 ofst->SetFinal(s, mapper->Final(s));
167 }
168
169 uint64 oprops = ofst->Properties(kFstProperties, false);
170 ofst->SetProperties(mapper->Properties(iprops) | oprops, kFstProperties);
171 }
172
173 // Maps an arc type A to an arc type B using mapper function
174 // object C, passed by value. This version writes the mapped input
175 // Fst to an output MutableFst.
176 template<class A, class B, class C>
StateMap(const Fst<A> & ifst,MutableFst<B> * ofst,C mapper)177 void StateMap(const Fst<A> &ifst, MutableFst<B> *ofst, C mapper) {
178 StateMap(ifst, ofst, &mapper);
179 }
180
181 typedef CacheOptions StateMapFstOptions;
182
183 template <class A, class B, class C> class StateMapFst;
184
185 // Implementation of delayed StateMapFst.
186 template <class A, class B, class C>
187 class StateMapFstImpl : public CacheImpl<B> {
188 public:
189 using FstImpl<B>::SetType;
190 using FstImpl<B>::SetProperties;
191 using FstImpl<B>::SetInputSymbols;
192 using FstImpl<B>::SetOutputSymbols;
193
194 using CacheImpl<B>::PushArc;
195 using CacheImpl<B>::HasArcs;
196 using CacheImpl<B>::HasFinal;
197 using CacheImpl<B>::HasStart;
198 using CacheImpl<B>::SetArcs;
199 using CacheImpl<B>::SetFinal;
200 using CacheImpl<B>::SetStart;
201
202 friend class StateIterator< StateMapFst<A, B, C> >;
203
204 typedef B Arc;
205 typedef typename B::Weight Weight;
206 typedef typename B::StateId StateId;
207
StateMapFstImpl(const Fst<A> & fst,const C & mapper,const StateMapFstOptions & opts)208 StateMapFstImpl(const Fst<A> &fst, const C &mapper,
209 const StateMapFstOptions& opts)
210 : CacheImpl<B>(opts),
211 fst_(fst.Copy()),
212 mapper_(new C(mapper, fst_)),
213 own_mapper_(true) {
214 Init();
215 }
216
StateMapFstImpl(const Fst<A> & fst,C * mapper,const StateMapFstOptions & opts)217 StateMapFstImpl(const Fst<A> &fst, C *mapper,
218 const StateMapFstOptions& opts)
219 : CacheImpl<B>(opts),
220 fst_(fst.Copy()),
221 mapper_(mapper),
222 own_mapper_(false) {
223 Init();
224 }
225
StateMapFstImpl(const StateMapFstImpl<A,B,C> & impl)226 StateMapFstImpl(const StateMapFstImpl<A, B, C> &impl)
227 : CacheImpl<B>(impl),
228 fst_(impl.fst_->Copy(true)),
229 mapper_(new C(*impl.mapper_, fst_)),
230 own_mapper_(true) {
231 Init();
232 }
233
~StateMapFstImpl()234 ~StateMapFstImpl() {
235 delete fst_;
236 if (own_mapper_) delete mapper_;
237 }
238
Start()239 StateId Start() {
240 if (!HasStart())
241 SetStart(mapper_->Start());
242 return CacheImpl<B>::Start();
243 }
244
Final(StateId s)245 Weight Final(StateId s) {
246 if (!HasFinal(s))
247 SetFinal(s, mapper_->Final(s));
248 return CacheImpl<B>::Final(s);
249 }
250
NumArcs(StateId s)251 size_t NumArcs(StateId s) {
252 if (!HasArcs(s))
253 Expand(s);
254 return CacheImpl<B>::NumArcs(s);
255 }
256
NumInputEpsilons(StateId s)257 size_t NumInputEpsilons(StateId s) {
258 if (!HasArcs(s))
259 Expand(s);
260 return CacheImpl<B>::NumInputEpsilons(s);
261 }
262
NumOutputEpsilons(StateId s)263 size_t NumOutputEpsilons(StateId s) {
264 if (!HasArcs(s))
265 Expand(s);
266 return CacheImpl<B>::NumOutputEpsilons(s);
267 }
268
InitStateIterator(StateIteratorData<A> * data)269 void InitStateIterator(StateIteratorData<A> *data) const {
270 fst_->InitStateIterator(data);
271 }
272
InitArcIterator(StateId s,ArcIteratorData<B> * data)273 void InitArcIterator(StateId s, ArcIteratorData<B> *data) {
274 if (!HasArcs(s))
275 Expand(s);
276 CacheImpl<B>::InitArcIterator(s, data);
277 }
278
Properties()279 uint64 Properties() const { return Properties(kFstProperties); }
280
281 // Set error if found; return FST impl properties.
Properties(uint64 mask)282 uint64 Properties(uint64 mask) const {
283 if ((mask & kError) && (fst_->Properties(kError, false) ||
284 (mapper_->Properties(0) & kError)))
285 SetProperties(kError, kError);
286 return FstImpl<Arc>::Properties(mask);
287 }
288
Expand(StateId s)289 void Expand(StateId s) {
290 // Add exiting arcs.
291 for (mapper_->SetState(s); !mapper_->Done(); mapper_->Next())
292 PushArc(s, mapper_->Value());
293 SetArcs(s);
294 }
295
GetFst()296 const Fst<A> &GetFst() const {
297 return *fst_;
298 }
299
300 private:
Init()301 void Init() {
302 SetType("statemap");
303
304 if (mapper_->InputSymbolsAction() == MAP_COPY_SYMBOLS)
305 SetInputSymbols(fst_->InputSymbols());
306 else if (mapper_->InputSymbolsAction() == MAP_CLEAR_SYMBOLS)
307 SetInputSymbols(0);
308
309 if (mapper_->OutputSymbolsAction() == MAP_COPY_SYMBOLS)
310 SetOutputSymbols(fst_->OutputSymbols());
311 else if (mapper_->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS)
312 SetOutputSymbols(0);
313
314 uint64 props = fst_->Properties(kCopyProperties, false);
315 SetProperties(mapper_->Properties(props));
316 }
317
318 const Fst<A> *fst_;
319 C* mapper_;
320 bool own_mapper_;
321
322 void operator=(const StateMapFstImpl<A, B, C> &); // disallow
323 };
324
325
326 // Maps an arc type A to an arc type B using Mapper function object
327 // C. This version is a delayed Fst.
328 template <class A, class B, class C>
329 class StateMapFst : public ImplToFst< StateMapFstImpl<A, B, C> > {
330 public:
331 friend class ArcIterator< StateMapFst<A, B, C> >;
332
333 typedef B Arc;
334 typedef typename B::Weight Weight;
335 typedef typename B::StateId StateId;
336 typedef CacheState<B> State;
337 typedef StateMapFstImpl<A, B, C> Impl;
338
StateMapFst(const Fst<A> & fst,const C & mapper,const StateMapFstOptions & opts)339 StateMapFst(const Fst<A> &fst, const C &mapper,
340 const StateMapFstOptions& opts)
341 : ImplToFst<Impl>(new Impl(fst, mapper, opts)) {}
342
StateMapFst(const Fst<A> & fst,C * mapper,const StateMapFstOptions & opts)343 StateMapFst(const Fst<A> &fst, C* mapper, const StateMapFstOptions& opts)
344 : ImplToFst<Impl>(new Impl(fst, mapper, opts)) {}
345
StateMapFst(const Fst<A> & fst,const C & mapper)346 StateMapFst(const Fst<A> &fst, const C &mapper)
347 : ImplToFst<Impl>(new Impl(fst, mapper, StateMapFstOptions())) {}
348
StateMapFst(const Fst<A> & fst,C * mapper)349 StateMapFst(const Fst<A> &fst, C* mapper)
350 : ImplToFst<Impl>(new Impl(fst, mapper, StateMapFstOptions())) {}
351
352 // See Fst<>::Copy() for doc.
353 StateMapFst(const StateMapFst<A, B, C> &fst, bool safe = false)
354 : ImplToFst<Impl>(fst, safe) {}
355
356 // Get a copy of this StateMapFst. See Fst<>::Copy() for further doc.
357 virtual StateMapFst<A, B, C> *Copy(bool safe = false) const {
358 return new StateMapFst<A, B, C>(*this, safe);
359 }
360
InitStateIterator(StateIteratorData<A> * data)361 virtual void InitStateIterator(StateIteratorData<A> *data) const {
362 GetImpl()->InitStateIterator(data);
363 }
364
InitArcIterator(StateId s,ArcIteratorData<B> * data)365 virtual void InitArcIterator(StateId s, ArcIteratorData<B> *data) const {
366 GetImpl()->InitArcIterator(s, data);
367 }
368
369 protected:
GetImpl()370 Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
371
372 private:
373 void operator=(const StateMapFst<A, B, C> &fst); // disallow
374 };
375
376
377 // Specialization for StateMapFst.
378 template <class A, class B, class C>
379 class ArcIterator< StateMapFst<A, B, C> >
380 : public CacheArcIterator< StateMapFst<A, B, C> > {
381 public:
382 typedef typename A::StateId StateId;
383
ArcIterator(const StateMapFst<A,B,C> & fst,StateId s)384 ArcIterator(const StateMapFst<A, B, C> &fst, StateId s)
385 : CacheArcIterator< StateMapFst<A, B, C> >(fst.GetImpl(), s) {
386 if (!fst.GetImpl()->HasArcs(s))
387 fst.GetImpl()->Expand(s);
388 }
389
390 private:
391 DISALLOW_COPY_AND_ASSIGN(ArcIterator);
392 };
393
394 //
395 // Utility Mappers
396 //
397
398 // Mapper that returns its input.
399 template <class A>
400 class IdentityStateMapper {
401 public:
402 typedef A FromArc;
403 typedef A ToArc;
404
405 typedef typename A::StateId StateId;
406 typedef typename A::Weight Weight;
407
IdentityStateMapper(const Fst<A> & fst)408 explicit IdentityStateMapper(const Fst<A> &fst) : fst_(fst), aiter_(0) {}
409
410 // Allows updating Fst argument; pass only if changed.
411 IdentityStateMapper(const IdentityStateMapper<A> &mapper,
412 const Fst<A> *fst = 0)
413 : fst_(fst ? *fst : mapper.fst_), aiter_(0) {}
414
~IdentityStateMapper()415 ~IdentityStateMapper() { delete aiter_; }
416
Start()417 StateId Start() const { return fst_.Start(); }
418
Final(StateId s)419 Weight Final(StateId s) const { return fst_.Final(s); }
420
SetState(StateId s)421 void SetState(StateId s) {
422 if (aiter_) delete aiter_;
423 aiter_ = new ArcIterator< Fst<A> >(fst_, s);
424 }
425
Done()426 bool Done() const { return aiter_->Done(); }
Value()427 const A &Value() const { return aiter_->Value(); }
Next()428 void Next() { aiter_->Next(); }
429
InputSymbolsAction()430 MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
OutputSymbolsAction()431 MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS;}
432
Properties(uint64 props)433 uint64 Properties(uint64 props) const { return props; }
434
435 private:
436 const Fst<A> &fst_;
437 ArcIterator< Fst<A> > *aiter_;
438 };
439
440 template <class A>
441 class ArcSumMapper {
442 public:
443 typedef A FromArc;
444 typedef A ToArc;
445
446 typedef typename A::StateId StateId;
447 typedef typename A::Weight Weight;
448
ArcSumMapper(const Fst<A> & fst)449 explicit ArcSumMapper(const Fst<A> &fst) : fst_(fst), i_(0) {}
450
451 // Allows updating Fst argument; pass only if changed.
452 ArcSumMapper(const ArcSumMapper<A> &mapper,
453 const Fst<A> *fst = 0)
454 : fst_(fst ? *fst : mapper.fst_), i_(0) {}
455
Start()456 StateId Start() const { return fst_.Start(); }
Final(StateId s)457 Weight Final(StateId s) const { return fst_.Final(s); }
458
SetState(StateId s)459 void SetState(StateId s) {
460 i_ = 0;
461 arcs_.clear();
462 arcs_.reserve(fst_.NumArcs(s));
463 for (ArcIterator<Fst<A> > aiter(fst_, s); !aiter.Done(); aiter.Next())
464 arcs_.push_back(aiter.Value());
465
466 // First sorts the exiting arcs by input label, output label
467 // and destination state and then sums weights of arcs with
468 // the same input label, output label, and destination state.
469 sort(arcs_.begin(), arcs_.end(), comp_);
470 size_t narcs = 0;
471 for (size_t i = 0; i < arcs_.size(); ++i) {
472 if (narcs > 0 && equal_(arcs_[i], arcs_[narcs - 1])) {
473 arcs_[narcs - 1].weight = Plus(arcs_[narcs - 1].weight,
474 arcs_[i].weight);
475 } else {
476 arcs_[narcs++] = arcs_[i];
477 }
478 }
479 arcs_.resize(narcs);
480 }
481
Done()482 bool Done() const { return i_ >= arcs_.size(); }
Value()483 const A &Value() const { return arcs_[i_]; }
Next()484 void Next() { ++i_; }
485
InputSymbolsAction()486 MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
OutputSymbolsAction()487 MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
488
Properties(uint64 props)489 uint64 Properties(uint64 props) const {
490 return props & kArcSortProperties &
491 kDeleteArcsProperties & kWeightInvariantProperties;
492 }
493
494 private:
495 struct Compare {
operatorCompare496 bool operator()(const A& x, const A& y) {
497 if (x.ilabel < y.ilabel) return true;
498 if (x.ilabel > y.ilabel) return false;
499 if (x.olabel < y.olabel) return true;
500 if (x.olabel > y.olabel) return false;
501 if (x.nextstate < y.nextstate) return true;
502 if (x.nextstate > y.nextstate) return false;
503 return false;
504 }
505 };
506
507 struct Equal {
operatorEqual508 bool operator()(const A& x, const A& y) {
509 return (x.ilabel == y.ilabel &&
510 x.olabel == y.olabel &&
511 x.nextstate == y.nextstate);
512 }
513 };
514
515 const Fst<A> &fst_;
516 Compare comp_;
517 Equal equal_;
518 vector<A> arcs_;
519 ssize_t i_; // current arc position
520
521 void operator=(const ArcSumMapper<A> &); // disallow
522 };
523
524 template <class A>
525 class ArcUniqueMapper {
526 public:
527 typedef A FromArc;
528 typedef A ToArc;
529
530 typedef typename A::StateId StateId;
531 typedef typename A::Weight Weight;
532
ArcUniqueMapper(const Fst<A> & fst)533 explicit ArcUniqueMapper(const Fst<A> &fst) : fst_(fst), i_(0) {}
534
535 // Allows updating Fst argument; pass only if changed.
536 ArcUniqueMapper(const ArcUniqueMapper<A> &mapper,
537 const Fst<A> *fst = 0)
538 : fst_(fst ? *fst : mapper.fst_), i_(0) {}
539
Start()540 StateId Start() const { return fst_.Start(); }
Final(StateId s)541 Weight Final(StateId s) const { return fst_.Final(s); }
542
SetState(StateId s)543 void SetState(StateId s) {
544 i_ = 0;
545 arcs_.clear();
546 arcs_.reserve(fst_.NumArcs(s));
547 for (ArcIterator<Fst<A> > aiter(fst_, s); !aiter.Done(); aiter.Next())
548 arcs_.push_back(aiter.Value());
549
550 // First sorts the exiting arcs by input label, output label
551 // and destination state and then uniques identical arcs
552 sort(arcs_.begin(), arcs_.end(), comp_);
553 typename vector<A>::iterator unique_end =
554 unique(arcs_.begin(), arcs_.end(), equal_);
555 arcs_.resize(unique_end - arcs_.begin());
556 }
557
Done()558 bool Done() const { return i_ >= arcs_.size(); }
Value()559 const A &Value() const { return arcs_[i_]; }
Next()560 void Next() { ++i_; }
561
InputSymbolsAction()562 MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
OutputSymbolsAction()563 MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
564
Properties(uint64 props)565 uint64 Properties(uint64 props) const {
566 return props & kArcSortProperties & kDeleteArcsProperties;
567 }
568
569 private:
570 struct Compare {
operatorCompare571 bool operator()(const A& x, const A& y) {
572 if (x.ilabel < y.ilabel) return true;
573 if (x.ilabel > y.ilabel) return false;
574 if (x.olabel < y.olabel) return true;
575 if (x.olabel > y.olabel) return false;
576 if (x.nextstate < y.nextstate) return true;
577 if (x.nextstate > y.nextstate) return false;
578 return false;
579 }
580 };
581
582 struct Equal {
operatorEqual583 bool operator()(const A& x, const A& y) {
584 return (x.ilabel == y.ilabel &&
585 x.olabel == y.olabel &&
586 x.nextstate == y.nextstate &&
587 x.weight == y.weight);
588 }
589 };
590
591 const Fst<A> &fst_;
592 Compare comp_;
593 Equal equal_;
594 vector<A> arcs_;
595 ssize_t i_; // current arc position
596
597 void operator=(const ArcUniqueMapper<A> &); // disallow
598 };
599
600
601 } // namespace fst
602
603 #endif // FST_LIB_STATE_MAP_H__
604