1 // factor-weight.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: allauzen@google.com (Cyril Allauzen)
17 //
18 // \file
19 // Classes to factor weights in an FST.
20 
21 #ifndef FST_LIB_FACTOR_WEIGHT_H__
22 #define FST_LIB_FACTOR_WEIGHT_H__
23 
24 #include <algorithm>
25 #include <tr1/unordered_map>
26 using std::tr1::unordered_map;
27 using std::tr1::unordered_multimap;
28 #include <string>
29 #include <utility>
30 using std::pair; using std::make_pair;
31 #include <vector>
32 using std::vector;
33 
34 #include <fst/cache.h>
35 #include <fst/test-properties.h>
36 
37 
38 namespace fst {
39 
40 const uint32 kFactorFinalWeights = 0x00000001;
41 const uint32 kFactorArcWeights   = 0x00000002;
42 
43 template <class Arc>
44 struct FactorWeightOptions : CacheOptions {
45   typedef typename Arc::Label Label;
46   float delta;
47   uint32 mode;         // factor arc weights and/or final weights
48   Label final_ilabel;  // input label of arc created when factoring final w's
49   Label final_olabel;  // output label of arc created when factoring final w's
50 
51   FactorWeightOptions(const CacheOptions &opts, float d,
52                       uint32 m = kFactorArcWeights | kFactorFinalWeights,
53                       Label il = 0, Label ol = 0)
CacheOptionsFactorWeightOptions54       : CacheOptions(opts), delta(d), mode(m), final_ilabel(il),
55         final_olabel(ol) {}
56 
57   explicit FactorWeightOptions(
58       float d, uint32 m = kFactorArcWeights | kFactorFinalWeights,
59       Label il = 0, Label ol = 0)
deltaFactorWeightOptions60       : delta(d), mode(m), final_ilabel(il), final_olabel(ol) {}
61 
62   FactorWeightOptions(uint32 m = kFactorArcWeights | kFactorFinalWeights,
63                       Label il = 0, Label ol = 0)
deltaFactorWeightOptions64       : delta(kDelta), mode(m), final_ilabel(il), final_olabel(ol) {}
65 };
66 
67 
68 // A factor iterator takes as argument a weight w and returns a
69 // sequence of pairs of weights (xi,yi) such that the sum of the
70 // products xi times yi is equal to w. If w is fully factored,
71 // the iterator should return nothing.
72 //
73 // template <class W>
74 // class FactorIterator {
75 //  public:
76 //   FactorIterator(W w);
77 //   bool Done() const;
78 //   void Next();
79 //   pair<W, W> Value() const;
80 //   void Reset();
81 // }
82 
83 
84 // Factor trivially.
85 template <class W>
86 class IdentityFactor {
87  public:
IdentityFactor(const W & w)88   IdentityFactor(const W &w) {}
Done()89   bool Done() const { return true; }
Next()90   void Next() {}
Value()91   pair<W, W> Value() const { return make_pair(W::One(), W::One()); } // unused
Reset()92   void Reset() {}
93 };
94 
95 
96 // Factor a StringWeight w as 'ab' where 'a' is a label.
97 template <typename L, StringType S = STRING_LEFT>
98 class StringFactor {
99  public:
StringFactor(const StringWeight<L,S> & w)100   StringFactor(const StringWeight<L, S> &w)
101       : weight_(w), done_(w.Size() <= 1) {}
102 
Done()103   bool Done() const { return done_; }
104 
Next()105   void Next() { done_ = true; }
106 
Value()107   pair< StringWeight<L, S>, StringWeight<L, S> > Value() const {
108     StringWeightIterator<L, S> iter(weight_);
109     StringWeight<L, S> w1(iter.Value());
110     StringWeight<L, S> w2;
111     for (iter.Next(); !iter.Done(); iter.Next())
112       w2.PushBack(iter.Value());
113     return make_pair(w1, w2);
114   }
115 
Reset()116   void Reset() { done_ = weight_.Size() <= 1; }
117 
118  private:
119   StringWeight<L, S> weight_;
120   bool done_;
121 };
122 
123 
124 // Factor a GallicWeight using StringFactor.
125 template <class L, class W, StringType S = STRING_LEFT>
126 class GallicFactor {
127  public:
GallicFactor(const GallicWeight<L,W,S> & w)128   GallicFactor(const GallicWeight<L, W, S> &w)
129       : weight_(w), done_(w.Value1().Size() <= 1) {}
130 
Done()131   bool Done() const { return done_; }
132 
Next()133   void Next() { done_ = true; }
134 
Value()135   pair< GallicWeight<L, W, S>, GallicWeight<L, W, S> > Value() const {
136     StringFactor<L, S> iter(weight_.Value1());
137     GallicWeight<L, W, S> w1(iter.Value().first, weight_.Value2());
138     GallicWeight<L, W, S> w2(iter.Value().second, W::One());
139     return make_pair(w1, w2);
140   }
141 
Reset()142   void Reset() { done_ = weight_.Value1().Size() <= 1; }
143 
144  private:
145   GallicWeight<L, W, S> weight_;
146   bool done_;
147 };
148 
149 
150 // Implementation class for FactorWeight
151 template <class A, class F>
152 class FactorWeightFstImpl
153     : public CacheImpl<A> {
154  public:
155   using FstImpl<A>::SetType;
156   using FstImpl<A>::SetProperties;
157   using FstImpl<A>::SetInputSymbols;
158   using FstImpl<A>::SetOutputSymbols;
159 
160   using CacheBaseImpl< CacheState<A> >::PushArc;
161   using CacheBaseImpl< CacheState<A> >::HasStart;
162   using CacheBaseImpl< CacheState<A> >::HasFinal;
163   using CacheBaseImpl< CacheState<A> >::HasArcs;
164   using CacheBaseImpl< CacheState<A> >::SetArcs;
165   using CacheBaseImpl< CacheState<A> >::SetFinal;
166   using CacheBaseImpl< CacheState<A> >::SetStart;
167 
168   typedef A Arc;
169   typedef typename A::Label Label;
170   typedef typename A::Weight Weight;
171   typedef typename A::StateId StateId;
172   typedef F FactorIterator;
173 
174   struct Element {
ElementElement175     Element() {}
176 
ElementElement177     Element(StateId s, Weight w) : state(s), weight(w) {}
178 
179     StateId state;     // Input state Id
180     Weight weight;     // Residual weight
181   };
182 
FactorWeightFstImpl(const Fst<A> & fst,const FactorWeightOptions<A> & opts)183   FactorWeightFstImpl(const Fst<A> &fst, const FactorWeightOptions<A> &opts)
184       : CacheImpl<A>(opts),
185         fst_(fst.Copy()),
186         delta_(opts.delta),
187         mode_(opts.mode),
188         final_ilabel_(opts.final_ilabel),
189         final_olabel_(opts.final_olabel) {
190     SetType("factor_weight");
191     uint64 props = fst.Properties(kFstProperties, false);
192     SetProperties(FactorWeightProperties(props), kCopyProperties);
193 
194     SetInputSymbols(fst.InputSymbols());
195     SetOutputSymbols(fst.OutputSymbols());
196 
197     if (mode_ == 0)
198       LOG(WARNING) << "FactorWeightFst: factor mode is set to 0: "
199                    << "factoring neither arc weights nor final weights.";
200   }
201 
FactorWeightFstImpl(const FactorWeightFstImpl<A,F> & impl)202   FactorWeightFstImpl(const FactorWeightFstImpl<A, F> &impl)
203       : CacheImpl<A>(impl),
204         fst_(impl.fst_->Copy(true)),
205         delta_(impl.delta_),
206         mode_(impl.mode_),
207         final_ilabel_(impl.final_ilabel_),
208         final_olabel_(impl.final_olabel_) {
209     SetType("factor_weight");
210     SetProperties(impl.Properties(), kCopyProperties);
211     SetInputSymbols(impl.InputSymbols());
212     SetOutputSymbols(impl.OutputSymbols());
213   }
214 
~FactorWeightFstImpl()215   ~FactorWeightFstImpl() {
216     delete fst_;
217   }
218 
Start()219   StateId Start() {
220     if (!HasStart()) {
221       StateId s = fst_->Start();
222       if (s == kNoStateId)
223         return kNoStateId;
224       StateId start = FindState(Element(fst_->Start(), Weight::One()));
225       SetStart(start);
226     }
227     return CacheImpl<A>::Start();
228   }
229 
Final(StateId s)230   Weight Final(StateId s) {
231     if (!HasFinal(s)) {
232       const Element &e = elements_[s];
233       // TODO: fix so cast is unnecessary
234       Weight w = e.state == kNoStateId
235                  ? e.weight
236                  : (Weight) Times(e.weight, fst_->Final(e.state));
237       FactorIterator f(w);
238       if (!(mode_ & kFactorFinalWeights) || f.Done())
239         SetFinal(s, w);
240       else
241         SetFinal(s, Weight::Zero());
242     }
243     return CacheImpl<A>::Final(s);
244   }
245 
NumArcs(StateId s)246   size_t NumArcs(StateId s) {
247     if (!HasArcs(s))
248       Expand(s);
249     return CacheImpl<A>::NumArcs(s);
250   }
251 
NumInputEpsilons(StateId s)252   size_t NumInputEpsilons(StateId s) {
253     if (!HasArcs(s))
254       Expand(s);
255     return CacheImpl<A>::NumInputEpsilons(s);
256   }
257 
NumOutputEpsilons(StateId s)258   size_t NumOutputEpsilons(StateId s) {
259     if (!HasArcs(s))
260       Expand(s);
261     return CacheImpl<A>::NumOutputEpsilons(s);
262   }
263 
Properties()264   uint64 Properties() const { return Properties(kFstProperties); }
265 
266   // Set error if found; return FST impl properties.
Properties(uint64 mask)267   uint64 Properties(uint64 mask) const {
268     if ((mask & kError) && fst_->Properties(kError, false))
269       SetProperties(kError, kError);
270     return FstImpl<Arc>::Properties(mask);
271   }
272 
InitArcIterator(StateId s,ArcIteratorData<A> * data)273   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
274     if (!HasArcs(s))
275       Expand(s);
276     CacheImpl<A>::InitArcIterator(s, data);
277   }
278 
279 
280   // Find state corresponding to an element. Create new state
281   // if element not found.
FindState(const Element & e)282   StateId FindState(const Element &e) {
283     if (!(mode_ & kFactorArcWeights) && e.weight == Weight::One()) {
284       while (unfactored_.size() <= e.state)
285         unfactored_.push_back(kNoStateId);
286       if (unfactored_[e.state] == kNoStateId) {
287         unfactored_[e.state] = elements_.size();
288         elements_.push_back(e);
289       }
290       return unfactored_[e.state];
291     } else {
292       typename ElementMap::iterator eit = element_map_.find(e);
293       if (eit != element_map_.end()) {
294         return (*eit).second;
295       } else {
296         StateId s = elements_.size();
297         elements_.push_back(e);
298         element_map_.insert(pair<const Element, StateId>(e, s));
299         return s;
300       }
301     }
302   }
303 
304   // Computes the outgoing transitions from a state, creating new destination
305   // states as needed.
Expand(StateId s)306   void Expand(StateId s) {
307     Element e = elements_[s];
308     if (e.state != kNoStateId) {
309       for (ArcIterator< Fst<A> > ait(*fst_, e.state);
310            !ait.Done();
311            ait.Next()) {
312         const A &arc = ait.Value();
313         Weight w = Times(e.weight, arc.weight);
314         FactorIterator fit(w);
315         if (!(mode_ & kFactorArcWeights) || fit.Done()) {
316           StateId d = FindState(Element(arc.nextstate, Weight::One()));
317           PushArc(s, Arc(arc.ilabel, arc.olabel, w, d));
318         } else {
319           for (; !fit.Done(); fit.Next()) {
320             const pair<Weight, Weight> &p = fit.Value();
321             StateId d = FindState(Element(arc.nextstate,
322                                           p.second.Quantize(delta_)));
323             PushArc(s, Arc(arc.ilabel, arc.olabel, p.first, d));
324           }
325         }
326       }
327     }
328 
329     if ((mode_ & kFactorFinalWeights) &&
330         ((e.state == kNoStateId) ||
331          (fst_->Final(e.state) != Weight::Zero()))) {
332       Weight w = e.state == kNoStateId
333                  ? e.weight
334                  : Times(e.weight, fst_->Final(e.state));
335       for (FactorIterator fit(w);
336            !fit.Done();
337            fit.Next()) {
338         const pair<Weight, Weight> &p = fit.Value();
339         StateId d = FindState(Element(kNoStateId,
340                                       p.second.Quantize(delta_)));
341         PushArc(s, Arc(final_ilabel_, final_olabel_, p.first, d));
342       }
343     }
344     SetArcs(s);
345   }
346 
347  private:
348   static const size_t kPrime = 7853;
349 
350   // Equality function for Elements, assume weights have been quantized.
351   class ElementEqual {
352    public:
operator()353     bool operator()(const Element &x, const Element &y) const {
354       return x.state == y.state && x.weight == y.weight;
355     }
356   };
357 
358   // Hash function for Elements to Fst states.
359   class ElementKey {
360    public:
operator()361     size_t operator()(const Element &x) const {
362       return static_cast<size_t>(x.state * kPrime + x.weight.Hash());
363     }
364    private:
365   };
366 
367   typedef unordered_map<Element, StateId, ElementKey, ElementEqual> ElementMap;
368 
369   const Fst<A> *fst_;
370   float delta_;
371   uint32 mode_;               // factoring arc and/or final weights
372   Label final_ilabel_;        // ilabel of arc created when factoring final w's
373   Label final_olabel_;        // olabel of arc created when factoring final w's
374   vector<Element> elements_;  // mapping Fst state to Elements
375   ElementMap element_map_;    // mapping Elements to Fst state
376   // mapping between old/new 'StateId' for states that do not need to
377   // be factored when 'mode_' is '0' or 'kFactorFinalWeights'
378   vector<StateId> unfactored_;
379 
380   void operator=(const FactorWeightFstImpl<A, F> &);  // disallow
381 };
382 
383 template <class A, class F> const size_t FactorWeightFstImpl<A, F>::kPrime;
384 
385 
386 // FactorWeightFst takes as template parameter a FactorIterator as
387 // defined above. The result of weight factoring is a transducer
388 // equivalent to the input whose path weights have been factored
389 // according to the FactorIterator. States and transitions will be
390 // added as necessary. The algorithm is a generalization to arbitrary
391 // weights of the second step of the input epsilon-normalization
392 // algorithm due to Mohri, "Generic epsilon-removal and input
393 // epsilon-normalization algorithms for weighted transducers",
394 // International Journal of Computer Science 13(1): 129-143 (2002).
395 //
396 // This class attaches interface to implementation and handles
397 // reference counting, delegating most methods to ImplToFst.
398 template <class A, class F>
399 class FactorWeightFst : public ImplToFst< FactorWeightFstImpl<A, F> > {
400  public:
401   friend class ArcIterator< FactorWeightFst<A, F> >;
402   friend class StateIterator< FactorWeightFst<A, F> >;
403 
404   typedef A Arc;
405   typedef typename A::Weight Weight;
406   typedef typename A::StateId StateId;
407   typedef CacheState<A> State;
408   typedef FactorWeightFstImpl<A, F> Impl;
409 
FactorWeightFst(const Fst<A> & fst)410   FactorWeightFst(const Fst<A> &fst)
411       : ImplToFst<Impl>(new Impl(fst, FactorWeightOptions<A>())) {}
412 
FactorWeightFst(const Fst<A> & fst,const FactorWeightOptions<A> & opts)413   FactorWeightFst(const Fst<A> &fst,  const FactorWeightOptions<A> &opts)
414       : ImplToFst<Impl>(new Impl(fst, opts)) {}
415 
416   // See Fst<>::Copy() for doc.
FactorWeightFst(const FactorWeightFst<A,F> & fst,bool copy)417   FactorWeightFst(const FactorWeightFst<A, F> &fst, bool copy)
418       : ImplToFst<Impl>(fst, copy) {}
419 
420   // Get a copy of this FactorWeightFst. See Fst<>::Copy() for further doc.
421   virtual FactorWeightFst<A, F> *Copy(bool copy = false) const {
422     return new FactorWeightFst<A, F>(*this, copy);
423   }
424 
425   virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
426 
InitArcIterator(StateId s,ArcIteratorData<A> * data)427   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
428     GetImpl()->InitArcIterator(s, data);
429   }
430 
431  private:
432   // Makes visible to friends.
GetImpl()433   Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
434 
435   void operator=(const FactorWeightFst<A, F> &fst);  // Disallow
436 };
437 
438 
439 // Specialization for FactorWeightFst.
440 template<class A, class F>
441 class StateIterator< FactorWeightFst<A, F> >
442     : public CacheStateIterator< FactorWeightFst<A, F> > {
443  public:
StateIterator(const FactorWeightFst<A,F> & fst)444   explicit StateIterator(const FactorWeightFst<A, F> &fst)
445       : CacheStateIterator< FactorWeightFst<A, F> >(fst, fst.GetImpl()) {}
446 };
447 
448 
449 // Specialization for FactorWeightFst.
450 template <class A, class F>
451 class ArcIterator< FactorWeightFst<A, F> >
452     : public CacheArcIterator< FactorWeightFst<A, F> > {
453  public:
454   typedef typename A::StateId StateId;
455 
ArcIterator(const FactorWeightFst<A,F> & fst,StateId s)456   ArcIterator(const FactorWeightFst<A, F> &fst, StateId s)
457       : CacheArcIterator< FactorWeightFst<A, F> >(fst.GetImpl(), s) {
458     if (!fst.GetImpl()->HasArcs(s))
459       fst.GetImpl()->Expand(s);
460   }
461 
462  private:
463   DISALLOW_COPY_AND_ASSIGN(ArcIterator);
464 };
465 
466 template <class A, class F> inline
InitStateIterator(StateIteratorData<A> * data)467 void FactorWeightFst<A, F>::InitStateIterator(StateIteratorData<A> *data) const
468 {
469   data->base = new StateIterator< FactorWeightFst<A, F> >(*this);
470 }
471 
472 
473 }  // namespace fst
474 
475 #endif // FST_LIB_FACTOR_WEIGHT_H__
476