1 // minimize.h
2 // minimize.h
3 
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 // Copyright 2005-2010 Google, Inc.
17 // Author: johans@google.com (Johan Schalkwyk)
18 //
19 // \file Functions and classes to minimize a finite state acceptor
20 //
21 
22 #ifndef FST_LIB_MINIMIZE_H__
23 #define FST_LIB_MINIMIZE_H__
24 
25 #include <cmath>
26 
27 #include <algorithm>
28 #include <map>
29 #include <queue>
30 #include <vector>
31 using std::vector;
32 
33 #include <fst/arcsort.h>
34 #include <fst/connect.h>
35 #include <fst/dfs-visit.h>
36 #include <fst/encode.h>
37 #include <fst/factor-weight.h>
38 #include <fst/fst.h>
39 #include <fst/mutable-fst.h>
40 #include <fst/partition.h>
41 #include <fst/push.h>
42 #include <fst/queue.h>
43 #include <fst/reverse.h>
44 #include <fst/state-map.h>
45 
46 
47 namespace fst {
48 
49 // comparator for creating partition based on sorting on
50 // - states
51 // - final weight
52 // - out degree,
53 // -  (input label, output label, weight, destination_block)
54 template <class A>
55 class StateComparator {
56  public:
57   typedef typename A::StateId StateId;
58   typedef typename A::Weight Weight;
59 
60   static const uint32 kCompareFinal     = 0x00000001;
61   static const uint32 kCompareOutDegree = 0x00000002;
62   static const uint32 kCompareArcs      = 0x00000004;
63   static const uint32 kCompareAll       = 0x00000007;
64 
65   StateComparator(const Fst<A>& fst,
66                   const Partition<typename A::StateId>& partition,
67                   uint32 flags = kCompareAll)
fst_(fst)68       : fst_(fst), partition_(partition), flags_(flags) {}
69 
70   // compare state x with state y based on sort criteria
operator()71   bool operator()(const StateId x, const StateId y) const {
72     // check for final state equivalence
73     if (flags_ & kCompareFinal) {
74       const size_t xfinal = fst_.Final(x).Hash();
75       const size_t yfinal = fst_.Final(y).Hash();
76       if      (xfinal < yfinal) return true;
77       else if (xfinal > yfinal) return false;
78     }
79 
80     if (flags_ & kCompareOutDegree) {
81       // check for # arcs
82       if (fst_.NumArcs(x) < fst_.NumArcs(y)) return true;
83       if (fst_.NumArcs(x) > fst_.NumArcs(y)) return false;
84 
85       if (flags_ & kCompareArcs) {
86         // # arcs are equal, check for arc match
87         for (ArcIterator<Fst<A> > aiter1(fst_, x), aiter2(fst_, y);
88              !aiter1.Done() && !aiter2.Done(); aiter1.Next(), aiter2.Next()) {
89           const A& arc1 = aiter1.Value();
90           const A& arc2 = aiter2.Value();
91           if (arc1.ilabel < arc2.ilabel) return true;
92           if (arc1.ilabel > arc2.ilabel) return false;
93 
94           if (partition_.class_id(arc1.nextstate) <
95               partition_.class_id(arc2.nextstate)) return true;
96           if (partition_.class_id(arc1.nextstate) >
97               partition_.class_id(arc2.nextstate)) return false;
98         }
99       }
100     }
101 
102     return false;
103   }
104 
105  private:
106   const Fst<A>& fst_;
107   const Partition<typename A::StateId>& partition_;
108   const uint32 flags_;
109 };
110 
111 template <class A> const uint32 StateComparator<A>::kCompareFinal;
112 template <class A> const uint32 StateComparator<A>::kCompareOutDegree;
113 template <class A> const uint32 StateComparator<A>::kCompareArcs;
114 template <class A> const uint32 StateComparator<A>::kCompareAll;
115 
116 
117 // Computes equivalence classes for cyclic Fsts. For cyclic minimization
118 // we use the classic HopCroft minimization algorithm, which is of
119 //
120 //   O(E)log(N),
121 //
122 // where E is the number of edges in the machine and N is number of states.
123 //
124 // The following paper describes the original algorithm
125 //  An N Log N algorithm for minimizing states in a finite automaton
126 //  by John HopCroft, January 1971
127 //
128 template <class A, class Queue>
129 class CyclicMinimizer {
130  public:
131   typedef typename A::Label Label;
132   typedef typename A::StateId StateId;
133   typedef typename A::StateId ClassId;
134   typedef typename A::Weight Weight;
135   typedef ReverseArc<A> RevA;
136 
CyclicMinimizer(const ExpandedFst<A> & fst)137   CyclicMinimizer(const ExpandedFst<A>& fst) {
138     Initialize(fst);
139     Compute(fst);
140   }
141 
~CyclicMinimizer()142   ~CyclicMinimizer() {
143     delete aiter_queue_;
144   }
145 
partition()146   const Partition<StateId>& partition() const {
147     return P_;
148   }
149 
150   // helper classes
151  private:
152   typedef ArcIterator<Fst<RevA> > ArcIter;
153   class ArcIterCompare {
154    public:
ArcIterCompare(const Partition<StateId> & partition)155     ArcIterCompare(const Partition<StateId>& partition)
156         : partition_(partition) {}
157 
ArcIterCompare(const ArcIterCompare & comp)158     ArcIterCompare(const ArcIterCompare& comp)
159         : partition_(comp.partition_) {}
160 
161     // compare two iterators based on there input labels, and proto state
162     // (partition class Ids)
operator()163     bool operator()(const ArcIter* x, const ArcIter* y) const {
164       const RevA& xarc = x->Value();
165       const RevA& yarc = y->Value();
166       return (xarc.ilabel > yarc.ilabel);
167     }
168 
169    private:
170     const Partition<StateId>& partition_;
171   };
172 
173   typedef priority_queue<ArcIter*, vector<ArcIter*>, ArcIterCompare>
174   ArcIterQueue;
175 
176   // helper methods
177  private:
178   // prepartitions the space into equivalence classes with
179   //   same final weight
180   //   same # arcs per state
181   //   same outgoing arcs
PrePartition(const Fst<A> & fst)182   void PrePartition(const Fst<A>& fst) {
183     VLOG(5) << "PrePartition";
184 
185     typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap;
186     StateComparator<A> comp(fst, P_, StateComparator<A>::kCompareFinal);
187     EquivalenceMap equiv_map(comp);
188 
189     StateIterator<Fst<A> > siter(fst);
190     StateId class_id = P_.AddClass();
191     P_.Add(siter.Value(), class_id);
192     equiv_map[siter.Value()] = class_id;
193     L_.Enqueue(class_id);
194     for (siter.Next(); !siter.Done(); siter.Next()) {
195       StateId  s = siter.Value();
196       typename EquivalenceMap::const_iterator it = equiv_map.find(s);
197       if (it == equiv_map.end()) {
198         class_id = P_.AddClass();
199         P_.Add(s, class_id);
200         equiv_map[s] = class_id;
201         L_.Enqueue(class_id);
202       } else {
203         P_.Add(s, it->second);
204         equiv_map[s] = it->second;
205       }
206     }
207 
208     VLOG(5) << "Initial Partition: " << P_.num_classes();
209   }
210 
211   // - Create inverse transition Tr_ = rev(fst)
212   // - loop over states in fst and split on final, creating two blocks
213   //   in the partition corresponding to final, non-final
Initialize(const Fst<A> & fst)214   void Initialize(const Fst<A>& fst) {
215     // construct Tr
216     Reverse(fst, &Tr_);
217     ILabelCompare<RevA> ilabel_comp;
218     ArcSort(&Tr_, ilabel_comp);
219 
220     // initial split (F, S - F)
221     P_.Initialize(Tr_.NumStates() - 1);
222 
223     // prep partition
224     PrePartition(fst);
225 
226     // allocate arc iterator queue
227     ArcIterCompare comp(P_);
228     aiter_queue_ = new ArcIterQueue(comp);
229   }
230 
231   // partition all classes with destination C
Split(ClassId C)232   void Split(ClassId C) {
233     // Prep priority queue. Open arc iterator for each state in C, and
234     // insert into priority queue.
235     for (PartitionIterator<StateId> siter(P_, C);
236          !siter.Done(); siter.Next()) {
237       StateId s = siter.Value();
238       if (Tr_.NumArcs(s + 1))
239         aiter_queue_->push(new ArcIterator<Fst<RevA> >(Tr_, s + 1));
240     }
241 
242     // Now pop arc iterator from queue, split entering equivalence class
243     // re-insert updated iterator into queue.
244     Label prev_label = -1;
245     while (!aiter_queue_->empty()) {
246       ArcIterator<Fst<RevA> >* aiter = aiter_queue_->top();
247       aiter_queue_->pop();
248       if (aiter->Done()) {
249         delete aiter;
250         continue;
251      }
252 
253       const RevA& arc = aiter->Value();
254       StateId from_state = aiter->Value().nextstate - 1;
255       Label   from_label = arc.ilabel;
256       if (prev_label != from_label)
257         P_.FinalizeSplit(&L_);
258 
259       StateId from_class = P_.class_id(from_state);
260       if (P_.class_size(from_class) > 1)
261         P_.SplitOn(from_state);
262 
263       prev_label = from_label;
264       aiter->Next();
265       if (aiter->Done())
266         delete aiter;
267       else
268         aiter_queue_->push(aiter);
269     }
270     P_.FinalizeSplit(&L_);
271   }
272 
273   // Main loop for hopcroft minimization.
Compute(const Fst<A> & fst)274   void Compute(const Fst<A>& fst) {
275     // process active classes (FIFO, or FILO)
276     while (!L_.Empty()) {
277       ClassId C = L_.Head();
278       L_.Dequeue();
279 
280       // split on C, all labels in C
281       Split(C);
282     }
283   }
284 
285   // helper data
286  private:
287   // Partioning of states into equivalence classes
288   Partition<StateId> P_;
289 
290   // L = set of active classes to be processed in partition P
291   Queue L_;
292 
293   // reverse transition function
294   VectorFst<RevA> Tr_;
295 
296   // Priority queue of open arc iterators for all states in the 'splitter'
297   // equivalence class
298   ArcIterQueue* aiter_queue_;
299 };
300 
301 
302 // Computes equivalence classes for acyclic Fsts. The implementation details
303 // for this algorithms is documented by the following paper.
304 //
305 // Minimization of acyclic deterministic automata in linear time
306 //  Dominque Revuz
307 //
308 // Complexity O(|E|)
309 //
310 template <class A>
311 class AcyclicMinimizer {
312  public:
313   typedef typename A::Label Label;
314   typedef typename A::StateId StateId;
315   typedef typename A::StateId ClassId;
316   typedef typename A::Weight Weight;
317 
AcyclicMinimizer(const ExpandedFst<A> & fst)318   AcyclicMinimizer(const ExpandedFst<A>& fst) {
319     Initialize(fst);
320     Refine(fst);
321   }
322 
partition()323   const Partition<StateId>& partition() {
324     return partition_;
325   }
326 
327   // helper classes
328  private:
329   // DFS visitor to compute the height (distance) to final state.
330   class HeightVisitor {
331    public:
HeightVisitor()332     HeightVisitor() : max_height_(0), num_states_(0) { }
333 
334     // invoked before dfs visit
InitVisit(const Fst<A> & fst)335     void InitVisit(const Fst<A>& fst) {}
336 
337     // invoked when state is discovered (2nd arg is DFS tree root)
InitState(StateId s,StateId root)338     bool InitState(StateId s, StateId root) {
339       // extend height array and initialize height (distance) to 0
340       for (size_t i = height_.size(); i <= s; ++i)
341         height_.push_back(-1);
342 
343       if (s >= num_states_) num_states_ = s + 1;
344       return true;
345     }
346 
347     // invoked when tree arc examined (to undiscoverted state)
TreeArc(StateId s,const A & arc)348     bool TreeArc(StateId s, const A& arc) {
349       return true;
350     }
351 
352     // invoked when back arc examined (to unfinished state)
BackArc(StateId s,const A & arc)353     bool BackArc(StateId s, const A& arc) {
354       return true;
355     }
356 
357     // invoked when forward or cross arc examined (to finished state)
ForwardOrCrossArc(StateId s,const A & arc)358     bool ForwardOrCrossArc(StateId s, const A& arc) {
359       if (height_[arc.nextstate] + 1 > height_[s])
360         height_[s] = height_[arc.nextstate] + 1;
361       return true;
362     }
363 
364     // invoked when state finished (parent is kNoStateId for tree root)
FinishState(StateId s,StateId parent,const A * parent_arc)365     void FinishState(StateId s, StateId parent, const A* parent_arc) {
366       if (height_[s] == -1) height_[s] = 0;
367       StateId h = height_[s] +  1;
368       if (parent >= 0) {
369         if (h > height_[parent]) height_[parent] = h;
370         if (h > max_height_)     max_height_ = h;
371       }
372     }
373 
374     // invoked after DFS visit
FinishVisit()375     void FinishVisit() {}
376 
max_height()377     size_t max_height() const { return max_height_; }
378 
height()379     const vector<StateId>& height() const { return height_; }
380 
num_states()381     const size_t num_states() const { return num_states_; }
382 
383    private:
384     vector<StateId> height_;
385     size_t max_height_;
386     size_t num_states_;
387   };
388 
389   // helper methods
390  private:
391   // cluster states according to height (distance to final state)
Initialize(const Fst<A> & fst)392   void Initialize(const Fst<A>& fst) {
393     // compute height (distance to final state)
394     HeightVisitor hvisitor;
395     DfsVisit(fst, &hvisitor);
396 
397     // create initial partition based on height
398     partition_.Initialize(hvisitor.num_states());
399     partition_.AllocateClasses(hvisitor.max_height() + 1);
400     const vector<StateId>& hstates = hvisitor.height();
401     for (size_t s = 0; s < hstates.size(); ++s)
402       partition_.Add(s, hstates[s]);
403   }
404 
405   // refine states based on arc sort (out degree, arc equivalence)
Refine(const Fst<A> & fst)406   void Refine(const Fst<A>& fst) {
407     typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap;
408     StateComparator<A> comp(fst, partition_);
409 
410     // start with tail (height = 0)
411     size_t height = partition_.num_classes();
412     for (size_t h = 0; h < height; ++h) {
413       EquivalenceMap equiv_classes(comp);
414 
415       // sort states within equivalence class
416       PartitionIterator<StateId> siter(partition_, h);
417       equiv_classes[siter.Value()] = h;
418       for (siter.Next(); !siter.Done(); siter.Next()) {
419         const StateId s = siter.Value();
420         typename EquivalenceMap::const_iterator it = equiv_classes.find(s);
421         if (it == equiv_classes.end())
422           equiv_classes[s] = partition_.AddClass();
423         else
424           equiv_classes[s] = it->second;
425       }
426 
427       // create refined partition
428       for (siter.Reset(); !siter.Done();) {
429         const StateId s = siter.Value();
430         const StateId old_class = partition_.class_id(s);
431         const StateId new_class = equiv_classes[s];
432 
433         // a move operation can invalidate the iterator, so
434         // we first update the iterator to the next element
435         // before we move the current element out of the list
436         siter.Next();
437         if (old_class != new_class)
438           partition_.Move(s, new_class);
439       }
440     }
441   }
442 
443  private:
444   Partition<StateId> partition_;
445 };
446 
447 
448 // Given a partition and a mutable fst, merge states of Fst inplace
449 // (i.e. destructively). Merging works by taking the first state in
450 // a class of the partition to be the representative state for the class.
451 // Each arc is then reconnected to this state. All states in the class
452 // are merged by adding there arcs to the representative state.
453 template <class A>
MergeStates(const Partition<typename A::StateId> & partition,MutableFst<A> * fst)454 void MergeStates(
455     const Partition<typename A::StateId>& partition, MutableFst<A>* fst) {
456   typedef typename A::StateId StateId;
457 
458   vector<StateId> state_map(partition.num_classes());
459   for (size_t i = 0; i < partition.num_classes(); ++i) {
460     PartitionIterator<StateId> siter(partition, i);
461     state_map[i] = siter.Value();  // first state in partition;
462   }
463 
464   // relabel destination states
465   for (size_t c = 0; c < partition.num_classes(); ++c) {
466     for (PartitionIterator<StateId> siter(partition, c);
467          !siter.Done(); siter.Next()) {
468       StateId s = siter.Value();
469       for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
470            !aiter.Done(); aiter.Next()) {
471         A arc = aiter.Value();
472         arc.nextstate = state_map[partition.class_id(arc.nextstate)];
473 
474         if (s == state_map[c])  // first state just set destination
475           aiter.SetValue(arc);
476         else
477           fst->AddArc(state_map[c], arc);
478       }
479     }
480   }
481   fst->SetStart(state_map[partition.class_id(fst->Start())]);
482 
483   Connect(fst);
484 }
485 
486 template <class A>
AcceptorMinimize(MutableFst<A> * fst)487 void AcceptorMinimize(MutableFst<A>* fst) {
488   typedef typename A::StateId StateId;
489   if (!(fst->Properties(kAcceptor | kUnweighted, true))) {
490     FSTERROR() << "FST is not an unweighted acceptor";
491     fst->SetProperties(kError, kError);
492     return;
493   }
494 
495   // connect fst before minimization, handles disconnected states
496   Connect(fst);
497   if (fst->NumStates() == 0) return;
498 
499   if (fst->Properties(kAcyclic, true)) {
500     // Acyclic minimization (revuz)
501     VLOG(2) << "Acyclic Minimization";
502     ArcSort(fst, ILabelCompare<A>());
503     AcyclicMinimizer<A> minimizer(*fst);
504     MergeStates(minimizer.partition(), fst);
505 
506   } else {
507     // Cyclic minimizaton (hopcroft)
508     VLOG(2) << "Cyclic Minimization";
509     CyclicMinimizer<A, LifoQueue<StateId> > minimizer(*fst);
510     MergeStates(minimizer.partition(), fst);
511   }
512 
513   // Merge in appropriate semiring
514   ArcUniqueMapper<A> mapper(*fst);
515   StateMap(fst, mapper);
516 }
517 
518 
519 // In place minimization of deterministic weighted automata and transducers.
520 // For transducers, then the 'sfst' argument is not null, the algorithm
521 // produces a compact factorization of the minimal transducer.
522 //
523 // In the acyclic case, we use an algorithm from Dominique Revuz that
524 // is linear in the number of arcs (edges) in the machine.
525 //  Complexity = O(E)
526 //
527 // In the cyclic case, we use the classical hopcroft minimization.
528 //  Complexity = O(|E|log(|N|)
529 //
530 template <class A>
531 void Minimize(MutableFst<A>* fst,
532               MutableFst<A>* sfst = 0,
533               float delta = kDelta) {
534   uint64 props = fst->Properties(kAcceptor | kIDeterministic|
535                                  kWeighted | kUnweighted, true);
536   if (!(props & kIDeterministic)) {
537     FSTERROR() << "FST is not deterministic";
538     fst->SetProperties(kError, kError);
539     return;
540   }
541 
542   if (!(props & kAcceptor)) {  // weighted transducer
543     VectorFst< GallicArc<A, STRING_LEFT> > gfst;
544     ArcMap(*fst, &gfst, ToGallicMapper<A, STRING_LEFT>());
545     fst->DeleteStates();
546     gfst.SetProperties(kAcceptor, kAcceptor);
547     Push(&gfst, REWEIGHT_TO_INITIAL, delta);
548     ArcMap(&gfst, QuantizeMapper< GallicArc<A, STRING_LEFT> >(delta));
549     EncodeMapper< GallicArc<A, STRING_LEFT> >
550       encoder(kEncodeLabels | kEncodeWeights, ENCODE);
551     Encode(&gfst, &encoder);
552     AcceptorMinimize(&gfst);
553     Decode(&gfst, encoder);
554 
555     if (sfst == 0) {
556       FactorWeightFst< GallicArc<A, STRING_LEFT>,
557         GallicFactor<typename A::Label,
558         typename A::Weight, STRING_LEFT> > fwfst(gfst);
559       SymbolTable *osyms = fst->OutputSymbols() ?
560           fst->OutputSymbols()->Copy() : 0;
561       ArcMap(fwfst, fst, FromGallicMapper<A, STRING_LEFT>());
562       fst->SetOutputSymbols(osyms);
563       delete osyms;
564     } else {
565       sfst->SetOutputSymbols(fst->OutputSymbols());
566       GallicToNewSymbolsMapper<A, STRING_LEFT> mapper(sfst);
567       ArcMap(gfst, fst, &mapper);
568       fst->SetOutputSymbols(sfst->InputSymbols());
569     }
570   } else if (props & kWeighted) {  // weighted acceptor
571     Push(fst, REWEIGHT_TO_INITIAL, delta);
572     ArcMap(fst, QuantizeMapper<A>(delta));
573     EncodeMapper<A> encoder(kEncodeLabels | kEncodeWeights, ENCODE);
574     Encode(fst, &encoder);
575     AcceptorMinimize(fst);
576     Decode(fst, encoder);
577   } else {  // unweighted acceptor
578     AcceptorMinimize(fst);
579   }
580 }
581 
582 }  // namespace fst
583 
584 #endif  // FST_LIB_MINIMIZE_H__
585