1 // randgen.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Classes and functions to generate random paths through an FST.
20 
21 #ifndef FST_LIB_RANDGEN_H__
22 #define FST_LIB_RANDGEN_H__
23 
24 #include <cmath>
25 #include <cstdlib>
26 #include <ctime>
27 #include <map>
28 
29 #include <fst/accumulator.h>
30 #include <fst/cache.h>
31 #include <fst/dfs-visit.h>
32 #include <fst/mutable-fst.h>
33 
34 namespace fst {
35 
36 //
37 // ARC SELECTORS - these function objects are used to select a random
38 // transition to take from an FST's state. They should return a number
39 // N s.t. 0 <= N <= NumArcs(). If N < NumArcs(), then the N-th
40 // transition is selected. If N == NumArcs(), then the final weight at
41 // that state is selected (i.e., the 'super-final' transition is selected).
42 // It can be assumed these will not be called unless either there
43 // are transitions leaving the state and/or the state is final.
44 //
45 
46 // Randomly selects a transition using the uniform distribution.
47 template <class A>
48 struct UniformArcSelector {
49   typedef typename A::StateId StateId;
50   typedef typename A::Weight Weight;
51 
52   UniformArcSelector(int seed = time(0)) { srand(seed); }
53 
operatorUniformArcSelector54   size_t operator()(const Fst<A> &fst, StateId s) const {
55     double r = rand()/(RAND_MAX + 1.0);
56     size_t n = fst.NumArcs(s);
57     if (fst.Final(s) != Weight::Zero())
58       ++n;
59     return static_cast<size_t>(r * n);
60   }
61 };
62 
63 
64 // Randomly selects a transition w.r.t. the weights treated as negative
65 // log probabilities after normalizing for the total weight leaving
66 // the state. Weight::zero transitions are disregarded.
67 // Assumes Weight::Value() accesses the floating point
68 // representation of the weight.
69 template <class A>
70 class LogProbArcSelector {
71  public:
72   typedef typename A::StateId StateId;
73   typedef typename A::Weight Weight;
74 
75   LogProbArcSelector(int seed = time(0)) { srand(seed); }
76 
operator()77   size_t operator()(const Fst<A> &fst, StateId s) const {
78     // Find total weight leaving state
79     double sum = 0.0;
80     for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done();
81          aiter.Next()) {
82       const A &arc = aiter.Value();
83       sum += exp(-to_log_weight_(arc.weight).Value());
84     }
85     sum += exp(-to_log_weight_(fst.Final(s)).Value());
86 
87     double r = rand()/(RAND_MAX + 1.0);
88     double p = 0.0;
89     int n = 0;
90     for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done();
91          aiter.Next(), ++n) {
92       const A &arc = aiter.Value();
93       p += exp(-to_log_weight_(arc.weight).Value());
94       if (p > r * sum) return n;
95     }
96     return n;
97   }
98 
99  private:
100   WeightConvert<Weight, Log64Weight> to_log_weight_;
101 };
102 
103 // Convenience definitions
104 typedef LogProbArcSelector<StdArc> StdArcSelector;
105 typedef LogProbArcSelector<LogArc> LogArcSelector;
106 
107 
108 // Same as LogProbArcSelector but use CacheLogAccumulator to cache
109 // the cummulative weight computations.
110 template <class A>
111 class FastLogProbArcSelector : public LogProbArcSelector<A> {
112  public:
113   typedef typename A::StateId StateId;
114   typedef typename A::Weight Weight;
115   using LogProbArcSelector<A>::operator();
116 
117   FastLogProbArcSelector(int seed = time(0))
118       : LogProbArcSelector<A>(seed),
119         seed_(seed) {}
120 
operator()121   size_t operator()(const Fst<A> &fst, StateId s,
122                     CacheLogAccumulator<A> *accumulator) const {
123     accumulator->SetState(s);
124     ArcIterator< Fst<A> > aiter(fst, s);
125     // Find total weight leaving state
126     double sum = to_log_weight_(accumulator->Sum(fst.Final(s), &aiter, 0,
127                                                  fst.NumArcs(s))).Value();
128     double r = -log(rand()/(RAND_MAX + 1.0));
129     return accumulator->LowerBound(r + sum, &aiter);
130   }
131 
Seed()132   int Seed() const { return seed_; }
133  private:
134   int seed_;
135   WeightConvert<Weight, Log64Weight> to_log_weight_;
136 };
137 
138 // Random path state info maintained by RandGenFst and passed to samplers.
139 template <typename A>
140 struct RandState {
141   typedef typename A::StateId StateId;
142 
143   StateId state_id;              // current input FST state
144   size_t nsamples;               // # of samples to be sampled at this state
145   size_t length;                 // length of path to this random state
146   size_t select;                 // previous sample arc selection
147   const RandState<A> *parent;    // previous random state on this path
148 
RandStateRandState149   RandState(StateId s, size_t n, size_t l, size_t k, const RandState<A> *p)
150       : state_id(s), nsamples(n), length(l), select(k), parent(p) {}
151 
RandStateRandState152   RandState()
153       : state_id(kNoStateId), nsamples(0), length(0), select(0), parent(0) {}
154 };
155 
156 // This class, given an arc selector, samples, with raplacement,
157 // multiple random transitions from an FST's state. This is a generic
158 // version with a straight-forward use of the arc selector.
159 // Specializations may be defined for arc selectors for greater
160 // efficiency or special behavior.
161 template <class A, class S>
162 class ArcSampler {
163  public:
164   typedef typename A::StateId StateId;
165   typedef typename A::Weight Weight;
166 
167   // The 'max_length' may be interpreted (including ignored) by a
168   // sampler as it chooses. This generic version interprets this literally.
169   ArcSampler(const Fst<A> &fst, const S &arc_selector,
170              int max_length = INT_MAX)
fst_(fst)171       : fst_(fst),
172         arc_selector_(arc_selector),
173         max_length_(max_length) {}
174 
175   // Allow updating Fst argument; pass only if changed.
176   ArcSampler(const ArcSampler<A, S> &sampler, const Fst<A> *fst = 0)
177       : fst_(fst ? *fst : sampler.fst_),
178         arc_selector_(sampler.arc_selector_),
179         max_length_(sampler.max_length_) {
180     Reset();
181   }
182 
183   // Samples 'rstate.nsamples' from state 'state_id'. The 'rstate.length' is
184   // the length of the path to 'rstate'. Returns true if samples were
185   // collected.  No samples may be collected if either there are no (including
186   // 'super-final') transitions leaving that state or if the
187   // 'max_length' has been deemed reached. Use the iterator members to
188   // read the samples. The samples will be in their original order.
Sample(const RandState<A> & rstate)189   bool Sample(const RandState<A> &rstate) {
190     sample_map_.clear();
191     if ((fst_.NumArcs(rstate.state_id) == 0 &&
192          fst_.Final(rstate.state_id) == Weight::Zero()) ||
193         rstate.length == max_length_) {
194       Reset();
195       return false;
196     }
197 
198     for (size_t i = 0; i < rstate.nsamples; ++i)
199       ++sample_map_[arc_selector_(fst_, rstate.state_id)];
200     Reset();
201     return true;
202   }
203 
204   // More samples?
Done()205   bool Done() const { return sample_iter_ == sample_map_.end(); }
206 
207   // Gets the next sample.
Next()208   void Next() { ++sample_iter_; }
209 
210   // Returns a pair (N, K) where 0 <= N <= NumArcs(s) and 0 < K <= nsamples.
211   // If N < NumArcs(s), then the N-th transition is specified.
212   // If N == NumArcs(s), then the final weight at that state is
213   // specified (i.e., the 'super-final' transition is specified).
214   // For the specified transition, K repetitions have been sampled.
Value()215   pair<size_t, size_t> Value() const { return *sample_iter_; }
216 
Reset()217   void Reset() { sample_iter_ = sample_map_.begin(); }
218 
Error()219   bool Error() const { return false; }
220 
221  private:
222   const Fst<A> &fst_;
223   const S &arc_selector_;
224   int max_length_;
225 
226   // Stores (N, K) as described for Value().
227   map<size_t, size_t> sample_map_;
228   map<size_t, size_t>::const_iterator sample_iter_;
229 
230   // disallow
231   ArcSampler<A, S> & operator=(const ArcSampler<A, S> &s);
232 };
233 
234 
235 // Specialization for FastLogProbArcSelector.
236 template <class A>
237 class ArcSampler<A, FastLogProbArcSelector<A> > {
238  public:
239   typedef FastLogProbArcSelector<A> S;
240   typedef typename A::StateId StateId;
241   typedef typename A::Weight Weight;
242   typedef CacheLogAccumulator<A> C;
243 
244   ArcSampler(const Fst<A> &fst, const S &arc_selector, int max_length = INT_MAX)
fst_(fst)245       : fst_(fst),
246         arc_selector_(arc_selector),
247         max_length_(max_length),
248         accumulator_(new C()) {
249     accumulator_->Init(fst);
250   }
251 
252   ArcSampler(const ArcSampler<A, S> &sampler, const Fst<A> *fst = 0)
253       : fst_(fst ? *fst : sampler.fst_),
254         arc_selector_(sampler.arc_selector_),
255         max_length_(sampler.max_length_) {
256     if (fst) {
257       accumulator_ = new C();
258       accumulator_->Init(*fst);
259     } else {  // shallow copy
260       accumulator_ = new C(*sampler.accumulator_);
261     }
262   }
263 
~ArcSampler()264   ~ArcSampler() {
265     delete accumulator_;
266   }
267 
Sample(const RandState<A> & rstate)268   bool Sample(const RandState<A> &rstate) {
269     sample_map_.clear();
270     if ((fst_.NumArcs(rstate.state_id) == 0 &&
271          fst_.Final(rstate.state_id) == Weight::Zero()) ||
272         rstate.length == max_length_) {
273       Reset();
274       return false;
275     }
276 
277     for (size_t i = 0; i < rstate.nsamples; ++i)
278       ++sample_map_[arc_selector_(fst_, rstate.state_id, accumulator_)];
279     Reset();
280     return true;
281   }
282 
Done()283   bool Done() const { return sample_iter_ == sample_map_.end(); }
Next()284   void Next() { ++sample_iter_; }
Value()285   pair<size_t, size_t> Value() const { return *sample_iter_; }
Reset()286   void Reset() { sample_iter_ = sample_map_.begin(); }
287 
Error()288   bool Error() const { return accumulator_->Error(); }
289 
290  private:
291   const Fst<A> &fst_;
292   const S &arc_selector_;
293   int max_length_;
294 
295   // Stores (N, K) as described for Value().
296   map<size_t, size_t> sample_map_;
297   map<size_t, size_t>::const_iterator sample_iter_;
298   C *accumulator_;
299 
300   // disallow
301   ArcSampler<A, S> & operator=(const ArcSampler<A, S> &s);
302 };
303 
304 
305 // Options for random path generation with RandGenFst. The template argument
306 // is an arc sampler, typically class 'ArcSampler' above.  Ownership of
307 // the sampler is taken by RandGenFst.
308 template <class S>
309 struct RandGenFstOptions : public CacheOptions {
310   S *arc_sampler;            // How to sample transitions at a state
311   size_t npath;              // # of paths to generate
312   bool weighted;             // Output tree weighted by path count; o.w.
313                              // output unweighted DAG
314   bool remove_total_weight;  // Remove total weight when output is weighted.
315 
316   RandGenFstOptions(const CacheOptions &copts, S *samp,
317                     size_t n = 1, bool w = true, bool rw = false)
CacheOptionsRandGenFstOptions318       : CacheOptions(copts),
319         arc_sampler(samp),
320         npath(n),
321         weighted(w),
322         remove_total_weight(rw) {}
323 };
324 
325 
326 // Implementation of RandGenFst.
327 template <class A, class B, class S>
328 class RandGenFstImpl : public CacheImpl<B> {
329  public:
330   using FstImpl<B>::SetType;
331   using FstImpl<B>::SetProperties;
332   using FstImpl<B>::SetInputSymbols;
333   using FstImpl<B>::SetOutputSymbols;
334 
335   using CacheBaseImpl< CacheState<B> >::AddArc;
336   using CacheBaseImpl< CacheState<B> >::HasArcs;
337   using CacheBaseImpl< CacheState<B> >::HasFinal;
338   using CacheBaseImpl< CacheState<B> >::HasStart;
339   using CacheBaseImpl< CacheState<B> >::SetArcs;
340   using CacheBaseImpl< CacheState<B> >::SetFinal;
341   using CacheBaseImpl< CacheState<B> >::SetStart;
342 
343   typedef B Arc;
344   typedef typename A::Label Label;
345   typedef typename A::Weight Weight;
346   typedef typename A::StateId StateId;
347 
RandGenFstImpl(const Fst<A> & fst,const RandGenFstOptions<S> & opts)348   RandGenFstImpl(const Fst<A> &fst, const RandGenFstOptions<S> &opts)
349       : CacheImpl<B>(opts),
350         fst_(fst.Copy()),
351         arc_sampler_(opts.arc_sampler),
352         npath_(opts.npath),
353         weighted_(opts.weighted),
354         remove_total_weight_(opts.remove_total_weight),
355         superfinal_(kNoLabel) {
356     SetType("randgen");
357 
358     uint64 props = fst.Properties(kFstProperties, false);
359     SetProperties(RandGenProperties(props, weighted_), kCopyProperties);
360 
361     SetInputSymbols(fst.InputSymbols());
362     SetOutputSymbols(fst.OutputSymbols());
363   }
364 
RandGenFstImpl(const RandGenFstImpl & impl)365   RandGenFstImpl(const RandGenFstImpl &impl)
366     : CacheImpl<B>(impl),
367       fst_(impl.fst_->Copy(true)),
368       arc_sampler_(new S(*impl.arc_sampler_, fst_)),
369       npath_(impl.npath_),
370       weighted_(impl.weighted_),
371       superfinal_(kNoLabel) {
372     SetType("randgen");
373     SetProperties(impl.Properties(), kCopyProperties);
374     SetInputSymbols(impl.InputSymbols());
375     SetOutputSymbols(impl.OutputSymbols());
376   }
377 
~RandGenFstImpl()378   ~RandGenFstImpl() {
379     for (int i = 0; i < state_table_.size(); ++i)
380       delete state_table_[i];
381     delete fst_;
382     delete arc_sampler_;
383   }
384 
Start()385   StateId Start() {
386     if (!HasStart()) {
387       StateId s = fst_->Start();
388       if (s == kNoStateId)
389         return kNoStateId;
390       StateId start = state_table_.size();
391       SetStart(start);
392       RandState<A> *rstate = new RandState<A>(s, npath_, 0, 0, 0);
393       state_table_.push_back(rstate);
394     }
395     return CacheImpl<B>::Start();
396   }
397 
Final(StateId s)398   Weight Final(StateId s) {
399     if (!HasFinal(s)) {
400       Expand(s);
401     }
402     return CacheImpl<B>::Final(s);
403   }
404 
NumArcs(StateId s)405   size_t NumArcs(StateId s) {
406     if (!HasArcs(s)) {
407       Expand(s);
408     }
409     return CacheImpl<B>::NumArcs(s);
410   }
411 
NumInputEpsilons(StateId s)412   size_t NumInputEpsilons(StateId s) {
413     if (!HasArcs(s))
414       Expand(s);
415     return CacheImpl<B>::NumInputEpsilons(s);
416   }
417 
NumOutputEpsilons(StateId s)418   size_t NumOutputEpsilons(StateId s) {
419     if (!HasArcs(s))
420       Expand(s);
421     return CacheImpl<B>::NumOutputEpsilons(s);
422   }
423 
Properties()424   uint64 Properties() const { return Properties(kFstProperties); }
425 
426   // Set error if found; return FST impl properties.
Properties(uint64 mask)427   uint64 Properties(uint64 mask) const {
428     if ((mask & kError) &&
429         (fst_->Properties(kError, false) || arc_sampler_->Error())) {
430       SetProperties(kError, kError);
431     }
432     return FstImpl<Arc>::Properties(mask);
433   }
434 
InitArcIterator(StateId s,ArcIteratorData<B> * data)435   void InitArcIterator(StateId s, ArcIteratorData<B> *data) {
436     if (!HasArcs(s))
437       Expand(s);
438     CacheImpl<B>::InitArcIterator(s, data);
439   }
440 
441   // Computes the outgoing transitions from a state, creating new destination
442   // states as needed.
Expand(StateId s)443   void Expand(StateId s) {
444     if (s == superfinal_) {
445       SetFinal(s, Weight::One());
446       SetArcs(s);
447       return;
448     }
449 
450     SetFinal(s, Weight::Zero());
451     const RandState<A> &rstate = *state_table_[s];
452     arc_sampler_->Sample(rstate);
453     ArcIterator< Fst<A> > aiter(*fst_, rstate.state_id);
454     size_t narcs = fst_->NumArcs(rstate.state_id);
455     for (;!arc_sampler_->Done(); arc_sampler_->Next()) {
456       const pair<size_t, size_t> &sample_pair = arc_sampler_->Value();
457       size_t pos = sample_pair.first;
458       size_t count = sample_pair.second;
459       double prob = static_cast<double>(count)/rstate.nsamples;
460       if (pos < narcs) {  // regular transition
461         aiter.Seek(sample_pair.first);
462         const A &aarc = aiter.Value();
463         Weight weight = weighted_ ? to_weight_(-log(prob)) : Weight::One();
464         B barc(aarc.ilabel, aarc.olabel, weight, state_table_.size());
465         AddArc(s, barc);
466         RandState<A> *nrstate =
467             new RandState<A>(aarc.nextstate, count, rstate.length + 1,
468                              pos, &rstate);
469         state_table_.push_back(nrstate);
470       } else {            // super-final transition
471         if (weighted_) {
472           Weight weight = remove_total_weight_ ?
473               to_weight_(-log(prob)) : to_weight_(-log(prob * npath_));
474           SetFinal(s, weight);
475         } else {
476           if (superfinal_ == kNoLabel) {
477             superfinal_ = state_table_.size();
478             RandState<A> *nrstate = new RandState<A>(kNoStateId, 0, 0, 0, 0);
479             state_table_.push_back(nrstate);
480           }
481           for (size_t n = 0; n < count; ++n) {
482             B barc(0, 0, Weight::One(), superfinal_);
483             AddArc(s, barc);
484           }
485         }
486       }
487     }
488     SetArcs(s);
489   }
490 
491  private:
492   Fst<A> *fst_;
493   S *arc_sampler_;
494   size_t npath_;
495   vector<RandState<A> *> state_table_;
496   bool weighted_;
497   bool remove_total_weight_;
498   StateId superfinal_;
499   WeightConvert<Log64Weight, Weight> to_weight_;
500 
501   void operator=(const RandGenFstImpl<A, B, S> &);  // disallow
502 };
503 
504 
505 // Fst class to randomly generate paths through an FST; details controlled
506 // by RandGenOptionsFst. Output format is a tree weighted by the
507 // path count.
508 template <class A, class B, class S>
509 class RandGenFst : public ImplToFst< RandGenFstImpl<A, B, S> > {
510  public:
511   friend class ArcIterator< RandGenFst<A, B, S> >;
512   friend class StateIterator< RandGenFst<A, B, S> >;
513   typedef B Arc;
514   typedef S Sampler;
515   typedef typename A::Label Label;
516   typedef typename A::Weight Weight;
517   typedef typename A::StateId StateId;
518   typedef CacheState<B> State;
519   typedef RandGenFstImpl<A, B, S> Impl;
520 
RandGenFst(const Fst<A> & fst,const RandGenFstOptions<S> & opts)521   RandGenFst(const Fst<A> &fst, const RandGenFstOptions<S> &opts)
522     : ImplToFst<Impl>(new Impl(fst, opts)) {}
523 
524   // See Fst<>::Copy() for doc.
525  RandGenFst(const RandGenFst<A, B, S> &fst, bool safe = false)
526     : ImplToFst<Impl>(fst, safe) {}
527 
528   // Get a copy of this RandGenFst. See Fst<>::Copy() for further doc.
529   virtual RandGenFst<A, B, S> *Copy(bool safe = false) const {
530     return new RandGenFst<A, B, S>(*this, safe);
531   }
532 
533   virtual inline void InitStateIterator(StateIteratorData<B> *data) const;
534 
InitArcIterator(StateId s,ArcIteratorData<B> * data)535   virtual void InitArcIterator(StateId s, ArcIteratorData<B> *data) const {
536     GetImpl()->InitArcIterator(s, data);
537   }
538 
539  private:
540   // Makes visible to friends.
GetImpl()541   Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
542 
543   void operator=(const RandGenFst<A, B, S> &fst);  // Disallow
544 };
545 
546 
547 
548 // Specialization for RandGenFst.
549 template <class A, class B, class S>
550 class StateIterator< RandGenFst<A, B, S> >
551     : public CacheStateIterator< RandGenFst<A, B, S> > {
552  public:
StateIterator(const RandGenFst<A,B,S> & fst)553   explicit StateIterator(const RandGenFst<A, B, S> &fst)
554     : CacheStateIterator< RandGenFst<A, B, S> >(fst, fst.GetImpl()) {}
555 
556  private:
557   DISALLOW_COPY_AND_ASSIGN(StateIterator);
558 };
559 
560 
561 // Specialization for RandGenFst.
562 template <class A, class B, class S>
563 class ArcIterator< RandGenFst<A, B, S> >
564     : public CacheArcIterator< RandGenFst<A, B, S> > {
565  public:
566   typedef typename A::StateId StateId;
567 
ArcIterator(const RandGenFst<A,B,S> & fst,StateId s)568   ArcIterator(const RandGenFst<A, B, S> &fst, StateId s)
569       : CacheArcIterator< RandGenFst<A, B, S> >(fst.GetImpl(), s) {
570     if (!fst.GetImpl()->HasArcs(s))
571       fst.GetImpl()->Expand(s);
572   }
573 
574  private:
575   DISALLOW_COPY_AND_ASSIGN(ArcIterator);
576 };
577 
578 
579 template <class A, class B, class S> inline
InitStateIterator(StateIteratorData<B> * data)580 void RandGenFst<A, B, S>::InitStateIterator(StateIteratorData<B> *data) const
581 {
582   data->base = new StateIterator< RandGenFst<A, B, S> >(*this);
583 }
584 
585 // Options for random path generation.
586 template <class S>
587 struct RandGenOptions {
588   const S &arc_selector;     // How an arc is selected at a state
589   int max_length;            // Maximum path length
590   size_t npath;              // # of paths to generate
591   bool weighted;             // Output is tree weighted by path count; o.w.
592                              // output unweighted union of paths.
593   bool remove_total_weight;  // Remove total weight when output is weighted.
594 
595   RandGenOptions(const S &sel, int len = INT_MAX, size_t n = 1,
596                  bool w = false, bool rw = false)
arc_selectorRandGenOptions597       : arc_selector(sel),
598         max_length(len),
599         npath(n),
600         weighted(w),
601         remove_total_weight(rw) {}
602 };
603 
604 
605 template <class IArc, class OArc>
606 class RandGenVisitor {
607  public:
608   typedef typename IArc::Weight Weight;
609   typedef typename IArc::StateId StateId;
610 
RandGenVisitor(MutableFst<OArc> * ofst)611   RandGenVisitor(MutableFst<OArc> *ofst) : ofst_(ofst) {}
612 
InitVisit(const Fst<IArc> & ifst)613   void InitVisit(const Fst<IArc> &ifst) {
614     ifst_ = &ifst;
615 
616     ofst_->DeleteStates();
617     ofst_->SetInputSymbols(ifst.InputSymbols());
618     ofst_->SetOutputSymbols(ifst.OutputSymbols());
619     if (ifst.Properties(kError, false))
620       ofst_->SetProperties(kError, kError);
621     path_.clear();
622   }
623 
InitState(StateId s,StateId root)624   bool InitState(StateId s, StateId root) { return true; }
625 
TreeArc(StateId s,const IArc & arc)626   bool TreeArc(StateId s, const IArc &arc) {
627     if (ifst_->Final(arc.nextstate) == Weight::Zero()) {
628       path_.push_back(arc);
629     } else {
630       OutputPath();
631     }
632     return true;
633   }
634 
BackArc(StateId s,const IArc & arc)635   bool BackArc(StateId s, const IArc &arc) {
636     FSTERROR() << "RandGenVisitor: cyclic input";
637     ofst_->SetProperties(kError, kError);
638     return false;
639   }
640 
ForwardOrCrossArc(StateId s,const IArc & arc)641   bool ForwardOrCrossArc(StateId s, const IArc &arc) {
642     OutputPath();
643     return true;
644   }
645 
FinishState(StateId s,StateId p,const IArc *)646   void FinishState(StateId s, StateId p, const IArc *) {
647     if (p != kNoStateId && ifst_->Final(s) == Weight::Zero())
648       path_.pop_back();
649   }
650 
FinishVisit()651   void FinishVisit() {}
652 
653  private:
OutputPath()654   void OutputPath() {
655     if (ofst_->Start() == kNoStateId) {
656       StateId start = ofst_->AddState();
657       ofst_->SetStart(start);
658     }
659 
660     StateId src = ofst_->Start();
661     for (size_t i = 0; i < path_.size(); ++i) {
662       StateId dest = ofst_->AddState();
663       OArc arc(path_[i].ilabel, path_[i].olabel, Weight::One(), dest);
664       ofst_->AddArc(src, arc);
665       src = dest;
666     }
667     ofst_->SetFinal(src, Weight::One());
668   }
669 
670   const Fst<IArc> *ifst_;
671   MutableFst<OArc> *ofst_;
672   vector<OArc> path_;
673 
674   DISALLOW_COPY_AND_ASSIGN(RandGenVisitor);
675 };
676 
677 
678 // Randomly generate paths through an FST; details controlled by
679 // RandGenOptions.
680 template<class IArc, class OArc, class Selector>
RandGen(const Fst<IArc> & ifst,MutableFst<OArc> * ofst,const RandGenOptions<Selector> & opts)681 void RandGen(const Fst<IArc> &ifst, MutableFst<OArc> *ofst,
682              const RandGenOptions<Selector> &opts) {
683   typedef ArcSampler<IArc, Selector> Sampler;
684   typedef RandGenFst<IArc, OArc, Sampler> RandFst;
685   typedef typename OArc::StateId StateId;
686   typedef typename OArc::Weight Weight;
687 
688   Sampler* arc_sampler = new Sampler(ifst, opts.arc_selector, opts.max_length);
689   RandGenFstOptions<Sampler> fopts(CacheOptions(true, 0), arc_sampler,
690                                    opts.npath, opts.weighted,
691                                    opts.remove_total_weight);
692   RandFst rfst(ifst, fopts);
693   if (opts.weighted) {
694     *ofst = rfst;
695   } else {
696     RandGenVisitor<IArc, OArc> rand_visitor(ofst);
697     DfsVisit(rfst, &rand_visitor);
698   }
699 }
700 
701 // Randomly generate a path through an FST with the uniform distribution
702 // over the transitions.
703 template<class IArc, class OArc>
RandGen(const Fst<IArc> & ifst,MutableFst<OArc> * ofst)704 void RandGen(const Fst<IArc> &ifst, MutableFst<OArc> *ofst) {
705   UniformArcSelector<IArc> uniform_selector;
706   RandGenOptions< UniformArcSelector<IArc> > opts(uniform_selector);
707   RandGen(ifst, ofst, opts);
708 }
709 
710 }  // namespace fst
711 
712 #endif  // FST_LIB_RANDGEN_H__
713