1 // randgen.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Classes and functions to generate random paths through an FST.
20
21 #ifndef FST_LIB_RANDGEN_H__
22 #define FST_LIB_RANDGEN_H__
23
24 #include <cmath>
25 #include <cstdlib>
26 #include <ctime>
27 #include <map>
28
29 #include <fst/accumulator.h>
30 #include <fst/cache.h>
31 #include <fst/dfs-visit.h>
32 #include <fst/mutable-fst.h>
33
34 namespace fst {
35
36 //
37 // ARC SELECTORS - these function objects are used to select a random
38 // transition to take from an FST's state. They should return a number
39 // N s.t. 0 <= N <= NumArcs(). If N < NumArcs(), then the N-th
40 // transition is selected. If N == NumArcs(), then the final weight at
41 // that state is selected (i.e., the 'super-final' transition is selected).
42 // It can be assumed these will not be called unless either there
43 // are transitions leaving the state and/or the state is final.
44 //
45
46 // Randomly selects a transition using the uniform distribution.
47 template <class A>
48 struct UniformArcSelector {
49 typedef typename A::StateId StateId;
50 typedef typename A::Weight Weight;
51
52 UniformArcSelector(int seed = time(0)) { srand(seed); }
53
operatorUniformArcSelector54 size_t operator()(const Fst<A> &fst, StateId s) const {
55 double r = rand()/(RAND_MAX + 1.0);
56 size_t n = fst.NumArcs(s);
57 if (fst.Final(s) != Weight::Zero())
58 ++n;
59 return static_cast<size_t>(r * n);
60 }
61 };
62
63
64 // Randomly selects a transition w.r.t. the weights treated as negative
65 // log probabilities after normalizing for the total weight leaving
66 // the state. Weight::zero transitions are disregarded.
67 // Assumes Weight::Value() accesses the floating point
68 // representation of the weight.
69 template <class A>
70 class LogProbArcSelector {
71 public:
72 typedef typename A::StateId StateId;
73 typedef typename A::Weight Weight;
74
75 LogProbArcSelector(int seed = time(0)) { srand(seed); }
76
operator()77 size_t operator()(const Fst<A> &fst, StateId s) const {
78 // Find total weight leaving state
79 double sum = 0.0;
80 for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done();
81 aiter.Next()) {
82 const A &arc = aiter.Value();
83 sum += exp(-to_log_weight_(arc.weight).Value());
84 }
85 sum += exp(-to_log_weight_(fst.Final(s)).Value());
86
87 double r = rand()/(RAND_MAX + 1.0);
88 double p = 0.0;
89 int n = 0;
90 for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done();
91 aiter.Next(), ++n) {
92 const A &arc = aiter.Value();
93 p += exp(-to_log_weight_(arc.weight).Value());
94 if (p > r * sum) return n;
95 }
96 return n;
97 }
98
99 private:
100 WeightConvert<Weight, Log64Weight> to_log_weight_;
101 };
102
103 // Convenience definitions
104 typedef LogProbArcSelector<StdArc> StdArcSelector;
105 typedef LogProbArcSelector<LogArc> LogArcSelector;
106
107
108 // Same as LogProbArcSelector but use CacheLogAccumulator to cache
109 // the cummulative weight computations.
110 template <class A>
111 class FastLogProbArcSelector : public LogProbArcSelector<A> {
112 public:
113 typedef typename A::StateId StateId;
114 typedef typename A::Weight Weight;
115 using LogProbArcSelector<A>::operator();
116
117 FastLogProbArcSelector(int seed = time(0))
118 : LogProbArcSelector<A>(seed),
119 seed_(seed) {}
120
operator()121 size_t operator()(const Fst<A> &fst, StateId s,
122 CacheLogAccumulator<A> *accumulator) const {
123 accumulator->SetState(s);
124 ArcIterator< Fst<A> > aiter(fst, s);
125 // Find total weight leaving state
126 double sum = to_log_weight_(accumulator->Sum(fst.Final(s), &aiter, 0,
127 fst.NumArcs(s))).Value();
128 double r = -log(rand()/(RAND_MAX + 1.0));
129 return accumulator->LowerBound(r + sum, &aiter);
130 }
131
Seed()132 int Seed() const { return seed_; }
133 private:
134 int seed_;
135 WeightConvert<Weight, Log64Weight> to_log_weight_;
136 };
137
138 // Random path state info maintained by RandGenFst and passed to samplers.
139 template <typename A>
140 struct RandState {
141 typedef typename A::StateId StateId;
142
143 StateId state_id; // current input FST state
144 size_t nsamples; // # of samples to be sampled at this state
145 size_t length; // length of path to this random state
146 size_t select; // previous sample arc selection
147 const RandState<A> *parent; // previous random state on this path
148
RandStateRandState149 RandState(StateId s, size_t n, size_t l, size_t k, const RandState<A> *p)
150 : state_id(s), nsamples(n), length(l), select(k), parent(p) {}
151
RandStateRandState152 RandState()
153 : state_id(kNoStateId), nsamples(0), length(0), select(0), parent(0) {}
154 };
155
156 // This class, given an arc selector, samples, with raplacement,
157 // multiple random transitions from an FST's state. This is a generic
158 // version with a straight-forward use of the arc selector.
159 // Specializations may be defined for arc selectors for greater
160 // efficiency or special behavior.
161 template <class A, class S>
162 class ArcSampler {
163 public:
164 typedef typename A::StateId StateId;
165 typedef typename A::Weight Weight;
166
167 // The 'max_length' may be interpreted (including ignored) by a
168 // sampler as it chooses. This generic version interprets this literally.
169 ArcSampler(const Fst<A> &fst, const S &arc_selector,
170 int max_length = INT_MAX)
fst_(fst)171 : fst_(fst),
172 arc_selector_(arc_selector),
173 max_length_(max_length) {}
174
175 // Allow updating Fst argument; pass only if changed.
176 ArcSampler(const ArcSampler<A, S> &sampler, const Fst<A> *fst = 0)
177 : fst_(fst ? *fst : sampler.fst_),
178 arc_selector_(sampler.arc_selector_),
179 max_length_(sampler.max_length_) {
180 Reset();
181 }
182
183 // Samples 'rstate.nsamples' from state 'state_id'. The 'rstate.length' is
184 // the length of the path to 'rstate'. Returns true if samples were
185 // collected. No samples may be collected if either there are no (including
186 // 'super-final') transitions leaving that state or if the
187 // 'max_length' has been deemed reached. Use the iterator members to
188 // read the samples. The samples will be in their original order.
Sample(const RandState<A> & rstate)189 bool Sample(const RandState<A> &rstate) {
190 sample_map_.clear();
191 if ((fst_.NumArcs(rstate.state_id) == 0 &&
192 fst_.Final(rstate.state_id) == Weight::Zero()) ||
193 rstate.length == max_length_) {
194 Reset();
195 return false;
196 }
197
198 for (size_t i = 0; i < rstate.nsamples; ++i)
199 ++sample_map_[arc_selector_(fst_, rstate.state_id)];
200 Reset();
201 return true;
202 }
203
204 // More samples?
Done()205 bool Done() const { return sample_iter_ == sample_map_.end(); }
206
207 // Gets the next sample.
Next()208 void Next() { ++sample_iter_; }
209
210 // Returns a pair (N, K) where 0 <= N <= NumArcs(s) and 0 < K <= nsamples.
211 // If N < NumArcs(s), then the N-th transition is specified.
212 // If N == NumArcs(s), then the final weight at that state is
213 // specified (i.e., the 'super-final' transition is specified).
214 // For the specified transition, K repetitions have been sampled.
Value()215 pair<size_t, size_t> Value() const { return *sample_iter_; }
216
Reset()217 void Reset() { sample_iter_ = sample_map_.begin(); }
218
Error()219 bool Error() const { return false; }
220
221 private:
222 const Fst<A> &fst_;
223 const S &arc_selector_;
224 int max_length_;
225
226 // Stores (N, K) as described for Value().
227 map<size_t, size_t> sample_map_;
228 map<size_t, size_t>::const_iterator sample_iter_;
229
230 // disallow
231 ArcSampler<A, S> & operator=(const ArcSampler<A, S> &s);
232 };
233
234
235 // Specialization for FastLogProbArcSelector.
236 template <class A>
237 class ArcSampler<A, FastLogProbArcSelector<A> > {
238 public:
239 typedef FastLogProbArcSelector<A> S;
240 typedef typename A::StateId StateId;
241 typedef typename A::Weight Weight;
242 typedef CacheLogAccumulator<A> C;
243
244 ArcSampler(const Fst<A> &fst, const S &arc_selector, int max_length = INT_MAX)
fst_(fst)245 : fst_(fst),
246 arc_selector_(arc_selector),
247 max_length_(max_length),
248 accumulator_(new C()) {
249 accumulator_->Init(fst);
250 }
251
252 ArcSampler(const ArcSampler<A, S> &sampler, const Fst<A> *fst = 0)
253 : fst_(fst ? *fst : sampler.fst_),
254 arc_selector_(sampler.arc_selector_),
255 max_length_(sampler.max_length_) {
256 if (fst) {
257 accumulator_ = new C();
258 accumulator_->Init(*fst);
259 } else { // shallow copy
260 accumulator_ = new C(*sampler.accumulator_);
261 }
262 }
263
~ArcSampler()264 ~ArcSampler() {
265 delete accumulator_;
266 }
267
Sample(const RandState<A> & rstate)268 bool Sample(const RandState<A> &rstate) {
269 sample_map_.clear();
270 if ((fst_.NumArcs(rstate.state_id) == 0 &&
271 fst_.Final(rstate.state_id) == Weight::Zero()) ||
272 rstate.length == max_length_) {
273 Reset();
274 return false;
275 }
276
277 for (size_t i = 0; i < rstate.nsamples; ++i)
278 ++sample_map_[arc_selector_(fst_, rstate.state_id, accumulator_)];
279 Reset();
280 return true;
281 }
282
Done()283 bool Done() const { return sample_iter_ == sample_map_.end(); }
Next()284 void Next() { ++sample_iter_; }
Value()285 pair<size_t, size_t> Value() const { return *sample_iter_; }
Reset()286 void Reset() { sample_iter_ = sample_map_.begin(); }
287
Error()288 bool Error() const { return accumulator_->Error(); }
289
290 private:
291 const Fst<A> &fst_;
292 const S &arc_selector_;
293 int max_length_;
294
295 // Stores (N, K) as described for Value().
296 map<size_t, size_t> sample_map_;
297 map<size_t, size_t>::const_iterator sample_iter_;
298 C *accumulator_;
299
300 // disallow
301 ArcSampler<A, S> & operator=(const ArcSampler<A, S> &s);
302 };
303
304
305 // Options for random path generation with RandGenFst. The template argument
306 // is an arc sampler, typically class 'ArcSampler' above. Ownership of
307 // the sampler is taken by RandGenFst.
308 template <class S>
309 struct RandGenFstOptions : public CacheOptions {
310 S *arc_sampler; // How to sample transitions at a state
311 size_t npath; // # of paths to generate
312 bool weighted; // Output tree weighted by path count; o.w.
313 // output unweighted DAG
314 bool remove_total_weight; // Remove total weight when output is weighted.
315
316 RandGenFstOptions(const CacheOptions &copts, S *samp,
317 size_t n = 1, bool w = true, bool rw = false)
CacheOptionsRandGenFstOptions318 : CacheOptions(copts),
319 arc_sampler(samp),
320 npath(n),
321 weighted(w),
322 remove_total_weight(rw) {}
323 };
324
325
326 // Implementation of RandGenFst.
327 template <class A, class B, class S>
328 class RandGenFstImpl : public CacheImpl<B> {
329 public:
330 using FstImpl<B>::SetType;
331 using FstImpl<B>::SetProperties;
332 using FstImpl<B>::SetInputSymbols;
333 using FstImpl<B>::SetOutputSymbols;
334
335 using CacheBaseImpl< CacheState<B> >::AddArc;
336 using CacheBaseImpl< CacheState<B> >::HasArcs;
337 using CacheBaseImpl< CacheState<B> >::HasFinal;
338 using CacheBaseImpl< CacheState<B> >::HasStart;
339 using CacheBaseImpl< CacheState<B> >::SetArcs;
340 using CacheBaseImpl< CacheState<B> >::SetFinal;
341 using CacheBaseImpl< CacheState<B> >::SetStart;
342
343 typedef B Arc;
344 typedef typename A::Label Label;
345 typedef typename A::Weight Weight;
346 typedef typename A::StateId StateId;
347
RandGenFstImpl(const Fst<A> & fst,const RandGenFstOptions<S> & opts)348 RandGenFstImpl(const Fst<A> &fst, const RandGenFstOptions<S> &opts)
349 : CacheImpl<B>(opts),
350 fst_(fst.Copy()),
351 arc_sampler_(opts.arc_sampler),
352 npath_(opts.npath),
353 weighted_(opts.weighted),
354 remove_total_weight_(opts.remove_total_weight),
355 superfinal_(kNoLabel) {
356 SetType("randgen");
357
358 uint64 props = fst.Properties(kFstProperties, false);
359 SetProperties(RandGenProperties(props, weighted_), kCopyProperties);
360
361 SetInputSymbols(fst.InputSymbols());
362 SetOutputSymbols(fst.OutputSymbols());
363 }
364
RandGenFstImpl(const RandGenFstImpl & impl)365 RandGenFstImpl(const RandGenFstImpl &impl)
366 : CacheImpl<B>(impl),
367 fst_(impl.fst_->Copy(true)),
368 arc_sampler_(new S(*impl.arc_sampler_, fst_)),
369 npath_(impl.npath_),
370 weighted_(impl.weighted_),
371 superfinal_(kNoLabel) {
372 SetType("randgen");
373 SetProperties(impl.Properties(), kCopyProperties);
374 SetInputSymbols(impl.InputSymbols());
375 SetOutputSymbols(impl.OutputSymbols());
376 }
377
~RandGenFstImpl()378 ~RandGenFstImpl() {
379 for (int i = 0; i < state_table_.size(); ++i)
380 delete state_table_[i];
381 delete fst_;
382 delete arc_sampler_;
383 }
384
Start()385 StateId Start() {
386 if (!HasStart()) {
387 StateId s = fst_->Start();
388 if (s == kNoStateId)
389 return kNoStateId;
390 StateId start = state_table_.size();
391 SetStart(start);
392 RandState<A> *rstate = new RandState<A>(s, npath_, 0, 0, 0);
393 state_table_.push_back(rstate);
394 }
395 return CacheImpl<B>::Start();
396 }
397
Final(StateId s)398 Weight Final(StateId s) {
399 if (!HasFinal(s)) {
400 Expand(s);
401 }
402 return CacheImpl<B>::Final(s);
403 }
404
NumArcs(StateId s)405 size_t NumArcs(StateId s) {
406 if (!HasArcs(s)) {
407 Expand(s);
408 }
409 return CacheImpl<B>::NumArcs(s);
410 }
411
NumInputEpsilons(StateId s)412 size_t NumInputEpsilons(StateId s) {
413 if (!HasArcs(s))
414 Expand(s);
415 return CacheImpl<B>::NumInputEpsilons(s);
416 }
417
NumOutputEpsilons(StateId s)418 size_t NumOutputEpsilons(StateId s) {
419 if (!HasArcs(s))
420 Expand(s);
421 return CacheImpl<B>::NumOutputEpsilons(s);
422 }
423
Properties()424 uint64 Properties() const { return Properties(kFstProperties); }
425
426 // Set error if found; return FST impl properties.
Properties(uint64 mask)427 uint64 Properties(uint64 mask) const {
428 if ((mask & kError) &&
429 (fst_->Properties(kError, false) || arc_sampler_->Error())) {
430 SetProperties(kError, kError);
431 }
432 return FstImpl<Arc>::Properties(mask);
433 }
434
InitArcIterator(StateId s,ArcIteratorData<B> * data)435 void InitArcIterator(StateId s, ArcIteratorData<B> *data) {
436 if (!HasArcs(s))
437 Expand(s);
438 CacheImpl<B>::InitArcIterator(s, data);
439 }
440
441 // Computes the outgoing transitions from a state, creating new destination
442 // states as needed.
Expand(StateId s)443 void Expand(StateId s) {
444 if (s == superfinal_) {
445 SetFinal(s, Weight::One());
446 SetArcs(s);
447 return;
448 }
449
450 SetFinal(s, Weight::Zero());
451 const RandState<A> &rstate = *state_table_[s];
452 arc_sampler_->Sample(rstate);
453 ArcIterator< Fst<A> > aiter(*fst_, rstate.state_id);
454 size_t narcs = fst_->NumArcs(rstate.state_id);
455 for (;!arc_sampler_->Done(); arc_sampler_->Next()) {
456 const pair<size_t, size_t> &sample_pair = arc_sampler_->Value();
457 size_t pos = sample_pair.first;
458 size_t count = sample_pair.second;
459 double prob = static_cast<double>(count)/rstate.nsamples;
460 if (pos < narcs) { // regular transition
461 aiter.Seek(sample_pair.first);
462 const A &aarc = aiter.Value();
463 Weight weight = weighted_ ? to_weight_(-log(prob)) : Weight::One();
464 B barc(aarc.ilabel, aarc.olabel, weight, state_table_.size());
465 AddArc(s, barc);
466 RandState<A> *nrstate =
467 new RandState<A>(aarc.nextstate, count, rstate.length + 1,
468 pos, &rstate);
469 state_table_.push_back(nrstate);
470 } else { // super-final transition
471 if (weighted_) {
472 Weight weight = remove_total_weight_ ?
473 to_weight_(-log(prob)) : to_weight_(-log(prob * npath_));
474 SetFinal(s, weight);
475 } else {
476 if (superfinal_ == kNoLabel) {
477 superfinal_ = state_table_.size();
478 RandState<A> *nrstate = new RandState<A>(kNoStateId, 0, 0, 0, 0);
479 state_table_.push_back(nrstate);
480 }
481 for (size_t n = 0; n < count; ++n) {
482 B barc(0, 0, Weight::One(), superfinal_);
483 AddArc(s, barc);
484 }
485 }
486 }
487 }
488 SetArcs(s);
489 }
490
491 private:
492 Fst<A> *fst_;
493 S *arc_sampler_;
494 size_t npath_;
495 vector<RandState<A> *> state_table_;
496 bool weighted_;
497 bool remove_total_weight_;
498 StateId superfinal_;
499 WeightConvert<Log64Weight, Weight> to_weight_;
500
501 void operator=(const RandGenFstImpl<A, B, S> &); // disallow
502 };
503
504
505 // Fst class to randomly generate paths through an FST; details controlled
506 // by RandGenOptionsFst. Output format is a tree weighted by the
507 // path count.
508 template <class A, class B, class S>
509 class RandGenFst : public ImplToFst< RandGenFstImpl<A, B, S> > {
510 public:
511 friend class ArcIterator< RandGenFst<A, B, S> >;
512 friend class StateIterator< RandGenFst<A, B, S> >;
513 typedef B Arc;
514 typedef S Sampler;
515 typedef typename A::Label Label;
516 typedef typename A::Weight Weight;
517 typedef typename A::StateId StateId;
518 typedef CacheState<B> State;
519 typedef RandGenFstImpl<A, B, S> Impl;
520
RandGenFst(const Fst<A> & fst,const RandGenFstOptions<S> & opts)521 RandGenFst(const Fst<A> &fst, const RandGenFstOptions<S> &opts)
522 : ImplToFst<Impl>(new Impl(fst, opts)) {}
523
524 // See Fst<>::Copy() for doc.
525 RandGenFst(const RandGenFst<A, B, S> &fst, bool safe = false)
526 : ImplToFst<Impl>(fst, safe) {}
527
528 // Get a copy of this RandGenFst. See Fst<>::Copy() for further doc.
529 virtual RandGenFst<A, B, S> *Copy(bool safe = false) const {
530 return new RandGenFst<A, B, S>(*this, safe);
531 }
532
533 virtual inline void InitStateIterator(StateIteratorData<B> *data) const;
534
InitArcIterator(StateId s,ArcIteratorData<B> * data)535 virtual void InitArcIterator(StateId s, ArcIteratorData<B> *data) const {
536 GetImpl()->InitArcIterator(s, data);
537 }
538
539 private:
540 // Makes visible to friends.
GetImpl()541 Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
542
543 void operator=(const RandGenFst<A, B, S> &fst); // Disallow
544 };
545
546
547
548 // Specialization for RandGenFst.
549 template <class A, class B, class S>
550 class StateIterator< RandGenFst<A, B, S> >
551 : public CacheStateIterator< RandGenFst<A, B, S> > {
552 public:
StateIterator(const RandGenFst<A,B,S> & fst)553 explicit StateIterator(const RandGenFst<A, B, S> &fst)
554 : CacheStateIterator< RandGenFst<A, B, S> >(fst, fst.GetImpl()) {}
555
556 private:
557 DISALLOW_COPY_AND_ASSIGN(StateIterator);
558 };
559
560
561 // Specialization for RandGenFst.
562 template <class A, class B, class S>
563 class ArcIterator< RandGenFst<A, B, S> >
564 : public CacheArcIterator< RandGenFst<A, B, S> > {
565 public:
566 typedef typename A::StateId StateId;
567
ArcIterator(const RandGenFst<A,B,S> & fst,StateId s)568 ArcIterator(const RandGenFst<A, B, S> &fst, StateId s)
569 : CacheArcIterator< RandGenFst<A, B, S> >(fst.GetImpl(), s) {
570 if (!fst.GetImpl()->HasArcs(s))
571 fst.GetImpl()->Expand(s);
572 }
573
574 private:
575 DISALLOW_COPY_AND_ASSIGN(ArcIterator);
576 };
577
578
579 template <class A, class B, class S> inline
InitStateIterator(StateIteratorData<B> * data)580 void RandGenFst<A, B, S>::InitStateIterator(StateIteratorData<B> *data) const
581 {
582 data->base = new StateIterator< RandGenFst<A, B, S> >(*this);
583 }
584
585 // Options for random path generation.
586 template <class S>
587 struct RandGenOptions {
588 const S &arc_selector; // How an arc is selected at a state
589 int max_length; // Maximum path length
590 size_t npath; // # of paths to generate
591 bool weighted; // Output is tree weighted by path count; o.w.
592 // output unweighted union of paths.
593 bool remove_total_weight; // Remove total weight when output is weighted.
594
595 RandGenOptions(const S &sel, int len = INT_MAX, size_t n = 1,
596 bool w = false, bool rw = false)
arc_selectorRandGenOptions597 : arc_selector(sel),
598 max_length(len),
599 npath(n),
600 weighted(w),
601 remove_total_weight(rw) {}
602 };
603
604
605 template <class IArc, class OArc>
606 class RandGenVisitor {
607 public:
608 typedef typename IArc::Weight Weight;
609 typedef typename IArc::StateId StateId;
610
RandGenVisitor(MutableFst<OArc> * ofst)611 RandGenVisitor(MutableFst<OArc> *ofst) : ofst_(ofst) {}
612
InitVisit(const Fst<IArc> & ifst)613 void InitVisit(const Fst<IArc> &ifst) {
614 ifst_ = &ifst;
615
616 ofst_->DeleteStates();
617 ofst_->SetInputSymbols(ifst.InputSymbols());
618 ofst_->SetOutputSymbols(ifst.OutputSymbols());
619 if (ifst.Properties(kError, false))
620 ofst_->SetProperties(kError, kError);
621 path_.clear();
622 }
623
InitState(StateId s,StateId root)624 bool InitState(StateId s, StateId root) { return true; }
625
TreeArc(StateId s,const IArc & arc)626 bool TreeArc(StateId s, const IArc &arc) {
627 if (ifst_->Final(arc.nextstate) == Weight::Zero()) {
628 path_.push_back(arc);
629 } else {
630 OutputPath();
631 }
632 return true;
633 }
634
BackArc(StateId s,const IArc & arc)635 bool BackArc(StateId s, const IArc &arc) {
636 FSTERROR() << "RandGenVisitor: cyclic input";
637 ofst_->SetProperties(kError, kError);
638 return false;
639 }
640
ForwardOrCrossArc(StateId s,const IArc & arc)641 bool ForwardOrCrossArc(StateId s, const IArc &arc) {
642 OutputPath();
643 return true;
644 }
645
FinishState(StateId s,StateId p,const IArc *)646 void FinishState(StateId s, StateId p, const IArc *) {
647 if (p != kNoStateId && ifst_->Final(s) == Weight::Zero())
648 path_.pop_back();
649 }
650
FinishVisit()651 void FinishVisit() {}
652
653 private:
OutputPath()654 void OutputPath() {
655 if (ofst_->Start() == kNoStateId) {
656 StateId start = ofst_->AddState();
657 ofst_->SetStart(start);
658 }
659
660 StateId src = ofst_->Start();
661 for (size_t i = 0; i < path_.size(); ++i) {
662 StateId dest = ofst_->AddState();
663 OArc arc(path_[i].ilabel, path_[i].olabel, Weight::One(), dest);
664 ofst_->AddArc(src, arc);
665 src = dest;
666 }
667 ofst_->SetFinal(src, Weight::One());
668 }
669
670 const Fst<IArc> *ifst_;
671 MutableFst<OArc> *ofst_;
672 vector<OArc> path_;
673
674 DISALLOW_COPY_AND_ASSIGN(RandGenVisitor);
675 };
676
677
678 // Randomly generate paths through an FST; details controlled by
679 // RandGenOptions.
680 template<class IArc, class OArc, class Selector>
RandGen(const Fst<IArc> & ifst,MutableFst<OArc> * ofst,const RandGenOptions<Selector> & opts)681 void RandGen(const Fst<IArc> &ifst, MutableFst<OArc> *ofst,
682 const RandGenOptions<Selector> &opts) {
683 typedef ArcSampler<IArc, Selector> Sampler;
684 typedef RandGenFst<IArc, OArc, Sampler> RandFst;
685 typedef typename OArc::StateId StateId;
686 typedef typename OArc::Weight Weight;
687
688 Sampler* arc_sampler = new Sampler(ifst, opts.arc_selector, opts.max_length);
689 RandGenFstOptions<Sampler> fopts(CacheOptions(true, 0), arc_sampler,
690 opts.npath, opts.weighted,
691 opts.remove_total_weight);
692 RandFst rfst(ifst, fopts);
693 if (opts.weighted) {
694 *ofst = rfst;
695 } else {
696 RandGenVisitor<IArc, OArc> rand_visitor(ofst);
697 DfsVisit(rfst, &rand_visitor);
698 }
699 }
700
701 // Randomly generate a path through an FST with the uniform distribution
702 // over the transitions.
703 template<class IArc, class OArc>
RandGen(const Fst<IArc> & ifst,MutableFst<OArc> * ofst)704 void RandGen(const Fst<IArc> &ifst, MutableFst<OArc> *ofst) {
705 UniformArcSelector<IArc> uniform_selector;
706 RandGenOptions< UniformArcSelector<IArc> > opts(uniform_selector);
707 RandGen(ifst, ofst, opts);
708 }
709
710 } // namespace fst
711
712 #endif // FST_LIB_RANDGEN_H__
713