1 // minimize.h
2 // minimize.h
3
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 // Copyright 2005-2010 Google, Inc.
17 // Author: johans@google.com (Johan Schalkwyk)
18 //
19 // \file Functions and classes to minimize a finite state acceptor
20 //
21
22 #ifndef FST_LIB_MINIMIZE_H__
23 #define FST_LIB_MINIMIZE_H__
24
25 #include <cmath>
26
27 #include <algorithm>
28 #include <map>
29 #include <queue>
30 #include <vector>
31 using std::vector;
32
33 #include <fst/arcsort.h>
34 #include <fst/connect.h>
35 #include <fst/dfs-visit.h>
36 #include <fst/encode.h>
37 #include <fst/factor-weight.h>
38 #include <fst/fst.h>
39 #include <fst/mutable-fst.h>
40 #include <fst/partition.h>
41 #include <fst/push.h>
42 #include <fst/queue.h>
43 #include <fst/reverse.h>
44 #include <fst/state-map.h>
45
46
47 namespace fst {
48
49 // comparator for creating partition based on sorting on
50 // - states
51 // - final weight
52 // - out degree,
53 // - (input label, output label, weight, destination_block)
54 template <class A>
55 class StateComparator {
56 public:
57 typedef typename A::StateId StateId;
58 typedef typename A::Weight Weight;
59
60 static const uint32 kCompareFinal = 0x00000001;
61 static const uint32 kCompareOutDegree = 0x00000002;
62 static const uint32 kCompareArcs = 0x00000004;
63 static const uint32 kCompareAll = 0x00000007;
64
65 StateComparator(const Fst<A>& fst,
66 const Partition<typename A::StateId>& partition,
67 uint32 flags = kCompareAll)
fst_(fst)68 : fst_(fst), partition_(partition), flags_(flags) {}
69
70 // compare state x with state y based on sort criteria
operator()71 bool operator()(const StateId x, const StateId y) const {
72 // check for final state equivalence
73 if (flags_ & kCompareFinal) {
74 const size_t xfinal = fst_.Final(x).Hash();
75 const size_t yfinal = fst_.Final(y).Hash();
76 if (xfinal < yfinal) return true;
77 else if (xfinal > yfinal) return false;
78 }
79
80 if (flags_ & kCompareOutDegree) {
81 // check for # arcs
82 if (fst_.NumArcs(x) < fst_.NumArcs(y)) return true;
83 if (fst_.NumArcs(x) > fst_.NumArcs(y)) return false;
84
85 if (flags_ & kCompareArcs) {
86 // # arcs are equal, check for arc match
87 for (ArcIterator<Fst<A> > aiter1(fst_, x), aiter2(fst_, y);
88 !aiter1.Done() && !aiter2.Done(); aiter1.Next(), aiter2.Next()) {
89 const A& arc1 = aiter1.Value();
90 const A& arc2 = aiter2.Value();
91 if (arc1.ilabel < arc2.ilabel) return true;
92 if (arc1.ilabel > arc2.ilabel) return false;
93
94 if (partition_.class_id(arc1.nextstate) <
95 partition_.class_id(arc2.nextstate)) return true;
96 if (partition_.class_id(arc1.nextstate) >
97 partition_.class_id(arc2.nextstate)) return false;
98 }
99 }
100 }
101
102 return false;
103 }
104
105 private:
106 const Fst<A>& fst_;
107 const Partition<typename A::StateId>& partition_;
108 const uint32 flags_;
109 };
110
111 template <class A> const uint32 StateComparator<A>::kCompareFinal;
112 template <class A> const uint32 StateComparator<A>::kCompareOutDegree;
113 template <class A> const uint32 StateComparator<A>::kCompareArcs;
114 template <class A> const uint32 StateComparator<A>::kCompareAll;
115
116
117 // Computes equivalence classes for cyclic Fsts. For cyclic minimization
118 // we use the classic HopCroft minimization algorithm, which is of
119 //
120 // O(E)log(N),
121 //
122 // where E is the number of edges in the machine and N is number of states.
123 //
124 // The following paper describes the original algorithm
125 // An N Log N algorithm for minimizing states in a finite automaton
126 // by John HopCroft, January 1971
127 //
128 template <class A, class Queue>
129 class CyclicMinimizer {
130 public:
131 typedef typename A::Label Label;
132 typedef typename A::StateId StateId;
133 typedef typename A::StateId ClassId;
134 typedef typename A::Weight Weight;
135 typedef ReverseArc<A> RevA;
136
CyclicMinimizer(const ExpandedFst<A> & fst)137 CyclicMinimizer(const ExpandedFst<A>& fst) {
138 Initialize(fst);
139 Compute(fst);
140 }
141
~CyclicMinimizer()142 ~CyclicMinimizer() {
143 delete aiter_queue_;
144 }
145
partition()146 const Partition<StateId>& partition() const {
147 return P_;
148 }
149
150 // helper classes
151 private:
152 typedef ArcIterator<Fst<RevA> > ArcIter;
153 class ArcIterCompare {
154 public:
ArcIterCompare(const Partition<StateId> & partition)155 ArcIterCompare(const Partition<StateId>& partition)
156 : partition_(partition) {}
157
ArcIterCompare(const ArcIterCompare & comp)158 ArcIterCompare(const ArcIterCompare& comp)
159 : partition_(comp.partition_) {}
160
161 // compare two iterators based on there input labels, and proto state
162 // (partition class Ids)
operator()163 bool operator()(const ArcIter* x, const ArcIter* y) const {
164 const RevA& xarc = x->Value();
165 const RevA& yarc = y->Value();
166 return (xarc.ilabel > yarc.ilabel);
167 }
168
169 private:
170 const Partition<StateId>& partition_;
171 };
172
173 typedef priority_queue<ArcIter*, vector<ArcIter*>, ArcIterCompare>
174 ArcIterQueue;
175
176 // helper methods
177 private:
178 // prepartitions the space into equivalence classes with
179 // same final weight
180 // same # arcs per state
181 // same outgoing arcs
PrePartition(const Fst<A> & fst)182 void PrePartition(const Fst<A>& fst) {
183 VLOG(5) << "PrePartition";
184
185 typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap;
186 StateComparator<A> comp(fst, P_, StateComparator<A>::kCompareFinal);
187 EquivalenceMap equiv_map(comp);
188
189 StateIterator<Fst<A> > siter(fst);
190 StateId class_id = P_.AddClass();
191 P_.Add(siter.Value(), class_id);
192 equiv_map[siter.Value()] = class_id;
193 L_.Enqueue(class_id);
194 for (siter.Next(); !siter.Done(); siter.Next()) {
195 StateId s = siter.Value();
196 typename EquivalenceMap::const_iterator it = equiv_map.find(s);
197 if (it == equiv_map.end()) {
198 class_id = P_.AddClass();
199 P_.Add(s, class_id);
200 equiv_map[s] = class_id;
201 L_.Enqueue(class_id);
202 } else {
203 P_.Add(s, it->second);
204 equiv_map[s] = it->second;
205 }
206 }
207
208 VLOG(5) << "Initial Partition: " << P_.num_classes();
209 }
210
211 // - Create inverse transition Tr_ = rev(fst)
212 // - loop over states in fst and split on final, creating two blocks
213 // in the partition corresponding to final, non-final
Initialize(const Fst<A> & fst)214 void Initialize(const Fst<A>& fst) {
215 // construct Tr
216 Reverse(fst, &Tr_);
217 ILabelCompare<RevA> ilabel_comp;
218 ArcSort(&Tr_, ilabel_comp);
219
220 // initial split (F, S - F)
221 P_.Initialize(Tr_.NumStates() - 1);
222
223 // prep partition
224 PrePartition(fst);
225
226 // allocate arc iterator queue
227 ArcIterCompare comp(P_);
228 aiter_queue_ = new ArcIterQueue(comp);
229 }
230
231 // partition all classes with destination C
Split(ClassId C)232 void Split(ClassId C) {
233 // Prep priority queue. Open arc iterator for each state in C, and
234 // insert into priority queue.
235 for (PartitionIterator<StateId> siter(P_, C);
236 !siter.Done(); siter.Next()) {
237 StateId s = siter.Value();
238 if (Tr_.NumArcs(s + 1))
239 aiter_queue_->push(new ArcIterator<Fst<RevA> >(Tr_, s + 1));
240 }
241
242 // Now pop arc iterator from queue, split entering equivalence class
243 // re-insert updated iterator into queue.
244 Label prev_label = -1;
245 while (!aiter_queue_->empty()) {
246 ArcIterator<Fst<RevA> >* aiter = aiter_queue_->top();
247 aiter_queue_->pop();
248 if (aiter->Done()) {
249 delete aiter;
250 continue;
251 }
252
253 const RevA& arc = aiter->Value();
254 StateId from_state = aiter->Value().nextstate - 1;
255 Label from_label = arc.ilabel;
256 if (prev_label != from_label)
257 P_.FinalizeSplit(&L_);
258
259 StateId from_class = P_.class_id(from_state);
260 if (P_.class_size(from_class) > 1)
261 P_.SplitOn(from_state);
262
263 prev_label = from_label;
264 aiter->Next();
265 if (aiter->Done())
266 delete aiter;
267 else
268 aiter_queue_->push(aiter);
269 }
270 P_.FinalizeSplit(&L_);
271 }
272
273 // Main loop for hopcroft minimization.
Compute(const Fst<A> & fst)274 void Compute(const Fst<A>& fst) {
275 // process active classes (FIFO, or FILO)
276 while (!L_.Empty()) {
277 ClassId C = L_.Head();
278 L_.Dequeue();
279
280 // split on C, all labels in C
281 Split(C);
282 }
283 }
284
285 // helper data
286 private:
287 // Partioning of states into equivalence classes
288 Partition<StateId> P_;
289
290 // L = set of active classes to be processed in partition P
291 Queue L_;
292
293 // reverse transition function
294 VectorFst<RevA> Tr_;
295
296 // Priority queue of open arc iterators for all states in the 'splitter'
297 // equivalence class
298 ArcIterQueue* aiter_queue_;
299 };
300
301
302 // Computes equivalence classes for acyclic Fsts. The implementation details
303 // for this algorithms is documented by the following paper.
304 //
305 // Minimization of acyclic deterministic automata in linear time
306 // Dominque Revuz
307 //
308 // Complexity O(|E|)
309 //
310 template <class A>
311 class AcyclicMinimizer {
312 public:
313 typedef typename A::Label Label;
314 typedef typename A::StateId StateId;
315 typedef typename A::StateId ClassId;
316 typedef typename A::Weight Weight;
317
AcyclicMinimizer(const ExpandedFst<A> & fst)318 AcyclicMinimizer(const ExpandedFst<A>& fst) {
319 Initialize(fst);
320 Refine(fst);
321 }
322
partition()323 const Partition<StateId>& partition() {
324 return partition_;
325 }
326
327 // helper classes
328 private:
329 // DFS visitor to compute the height (distance) to final state.
330 class HeightVisitor {
331 public:
HeightVisitor()332 HeightVisitor() : max_height_(0), num_states_(0) { }
333
334 // invoked before dfs visit
InitVisit(const Fst<A> & fst)335 void InitVisit(const Fst<A>& fst) {}
336
337 // invoked when state is discovered (2nd arg is DFS tree root)
InitState(StateId s,StateId root)338 bool InitState(StateId s, StateId root) {
339 // extend height array and initialize height (distance) to 0
340 for (size_t i = height_.size(); i <= s; ++i)
341 height_.push_back(-1);
342
343 if (s >= num_states_) num_states_ = s + 1;
344 return true;
345 }
346
347 // invoked when tree arc examined (to undiscoverted state)
TreeArc(StateId s,const A & arc)348 bool TreeArc(StateId s, const A& arc) {
349 return true;
350 }
351
352 // invoked when back arc examined (to unfinished state)
BackArc(StateId s,const A & arc)353 bool BackArc(StateId s, const A& arc) {
354 return true;
355 }
356
357 // invoked when forward or cross arc examined (to finished state)
ForwardOrCrossArc(StateId s,const A & arc)358 bool ForwardOrCrossArc(StateId s, const A& arc) {
359 if (height_[arc.nextstate] + 1 > height_[s])
360 height_[s] = height_[arc.nextstate] + 1;
361 return true;
362 }
363
364 // invoked when state finished (parent is kNoStateId for tree root)
FinishState(StateId s,StateId parent,const A * parent_arc)365 void FinishState(StateId s, StateId parent, const A* parent_arc) {
366 if (height_[s] == -1) height_[s] = 0;
367 StateId h = height_[s] + 1;
368 if (parent >= 0) {
369 if (h > height_[parent]) height_[parent] = h;
370 if (h > max_height_) max_height_ = h;
371 }
372 }
373
374 // invoked after DFS visit
FinishVisit()375 void FinishVisit() {}
376
max_height()377 size_t max_height() const { return max_height_; }
378
height()379 const vector<StateId>& height() const { return height_; }
380
num_states()381 const size_t num_states() const { return num_states_; }
382
383 private:
384 vector<StateId> height_;
385 size_t max_height_;
386 size_t num_states_;
387 };
388
389 // helper methods
390 private:
391 // cluster states according to height (distance to final state)
Initialize(const Fst<A> & fst)392 void Initialize(const Fst<A>& fst) {
393 // compute height (distance to final state)
394 HeightVisitor hvisitor;
395 DfsVisit(fst, &hvisitor);
396
397 // create initial partition based on height
398 partition_.Initialize(hvisitor.num_states());
399 partition_.AllocateClasses(hvisitor.max_height() + 1);
400 const vector<StateId>& hstates = hvisitor.height();
401 for (size_t s = 0; s < hstates.size(); ++s)
402 partition_.Add(s, hstates[s]);
403 }
404
405 // refine states based on arc sort (out degree, arc equivalence)
Refine(const Fst<A> & fst)406 void Refine(const Fst<A>& fst) {
407 typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap;
408 StateComparator<A> comp(fst, partition_);
409
410 // start with tail (height = 0)
411 size_t height = partition_.num_classes();
412 for (size_t h = 0; h < height; ++h) {
413 EquivalenceMap equiv_classes(comp);
414
415 // sort states within equivalence class
416 PartitionIterator<StateId> siter(partition_, h);
417 equiv_classes[siter.Value()] = h;
418 for (siter.Next(); !siter.Done(); siter.Next()) {
419 const StateId s = siter.Value();
420 typename EquivalenceMap::const_iterator it = equiv_classes.find(s);
421 if (it == equiv_classes.end())
422 equiv_classes[s] = partition_.AddClass();
423 else
424 equiv_classes[s] = it->second;
425 }
426
427 // create refined partition
428 for (siter.Reset(); !siter.Done();) {
429 const StateId s = siter.Value();
430 const StateId old_class = partition_.class_id(s);
431 const StateId new_class = equiv_classes[s];
432
433 // a move operation can invalidate the iterator, so
434 // we first update the iterator to the next element
435 // before we move the current element out of the list
436 siter.Next();
437 if (old_class != new_class)
438 partition_.Move(s, new_class);
439 }
440 }
441 }
442
443 private:
444 Partition<StateId> partition_;
445 };
446
447
448 // Given a partition and a mutable fst, merge states of Fst inplace
449 // (i.e. destructively). Merging works by taking the first state in
450 // a class of the partition to be the representative state for the class.
451 // Each arc is then reconnected to this state. All states in the class
452 // are merged by adding there arcs to the representative state.
453 template <class A>
MergeStates(const Partition<typename A::StateId> & partition,MutableFst<A> * fst)454 void MergeStates(
455 const Partition<typename A::StateId>& partition, MutableFst<A>* fst) {
456 typedef typename A::StateId StateId;
457
458 vector<StateId> state_map(partition.num_classes());
459 for (size_t i = 0; i < partition.num_classes(); ++i) {
460 PartitionIterator<StateId> siter(partition, i);
461 state_map[i] = siter.Value(); // first state in partition;
462 }
463
464 // relabel destination states
465 for (size_t c = 0; c < partition.num_classes(); ++c) {
466 for (PartitionIterator<StateId> siter(partition, c);
467 !siter.Done(); siter.Next()) {
468 StateId s = siter.Value();
469 for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
470 !aiter.Done(); aiter.Next()) {
471 A arc = aiter.Value();
472 arc.nextstate = state_map[partition.class_id(arc.nextstate)];
473
474 if (s == state_map[c]) // first state just set destination
475 aiter.SetValue(arc);
476 else
477 fst->AddArc(state_map[c], arc);
478 }
479 }
480 }
481 fst->SetStart(state_map[partition.class_id(fst->Start())]);
482
483 Connect(fst);
484 }
485
486 template <class A>
AcceptorMinimize(MutableFst<A> * fst)487 void AcceptorMinimize(MutableFst<A>* fst) {
488 typedef typename A::StateId StateId;
489 if (!(fst->Properties(kAcceptor | kUnweighted, true))) {
490 FSTERROR() << "FST is not an unweighted acceptor";
491 fst->SetProperties(kError, kError);
492 return;
493 }
494
495 // connect fst before minimization, handles disconnected states
496 Connect(fst);
497 if (fst->NumStates() == 0) return;
498
499 if (fst->Properties(kAcyclic, true)) {
500 // Acyclic minimization (revuz)
501 VLOG(2) << "Acyclic Minimization";
502 ArcSort(fst, ILabelCompare<A>());
503 AcyclicMinimizer<A> minimizer(*fst);
504 MergeStates(minimizer.partition(), fst);
505
506 } else {
507 // Cyclic minimizaton (hopcroft)
508 VLOG(2) << "Cyclic Minimization";
509 CyclicMinimizer<A, LifoQueue<StateId> > minimizer(*fst);
510 MergeStates(minimizer.partition(), fst);
511 }
512
513 // Merge in appropriate semiring
514 ArcUniqueMapper<A> mapper(*fst);
515 StateMap(fst, mapper);
516 }
517
518
519 // In place minimization of deterministic weighted automata and transducers.
520 // For transducers, then the 'sfst' argument is not null, the algorithm
521 // produces a compact factorization of the minimal transducer.
522 //
523 // In the acyclic case, we use an algorithm from Dominique Revuz that
524 // is linear in the number of arcs (edges) in the machine.
525 // Complexity = O(E)
526 //
527 // In the cyclic case, we use the classical hopcroft minimization.
528 // Complexity = O(|E|log(|N|)
529 //
530 template <class A>
531 void Minimize(MutableFst<A>* fst,
532 MutableFst<A>* sfst = 0,
533 float delta = kDelta) {
534 uint64 props = fst->Properties(kAcceptor | kIDeterministic|
535 kWeighted | kUnweighted, true);
536 if (!(props & kIDeterministic)) {
537 FSTERROR() << "FST is not deterministic";
538 fst->SetProperties(kError, kError);
539 return;
540 }
541
542 if (!(props & kAcceptor)) { // weighted transducer
543 VectorFst< GallicArc<A, STRING_LEFT> > gfst;
544 ArcMap(*fst, &gfst, ToGallicMapper<A, STRING_LEFT>());
545 fst->DeleteStates();
546 gfst.SetProperties(kAcceptor, kAcceptor);
547 Push(&gfst, REWEIGHT_TO_INITIAL, delta);
548 ArcMap(&gfst, QuantizeMapper< GallicArc<A, STRING_LEFT> >(delta));
549 EncodeMapper< GallicArc<A, STRING_LEFT> >
550 encoder(kEncodeLabels | kEncodeWeights, ENCODE);
551 Encode(&gfst, &encoder);
552 AcceptorMinimize(&gfst);
553 Decode(&gfst, encoder);
554
555 if (sfst == 0) {
556 FactorWeightFst< GallicArc<A, STRING_LEFT>,
557 GallicFactor<typename A::Label,
558 typename A::Weight, STRING_LEFT> > fwfst(gfst);
559 SymbolTable *osyms = fst->OutputSymbols() ?
560 fst->OutputSymbols()->Copy() : 0;
561 ArcMap(fwfst, fst, FromGallicMapper<A, STRING_LEFT>());
562 fst->SetOutputSymbols(osyms);
563 delete osyms;
564 } else {
565 sfst->SetOutputSymbols(fst->OutputSymbols());
566 GallicToNewSymbolsMapper<A, STRING_LEFT> mapper(sfst);
567 ArcMap(gfst, fst, &mapper);
568 fst->SetOutputSymbols(sfst->InputSymbols());
569 }
570 } else if (props & kWeighted) { // weighted acceptor
571 Push(fst, REWEIGHT_TO_INITIAL, delta);
572 ArcMap(fst, QuantizeMapper<A>(delta));
573 EncodeMapper<A> encoder(kEncodeLabels | kEncodeWeights, ENCODE);
574 Encode(fst, &encoder);
575 AcceptorMinimize(fst);
576 Decode(fst, encoder);
577 } else { // unweighted acceptor
578 AcceptorMinimize(fst);
579 }
580 }
581
582 } // namespace fst
583
584 #endif // FST_LIB_MINIMIZE_H__
585